-
-
Notifications
You must be signed in to change notification settings - Fork 88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(sqlalchemy-factory): Add support for SQLAlchemy custom types #398
Changes from all commits
cef52d6
7c430ac
ed1d3b8
0c0f05c
c6bf43c
16d0c93
19918fa
9348488
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import datetime | ||
from typing import Any | ||
|
||
from sqlalchemy import DateTime, types | ||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column | ||
|
||
from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory | ||
|
||
|
||
class TZAwareDateTime(types.TypeDecorator): | ||
impl = DateTime(timezone=True) | ||
|
||
|
||
class Base(DeclarativeBase): | ||
... | ||
|
||
|
||
class Author(Base): | ||
__tablename__ = "authors" | ||
|
||
id: Mapped[int] = mapped_column(primary_key=True) | ||
first_publication_at: Mapped[Any] = mapped_column(type_=TZAwareDateTime(), nullable=False) | ||
|
||
|
||
class AuthorFactory(SQLAlchemyFactory[Author]): | ||
__model__ = Author | ||
|
||
|
||
def test_sqla_type_decorator_custom_type_factory() -> None: | ||
author = AuthorFactory.build() | ||
assert isinstance(author.first_publication_at, datetime.datetime) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import datetime | ||
from typing import Any, Callable, Dict | ||
|
||
from sqlalchemy import types | ||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column | ||
|
||
from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory | ||
|
||
|
||
class CustomType(types.UserDefinedType): | ||
... | ||
|
||
|
||
class Base(DeclarativeBase): | ||
... | ||
|
||
|
||
class Author(Base): | ||
__tablename__ = "authors" | ||
|
||
id: Mapped[int] = mapped_column(primary_key=True) | ||
first_publication_at: Mapped[Any] = mapped_column(type_=CustomType()) | ||
|
||
|
||
class AuthorFactory(SQLAlchemyFactory[Author]): | ||
__model__ = Author | ||
|
||
@classmethod | ||
def get_sqlalchemy_types(cls) -> Dict[Any, Callable[[], Any]]: | ||
return {**super().get_sqlalchemy_types(), CustomType: lambda: cls.__faker__.date_time()} | ||
|
||
|
||
def test_sqla_user_defined_type_custom_type_factory() -> None: | ||
author = AuthorFactory.build() | ||
assert isinstance(author.first_publication_at, datetime.datetime) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ | |
from datetime import date, datetime | ||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, List, TypeVar, Union | ||
|
||
from polyfactory.exceptions import MissingDependencyException | ||
from polyfactory.exceptions import MissingDependencyException, ParameterException | ||
from polyfactory.factories.base import BaseFactory | ||
from polyfactory.field_meta import FieldMeta | ||
from polyfactory.persistence import AsyncPersistenceProtocol, SyncPersistenceProtocol | ||
|
@@ -118,6 +118,17 @@ def get_type_from_column(cls, column: Column) -> type: | |
annotation = column_type | ||
elif issubclass(column_type, types.ARRAY): | ||
annotation = List[column.type.item_type.python_type] # type: ignore[assignment,name-defined] | ||
elif issubclass(column_type, types.TypeDecorator) and isinstance( | ||
python_type := column_type.impl.python_type, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is the test for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Essentially, I think the parsing of the annotation from the column type needs to be recursive in the case of |
||
type, | ||
): | ||
annotation = python_type | ||
elif issubclass(column_type, types.UserDefinedType): | ||
parameter_exc_msg = ( | ||
f"User defined type detected (subclass of {types.UserDefinedType}). " | ||
"Override get_sqlalchemy_types to provide factory function." | ||
) | ||
raise ParameterException(parameter_exc_msg) | ||
Comment on lines
+126
to
+131
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need to throw the exception here itself. If the provider map doesn't have the corresponding factory function, then the exception will be thrown from the |
||
else: | ||
annotation = column.type.python_type | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,11 @@ | ||
import datetime | ||
from enum import Enum | ||
from typing import Any, List | ||
from typing import Any, Callable, Dict, List | ||
|
||
import pytest | ||
from sqlalchemy import ForeignKey, __version__, orm, types | ||
from sqlalchemy import ForeignKey, __version__, orm, sql, types | ||
|
||
from polyfactory.exceptions import ParameterException | ||
from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory | ||
|
||
if __version__.startswith("1"): | ||
|
@@ -85,3 +87,95 @@ class ModelFactory(SQLAlchemyFactory[Model]): | |
|
||
instance = ModelFactory.build() | ||
assert instance.overridden is not None | ||
|
||
|
||
@pytest.mark.parametrize( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this applicable to SQLA 1.4? If so, I think these tests should be moved to _common and adjusted to common so tested against both 1.4 and 2 |
||
"impl_", | ||
( | ||
sql.sqltypes.BigInteger(), | ||
sql.sqltypes.Boolean(), | ||
sql.sqltypes.Date(), | ||
sql.sqltypes.DateTime(), | ||
sql.sqltypes.Double(), | ||
sql.sqltypes.Enum(), | ||
sql.sqltypes.Float(), | ||
sql.sqltypes.Integer(), | ||
sql.sqltypes.Interval(), | ||
sql.sqltypes.LargeBinary(), | ||
sql.sqltypes.MatchType(), | ||
sql.sqltypes.Numeric(), | ||
sql.sqltypes.SmallInteger(), | ||
sql.sqltypes.String(), | ||
sql.sqltypes.Text(), | ||
sql.sqltypes.Time(), | ||
sql.sqltypes.Unicode(), # type: ignore[no-untyped-call] | ||
sql.sqltypes.UnicodeText(), # type: ignore[no-untyped-call] | ||
sql.sqltypes.Uuid(), | ||
), | ||
) | ||
def test_sqlalchemy_custom_type_from_type_decorator(impl_: types.TypeEngine) -> None: | ||
class CustomType(types.TypeDecorator): | ||
impl = impl_ | ||
|
||
class Base(orm.DeclarativeBase): | ||
type_annotation_map = {object: CustomType} | ||
|
||
class Model(Base): | ||
__tablename__ = "model_with_custom_types" | ||
|
||
id: orm.Mapped[int] = orm.mapped_column(primary_key=True) | ||
custom_type: orm.Mapped[Any] = orm.mapped_column(type_=CustomType(), nullable=False) | ||
custom_type_from_annotation_map: orm.Mapped[object] | ||
|
||
class ModelFactory(SQLAlchemyFactory[Model]): | ||
__model__ = Model | ||
|
||
instance = ModelFactory.build() | ||
assert isinstance(instance.id, int) | ||
assert isinstance(instance.custom_type, impl_.python_type) | ||
assert isinstance(instance.custom_type_from_annotation_map, impl_.python_type) | ||
|
||
|
||
def test_sqlalchemy_custom_type_from_user_defined_type__overridden() -> None: | ||
class CustomType(types.UserDefinedType): | ||
... | ||
|
||
class Base(orm.DeclarativeBase): | ||
... | ||
|
||
class Model(Base): | ||
__tablename__ = "model_with_custom_types" | ||
|
||
id: orm.Mapped[int] = orm.mapped_column(primary_key=True) | ||
custom_type: orm.Mapped[Any] = orm.mapped_column(type_=CustomType()) | ||
|
||
class ModelFactory(SQLAlchemyFactory[Model]): | ||
__model__ = Model | ||
|
||
@classmethod | ||
def get_sqlalchemy_types(cls) -> Dict[Any, Callable[[], Any]]: | ||
return {**super().get_sqlalchemy_types(), CustomType: lambda: cls.__faker__.date_time()} | ||
|
||
instance = ModelFactory.build() | ||
assert isinstance(instance.id, int) | ||
assert isinstance(instance.custom_type, datetime.datetime) | ||
|
||
|
||
def test_sqlalchemy_custom_type_from_user_defined_type__type_not_supported() -> None: | ||
class CustomType(types.UserDefinedType): | ||
... | ||
|
||
class Base(orm.DeclarativeBase): | ||
... | ||
|
||
class Model(Base): | ||
__tablename__ = "model_with_custom_types" | ||
|
||
id: orm.Mapped[int] = orm.mapped_column(primary_key=True) | ||
custom_type: orm.Mapped[Any] = orm.mapped_column(type_=CustomType()) | ||
|
||
class ModelFactory(SQLAlchemyFactory[Model]): | ||
__model__ = Model | ||
|
||
with pytest.raises(ParameterException, match="User defined type detected"): | ||
ModelFactory.build() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be changed so that the user should override the
get_provider_map
method. That makes it more consistent with how the other factories handle custom user defined types.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This provides the mapping from sqlalchemy types to provider where this is directly mapped from column vs where the implementation type could be used.
I think these could be unified but may require some reworking so the the logic of handling types in exposed separately. There wasn't a clear way to map things like
Integer
column without repeating int handling. This case is alright to handle but this becomes more complex especially when considering things likelist[str]
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that we can keep a separate
get_sqlalchemy_types
and aget_provider_map
so that we can deal withsql
types in an easier manner. However, since this is a user defined type, should it not go into theget_provider_map
which is where the user defined types are supposed to be given for the other factories?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not with the current implementation.
get_type_from_column
only considers the output ofget_sqlalchemy_types
here for SQLA types it should not map for providers. This could be changed to just usingget_provider_map
to bring inline with other factories.I don't think this is a backwards breaking change though one very minor concern it may be harder where want to provide a mapping for a field with type of Column itself. This is minor as feels a fairly obscure use case.