From 8f49d33a491dd29dc721de44d659ae83ee9abf72 Mon Sep 17 00:00:00 2001 From: mityuha Date: Wed, 6 Dec 2023 20:35:12 +0100 Subject: [PATCH 01/11] add optional model field tests --- tests/optional_model_field/__init__.py | 0 .../test_attrs_factory.py | 48 ++++++++++ .../test_beanie_factory.py | 48 ++++++++++ .../test_generic_class_inference.py | 17 ++++ .../test_model_inference_error.py | 59 ++++++++++++ .../test_msgspec_factory.py | 25 ++++++ .../test_odmantic_factory.py | 90 +++++++++++++++++++ .../test_pydantic_factory.py | 21 +++++ .../test_sqlalchemy_factory.py | 44 +++++++++ .../test_typeddict_factory.py | 26 ++++++ 10 files changed, 378 insertions(+) create mode 100644 tests/optional_model_field/__init__.py create mode 100644 tests/optional_model_field/test_attrs_factory.py create mode 100644 tests/optional_model_field/test_beanie_factory.py create mode 100644 tests/optional_model_field/test_generic_class_inference.py create mode 100644 tests/optional_model_field/test_model_inference_error.py create mode 100644 tests/optional_model_field/test_msgspec_factory.py create mode 100644 tests/optional_model_field/test_odmantic_factory.py create mode 100644 tests/optional_model_field/test_pydantic_factory.py create mode 100644 tests/optional_model_field/test_sqlalchemy_factory.py create mode 100644 tests/optional_model_field/test_typeddict_factory.py diff --git a/tests/optional_model_field/__init__.py b/tests/optional_model_field/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/optional_model_field/test_attrs_factory.py b/tests/optional_model_field/test_attrs_factory.py new file mode 100644 index 00000000..5578f7b3 --- /dev/null +++ b/tests/optional_model_field/test_attrs_factory.py @@ -0,0 +1,48 @@ +import datetime as dt +from decimal import Decimal +from enum import Enum +from typing import Any, Dict, List, Tuple +from uuid import UUID + +import pytest +from attrs import asdict, define + +from polyfactory.factories.attrs_factory import AttrsFactory + +pytestmark = [pytest.mark.attrs] + + +def test_with_basic_types_annotated() -> None: + class SampleEnum(Enum): + FOO = "foo" + BAR = "bar" + + @define + class Foo: + bool_field: bool + int_field: int + float_field: float + str_field: str + bytse_field: bytes + bytearray_field: bytearray + tuple_field: Tuple[int, str] + tuple_with_variadic_args: Tuple[int, ...] + list_field: List[int] + dict_field: Dict[str, int] + datetime_field: dt.datetime + date_field: dt.date + time_field: dt.time + uuid_field: UUID + decimal_field: Decimal + enum_type: SampleEnum + any_type: Any + + class FooFactory(AttrsFactory[Foo]): + ... + + assert getattr(FooFactory, "__model__") is Foo + + foo: Foo = FooFactory.build() + foo_dict = asdict(foo) + + assert foo == Foo(**foo_dict) diff --git a/tests/optional_model_field/test_beanie_factory.py b/tests/optional_model_field/test_beanie_factory.py new file mode 100644 index 00000000..5d3ce028 --- /dev/null +++ b/tests/optional_model_field/test_beanie_factory.py @@ -0,0 +1,48 @@ +from typing import Callable, List + +import pymongo +import pytest + +try: + from beanie import Document, init_beanie + from beanie.odm.fields import Indexed, PydanticObjectId + from mongomock_motor import AsyncMongoMockClient + + from polyfactory.factories.beanie_odm_factory import BeanieDocumentFactory +except ImportError: + pytest.importorskip("beanie") + + BeanieDocumentFactory = None # type: ignore + Document = None # type: ignore + init_beanie = None # type: ignore + Indexed = None # type: ignore + PydanticObjectId = None # type: ignore + + +@pytest.fixture() +def mongo_connection() -> AsyncMongoMockClient: + return AsyncMongoMockClient() + + +class MyDocument(Document): + id: PydanticObjectId + name: str + index: Indexed(str, pymongo.DESCENDING) # type: ignore + siblings: List[PydanticObjectId] + + +class MyFactory(BeanieDocumentFactory[MyDocument]): + ... + + +@pytest.fixture() +async def beanie_init(mongo_connection: AsyncMongoMockClient) -> None: + await init_beanie(database=mongo_connection.db_name, document_models=[MyDocument]) # type: ignore + + +async def test_handling_of_beanie_types(beanie_init: Callable) -> None: + assert getattr(MyFactory, "__model__") is MyDocument + result: MyDocument = MyFactory.build() + assert result.name + assert result.index + assert isinstance(result.index, str) diff --git a/tests/optional_model_field/test_generic_class_inference.py b/tests/optional_model_field/test_generic_class_inference.py new file mode 100644 index 00000000..ba2f0367 --- /dev/null +++ b/tests/optional_model_field/test_generic_class_inference.py @@ -0,0 +1,17 @@ +from typing import Generic, TypeVar + +from pydantic import BaseModel + +from polyfactory.factories.pydantic_factory import ModelFactory + + +def test_generic_model_is_not_an_error() -> None: + T = TypeVar("T") + + class Foo(BaseModel, Generic[T]): + val: T + + class FooFactory(ModelFactory[Foo[str]]): + ... + + assert isinstance(FooFactory.build().val, str) diff --git a/tests/optional_model_field/test_model_inference_error.py b/tests/optional_model_field/test_model_inference_error.py new file mode 100644 index 00000000..af211510 --- /dev/null +++ b/tests/optional_model_field/test_model_inference_error.py @@ -0,0 +1,59 @@ +from typing import Type, TypedDict + +import pytest +from pydantic import BaseModel + +from polyfactory import ConfigurationException +from polyfactory.factories import TypedDictFactory +from polyfactory.factories.attrs_factory import AttrsFactory +from polyfactory.factories.base import BaseFactory +from polyfactory.factories.msgspec_factory import MsgspecFactory +from polyfactory.factories.pydantic_factory import ModelFactory +from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory + + +@pytest.mark.parametrize( + "base_factory", + [ + AttrsFactory, + ModelFactory, + MsgspecFactory, + SQLAlchemyFactory, + TypedDictFactory, + ], +) +def test_model_without_generic_type_inference_error(base_factory: Type[BaseFactory]) -> None: + with pytest.raises(ConfigurationException): + + class Foo(base_factory): # type: ignore + ... + + +@pytest.mark.parametrize( + "base_factory", + [ + AttrsFactory, + ModelFactory, + MsgspecFactory, + SQLAlchemyFactory, + TypedDictFactory, + ], +) +def test_model_type_error(base_factory: Type[BaseFactory]) -> None: + with pytest.raises(ConfigurationException): + + class Foo(base_factory[int]): # type: ignore + ... + + +def test_model_multiple_inheritance_cannot_infer_error() -> None: + class PFoo(BaseModel): + val: int + + class TDFoo(TypedDict): + val: str + + with pytest.raises(ConfigurationException): + + class Foo(ModelFactory[PFoo], TypedDictFactory[TDFoo]): # type: ignore + ... diff --git a/tests/optional_model_field/test_msgspec_factory.py b/tests/optional_model_field/test_msgspec_factory.py new file mode 100644 index 00000000..85287dfd --- /dev/null +++ b/tests/optional_model_field/test_msgspec_factory.py @@ -0,0 +1,25 @@ +import msgspec +from msgspec import Struct, structs + +from polyfactory.factories.msgspec_factory import MsgspecFactory + + +def test_with_nested_struct() -> None: + class Foo(Struct): + int_field: int + + class Bar(Struct): + int_field: int + foo_field: Foo + + class BarFactory(MsgspecFactory[Bar]): + ... + + assert getattr(BarFactory, "__model__") is Bar + + bar: Bar = BarFactory.build() + bar_dict = structs.asdict(bar) + bar_dict["foo_field"] = structs.asdict(bar_dict["foo_field"]) + + validated_bar = msgspec.convert(bar_dict, type=Bar) + assert validated_bar == bar diff --git a/tests/optional_model_field/test_odmantic_factory.py b/tests/optional_model_field/test_odmantic_factory.py new file mode 100644 index 00000000..44216b24 --- /dev/null +++ b/tests/optional_model_field/test_odmantic_factory.py @@ -0,0 +1,90 @@ +from datetime import datetime +from typing import Any, List +from uuid import UUID + +import bson +import pytest + +try: + from odmantic import AIOEngine, EmbeddedModel, Model + + from polyfactory.factories.odmantic_odm_factory import OdmanticModelFactory +except ImportError: + AIOEngine, EmbeddedModel, Model, OdmanticModelFactory = None, None, None, None # type: ignore + pytest.importorskip("odmantic") + + +class OtherEmbeddedDocument(EmbeddedModel): # type: ignore + name: str + serial: UUID + created_on: datetime + bson_id: bson.ObjectId + bson_int64: bson.Int64 + bson_dec128: bson.Decimal128 + bson_binary: bson.Binary + + +class MyEmbeddedDocument(EmbeddedModel): # type: ignore + name: str + serial: UUID + other_embedded_document: OtherEmbeddedDocument + created_on: datetime + bson_id: bson.ObjectId + bson_int64: bson.Int64 + bson_dec128: bson.Decimal128 + bson_binary: bson.Binary + + +class MyModel(Model): # type: ignore + created_on: datetime + bson_id: bson.ObjectId + bson_int64: bson.Int64 + bson_dec128: bson.Decimal128 + bson_binary: bson.Binary + name: str + embedded: MyEmbeddedDocument + embedded_list: List[MyEmbeddedDocument] + + +@pytest.fixture() +async def odmantic_engine(mongo_connection: Any) -> AIOEngine: + return AIOEngine(client=mongo_connection, database=mongo_connection.db_name) + + +def test_handles_odmantic_models() -> None: + class MyFactory(OdmanticModelFactory[MyModel]): + ... + + assert getattr(MyFactory, "__model__") is MyModel + + result: MyModel = MyFactory.build() + + assert isinstance(result, MyModel) + assert isinstance(result.id, bson.ObjectId) + assert isinstance(result.created_on, datetime) + assert isinstance(result.bson_id, bson.ObjectId) + assert isinstance(result.bson_int64, bson.Int64) + assert isinstance(result.bson_dec128, bson.Decimal128) + assert isinstance(result.bson_binary, bson.Binary) + assert isinstance(result.name, str) + assert isinstance(result.embedded, MyEmbeddedDocument) + assert isinstance(result.embedded_list, list) + for item in result.embedded_list: + assert isinstance(item, MyEmbeddedDocument) + assert isinstance(item.name, str) + assert isinstance(item.serial, UUID) + assert isinstance(item.created_on, datetime) + assert isinstance(item.bson_id, bson.ObjectId) + assert isinstance(item.bson_int64, bson.Int64) + assert isinstance(item.bson_dec128, bson.Decimal128) + assert isinstance(item.bson_binary, bson.Binary) + + other = item.other_embedded_document + assert isinstance(other, OtherEmbeddedDocument) + assert isinstance(other.name, str) + assert isinstance(other.serial, UUID) + assert isinstance(other.created_on, datetime) + assert isinstance(other.bson_id, bson.ObjectId) + assert isinstance(other.bson_int64, bson.Int64) + assert isinstance(other.bson_dec128, bson.Decimal128) + assert isinstance(other.bson_binary, bson.Binary) diff --git a/tests/optional_model_field/test_pydantic_factory.py b/tests/optional_model_field/test_pydantic_factory.py new file mode 100644 index 00000000..b936eb34 --- /dev/null +++ b/tests/optional_model_field/test_pydantic_factory.py @@ -0,0 +1,21 @@ +from typing import Dict + +from pydantic import BaseModel, Field +from typing_extensions import Annotated + +from polyfactory.factories.pydantic_factory import ModelFactory + + +def test_mapping_with_annotated_item_types() -> None: + ConstrainedInt = Annotated[int, Field(ge=100, le=200)] + ConstrainedStr = Annotated[str, Field(min_length=1, max_length=3)] + + class Foo(BaseModel): + dict_field: Dict[ConstrainedStr, ConstrainedInt] + + class FooFactory(ModelFactory[Foo]): + ... + + assert getattr(FooFactory, "__model__") is Foo + + assert FooFactory.build() diff --git a/tests/optional_model_field/test_sqlalchemy_factory.py b/tests/optional_model_field/test_sqlalchemy_factory.py new file mode 100644 index 00000000..3a4395fe --- /dev/null +++ b/tests/optional_model_field/test_sqlalchemy_factory.py @@ -0,0 +1,44 @@ +from enum import Enum +from typing import Any + +from sqlalchemy import Column, Integer, String, types +from sqlalchemy.orm.decl_api import DeclarativeMeta, registry + +from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory + + +def test_python_type_handling() -> None: + _registry = registry() + + class Base(metaclass=DeclarativeMeta): + __abstract__ = True + + registry = _registry + metadata = _registry.metadata + + class Animal(str, Enum): + DOG = "Dog" + CAT = "Cat" + + class Model(Base): + __tablename__ = "model" + + id: Any = Column(Integer(), primary_key=True) + str_type: Any = Column(String(), nullable=False) + enum_type: Any = Column(types.Enum(Animal), nullable=False) + str_array_type: Any = Column( + types.ARRAY(types.String), + nullable=False, + ) + + class ModelFactory(SQLAlchemyFactory[Model]): + ... + + assert getattr(ModelFactory, "__model__") is Model + + instance: Model = ModelFactory.build() + assert isinstance(instance.id, int) + assert isinstance(instance.str_type, str) + assert isinstance(instance.enum_type, Animal) + assert isinstance(instance.str_array_type, list) + assert isinstance(instance.str_array_type[0], str) diff --git a/tests/optional_model_field/test_typeddict_factory.py b/tests/optional_model_field/test_typeddict_factory.py new file mode 100644 index 00000000..579f3eca --- /dev/null +++ b/tests/optional_model_field/test_typeddict_factory.py @@ -0,0 +1,26 @@ +from typing import Dict, List, Optional + +from typing_extensions import TypedDict + +from polyfactory.factories import TypedDictFactory + + +class TypedDictModel(TypedDict): + id: int + name: str + list_field: List[Dict[str, int]] + int_field: Optional[int] + + +def test_factory_with_typeddict() -> None: + class MyFactory(TypedDictFactory[TypedDictModel]): + ... + + assert getattr(MyFactory, "__model__") is TypedDictModel + result: TypedDictModel = MyFactory.build() + + assert isinstance(result, dict) + assert result["id"] + assert result["name"] + assert result["list_field"][0] + assert type(result["int_field"]) in (type(None), int) From c6f544a2b0d6ec5ac1ecc3aaf7873cb22882f6c0 Mon Sep 17 00:00:00 2001 From: mityuha Date: Wed, 6 Dec 2023 20:58:10 +0100 Subject: [PATCH 02/11] implement infer_model_type classmethod --- polyfactory/factories/base.py | 31 +++++++++++++++++++++++++++---- pyproject.toml | 4 ++++ 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/polyfactory/factories/base.py b/polyfactory/factories/base.py index 07a6c97d..f8b1d812 100644 --- a/polyfactory/factories/base.py +++ b/polyfactory/factories/base.py @@ -1,6 +1,5 @@ from __future__ import annotations -import typing from abc import ABC, abstractmethod from collections import Counter, abc, deque from contextlib import suppress @@ -36,7 +35,9 @@ ClassVar, Collection, Generic, + Iterable, Mapping, + Optional, Sequence, Type, TypeVar, @@ -45,7 +46,7 @@ from uuid import UUID from faker import Faker -from typing_extensions import get_args +from typing_extensions import get_args, get_origin from polyfactory.constants import ( DEFAULT_RANDOM, @@ -190,12 +191,13 @@ def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None: # noqa: C901 ) if "__is_base_factory__" not in cls.__dict__ or not cls.__is_base_factory__: - model = getattr(cls, "__model__", None) + model: Optional[type[T]] = getattr(cls, "__model__", None) or cls._infer_model_type() if not model: msg = f"required configuration attribute '__model__' is not set on {cls.__name__}" raise ConfigurationException( msg, ) + cls.__model__ = model if not cls.is_supported_type(model): for factory in BaseFactory._base_factories: if factory.is_supported_type(model): @@ -219,6 +221,27 @@ def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None: # noqa: C901 if cls.__set_as_default_factory_for_type__: BaseFactory._factory_type_mapping[cls.__model__] = cls + @classmethod + def _infer_model_type(cls: type[F]) -> Optional[type[T]]: + """Return model type inferred from class declaration. + class Foo(ModelFactory[MyModel]): # <<< MyModel + ... + + If more than one base class and/or generic arguments specified return None. + + :returns: Inferred model type or None + """ + factory_bases: Iterable[type[BaseFactory[T]]] = ( + b for b in getattr(cls, "__orig_bases__", []) if issubclass(get_origin(b), BaseFactory) + ) + generic_args: Sequence[type[T]] = [ + arg for factory_base in factory_bases for arg in get_args(factory_base) if not isinstance(arg, TypeVar) + ] + if len(generic_args) != 1: + return None + + return generic_args[0] + @classmethod def _get_sync_persistence(cls) -> SyncPersistenceProtocol[T]: """Return a SyncPersistenceHandler if defined for the factory, otherwise raises a ConfigurationException. @@ -676,7 +699,7 @@ def get_field_value_coverage( # noqa: C901 cls, field_meta: FieldMeta, field_build_parameters: Any | None = None, - ) -> typing.Iterable[Any]: + ) -> Iterable[Any]: """Return a field value on the subclass if existing, otherwise returns a mock value. :param field_meta: FieldMeta instance. diff --git a/pyproject.toml b/pyproject.toml index 31544fac..6faf9e7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -189,6 +189,10 @@ line-length = 120 src = ["polyfactory", "tests", "docs/examples"] target-version = "py38" +[tool.ruff.lint.pyupgrade] +# Preserve types, even if a file imports `from __future__ import annotations`. +keep-runtime-typing = true + [tool.ruff.pydocstyle] convention = "google" From 4f660c408de65a15a2294538390c1ea07add332b Mon Sep 17 00:00:00 2001 From: mityuha Date: Fri, 8 Dec 2023 14:41:00 +0100 Subject: [PATCH 03/11] base model does not support generics within 1.10 --- .../test_generic_class_inference.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/optional_model_field/test_generic_class_inference.py b/tests/optional_model_field/test_generic_class_inference.py index ba2f0367..d26f87fa 100644 --- a/tests/optional_model_field/test_generic_class_inference.py +++ b/tests/optional_model_field/test_generic_class_inference.py @@ -1,17 +1,20 @@ from typing import Generic, TypeVar -from pydantic import BaseModel +from pydantic.generics import GenericModel from polyfactory.factories.pydantic_factory import ModelFactory def test_generic_model_is_not_an_error() -> None: T = TypeVar("T") + P = TypeVar("P") - class Foo(BaseModel, Generic[T]): - val: T + class Foo(GenericModel, Generic[T, P]): # type: ignore[misc] + val1: T + val2: P - class FooFactory(ModelFactory[Foo[str]]): + class FooFactory(ModelFactory[Foo[str, int]]): ... - assert isinstance(FooFactory.build().val, str) + assert isinstance(FooFactory.build().val1, str) + assert isinstance(FooFactory.build().val2, int) From 712f003ed8357f9ffebba671de430a3b52c4cc3a Mon Sep 17 00:00:00 2001 From: mityuha Date: Sat, 9 Dec 2023 16:50:22 +0100 Subject: [PATCH 04/11] fix x | y syntax --- polyfactory/factories/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/polyfactory/factories/base.py b/polyfactory/factories/base.py index f8b1d812..7461688f 100644 --- a/polyfactory/factories/base.py +++ b/polyfactory/factories/base.py @@ -191,7 +191,7 @@ def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None: # noqa: C901 ) if "__is_base_factory__" not in cls.__dict__ or not cls.__is_base_factory__: - model: Optional[type[T]] = getattr(cls, "__model__", None) or cls._infer_model_type() + model: type[T] | None = getattr(cls, "__model__", None) or cls._infer_model_type() if not model: msg = f"required configuration attribute '__model__' is not set on {cls.__name__}" raise ConfigurationException( From 1bda552a93b553b01ab2228455b3dba9391329df Mon Sep 17 00:00:00 2001 From: mityuha Date: Sat, 9 Dec 2023 17:21:16 +0100 Subject: [PATCH 05/11] the use of get_original_bases --- polyfactory/factories/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/polyfactory/factories/base.py b/polyfactory/factories/base.py index 7461688f..f4aeffb2 100644 --- a/polyfactory/factories/base.py +++ b/polyfactory/factories/base.py @@ -46,7 +46,7 @@ from uuid import UUID from faker import Faker -from typing_extensions import get_args, get_origin +from typing_extensions import get_args, get_origin, get_original_bases from polyfactory.constants import ( DEFAULT_RANDOM, @@ -231,8 +231,9 @@ class Foo(ModelFactory[MyModel]): # <<< MyModel :returns: Inferred model type or None """ + factory_bases: Iterable[type[BaseFactory[T]]] = ( - b for b in getattr(cls, "__orig_bases__", []) if issubclass(get_origin(b), BaseFactory) + b for b in get_original_bases(cls) if get_origin(b) and issubclass(get_origin(b), BaseFactory) ) generic_args: Sequence[type[T]] = [ arg for factory_base in factory_bases for arg in get_args(factory_base) if not isinstance(arg, TypeVar) From 798920cba6a0ca2638955789589b99f27394b99b Mon Sep 17 00:00:00 2001 From: mityuha Date: Sat, 9 Dec 2023 17:22:31 +0100 Subject: [PATCH 06/11] revert ruff config back --- pyproject.toml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6faf9e7a..31544fac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -189,10 +189,6 @@ line-length = 120 src = ["polyfactory", "tests", "docs/examples"] target-version = "py38" -[tool.ruff.lint.pyupgrade] -# Preserve types, even if a file imports `from __future__ import annotations`. -keep-runtime-typing = true - [tool.ruff.pydocstyle] convention = "google" From 587e10d709ec917cb2a95b4c3c738e7cdbc547e8 Mon Sep 17 00:00:00 2001 From: mityuha Date: Sat, 9 Dec 2023 18:04:53 +0100 Subject: [PATCH 07/11] move optional model field tests into single file --- tests/optional_model_field/__init__.py | 0 .../test_attrs_factory.py | 48 ------ .../test_beanie_factory.py | 48 ------ .../test_generic_class_inference.py | 20 --- .../test_model_inference_error.py | 59 ------- .../test_msgspec_factory.py | 25 --- .../test_odmantic_factory.py | 90 ---------- .../test_pydantic_factory.py | 21 --- .../test_sqlalchemy_factory.py | 44 ----- .../test_typeddict_factory.py | 26 --- tests/test_optional_model_field_inference.py | 161 ++++++++++++++++++ 11 files changed, 161 insertions(+), 381 deletions(-) delete mode 100644 tests/optional_model_field/__init__.py delete mode 100644 tests/optional_model_field/test_attrs_factory.py delete mode 100644 tests/optional_model_field/test_beanie_factory.py delete mode 100644 tests/optional_model_field/test_generic_class_inference.py delete mode 100644 tests/optional_model_field/test_model_inference_error.py delete mode 100644 tests/optional_model_field/test_msgspec_factory.py delete mode 100644 tests/optional_model_field/test_odmantic_factory.py delete mode 100644 tests/optional_model_field/test_pydantic_factory.py delete mode 100644 tests/optional_model_field/test_sqlalchemy_factory.py delete mode 100644 tests/optional_model_field/test_typeddict_factory.py create mode 100644 tests/test_optional_model_field_inference.py diff --git a/tests/optional_model_field/__init__.py b/tests/optional_model_field/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/optional_model_field/test_attrs_factory.py b/tests/optional_model_field/test_attrs_factory.py deleted file mode 100644 index 5578f7b3..00000000 --- a/tests/optional_model_field/test_attrs_factory.py +++ /dev/null @@ -1,48 +0,0 @@ -import datetime as dt -from decimal import Decimal -from enum import Enum -from typing import Any, Dict, List, Tuple -from uuid import UUID - -import pytest -from attrs import asdict, define - -from polyfactory.factories.attrs_factory import AttrsFactory - -pytestmark = [pytest.mark.attrs] - - -def test_with_basic_types_annotated() -> None: - class SampleEnum(Enum): - FOO = "foo" - BAR = "bar" - - @define - class Foo: - bool_field: bool - int_field: int - float_field: float - str_field: str - bytse_field: bytes - bytearray_field: bytearray - tuple_field: Tuple[int, str] - tuple_with_variadic_args: Tuple[int, ...] - list_field: List[int] - dict_field: Dict[str, int] - datetime_field: dt.datetime - date_field: dt.date - time_field: dt.time - uuid_field: UUID - decimal_field: Decimal - enum_type: SampleEnum - any_type: Any - - class FooFactory(AttrsFactory[Foo]): - ... - - assert getattr(FooFactory, "__model__") is Foo - - foo: Foo = FooFactory.build() - foo_dict = asdict(foo) - - assert foo == Foo(**foo_dict) diff --git a/tests/optional_model_field/test_beanie_factory.py b/tests/optional_model_field/test_beanie_factory.py deleted file mode 100644 index 5d3ce028..00000000 --- a/tests/optional_model_field/test_beanie_factory.py +++ /dev/null @@ -1,48 +0,0 @@ -from typing import Callable, List - -import pymongo -import pytest - -try: - from beanie import Document, init_beanie - from beanie.odm.fields import Indexed, PydanticObjectId - from mongomock_motor import AsyncMongoMockClient - - from polyfactory.factories.beanie_odm_factory import BeanieDocumentFactory -except ImportError: - pytest.importorskip("beanie") - - BeanieDocumentFactory = None # type: ignore - Document = None # type: ignore - init_beanie = None # type: ignore - Indexed = None # type: ignore - PydanticObjectId = None # type: ignore - - -@pytest.fixture() -def mongo_connection() -> AsyncMongoMockClient: - return AsyncMongoMockClient() - - -class MyDocument(Document): - id: PydanticObjectId - name: str - index: Indexed(str, pymongo.DESCENDING) # type: ignore - siblings: List[PydanticObjectId] - - -class MyFactory(BeanieDocumentFactory[MyDocument]): - ... - - -@pytest.fixture() -async def beanie_init(mongo_connection: AsyncMongoMockClient) -> None: - await init_beanie(database=mongo_connection.db_name, document_models=[MyDocument]) # type: ignore - - -async def test_handling_of_beanie_types(beanie_init: Callable) -> None: - assert getattr(MyFactory, "__model__") is MyDocument - result: MyDocument = MyFactory.build() - assert result.name - assert result.index - assert isinstance(result.index, str) diff --git a/tests/optional_model_field/test_generic_class_inference.py b/tests/optional_model_field/test_generic_class_inference.py deleted file mode 100644 index d26f87fa..00000000 --- a/tests/optional_model_field/test_generic_class_inference.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import Generic, TypeVar - -from pydantic.generics import GenericModel - -from polyfactory.factories.pydantic_factory import ModelFactory - - -def test_generic_model_is_not_an_error() -> None: - T = TypeVar("T") - P = TypeVar("P") - - class Foo(GenericModel, Generic[T, P]): # type: ignore[misc] - val1: T - val2: P - - class FooFactory(ModelFactory[Foo[str, int]]): - ... - - assert isinstance(FooFactory.build().val1, str) - assert isinstance(FooFactory.build().val2, int) diff --git a/tests/optional_model_field/test_model_inference_error.py b/tests/optional_model_field/test_model_inference_error.py deleted file mode 100644 index af211510..00000000 --- a/tests/optional_model_field/test_model_inference_error.py +++ /dev/null @@ -1,59 +0,0 @@ -from typing import Type, TypedDict - -import pytest -from pydantic import BaseModel - -from polyfactory import ConfigurationException -from polyfactory.factories import TypedDictFactory -from polyfactory.factories.attrs_factory import AttrsFactory -from polyfactory.factories.base import BaseFactory -from polyfactory.factories.msgspec_factory import MsgspecFactory -from polyfactory.factories.pydantic_factory import ModelFactory -from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory - - -@pytest.mark.parametrize( - "base_factory", - [ - AttrsFactory, - ModelFactory, - MsgspecFactory, - SQLAlchemyFactory, - TypedDictFactory, - ], -) -def test_model_without_generic_type_inference_error(base_factory: Type[BaseFactory]) -> None: - with pytest.raises(ConfigurationException): - - class Foo(base_factory): # type: ignore - ... - - -@pytest.mark.parametrize( - "base_factory", - [ - AttrsFactory, - ModelFactory, - MsgspecFactory, - SQLAlchemyFactory, - TypedDictFactory, - ], -) -def test_model_type_error(base_factory: Type[BaseFactory]) -> None: - with pytest.raises(ConfigurationException): - - class Foo(base_factory[int]): # type: ignore - ... - - -def test_model_multiple_inheritance_cannot_infer_error() -> None: - class PFoo(BaseModel): - val: int - - class TDFoo(TypedDict): - val: str - - with pytest.raises(ConfigurationException): - - class Foo(ModelFactory[PFoo], TypedDictFactory[TDFoo]): # type: ignore - ... diff --git a/tests/optional_model_field/test_msgspec_factory.py b/tests/optional_model_field/test_msgspec_factory.py deleted file mode 100644 index 85287dfd..00000000 --- a/tests/optional_model_field/test_msgspec_factory.py +++ /dev/null @@ -1,25 +0,0 @@ -import msgspec -from msgspec import Struct, structs - -from polyfactory.factories.msgspec_factory import MsgspecFactory - - -def test_with_nested_struct() -> None: - class Foo(Struct): - int_field: int - - class Bar(Struct): - int_field: int - foo_field: Foo - - class BarFactory(MsgspecFactory[Bar]): - ... - - assert getattr(BarFactory, "__model__") is Bar - - bar: Bar = BarFactory.build() - bar_dict = structs.asdict(bar) - bar_dict["foo_field"] = structs.asdict(bar_dict["foo_field"]) - - validated_bar = msgspec.convert(bar_dict, type=Bar) - assert validated_bar == bar diff --git a/tests/optional_model_field/test_odmantic_factory.py b/tests/optional_model_field/test_odmantic_factory.py deleted file mode 100644 index 44216b24..00000000 --- a/tests/optional_model_field/test_odmantic_factory.py +++ /dev/null @@ -1,90 +0,0 @@ -from datetime import datetime -from typing import Any, List -from uuid import UUID - -import bson -import pytest - -try: - from odmantic import AIOEngine, EmbeddedModel, Model - - from polyfactory.factories.odmantic_odm_factory import OdmanticModelFactory -except ImportError: - AIOEngine, EmbeddedModel, Model, OdmanticModelFactory = None, None, None, None # type: ignore - pytest.importorskip("odmantic") - - -class OtherEmbeddedDocument(EmbeddedModel): # type: ignore - name: str - serial: UUID - created_on: datetime - bson_id: bson.ObjectId - bson_int64: bson.Int64 - bson_dec128: bson.Decimal128 - bson_binary: bson.Binary - - -class MyEmbeddedDocument(EmbeddedModel): # type: ignore - name: str - serial: UUID - other_embedded_document: OtherEmbeddedDocument - created_on: datetime - bson_id: bson.ObjectId - bson_int64: bson.Int64 - bson_dec128: bson.Decimal128 - bson_binary: bson.Binary - - -class MyModel(Model): # type: ignore - created_on: datetime - bson_id: bson.ObjectId - bson_int64: bson.Int64 - bson_dec128: bson.Decimal128 - bson_binary: bson.Binary - name: str - embedded: MyEmbeddedDocument - embedded_list: List[MyEmbeddedDocument] - - -@pytest.fixture() -async def odmantic_engine(mongo_connection: Any) -> AIOEngine: - return AIOEngine(client=mongo_connection, database=mongo_connection.db_name) - - -def test_handles_odmantic_models() -> None: - class MyFactory(OdmanticModelFactory[MyModel]): - ... - - assert getattr(MyFactory, "__model__") is MyModel - - result: MyModel = MyFactory.build() - - assert isinstance(result, MyModel) - assert isinstance(result.id, bson.ObjectId) - assert isinstance(result.created_on, datetime) - assert isinstance(result.bson_id, bson.ObjectId) - assert isinstance(result.bson_int64, bson.Int64) - assert isinstance(result.bson_dec128, bson.Decimal128) - assert isinstance(result.bson_binary, bson.Binary) - assert isinstance(result.name, str) - assert isinstance(result.embedded, MyEmbeddedDocument) - assert isinstance(result.embedded_list, list) - for item in result.embedded_list: - assert isinstance(item, MyEmbeddedDocument) - assert isinstance(item.name, str) - assert isinstance(item.serial, UUID) - assert isinstance(item.created_on, datetime) - assert isinstance(item.bson_id, bson.ObjectId) - assert isinstance(item.bson_int64, bson.Int64) - assert isinstance(item.bson_dec128, bson.Decimal128) - assert isinstance(item.bson_binary, bson.Binary) - - other = item.other_embedded_document - assert isinstance(other, OtherEmbeddedDocument) - assert isinstance(other.name, str) - assert isinstance(other.serial, UUID) - assert isinstance(other.created_on, datetime) - assert isinstance(other.bson_id, bson.ObjectId) - assert isinstance(other.bson_int64, bson.Int64) - assert isinstance(other.bson_dec128, bson.Decimal128) - assert isinstance(other.bson_binary, bson.Binary) diff --git a/tests/optional_model_field/test_pydantic_factory.py b/tests/optional_model_field/test_pydantic_factory.py deleted file mode 100644 index b936eb34..00000000 --- a/tests/optional_model_field/test_pydantic_factory.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import Dict - -from pydantic import BaseModel, Field -from typing_extensions import Annotated - -from polyfactory.factories.pydantic_factory import ModelFactory - - -def test_mapping_with_annotated_item_types() -> None: - ConstrainedInt = Annotated[int, Field(ge=100, le=200)] - ConstrainedStr = Annotated[str, Field(min_length=1, max_length=3)] - - class Foo(BaseModel): - dict_field: Dict[ConstrainedStr, ConstrainedInt] - - class FooFactory(ModelFactory[Foo]): - ... - - assert getattr(FooFactory, "__model__") is Foo - - assert FooFactory.build() diff --git a/tests/optional_model_field/test_sqlalchemy_factory.py b/tests/optional_model_field/test_sqlalchemy_factory.py deleted file mode 100644 index 3a4395fe..00000000 --- a/tests/optional_model_field/test_sqlalchemy_factory.py +++ /dev/null @@ -1,44 +0,0 @@ -from enum import Enum -from typing import Any - -from sqlalchemy import Column, Integer, String, types -from sqlalchemy.orm.decl_api import DeclarativeMeta, registry - -from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory - - -def test_python_type_handling() -> None: - _registry = registry() - - class Base(metaclass=DeclarativeMeta): - __abstract__ = True - - registry = _registry - metadata = _registry.metadata - - class Animal(str, Enum): - DOG = "Dog" - CAT = "Cat" - - class Model(Base): - __tablename__ = "model" - - id: Any = Column(Integer(), primary_key=True) - str_type: Any = Column(String(), nullable=False) - enum_type: Any = Column(types.Enum(Animal), nullable=False) - str_array_type: Any = Column( - types.ARRAY(types.String), - nullable=False, - ) - - class ModelFactory(SQLAlchemyFactory[Model]): - ... - - assert getattr(ModelFactory, "__model__") is Model - - instance: Model = ModelFactory.build() - assert isinstance(instance.id, int) - assert isinstance(instance.str_type, str) - assert isinstance(instance.enum_type, Animal) - assert isinstance(instance.str_array_type, list) - assert isinstance(instance.str_array_type[0], str) diff --git a/tests/optional_model_field/test_typeddict_factory.py b/tests/optional_model_field/test_typeddict_factory.py deleted file mode 100644 index 579f3eca..00000000 --- a/tests/optional_model_field/test_typeddict_factory.py +++ /dev/null @@ -1,26 +0,0 @@ -from typing import Dict, List, Optional - -from typing_extensions import TypedDict - -from polyfactory.factories import TypedDictFactory - - -class TypedDictModel(TypedDict): - id: int - name: str - list_field: List[Dict[str, int]] - int_field: Optional[int] - - -def test_factory_with_typeddict() -> None: - class MyFactory(TypedDictFactory[TypedDictModel]): - ... - - assert getattr(MyFactory, "__model__") is TypedDictModel - result: TypedDictModel = MyFactory.build() - - assert isinstance(result, dict) - assert result["id"] - assert result["name"] - assert result["list_field"][0] - assert type(result["int_field"]) in (type(None), int) diff --git a/tests/test_optional_model_field_inference.py b/tests/test_optional_model_field_inference.py new file mode 100644 index 00000000..3aeefda7 --- /dev/null +++ b/tests/test_optional_model_field_inference.py @@ -0,0 +1,161 @@ +from typing import Any, Dict, Generic, Type, TypedDict, TypeVar + +import pytest +from attrs import define +from msgspec import Struct +from pydantic import BaseModel +from pydantic.generics import GenericModel +from sqlalchemy import Column, Integer +from sqlalchemy.orm.decl_api import DeclarativeMeta, registry + +from polyfactory import ConfigurationException +from polyfactory.factories import TypedDictFactory +from polyfactory.factories.attrs_factory import AttrsFactory +from polyfactory.factories.base import BaseFactory +from polyfactory.factories.msgspec_factory import MsgspecFactory +from polyfactory.factories.pydantic_factory import ModelFactory +from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory + +try: + from odmantic import Model + + from polyfactory.factories.odmantic_odm_factory import OdmanticModelFactory +except ImportError: + Model, OdmanticModelFactory = None, None # type: ignore + +try: + from beanie import Document + + from polyfactory.factories.beanie_odm_factory import BeanieDocumentFactory +except ImportError: + BeanieDocumentFactory = None # type: ignore + Document = None # type: ignore + + +@define +class AttrsBase: + bool_field: bool + + +class ModelBase(BaseModel): + dict_field: Dict[str, int] + + +class MsgspecBase(Struct): + int_field: int + + +class Base(metaclass=DeclarativeMeta): + __abstract__ = True + + registry = registry() + + +class SQLAlchemyBase(Base): + __tablename__ = "model" + + id: Any = Column(Integer(), primary_key=True) + + +class TypedDictBase(TypedDict): + name: str + + +@pytest.mark.parametrize( + "base_factory, generic_arg", + [ + (AttrsFactory, AttrsBase), + (ModelFactory, ModelBase), + (MsgspecFactory, MsgspecBase), + (SQLAlchemyFactory, SQLAlchemyBase), + (TypedDictFactory, TypedDictBase), + ], +) +def test_modeL_inference_ok(base_factory: Type[BaseFactory], generic_arg: Type[Any]) -> None: + class Foo(base_factory[generic_arg]): # type: ignore + ... + + assert getattr(Foo, "__model__") is generic_arg + + +@pytest.mark.skipif(Model is None, reason="Odmantic import error") +def test_odmantic_model_inference_ok() -> None: + class OdmanticModelBase(Model): # type: ignore + name: str + + class Foo(OdmanticModelFactory[OdmanticModelBase]): + ... + + assert getattr(Foo, "__model__") is OdmanticModelBase + + +@pytest.mark.skipif(Document is None, reason="Beanie import error") +def test_beanie_model_inference_ok() -> None: + class BeanieBase(Document): + name: str + + class Foo(BeanieDocumentFactory[BeanieBase]): + ... + + assert getattr(Foo, "__model__") is BeanieBase + + +@pytest.mark.parametrize( + "base_factory", + [ + AttrsFactory, + ModelFactory, + MsgspecFactory, + SQLAlchemyFactory, + TypedDictFactory, + ], +) +def test_model_without_generic_type_inference_error(base_factory: Type[BaseFactory]) -> None: + with pytest.raises(ConfigurationException): + + class Foo(base_factory): # type: ignore + ... + + +@pytest.mark.parametrize( + "base_factory", + [ + AttrsFactory, + ModelFactory, + MsgspecFactory, + SQLAlchemyFactory, + TypedDictFactory, + ], +) +def test_model_type_error(base_factory: Type[BaseFactory]) -> None: + with pytest.raises(ConfigurationException): + + class Foo(base_factory[int]): # type: ignore + ... + + +def test_model_multiple_inheritance_cannot_infer_error() -> None: + class PFoo(BaseModel): + val: int + + class TDFoo(TypedDict): + val: str + + with pytest.raises(ConfigurationException): + + class Foo(ModelFactory[PFoo], TypedDictFactory[TDFoo]): # type: ignore + ... + + +def test_generic_model_is_not_an_error() -> None: + T = TypeVar("T") + P = TypeVar("P") + + class Foo(GenericModel, Generic[T, P]): # type: ignore[misc] + val1: T + val2: P + + class FooFactory(ModelFactory[Foo[str, int]]): + ... + + assert getattr(FooFactory, "__model__") is Foo[str, int] From fd7cfed4218b195bb90339258523ebf0ca378e87 Mon Sep 17 00:00:00 2001 From: mityuha Date: Sat, 9 Dec 2023 19:25:19 +0100 Subject: [PATCH 08/11] fix README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 32d6e065..12e1f636 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ class Person: class PersonFactory(DataclassFactory[Person]): - __model__ = Person + ... def test_is_person() -> None: From ef3f747118230dddffb6f9115f9879cd2fbae349 Mon Sep 17 00:00:00 2001 From: mityuha Date: Sat, 9 Dec 2023 19:38:37 +0100 Subject: [PATCH 09/11] remove __mode__ declaration for all examples --- docs/PYPI_README.md | 2 +- docs/examples/configuration/test_example_1.py | 1 - docs/examples/configuration/test_example_2.py | 1 - docs/examples/configuration/test_example_3.py | 1 - docs/examples/configuration/test_example_4.py | 1 - docs/examples/configuration/test_example_5.py | 3 +-- docs/examples/configuration/test_example_6.py | 2 -- docs/examples/configuration/test_example_7.py | 1 - docs/examples/configuration/test_example_8.py | 1 - docs/examples/declaring_factories/test_example_1.py | 2 +- docs/examples/declaring_factories/test_example_2.py | 2 +- docs/examples/declaring_factories/test_example_3.py | 2 +- docs/examples/declaring_factories/test_example_4.py | 2 +- docs/examples/declaring_factories/test_example_5.py | 2 +- docs/examples/declaring_factories/test_example_7.py | 2 +- docs/examples/decorators/test_example_1.py | 2 -- docs/examples/fields/test_example_1.py | 2 -- docs/examples/fields/test_example_2.py | 4 ---- docs/examples/fields/test_example_3.py | 4 ---- docs/examples/fields/test_example_4.py | 4 ---- docs/examples/fields/test_example_5.py | 2 -- docs/examples/fields/test_example_6.py | 2 -- docs/examples/fields/test_example_7.py | 2 -- docs/examples/fields/test_example_8.py | 4 ---- docs/examples/fixtures/test_example_1.py | 2 +- docs/examples/fixtures/test_example_2.py | 2 +- docs/examples/fixtures/test_example_3.py | 2 +- docs/examples/fixtures/test_example_4.py | 4 +--- docs/examples/handling_custom_types/test_example_1.py | 2 -- docs/examples/handling_custom_types/test_example_2.py | 2 +- .../library_factories/sqlalchemy_factory/test_example_1.py | 2 +- .../library_factories/sqlalchemy_factory/test_example_2.py | 3 +-- .../library_factories/sqlalchemy_factory/test_example_3.py | 1 - docs/examples/model_coverage/test_example_1.py | 2 +- docs/examples/model_coverage/test_example_2.py | 2 +- 35 files changed, 17 insertions(+), 58 deletions(-) diff --git a/docs/PYPI_README.md b/docs/PYPI_README.md index b9530e76..ac3ce48a 100644 --- a/docs/PYPI_README.md +++ b/docs/PYPI_README.md @@ -44,7 +44,7 @@ class Person: class PersonFactory(DataclassFactory[Person]): - __model__ = Person + ... def test_is_person() -> None: diff --git a/docs/examples/configuration/test_example_1.py b/docs/examples/configuration/test_example_1.py index 6eef5dae..0136ec90 100644 --- a/docs/examples/configuration/test_example_1.py +++ b/docs/examples/configuration/test_example_1.py @@ -12,7 +12,6 @@ class Person: class PersonFactory(DataclassFactory[Person]): - __model__ = Person __random_seed__ = 1 @classmethod diff --git a/docs/examples/configuration/test_example_2.py b/docs/examples/configuration/test_example_2.py index 85676ec4..38c45339 100644 --- a/docs/examples/configuration/test_example_2.py +++ b/docs/examples/configuration/test_example_2.py @@ -13,7 +13,6 @@ class Person: class PersonFactory(DataclassFactory[Person]): - __model__ = Person __random__ = Random(10) @classmethod diff --git a/docs/examples/configuration/test_example_3.py b/docs/examples/configuration/test_example_3.py index 1b55e5be..b5a07acc 100644 --- a/docs/examples/configuration/test_example_3.py +++ b/docs/examples/configuration/test_example_3.py @@ -14,7 +14,6 @@ class Person: class PersonFactory(DataclassFactory[Person]): - __model__ = Person __faker__ = Faker(locale="es_ES") __random_seed__ = 10 diff --git a/docs/examples/configuration/test_example_4.py b/docs/examples/configuration/test_example_4.py index d52cf175..ac5e5787 100644 --- a/docs/examples/configuration/test_example_4.py +++ b/docs/examples/configuration/test_example_4.py @@ -50,7 +50,6 @@ async def save_many(self, data: List[Person]) -> List[Person]: class PersonFactory(DataclassFactory[Person]): - __model__ = Person __sync_persistence__ = SyncPersistenceHandler __async_persistence__ = AsyncPersistenceHandler diff --git a/docs/examples/configuration/test_example_5.py b/docs/examples/configuration/test_example_5.py index db63d9ef..e6b88742 100644 --- a/docs/examples/configuration/test_example_5.py +++ b/docs/examples/configuration/test_example_5.py @@ -32,14 +32,13 @@ class Person: class PetFactory(DataclassFactory[Pet]): - __model__ = Pet __set_as_default_factory_for_type__ = True name = Use(DataclassFactory.__random__.choice, ["Roxy", "Spammy", "Moshe"]) class PersonFactory(DataclassFactory[Person]): - __model__ = Person + ... def test_default_pet_factory() -> None: diff --git a/docs/examples/configuration/test_example_6.py b/docs/examples/configuration/test_example_6.py index a84aa60c..b962ba08 100644 --- a/docs/examples/configuration/test_example_6.py +++ b/docs/examples/configuration/test_example_6.py @@ -10,8 +10,6 @@ class Owner: class OwnerFactory(DataclassFactory[Owner]): - __model__ = Owner - __randomize_collection_length__ = True __min_collection_length__ = 2 __max_collection_length__ = 5 diff --git a/docs/examples/configuration/test_example_7.py b/docs/examples/configuration/test_example_7.py index db521b9f..22a29d9e 100644 --- a/docs/examples/configuration/test_example_7.py +++ b/docs/examples/configuration/test_example_7.py @@ -12,7 +12,6 @@ class Person: class PersonFactory(DataclassFactory[Person]): - __model__ = Person __allow_none_optionals__ = False diff --git a/docs/examples/configuration/test_example_8.py b/docs/examples/configuration/test_example_8.py index 3a440b23..0b607c68 100644 --- a/docs/examples/configuration/test_example_8.py +++ b/docs/examples/configuration/test_example_8.py @@ -19,6 +19,5 @@ def test_check_factory_fields() -> None: ): class PersonFactory(DataclassFactory[Person]): - __model__ = Person __check_model__ = True unknown_field = PostGenerated(lambda: "foo") diff --git a/docs/examples/declaring_factories/test_example_1.py b/docs/examples/declaring_factories/test_example_1.py index a8dda4cc..e24c3be9 100644 --- a/docs/examples/declaring_factories/test_example_1.py +++ b/docs/examples/declaring_factories/test_example_1.py @@ -12,7 +12,7 @@ class Person: class PersonFactory(DataclassFactory[Person]): - __model__ = Person + ... def test_is_person() -> None: diff --git a/docs/examples/declaring_factories/test_example_2.py b/docs/examples/declaring_factories/test_example_2.py index f5ce2474..6ebf8040 100644 --- a/docs/examples/declaring_factories/test_example_2.py +++ b/docs/examples/declaring_factories/test_example_2.py @@ -11,7 +11,7 @@ class Person(TypedDict): class PersonFactory(TypedDictFactory[Person]): - __model__ = Person + ... def test_is_person() -> None: diff --git a/docs/examples/declaring_factories/test_example_3.py b/docs/examples/declaring_factories/test_example_3.py index 1dcb28da..4dafe47b 100644 --- a/docs/examples/declaring_factories/test_example_3.py +++ b/docs/examples/declaring_factories/test_example_3.py @@ -11,7 +11,7 @@ class Person(BaseModel): class PersonFactory(ModelFactory[Person]): - __model__ = Person + ... def test_is_person() -> None: diff --git a/docs/examples/declaring_factories/test_example_4.py b/docs/examples/declaring_factories/test_example_4.py index f862f719..b192348e 100644 --- a/docs/examples/declaring_factories/test_example_4.py +++ b/docs/examples/declaring_factories/test_example_4.py @@ -12,7 +12,7 @@ class Person: class PersonFactory(DataclassFactory[Person]): - __model__ = Person + ... def test_is_person() -> None: diff --git a/docs/examples/declaring_factories/test_example_5.py b/docs/examples/declaring_factories/test_example_5.py index 24d80edd..f64de77f 100644 --- a/docs/examples/declaring_factories/test_example_5.py +++ b/docs/examples/declaring_factories/test_example_5.py @@ -31,7 +31,7 @@ class Person: class PersonFactory(DataclassFactory[Person]): - __model__ = Person + ... def test_dynamic_factory_generation() -> None: diff --git a/docs/examples/declaring_factories/test_example_7.py b/docs/examples/declaring_factories/test_example_7.py index c48bbe13..d17aee82 100644 --- a/docs/examples/declaring_factories/test_example_7.py +++ b/docs/examples/declaring_factories/test_example_7.py @@ -20,7 +20,7 @@ class Person: class PersonFactory(AttrsFactory[Person]): - __model__ = Person + ... def test_person_factory() -> None: diff --git a/docs/examples/decorators/test_example_1.py b/docs/examples/decorators/test_example_1.py index 924498ce..b64f90e2 100644 --- a/docs/examples/decorators/test_example_1.py +++ b/docs/examples/decorators/test_example_1.py @@ -13,8 +13,6 @@ class DatetimeRange: class DatetimeRangeFactory(DataclassFactory[DatetimeRange]): - __model__ = DatetimeRange - @post_generated @classmethod def to_dt(cls, from_dt: datetime) -> datetime: diff --git a/docs/examples/fields/test_example_1.py b/docs/examples/fields/test_example_1.py index 9e3ae799..5d8926af 100644 --- a/docs/examples/fields/test_example_1.py +++ b/docs/examples/fields/test_example_1.py @@ -34,8 +34,6 @@ class Person: class PersonFactory(DataclassFactory[Person]): - __model__ = Person - pets = [pet_instance] diff --git a/docs/examples/fields/test_example_2.py b/docs/examples/fields/test_example_2.py index 57639363..d8275241 100644 --- a/docs/examples/fields/test_example_2.py +++ b/docs/examples/fields/test_example_2.py @@ -32,15 +32,11 @@ class Person: class PetFactory(DataclassFactory[Pet]): - __model__ = Pet - name = Use(DataclassFactory.__random__.choice, ["Ralph", "Roxy"]) species = Use(DataclassFactory.__random__.choice, list(Species)) class PersonFactory(DataclassFactory[Person]): - __model__ = Person - pets = Use(PetFactory.batch, size=2) diff --git a/docs/examples/fields/test_example_3.py b/docs/examples/fields/test_example_3.py index 81a2bc1f..b312df40 100644 --- a/docs/examples/fields/test_example_3.py +++ b/docs/examples/fields/test_example_3.py @@ -32,15 +32,11 @@ class Person: class PetFactory(DataclassFactory[Pet]): - __model__ = Pet - name = lambda: DataclassFactory.__random__.choice(["Ralph", "Roxy"]) species = lambda: DataclassFactory.__random__.choice(list(Species)) class PersonFactory(DataclassFactory[Person]): - __model__ = Person - pets = Use(PetFactory.batch, size=2) diff --git a/docs/examples/fields/test_example_4.py b/docs/examples/fields/test_example_4.py index 84ed2ae4..31ede624 100644 --- a/docs/examples/fields/test_example_4.py +++ b/docs/examples/fields/test_example_4.py @@ -32,8 +32,6 @@ class Person: class PetFactory(DataclassFactory[Pet]): - __model__ = Pet - @classmethod def name(cls) -> str: return cls.__random__.choice(["Ralph", "Roxy"]) @@ -44,8 +42,6 @@ def species(cls) -> str: class PersonFactory(DataclassFactory[Person]): - __model__ = Person - pets = Use(PetFactory.batch, size=2) diff --git a/docs/examples/fields/test_example_5.py b/docs/examples/fields/test_example_5.py index 4435b9e8..fbe9b405 100644 --- a/docs/examples/fields/test_example_5.py +++ b/docs/examples/fields/test_example_5.py @@ -10,8 +10,6 @@ class Person(TypedDict): class PersonFactory(TypedDictFactory[Person]): - __model__ = Person - id = Ignore() diff --git a/docs/examples/fields/test_example_6.py b/docs/examples/fields/test_example_6.py index 893466bf..0c6a5f80 100644 --- a/docs/examples/fields/test_example_6.py +++ b/docs/examples/fields/test_example_6.py @@ -13,8 +13,6 @@ class Person(TypedDict): class PersonFactory(TypedDictFactory[Person]): - __model__ = Person - id = Require() diff --git a/docs/examples/fields/test_example_7.py b/docs/examples/fields/test_example_7.py index 86debbe6..c99f884c 100644 --- a/docs/examples/fields/test_example_7.py +++ b/docs/examples/fields/test_example_7.py @@ -18,8 +18,6 @@ class DatetimeRange: class DatetimeRangeFactory(DataclassFactory[DatetimeRange]): - __model__ = DatetimeRange - to_dt = PostGenerated(add_timedelta) diff --git a/docs/examples/fields/test_example_8.py b/docs/examples/fields/test_example_8.py index d079e6a0..b204ade5 100644 --- a/docs/examples/fields/test_example_8.py +++ b/docs/examples/fields/test_example_8.py @@ -31,14 +31,10 @@ class Person: class PetFactory(DataclassFactory[Pet]): - __model__ = Pet - name = lambda: DataclassFactory.__random__.choice(["Ralph", "Roxy"]) class PersonFactory(DataclassFactory[Person]): - __model__ = Person - pet = PetFactory diff --git a/docs/examples/fixtures/test_example_1.py b/docs/examples/fixtures/test_example_1.py index c5f7b51e..663493a1 100644 --- a/docs/examples/fixtures/test_example_1.py +++ b/docs/examples/fixtures/test_example_1.py @@ -19,7 +19,7 @@ class Person: @register_fixture class PersonFactory(DataclassFactory[Person]): - __model__ = Person + ... def test_person_factory(person_factory: PersonFactory) -> None: diff --git a/docs/examples/fixtures/test_example_2.py b/docs/examples/fixtures/test_example_2.py index ee26d811..58864fd5 100644 --- a/docs/examples/fixtures/test_example_2.py +++ b/docs/examples/fixtures/test_example_2.py @@ -18,7 +18,7 @@ class Person: class PersonFactory(DataclassFactory[Person]): - __model__ = Person + ... person_factory_fixture = register_fixture(PersonFactory) diff --git a/docs/examples/fixtures/test_example_3.py b/docs/examples/fixtures/test_example_3.py index f5cd3253..bebd96d8 100644 --- a/docs/examples/fixtures/test_example_3.py +++ b/docs/examples/fixtures/test_example_3.py @@ -18,7 +18,7 @@ class Person: class PersonFactory(DataclassFactory[Person]): - __model__ = Person + ... person_factory_fixture = register_fixture(PersonFactory, name="aliased_person_factory") diff --git a/docs/examples/fixtures/test_example_4.py b/docs/examples/fixtures/test_example_4.py index 431cc087..d6803f2a 100644 --- a/docs/examples/fixtures/test_example_4.py +++ b/docs/examples/fixtures/test_example_4.py @@ -26,12 +26,10 @@ class ClassRoom: @register_fixture class PersonFactory(DataclassFactory[Person]): - __model__ = Person + ... class ClassRoomFactory(DataclassFactory[ClassRoom]): - __model__ = ClassRoom - teacher = Fixture(PersonFactory, name="Ludmilla Newman") pupils = Fixture(PersonFactory, size=20) diff --git a/docs/examples/handling_custom_types/test_example_1.py b/docs/examples/handling_custom_types/test_example_1.py index 6f21076d..5f94ac29 100644 --- a/docs/examples/handling_custom_types/test_example_1.py +++ b/docs/examples/handling_custom_types/test_example_1.py @@ -26,8 +26,6 @@ class Person: # by default the factory class cannot handle unknown types, # so we need to override the provider map to add it: class PersonFactory(DataclassFactory[Person]): - __model__ = Person - @classmethod def get_provider_map(cls) -> Dict[Type, Any]: providers_map = super().get_provider_map() diff --git a/docs/examples/handling_custom_types/test_example_2.py b/docs/examples/handling_custom_types/test_example_2.py index 9cd954ce..5e35fcc9 100644 --- a/docs/examples/handling_custom_types/test_example_2.py +++ b/docs/examples/handling_custom_types/test_example_2.py @@ -42,7 +42,7 @@ class Person: # we use our CustomDataclassFactory as a base for the PersonFactory class PersonFactory(CustomDataclassFactory[Person]): - __model__ = Person + ... def test_custom_dataclass_base_factory() -> None: diff --git a/docs/examples/library_factories/sqlalchemy_factory/test_example_1.py b/docs/examples/library_factories/sqlalchemy_factory/test_example_1.py index 9b2a3044..292c10d0 100644 --- a/docs/examples/library_factories/sqlalchemy_factory/test_example_1.py +++ b/docs/examples/library_factories/sqlalchemy_factory/test_example_1.py @@ -15,7 +15,7 @@ class Author(Base): class AuthorFactory(SQLAlchemyFactory[Author]): - __model__ = Author + ... def test_sqla_factory() -> None: diff --git a/docs/examples/library_factories/sqlalchemy_factory/test_example_2.py b/docs/examples/library_factories/sqlalchemy_factory/test_example_2.py index 9dc55f05..80a51463 100644 --- a/docs/examples/library_factories/sqlalchemy_factory/test_example_2.py +++ b/docs/examples/library_factories/sqlalchemy_factory/test_example_2.py @@ -27,11 +27,10 @@ class Book(Base): class AuthorFactory(SQLAlchemyFactory[Author]): - __model__ = Author + ... class AuthorFactoryWithRelationship(SQLAlchemyFactory[Author]): - __model__ = Author __set_relationships__ = True diff --git a/docs/examples/library_factories/sqlalchemy_factory/test_example_3.py b/docs/examples/library_factories/sqlalchemy_factory/test_example_3.py index a958c877..d51985d3 100644 --- a/docs/examples/library_factories/sqlalchemy_factory/test_example_3.py +++ b/docs/examples/library_factories/sqlalchemy_factory/test_example_3.py @@ -27,7 +27,6 @@ class Book(Base): class AuthorFactory(SQLAlchemyFactory[Author]): - __model__ = Author __set_relationships__ = True diff --git a/docs/examples/model_coverage/test_example_1.py b/docs/examples/model_coverage/test_example_1.py index ced96f4a..a832b6ce 100644 --- a/docs/examples/model_coverage/test_example_1.py +++ b/docs/examples/model_coverage/test_example_1.py @@ -24,7 +24,7 @@ class Profile: class ProfileFactory(DataclassFactory[Profile]): - __model__ = Profile + ... def test_profile_coverage() -> None: diff --git a/docs/examples/model_coverage/test_example_2.py b/docs/examples/model_coverage/test_example_2.py index bf67f959..cd8787b8 100644 --- a/docs/examples/model_coverage/test_example_2.py +++ b/docs/examples/model_coverage/test_example_2.py @@ -29,7 +29,7 @@ class SocialGroup: class SocialGroupFactory(DataclassFactory[SocialGroup]): - __model__ = SocialGroup + ... def test_social_group_coverage() -> None: From 6f6c7b09454ecf0a8141d8e7d62b16aa27518ec9 Mon Sep 17 00:00:00 2001 From: mityuha Date: Sat, 9 Dec 2023 19:50:17 +0100 Subject: [PATCH 10/11] fix description for the __model__ attribute --- polyfactory/factories/base.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/polyfactory/factories/base.py b/polyfactory/factories/base.py index f4aeffb2..4b831a15 100644 --- a/polyfactory/factories/base.py +++ b/polyfactory/factories/base.py @@ -37,7 +37,6 @@ Generic, Iterable, Mapping, - Optional, Sequence, Type, TypeVar, @@ -111,7 +110,7 @@ class BaseFactory(ABC, Generic[T]): __model__: type[T] """ The model for the factory. - This attribute is required for non-base factories and an exception will be raised if its not set. + This attribute is required for non-base factories and an exception will be raised if it's not set. Can be automatically inferred from the factory generic argument. """ __check_model__: bool = False """ @@ -222,7 +221,7 @@ def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None: # noqa: C901 BaseFactory._factory_type_mapping[cls.__model__] = cls @classmethod - def _infer_model_type(cls: type[F]) -> Optional[type[T]]: + def _infer_model_type(cls: type[F]) -> type[T] | None: """Return model type inferred from class declaration. class Foo(ModelFactory[MyModel]): # <<< MyModel ... From 10aa7e272be029cfa4f084be8b787ef1af5c8c34 Mon Sep 17 00:00:00 2001 From: Sourcery AI <> Date: Sat, 9 Dec 2023 18:52:59 +0000 Subject: [PATCH 11/11] 'Refactored by Sourcery' --- polyfactory/factories/base.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/polyfactory/factories/base.py b/polyfactory/factories/base.py index 242b73cd..fac77be8 100644 --- a/polyfactory/factories/base.py +++ b/polyfactory/factories/base.py @@ -237,10 +237,7 @@ class Foo(ModelFactory[MyModel]): # <<< MyModel generic_args: Sequence[type[T]] = [ arg for factory_base in factory_bases for arg in get_args(factory_base) if not isinstance(arg, TypeVar) ] - if len(generic_args) != 1: - return None - - return generic_args[0] + return None if len(generic_args) != 1 else generic_args[0] @classmethod def _get_sync_persistence(cls) -> SyncPersistenceProtocol[T]: