Skip to content

Commit

Permalink
fix: handle SQLA column constraints
Browse files Browse the repository at this point in the history
  • Loading branch information
adhtruong committed Oct 24, 2024
1 parent 37a9894 commit 8950c95
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 21 deletions.
68 changes: 52 additions & 16 deletions polyfactory/factories/sqlalchemy_factory.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from __future__ import annotations

from datetime import date, datetime
from decimal import Decimal
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, List, TypeVar, Union

from sqlalchemy import ARRAY
from typing_extensions import Annotated

from polyfactory.exceptions import MissingDependencyException
from polyfactory.factories.base import BaseFactory
from polyfactory.field_meta import FieldMeta
from polyfactory.field_meta import Constraints, FieldMeta
from polyfactory.persistence import AsyncPersistenceProtocol, SyncPersistenceProtocol
from polyfactory.utils.predicates import is_safe_subclass

try:
from sqlalchemy import Column, inspect, types
Expand All @@ -20,6 +25,7 @@
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from sqlalchemy.sql.type_api import TypeEngine
from typing_extensions import TypeGuard


Expand Down Expand Up @@ -133,27 +139,57 @@ def should_column_be_set(cls, column: Any) -> bool:

return bool(cls.__set_foreign_keys__ or not column.foreign_keys)

@classmethod
def _get_python_type(cls, type_engine: TypeEngine) -> type:
if type(type_engine) in cls.get_sqlalchemy_types():
return type(type_engine)

try:
annotation = type_engine.python_type
except NotImplementedError:
annotation = type_engine.impl.python_type # type: ignore[attr-defined]

return annotation

@classmethod
def _get_column_type_constraints(cls) -> dict[type, dict[str, str]]:
return {
str: {
"max_length": "length",
},
Decimal: {
"max_digits": "precision",
"decimal_places": "scale",
},
}

@classmethod
def _set_column_constraints(cls, type_engine: TypeEngine, annotation: type) -> type:
constraints: Constraints = {}
for type_, constraint_fields in cls._get_column_type_constraints().items():
if not is_safe_subclass(annotation, type_):
continue
for constraint_field, sqlalchemy_field in constraint_fields.items():
if (value := getattr(type_engine, sqlalchemy_field, None)) is not None:
constraints[constraint_field] = value # type: ignore[literal-required]
if constraints:
annotation = Annotated[annotation, constraints] # type: ignore[assignment]

return annotation

@classmethod
def get_type_from_column(cls, column: Column) -> type:
column_type = type(column.type)
sqla_types = cls.get_sqlalchemy_types()
if column_type in sqla_types:
annotation = column_type
elif issubclass(column_type, postgresql.ARRAY):
if type(column.type.item_type) in sqla_types: # type: ignore[attr-defined]
annotation = List[type(column.type.item_type)] # type: ignore[attr-defined,misc,assignment]
else:
annotation = List[column.type.item_type.python_type] # type: ignore[assignment,name-defined]
elif issubclass(column_type, types.ARRAY):
annotation = List[column.type.item_type.python_type] # type: ignore[assignment,name-defined]
annotation: type
if isinstance(column.type, (ARRAY, postgresql.ARRAY)):
item_type = cls._get_python_type(column.type.item_type)
annotation = List[item_type] # type: ignore[valid-type]
else:
try:
annotation = column.type.python_type
except NotImplementedError:
annotation = column.type.impl.python_type # type: ignore[attr-defined]
annotation = cls._get_python_type(column.type)
annotation = cls._set_column_constraints(column.type, annotation)

if column.nullable:
annotation = Union[annotation, None] # type: ignore[assignment]

return annotation

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion polyfactory/value_generators/constrained_dates.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,4 @@ def handle_constrained_date(
elif lt:
end_date = lt - timedelta(days=1)

return faker.date_between(start_date=start_date, end_date=end_date)
return faker.date_between(start_date=start_date, end_date=end_date) # type: ignore[return-value]
40 changes: 36 additions & 4 deletions tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from dataclasses import dataclass
from datetime import datetime
from decimal import Decimal
from enum import Enum
from typing import Any, Callable, Type, Union
from typing import Any, Callable, Type
from uuid import UUID

import pytest
Expand All @@ -11,6 +12,7 @@
DateTime,
ForeignKey,
Integer,
Numeric,
String,
create_engine,
func,
Expand Down Expand Up @@ -247,8 +249,7 @@ class AuthorFactory(SQLAlchemyFactory[Author]):
assert isinstance(result.books[0], Book)


def test_sqla_factory_create() -> None:
engine = create_engine("sqlite:///:memory:")
def test_sqla_factory_create(engine: Engine) -> None:
Base.metadata.create_all(engine)

class OverridenSQLAlchemyFactory(SQLAlchemyFactory):
Expand Down Expand Up @@ -415,7 +416,7 @@ class ModelFactory(SQLAlchemyFactory[ModelWithAlias]):


@pytest.mark.parametrize("python_type_", (UUID, None))
def test_sqlalchemy_custom_type_from_type_decorator(python_type_: Union[type, None]) -> None:
def test_sqlalchemy_custom_type_from_type_decorator(python_type_: type) -> None:
class CustomType(types.TypeDecorator):
impl = types.CHAR(32)
cache_ok = True
Expand Down Expand Up @@ -446,3 +447,34 @@ class ModelFactory(SQLAlchemyFactory[Model]):

expected_type = python_type_ if python_type_ is not None else CustomType.impl.python_type
assert isinstance(instance.custom_type, expected_type)


def test_constrained_types() -> None:
_registry = registry()

class Base(metaclass=DeclarativeMeta):
__abstract__ = True
__allow_unmapped__ = True

registry = _registry
metadata = _registry.metadata

class Model(Base):
__tablename__ = "constrained_model"

id: Any = Column(Integer(), primary_key=True)
constrained_string: Any = Column(String(length=1), nullable=False)
constrainted_number: Any = Column(
Numeric(precision=2, scale=1),
nullable=False,
)

class ModelFactory(SQLAlchemyFactory[Model]):
__model__ = Model

instance = ModelFactory.build()
assert len(instance.constrained_string) <= 1

constrained_number: Decimal = instance.constrainted_number
assert isinstance(constrained_number, Decimal)
assert abs(len(constrained_number.as_tuple().digits) - abs(int(constrained_number.as_tuple().exponent))) <= 2

0 comments on commit 8950c95

Please sign in to comment.