From 4532f2d6e62351643eae3cb7ea6c0c4e61dcaf8d Mon Sep 17 00:00:00 2001 From: Abdeldjalil-H Date: Tue, 19 Nov 2024 23:27:44 +0100 Subject: [PATCH] Add `table_name_generator` attribute to Meta for dynamic table name generation (#1770) * add table_name_generator attribute to Meta * add changelog * fix typing hints * change combinision behaviour * change to glabal table name generator * remove extra lines * remove print * add annotation for example --- CHANGELOG.rst | 1 + examples/global_table_name_generator.py | 56 +++++++++++++++++++++++++ tests/test_table_name.py | 43 +++++++++++++++++++ tortoise/__init__.py | 27 +++++++++++- 4 files changed, 125 insertions(+), 2 deletions(-) create mode 100644 examples/global_table_name_generator.py create mode 100644 tests/test_table_name.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 100f6b270..dca756084 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -38,6 +38,7 @@ Added ^^^^^ - Add POSIX Regex support for PostgreSQL and MySQL (#1714) - support app=None for tortoise.contrib.fastapi.RegisterTortoise (#1733) +- Added ``table_name_generator`` param to `Tortoise.init` to allow global dynamic table name generation (#1770) 0.21.6 <../0.21.6>`_ - 2024-08-17 ------ diff --git a/examples/global_table_name_generator.py b/examples/global_table_name_generator.py new file mode 100644 index 000000000..0bf43eb0c --- /dev/null +++ b/examples/global_table_name_generator.py @@ -0,0 +1,56 @@ +""" +This example demonstrates how to use the global table name generator to automatically +generate snake_case table names for all models, and how explicit table names take precedence. +""" + +from tortoise import Tortoise, fields, run_async +from tortoise.models import Model + + +def snake_case_table_names(cls): + """Convert CamelCase class name to snake_case table name""" + name = cls.__name__ + return "".join(["_" + c.lower() if c.isupper() else c for c in name]).lstrip("_") + + +class UserProfile(Model): + id = fields.IntField(primary_key=True) + name = fields.TextField() + created_at = fields.DatetimeField(auto_now_add=True) + + def __str__(self): + return self.name + + +class BlogPost(Model): + id = fields.IntField(primary_key=True) + title = fields.TextField() + author: fields.ForeignKeyRelation[UserProfile] = fields.ForeignKeyField( + "models.UserProfile", related_name="posts" + ) + + class Meta: + table = "custom_blog_posts" + + def __str__(self): + return self.title + + +async def run(): + # Initialize with snake_case table name generator + await Tortoise.init( + db_url="sqlite://:memory:", + modules={"models": ["__main__"]}, + table_name_generator=snake_case_table_names, + ) + await Tortoise.generate_schemas() + + # UserProfile uses generated name, BlogPost uses explicit table name + print(f"UserProfile table name: {UserProfile._meta.db_table}") # >>> user_profile + print(f"BlogPost table name: {BlogPost._meta.db_table}") # >>> custom_blog_posts + + await Tortoise.close_connections() + + +if __name__ == "__main__": + run_async(run()) diff --git a/tests/test_table_name.py b/tests/test_table_name.py new file mode 100644 index 000000000..53cf5d4b3 --- /dev/null +++ b/tests/test_table_name.py @@ -0,0 +1,43 @@ +from typing import Type + +from tortoise import Tortoise, fields +from tortoise.contrib.test import SimpleTestCase +from tortoise.models import Model + + +def table_name_generator(model_cls: Type[Model]): + return f"test_{model_cls.__name__.lower()}" + + +class Tournament(Model): + id = fields.IntField(pk=True) + name = fields.TextField() + created_at = fields.DatetimeField(auto_now_add=True) + + +class CustomTable(Model): + id = fields.IntField(pk=True) + name = fields.TextField() + + class Meta: + table = "my_custom_table" + + +class TestTableNameGenerator(SimpleTestCase): + async def asyncSetUp(self): + await super().asyncSetUp() + await Tortoise.init( + db_url="sqlite://:memory:", + modules={"models": [__name__]}, + table_name_generator=table_name_generator, + ) + await Tortoise.generate_schemas() + + async def asyncTearDown(self): + await Tortoise.close_connections() + + async def test_glabal_name_generator(self): + self.assertEqual(Tournament._meta.db_table, "test_tournament") + + async def test_custom_table_name_precedence(self): + self.assertEqual(CustomTable._meta.db_table, "my_custom_table") diff --git a/tortoise/__init__.py b/tortoise/__init__.py index 9cdb78d94..9b96b6e0f 100644 --- a/tortoise/__init__.py +++ b/tortoise/__init__.py @@ -7,7 +7,18 @@ from copy import deepcopy from inspect import isclass from types import ModuleType -from typing import Coroutine, Dict, Iterable, List, Optional, Tuple, Type, Union, cast +from typing import ( + Callable, + Coroutine, + Dict, + Iterable, + List, + Optional, + Tuple, + Type, + Union, + cast, +) from pypika import Table @@ -30,6 +41,7 @@ class Tortoise: apps: Dict[str, Dict[str, Type["Model"]]] = {} + table_name_generator: Optional[Callable[[Type["Model"]], str]] = None _inited: bool = False @classmethod @@ -223,7 +235,11 @@ def init_fk_o2o_field(model: Type["Model"], field: str, is_o2o=False) -> None: continue model._meta._inited = True if not model._meta.db_table: - model._meta.db_table = model.__name__.lower() + model._meta.db_table = ( + cls.table_name_generator(model) + if cls.table_name_generator + else (model.__name__.lower()) + ) for field in sorted(model._meta.fk_fields): init_fk_o2o_field(model, field) @@ -396,6 +412,7 @@ async def init( use_tz: bool = False, timezone: str = "UTC", routers: Optional[List[Union[str, Type]]] = None, + table_name_generator: Optional[Callable[[Type["Model"]], str]] = None, ) -> None: """ Sets up Tortoise-ORM. @@ -455,6 +472,10 @@ async def init( Timezone to use, default is UTC. :param routers: A list of db routers str path or module. + :param table_name_generator: + A callable that generates table names. The model class will be passed as its argument. + If not provided, Tortoise will use the lowercase model name as the table name. + Example: ``lambda cls: f"prefix_{cls.__name__.lower()}"`` :raises ConfigurationError: For any configuration error """ @@ -487,6 +508,8 @@ async def init( timezone = config.get("timezone", timezone) # type: ignore routers = config.get("routers", routers) # type: ignore + cls.table_name_generator = table_name_generator + # Mask passwords in logs output passwords = [] for name, info in connections_config.items():