Skip to content

Commit

Permalink
refactor: move the parsing of collection_type into its own function
Browse files Browse the repository at this point in the history
  • Loading branch information
guacs committed Oct 14, 2023
1 parent 80b8b46 commit 3a3dc63
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 44 deletions.
58 changes: 15 additions & 43 deletions polyfactory/factories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
ParameterException,
)
from polyfactory.fields import Fixture, Ignore, PostGenerated, Require, Use
from polyfactory.utils.helpers import unwrap_annotation, unwrap_args, unwrap_optional
from polyfactory.utils.helpers import get_collection_type, unwrap_annotation, unwrap_args, unwrap_optional
from polyfactory.utils.predicates import (
get_type_origin,
is_any,
Expand Down Expand Up @@ -495,21 +495,15 @@ def get_constrained_field_value(cls, annotation: Any, field_meta: FieldMeta) ->
pattern=constraints.get("pattern"),
)

if (
is_safe_subclass(annotation, set)
or is_safe_subclass(annotation, list)
or is_safe_subclass(annotation, frozenset)
or is_safe_subclass(annotation, tuple)
):
collection_type: type[list | set | tuple | frozenset]
if is_safe_subclass(annotation, list):
collection_type = list
elif is_safe_subclass(annotation, set):
collection_type = set
elif is_safe_subclass(annotation, tuple):
collection_type = tuple
else:
collection_type = frozenset
try:
collection_type = get_collection_type(annotation)
if collection_type == dict:
return handle_constrained_mapping(
factory=cls,
field_meta=field_meta,
min_items=constraints.get("min_length"),
max_items=constraints.get("max_length"),
)

return handle_constrained_collection(
collection_type=collection_type, # type: ignore[type-var]
Expand All @@ -520,14 +514,9 @@ def get_constrained_field_value(cls, annotation: Any, field_meta: FieldMeta) ->
min_items=constraints.get("min_length"),
unique_items=constraints.get("unique_items", False),
)

if is_safe_subclass(annotation, dict):
return handle_constrained_mapping(
factory=cls,
field_meta=field_meta,
min_items=constraints.get("min_length"),
max_items=constraints.get("max_length"),
)
except ValueError:
# implies the annotation was not a collection type
pass

if is_safe_subclass(annotation, date):
return handle_constrained_date(
Expand Down Expand Up @@ -599,26 +588,9 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912
return factory.batch(size=batch_size)

if (origin := get_type_origin(unwrapped_annotation)) and issubclass(origin, Collection):
# There's repetition of code from the get_constrained_field_value, but that can't be directly
# called since that function depends on the field_meta.constraints for the min_items etc. but
# there are no constraints for this field meta.
if cls.__randomize_collection_length__:
if (
is_safe_subclass(unwrapped_annotation, set)
or is_safe_subclass(unwrapped_annotation, list)
or is_safe_subclass(unwrapped_annotation, frozenset)
or is_safe_subclass(unwrapped_annotation, tuple)
):
collection_type: type[list | set | tuple | frozenset]
if is_safe_subclass(unwrapped_annotation, list):
collection_type = list
elif is_safe_subclass(unwrapped_annotation, set):
collection_type = set
elif is_safe_subclass(unwrapped_annotation, tuple):
collection_type = tuple
else:
collection_type = frozenset

collection_type = get_collection_type(unwrapped_annotation)
if collection_type != dict:
return handle_constrained_collection(
collection_type=collection_type, # type: ignore[type-var]
factory=cls,
Expand Down
26 changes: 25 additions & 1 deletion polyfactory/utils/helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import sys
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Mapping

from typing_extensions import get_args, get_origin

Expand All @@ -10,6 +10,7 @@
is_annotated,
is_new_type,
is_optional,
is_safe_subclass,
is_union,
)

Expand Down Expand Up @@ -130,3 +131,26 @@ def normalize_annotation(annotation: Any, random: Random) -> Any:
return origin[args] if origin is not type else annotation

return origin


def get_collection_type(annotation: Any) -> type[list | tuple | set | frozenset | dict]:
"""Get the collection type from the annotation.
:param annotation: A type annotation.
:returns: The collection type.
"""

if is_safe_subclass(annotation, list):
return list
if is_safe_subclass(annotation, Mapping):
return dict
if is_safe_subclass(annotation, tuple):
return tuple
if is_safe_subclass(annotation, set):
return set
if is_safe_subclass(annotation, frozenset):
return frozenset

msg = f"Unknown collection type - {annotation}"
raise ValueError(msg)

0 comments on commit 3a3dc63

Please sign in to comment.