Skip to content

Commit

Permalink
fix(sqla_factory): fix json type error and pg dialect default value e… (
Browse files Browse the repository at this point in the history
#542)

Co-authored-by: guacs <[email protected]>
  • Loading branch information
wangxin688 and guacs authored May 12, 2024
1 parent 2f781ee commit 9e6edab
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 9 deletions.
21 changes: 13 additions & 8 deletions polyfactory/factories/sqlalchemy_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

try:
from sqlalchemy import Column, inspect, types
from sqlalchemy.dialects import mysql, postgresql
from sqlalchemy.dialects import mssql, mysql, postgresql, sqlite
from sqlalchemy.exc import NoInspectionAvailable
from sqlalchemy.orm import InstanceState, Mapper
except ImportError as e:
Expand Down Expand Up @@ -85,22 +85,28 @@ class SQLAlchemyFactory(Generic[T], BaseFactory[T]):

@classmethod
def get_sqlalchemy_types(cls) -> dict[Any, Callable[[], Any]]:
"""Get mapping of types where column type."""
"""Get mapping of types where column type.
for sqlalchemy dialect `JSON` type, accepted only basic types in pydict in case sqlalchemy process `JSON` raise serialize error.
"""
return {
types.TupleType: cls.__faker__.pytuple,
mssql.JSON: lambda: cls.__faker__.pydict(value_types=(str, int, bool, float)),
mysql.YEAR: lambda: cls.__random__.randint(1901, 2155),
postgresql.CIDR: lambda: cls.__faker__.ipv4(network=False),
mysql.JSON: lambda: cls.__faker__.pydict(value_types=(str, int, bool, float)),
postgresql.CIDR: lambda: cls.__faker__.ipv4(network=True),
postgresql.DATERANGE: lambda: (cls.__faker__.past_date(), date.today()), # noqa: DTZ011
postgresql.INET: lambda: cls.__faker__.ipv4(network=True),
postgresql.INET: lambda: cls.__faker__.ipv4(network=False),
postgresql.INT4RANGE: lambda: tuple(sorted([cls.__faker__.pyint(), cls.__faker__.pyint()])),
postgresql.INT8RANGE: lambda: tuple(sorted([cls.__faker__.pyint(), cls.__faker__.pyint()])),
postgresql.MACADDR: lambda: cls.__faker__.hexify(text="^^:^^:^^:^^:^^:^^", upper=True),
postgresql.NUMRANGE: lambda: tuple(sorted([cls.__faker__.pyint(), cls.__faker__.pyint()])),
postgresql.TSRANGE: lambda: (cls.__faker__.past_datetime(), datetime.now()), # noqa: DTZ005
postgresql.TSTZRANGE: lambda: (cls.__faker__.past_datetime(), datetime.now()), # noqa: DTZ005
postgresql.HSTORE: lambda: cls.__faker__.pydict(),
# `types.JSON` is compatible for sqlachemy extend dialects. Such as `pg.JSON` and `JSONB`
types.JSON: lambda: cls.__faker__.pydict(),
postgresql.HSTORE: lambda: cls.__faker__.pydict(value_types=(str, int, bool, float)),
postgresql.JSON: lambda: cls.__faker__.pydict(value_types=(str, int, bool, float)),
postgresql.JSONB: lambda: cls.__faker__.pydict(value_types=(str, int, bool, float)),
sqlite.JSON: lambda: cls.__faker__.pydict(value_types=(str, int, bool, float)),
types.JSON: lambda: cls.__faker__.pydict(value_types=(str, int, bool, float)),
}

@classmethod
Expand Down Expand Up @@ -148,7 +154,6 @@ def get_type_from_column(cls, column: Column) -> type:

if column.nullable:
annotation = Union[annotation, None] # type: ignore[assignment]

return annotation

@classmethod
Expand Down
14 changes: 13 additions & 1 deletion tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class ModelFactory(SQLAlchemyFactory[Model]):
assert isinstance(instance.str_array_type[0], str)


def test_pg_dialect_types() -> None:
def test_sqla_dialect_types() -> None:
class Base(orm.DeclarativeBase): ...

class SqlaModel(Base):
Expand Down Expand Up @@ -115,11 +115,23 @@ class ModelFactory(SQLAlchemyFactory[SqlaModel]):
assert isinstance(instance.mut_nested_arry_inet[0], str)
assert ip_network(instance.mut_nested_arry_inet[0])
assert isinstance(instance.pg_json_type, dict)
for value in instance.pg_json_type.values():
assert isinstance(value, (str, int, bool, float))
assert isinstance(instance.pg_jsonb_type, dict)
for value in instance.pg_jsonb_type.values():
assert isinstance(value, (str, int, bool, float))
assert isinstance(instance.common_json_type, dict)
for value in instance.common_json_type.values():
assert isinstance(value, (str, int, bool, float))
assert isinstance(instance.mysql_json, dict)
for value in instance.mysql_json.values():
assert isinstance(value, (str, int, bool, float))
assert isinstance(instance.sqlite_json, dict)
for value in instance.sqlite_json.values():
assert isinstance(value, (str, int, bool, float))
assert isinstance(instance.mssql_json, dict)
for value in instance.mssql_json.values():
assert isinstance(value, (str, int, bool, float))
assert isinstance(instance.multible_pg_json_type, dict)
assert isinstance(instance.multible_pg_jsonb_type, dict)
assert isinstance(instance.multible_common_json_type, dict)
Expand Down

0 comments on commit 9e6edab

Please sign in to comment.