diff --git a/polyfactory/factories/base.py b/polyfactory/factories/base.py index 72d86dc2..f38f8d0e 100644 --- a/polyfactory/factories/base.py +++ b/polyfactory/factories/base.py @@ -29,6 +29,7 @@ Callable, ClassVar, Collection, + Coroutine, Generic, Hashable, Iterable, @@ -1068,9 +1069,25 @@ def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]: resolved[field_name] = post_generator.to_value(field_name, resolved) yield resolved + @classmethod + async def build_async(cls, **kwargs: Any) -> T: + """Asynchronously build an instance of the factory's __model__ + + :param kwargs: Any kwargs. If field names are set in kwargs, their values will be used. + + :returns: An instance of type T. + + """ + data: dict[str, Any] = cls.process_kwargs(**kwargs) + for k, v in data.items(): + if isinstance(v, Coroutine): + data[k] = await v + + return cast("T", cls.__model__(**data)) + @classmethod def build(cls, **kwargs: Any) -> T: - """Build an instance of the factory's __model__ + """Synchronously build an instance of the factory's __model__ :param kwargs: Any kwargs. If field names are set in kwargs, their values will be used. @@ -1081,7 +1098,7 @@ def build(cls, **kwargs: Any) -> T: @classmethod def batch(cls, size: int, **kwargs: Any) -> list[T]: - """Build a batch of size n of the factory's Meta.model. + """Synchronously build a batch of size n of the factory's Meta.model. :param size: Size of the batch. :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. @@ -1091,6 +1108,18 @@ def batch(cls, size: int, **kwargs: Any) -> list[T]: """ return [cls.build(**kwargs) for _ in range(size)] + @classmethod + async def batch_async(cls, size: int, **kwargs: Any) -> list[T]: + """Asynchronously build a batch of size n of the factory's Meta.model. + + :param size: Size of the batch. + :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. + + :returns: A list of instances of type T. + + """ + return [await cls.build_async(**kwargs) for _ in range(size)] + @classmethod def coverage(cls, **kwargs: Any) -> abc.Iterator[T]: """Build a batch of the factory's Meta.model will full coverage of the sub-types of the model. @@ -1135,7 +1164,7 @@ async def create_async(cls, **kwargs: Any) -> T: :returns: An instance of type T. """ - return await cls._get_async_persistence().save(data=cls.build(**kwargs)) + return await cls._get_async_persistence().save(data=await cls.build_async(**kwargs)) @classmethod async def create_batch_async(cls, size: int, **kwargs: Any) -> list[T]: @@ -1147,7 +1176,7 @@ async def create_batch_async(cls, size: int, **kwargs: Any) -> list[T]: :returns: A list of instances of type T. """ - return await cls._get_async_persistence().save_many(data=cls.batch(size, **kwargs)) + return await cls._get_async_persistence().save_many(data=await cls.batch_async(size, **kwargs)) def _register_builtin_factories() -> None: diff --git a/polyfactory/factories/sqlalchemy_factory.py b/polyfactory/factories/sqlalchemy_factory.py index 4b15ac52..4ded37c6 100644 --- a/polyfactory/factories/sqlalchemy_factory.py +++ b/polyfactory/factories/sqlalchemy_factory.py @@ -15,6 +15,7 @@ from sqlalchemy import ARRAY, Column, Numeric, String, inspect, types from sqlalchemy.dialects import mssql, mysql, postgresql, sqlite from sqlalchemy.exc import NoInspectionAvailable + from sqlalchemy.ext.associationproxy import AssociationProxy from sqlalchemy.orm import InstanceState, Mapper except ImportError as e: msg = "sqlalchemy is not installed" @@ -52,16 +53,18 @@ def __init__(self, session: AsyncSession) -> None: self.session = session async def save(self, data: T) -> T: - self.session.add(data) - await self.session.commit() - await self.session.refresh(data) + async with self.session as session: + session.add(data) + await session.commit() + await session.refresh(data) return data async def save_many(self, data: list[T]) -> list[T]: - self.session.add_all(data) - await self.session.commit() - for batch_item in data: - await self.session.refresh(batch_item) + async with self.session as session: + session.add_all(data) + await session.commit() + for batch_item in data: + await session.refresh(batch_item) return data @@ -76,6 +79,8 @@ class SQLAlchemyFactory(Generic[T], BaseFactory[T]): """Configuration to consider columns with foreign keys as a field or not.""" __set_relationships__: ClassVar[bool] = False """Configuration to consider relationships property as a model field or not.""" + __set_association_proxy__: ClassVar[bool] = False + """Configuration to consider AssociationProxy property as a model field or not.""" __session__: ClassVar[Session | Callable[[], Session] | None] = None __async_session__: ClassVar[AsyncSession | Callable[[], AsyncSession] | None] = None @@ -213,6 +218,23 @@ def get_model_fields(cls) -> list[FieldMeta]: random=cls.__random__, ), ) + if cls.__set_association_proxy__: + for name, attr in table.all_orm_descriptors.items(): + if isinstance(attr, AssociationProxy): + target_collection = table.relationships.get(attr.target_collection) + if target_collection: + target_class = target_collection.entity.class_ + target_attr = getattr(target_class, attr.value_attr) + if target_attr: + class_ = target_attr.entity.class_ + annotation = class_ if not target_collection.uselist else List[class_] # type: ignore[valid-type] + fields_meta.append( + FieldMeta.from_type( + name=name, + annotation=annotation, + random=cls.__random__, + ) + ) return fields_meta diff --git a/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py b/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py index 9d47366e..5c8084ce 100644 --- a/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py +++ b/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py @@ -18,6 +18,7 @@ func, inspect, orm, + select, text, types, ) @@ -343,13 +344,15 @@ class Factory(SQLAlchemyFactory[AsyncModel]): __async_session__ = session_config(session) __model__ = AsyncModel - result = await Factory.create_async() - assert inspect(result).persistent # type: ignore[union-attr] + instance = await Factory.create_async() + result = await session.scalar(select(AsyncModel).where(AsyncModel.id == instance.id)) + assert result batch_result = await Factory.create_batch_async(size=2) assert len(batch_result) == 2 for batch_item in batch_result: - assert inspect(batch_item).persistent # type: ignore[union-attr] + result = await session.scalar(select(AsyncModel).where(AsyncModel.id == batch_item.id)) + assert result @pytest.mark.parametrize( @@ -392,8 +395,9 @@ class Factory(SQLAlchemyFactory[AsyncRefreshModel]): test_int = Ignore() test_bool = Ignore() - result = await Factory.create_async() - assert inspect(result).persistent # type: ignore[union-attr] + instance = await Factory.create_async() + result = await session.scalar(select(AsyncRefreshModel).where(AsyncRefreshModel.id == instance.id)) + assert result assert result.test_datetime is not None assert isinstance(result.test_datetime, datetime) assert result.test_str == "test_str"