diff --git a/polyfactory/factories/sqlalchemy_factory.py b/polyfactory/factories/sqlalchemy_factory.py index 4b62c8f9..f9897dd8 100644 --- a/polyfactory/factories/sqlalchemy_factory.py +++ b/polyfactory/factories/sqlalchemy_factory.py @@ -50,11 +50,14 @@ def __init__(self, session: AsyncSession) -> None: async def save(self, data: T) -> T: self.session.add(data) await self.session.commit() + await self.session.refresh(data) return data async def save_many(self, data: list[T]) -> list[T]: self.session.add_all(data) await self.session.commit() + for batch_item in data: + await self.session.refresh(batch_item) return data diff --git a/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py b/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py index 78968e8a..2442918a 100644 --- a/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py +++ b/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py @@ -1,10 +1,24 @@ from dataclasses import dataclass +from datetime import datetime from enum import Enum from typing import Any, Callable, Type, Union from uuid import UUID import pytest -from sqlalchemy import Column, ForeignKey, Integer, String, create_engine, inspect, orm, types +from sqlalchemy import ( + Boolean, + Column, + DateTime, + ForeignKey, + Integer, + String, + create_engine, + func, + inspect, + orm, + text, + types, +) from sqlalchemy.engine import Engine from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine from sqlalchemy.ext.hybrid import hybrid_property @@ -14,6 +28,7 @@ from polyfactory.exceptions import ConfigurationException from polyfactory.factories.base import BaseFactory from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory +from polyfactory.fields import Ignore @pytest.fixture() @@ -336,6 +351,55 @@ class Factory(SQLAlchemyFactory[AsyncModel]): assert inspect(batch_item).persistent # type: ignore[union-attr] +@pytest.mark.parametrize( + "session_config", + ( + lambda session: session, + lambda session: (lambda: session), + ), +) +async def test_async_server_default_refresh( + async_engine: AsyncEngine, + session_config: Callable[[AsyncSession], Any], +) -> None: + _registry = registry() + + class Base(metaclass=DeclarativeMeta): + __abstract__ = True + + registry = _registry + metadata = _registry.metadata + + class AsyncRefreshModel(Base): + __tablename__ = "server_default_test" + + id: Any = Column(Integer(), primary_key=True) + test_datetime: Any = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) + test_str: Any = Column(String, nullable=False, server_default=text("test_str")) + test_int: Any = Column(Integer, nullable=False, server_default=text("123")) + test_bool: Any = Column(Boolean, nullable=False, server_default=text("False")) + + await create_tables(async_engine, Base) + + async with AsyncSession(async_engine) as session: + + class Factory(SQLAlchemyFactory[AsyncRefreshModel]): + __async_session__ = session_config(session) + __model__ = AsyncRefreshModel + test_datetime = Ignore() + test_str = Ignore() + test_int = Ignore() + test_bool = Ignore() + + result = await Factory.create_async() + assert inspect(result).persistent # type: ignore[union-attr] + assert result.test_datetime is not None + assert isinstance(result.test_datetime, datetime) + assert result.test_str == "test_str" + assert result.test_int == 123 + assert result.test_bool is False + + def test_alias() -> None: class ModelWithAlias(Base): __tablename__ = "table"