Skip to content

Commit

Permalink
fix: handle tuple and randomized length
Browse files Browse the repository at this point in the history
  • Loading branch information
adhtruong committed Aug 4, 2024
1 parent a8bb48d commit c8dc24d
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 26 deletions.
22 changes: 15 additions & 7 deletions polyfactory/collection_extender.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import random
from abc import ABC, abstractmethod
from collections import deque
from typing import Any
Expand Down Expand Up @@ -53,29 +54,36 @@ class ListLikeExtender(CollectionExtender):
__types__ = (list, deque)

@staticmethod
def _extend_type_args(type_args: tuple[Any, ...], _: int) -> tuple[Any, ...]:
return type_args
def _extend_type_args(type_args: tuple[Any, ...], number_of_args: int) -> tuple[Any, ...]:
if not type_args:
return type_args
return tuple(random.choice(type_args) for _ in range(number_of_args))


class SetExtender(CollectionExtender):
__types__ = (set, frozenset)

@staticmethod
def _extend_type_args(type_args: tuple[Any, ...], _: int) -> tuple[Any, ...]:
return type_args
def _extend_type_args(type_args: tuple[Any, ...], number_of_args: int) -> tuple[Any, ...]:
if not type_args:
return type_args
return tuple(random.choice(type_args) for _ in range(number_of_args))


class DictExtender(CollectionExtender):
__types__ = (dict,)

@staticmethod
def _extend_type_args(type_args: tuple[Any, ...], _: int) -> tuple[Any, ...]:
return type_args
def _extend_type_args(type_args: tuple[Any, ...], number_of_args: int) -> tuple[Any, ...]:
return type_args * number_of_args


class FallbackExtender(CollectionExtender):
__types__ = ()

@staticmethod
def _extend_type_args(type_args: tuple[Any, ...], _: int) -> tuple[Any, ...]:
def _extend_type_args(
type_args: tuple[Any, ...],
number_of_args: int, # noqa: ARG004
) -> tuple[Any, ...]: # - investigate @guacs
return type_args
13 changes: 10 additions & 3 deletions polyfactory/factories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,8 +746,11 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912
return factory.batch(size=batch_size, _build_context=build_context)

if (origin := get_type_origin(unwrapped_annotation)) and is_safe_subclass(origin, Collection):
if cls.__randomize_collection_length__:
collection_type = get_collection_type(unwrapped_annotation)
collection_type = get_collection_type(unwrapped_annotation)
is_fixed_length = collection_type is tuple and (
not field_meta.children or field_meta.children[-1].annotation != Ellipsis
)
if cls.__randomize_collection_length__ and not is_fixed_length:
if collection_type is not dict:
return handle_constrained_collection(
collection_type=collection_type, # type: ignore[type-var]
Expand All @@ -769,7 +772,11 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912
)

return handle_collection_type(
field_meta, origin, cls, field_build_parameters=field_build_parameters, build_context=build_context
field_meta,
origin,
cls,
field_build_parameters=field_build_parameters,
build_context=build_context,
)

if provider := cls.get_provider_map().get(unwrapped_annotation):
Expand Down
5 changes: 1 addition & 4 deletions polyfactory/field_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from typing_extensions import get_args, get_origin

from polyfactory.collection_extender import CollectionExtender
from polyfactory.constants import DEFAULT_RANDOM, TYPE_MAPPING
from polyfactory.utils.deprecation import check_for_deprecated_parameters
from polyfactory.utils.helpers import get_annotation_metadata, unwrap_annotated, unwrap_new_type
Expand Down Expand Up @@ -152,14 +151,12 @@ def from_type(
)

if field.type_args and not field.children:
number_of_args = 1
extended_type_args = CollectionExtender.extend_type_args(field.annotation, field.type_args, number_of_args)
field.children = [
cls.from_type(
annotation=unwrap_new_type(arg),
random=random,
)
for arg in extended_type_args
for arg in field.type_args
if arg is not NoneType
]
return field
Expand Down
5 changes: 4 additions & 1 deletion polyfactory/utils/helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import sys
from collections import deque
from typing import TYPE_CHECKING, Any, Mapping

from typing_extensions import TypeAliasType, get_args, get_origin
Expand Down Expand Up @@ -173,7 +174,7 @@ def get_annotation_metadata(annotation: Any) -> Sequence[Any]:
return get_args(annotation)[1:]


def get_collection_type(annotation: Any) -> type[list | tuple | set | frozenset | dict]:
def get_collection_type(annotation: Any) -> type[list | tuple | set | frozenset | dict | deque]:
"""Get the collection type from the annotation.
:param annotation: A type annotation.
Expand All @@ -191,6 +192,8 @@ def get_collection_type(annotation: Any) -> type[list | tuple | set | frozenset
return set
if is_safe_subclass(annotation, frozenset):
return frozenset
if is_safe_subclass(annotation, deque):
return deque

msg = f"Unknown collection type - {annotation}"
raise ValueError(msg)
13 changes: 12 additions & 1 deletion polyfactory/value_generators/complex_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,14 @@ def handle_collection_type(
return container_type(
factory.get_field_value(child, field_build_parameters=field_build_parameters, build_context=build_context)
for child in field_meta.children
if child.annotation != Ellipsis
)

msg = f"Unsupported container type: {container_type}"
raise NotImplementedError(msg)


def handle_collection_type_coverage(
def handle_collection_type_coverage( # noqa: C901, PLR0911
field_meta: FieldMeta,
container_type: type,
factory: type[BaseFactory[Any]],
Expand Down Expand Up @@ -136,6 +137,16 @@ def handle_collection_type_coverage(
return container.union(handle_collection_type_coverage(field_meta, set, factory, build_context=build_context))

if issubclass(container_type, tuple):
if field_meta.children[-1].annotation == Ellipsis:
return (
CoverageContainer(
factory.get_field_value_coverage(
field_meta.children[0],
build_context=build_context,
)
),
)

return container_type(
CoverageContainer(factory.get_field_value_coverage(subfield_meta, build_context=build_context))
for subfield_meta in field_meta.children
Expand Down
46 changes: 36 additions & 10 deletions tests/test_dataclass_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass as vanilla_dataclass
from dataclasses import field
from types import ModuleType
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from unittest.mock import ANY

import pytest
Expand Down Expand Up @@ -108,10 +108,6 @@ class MyFactory(DataclassFactory):
assert result.east.people


def function_with_kwargs(first: int, second: float, third: str = "moishe") -> None:
pass


def test_complex_embedded_dataclass() -> None:
@vanilla_dataclass
class VanillaDC:
Expand All @@ -133,18 +129,48 @@ class MyFactory(DataclassFactory):
assert isinstance(next(iter(next(iter(result.weirdly_nest_field[0].values())).values())), VanillaDC)


def test_tuple_ellipsis_in_vanilla_dc() -> None:
@pytest.mark.parametrize(
"factory_config, expected_length",
(
({}, 1),
(
{
"__randomize_collection_length__": True,
"__min_collection_length__": 3,
"__max_collection_length__": 3,
},
3,
),
(
{
"__randomize_collection_length__": True,
"__min_collection_length__": 0,
"__max_collection_length__": 0,
},
0,
),
),
)
def test_tuple_in_vanilla_dc(factory_config: Dict[str, Any], expected_length: int) -> None:
@vanilla_dataclass
class VanillaDC:
ids: Tuple[int, ...]
field: Tuple[int, str]

class MyFactory(DataclassFactory[VanillaDC]):
__model__ = VanillaDC

MyFactory = DataclassFactory[VanillaDC].create_factory(VanillaDC, **factory_config)
result = MyFactory.build()

assert result
assert result.ids
assert len(result.ids) == expected_length
assert len(result.field) == 2
assert isinstance(result.field[0], int)
assert isinstance(result.field[1], str)

coverage_results = list(MyFactory.coverage())
assert all(len(result.ids) == 1 for result in coverage_results)
assert all(len(result.field) == 2 for result in coverage_results)
assert all(isinstance(result.field[0], int) for result in coverage_results)
assert all(isinstance(result.field[1], str) for result in coverage_results)


def test_dataclass_factory_with_future_annotations(create_module: Callable[[str], ModuleType]) -> None:
Expand Down

0 comments on commit c8dc24d

Please sign in to comment.