From 77e12be2ca97520e33258bd38229eb5649a0cf13 Mon Sep 17 00:00:00 2001 From: jeffry <36665036+wangxin688@users.noreply.github.com> Date: Tue, 30 Apr 2024 03:20:54 +0000 Subject: [PATCH 01/10] feat(sqlafctory): support nested type in pg array types and other pg native types --- polyfactory/factories/sqlalchemy_factory.py | 8 ++++++ .../test_sqlalchemy_factory_v2.py | 28 +++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/polyfactory/factories/sqlalchemy_factory.py b/polyfactory/factories/sqlalchemy_factory.py index 5a2155b2..9d44b1c0 100644 --- a/polyfactory/factories/sqlalchemy_factory.py +++ b/polyfactory/factories/sqlalchemy_factory.py @@ -95,6 +95,9 @@ def get_sqlalchemy_types(cls) -> dict[Any, Callable[[], Any]]: postgresql.NUMRANGE: lambda: tuple(sorted([cls.__faker__.pyint(), cls.__faker__.pyint()])), postgresql.TSRANGE: lambda: (cls.__faker__.past_datetime(), datetime.now()), # noqa: DTZ005 postgresql.TSTZRANGE: lambda: (cls.__faker__.past_datetime(), datetime.now()), # noqa: DTZ005 + postgresql.HSTORE: lambda: cls.__faker__.pydict(), + postgresql.JSON: lambda: cls.__faker__.pydict(), + postgresql.JSONB: lambda: cls.__faker__.pydict(), } @classmethod @@ -126,6 +129,11 @@ def get_type_from_column(cls, column: Column) -> type: column_type = type(column.type) if column_type in cls.get_sqlalchemy_types(): annotation = column_type + elif issubclass(column_type, postgresql.ARRAY): + if type(column.type.item_type) in cls.get_sqlalchemy_types(): # type: ignore[attr-defined] + annotation = List[type(column.type.item_type)] # type: ignore # noqa: PGH003 + else: + annotation = List[column.type.item_type.python_type] # type: ignore[assignment,name-defined] elif issubclass(column_type, types.ARRAY): annotation = List[column.type.item_type.python_type] # type: ignore[assignment,name-defined] else: diff --git a/tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py b/tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py index ed293e77..a39d89fe 100644 --- a/tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py +++ b/tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py @@ -1,8 +1,10 @@ from enum import Enum +from ipaddress import ip_network from typing import Any, List import pytest from sqlalchemy import ForeignKey, __version__, orm, types +from sqlalchemy.dialects.postgresql import ARRAY, CIDR, HSTORE, INET, JSON, JSONB from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory @@ -64,6 +66,32 @@ class ModelFactory(SQLAlchemyFactory[Model]): assert isinstance(instance.str_array_type[0], str) +def test_pg_dialect_types() -> None: + class Base(orm.DeclarativeBase): ... + + class PgModel(Base): + __tablename__ = "pgmodel" + id: orm.Mapped[int] = orm.mapped_column(primary_key=True) + nested_array_inet: orm.Mapped[List[str]] = orm.mapped_column(type_=ARRAY(INET, dimensions=1)) + nested_array_cidr: orm.Mapped[list[str]] = orm.mapped_column(type_=ARRAY(CIDR, dimensions=1)) + hstore_type: orm.Mapped[dict] = orm.mapped_column(type_=HSTORE) + pg_json_type: orm.Mapped[dict] = orm.mapped_column(type_=JSON) + pg_jsonb_type: orm.Mapped[dict] = orm.mapped_column(type_=JSONB) + + class ModelFactory(SQLAlchemyFactory[PgModel]): + __model__ = PgModel + + instance = ModelFactory.build() + + assert isinstance(instance.nested_array_inet[0], str) + assert ip_network(instance.nested_array_inet[0]) + assert isinstance(instance.nested_array_cidr[0], str) + assert ip_network(instance.nested_array_cidr[0]) + assert isinstance(instance.hstore_type, dict) + assert isinstance(instance.pg_json_type, dict) + assert isinstance(instance.pg_json_type, dict) + + @pytest.mark.parametrize( "type_", tuple(SQLAlchemyFactory.get_sqlalchemy_types().keys()), From 4e8dd81f29acc5b7fcb08afc2546d4f1addf848c Mon Sep 17 00:00:00 2001 From: jeffry <36665036+wangxin688@users.noreply.github.com> Date: Tue, 30 Apr 2024 03:31:29 +0000 Subject: [PATCH 02/10] fix(type): fix type annation for lower python version --- tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py b/tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py index a39d89fe..63e8cb5e 100644 --- a/tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py +++ b/tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py @@ -73,7 +73,7 @@ class PgModel(Base): __tablename__ = "pgmodel" id: orm.Mapped[int] = orm.mapped_column(primary_key=True) nested_array_inet: orm.Mapped[List[str]] = orm.mapped_column(type_=ARRAY(INET, dimensions=1)) - nested_array_cidr: orm.Mapped[list[str]] = orm.mapped_column(type_=ARRAY(CIDR, dimensions=1)) + nested_array_cidr: orm.Mapped[List[str]] = orm.mapped_column(type_=ARRAY(CIDR, dimensions=1)) hstore_type: orm.Mapped[dict] = orm.mapped_column(type_=HSTORE) pg_json_type: orm.Mapped[dict] = orm.mapped_column(type_=JSON) pg_jsonb_type: orm.Mapped[dict] = orm.mapped_column(type_=JSONB) From 3d9404891cdbf0db46cda6eeeb6e22d1f29982c1 Mon Sep 17 00:00:00 2001 From: jeffry <36665036+wangxin688@users.noreply.github.com> Date: Tue, 30 Apr 2024 03:33:37 +0000 Subject: [PATCH 03/10] fix(type): fix type annation for lower python version --- tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py b/tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py index 63e8cb5e..ab33d227 100644 --- a/tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py +++ b/tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py @@ -1,6 +1,6 @@ from enum import Enum from ipaddress import ip_network -from typing import Any, List +from typing import Any, Dict, List import pytest from sqlalchemy import ForeignKey, __version__, orm, types @@ -74,9 +74,9 @@ class PgModel(Base): id: orm.Mapped[int] = orm.mapped_column(primary_key=True) nested_array_inet: orm.Mapped[List[str]] = orm.mapped_column(type_=ARRAY(INET, dimensions=1)) nested_array_cidr: orm.Mapped[List[str]] = orm.mapped_column(type_=ARRAY(CIDR, dimensions=1)) - hstore_type: orm.Mapped[dict] = orm.mapped_column(type_=HSTORE) - pg_json_type: orm.Mapped[dict] = orm.mapped_column(type_=JSON) - pg_jsonb_type: orm.Mapped[dict] = orm.mapped_column(type_=JSONB) + hstore_type: orm.Mapped[Dict] = orm.mapped_column(type_=HSTORE) + pg_json_type: orm.Mapped[Dict] = orm.mapped_column(type_=JSON) + pg_jsonb_type: orm.Mapped[Dict] = orm.mapped_column(type_=JSONB) class ModelFactory(SQLAlchemyFactory[PgModel]): __model__ = PgModel From 0c2241ec98ba67fe45e86a1a11765b28a0362868 Mon Sep 17 00:00:00 2001 From: jeffry <36665036+wangxin688@users.noreply.github.com> Date: Tue, 30 Apr 2024 03:44:01 +0000 Subject: [PATCH 04/10] fix(type): fix tests type error --- tests/test_recursive_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_recursive_models.py b/tests/test_recursive_models.py index b8a0669a..865ea1e8 100644 --- a/tests/test_recursive_models.py +++ b/tests/test_recursive_models.py @@ -53,7 +53,7 @@ def test_recursive_pydantic_models(factory_use_construct: bool) -> None: factory = ModelFactory.create_factory(PydanticNode) result = factory.build(factory_use_construct) - assert result.child is _Sentinel, "Default is not used" + assert result.child is _Sentinel, "Default is not used" # type: ignore[comparison-overlap] assert isinstance(result.union_child, int) assert result.optional_child is None assert result.list_child == [] From a76f23bce97e99eb682f33e05be95460cae89035 Mon Sep 17 00:00:00 2001 From: jeffry <36665036+wangxin688@users.noreply.github.com> Date: Tue, 30 Apr 2024 03:59:51 +0000 Subject: [PATCH 05/10] feat(tests): add coverage for more test case --- .../test_sqlalchemy_factory_v2.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py b/tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py index ab33d227..88eb46c3 100644 --- a/tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py +++ b/tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py @@ -1,10 +1,12 @@ from enum import Enum from ipaddress import ip_network from typing import Any, Dict, List +from uuid import UUID import pytest -from sqlalchemy import ForeignKey, __version__, orm, types +from sqlalchemy import ForeignKey, Text, __version__, orm, types from sqlalchemy.dialects.postgresql import ARRAY, CIDR, HSTORE, INET, JSON, JSONB +from sqlalchemy.ext.mutable import MutableDict, MutableList from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory @@ -72,11 +74,17 @@ class Base(orm.DeclarativeBase): ... class PgModel(Base): __tablename__ = "pgmodel" id: orm.Mapped[int] = orm.mapped_column(primary_key=True) + uuid_type: orm.Mapped[UUID] = orm.mapped_column(type_=types.UUID) nested_array_inet: orm.Mapped[List[str]] = orm.mapped_column(type_=ARRAY(INET, dimensions=1)) nested_array_cidr: orm.Mapped[List[str]] = orm.mapped_column(type_=ARRAY(CIDR, dimensions=1)) hstore_type: orm.Mapped[Dict] = orm.mapped_column(type_=HSTORE) pg_json_type: orm.Mapped[Dict] = orm.mapped_column(type_=JSON) pg_jsonb_type: orm.Mapped[Dict] = orm.mapped_column(type_=JSONB) + mut_nested_arry_inet: orm.Mapped[List[str]] = orm.mapped_column( + type_=MutableList.as_mutable(ARRAY(INET, dimensions=1)) + ) + # ignore mypy type check: it's a known issue: https://github.com/sqlalchemy/sqlalchemy/discussions/9203 + mut_pg_json_type: orm.Mapped[Dict] = orm.mapped_column(type_=MutableDict.as_mutable(JSON(astext_type=Text()))) # type: ignore[no-untyped-call] class ModelFactory(SQLAlchemyFactory[PgModel]): __model__ = PgModel @@ -90,6 +98,10 @@ class ModelFactory(SQLAlchemyFactory[PgModel]): assert isinstance(instance.hstore_type, dict) assert isinstance(instance.pg_json_type, dict) assert isinstance(instance.pg_json_type, dict) + assert isinstance(instance.uuid_type, UUID) + assert isinstance(instance.mut_nested_arry_inet[0], str) + assert ip_network(instance.mut_nested_arry_inet[0]) + assert isinstance(instance.mut_pg_json_type, dict) @pytest.mark.parametrize( From d9c6bf57bd444d8373802248e0db265c61d2aefa Mon Sep 17 00:00:00 2001 From: jeffry <36665036+wangxin688@users.noreply.github.com> Date: Tue, 30 Apr 2024 13:10:03 +0800 Subject: [PATCH 06/10] Update polyfactory/factories/sqlalchemy_factory.py Co-authored-by: Peter Schutt --- polyfactory/factories/sqlalchemy_factory.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/polyfactory/factories/sqlalchemy_factory.py b/polyfactory/factories/sqlalchemy_factory.py index 9d44b1c0..ab7ca057 100644 --- a/polyfactory/factories/sqlalchemy_factory.py +++ b/polyfactory/factories/sqlalchemy_factory.py @@ -127,10 +127,10 @@ def should_column_be_set(cls, column: Any) -> bool: @classmethod def get_type_from_column(cls, column: Column) -> type: column_type = type(column.type) - if column_type in cls.get_sqlalchemy_types(): + if column_type in (sqla_types := cls.get_sqlalchemy_types()): annotation = column_type elif issubclass(column_type, postgresql.ARRAY): - if type(column.type.item_type) in cls.get_sqlalchemy_types(): # type: ignore[attr-defined] + if type(column.type.item_type) in sqla_types: # type: ignore[attr-defined] annotation = List[type(column.type.item_type)] # type: ignore # noqa: PGH003 else: annotation = List[column.type.item_type.python_type] # type: ignore[assignment,name-defined] From 4bbac6a4a746226dd9816bb4fca157444fda3d35 Mon Sep 17 00:00:00 2001 From: jeffry <36665036+wangxin688@users.noreply.github.com> Date: Tue, 30 Apr 2024 05:12:00 +0000 Subject: [PATCH 07/10] fix(type): specifical mypy ignore error --- polyfactory/factories/sqlalchemy_factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/polyfactory/factories/sqlalchemy_factory.py b/polyfactory/factories/sqlalchemy_factory.py index ab7ca057..3c5d6346 100644 --- a/polyfactory/factories/sqlalchemy_factory.py +++ b/polyfactory/factories/sqlalchemy_factory.py @@ -131,7 +131,7 @@ def get_type_from_column(cls, column: Column) -> type: annotation = column_type elif issubclass(column_type, postgresql.ARRAY): if type(column.type.item_type) in sqla_types: # type: ignore[attr-defined] - annotation = List[type(column.type.item_type)] # type: ignore # noqa: PGH003 + annotation = List[type(column.type.item_type)] # type: ignore[attr-defined,misc] else: annotation = List[column.type.item_type.python_type] # type: ignore[assignment,name-defined] elif issubclass(column_type, types.ARRAY): From 0f7a3e4739322bdab38bf0d98c59968fa9de5716 Mon Sep 17 00:00:00 2001 From: jeffry <36665036+wangxin688@users.noreply.github.com> Date: Tue, 30 Apr 2024 10:15:57 +0000 Subject: [PATCH 08/10] fix(sqlafactory): remove json and jsonb due to uncertain python type --- polyfactory/factories/sqlalchemy_factory.py | 2 -- .../test_sqlalchemy_factory_v2.py | 13 +++---------- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/polyfactory/factories/sqlalchemy_factory.py b/polyfactory/factories/sqlalchemy_factory.py index 3c5d6346..8f789088 100644 --- a/polyfactory/factories/sqlalchemy_factory.py +++ b/polyfactory/factories/sqlalchemy_factory.py @@ -96,8 +96,6 @@ def get_sqlalchemy_types(cls) -> dict[Any, Callable[[], Any]]: postgresql.TSRANGE: lambda: (cls.__faker__.past_datetime(), datetime.now()), # noqa: DTZ005 postgresql.TSTZRANGE: lambda: (cls.__faker__.past_datetime(), datetime.now()), # noqa: DTZ005 postgresql.HSTORE: lambda: cls.__faker__.pydict(), - postgresql.JSON: lambda: cls.__faker__.pydict(), - postgresql.JSONB: lambda: cls.__faker__.pydict(), } @classmethod diff --git a/tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py b/tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py index 88eb46c3..941e658a 100644 --- a/tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py +++ b/tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py @@ -4,9 +4,9 @@ from uuid import UUID import pytest -from sqlalchemy import ForeignKey, Text, __version__, orm, types -from sqlalchemy.dialects.postgresql import ARRAY, CIDR, HSTORE, INET, JSON, JSONB -from sqlalchemy.ext.mutable import MutableDict, MutableList +from sqlalchemy import ForeignKey, __version__, orm, types +from sqlalchemy.dialects.postgresql import ARRAY, CIDR, HSTORE, INET +from sqlalchemy.ext.mutable import MutableList from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory @@ -78,13 +78,9 @@ class PgModel(Base): nested_array_inet: orm.Mapped[List[str]] = orm.mapped_column(type_=ARRAY(INET, dimensions=1)) nested_array_cidr: orm.Mapped[List[str]] = orm.mapped_column(type_=ARRAY(CIDR, dimensions=1)) hstore_type: orm.Mapped[Dict] = orm.mapped_column(type_=HSTORE) - pg_json_type: orm.Mapped[Dict] = orm.mapped_column(type_=JSON) - pg_jsonb_type: orm.Mapped[Dict] = orm.mapped_column(type_=JSONB) mut_nested_arry_inet: orm.Mapped[List[str]] = orm.mapped_column( type_=MutableList.as_mutable(ARRAY(INET, dimensions=1)) ) - # ignore mypy type check: it's a known issue: https://github.com/sqlalchemy/sqlalchemy/discussions/9203 - mut_pg_json_type: orm.Mapped[Dict] = orm.mapped_column(type_=MutableDict.as_mutable(JSON(astext_type=Text()))) # type: ignore[no-untyped-call] class ModelFactory(SQLAlchemyFactory[PgModel]): __model__ = PgModel @@ -96,12 +92,9 @@ class ModelFactory(SQLAlchemyFactory[PgModel]): assert isinstance(instance.nested_array_cidr[0], str) assert ip_network(instance.nested_array_cidr[0]) assert isinstance(instance.hstore_type, dict) - assert isinstance(instance.pg_json_type, dict) - assert isinstance(instance.pg_json_type, dict) assert isinstance(instance.uuid_type, UUID) assert isinstance(instance.mut_nested_arry_inet[0], str) assert ip_network(instance.mut_nested_arry_inet[0]) - assert isinstance(instance.mut_pg_json_type, dict) @pytest.mark.parametrize( From edda6183b102d0c43aef1891feb885175bf99493 Mon Sep 17 00:00:00 2001 From: jeffry <36665036+wangxin688@users.noreply.github.com> Date: Sun, 5 May 2024 12:58:37 +0000 Subject: [PATCH 09/10] feat(sqlafactory): add json type support sqlafactory --- polyfactory/factories/sqlalchemy_factory.py | 4 ++- .../test_sqlalchemy_factory_v2.py | 29 +++++++++++++++---- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/polyfactory/factories/sqlalchemy_factory.py b/polyfactory/factories/sqlalchemy_factory.py index 8f789088..b597d357 100644 --- a/polyfactory/factories/sqlalchemy_factory.py +++ b/polyfactory/factories/sqlalchemy_factory.py @@ -96,6 +96,7 @@ def get_sqlalchemy_types(cls) -> dict[Any, Callable[[], Any]]: postgresql.TSRANGE: lambda: (cls.__faker__.past_datetime(), datetime.now()), # noqa: DTZ005 postgresql.TSTZRANGE: lambda: (cls.__faker__.past_datetime(), datetime.now()), # noqa: DTZ005 postgresql.HSTORE: lambda: cls.__faker__.pydict(), + types.JSON: lambda: cls.__faker__.pydict(), } @classmethod @@ -125,7 +126,8 @@ def should_column_be_set(cls, column: Any) -> bool: @classmethod def get_type_from_column(cls, column: Column) -> type: column_type = type(column.type) - if column_type in (sqla_types := cls.get_sqlalchemy_types()): + sqla_types = cls.get_sqlalchemy_types() + if column_type in sqla_types: annotation = column_type elif issubclass(column_type, postgresql.ARRAY): if type(column.type.item_type) in sqla_types: # type: ignore[attr-defined] diff --git a/tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py b/tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py index 941e658a..5ea48f2a 100644 --- a/tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py +++ b/tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py @@ -5,7 +5,8 @@ import pytest from sqlalchemy import ForeignKey, __version__, orm, types -from sqlalchemy.dialects.postgresql import ARRAY, CIDR, HSTORE, INET +from sqlalchemy.dialects.mysql import JSON as MYSQL_JSON +from sqlalchemy.dialects.postgresql import ARRAY, CIDR, HSTORE, INET, JSON, JSONB from sqlalchemy.ext.mutable import MutableList from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory @@ -71,8 +72,8 @@ class ModelFactory(SQLAlchemyFactory[Model]): def test_pg_dialect_types() -> None: class Base(orm.DeclarativeBase): ... - class PgModel(Base): - __tablename__ = "pgmodel" + class SqlaModel(Base): + __tablename__ = "sql_models" id: orm.Mapped[int] = orm.mapped_column(primary_key=True) uuid_type: orm.Mapped[UUID] = orm.mapped_column(type_=types.UUID) nested_array_inet: orm.Mapped[List[str]] = orm.mapped_column(type_=ARRAY(INET, dimensions=1)) @@ -81,12 +82,20 @@ class PgModel(Base): mut_nested_arry_inet: orm.Mapped[List[str]] = orm.mapped_column( type_=MutableList.as_mutable(ARRAY(INET, dimensions=1)) ) + pg_json_type: orm.Mapped[Dict] = orm.mapped_column(type_=JSON) + pg_jsonb_type: orm.Mapped[Dict] = orm.mapped_column(type_=JSONB) + common_json_type: orm.Mapped[Dict] = orm.mapped_column(type_=types.JSON) + mysql_json: orm.Mapped[Dict] = orm.mapped_column(type_=MYSQL_JSON) - class ModelFactory(SQLAlchemyFactory[PgModel]): - __model__ = PgModel + multible_pg_json_type: orm.Mapped[Dict] = orm.mapped_column(type_=JSON) + multible_pg_jsonb_type: orm.Mapped[Dict] = orm.mapped_column(type_=JSONB) + multible_common_json_type: orm.Mapped[Dict] = orm.mapped_column(type_=types.JSON) + multible_mysql_json: orm.Mapped[Dict] = orm.mapped_column(type_=MYSQL_JSON) - instance = ModelFactory.build() + class ModelFactory(SQLAlchemyFactory[SqlaModel]): + __model__ = SqlaModel + instance = ModelFactory.build() assert isinstance(instance.nested_array_inet[0], str) assert ip_network(instance.nested_array_inet[0]) assert isinstance(instance.nested_array_cidr[0], str) @@ -95,6 +104,14 @@ class ModelFactory(SQLAlchemyFactory[PgModel]): assert isinstance(instance.uuid_type, UUID) assert isinstance(instance.mut_nested_arry_inet[0], str) assert ip_network(instance.mut_nested_arry_inet[0]) + assert isinstance(instance.pg_json_type, dict) + assert isinstance(instance.pg_jsonb_type, dict) + assert isinstance(instance.common_json_type, dict) + assert isinstance(instance.mysql_json, dict) + assert isinstance(instance.multible_pg_json_type, dict) + assert isinstance(instance.multible_pg_jsonb_type, dict) + assert isinstance(instance.multible_common_json_type, dict) + assert isinstance(instance.multible_mysql_json, dict) @pytest.mark.parametrize( From d3afb8266a9e6d9a69779e54b881c1a38f59a5eb Mon Sep 17 00:00:00 2001 From: jeffry <36665036+wangxin688@users.noreply.github.com> Date: Sun, 5 May 2024 13:05:24 +0000 Subject: [PATCH 10/10] feat(sqlafactory): add json type support sqlafactory --- polyfactory/factories/sqlalchemy_factory.py | 1 + .../test_sqlalchemy_factory_v2.py | 26 ++++++++++++++----- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/polyfactory/factories/sqlalchemy_factory.py b/polyfactory/factories/sqlalchemy_factory.py index b597d357..4b62c8f9 100644 --- a/polyfactory/factories/sqlalchemy_factory.py +++ b/polyfactory/factories/sqlalchemy_factory.py @@ -96,6 +96,7 @@ def get_sqlalchemy_types(cls) -> dict[Any, Callable[[], Any]]: postgresql.TSRANGE: lambda: (cls.__faker__.past_datetime(), datetime.now()), # noqa: DTZ005 postgresql.TSTZRANGE: lambda: (cls.__faker__.past_datetime(), datetime.now()), # noqa: DTZ005 postgresql.HSTORE: lambda: cls.__faker__.pydict(), + # `types.JSON` is compatible for sqlachemy extend dialects. Such as `pg.JSON` and `JSONB` types.JSON: lambda: cls.__faker__.pydict(), } diff --git a/tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py b/tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py index 5ea48f2a..de2ab7e4 100644 --- a/tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py +++ b/tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py @@ -4,10 +4,12 @@ from uuid import UUID import pytest -from sqlalchemy import ForeignKey, __version__, orm, types +from sqlalchemy import ForeignKey, Text, __version__, orm, types +from sqlalchemy.dialects.mssql import JSON as MSSQL_JSON from sqlalchemy.dialects.mysql import JSON as MYSQL_JSON from sqlalchemy.dialects.postgresql import ARRAY, CIDR, HSTORE, INET, JSON, JSONB -from sqlalchemy.ext.mutable import MutableList +from sqlalchemy.dialects.sqlite import JSON as SQLITE_JSON +from sqlalchemy.ext.mutable import MutableDict, MutableList from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory @@ -86,11 +88,19 @@ class SqlaModel(Base): pg_jsonb_type: orm.Mapped[Dict] = orm.mapped_column(type_=JSONB) common_json_type: orm.Mapped[Dict] = orm.mapped_column(type_=types.JSON) mysql_json: orm.Mapped[Dict] = orm.mapped_column(type_=MYSQL_JSON) + sqlite_json: orm.Mapped[Dict] = orm.mapped_column(type_=SQLITE_JSON) + mssql_json: orm.Mapped[Dict] = orm.mapped_column(type_=MSSQL_JSON) - multible_pg_json_type: orm.Mapped[Dict] = orm.mapped_column(type_=JSON) - multible_pg_jsonb_type: orm.Mapped[Dict] = orm.mapped_column(type_=JSONB) - multible_common_json_type: orm.Mapped[Dict] = orm.mapped_column(type_=types.JSON) - multible_mysql_json: orm.Mapped[Dict] = orm.mapped_column(type_=MYSQL_JSON) + multible_pg_json_type: orm.Mapped[Dict] = orm.mapped_column( + type_=MutableDict.as_mutable(JSON(astext_type=Text())) # type: ignore[no-untyped-call] + ) + multible_pg_jsonb_type: orm.Mapped[Dict] = orm.mapped_column( + type_=MutableDict.as_mutable(JSONB(astext_type=Text())) # type: ignore[no-untyped-call] + ) + multible_common_json_type: orm.Mapped[Dict] = orm.mapped_column(type_=MutableDict.as_mutable(types.JSON())) + multible_mysql_json: orm.Mapped[Dict] = orm.mapped_column(type_=MutableDict.as_mutable(MYSQL_JSON())) + multible_sqlite_json: orm.Mapped[Dict] = orm.mapped_column(type_=MutableDict.as_mutable(SQLITE_JSON())) + multible_mssql_json: orm.Mapped[Dict] = orm.mapped_column(type_=MutableDict.as_mutable(MSSQL_JSON())) class ModelFactory(SQLAlchemyFactory[SqlaModel]): __model__ = SqlaModel @@ -108,10 +118,14 @@ class ModelFactory(SQLAlchemyFactory[SqlaModel]): assert isinstance(instance.pg_jsonb_type, dict) assert isinstance(instance.common_json_type, dict) assert isinstance(instance.mysql_json, dict) + assert isinstance(instance.sqlite_json, dict) + assert isinstance(instance.mssql_json, dict) assert isinstance(instance.multible_pg_json_type, dict) assert isinstance(instance.multible_pg_jsonb_type, dict) assert isinstance(instance.multible_common_json_type, dict) assert isinstance(instance.multible_mysql_json, dict) + assert isinstance(instance.multible_sqlite_json, dict) + assert isinstance(instance.multible_mssql_json, dict) @pytest.mark.parametrize(