Skip to content

Commit

Permalink
feat(type-coverage-generation): model type coverage batch generation (#…
Browse files Browse the repository at this point in the history
…390)

Co-authored-by: Jacob Coffee <[email protected]>
Co-authored-by: Andrew Truong <[email protected]>
Co-authored-by: guacs <[email protected]>
  • Loading branch information
4 people authored Nov 12, 2023
1 parent 70d49fd commit b1e8b5e
Show file tree
Hide file tree
Showing 10 changed files with 720 additions and 19 deletions.
Empty file.
38 changes: 38 additions & 0 deletions docs/examples/model_coverage/test_example_1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Literal

from polyfactory.factories.dataclass_factory import DataclassFactory


@dataclass
class Car:
model: str


@dataclass
class Boat:
can_float: bool


@dataclass
class Profile:
age: int
favourite_color: Literal["red", "green", "blue"]
vehicle: Car | Boat


class ProfileFactory(DataclassFactory[Profile]):
__model__ = Profile


def test_profile_coverage() -> None:
profiles = list(ProfileFactory.coverage())

assert profiles[0].favourite_color == "red"
assert isinstance(profiles[0].vehicle, Car)
assert profiles[1].favourite_color == "green"
assert isinstance(profiles[1].vehicle, Boat)
assert profiles[2].favourite_color == "blue"
assert isinstance(profiles[2].vehicle, Car)
47 changes: 47 additions & 0 deletions docs/examples/model_coverage/test_example_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Literal

from polyfactory.factories.dataclass_factory import DataclassFactory


@dataclass
class Car:
model: str


@dataclass
class Boat:
can_float: bool


@dataclass
class Profile:
age: int
favourite_color: Literal["red", "green", "blue"]
vehicle: Car | Boat


@dataclass
class SocialGroup:
members: list[Profile]


class SocialGroupFactory(DataclassFactory[SocialGroup]):
__model__ = SocialGroup


def test_social_group_coverage() -> None:
groups = list(SocialGroupFactory.coverage())
assert len(groups) == 3

for group in groups:
assert len(group.members) == 1

assert groups[0].members[0].favourite_color == "red"
assert isinstance(groups[0].members[0].vehicle, Car)
assert groups[1].members[0].favourite_color == "green"
assert isinstance(groups[1].members[0].vehicle, Boat)
assert groups[2].members[0].favourite_color == "blue"
assert isinstance(groups[2].members[0].vehicle, Car)
1 change: 1 addition & 0 deletions docs/usage/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ Usage Guide
decorators
fixtures
handling_custom_types
model_coverage
29 changes: 29 additions & 0 deletions docs/usage/model_coverage.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
Model coverage generation
=========================

The ``BaseFactory.coverage()`` function is an alternative approach to ``BaseFactory.batch()``, where the examples that are generated attempt to provide full coverage of all the forms a model can take with the minimum number of instances. For example:

.. literalinclude:: /examples/model_coverage/test_example_1.py
:caption: Defining a factory and generating examples with coverage
:language: python

As you can see in the above example, the ``Profile`` model has 3 options for ``favourite_color``, and 2 options for ``vehicle``. In the output you can expect to see instances of ``Profile`` that have each of these options. The largest variance dictates the length of the output, in this case ``favourite_color`` has the most, at 3 options, so expect to see 3 ``Profile`` instances.


.. note::
Notice that the same ``Car`` instance is used in the first and final generated example. When the coverage examples for a field are exhausted before another field, values for that field are re-used.

Notes on collection types
-------------------------

When generating coverage for models with fields that are collections, in particular collections that contain sub-models, the contents of the collection will be the all coverage examples for that sub-model. For example:

.. literalinclude:: /examples/model_coverage/test_example_2.py
:caption: Coverage output for the SocialGroup model
:language: python

Known Limitations
-----------------

- Recursive models will cause an error: ``RecursionError: maximum recursion depth exceeded``.
- ``__min_collection_length__`` and ``__max_collection_length__`` are currently ignored in coverage generation.
167 changes: 160 additions & 7 deletions polyfactory/factories/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import typing
from abc import ABC, abstractmethod
from collections import Counter, abc, deque
from contextlib import suppress
Expand All @@ -22,6 +23,12 @@
from os.path import realpath
from pathlib import Path
from random import Random

try:
from types import NoneType
except ImportError:
NoneType = type(None) # type: ignore[misc,assignment]

from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -46,13 +53,16 @@
MIN_COLLECTION_LENGTH,
RANDOMIZE_COLLECTION_LENGTH,
)
from polyfactory.exceptions import (
ConfigurationException,
MissingBuildKwargException,
ParameterException,
)
from polyfactory.exceptions import ConfigurationException, MissingBuildKwargException, ParameterException
from polyfactory.fields import Fixture, Ignore, PostGenerated, Require, Use
from polyfactory.utils.helpers import get_collection_type, unwrap_annotation, unwrap_args, unwrap_optional
from polyfactory.utils.helpers import (
flatten_annotation,
get_collection_type,
unwrap_annotation,
unwrap_args,
unwrap_optional,
)
from polyfactory.utils.model_coverage import CoverageContainer, CoverageContainerCallable, resolve_kwargs_coverage
from polyfactory.utils.predicates import (
get_type_origin,
is_any,
Expand All @@ -61,7 +71,7 @@
is_safe_subclass,
is_union,
)
from polyfactory.value_generators.complex_types import handle_collection_type
from polyfactory.value_generators.complex_types import handle_collection_type, handle_collection_type_coverage
from polyfactory.value_generators.constrained_collections import (
handle_constrained_collection,
handle_constrained_mapping,
Expand Down Expand Up @@ -263,6 +273,32 @@ def _handle_factory_field(cls, field_value: Any, field_build_parameters: Any | N

return field_value() if callable(field_value) else field_value

@classmethod
def _handle_factory_field_coverage(cls, field_value: Any, field_build_parameters: Any | None = None) -> Any:
"""Handle a value defined on the factory class itself.
:param field_value: A value defined as an attribute on the factory class.
:param field_build_parameters: Any build parameters passed to the factory as kwarg values.
:returns: An arbitrary value correlating with the given field_meta value.
"""
if is_safe_subclass(field_value, BaseFactory):
if isinstance(field_build_parameters, Mapping):
return CoverageContainer(field_value.coverage(**field_build_parameters))

if isinstance(field_build_parameters, Sequence):
return [CoverageContainer(field_value.coverage(**parameter)) for parameter in field_build_parameters]

return CoverageContainer(field_value.coverage())

if isinstance(field_value, Use):
return field_value.to_value()

if isinstance(field_value, Fixture):
return CoverageContainerCallable(field_value.to_value)

return CoverageContainerCallable(field_value) if callable(field_value) else field_value

@classmethod
def _get_or_create_factory(cls, model: type) -> type[BaseFactory[Any]]:
"""Get a factory from registered factories or generate a factory dynamically.
Expand Down Expand Up @@ -635,6 +671,66 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912
msg,
)

@classmethod
def get_field_value_coverage( # noqa: C901
cls,
field_meta: FieldMeta,
field_build_parameters: Any | None = None,
) -> typing.Iterable[Any]:
"""Return a field value on the subclass if existing, otherwise returns a mock value.
:param field_meta: FieldMeta instance.
:param field_build_parameters: Any build parameters passed to the factory as kwarg values.
:returns: An iterable of values.
"""
if cls.is_ignored_type(field_meta.annotation):
return [None]

for unwrapped_annotation in flatten_annotation(field_meta.annotation):
if unwrapped_annotation in (None, NoneType):
yield None

elif is_literal(annotation=unwrapped_annotation) and (literal_args := get_args(unwrapped_annotation)):
yield CoverageContainer(literal_args)

elif isinstance(unwrapped_annotation, EnumMeta):
yield CoverageContainer(list(unwrapped_annotation))

elif field_meta.constraints:
yield CoverageContainerCallable(
cls.get_constrained_field_value,
annotation=unwrapped_annotation,
field_meta=field_meta,
)

elif BaseFactory.is_factory_type(annotation=unwrapped_annotation):
yield CoverageContainer(
cls._get_or_create_factory(model=unwrapped_annotation).coverage(
**(field_build_parameters if isinstance(field_build_parameters, Mapping) else {}),
),
)

elif (origin := get_type_origin(unwrapped_annotation)) and issubclass(origin, Collection):
yield handle_collection_type_coverage(field_meta, origin, cls)

elif is_any(unwrapped_annotation) or isinstance(unwrapped_annotation, TypeVar):
yield create_random_string(cls.__random__, min_length=1, max_length=10)

elif provider := cls.get_provider_map().get(unwrapped_annotation):
yield CoverageContainerCallable(provider)

elif callable(unwrapped_annotation):
# if value is a callable we can try to naively call it.
# this will work for callables that do not require any parameters passed
yield CoverageContainerCallable(unwrapped_annotation)
else:
msg = f"Unsupported type: {unwrapped_annotation!r}\n\nEither extend the providers map or add a factory function for this type."
raise ParameterException(
msg,
)

@classmethod
def should_set_none_value(cls, field_meta: FieldMeta) -> bool:
"""Determine whether a given model field_meta should be set to None.
Expand Down Expand Up @@ -752,6 +848,50 @@ def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]:

return result

@classmethod
def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]:
"""Process the given kwargs and generate values for the factory's model.
:param kwargs: Any build kwargs.
:returns: A dictionary of build results.
"""
result: dict[str, Any] = {**kwargs}
generate_post: dict[str, PostGenerated] = {}

for field_meta in cls.get_model_fields():
field_build_parameters = cls.extract_field_build_parameters(field_meta=field_meta, build_args=kwargs)

if cls.should_set_field_value(field_meta, **kwargs):
if hasattr(cls, field_meta.name) and not hasattr(BaseFactory, field_meta.name):
field_value = getattr(cls, field_meta.name)
if isinstance(field_value, Ignore):
continue

if isinstance(field_value, Require) and field_meta.name not in kwargs:
msg = f"Require kwarg {field_meta.name} is missing"
raise MissingBuildKwargException(msg)

if isinstance(field_value, PostGenerated):
generate_post[field_meta.name] = field_value
continue

result[field_meta.name] = cls._handle_factory_field_coverage(
field_value=field_value,
field_build_parameters=field_build_parameters,
)
continue

result[field_meta.name] = CoverageContainer(
cls.get_field_value_coverage(field_meta, field_build_parameters=field_build_parameters),
)

for resolved in resolve_kwargs_coverage(result):
for field_name, post_generator in generate_post.items():
resolved[field_name] = post_generator.to_value(field_name, resolved)
yield resolved

@classmethod
def build(cls, **kwargs: Any) -> T:
"""Build an instance of the factory's __model__
Expand All @@ -776,6 +916,19 @@ def batch(cls, size: int, **kwargs: Any) -> list[T]:
"""
return [cls.build(**kwargs) for _ in range(size)]

@classmethod
def coverage(cls, **kwargs: Any) -> abc.Iterator[T]:
"""Build a batch of the factory's Meta.model will full coverage of the sub-types of the model.
:param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used.
:returns: A iterator of instances of type T.
"""
for data in cls.process_kwargs_coverage(**kwargs):
instance = cls.__model__(**data)
yield cast("T", instance)

@classmethod
def create_sync(cls, **kwargs: Any) -> T:
"""Build and persists synchronously a single model instance.
Expand Down
31 changes: 30 additions & 1 deletion polyfactory/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
import sys
from typing import TYPE_CHECKING, Any, Mapping

try:
from types import NoneType
except ImportError:
NoneType = type(None) # type: ignore[misc,assignment]

from typing_extensions import get_args, get_origin

from polyfactory.constants import TYPE_MAPPING
Expand Down Expand Up @@ -52,7 +57,7 @@ def unwrap_optional(annotation: Any) -> Any:
:returns: A type annotation
"""
while is_optional(annotation):
annotation = next(arg for arg in get_args(annotation) if arg not in (type(None), None))
annotation = next(arg for arg in get_args(annotation) if arg not in (NoneType, None))
return annotation


Expand All @@ -77,6 +82,30 @@ def unwrap_annotation(annotation: Any, random: Random) -> Any:
return annotation


def flatten_annotation(annotation: Any) -> list[Any]:
"""Flattens an annotation.
:param annotation: A type annotation.
:returns: The flattened annotations.
"""
flat = []
if is_new_type(annotation):
flat.extend(flatten_annotation(unwrap_new_type(annotation)))
elif is_optional(annotation):
flat.append(NoneType)
flat.extend(flatten_annotation(arg) for arg in get_args(annotation) if arg not in (NoneType, None))
elif is_annotated(annotation):
flat.extend(flatten_annotation(get_args(annotation)[0]))
elif is_union(annotation):
for a in get_args(annotation):
flat.extend(flatten_annotation(a))
else:
flat.append(annotation)

return flat


def unwrap_args(annotation: Any, random: Random) -> tuple[Any, ...]:
"""Unwrap the annotation and return any type args.
Expand Down
Loading

0 comments on commit b1e8b5e

Please sign in to comment.