Skip to content

Commit

Permalink
fix: respect both multiple_of and minimum/maximum constraints
Browse files Browse the repository at this point in the history
Previously, `generate_constrained_number()` would potentially generate
invalid numbers when `multiple_of` is not None and exactly one of either
`minimum` or `maximum` is not None, since it would just return
`multiple_of` without respecting the upper or lower bound.

This significantly changes the implementation of the code to correctly
handle this code. The `generate_constrained_number()` method has been
completely removed, being replaced with a
`generate_constrained_multiple_of()` function. A major difference
between the old function and the new function is that the new one does
not accept a `method` argument for generating random numbers. This is
because in the new function, we always use `create_random_integer()`,
since the problem reduces to generating a random integer multiplier.

The high-level algorithm behind `generate_constrained_multiple_of()` is
that we need to constrain the random integer generator to generate
numbers such that when they are multiplied with `multiple_of`, they
still fit within the original bounds constraints. This simplify involves
dividing the original bounds by `multiple_of`, with some special
handling for negative `multiple_of` numbers as well as carefully chosen
rounding behavior.

We also need to make some changes to other functions.

`get_increment()` needs to take an additional argument for the actual
value that the increment is for. This is because floating-point numbers
can't use a static increment or else it might get rounded away if the
numbers are too large. Python fortunately provides a `math.ulp()`
function for computing this for a given float value, so we make use of
that function. We still use the original `float_info.epsilon` constant
as a lower bound on the increment, though, since in the case that the
value is too close to zero, we still need to make sure that the
increment doesn't disappear when used against other numbers.

Finally, we rename and modify `passes_pydantic_multiple_validator()` to
`is_almost_multiple_of()`, modifying its implementation to defer the
casting of values to `float()` to minimize rounding errors. This
specifically affects Decimal numbers, where casting to float too early
causes too much loss of precision.

A significant number of changes were made to the tests as well, since
the original tests missed the bug being fixed here. Each of the integer,
floating-point, and decimal tests has been updated to assert that the
result is actually within the minimum and maximum constraints. In
addition, we remove some unnecessary sorting of the randomly generated
test input values, since this was unnecessarily constraining
`multiple_of` to be greater than or less than the minimum and maximum
values. This was causing a lot of the scenarios involving negative
values to be skipped.

Lastly, the floating-point and decimal tests need additional constraints
to avoid unrealistic extreme values from hitting precision issues. This
was done by adding a number of constraints on the number of significant
digits in the input numbers and on the relative magnitudes of the input
numbers.
  • Loading branch information
richardxia committed Mar 18, 2024
1 parent c4e3d91 commit 8881603
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 130 deletions.
113 changes: 62 additions & 51 deletions polyfactory/value_generators/constrained_numbers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

from decimal import Decimal
from math import ceil, floor, ulp
from sys import float_info
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast
from typing import TYPE_CHECKING, Protocol, TypeVar, cast

from polyfactory.exceptions import ParameterException
from polyfactory.value_generators.primitives import create_random_decimal, create_random_float, create_random_integer
Expand Down Expand Up @@ -99,8 +100,8 @@ def is_multiply_of_multiple_of_in_range(
return False


def passes_pydantic_multiple_validator(value: T, multiple_of: T) -> bool:
"""Determine whether a given value passes the pydantic multiple_of validation.
def is_almost_multiple_of(value: T, multiple_of: T) -> bool:
"""Determine whether a given ``value`` is a close enough to a multiple of ``multiple_of``.
:param value: A numeric value.
:param multiple_of: Another numeric value.
Expand All @@ -110,23 +111,33 @@ def passes_pydantic_multiple_validator(value: T, multiple_of: T) -> bool:
"""
if multiple_of == 0:
return True
mod = float(value) / float(multiple_of) % 1
return almost_equal_floats(mod, 0.0) or almost_equal_floats(mod, 1.0)
mod = value % multiple_of
return almost_equal_floats(float(mod), 0.0) or almost_equal_floats(float(abs(mod)), float(abs(multiple_of)))


def get_increment(t_type: type[T]) -> T:
def get_increment(value: T, t_type: type[T]) -> T:
"""Get a small increment base to add to constrained values, i.e. lt/gt entries.
:param t_type: A value of type T.
:param value: A value of type T.
:param t_type: The type of ``value``.
:returns: An increment T.
"""
values: dict[Any, Any] = {
int: 1,
float: float_info.epsilon,
Decimal: Decimal("0.001"),
}
return cast("T", values[t_type])
# See https://github.com/python/mypy/issues/17045 for why the redundant casts are ignored.
if t_type == int:
return cast("T", 1)
if t_type == float:
# When ``value`` is large in magnitude, we need to choose an increment that is large enough
# to not be rounded away, but when ``value`` small in magnitude, we need to prevent the
# incerement from vanishing. ``float_info.epsilon`` is defined as the smallest delta that
# can be represented between 1.0 and the next largest number, but it's not sufficient for
# larger values. ``ulp(x)`` will return smallest delta that can be added to ``x``.
return cast("T", max(ulp(value), float_info.epsilon)) # type: ignore[redundant-cast]
if t_type == Decimal:
return cast("T", Decimal("0.001")) # type: ignore[redundant-cast]

msg = f"invalid t_type: {t_type}"
raise AssertionError(msg)


def get_value_or_none(
Expand All @@ -147,14 +158,14 @@ def get_value_or_none(
if ge is not None:
minimum_value = ge
elif gt is not None:
minimum_value = gt + get_increment(t_type)
minimum_value = gt + get_increment(gt, t_type)
else:
minimum_value = None

if le is not None:
maximum_value = le
elif lt is not None:
maximum_value = lt - get_increment(t_type)
maximum_value = lt - get_increment(lt, t_type)
else:
maximum_value = None
return minimum_value, maximum_value
Expand Down Expand Up @@ -210,33 +221,36 @@ def get_constrained_number_range(
return minimum, maximum


def generate_constrained_number(
def generate_constrained_multiple_of(
random: Random,
minimum: T | None,
maximum: T | None,
multiple_of: T | None,
method: "NumberGeneratorProtocol[T]",
multiple_of: T,
) -> T:
"""Generate a constrained number, output depends on the passed in callbacks.
"""Generate a constrained multiple of ``multiple_of``.
:param random: An instance of random.
:param minimum: A minimum value.
:param maximum: A maximum value.
:param multiple_of: A multiple of value.
:param method: A function that generates numbers of type T.
:returns: A value of type T.
"""
if minimum is None or maximum is None:
return multiple_of if multiple_of is not None else method(random=random)
if multiple_of is None:
return method(random=random, minimum=minimum, maximum=maximum)
if multiple_of >= minimum:
return multiple_of
result = minimum
while not passes_pydantic_multiple_validator(result, multiple_of):
result = round(method(random=random, minimum=minimum, maximum=maximum) / multiple_of) * multiple_of
return result

# Regardless of the type of ``multiple_of``, we can generate a valid multiple of it by
# multiplying it with any integer, which we call a multiplier. We will randomly generate the
# multiplier as a random integer, but we need to translate the original bounds, if any, to the
# correct bounds on the multiplier so that the resulting product will meet the original
# constraints.

if multiple_of < 0:
minimum, maximum = maximum, minimum

multiplier_min = ceil(minimum / multiple_of) if minimum is not None else None
multiplier_max = floor(maximum / multiple_of) if maximum is not None else None
multiplier = create_random_integer(random=random, minimum=multiplier_min, maximum=multiplier_max)

return multiplier * multiple_of


def handle_constrained_int(
Expand Down Expand Up @@ -269,13 +283,11 @@ def handle_constrained_int(
multiple_of=multiple_of,
random=random,
)
return generate_constrained_number(
random=random,
minimum=minimum,
maximum=maximum,
multiple_of=multiple_of,
method=create_random_integer,
)

if multiple_of is None:
return create_random_integer(random=random, minimum=minimum, maximum=maximum)

return generate_constrained_multiple_of(random=random, minimum=minimum, maximum=maximum, multiple_of=multiple_of)


def handle_constrained_float(
Expand Down Expand Up @@ -308,13 +320,10 @@ def handle_constrained_float(
random=random,
)

return generate_constrained_number(
random=random,
minimum=minimum,
maximum=maximum,
multiple_of=multiple_of,
method=create_random_float,
)
if multiple_of is None:
return create_random_float(random=random, minimum=minimum, maximum=maximum)

return generate_constrained_multiple_of(random=random, minimum=minimum, maximum=maximum, multiple_of=multiple_of)


def validate_max_digits(
Expand Down Expand Up @@ -422,13 +431,15 @@ def handle_constrained_decimal(
if max_digits is not None:
validate_max_digits(max_digits=max_digits, minimum=minimum, decimal_places=decimal_places)

generated_decimal = generate_constrained_number(
random=random,
minimum=minimum,
maximum=maximum,
multiple_of=multiple_of,
method=create_random_decimal,
)
if multiple_of is None:
generated_decimal = create_random_decimal(random=random, minimum=minimum, maximum=maximum)
else:
generated_decimal = generate_constrained_multiple_of(
random=random,
minimum=minimum,
maximum=maximum,
multiple_of=multiple_of,
)

if max_digits is not None or decimal_places is not None:
return handle_decimal_length(
Expand Down
61 changes: 43 additions & 18 deletions tests/constraints/test_decimal_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional, cast

import pytest
from hypothesis import given
from hypothesis import assume, given
from hypothesis.strategies import decimals, integers

from pydantic import BaseModel, condecimal
Expand All @@ -13,11 +13,24 @@
from polyfactory.value_generators.constrained_numbers import (
handle_constrained_decimal,
handle_decimal_length,
is_almost_multiple_of,
is_multiply_of_multiple_of_in_range,
passes_pydantic_multiple_validator,
)


def assume_max_digits(x: Decimal, max_digits: int) -> None:
"""
Signal to Hypothesis that ``x`` should have at most ``max_digits`` significant digits.
This is different than the ``decimals()`` strategy function's ``places`` keyword argument, which
only counts the digits after the decimal point when the number is written without an exponent.
E.g. 12.51 has 4 significant digits but 2 decimal places.
"""

assume(len(x.as_tuple().digits) <= max_digits)


def test_handle_constrained_decimal_without_constraints() -> None:
result = handle_constrained_decimal(
random=Random(),
Expand Down Expand Up @@ -162,7 +175,7 @@ def test_handle_constrained_decimal_handles_multiple_of(multiple_of: Decimal) ->
random=Random(),
multiple_of=multiple_of,
)
assert passes_pydantic_multiple_validator(result, multiple_of)
assert is_almost_multiple_of(result, multiple_of)
else:
with pytest.raises(ParameterException):
handle_constrained_decimal(
Expand All @@ -185,15 +198,17 @@ def test_handle_constrained_decimal_handles_multiple_of(multiple_of: Decimal) ->
max_value=1000000000,
),
)
def test_handle_constrained_decimal_handles_multiple_of_with_lt(val1: Decimal, val2: Decimal) -> None:
multiple_of, max_value = sorted([val1, val2])
def test_handle_constrained_decimal_handles_multiple_of_with_lt(max_value: Decimal, multiple_of: Decimal) -> None:
if multiple_of != Decimal("0"):
assume_max_digits(max_value, 10)
assume_max_digits(multiple_of, 10)
result = handle_constrained_decimal(
random=Random(),
multiple_of=multiple_of,
lt=max_value,
)
assert passes_pydantic_multiple_validator(result, multiple_of)
assert result < max_value
assert is_almost_multiple_of(result, multiple_of)
else:
with pytest.raises(ParameterException):
handle_constrained_decimal(
Expand All @@ -217,15 +232,17 @@ def test_handle_constrained_decimal_handles_multiple_of_with_lt(val1: Decimal, v
max_value=1000000000,
),
)
def test_handle_constrained_decimal_handles_multiple_of_with_le(val1: Decimal, val2: Decimal) -> None:
multiple_of, max_value = sorted([val1, val2])
def test_handle_constrained_decimal_handles_multiple_of_with_le(max_value: Decimal, multiple_of: Decimal) -> None:
if multiple_of != Decimal("0"):
assume_max_digits(max_value, 10)
assume_max_digits(multiple_of, 10)
result = handle_constrained_decimal(
random=Random(),
multiple_of=multiple_of,
le=max_value,
)
assert passes_pydantic_multiple_validator(result, multiple_of)
assert result <= max_value
assert is_almost_multiple_of(result, multiple_of)
else:
with pytest.raises(ParameterException):
handle_constrained_decimal(
Expand All @@ -249,15 +266,17 @@ def test_handle_constrained_decimal_handles_multiple_of_with_le(val1: Decimal, v
max_value=1000000000,
),
)
def test_handle_constrained_decimal_handles_multiple_of_with_ge(val1: Decimal, val2: Decimal) -> None:
min_value, multiple_of = sorted([val1, val2])
def test_handle_constrained_decimal_handles_multiple_of_with_ge(min_value: Decimal, multiple_of: Decimal) -> None:
if multiple_of != Decimal("0"):
assume_max_digits(min_value, 10)
assume_max_digits(multiple_of, 10)
result = handle_constrained_decimal(
random=Random(),
multiple_of=multiple_of,
ge=min_value,
)
assert passes_pydantic_multiple_validator(result, multiple_of)
assert min_value <= result
assert is_almost_multiple_of(result, multiple_of)
else:
with pytest.raises(ParameterException):
handle_constrained_decimal(
Expand All @@ -281,15 +300,17 @@ def test_handle_constrained_decimal_handles_multiple_of_with_ge(val1: Decimal, v
max_value=1000000000,
),
)
def test_handle_constrained_decimal_handles_multiple_of_with_gt(val1: Decimal, val2: Decimal) -> None:
min_value, multiple_of = sorted([val1, val2])
def test_handle_constrained_decimal_handles_multiple_of_with_gt(min_value: Decimal, multiple_of: Decimal) -> None:
if multiple_of != Decimal("0"):
assume_max_digits(min_value, 10)
assume_max_digits(multiple_of, 10)
result = handle_constrained_decimal(
random=Random(),
multiple_of=multiple_of,
gt=min_value,
)
assert passes_pydantic_multiple_validator(result, multiple_of)
assert min_value < result
assert is_almost_multiple_of(result, multiple_of)
else:
with pytest.raises(ParameterException):
handle_constrained_decimal(
Expand Down Expand Up @@ -322,21 +343,25 @@ def test_handle_constrained_decimal_handles_multiple_of_with_gt(val1: Decimal, v
def test_handle_constrained_decimal_handles_multiple_of_with_ge_and_le(
val1: Decimal,
val2: Decimal,
val3: Decimal,
multiple_of: Decimal,
) -> None:
min_value, multiple_of, max_value = sorted([val1, val2, val3])
min_value, max_value = sorted([val1, val2])
if multiple_of != Decimal("0") and is_multiply_of_multiple_of_in_range(
minimum=min_value,
maximum=max_value,
multiple_of=multiple_of,
):
assume_max_digits(min_value, 10)
assume_max_digits(max_value, 10)
assume_max_digits(multiple_of, 10)
result = handle_constrained_decimal(
random=Random(),
multiple_of=multiple_of,
ge=min_value,
le=max_value,
)
assert passes_pydantic_multiple_validator(result, multiple_of)
assert min_value <= result <= max_value
assert is_almost_multiple_of(result, multiple_of)
else:
with pytest.raises(ParameterException):
handle_constrained_decimal(
Expand Down
Loading

0 comments on commit 8881603

Please sign in to comment.