diff --git a/polyfactory/factories/pydantic_factory.py b/polyfactory/factories/pydantic_factory.py index d0a6dab4..47f4af35 100644 --- a/polyfactory/factories/pydantic_factory.py +++ b/polyfactory/factories/pydantic_factory.py @@ -61,9 +61,15 @@ # is installed. from pydantic import PyObject - # prevent unbound variable warnings + # Prevent unbound variable warnings BaseModelV2 = BaseModelV1 UndefinedV2 = Undefined + + if TYPE_CHECKING: + from pydantic.dataclasses import Dataclass as PydanticDataclassV1 # pyright: ignore[reportPrivateImportUsage] + + # Prevent unbound variable warnings + PydanticDataclassV2 = PydanticDataclassV1 except ImportError: # pydantic v2 @@ -92,6 +98,8 @@ from pydantic.v1.color import Color # type: ignore[assignment] from pydantic.v1.fields import DeferredType, ModelField, Undefined + if TYPE_CHECKING: + from pydantic.dataclasses import PydanticDataclass as PydanticDataclassV2 # pyright: ignore[reportPrivateImportUsage] if TYPE_CHECKING: from collections import abc @@ -100,7 +108,6 @@ from typing_extensions import NotRequired, TypeGuard - from pydantic.dataclasses import PydanticDataclass # pyright: ignore[reportPrivateImportUsage] ModelT = TypeVar("ModelT", bound="BaseModelV1 | BaseModelV2") # pyright: ignore[reportInvalidTypeForm] T = TypeVar("T") @@ -627,8 +634,11 @@ def _is_pydantic_v2_model(model: Any) -> TypeGuard[BaseModelV2]: # pyright: ign return not _IS_PYDANTIC_V1 and is_safe_subclass(model, BaseModelV2) -def is_pydantic_dataclass(cls: type[Any]) -> TypeGuard[PydanticDataclass]: - # This method is available in the `pydantic.dataclasses` module for python >= 3.9 +def _is_pydantic_v1_dataclass(cls: type[Any]) -> TypeGuard[PydanticDataclassV1]: + return is_dataclass(cls) and "__pydantic_model__" in cls.__dict__ + + +def _is_pydantic_v2_dataclass(cls: type[Any]) -> TypeGuard[PydanticDataclassV2]: return is_dataclass(cls) and "__pydantic_validator__" in cls.__dict__ @@ -639,27 +649,37 @@ class PydanticDataclassFactory(ModelFactory[T]): # type: ignore[type-var] @classmethod def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: - return is_pydantic_dataclass(value) + return _is_pydantic_v1_dataclass(value) or _is_pydantic_v2_dataclass(value) @classmethod def get_model_fields(cls) -> list[FieldMeta]: - if not is_pydantic_dataclass(cls.__model__): + if _is_pydantic_v1_dataclass(cls.__model__): + pydantic_model = cls.__model__.__pydantic_model__ + cls._fields_metadata = [ + PydanticFieldMeta.from_model_field( + field, + use_alias=not pydantic_model.__config__.allow_population_by_field_name, # type: ignore[attr-defined] + random=cls.__random__, + ) + for field in pydantic_model.__fields__.values() + ] + elif _is_pydantic_v2_dataclass(cls.__model__): + pydantic_fields = cls.__model__.__pydantic_fields__ + pydantic_config = cls.__model__.__pydantic_config__ + cls._fields_metadata = [ + PydanticFieldMeta.from_field_info( + field_info=field_info, + field_name=field_name, + random=cls.__random__, + use_alias=not pydantic_config.get( + "populate_by_name", + False, + ), + ) + for field_name, field_info in pydantic_fields.items() + ] + else: # This should be unreachable return [] - pydantic_fields = cls.__model__.__pydantic_fields__ - pydantic_config = cls.__model__.__pydantic_config__ - cls._fields_metadata = [ - PydanticFieldMeta.from_field_info( - field_info=field_info, - field_name=field_name, - random=cls.__random__, - use_alias=not pydantic_config.get( - "populate_by_name", - False, - ), - ) - for field_name, field_info in pydantic_fields.items() - ] - return cls._fields_metadata