Skip to content

Commit

Permalink
feat: handle SQLA column constraints (#594)
Browse files Browse the repository at this point in the history
  • Loading branch information
adhtruong authored Nov 6, 2024
1 parent 8d8f6a9 commit 6abb845
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 22 deletions.
68 changes: 50 additions & 18 deletions polyfactory/factories/sqlalchemy_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
from datetime import date, datetime
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, List, TypeVar, Union

from sqlalchemy import ARRAY, Numeric, String
from typing_extensions import Annotated

from polyfactory.exceptions import MissingDependencyException
from polyfactory.factories.base import BaseFactory
from polyfactory.field_meta import FieldMeta
from polyfactory.field_meta import Constraints, FieldMeta
from polyfactory.persistence import AsyncPersistenceProtocol, SyncPersistenceProtocol

try:
Expand All @@ -20,6 +23,7 @@
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from sqlalchemy.sql.type_api import TypeEngine
from typing_extensions import TypeGuard


Expand Down Expand Up @@ -85,8 +89,9 @@ class SQLAlchemyFactory(Generic[T], BaseFactory[T]):

@classmethod
def get_sqlalchemy_types(cls) -> dict[Any, Callable[[], Any]]:
"""Get mapping of types where column type.
for sqlalchemy dialect `JSON` type, accepted only basic types in pydict in case sqlalchemy process `JSON` raise serialize error.
"""Get mapping of types where column type should be used directly.
For sqlalchemy dialect `JSON` type, accepted only basic types in pydict in case sqlalchemy process `JSON` raise serialize error.
"""
return {
types.TupleType: cls.__faker__.pytuple,
Expand All @@ -109,6 +114,19 @@ def get_sqlalchemy_types(cls) -> dict[Any, Callable[[], Any]]:
types.JSON: lambda: cls.__faker__.pydict(value_types=(str, int, bool, float)),
}

@classmethod
def get_sqlalchemy_constraints(cls) -> dict[type[TypeEngine], dict[str, str]]:
"""Get mapping of SQLA type engine to attribute to constraints key."""
return {
String: {
"length": "max_length",
},
Numeric: {
"precision": "max_digits",
"scale": "decimal_places",
},
}

@classmethod
def get_provider_map(cls) -> dict[Any, Callable[[], Any]]:
providers_map = super().get_provider_map()
Expand All @@ -133,27 +151,41 @@ def should_column_be_set(cls, column: Any) -> bool:

return bool(cls.__set_foreign_keys__ or not column.foreign_keys)

@classmethod
def _get_type_from_type_engine(cls, type_engine: TypeEngine) -> type:
if type(type_engine) in cls.get_sqlalchemy_types():
return type(type_engine)

annotation: type
try:
annotation = type_engine.python_type
except NotImplementedError:
annotation = type_engine.impl.python_type # type: ignore[attr-defined]

constraints: Constraints = {}
for type_, constraint_fields in cls.get_sqlalchemy_constraints().items():
if not isinstance(type_engine, type_):
continue
for sqlalchemy_field, constraint_field in constraint_fields.items():
if (value := getattr(type_engine, sqlalchemy_field, None)) is not None:
constraints[constraint_field] = value # type: ignore[literal-required]
if constraints:
annotation = Annotated[annotation, constraints] # type: ignore[assignment]

return annotation

@classmethod
def get_type_from_column(cls, column: Column) -> type:
column_type = type(column.type)
sqla_types = cls.get_sqlalchemy_types()
if column_type in sqla_types:
annotation = column_type
elif issubclass(column_type, postgresql.ARRAY):
if type(column.type.item_type) in sqla_types: # type: ignore[attr-defined]
annotation = List[type(column.type.item_type)] # type: ignore[attr-defined,misc,assignment]
else:
annotation = List[column.type.item_type.python_type] # type: ignore[assignment,name-defined]
elif issubclass(column_type, types.ARRAY):
annotation = List[column.type.item_type.python_type] # type: ignore[assignment,name-defined]
annotation: type
if isinstance(column.type, (ARRAY, postgresql.ARRAY)):
item_type = cls._get_type_from_type_engine(column.type.item_type)
annotation = List[item_type] # type: ignore[valid-type]
else:
try:
annotation = column.type.python_type
except NotImplementedError:
annotation = column.type.impl.python_type # type: ignore[attr-defined]
annotation = cls._get_type_from_type_engine(column.type)

if column.nullable:
annotation = Union[annotation, None] # type: ignore[assignment]

return annotation

@classmethod
Expand Down
40 changes: 36 additions & 4 deletions tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from dataclasses import dataclass
from datetime import datetime
from decimal import Decimal
from enum import Enum
from typing import Any, Callable, Type, Union
from typing import Any, Callable, Type
from uuid import UUID

import pytest
Expand All @@ -11,6 +12,7 @@
DateTime,
ForeignKey,
Integer,
Numeric,
String,
create_engine,
func,
Expand Down Expand Up @@ -247,8 +249,7 @@ class AuthorFactory(SQLAlchemyFactory[Author]):
assert isinstance(result.books[0], Book)


def test_sqla_factory_create() -> None:
engine = create_engine("sqlite:///:memory:")
def test_sqla_factory_create(engine: Engine) -> None:
Base.metadata.create_all(engine)

class OverridenSQLAlchemyFactory(SQLAlchemyFactory):
Expand Down Expand Up @@ -415,7 +416,7 @@ class ModelFactory(SQLAlchemyFactory[ModelWithAlias]):


@pytest.mark.parametrize("python_type_", (UUID, None))
def test_sqlalchemy_custom_type_from_type_decorator(python_type_: Union[type, None]) -> None:
def test_sqlalchemy_custom_type_from_type_decorator(python_type_: type) -> None:
class CustomType(types.TypeDecorator):
impl = types.CHAR(32)
cache_ok = True
Expand Down Expand Up @@ -446,3 +447,34 @@ class ModelFactory(SQLAlchemyFactory[Model]):

expected_type = python_type_ if python_type_ is not None else CustomType.impl.python_type
assert isinstance(instance.custom_type, expected_type)


def test_constrained_types() -> None:
_registry = registry()

class Base(metaclass=DeclarativeMeta):
__abstract__ = True
__allow_unmapped__ = True

registry = _registry
metadata = _registry.metadata

class Model(Base):
__tablename__ = "constrained_model"

id: Any = Column(Integer(), primary_key=True)
constrained_string: Any = Column(String(length=1), nullable=False)
constrainted_number: Any = Column(
Numeric(precision=2, scale=1),
nullable=False,
)

class ModelFactory(SQLAlchemyFactory[Model]):
__model__ = Model

instance = ModelFactory.build()
assert len(instance.constrained_string) <= 1

constrained_number: Decimal = instance.constrainted_number
assert isinstance(constrained_number, Decimal)
assert abs(len(constrained_number.as_tuple().digits) - abs(int(constrained_number.as_tuple().exponent))) <= 2

0 comments on commit 6abb845

Please sign in to comment.