From 6c74cf6c879c707806ab4b2a92f8d2d6e5b63050 Mon Sep 17 00:00:00 2001 From: Viraj Kanwade Date: Fri, 12 Apr 2024 23:28:37 -0700 Subject: [PATCH] add: support for pydantic models - first cut --- .gitignore | 1 + pyproject.toml | 9 +- src/sqlacodegen/generators.py | 229 ++++++++++++++++++++++++++++++- tests/test_cli.py | 28 ++++ tests/test_generator_pydantic.py | 45 ++++++ 5 files changed, 309 insertions(+), 3 deletions(-) create mode 100644 tests/test_generator_pydantic.py diff --git a/.gitignore b/.gitignore index b5b2f478..7478b082 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ dist build venv* +.venv diff --git a/pyproject.toml b/pyproject.toml index c6434266..7d386f34 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "SQLAlchemy >= 2.0.23", "inflect >= 4.0.0", "importlib_metadata; python_version < '3.10'", + "pydantic >= 2.0", ] dynamic = ["version"] @@ -55,6 +56,7 @@ pgvector = ["pgvector >= 0.2.4"] tables = "sqlacodegen.generators:TablesGenerator" declarative = "sqlacodegen.generators:DeclarativeGenerator" dataclasses = "sqlacodegen.generators:DataclassGenerator" +pydanticmodels = "sqlacodegen.generators:PydanticGenerator" sqlmodels = "sqlacodegen.generators:SQLModelGenerator" [project.scripts] @@ -65,7 +67,7 @@ version_scheme = "post-release" local_scheme = "dirty-tag" [tool.ruff] -select = [ +lint.select = [ "E", "F", "W", # default Flake8 "I", # isort "ISC", # flake8-implicit-str-concat @@ -97,6 +99,9 @@ skip_missing_interpreters = true minversion = 4.0.0 [testenv] -extras = test +extras = + test + sqlmodel + pydantic commands = python -m pytest {posargs} """ diff --git a/src/sqlacodegen/generators.py b/src/sqlacodegen/generators.py index 21eadb63..04db5e7c 100644 --- a/src/sqlacodegen/generators.py +++ b/src/sqlacodegen/generators.py @@ -13,9 +13,10 @@ from keyword import iskeyword from pprint import pformat from textwrap import indent -from typing import Any, ClassVar +from typing import Any, ClassVar, Optional import inflect +import pydantic import sqlalchemy from sqlalchemy import ( ARRAY, @@ -1301,6 +1302,232 @@ def render_join(terms: list[JoinType]) -> str: ) +class PydanticGenerator(DeclarativeGenerator): + def __init__( + self, + metadata: MetaData, + bind: Connection | Engine, + options: Sequence[str], + *, + indentation: str = " ", + base_class_name: str = "BaseModel", + ): + super().__init__( + metadata, + bind, + options, + indentation=indentation, + base_class_name=base_class_name, + ) + + def generate_base(self) -> None: + self.base = Base( + literal_imports=[ + LiteralImport("pydantic", "BaseModel"), + LiteralImport("pydantic", "ConfigDict"), + ], + declarations=[], + metadata_ref="", + ) + + def generate_models(self) -> list[Model]: + models_by_table_name: dict[str, Model] = {} + + # Pick association tables from the metadata into their own set, don't process + # them normally + links: defaultdict[str, list[Model]] = defaultdict(lambda: []) + for table in self.metadata.sorted_tables: + qualified_name = qualified_table_name(table) + + # Link tables have exactly two foreign key constraints and all columns are + # involved in them + fk_constraints = sorted( + table.foreign_key_constraints, key=get_constraint_sort_key + ) + if len(fk_constraints) == 2 and all( + col.foreign_keys for col in table.columns + ): + model = models_by_table_name[qualified_name] = Model(table) + tablename = fk_constraints[0].elements[0].column.table.name + links[tablename].append(model) + continue + + # Only difference from DeclarativeGenerator.generate_models + model = ModelClass(table) + models_by_table_name[qualified_name] = model + + # Fill in the columns + for column in table.c: + column_attr = ColumnAttribute(model, column) + model.columns.append(column_attr) + # difference end + + # Add relationships + for model in models_by_table_name.values(): + if isinstance(model, ModelClass): + self.generate_relationships( + model, models_by_table_name, links[model.table.name] + ) + + # Nest inherited classes in their superclasses to ensure proper ordering + if "nojoined" not in self.options: + for model in list(models_by_table_name.values()): + if not isinstance(model, ModelClass): + continue + + pk_column_names = {col.name for col in model.table.primary_key.columns} + for constraint in model.table.foreign_key_constraints: + if set(get_column_names(constraint)) == pk_column_names: + target = models_by_table_name[ + qualified_table_name(constraint.elements[0].column.table) + ] + if isinstance(target, ModelClass): + model.parent_class = target + target.children.append(model) + + # Change base if we only have tables + if not any( + isinstance(model, ModelClass) for model in models_by_table_name.values() + ): + super().generate_base() + + # Collect the imports + self.collect_imports(models_by_table_name.values()) + + # Rename models and their attributes that conflict with imports or other + # attributes + global_names = { + name for namespace in self.imports.values() for name in namespace + } + for model in models_by_table_name.values(): + self.generate_model_name(model, global_names) + global_names.add(model.name) + + return list(models_by_table_name.values()) + + def collect_imports(self, models: Iterable[Model]) -> None: + # call TablesGenerator collect_imports bypassing DeclarativeGenerator + super(DeclarativeGenerator, self).collect_imports(models) + + def collect_imports_for_model(self, model: Model) -> None: + for column in model.table.c: + self.collect_imports_for_column(column) + + # for constraint in model.table.constraints: + # self.collect_imports_for_constraint(constraint) + + # for index in model.table.indexes: + # self.collect_imports_for_constraint(index) + + def collect_imports_for_column(self, column: Column[Any]) -> None: + self.add_import(column.type.python_type) + + if isinstance(column.type, ARRAY): + # self.add_import(column.type.item_type.__class__) + print( + "collect_imports_for_column ARRAY", + column.type.item_type, + column.type.item_type.__class__, + ) + ... + elif isinstance(column.type, JSONB): + if ( + not isinstance(column.type.astext_type, Text) + or column.type.astext_type.length is not None + ): + print("collect_imports_for_column JSONB", column.type.astext_type) + # self.add_import(column.type.astext_type) + ... + + def add_import(self, obj: Any) -> None: + # Don't store builtin imports + if getattr(obj, "__module__", "builtins") == "builtins": + return + + type_ = type(obj) if not isinstance(obj, type) else obj + pkgname: Optional[str] = None # noqa: UP007 + + if type_.__module__.startswith("sqlalchemy.dialects."): + pkgname = None + elif type_.__name__ in dir(sqlalchemy): + pkgname = None + elif type_.__name__ in dir(pydantic): + pkgname = "pydantic" + else: + pkgname = type_.__module__ + + if pkgname: + self.add_literal_import(pkgname, type_.__name__) + + def render_class(self, model: ModelClass) -> str: + sections: list[str] = [] + + sections.append("model_config = ConfigDict(from_attributes=True)") + + # Render column attributes + rendered_column_attributes: list[str] = [] + + for column_attr in model.columns: + rendered_column_attributes.append(self.render_column_attribute(column_attr)) + + if rendered_column_attributes: + sections.append("\n".join(rendered_column_attributes)) + + # Render relationship attributes + # rendered_relationship_attributes: list[str] = [ + # self.render_relationship(relationship) + # for relationship in model.relationships + # ] + + # if rendered_relationship_attributes: + # sections.append("\n".join(rendered_relationship_attributes)) + + declaration = self.render_class_declaration(model) + rendered_sections = "\n\n".join( + indent(section, self.indentation) for section in sections + ) + return f"{declaration}\n{rendered_sections}" + + def render_column_attribute(self, column_attr: ColumnAttribute) -> str: + column = column_attr.column + + try: + python_type = column.type.python_type + python_type_name = python_type.__name__ + if python_type.__module__ == "builtins": + if python_type_name == "str" and column.type.length is not None: + column_python_type = self.render_column_type_str_length( + column.type.length + ) + else: + column_python_type = python_type_name + else: + python_type_module = python_type.__module__ + column_python_type = f"{python_type_module}.{python_type_name}" + self.add_module_import(python_type_module) + except NotImplementedError: + self.add_literal_import("typing", "Any") + column_python_type = "Any" + + if column.nullable: + self.add_literal_import("typing", "Optional") + column_python_type = f"Optional[{column_python_type}]" + return f"{column_attr.name}: {column_python_type} = None" + else: + return f"{column_attr.name}: {column_python_type}" + + def render_column_type_str_length(self, length: int) -> str: + self.add_literal_import("typing_extensions", "Annotated") + self.add_literal_import("pydantic", "StringConstraints") + + return f"Annotated[str, StringConstraints(max_length={length})]" + + def render_column( + self, column: Column[Any], show_name: bool, is_table: bool = False + ) -> str: + return super().render_column(column, show_name, is_table) + + class DataclassGenerator(DeclarativeGenerator): def __init__( self, diff --git a/tests/test_cli.py b/tests/test_cli.py index 6a176d8c..0d46af10 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -150,6 +150,34 @@ class Foo(SQLModel, table=True): ) +def test_cli_pydanticmodels(db_path: Path, tmp_path: Path) -> None: + output_path = tmp_path / "outfile" + subprocess.run( + [ + "sqlacodegen", + f"sqlite:///{db_path}", + "--generator", + "pydanticmodels", + "--outfile", + str(output_path), + ], + check=True, + ) + + assert ( + output_path.read_text() + == """\ +from pydantic import BaseModel, ConfigDict + +class Foo(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: int + name: str +""" + ) + + def test_main() -> None: expected_version = version("sqlacodegen") completed = subprocess.run( diff --git a/tests/test_generator_pydantic.py b/tests/test_generator_pydantic.py new file mode 100644 index 00000000..adfa5f0c --- /dev/null +++ b/tests/test_generator_pydantic.py @@ -0,0 +1,45 @@ +import pytest +from _pytest.fixtures import FixtureRequest +from sqlalchemy.dialects import mysql +from sqlalchemy.engine import Engine +from sqlalchemy.schema import Column, MetaData, Table + +from sqlacodegen.generators import CodeGenerator, PydanticGenerator + +from .conftest import validate_code + + +@pytest.fixture +def generator( + request: FixtureRequest, metadata: MetaData, engine: Engine +) -> CodeGenerator: + options = getattr(request, "param", []) + return PydanticGenerator(metadata, engine, options) + + +@pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"]) +def test_mysql_column_types(generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", mysql.INTEGER), + Column("name", mysql.VARCHAR(255)), + Column("text", mysql.TEXT), + ) + + validate_code( + generator.generate(), + """\ + from typing import Optional + + from pydantic import BaseModel, ConfigDict, StringConstraints + from typing_extensions import Annotated + + class SimpleItems(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: Optional[int] = None + name: Optional[Annotated[str, StringConstraints(max_length=255)]] = None + text: Optional[str] = None + """, + )