Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(type-coverage-generation): Model type coverage batch generation #390

Merged
merged 42 commits into from
Nov 12, 2023
Merged
Show file tree
Hide file tree
Changes from 41 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
d689ce0
feat(type-coverage-gen): Initial implementation of type coverage gene…
sam-or Sep 27, 2023
84c9c51
fix: revert change to .pre-commit-config.yaml
sam-or Sep 27, 2023
e97b1c1
fix: Update NoneType importing for older python versions
sam-or Sep 27, 2023
8df692f
fix: apply sourcery refactor
sam-or Sep 27, 2023
9734172
fix: import ParamSpec from typing_extensions
sam-or Sep 27, 2023
373fea4
fix: Skip tests on py versions < 3.10
sam-or Sep 27, 2023
ae38e54
fix: revert changes to .pre-commit-config.yaml
sam-or Sep 27, 2023
1fb0608
chore: Create devcontainer.json
sam-or Sep 27, 2023
f2289d0
fix: remove .devcontainer dir
sam-or Sep 27, 2023
d9adc27
fix: Add missing test skip for older python versions
sam-or Sep 27, 2023
20f4813
test: Add test for post generated in coverage generation
sam-or Sep 27, 2023
176689a
Merge remote-tracking branch 'upstream/main' into coverage
sam-or Oct 3, 2023
7f71339
test: Simplify type coverage generation tests
sam-or Oct 4, 2023
f93a498
test: Add back min python3.10 version condition
sam-or Oct 4, 2023
e306abc
fix(infra): update makefile (#399)
JacobCoffee Oct 7, 2023
8752b81
docs: Install all dependencies for docs build (#404)
adhtruong Oct 11, 2023
f84b771
fix: decouple the handling of collection length configuration from `F…
guacs Oct 14, 2023
8a3ac1f
refactor: refactor the msgspec factory to use the fields API (#409)
guacs Oct 14, 2023
5baf00a
chore: prepare for releasing v2.10 (#410)
guacs Oct 16, 2023
7570acd
feat(type-coverage-gen): Initial implementation of type coverage gene…
sam-or Sep 27, 2023
87df742
fix: Update NoneType importing for older python versions
sam-or Sep 27, 2023
811658b
fix: Make CoverageContainer generic
sam-or Oct 16, 2023
bac1971
fix: linting and rebase issues
sam-or Oct 16, 2023
2c0a71a
Merge remote-tracking branch 'upstream/main' into coverage
sam-or Oct 16, 2023
deb72a1
fix: revert pre-commit conf change
sam-or Oct 16, 2023
85e2525
docs(type-coverage-gen): Add docs for coverage gen
sam-or Oct 18, 2023
88923de
docs: Fix formatting in coverage docs
sam-or Oct 18, 2023
006307d
Merge branch 'main' into coverage
sam-or Oct 18, 2023
a304a99
Merge branch 'main' into coverage
sam-or Oct 23, 2023
2df1a26
docs: Move profile coverage exmaple into test func
sam-or Nov 2, 2023
f989775
docs: Update social group example to use test func
sam-or Nov 2, 2023
b337bc0
fix: Address review comments
sam-or Nov 3, 2023
be5e712
test: Remove 3.10 requirement for coverage tests
sam-or Nov 3, 2023
eb075e0
test: Move CustomInt definition outside of test
sam-or Nov 3, 2023
5299535
test: disable ruff UP006 in test file
sam-or Nov 3, 2023
93f8658
test: fix social group test in docs example
sam-or Nov 3, 2023
d70c526
test: fix social group test in doc example
sam-or Nov 3, 2023
d6e2ce8
test: Change hint dict to Dict in coverage test
sam-or Nov 3, 2023
d0cf706
test: Fix tuple annotation in coverage tests
sam-or Nov 3, 2023
3201015
Merge branch 'main' into coverage
sam-or Nov 3, 2023
ba355e5
chore: fix formatting in docstring
sam-or Nov 6, 2023
9c84551
chore: Add docstring to CoverageContainerCallable
sam-or Nov 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
38 changes: 38 additions & 0 deletions docs/examples/model_coverage/test_example_1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Literal

from polyfactory.factories.dataclass_factory import DataclassFactory


@dataclass
class Car:
model: str


@dataclass
class Boat:
can_float: bool


@dataclass
class Profile:
age: int
favourite_color: Literal["red", "green", "blue"]
vehicle: Car | Boat


class ProfileFactory(DataclassFactory[Profile]):
__model__ = Profile


def test_profile_coverage() -> None:
profiles = list(ProfileFactory.coverage())

assert profiles[0].favourite_color == "red"
assert isinstance(profiles[0].vehicle, Car)
assert profiles[1].favourite_color == "green"
assert isinstance(profiles[1].vehicle, Boat)
assert profiles[2].favourite_color == "blue"
assert isinstance(profiles[2].vehicle, Car)
47 changes: 47 additions & 0 deletions docs/examples/model_coverage/test_example_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Literal

from polyfactory.factories.dataclass_factory import DataclassFactory


@dataclass
class Car:
model: str


@dataclass
class Boat:
can_float: bool


@dataclass
class Profile:
age: int
favourite_color: Literal["red", "green", "blue"]
vehicle: Car | Boat


@dataclass
class SocialGroup:
members: list[Profile]


class SocialGroupFactory(DataclassFactory[SocialGroup]):
__model__ = SocialGroup


def test_social_group_coverage() -> None:
groups = list(SocialGroupFactory.coverage())
assert len(groups) == 3

for group in groups:
assert len(group.members) == 1

assert groups[0].members[0].favourite_color == "red"
assert isinstance(groups[0].members[0].vehicle, Car)
assert groups[1].members[0].favourite_color == "green"
assert isinstance(groups[1].members[0].vehicle, Boat)
assert groups[2].members[0].favourite_color == "blue"
assert isinstance(groups[2].members[0].vehicle, Car)
1 change: 1 addition & 0 deletions docs/usage/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ Usage Guide
decorators
fixtures
handling_custom_types
model_coverage
29 changes: 29 additions & 0 deletions docs/usage/model_coverage.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
Model coverage generation
=========================

The ``BaseFactory.coverage()`` function is an alternative approach to ``BaseFactory.batch()``, where the examples that are generated attempt to provide full coverage of all the forms a model can take with the minimum number of instances. For example:

.. literalinclude:: /examples/model_coverage/test_example_1.py
:caption: Defining a factory and generating examples with coverage
:language: python

As you can see in the above example, the ``Profile`` model has 3 options for ``favourite_color``, and 2 options for ``vehicle``. In the output you can expect to see instances of ``Profile`` that have each of these options. The largest variance dictates the length of the output, in this case ``favourite_color`` has the most, at 3 options, so expect to see 3 ``Profile`` instances.


.. note::
Notice that the same ``Car`` instance is used in the first and final generated example. When the coverage examples for a field are exhausted before another field, values for that field are re-used.

Notes on collection types
-------------------------

When generating coverage for models with fields that are collections, in particular collections that contain sub-models, the contents of the collection will be the all coverage examples for that sub-model. For example:

.. literalinclude:: /examples/model_coverage/test_example_2.py
:caption: Coverage output for the SocialGroup model
:language: python

Known Limitations
-----------------

- Recursive models will cause an error: ``RecursionError: maximum recursion depth exceeded``.
- ``__min_collection_length__`` and ``__max_collection_length__`` are currently ignored in coverage generation.
167 changes: 160 additions & 7 deletions polyfactory/factories/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import typing
from abc import ABC, abstractmethod
from collections import Counter, abc, deque
from contextlib import suppress
Expand All @@ -22,6 +23,12 @@
from os.path import realpath
from pathlib import Path
from random import Random

try:
from types import NoneType
except ImportError:
NoneType = type(None) # type: ignore[misc,assignment]

from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -46,13 +53,16 @@
MIN_COLLECTION_LENGTH,
RANDOMIZE_COLLECTION_LENGTH,
)
from polyfactory.exceptions import (
ConfigurationException,
MissingBuildKwargException,
ParameterException,
)
from polyfactory.exceptions import ConfigurationException, MissingBuildKwargException, ParameterException
from polyfactory.fields import Fixture, Ignore, PostGenerated, Require, Use
from polyfactory.utils.helpers import get_collection_type, unwrap_annotation, unwrap_args, unwrap_optional
from polyfactory.utils.helpers import (
flatten_annotation,
get_collection_type,
unwrap_annotation,
unwrap_args,
unwrap_optional,
)
from polyfactory.utils.model_coverage import CoverageContainer, CoverageContainerCallable, resolve_kwargs_coverage
from polyfactory.utils.predicates import (
get_type_origin,
is_any,
Expand All @@ -61,7 +71,7 @@
is_safe_subclass,
is_union,
)
from polyfactory.value_generators.complex_types import handle_collection_type
from polyfactory.value_generators.complex_types import handle_collection_type, handle_collection_type_coverage
from polyfactory.value_generators.constrained_collections import (
handle_constrained_collection,
handle_constrained_mapping,
Expand Down Expand Up @@ -263,6 +273,32 @@ def _handle_factory_field(cls, field_value: Any, field_build_parameters: Any | N

return field_value() if callable(field_value) else field_value

@classmethod
def _handle_factory_field_coverage(cls, field_value: Any, field_build_parameters: Any | None = None) -> Any:
"""Handle a value defined on the factory class itself.

:param field_value: A value defined as an attribute on the factory class.
:param field_build_parameters: Any build parameters passed to the factory as kwarg values.

:returns: An arbitrary value correlating with the given field_meta value.
"""
if is_safe_subclass(field_value, BaseFactory):
if isinstance(field_build_parameters, Mapping):
return CoverageContainer(field_value.coverage(**field_build_parameters))

if isinstance(field_build_parameters, Sequence):
return [CoverageContainer(field_value.coverage(**parameter)) for parameter in field_build_parameters]

return CoverageContainer(field_value.coverage())

if isinstance(field_value, Use):
return field_value.to_value()

if isinstance(field_value, Fixture):
return CoverageContainerCallable(field_value.to_value)

return CoverageContainerCallable(field_value) if callable(field_value) else field_value

@classmethod
def _get_or_create_factory(cls, model: type) -> type[BaseFactory[Any]]:
"""Get a factory from registered factories or generate a factory dynamically.
Expand Down Expand Up @@ -635,6 +671,66 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912
msg,
)

@classmethod
def get_field_value_coverage( # noqa: C901
cls,
field_meta: FieldMeta,
field_build_parameters: Any | None = None,
) -> typing.Iterable[Any]:
"""Return a field value on the subclass if existing, otherwise returns a mock value.

:param field_meta: FieldMeta instance.
:param field_build_parameters: Any build parameters passed to the factory as kwarg values.

:returns: An iterable of values.

"""
if cls.is_ignored_type(field_meta.annotation):
return [None]

for unwrapped_annotation in flatten_annotation(field_meta.annotation):
if unwrapped_annotation in (None, NoneType):
yield None

elif is_literal(annotation=unwrapped_annotation) and (literal_args := get_args(unwrapped_annotation)):
yield CoverageContainer(literal_args)

elif isinstance(unwrapped_annotation, EnumMeta):
yield CoverageContainer(list(unwrapped_annotation))

elif field_meta.constraints:
yield CoverageContainerCallable(
cls.get_constrained_field_value,
annotation=unwrapped_annotation,
field_meta=field_meta,
)

elif BaseFactory.is_factory_type(annotation=unwrapped_annotation):
yield CoverageContainer(
cls._get_or_create_factory(model=unwrapped_annotation).coverage(
**(field_build_parameters if isinstance(field_build_parameters, Mapping) else {}),
),
)

elif (origin := get_type_origin(unwrapped_annotation)) and issubclass(origin, Collection):
yield handle_collection_type_coverage(field_meta, origin, cls)

elif is_any(unwrapped_annotation) or isinstance(unwrapped_annotation, TypeVar):
yield create_random_string(cls.__random__, min_length=1, max_length=10)

elif provider := cls.get_provider_map().get(unwrapped_annotation):
yield CoverageContainerCallable(provider)

elif callable(unwrapped_annotation):
# if value is a callable we can try to naively call it.
# this will work for callables that do not require any parameters passed
yield CoverageContainerCallable(unwrapped_annotation)
else:
msg = f"Unsupported type: {unwrapped_annotation!r}\n\nEither extend the providers map or add a factory function for this type."
raise ParameterException(
msg,
)

@classmethod
def should_set_none_value(cls, field_meta: FieldMeta) -> bool:
"""Determine whether a given model field_meta should be set to None.
Expand Down Expand Up @@ -752,6 +848,50 @@ def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]:

return result

@classmethod
def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]:
sam-or marked this conversation as resolved.
Show resolved Hide resolved
"""Process the given kwargs and generate values for the factory's model.

:param kwargs: Any build kwargs.

:returns: A dictionary of build results.

"""
result: dict[str, Any] = {**kwargs}
generate_post: dict[str, PostGenerated] = {}

for field_meta in cls.get_model_fields():
field_build_parameters = cls.extract_field_build_parameters(field_meta=field_meta, build_args=kwargs)

if cls.should_set_field_value(field_meta, **kwargs):
if hasattr(cls, field_meta.name) and not hasattr(BaseFactory, field_meta.name):
field_value = getattr(cls, field_meta.name)
if isinstance(field_value, Ignore):
continue

if isinstance(field_value, Require) and field_meta.name not in kwargs:
msg = f"Require kwarg {field_meta.name} is missing"
raise MissingBuildKwargException(msg)

if isinstance(field_value, PostGenerated):
generate_post[field_meta.name] = field_value
continue

result[field_meta.name] = cls._handle_factory_field_coverage(
field_value=field_value,
field_build_parameters=field_build_parameters,
)
continue

result[field_meta.name] = CoverageContainer(
cls.get_field_value_coverage(field_meta, field_build_parameters=field_build_parameters),
)

for resolved in resolve_kwargs_coverage(result):
for field_name, post_generator in generate_post.items():
resolved[field_name] = post_generator.to_value(field_name, resolved)
yield resolved

@classmethod
def build(cls, **kwargs: Any) -> T:
"""Build an instance of the factory's __model__
Expand All @@ -776,6 +916,19 @@ def batch(cls, size: int, **kwargs: Any) -> list[T]:
"""
return [cls.build(**kwargs) for _ in range(size)]

@classmethod
def coverage(cls, **kwargs: Any) -> abc.Iterator[T]:
sam-or marked this conversation as resolved.
Show resolved Hide resolved
"""Build a batch of the factory's Meta.model will full coverage of the sub-types of the model.

:param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used.

:returns: A iterator of instances of type T.

"""
for data in cls.process_kwargs_coverage(**kwargs):
instance = cls.__model__(**data)
yield cast("T", instance)

@classmethod
def create_sync(cls, **kwargs: Any) -> T:
"""Build and persists synchronously a single model instance.
Expand Down
31 changes: 30 additions & 1 deletion polyfactory/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
import sys
from typing import TYPE_CHECKING, Any, Mapping

try:
from types import NoneType
except ImportError:
NoneType = type(None) # type: ignore[misc,assignment]

from typing_extensions import get_args, get_origin

from polyfactory.constants import TYPE_MAPPING
Expand Down Expand Up @@ -52,7 +57,7 @@ def unwrap_optional(annotation: Any) -> Any:
:returns: A type annotation
"""
while is_optional(annotation):
annotation = next(arg for arg in get_args(annotation) if arg not in (type(None), None))
annotation = next(arg for arg in get_args(annotation) if arg not in (NoneType, None))
return annotation


Expand All @@ -77,6 +82,30 @@ def unwrap_annotation(annotation: Any, random: Random) -> Any:
return annotation


def flatten_annotation(annotation: Any) -> list[Any]:
"""Flattens an annotation.

:param annotation: A type annotation.

:returns: The flattened annotations.
"""
flat = []
if is_new_type(annotation):
flat.extend(flatten_annotation(unwrap_new_type(annotation)))
elif is_optional(annotation):
flat.append(NoneType)
flat.extend(flatten_annotation(arg) for arg in get_args(annotation) if arg not in (NoneType, None))
elif is_annotated(annotation):
flat.extend(flatten_annotation(get_args(annotation)[0]))
elif is_union(annotation):
for a in get_args(annotation):
flat.extend(flatten_annotation(a))
else:
flat.append(annotation)

return flat


def unwrap_args(annotation: Any, random: Random) -> tuple[Any, ...]:
"""Unwrap the annotation and return any type args.

Expand Down
Loading
Loading