diff --git a/src/dipdup/aerich.py b/src/dipdup/aerich.py index a89bb961a..e69de29bb 100644 --- a/src/dipdup/aerich.py +++ b/src/dipdup/aerich.py @@ -1,23 +0,0 @@ -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from collections.abc import Iterable - from pathlib import Path - from types import ModuleType - - from aerich import Command as AerichCommand # type: ignore[import-untyped] - - -async def create_aerich_command(db_url: str, package: str, migrations_dir: 'Path') -> 'AerichCommand': - """Create and return an `AerichCommand` instance - - The AerichCommand is the entry point to manage database migrations using aerich. - """ - from aerich import Command as AerichCommand - from tortoise.backends.base.config_generator import generate_config - - # TODO: Refactor building the app_modules dict and use here and in the tortoise_wrapper function ? - # Or maybe add the tortoise_config to database config ? - app_modules: dict[str, Iterable[str | ModuleType]] = {'models': [f'{package}.models', 'aerich.models']} - tortoise_config = generate_config(db_url=db_url, app_modules=app_modules) - return AerichCommand(tortoise_config=tortoise_config, app='models', location=migrations_dir.as_posix()) diff --git a/src/dipdup/cli.py b/src/dipdup/cli.py index c696feed0..9ab5e075f 100644 --- a/src/dipdup/cli.py +++ b/src/dipdup/cli.py @@ -673,9 +673,14 @@ async def schema(ctx: click.Context) -> None: Run `dipdup schema init` or `dipdup run` to the run the indexer and it'll be initialized automatically.""" ) - from dipdup.aerich import create_aerich_command + from aerich import Command as AerichCommand # type: ignore[import-untyped] - aerich_command = await create_aerich_command(config.database.connection_string, config.package, migrations_dir) + from dipdup.database import get_tortoise_config + + tortoise_config = get_tortoise_config(config.database.connection_string, config.package) + aerich_command = AerichCommand( + tortoise_config=tortoise_config, app='models', location=migrations_dir.as_posix() + ) await aerich_command.init() ctx.obj['command'] = aerich_command diff --git a/src/dipdup/database.py b/src/dipdup/database.py index 9917822ad..651790abe 100644 --- a/src/dipdup/database.py +++ b/src/dipdup/database.py @@ -3,6 +3,7 @@ import decimal import hashlib import importlib +import importlib.util import logging from collections.abc import AsyncIterator from collections.abc import Iterable @@ -35,6 +36,7 @@ if TYPE_CHECKING: from types import ModuleType + _logger = logging.getLogger(__name__) DEFAULT_CONNECTION_NAME = 'default' @@ -53,6 +55,32 @@ def set_connection(conn: SupportedClient) -> None: connections.set(DEFAULT_CONNECTION_NAME, conn) +def get_tortoise_config(db_url: str, project_models: str | None = None) -> dict[str, Any]: + """Get Tortoise config for the given URL and internal, aerich and project models""" + from tortoise.backends.base.config_generator import generate_config + + app_modules: dict[str, Iterable[str | ModuleType]] = { + 'int_models': ['dipdup.models'], + } + + models = [] + + if project_models: + if not project_models.endswith('.models'): + project_models += '.models' + models.append(project_models) + + if 'sqlite' not in db_url: + import importlib + + if importlib.util.find_spec('aerich') is not None: + models.append('aerich.models') + + app_modules['models'] = models + + return generate_config(db_url=db_url, app_modules=app_modules) + + @asynccontextmanager async def tortoise_wrapper( url: str, @@ -67,15 +95,6 @@ async def tortoise_wrapper( if '/tmp/' in url: _logger.warning('Using tmpfs database; data will be lost on reboot') - model_modules: dict[str, Iterable[str | ModuleType]] = { - 'int_models': ['dipdup.models'], - } - - if models: - if not models.endswith('.models'): - models += '.models' - model_modules['models'] = [models, 'dipdup.models'] - # NOTE: Must be called before entering Tortoise context decimal_precision = decimal_precision or guess_decimal_precision(models) set_decimal_precision(decimal_precision) @@ -84,10 +103,7 @@ async def tortoise_wrapper( try: for attempt in range(timeout): try: - await Tortoise.init( - db_url=url, - modules=model_modules, - ) + await Tortoise.init(config=get_tortoise_config(url, models)) conn = get_connection() try: diff --git a/src/dipdup/dipdup.py b/src/dipdup/dipdup.py index 1f517794d..ac7c91608 100644 --- a/src/dipdup/dipdup.py +++ b/src/dipdup/dipdup.py @@ -21,7 +21,6 @@ from tortoise.exceptions import OperationalError from dipdup import env -from dipdup.aerich import create_aerich_command from dipdup.codegen import CodeGenerator from dipdup.codegen import CommonCodeGenerator from dipdup.codegen import generate_environments @@ -937,9 +936,16 @@ async def _initialize_migrations(self) -> None: migrations_dir = self._ctx.package.migrations try: - aerich_command = await create_aerich_command( - self._config.database.connection_string, self._config.package, migrations_dir + from aerich import Command as AerichCommand # type: ignore[import-untyped] + + from dipdup.database import get_tortoise_config + + tortoise_config = get_tortoise_config(self._config.database.connection_string, self._config.package) + aerich_command = AerichCommand( + tortoise_config=tortoise_config, app='models', location=migrations_dir.as_posix() ) + await aerich_command.init() + _logger.info("Initializing database migrations at '%s'", migrations_dir) await aerich_command.init_db(safe=True) except ModuleNotFoundError as e: