Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: check if fields declared on the factory belong to the model #405

24 changes: 24 additions & 0 deletions docs/examples/configuration/test_example_8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from dataclasses import dataclass
from uuid import UUID

import pytest

from polyfactory import ConfigurationException, PostGenerated
from polyfactory.factories.dataclass_factory import DataclassFactory


@dataclass
class Person:
id: UUID


def test_check_factory_fields() -> None:
with pytest.raises(
ConfigurationException,
match="unknown_field is declared on the factory PersonFactory but it is not part of the model Person",
):

class PersonFactory(DataclassFactory[Person]):
__model__ = Person
__check_model__ = True
unknown_field = PostGenerated(lambda: "foo")
12 changes: 12 additions & 0 deletions docs/usage/configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,15 @@ By setting to `False`, then optional types will always be treated as the wrapped
.. literalinclude:: /examples/configuration/test_example_7.py
:caption: Disable Allow None Optionals
:language: python

Check Factory Fields
--------------------
When `__check_model__` is set to `True`, declaring fields on the factory that don't exist on the model will trigger an exception.

This is only true when fields are declared with ``Use``, ``PostGenerated``, ``Ignore`` and ``Require``.
Any other field definition will not be checked.


.. literalinclude:: /examples/configuration/test_example_8.py
:caption: Enable Check Factory Fields
:language: python
1 change: 1 addition & 0 deletions docs/usage/declaring_factories.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Or for attrs models:
Validators are not currently supported - neither the built in validators that come
with `attrs` nor custom validators.


Imperative Factory Creation
---------------------------

Expand Down
41 changes: 40 additions & 1 deletion polyfactory/factories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ class BaseFactory(ABC, Generic[T]):
The model for the factory.
This attribute is required for non-base factories and an exception will be raised if its not set.
"""
__check_model__: bool = False
"""
Flag dictating whether to check if fields defined on the factory exists on the model or not.
If 'True', checks will be done against Use, PostGenerated, Ignore, Require constructs fields only.
"""
__allow_none_optionals__: ClassVar[bool] = True
"""
Flag dictating whether to allow 'None' for optional values.
Expand Down Expand Up @@ -184,14 +189,16 @@ def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None: # noqa: C901
if not cls.is_supported_type(model):
for factory in BaseFactory._base_factories:
if factory.is_supported_type(model):
msg = f"{cls.__name__} does not support {model.__name__}, but this type is support by the {factory.__name__} base factory class. To resolve this error, subclass the factory from {factory.__name__} instead of {cls.__name__}"
msg = f"{cls.__name__} does not support {model.__name__}, but this type is supported by the {factory.__name__} base factory class. To resolve this error, subclass the factory from {factory.__name__} instead of {cls.__name__}"
raise ConfigurationException(
msg,
)
msg = f"Model type {model.__name__} is not supported. To support it, register an appropriate base factory and subclass it for your factory."
raise ConfigurationException(
msg,
)
if cls.__check_model__:
cls._check_declared_fields_exist_in_model()
else:
BaseFactory._base_factories.append(cls)

Expand Down Expand Up @@ -665,6 +672,38 @@ def get_model_fields(cls) -> list[FieldMeta]: # pragma: no cover
"""
raise NotImplementedError

@classmethod
def get_factory_fields(cls) -> list[tuple[str, Any]]:
guacs marked this conversation as resolved.
Show resolved Hide resolved
"""Retrieve a list of fields from the factory.

Trying to be smart about what should be considered a field on the model,
ignoring dunder methods and some parent class attributes.

:returns: A list of tuples made of field name and field definition
"""
factory_fields = cls.__dict__.items()
return [
(field_name, field_value)
for field_name, field_value in factory_fields
if not (field_name.startswith("__") or field_name == "_abc_impl")
]

@classmethod
def _check_declared_fields_exist_in_model(cls) -> None:
model_fields_names = {field_meta.name for field_meta in cls.get_model_fields()}
factory_fields = cls.get_factory_fields()

for field_name, field_value in factory_fields:
if field_name in model_fields_names:
continue

error_message = (
f"{field_name} is declared on the factory {cls.__name__}"
f" but it is not part of the model {cls.__model__.__name__}"
)
if isinstance(field_value, (Use, PostGenerated, Ignore, Require)):
raise ConfigurationException(error_message)
guacs marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]:
"""Process the given kwargs and generate values for the factory's model.
Expand Down
44 changes: 42 additions & 2 deletions tests/test_factory_fields.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import random
from datetime import datetime, timedelta
from typing import Any, Optional
from typing import Any, Optional, Union

import pytest
from pydantic import BaseModel

from polyfactory.decorators import post_generated
from polyfactory.exceptions import MissingBuildKwargException
from polyfactory.exceptions import ConfigurationException, MissingBuildKwargException
from polyfactory.factories.pydantic_factory import ModelFactory
from polyfactory.fields import Ignore, PostGenerated, Require, Use

Expand Down Expand Up @@ -161,3 +161,43 @@ def caption(cls, is_long: bool) -> str:
assert result.caption == "this was really long for me"
else:
assert result.caption == "just this"


@pytest.mark.parametrize(
"factory_field",
[
Use(lambda: "foo"),
PostGenerated(lambda: "foo"),
Require(),
Ignore(),
],
)
def test_non_existing_model_fields_does_not_raise_by_default(
factory_field: Union[Use, PostGenerated, Require, Ignore],
) -> None:
class NoFieldModel(BaseModel):
pass

ModelFactory.create_factory(NoFieldModel, bases=None, unknown_field=factory_field)


@pytest.mark.parametrize(
"factory_field",
[
Use(lambda: "foo"),
PostGenerated(lambda: "foo"),
Require(),
Ignore(),
],
)
def test_non_existing_model_fields_raises_with__check__model__(
factory_field: Union[Use, PostGenerated, Require, Ignore],
) -> None:
class NoFieldModel(BaseModel):
pass

with pytest.raises(
ConfigurationException,
match="unknown_field is declared on the factory NoFieldModelFactory but it is not part of the model NoFieldModel",
):
ModelFactory.create_factory(NoFieldModel, bases=None, __check_model__=True, unknown_field=factory_field)