diff --git a/docs/usage/library_factories/index.rst b/docs/usage/library_factories/index.rst index 20911acf..8ccfebb4 100644 --- a/docs/usage/library_factories/index.rst +++ b/docs/usage/library_factories/index.rst @@ -11,9 +11,13 @@ These include: :class:`TypedDictFactory ` a base factory for typed-dicts + :class:`ModelFactory ` a base factory for `pydantic `_ models +:class:`PydanticDataclassFactory ` + a base factory for `pydantic `_ dataclasses + :class:`BeanieDocumentFactory ` a base factory for `beanie `_ documents diff --git a/polyfactory/factories/pydantic_factory.py b/polyfactory/factories/pydantic_factory.py index 0d53f341..3c4b26d2 100644 --- a/polyfactory/factories/pydantic_factory.py +++ b/polyfactory/factories/pydantic_factory.py @@ -2,6 +2,7 @@ import copy from contextlib import suppress +from dataclasses import is_dataclass from datetime import timezone from functools import partial from os.path import realpath @@ -60,9 +61,15 @@ # is installed. from pydantic import PyObject - # prevent unbound variable warnings + # Prevent unbound variable warnings BaseModelV2 = BaseModelV1 UndefinedV2 = Undefined + + if TYPE_CHECKING: + from pydantic.dataclasses import Dataclass as PydanticDataclassV1 # pyright: ignore[reportPrivateImportUsage] + + # Prevent unbound variable warnings + PydanticDataclassV2 = PydanticDataclassV1 except ImportError: # pydantic v2 @@ -91,6 +98,8 @@ from pydantic.v1.color import Color # type: ignore[assignment] from pydantic.v1.fields import DeferredType, ModelField, Undefined + if TYPE_CHECKING: + from pydantic.dataclasses import PydanticDataclass as PydanticDataclassV2 # pyright: ignore[reportPrivateImportUsage] if TYPE_CHECKING: from collections import abc @@ -99,7 +108,9 @@ from typing_extensions import NotRequired, TypeGuard -T = TypeVar("T", bound="BaseModelV1 | BaseModelV2") # pyright: ignore[reportInvalidTypeForm] + +ModelT = TypeVar("ModelT", bound="BaseModelV1 | BaseModelV2") # pyright: ignore[reportInvalidTypeForm] +T = TypeVar("T") _IS_PYDANTIC_V1 = VERSION.startswith("1") @@ -370,7 +381,7 @@ def get_constraints_metadata(cls, annotation: Any) -> Sequence[Any]: return metadata -class ModelFactory(Generic[T], BaseFactory[T]): +class ModelFactory(Generic[ModelT], BaseFactory[ModelT]): """Base factory for pydantic models""" __forward_ref_resolution_type_mapping__: ClassVar[Mapping[str, type]] = {} @@ -388,7 +399,7 @@ def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None: cls.__model__.update_forward_refs(**cls.__forward_ref_resolution_type_mapping__) # type: ignore[attr-defined] @classmethod - def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: + def is_supported_type(cls, value: Any) -> TypeGuard[type[ModelT]]: """Determine whether the given value is supported by the factory. :param value: An arbitrary value. @@ -454,7 +465,7 @@ def build( cls, factory_use_construct: bool = False, **kwargs: Any, - ) -> T: + ) -> ModelT: """Build an instance of the factory's __model__ :param factory_use_construct: A boolean that determines whether validations will be made when instantiating the @@ -492,7 +503,7 @@ def _get_build_context(cls, build_context: BaseBuildContext | PydanticBuildConte } @classmethod - def _create_model(cls, _build_context: PydanticBuildContext, **kwargs: Any) -> T: + def _create_model(cls, _build_context: PydanticBuildContext, **kwargs: Any) -> ModelT: """Create an instance of the factory's __model__ :param _build_context: BuildContext instance. @@ -508,7 +519,7 @@ def _create_model(cls, _build_context: PydanticBuildContext, **kwargs: Any) -> T return cls.__model__(**kwargs) # type: ignore[return-value] @classmethod - def coverage(cls, factory_use_construct: bool = False, **kwargs: Any) -> abc.Iterator[T]: + def coverage(cls, factory_use_construct: bool = False, **kwargs: Any) -> abc.Iterator[ModelT]: """Build a batch of the factory's Meta.model will full coverage of the sub-types of the model. :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. @@ -629,3 +640,54 @@ def _is_pydantic_v1_model(model: Any) -> TypeGuard[BaseModelV1]: def _is_pydantic_v2_model(model: Any) -> TypeGuard[BaseModelV2]: # pyright: ignore[reportInvalidTypeForm] return not _IS_PYDANTIC_V1 and is_safe_subclass(model, BaseModelV2) + + +def _is_pydantic_v1_dataclass(cls: type[Any]) -> TypeGuard[PydanticDataclassV1]: + return is_dataclass(cls) and "__pydantic_model__" in cls.__dict__ + + +def _is_pydantic_v2_dataclass(cls: type[Any]) -> TypeGuard[PydanticDataclassV2]: + return is_dataclass(cls) and "__pydantic_validator__" in cls.__dict__ + + +class PydanticDataclassFactory(ModelFactory[T]): # type: ignore[type-var] + """Base factory for pydantic dataclasses""" + + __is_base_factory__ = True + + @classmethod + def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: + return _is_pydantic_v1_dataclass(value) or _is_pydantic_v2_dataclass(value) + + @classmethod + def get_model_fields(cls) -> list[FieldMeta]: + if _is_pydantic_v1_dataclass(cls.__model__): + pydantic_model = cls.__model__.__pydantic_model__ + cls._fields_metadata = [ + PydanticFieldMeta.from_model_field( + field, + use_alias=not pydantic_model.__config__.allow_population_by_field_name, # type: ignore[attr-defined] + random=cls.__random__, + ) + for field in pydantic_model.__fields__.values() + ] + elif _is_pydantic_v2_dataclass(cls.__model__): + pydantic_fields = cls.__model__.__pydantic_fields__ + pydantic_config = cls.__model__.__pydantic_config__ + cls._fields_metadata = [ + PydanticFieldMeta.from_field_info( + field_info=field_info, + field_name=field_name, + random=cls.__random__, + use_alias=not pydantic_config.get( + "populate_by_name", + False, + ), + ) + for field_name, field_info in pydantic_fields.items() + ] + else: + # This should be unreachable + return [] + + return cls._fields_metadata diff --git a/tests/test_pydantic_factory.py b/tests/test_pydantic_factory.py index 2c24db78..1f859566 100644 --- a/tests/test_pydantic_factory.py +++ b/tests/test_pydantic_factory.py @@ -63,9 +63,11 @@ constr, validator, ) +from pydantic.dataclasses import dataclass as pydantic_dataclass +from polyfactory.exceptions import ConfigurationException from polyfactory.factories import DataclassFactory -from polyfactory.factories.pydantic_factory import _IS_PYDANTIC_V1, ModelFactory +from polyfactory.factories.pydantic_factory import _IS_PYDANTIC_V1, ModelFactory, PydanticDataclassFactory from tests.models import Person, PetFactory IS_PYDANTIC_V1 = _IS_PYDANTIC_V1 @@ -1038,3 +1040,49 @@ class A(BaseModel): AFactory = ModelFactory.create_factory(A) assert AFactory.build() + + +def test_simple_pydantic_dataclass() -> None: + @pydantic_dataclass + class DataclassModel: + a: int + b: Annotated[str, MinLen(1)] + + class DataclassModelFactory(PydanticDataclassFactory[DataclassModel]): + __model__ = DataclassModel + + instance = DataclassModelFactory.build() + assert isinstance(instance, DataclassModel) + assert isinstance(instance.a, int) + assert isinstance(instance.b, str) + assert len(instance.b) >= 1 + + +def test_nested_pydantic_dataclass() -> None: + @pydantic_dataclass + class FooDataclass: + content: int + + @pydantic_dataclass + class NestedDataclassModel: + foo: FooDataclass + + class DataclassModelFactory(PydanticDataclassFactory[NestedDataclassModel]): + __model__ = NestedDataclassModel + + instance = DataclassModelFactory.build() + assert isinstance(instance, NestedDataclassModel) + assert isinstance(instance.foo, FooDataclass) + assert isinstance(instance.foo.content, int) + + +def test_pydantic_dataclass_factory_raises_for_std_dataclasses() -> None: + @dataclass + class DataclassModel: + a: int + b: str + + with pytest.raises(ConfigurationException): + + class DataclassModelFactory(PydanticDataclassFactory[DataclassModel]): + __model__ = DataclassModel