From 3a3dc635ad2f41fec690f0fcdb4a7e8806373e85 Mon Sep 17 00:00:00 2001 From: guacs Date: Sat, 14 Oct 2023 22:16:19 +0530 Subject: [PATCH] refactor: move the parsing of collection_type into its own function --- polyfactory/factories/base.py | 58 +++++++++-------------------------- polyfactory/utils/helpers.py | 26 +++++++++++++++- 2 files changed, 40 insertions(+), 44 deletions(-) diff --git a/polyfactory/factories/base.py b/polyfactory/factories/base.py index 9c8b42c3..5d19d976 100644 --- a/polyfactory/factories/base.py +++ b/polyfactory/factories/base.py @@ -52,7 +52,7 @@ ParameterException, ) from polyfactory.fields import Fixture, Ignore, PostGenerated, Require, Use -from polyfactory.utils.helpers import unwrap_annotation, unwrap_args, unwrap_optional +from polyfactory.utils.helpers import get_collection_type, unwrap_annotation, unwrap_args, unwrap_optional from polyfactory.utils.predicates import ( get_type_origin, is_any, @@ -495,21 +495,15 @@ def get_constrained_field_value(cls, annotation: Any, field_meta: FieldMeta) -> pattern=constraints.get("pattern"), ) - if ( - is_safe_subclass(annotation, set) - or is_safe_subclass(annotation, list) - or is_safe_subclass(annotation, frozenset) - or is_safe_subclass(annotation, tuple) - ): - collection_type: type[list | set | tuple | frozenset] - if is_safe_subclass(annotation, list): - collection_type = list - elif is_safe_subclass(annotation, set): - collection_type = set - elif is_safe_subclass(annotation, tuple): - collection_type = tuple - else: - collection_type = frozenset + try: + collection_type = get_collection_type(annotation) + if collection_type == dict: + return handle_constrained_mapping( + factory=cls, + field_meta=field_meta, + min_items=constraints.get("min_length"), + max_items=constraints.get("max_length"), + ) return handle_constrained_collection( collection_type=collection_type, # type: ignore[type-var] @@ -520,14 +514,9 @@ def get_constrained_field_value(cls, annotation: Any, field_meta: FieldMeta) -> min_items=constraints.get("min_length"), unique_items=constraints.get("unique_items", False), ) - - if is_safe_subclass(annotation, dict): - return handle_constrained_mapping( - factory=cls, - field_meta=field_meta, - min_items=constraints.get("min_length"), - max_items=constraints.get("max_length"), - ) + except ValueError: + # implies the annotation was not a collection type + pass if is_safe_subclass(annotation, date): return handle_constrained_date( @@ -599,26 +588,9 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912 return factory.batch(size=batch_size) if (origin := get_type_origin(unwrapped_annotation)) and issubclass(origin, Collection): - # There's repetition of code from the get_constrained_field_value, but that can't be directly - # called since that function depends on the field_meta.constraints for the min_items etc. but - # there are no constraints for this field meta. if cls.__randomize_collection_length__: - if ( - is_safe_subclass(unwrapped_annotation, set) - or is_safe_subclass(unwrapped_annotation, list) - or is_safe_subclass(unwrapped_annotation, frozenset) - or is_safe_subclass(unwrapped_annotation, tuple) - ): - collection_type: type[list | set | tuple | frozenset] - if is_safe_subclass(unwrapped_annotation, list): - collection_type = list - elif is_safe_subclass(unwrapped_annotation, set): - collection_type = set - elif is_safe_subclass(unwrapped_annotation, tuple): - collection_type = tuple - else: - collection_type = frozenset - + collection_type = get_collection_type(unwrapped_annotation) + if collection_type != dict: return handle_constrained_collection( collection_type=collection_type, # type: ignore[type-var] factory=cls, diff --git a/polyfactory/utils/helpers.py b/polyfactory/utils/helpers.py index ad0c3764..4cafc7de 100644 --- a/polyfactory/utils/helpers.py +++ b/polyfactory/utils/helpers.py @@ -1,7 +1,7 @@ from __future__ import annotations import sys -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Mapping from typing_extensions import get_args, get_origin @@ -10,6 +10,7 @@ is_annotated, is_new_type, is_optional, + is_safe_subclass, is_union, ) @@ -130,3 +131,26 @@ def normalize_annotation(annotation: Any, random: Random) -> Any: return origin[args] if origin is not type else annotation return origin + + +def get_collection_type(annotation: Any) -> type[list | tuple | set | frozenset | dict]: + """Get the collection type from the annotation. + + :param annotation: A type annotation. + + :returns: The collection type. + """ + + if is_safe_subclass(annotation, list): + return list + if is_safe_subclass(annotation, Mapping): + return dict + if is_safe_subclass(annotation, tuple): + return tuple + if is_safe_subclass(annotation, set): + return set + if is_safe_subclass(annotation, frozenset): + return frozenset + + msg = f"Unknown collection type - {annotation}" + raise ValueError(msg)