From 79360f18340da6ab67808a4a177349b206432bc9 Mon Sep 17 00:00:00 2001 From: guacs <126393040+guacs@users.noreply.github.com> Date: Sun, 1 Oct 2023 06:39:56 +0530 Subject: [PATCH] refactor: move creation of pydantic provider map (#396) --- polyfactory/factories/base.py | 79 +---------------------- polyfactory/factories/pydantic_factory.py | 73 ++++++++++++++++++++- tests/test_provider_map.py | 4 +- 3 files changed, 75 insertions(+), 81 deletions(-) diff --git a/polyfactory/factories/base.py b/polyfactory/factories/base.py index e7c5efd5..3a376774 100644 --- a/polyfactory/factories/base.py +++ b/polyfactory/factories/base.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from collections import Counter, abc, deque from contextlib import suppress -from datetime import date, datetime, time, timedelta, timezone +from datetime import date, datetime, time, timedelta from decimal import Decimal from enum import EnumMeta from functools import partial @@ -35,7 +35,7 @@ TypeVar, cast, ) -from uuid import NAMESPACE_DNS, UUID, uuid1, uuid3, uuid5 +from uuid import UUID from faker import Faker from typing_extensions import get_args @@ -89,80 +89,6 @@ from polyfactory.persistence import AsyncPersistenceProtocol, SyncPersistenceProtocol -def _create_pydantic_type_map(cls: type[BaseFactory[Any]]) -> dict[type, Callable[[], Any]]: - """Creates a mapping of pydantic types to mock data functions. - - :param cls: The base factory class. - :return: A dict mapping types to callables. - """ - try: - import pydantic - - mapping = { - pydantic.ByteSize: cls.__faker__.pyint, - pydantic.PositiveInt: cls.__faker__.pyint, - pydantic.NegativeFloat: lambda: cls.__random__.uniform(-100, -1), - pydantic.NegativeInt: lambda: cls.__faker__.pyint() * -1, - pydantic.PositiveFloat: cls.__faker__.pyint, - pydantic.NonPositiveFloat: lambda: cls.__random__.uniform(-100, 0), - pydantic.NonNegativeInt: cls.__faker__.pyint, - pydantic.StrictInt: cls.__faker__.pyint, - pydantic.StrictBool: cls.__faker__.pybool, - pydantic.StrictBytes: partial(create_random_bytes, cls.__random__), - pydantic.StrictFloat: cls.__faker__.pyfloat, - pydantic.StrictStr: cls.__faker__.pystr, - pydantic.EmailStr: cls.__faker__.free_email, - pydantic.NameEmail: cls.__faker__.free_email, - pydantic.Json: cls.__faker__.json, - pydantic.PaymentCardNumber: cls.__faker__.credit_card_number, - pydantic.AnyUrl: cls.__faker__.url, - pydantic.AnyHttpUrl: cls.__faker__.url, - pydantic.HttpUrl: cls.__faker__.url, - pydantic.SecretBytes: partial(create_random_bytes, cls.__random__), - pydantic.SecretStr: cls.__faker__.pystr, - pydantic.IPvAnyAddress: cls.__faker__.ipv4, - pydantic.IPvAnyInterface: cls.__faker__.ipv4, - pydantic.IPvAnyNetwork: lambda: cls.__faker__.ipv4(network=True), - pydantic.PastDate: cls.__faker__.past_date, - pydantic.FutureDate: cls.__faker__.future_date, - } - - if pydantic.VERSION.startswith("1"): - # v1 only values - these will raise an exception in v2 - # in pydantic v2 these are all aliases for Annotated with a constraint. - # we therefore do not need them in v2 - mapping.update( - { - pydantic.PyObject: lambda: "decimal.Decimal", - pydantic.AmqpDsn: lambda: "amqps://example.com", - pydantic.KafkaDsn: lambda: "kafka://localhost:9092", - pydantic.PostgresDsn: lambda: "postgresql://user:secret@localhost", - pydantic.RedisDsn: lambda: "redis://localhost:6379/0", - pydantic.FilePath: lambda: Path(realpath(__file__)), - pydantic.DirectoryPath: lambda: Path(realpath(__file__)).parent, - pydantic.UUID1: uuid1, - pydantic.UUID3: lambda: uuid3(NAMESPACE_DNS, cls.__faker__.pystr()), - pydantic.UUID4: cls.__faker__.uuid4, - pydantic.UUID5: lambda: uuid5(NAMESPACE_DNS, cls.__faker__.pystr()), - pydantic.color.Color: cls.__faker__.hex_color, # pyright: ignore[reportGeneralTypeIssues] - }, - ) - else: - mapping.update( - { - pydantic.PastDatetime: cls.__faker__.past_datetime, - pydantic.FutureDatetime: cls.__faker__.future_datetime, - pydantic.AwareDatetime: partial(cls.__faker__.date_time, timezone.utc), - pydantic.NaiveDatetime: cls.__faker__.date_time, - }, - ) - - except ImportError: - mapping = {} - - return mapping - - T = TypeVar("T") F = TypeVar("F", bound="BaseFactory[Any]") @@ -488,7 +414,6 @@ def _create_generic_fn() -> Callable: Callable: _create_generic_fn, abc.Callable: _create_generic_fn, Counter: lambda: Counter(cls.__faker__.pystr()), - **_create_pydantic_type_map(cls), } @classmethod diff --git a/polyfactory/factories/pydantic_factory.py b/polyfactory/factories/pydantic_factory.py index 8d4a45d3..e7b0f672 100644 --- a/polyfactory/factories/pydantic_factory.py +++ b/polyfactory/factories/pydantic_factory.py @@ -1,6 +1,10 @@ from __future__ import annotations from contextlib import suppress +from datetime import timezone +from functools import partial +from os.path import realpath +from pathlib import Path from typing import ( TYPE_CHECKING, Any, @@ -12,6 +16,7 @@ TypeVar, cast, ) +from uuid import NAMESPACE_DNS, uuid1, uuid3, uuid5 from typing_extensions import Literal, get_args, get_origin @@ -27,8 +32,10 @@ from polyfactory.field_meta import Constraints, FieldMeta, Null from polyfactory.utils.helpers import unwrap_new_type, unwrap_optional from polyfactory.utils.predicates import is_optional, is_safe_subclass, is_union +from polyfactory.value_generators.primitives import create_random_bytes try: + import pydantic from pydantic import VERSION, BaseModel, Json from pydantic.fields import FieldInfo except ImportError as e: @@ -47,7 +54,7 @@ if TYPE_CHECKING: from random import Random - from typing_extensions import NotRequired, TypeGuard + from typing_extensions import Callable, NotRequired, TypeGuard T = TypeVar("T", bound=BaseModel) @@ -405,3 +412,67 @@ def should_set_field_value(cls, field_meta: FieldMeta, **kwargs: Any) -> bool: return field_meta.name not in kwargs and ( not field_meta.name.startswith("_") or cls.is_custom_root_field(field_meta) ) + + @classmethod + def get_provider_map(cls) -> dict[Any, Callable[[], Any]]: + mapping = { + pydantic.ByteSize: cls.__faker__.pyint, + pydantic.PositiveInt: cls.__faker__.pyint, + pydantic.NegativeFloat: lambda: cls.__random__.uniform(-100, -1), + pydantic.NegativeInt: lambda: cls.__faker__.pyint() * -1, + pydantic.PositiveFloat: cls.__faker__.pyint, + pydantic.NonPositiveFloat: lambda: cls.__random__.uniform(-100, 0), + pydantic.NonNegativeInt: cls.__faker__.pyint, + pydantic.StrictInt: cls.__faker__.pyint, + pydantic.StrictBool: cls.__faker__.pybool, + pydantic.StrictBytes: partial(create_random_bytes, cls.__random__), + pydantic.StrictFloat: cls.__faker__.pyfloat, + pydantic.StrictStr: cls.__faker__.pystr, + pydantic.EmailStr: cls.__faker__.free_email, + pydantic.NameEmail: cls.__faker__.free_email, + pydantic.Json: cls.__faker__.json, + pydantic.PaymentCardNumber: cls.__faker__.credit_card_number, + pydantic.AnyUrl: cls.__faker__.url, + pydantic.AnyHttpUrl: cls.__faker__.url, + pydantic.HttpUrl: cls.__faker__.url, + pydantic.SecretBytes: partial(create_random_bytes, cls.__random__), + pydantic.SecretStr: cls.__faker__.pystr, + pydantic.IPvAnyAddress: cls.__faker__.ipv4, + pydantic.IPvAnyInterface: cls.__faker__.ipv4, + pydantic.IPvAnyNetwork: lambda: cls.__faker__.ipv4(network=True), + pydantic.PastDate: cls.__faker__.past_date, + pydantic.FutureDate: cls.__faker__.future_date, + } + + if pydantic.VERSION.startswith("1"): + # v1 only values - these will raise an exception in v2 + # in pydantic v2 these are all aliases for Annotated with a constraint. + # we therefore do not need them in v2 + mapping.update( + { + pydantic.PyObject: lambda: "decimal.Decimal", + pydantic.AmqpDsn: lambda: "amqps://example.com", + pydantic.KafkaDsn: lambda: "kafka://localhost:9092", + pydantic.PostgresDsn: lambda: "postgresql://user:secret@localhost", + pydantic.RedisDsn: lambda: "redis://localhost:6379/0", + pydantic.FilePath: lambda: Path(realpath(__file__)), + pydantic.DirectoryPath: lambda: Path(realpath(__file__)).parent, + pydantic.UUID1: uuid1, + pydantic.UUID3: lambda: uuid3(NAMESPACE_DNS, cls.__faker__.pystr()), + pydantic.UUID4: cls.__faker__.uuid4, + pydantic.UUID5: lambda: uuid5(NAMESPACE_DNS, cls.__faker__.pystr()), + pydantic.color.Color: cls.__faker__.hex_color, # pyright: ignore[reportGeneralTypeIssues] + }, + ) + else: + mapping.update( + { + pydantic.PastDatetime: cls.__faker__.past_datetime, + pydantic.FutureDatetime: cls.__faker__.future_datetime, + pydantic.AwareDatetime: partial(cls.__faker__.date_time, timezone.utc), + pydantic.NaiveDatetime: cls.__faker__.date_time, + }, + ) + + mapping.update(super().get_provider_map()) + return mapping diff --git a/tests/test_provider_map.py b/tests/test_provider_map.py index 0dd42bb4..07fd538f 100644 --- a/tests/test_provider_map.py +++ b/tests/test_provider_map.py @@ -1,13 +1,11 @@ from typing import Any -from polyfactory.factories.base import BaseFactory, _create_pydantic_type_map +from polyfactory.factories.base import BaseFactory def test_provider_map() -> None: provider_map = BaseFactory.get_provider_map() provider_map.pop(Any) - for key in _create_pydantic_type_map(BaseFactory): # type: ignore[type-abstract] - provider_map.pop(key) for type_, handler in provider_map.items(): value = handler()