Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sqlfactory): support nested type in pg.array types and others #530

Merged
merged 12 commits into from
May 9, 2024
Merged
8 changes: 8 additions & 0 deletions polyfactory/factories/sqlalchemy_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ def get_sqlalchemy_types(cls) -> dict[Any, Callable[[], Any]]:
postgresql.NUMRANGE: lambda: tuple(sorted([cls.__faker__.pyint(), cls.__faker__.pyint()])),
postgresql.TSRANGE: lambda: (cls.__faker__.past_datetime(), datetime.now()), # noqa: DTZ005
postgresql.TSTZRANGE: lambda: (cls.__faker__.past_datetime(), datetime.now()), # noqa: DTZ005
postgresql.HSTORE: lambda: cls.__faker__.pydict(),
wangxin688 marked this conversation as resolved.
Show resolved Hide resolved
postgresql.JSON: lambda: cls.__faker__.pydict(),
postgresql.JSONB: lambda: cls.__faker__.pydict(),
}

@classmethod
Expand Down Expand Up @@ -126,6 +129,11 @@ def get_type_from_column(cls, column: Column) -> type:
column_type = type(column.type)
if column_type in cls.get_sqlalchemy_types():
annotation = column_type
elif issubclass(column_type, postgresql.ARRAY):
if type(column.type.item_type) in cls.get_sqlalchemy_types(): # type: ignore[attr-defined]
wangxin688 marked this conversation as resolved.
Show resolved Hide resolved
annotation = List[type(column.type.item_type)] # type: ignore # noqa: PGH003
wangxin688 marked this conversation as resolved.
Show resolved Hide resolved
else:
annotation = List[column.type.item_type.python_type] # type: ignore[assignment,name-defined]
elif issubclass(column_type, types.ARRAY):
annotation = List[column.type.item_type.python_type] # type: ignore[assignment,name-defined]
else:
Expand Down
44 changes: 42 additions & 2 deletions tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from enum import Enum
from typing import Any, List
from ipaddress import ip_network
from typing import Any, Dict, List
from uuid import UUID

import pytest
from sqlalchemy import ForeignKey, __version__, orm, types
from sqlalchemy import ForeignKey, Text, __version__, orm, types
from sqlalchemy.dialects.postgresql import ARRAY, CIDR, HSTORE, INET, JSON, JSONB
from sqlalchemy.ext.mutable import MutableDict, MutableList

from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory

Expand Down Expand Up @@ -64,6 +68,42 @@ class ModelFactory(SQLAlchemyFactory[Model]):
assert isinstance(instance.str_array_type[0], str)


def test_pg_dialect_types() -> None:
class Base(orm.DeclarativeBase): ...

class PgModel(Base):
__tablename__ = "pgmodel"
id: orm.Mapped[int] = orm.mapped_column(primary_key=True)
uuid_type: orm.Mapped[UUID] = orm.mapped_column(type_=types.UUID)
nested_array_inet: orm.Mapped[List[str]] = orm.mapped_column(type_=ARRAY(INET, dimensions=1))
nested_array_cidr: orm.Mapped[List[str]] = orm.mapped_column(type_=ARRAY(CIDR, dimensions=1))
hstore_type: orm.Mapped[Dict] = orm.mapped_column(type_=HSTORE)
pg_json_type: orm.Mapped[Dict] = orm.mapped_column(type_=JSON)
pg_jsonb_type: orm.Mapped[Dict] = orm.mapped_column(type_=JSONB)
mut_nested_arry_inet: orm.Mapped[List[str]] = orm.mapped_column(
type_=MutableList.as_mutable(ARRAY(INET, dimensions=1))
)
# ignore mypy type check: it's a known issue: https://github.com/sqlalchemy/sqlalchemy/discussions/9203
mut_pg_json_type: orm.Mapped[Dict] = orm.mapped_column(type_=MutableDict.as_mutable(JSON(astext_type=Text()))) # type: ignore[no-untyped-call]

class ModelFactory(SQLAlchemyFactory[PgModel]):
__model__ = PgModel

instance = ModelFactory.build()

assert isinstance(instance.nested_array_inet[0], str)
assert ip_network(instance.nested_array_inet[0])
assert isinstance(instance.nested_array_cidr[0], str)
assert ip_network(instance.nested_array_cidr[0])
assert isinstance(instance.hstore_type, dict)
assert isinstance(instance.pg_json_type, dict)
assert isinstance(instance.pg_json_type, dict)
assert isinstance(instance.uuid_type, UUID)
assert isinstance(instance.mut_nested_arry_inet[0], str)
assert ip_network(instance.mut_nested_arry_inet[0])
assert isinstance(instance.mut_pg_json_type, dict)


@pytest.mark.parametrize(
"type_",
tuple(SQLAlchemyFactory.get_sqlalchemy_types().keys()),
Expand Down
2 changes: 1 addition & 1 deletion tests/test_recursive_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_recursive_pydantic_models(factory_use_construct: bool) -> None:
factory = ModelFactory.create_factory(PydanticNode)

result = factory.build(factory_use_construct)
assert result.child is _Sentinel, "Default is not used"
assert result.child is _Sentinel, "Default is not used" # type: ignore[comparison-overlap]
assert isinstance(result.union_child, int)
assert result.optional_child is None
assert result.list_child == []
Expand Down
Loading