diff --git a/polyfactory/factories/base.py b/polyfactory/factories/base.py index 3a376774..e7de20d9 100644 --- a/polyfactory/factories/base.py +++ b/polyfactory/factories/base.py @@ -82,7 +82,15 @@ create_random_string, ) +try: + from hypothesis import strategies as st + + _HYPOTHESIS_AVAILABLE: bool = True +except ImportError: + _HYPOTHESIS_AVAILABLE = False + if TYPE_CHECKING: + from hypothesis.strategies import SearchStrategy from typing_extensions import TypeGuard from polyfactory.field_meta import Constraints, FieldMeta @@ -771,6 +779,55 @@ async def create_batch_async(cls, size: int, **kwargs: Any) -> list[T]: """ return await cls._get_async_persistence().save_many(data=cls.batch(size, **kwargs)) + if _HYPOTHESIS_AVAILABLE: + + @classmethod + def get_hypothesis_provider_map(cls) -> dict[type, SearchStrategy[Any]]: + """Map types to search strategies. + + :notes: + - This method is distinct to allow overriding. + + + :returns: A dictionary mapping types to search strategies. + + """ + return {} + + @classmethod + def get_field_hypothesis_strategy(cls, field_meta: FieldMeta) -> SearchStrategy[Any]: + """Return a hypothesis strategy for the given field meta. + + :param field_meta: FieldMeta instance. + + :returns: An instance of SearchStrategy. + """ + + if strategy := cls.get_hypothesis_provider_map().get(field_meta.annotation, None): + return strategy + + return st.from_type(field_meta.annotation) + + @classmethod + def build_hypothesis_strategy(cls, **kwargs: SearchStrategy[Any]) -> SearchStrategy[T]: + """Build a hypothesis strategy for the factory's __model__ + + :param kwargs: Any kwargs. If field names are set in kwargs, their values will be used as the strategy for that field. + + :returns: An instance of SearchStrategy for the factory's __model__. + + """ + st_kwargs = kwargs + cls.get_hypothesis_provider_map() + + for field_meta in cls.get_model_fields(): + if field_meta.name in st_kwargs: + continue + + st_kwargs[field_meta.name] = cls.get_field_hypothesis_strategy(field_meta) + + return st.builds(cls.__model__, **st_kwargs) + def _register_builtin_factories() -> None: """This function is used to register the base factories, if present. diff --git a/tests/test_hypothesis_strategy_creation.py b/tests/test_hypothesis_strategy_creation.py new file mode 100644 index 00000000..7a2115ed --- /dev/null +++ b/tests/test_hypothesis_strategy_creation.py @@ -0,0 +1,20 @@ +import msgspec +from hypothesis import given +from msgspec import Struct +from msgspec.structs import asdict + +from polyfactory.factories.msgspec_factory import MsgspecFactory + + +def test_without_constraints() -> None: + class Foo(Struct): + int_field: int + str_field: str + + foo_st = MsgspecFactory.create_factory(Foo).build_hypothesis_strategy() + + @given(foo_st) + def test_foo(foo: Foo) -> None: + _ = msgspec.convert(asdict(foo), type=Foo) + + test_foo()