Skip to content

Commit

Permalink
fix: Adjust min/max items to valid lengths for Set[Enum] fields (#567)
Browse files Browse the repository at this point in the history
Co-authored-by: guacs <[email protected]>
  • Loading branch information
adrianeboyd and guacs authored Sep 13, 2024
1 parent 841831d commit 9a83ad6
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 2 deletions.
13 changes: 12 additions & 1 deletion polyfactory/value_generators/constrained_collections.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, List, Mapping, TypeVar, cast
from enum import EnumMeta
from typing import TYPE_CHECKING, Any, Callable, List, Literal, Mapping, TypeVar, cast

from polyfactory.exceptions import ParameterException
from polyfactory.field_meta import FieldMeta
Expand Down Expand Up @@ -43,6 +44,16 @@ def handle_constrained_collection(
msg = "max_items must be larger or equal to min_items"
raise ParameterException(msg)

if collection_type in (frozenset, set) or unique_items:
max_field_values = max_items
if hasattr(field_meta.annotation, "__origin__") and field_meta.annotation.__origin__ is Literal:
if field_meta.children is not None:
max_field_values = len(field_meta.children)
elif isinstance(field_meta.annotation, EnumMeta):
max_field_values = len(field_meta.annotation)
min_items = min(min_items, max_field_values)
max_items = min(max_items, max_field_values)

collection: set[T] | list[T] = set() if (collection_type in (frozenset, set) or unique_items) else []

try:
Expand Down
52 changes: 51 additions & 1 deletion tests/test_collection_length.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import Any, Dict, List, Optional, Set, Tuple
from enum import Enum
from typing import Any, Dict, FrozenSet, List, Literal, Optional, Set, Tuple, get_args

import pytest

from pydantic import BaseModel
from pydantic.dataclasses import dataclass

from polyfactory.factories import DataclassFactory
from polyfactory.factories.pydantic_factory import ModelFactory

MIN_MAX_PARAMETERS = ((10, 15), (20, 25), (30, 40), (40, 50))

Expand Down Expand Up @@ -132,3 +135,50 @@ class FooFactory(DataclassFactory[Foo]):

assert len(foo.foo) >= min_val, len(foo.foo)
assert len(foo.foo) <= max_val, len(foo.foo)


@pytest.mark.parametrize("type_", (List, FrozenSet, Set))
@pytest.mark.parametrize("min_items", (0, 2, 4))
@pytest.mark.parametrize("max_inc", (0, 1, 4))
def test_collection_length_with_literal(type_: type, min_items: int, max_inc: int) -> None:
max_items = min_items + max_inc
literal_type = Literal["Dog", "Cat", "Monkey"]

@dataclass
class MyModel:
animal_collection: type_[literal_type] # type: ignore

class MyFactory(DataclassFactory):
__model__ = MyModel
__randomize_collection_length__ = True
__min_collection_length__ = min_items
__max_collection_length__ = max_items

result = MyFactory.build()
assert len(result.animal_collection) >= min(min_items, len(get_args(literal_type)))
assert len(result.animal_collection) <= max_items


@pytest.mark.parametrize("type_", (List, FrozenSet, Set))
@pytest.mark.parametrize("min_items", (0, 2, 4))
@pytest.mark.parametrize("max_inc", (0, 1, 4))
def test_collection_length_with_enum(type_: type, min_items: int, max_inc: int) -> None:
max_items = min_items + max_inc

class Animal(str, Enum):
DOG = "Dog"
CAT = "Cat"
MONKEY = "Monkey"

class MyModel(BaseModel):
animal_collection: type_[Animal] # type: ignore

class MyFactory(ModelFactory):
__model__ = MyModel
__randomize_collection_length__ = True
__min_collection_length__ = min_items
__max_collection_length__ = max_items

result = MyFactory.build()
assert len(result.animal_collection) >= min(min_items, len(Animal))
assert len(result.animal_collection) <= max_items

0 comments on commit 9a83ad6

Please sign in to comment.