Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sqlalchemy-factory): Add support for SQLAlchemy custom types #398

Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import datetime
from typing import Any

from sqlalchemy import DateTime, types
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column

from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory


class TZAwareDateTime(types.TypeDecorator):
impl = DateTime(timezone=True)


class Base(DeclarativeBase):
...


class Author(Base):
__tablename__ = "authors"

id: Mapped[int] = mapped_column(primary_key=True)
first_publication_at: Mapped[Any] = mapped_column(type_=TZAwareDateTime(), nullable=False)


class AuthorFactory(SQLAlchemyFactory[Author]):
__model__ = Author


def test_sqla_type_decorator_custom_type_factory() -> None:
author = AuthorFactory.build()
assert isinstance(author.first_publication_at, datetime.datetime)
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import datetime
from typing import Any, Callable, Dict

from sqlalchemy import types
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column

from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory


class CustomType(types.UserDefinedType):
...


class Base(DeclarativeBase):
...


class Author(Base):
__tablename__ = "authors"

id: Mapped[int] = mapped_column(primary_key=True)
first_publication_at: Mapped[Any] = mapped_column(type_=CustomType())


class AuthorFactory(SQLAlchemyFactory[Author]):
__model__ = Author

@classmethod
def get_sqlalchemy_types(cls) -> Dict[Any, Callable[[], Any]]:
return {**super().get_sqlalchemy_types(), CustomType: lambda: cls.__faker__.date_time()}


def test_sqla_user_defined_type_custom_type_factory() -> None:
author = AuthorFactory.build()
assert isinstance(author.first_publication_at, datetime.datetime)
23 changes: 23 additions & 0 deletions docs/usage/library_factories/sqlalchemy_factory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,29 @@ By default, this will add generated models to the session and then commit. This

Similarly for ``__async_session__`` and ``create_async``.

SQLAlchemy custom types
------------------------------

TypeDecorator based
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Type will be inferred from SQLAlchemy type defined in ``impl``.

.. literalinclude:: /examples/library_factories/sqlalchemy_factory/test_example_4.py
:caption: Custom types with TypeDecorator
:language: python


UserDefinedType based
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

``get_sqlalchemy_types`` classmethod needs to be overridden to provide factory function for custom type.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be changed so that the user should override the get_provider_map method. That makes it more consistent with how the other factories handle custom user defined types.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This provides the mapping from sqlalchemy types to provider where this is directly mapped from column vs where the implementation type could be used.

I think these could be unified but may require some reworking so the the logic of handling types in exposed separately. There wasn't a clear way to map things like Integer column without repeating int handling. This case is alright to handle but this becomes more complex especially when considering things like list[str].

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This provides the mapping from sqlalchemy types to provider where this is directly mapped from column vs where the implementation type could be used.

I think these could be unified but may require some reworking so the the logic of handling types in exposed separately. There wasn't a clear way to map things like Integer column without repeating int handling. This case is alright to handle but this becomes more complex especially when considering things like list[str].

I agree that we can keep a separate get_sqlalchemy_types and a get_provider_map so that we can deal with sql types in an easier manner. However, since this is a user defined type, should it not go into the get_provider_map which is where the user defined types are supposed to be given for the other factories?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not with the current implementation.

get_type_from_column only considers the output of get_sqlalchemy_types here for SQLA types it should not map for providers. This could be changed to just using get_provider_map to bring inline with other factories.

I don't think this is a backwards breaking change though one very minor concern it may be harder where want to provide a mapping for a field with type of Column itself. This is minor as feels a fairly obscure use case.


.. literalinclude:: /examples/library_factories/sqlalchemy_factory/test_example_5.py
:caption: Custom types with UserDefinedType
:language: python

More info on `SQLAlchemy Custom Types <https://docs.sqlalchemy.org/en/20/core/custom_types.html>`_.

API reference
------------------------------
Expand Down
13 changes: 12 additions & 1 deletion polyfactory/factories/sqlalchemy_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from datetime import date, datetime
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, List, TypeVar, Union

from polyfactory.exceptions import MissingDependencyException
from polyfactory.exceptions import MissingDependencyException, ParameterException
from polyfactory.factories.base import BaseFactory
from polyfactory.field_meta import FieldMeta
from polyfactory.persistence import AsyncPersistenceProtocol, SyncPersistenceProtocol
Expand Down Expand Up @@ -118,6 +118,17 @@ def get_type_from_column(cls, column: Column) -> type:
annotation = column_type
elif issubclass(column_type, types.ARRAY):
annotation = List[column.type.item_type.python_type] # type: ignore[assignment,name-defined]
elif issubclass(column_type, types.TypeDecorator) and isinstance(
python_type := column_type.impl.python_type,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the test for python_type being an instance of type needed? Also, what happens if there is no python_type implemented for the column_type.impl? For example, it could be set to postgresql.CIDR or types.ARRAY.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Essentially, I think the parsing of the annotation from the column type needs to be recursive in the case of types.TypeDecorator.

type,
):
annotation = python_type
elif issubclass(column_type, types.UserDefinedType):
parameter_exc_msg = (
f"User defined type detected (subclass of {types.UserDefinedType}). "
"Override get_sqlalchemy_types to provide factory function."
)
raise ParameterException(parameter_exc_msg)
Comment on lines +126 to +131
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to throw the exception here itself. If the provider map doesn't have the corresponding factory function, then the exception will be thrown from the get_field_value method. Instead, this can just return the column type directly if the column type is a subclass off UserDefinedType.

else:
annotation = column.type.python_type

Expand Down
98 changes: 96 additions & 2 deletions tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import datetime
from enum import Enum
from typing import Any, List
from typing import Any, Callable, Dict, List

import pytest
from sqlalchemy import ForeignKey, __version__, orm, types
from sqlalchemy import ForeignKey, __version__, orm, sql, types

from polyfactory.exceptions import ParameterException
from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory

if __version__.startswith("1"):
Expand Down Expand Up @@ -85,3 +87,95 @@ class ModelFactory(SQLAlchemyFactory[Model]):

instance = ModelFactory.build()
assert instance.overridden is not None


@pytest.mark.parametrize(
Copy link
Collaborator

@adhtruong adhtruong Oct 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this applicable to SQLA 1.4? If so, I think these tests should be moved to _common and adjusted to common so tested against both 1.4 and 2

"impl_",
(
sql.sqltypes.BigInteger(),
sql.sqltypes.Boolean(),
sql.sqltypes.Date(),
sql.sqltypes.DateTime(),
sql.sqltypes.Double(),
sql.sqltypes.Enum(),
sql.sqltypes.Float(),
sql.sqltypes.Integer(),
sql.sqltypes.Interval(),
sql.sqltypes.LargeBinary(),
sql.sqltypes.MatchType(),
sql.sqltypes.Numeric(),
sql.sqltypes.SmallInteger(),
sql.sqltypes.String(),
sql.sqltypes.Text(),
sql.sqltypes.Time(),
sql.sqltypes.Unicode(), # type: ignore[no-untyped-call]
sql.sqltypes.UnicodeText(), # type: ignore[no-untyped-call]
sql.sqltypes.Uuid(),
),
)
def test_sqlalchemy_custom_type_from_type_decorator(impl_: types.TypeEngine) -> None:
class CustomType(types.TypeDecorator):
impl = impl_

class Base(orm.DeclarativeBase):
type_annotation_map = {object: CustomType}

class Model(Base):
__tablename__ = "model_with_custom_types"

id: orm.Mapped[int] = orm.mapped_column(primary_key=True)
custom_type: orm.Mapped[Any] = orm.mapped_column(type_=CustomType(), nullable=False)
custom_type_from_annotation_map: orm.Mapped[object]

class ModelFactory(SQLAlchemyFactory[Model]):
__model__ = Model

instance = ModelFactory.build()
assert isinstance(instance.id, int)
assert isinstance(instance.custom_type, impl_.python_type)
assert isinstance(instance.custom_type_from_annotation_map, impl_.python_type)


def test_sqlalchemy_custom_type_from_user_defined_type__overridden() -> None:
class CustomType(types.UserDefinedType):
...

class Base(orm.DeclarativeBase):
...

class Model(Base):
__tablename__ = "model_with_custom_types"

id: orm.Mapped[int] = orm.mapped_column(primary_key=True)
custom_type: orm.Mapped[Any] = orm.mapped_column(type_=CustomType())

class ModelFactory(SQLAlchemyFactory[Model]):
__model__ = Model

@classmethod
def get_sqlalchemy_types(cls) -> Dict[Any, Callable[[], Any]]:
return {**super().get_sqlalchemy_types(), CustomType: lambda: cls.__faker__.date_time()}

instance = ModelFactory.build()
assert isinstance(instance.id, int)
assert isinstance(instance.custom_type, datetime.datetime)


def test_sqlalchemy_custom_type_from_user_defined_type__type_not_supported() -> None:
class CustomType(types.UserDefinedType):
...

class Base(orm.DeclarativeBase):
...

class Model(Base):
__tablename__ = "model_with_custom_types"

id: orm.Mapped[int] = orm.mapped_column(primary_key=True)
custom_type: orm.Mapped[Any] = orm.mapped_column(type_=CustomType())

class ModelFactory(SQLAlchemyFactory[Model]):
__model__ = Model

with pytest.raises(ParameterException, match="User defined type detected"):
ModelFactory.build()