diff --git a/polyfactory/value_generators/constrained_numbers.py b/polyfactory/value_generators/constrained_numbers.py index 23516ceb..88b22049 100644 --- a/polyfactory/value_generators/constrained_numbers.py +++ b/polyfactory/value_generators/constrained_numbers.py @@ -1,6 +1,8 @@ from __future__ import annotations +import decimal from decimal import Decimal +from math import ceil, floor, isinf from sys import float_info from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast @@ -227,16 +229,62 @@ def generate_constrained_number( :returns: A value of type T. """ - if minimum is None or maximum is None: - return multiple_of if multiple_of is not None else method(random=random) if multiple_of is None: return method(random=random, minimum=minimum, maximum=maximum) - if multiple_of >= minimum: + + def passes_all_constraints(value: T) -> bool: + return ( + (minimum is None or value >= minimum) + and (maximum is None or value <= maximum) + and (multiple_of is None or passes_pydantic_multiple_validator(value, multiple_of)) + ) + + # If the arguments are Decimals, they might have precision that is greater than the current decimal context. If + # so, recreate them under the current context to ensure they have the appropriate precision. This is important + # because otherwise, x * 1 == x may not always hold, which can cause the algorithm below to fail in unintuitive + # ways. + if isinstance(minimum, Decimal): + minimum = decimal.getcontext().create_decimal(minimum) + if isinstance(maximum, Decimal): + maximum = decimal.getcontext().create_decimal(maximum) + if isinstance(multiple_of, Decimal): + multiple_of = decimal.getcontext().create_decimal(multiple_of) + + max_attempts = 10 + for _ in range(max_attempts): + # We attempt to generate a random number and find the nearest valid multiple, but a naive approach of rounding + # to the nearest multiple may push the number out of range. To handle edge cases, we find both the nearest + # multiple in both the negative and positive directions (floor and ceil), and we pick one that fits within + # range. We should be guaranteed to find a number other than in the case where the range (minimum, maximum) is + # narrow and does not contain any multiple of multiple_of. + random_value = method(random=random, minimum=minimum, maximum=maximum) + quotient = random_value / multiple_of + if isinf(quotient): + continue + lower = floor(quotient) * multiple_of + upper = ceil(quotient) * multiple_of + + # If both the lower and upper candidates are out of bounds, then there are no valid multiples that fit within + # the specified range. + if minimum is not None and maximum is not None and lower < minimum and upper > maximum: + msg = f"no multiple of {multiple_of} exists between {minimum} and {maximum}" + raise ParameterException(msg) + + for candidate in [lower, upper]: + if not passes_all_constraints(candidate): + continue + return candidate + + # Try last-ditch attempt at using the multiple_of, 0, or -multiple_of as the value + if passes_all_constraints(multiple_of): return multiple_of - result = minimum - while not passes_pydantic_multiple_validator(result, multiple_of): - result = round(method(random=random, minimum=minimum, maximum=maximum) / multiple_of) * multiple_of - return result + if passes_all_constraints(-multiple_of): + return -multiple_of + if passes_all_constraints(multiple_of * 0): + return multiple_of * 0 + + msg = f"could not find solution in {max_attempts} attempts" + raise ValueError(msg) def handle_constrained_int( diff --git a/tests/constraints/test_decimal_constraints.py b/tests/constraints/test_decimal_constraints.py index b6a76a3c..0acb5566 100644 --- a/tests/constraints/test_decimal_constraints.py +++ b/tests/constraints/test_decimal_constraints.py @@ -3,7 +3,7 @@ from typing import Optional, cast import pytest -from hypothesis import given +from hypothesis import assume, given from hypothesis.strategies import decimals, integers from pydantic import BaseModel, condecimal @@ -239,19 +239,23 @@ def test_handle_constrained_decimal_handles_multiple_of_with_le(val1: Decimal, v decimals( allow_nan=False, allow_infinity=False, - min_value=-1000000000, - max_value=1000000000, + min_value=-100000000, + max_value=100000000, ), decimals( allow_nan=False, allow_infinity=False, - min_value=-1000000000, - max_value=1000000000, + min_value=-100000000, + max_value=100000000, ), ) def test_handle_constrained_decimal_handles_multiple_of_with_ge(val1: Decimal, val2: Decimal) -> None: min_value, multiple_of = sorted([val1, val2]) if multiple_of != Decimal("0"): + # When multiple_of is too many orders of magnitude smaller than min_value, then floating-point precision issues + # prevent us from constructing a number that can pass passes_pydantic_multiple_validator(). This scenario is + # very unlikely to occur in practice, so we tell Hypothesis to not generate these cases. + assume(abs(min_value / multiple_of) < Decimal("1e8")) result = handle_constrained_decimal( random=Random(), multiple_of=multiple_of, @@ -267,23 +271,37 @@ def test_handle_constrained_decimal_handles_multiple_of_with_ge(val1: Decimal, v ) +# Note: The magnitudes of the min and max values have been specifically chosen to avoid issues with floating-point +# rounding errors. Despite these tests using Decimal numbers, the function under test will convert them to floats when +# calling `passes_pydantic_multiple_validator()`. Because `passes_pydantic_multiple_validator()` uses the modulus +# operator (%) with a fixed modulo of 1.0, we actually have to care about the absolute rounding error, not the relative +# error. IEEE 754 double-precision floating-point numbers are guaranteed to have at least 15 decimal digits of +# significand and up to 17 decimal digits of significant. `passes_pydantic_multiple_validator()` requires that the +# remainder modulo 1.0 be within 1e-8 of 0.0 or 1.0. Therefore, we can support a maximum value of approximately 10**(15 +# - 8) = 10**7. We have some probabilistic buffer, so can set a maximum value of 10**8 and expect the tests to pass with +# reasonable confidence. @given( decimals( allow_nan=False, allow_infinity=False, - min_value=-1000000000, - max_value=1000000000, + min_value=-100000000, + max_value=100000000, ), decimals( allow_nan=False, allow_infinity=False, - min_value=-1000000000, - max_value=1000000000, + min_value=-100000000, + max_value=100000000, ), ) def test_handle_constrained_decimal_handles_multiple_of_with_gt(val1: Decimal, val2: Decimal) -> None: min_value, multiple_of = sorted([val1, val2]) if multiple_of != Decimal("0"): + # Despite the note above about choosing a max_value to avoid _absolute_ rounding errors, we also have to worry + # about _relative_ rounding errors between min_value and multiple_of. Once again, + # `passes_pydantic_multiple_validator()` requires that the remainder be no greater than 1e-8, so we tell + # Hypothesis not to generate cases where the min_value and multiple_of have a ratio greater than that. + assume(abs(min_value / multiple_of) < Decimal("1e8")) result = handle_constrained_decimal( random=Random(), multiple_of=multiple_of, diff --git a/tests/test_number_generation.py b/tests/test_number_generation.py index 4800ecc7..4296617a 100644 --- a/tests/test_number_generation.py +++ b/tests/test_number_generation.py @@ -1,3 +1,4 @@ +from decimal import Decimal, localcontext from random import Random import pytest @@ -6,24 +7,61 @@ generate_constrained_number, passes_pydantic_multiple_validator, ) -from polyfactory.value_generators.primitives import create_random_float +from polyfactory.value_generators.primitives import create_random_decimal, create_random_float @pytest.mark.parametrize( ("maximum", "minimum", "multiple_of"), - ((100, 2, 8), (-100, -187, -10), (7.55, 0.13, 0.0123)), + ( + (100, 2, 8), + (-100, -187, -10), + (7.55, 0.13, 0.0123), + (None, 10, 3), + (None, -10, 3), + (13, 2, None), + (50, None, 7), + (-50, None, 7), + (None, None, 4), + (900, None, 1000), + ), ) -def test_generate_constrained_number(maximum: float, minimum: float, multiple_of: float) -> None: - assert passes_pydantic_multiple_validator( +def test_generate_constrained_number(maximum: float | None, minimum: float | None, multiple_of: float | None) -> None: + value = generate_constrained_number( + random=Random(), + minimum=minimum, + maximum=maximum, multiple_of=multiple_of, - value=generate_constrained_number( + method=create_random_float, + ) + if maximum is not None: + assert value <= maximum + if minimum is not None: + assert value >= minimum + if multiple_of is not None: + assert passes_pydantic_multiple_validator(multiple_of=multiple_of, value=value) + + +def test_generate_constrained_number_with_overprecise_decimals() -> None: + minimum = Decimal("1.0005") + maximum = Decimal("2") + multiple_of = Decimal("1.0005") + + with localcontext() as ctx: + ctx.prec = 3 + + value = generate_constrained_number( random=Random(), minimum=minimum, maximum=maximum, multiple_of=multiple_of, - method=create_random_float, - ), - ) + method=create_random_decimal, + ) + if maximum is not None: + assert value <= ctx.create_decimal(maximum) + if minimum is not None: + assert value >= ctx.create_decimal(minimum) + if multiple_of is not None: + assert passes_pydantic_multiple_validator(multiple_of=ctx.create_decimal(multiple_of), value=value) def test_passes_pydantic_multiple_validator_handles_zero_multiplier() -> None: