Skip to content

Commit

Permalink
fix: introduce pydantic v1/v2 code to hanble v1 dataclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
Slyces committed Dec 6, 2024
1 parent e9daae0 commit ddb77ff
Showing 1 changed file with 41 additions and 21 deletions.
62 changes: 41 additions & 21 deletions polyfactory/factories/pydantic_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,15 @@
# is installed.
from pydantic import PyObject

# prevent unbound variable warnings
# Prevent unbound variable warnings
BaseModelV2 = BaseModelV1
UndefinedV2 = Undefined

if TYPE_CHECKING:
from pydantic.dataclasses import Dataclass as PydanticDataclassV1 # pyright: ignore[reportPrivateImportUsage]

# Prevent unbound variable warnings
PydanticDataclassV2 = PydanticDataclassV1
except ImportError:
# pydantic v2

Expand Down Expand Up @@ -92,6 +98,8 @@
from pydantic.v1.color import Color # type: ignore[assignment]
from pydantic.v1.fields import DeferredType, ModelField, Undefined

if TYPE_CHECKING:
from pydantic.dataclasses import PydanticDataclass as PydanticDataclassV2 # pyright: ignore[reportPrivateImportUsage]

if TYPE_CHECKING:
from collections import abc
Expand All @@ -100,7 +108,6 @@

from typing_extensions import NotRequired, TypeGuard

from pydantic.dataclasses import PydanticDataclass # pyright: ignore[reportPrivateImportUsage]

ModelT = TypeVar("ModelT", bound="BaseModelV1 | BaseModelV2") # pyright: ignore[reportInvalidTypeForm]
T = TypeVar("T")
Expand Down Expand Up @@ -627,8 +634,11 @@ def _is_pydantic_v2_model(model: Any) -> TypeGuard[BaseModelV2]: # pyright: ign
return not _IS_PYDANTIC_V1 and is_safe_subclass(model, BaseModelV2)


def is_pydantic_dataclass(cls: type[Any]) -> TypeGuard[PydanticDataclass]:
# This method is available in the `pydantic.dataclasses` module for python >= 3.9
def _is_pydantic_v1_dataclass(cls: type[Any]) -> TypeGuard[PydanticDataclassV1]:
return is_dataclass(cls) and "__pydantic_model__" in cls.__dict__


def _is_pydantic_v2_dataclass(cls: type[Any]) -> TypeGuard[PydanticDataclassV2]:
return is_dataclass(cls) and "__pydantic_validator__" in cls.__dict__


Expand All @@ -639,27 +649,37 @@ class PydanticDataclassFactory(ModelFactory[T]): # type: ignore[type-var]

@classmethod
def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]:
return is_pydantic_dataclass(value)
return _is_pydantic_v1_dataclass(value) or _is_pydantic_v2_dataclass(value)

@classmethod
def get_model_fields(cls) -> list[FieldMeta]:
if not is_pydantic_dataclass(cls.__model__):
if _is_pydantic_v1_dataclass(cls.__model__):
pydantic_model = cls.__model__.__pydantic_model__
cls._fields_metadata = [
PydanticFieldMeta.from_model_field(
field,
use_alias=not pydantic_model.__config__.allow_population_by_field_name, # type: ignore[attr-defined]
random=cls.__random__,
)
for field in pydantic_model.__fields__.values()
]
elif _is_pydantic_v2_dataclass(cls.__model__):
pydantic_fields = cls.__model__.__pydantic_fields__
pydantic_config = cls.__model__.__pydantic_config__
cls._fields_metadata = [
PydanticFieldMeta.from_field_info(
field_info=field_info,
field_name=field_name,
random=cls.__random__,
use_alias=not pydantic_config.get(
"populate_by_name",
False,
),
)
for field_name, field_info in pydantic_fields.items()
]
else:
# This should be unreachable
return []

pydantic_fields = cls.__model__.__pydantic_fields__
pydantic_config = cls.__model__.__pydantic_config__
cls._fields_metadata = [
PydanticFieldMeta.from_field_info(
field_info=field_info,
field_name=field_name,
random=cls.__random__,
use_alias=not pydantic_config.get(
"populate_by_name",
False,
),
)
for field_name, field_info in pydantic_fields.items()
]

return cls._fields_metadata

0 comments on commit ddb77ff

Please sign in to comment.