From 86eb6b2f0dd1fdddf5582d7fd82d67555e4fcaba Mon Sep 17 00:00:00 2001 From: wangxin688 <182467653@qq.com> Date: Wed, 15 May 2024 08:54:06 +0800 Subject: [PATCH] feat(new-arch): update to new arch --- .env_example | 2 +- .python-version | 2 +- README.md | 4 + alembic/env.py | 2 +- deploy/init_pg.py | 4 +- pyproject.toml | 3 +- requirements-dev.lock | 95 ++++++++- requirements.lock | 70 ++++++- src/app.py | 12 +- src/config.py | 42 ---- src/core/__init__.py | 1 + src/{ => core}/_types.py | 11 +- src/core/config.py | 47 +++++ src/{auth => core/database}/__init__.py | 0 src/core/database/engine.py | 122 ++++++++++++ src/{db => core/database}/session.py | 12 +- src/core/database/types/__init__.py | 29 +++ src/core/database/types/annotated.py | 19 ++ src/core/database/types/datetime.py | 32 +++ src/core/database/types/encrypted_string.py | 131 +++++++++++++ src/core/database/types/enum.py | 34 ++++ src/core/database/types/guid.py | 97 ++++++++++ src/core/errors/_error.py | 20 ++ .../errors/auth_exceptions.py} | 26 +-- .../errors/base_exceptions.py} | 0 src/{utils => core/models}/__init__.py | 0 src/{db => core/models}/base.py | 10 +- src/core/models/mixins/__init__.py | 5 + .../models/mixins/audit_log.py} | 54 +----- src/core/models/mixins/audit_time.py | 14 ++ src/core/models/mixins/audit_user.py | 50 +++++ src/core/repositories/__init__.py | 3 + .../repositories/repository.py} | 182 +++++++++++------- src/{consts.py => core/utils/__init__.py} | 0 src/{ => core/utils}/cbv.py | 0 src/{ => core/utils}/context.py | 3 +- src/{ => core/utils}/i18n.py | 4 +- src/core/utils/singleton.py | 25 +++ src/{ => core/utils}/validators.py | 0 src/db/__init__.py | 2 +- src/db/_types.py | 68 ------- src/deps.py | 30 +-- src/enums.py | 12 -- .../auth/__init__.py} | 0 src/features/auth/consts.py | 5 + src/features/auth/graphql.py | 0 src/{ => features}/auth/models.py | 24 +-- .../auth/repositories.py} | 31 +-- src/{auth/api.py => features/auth/router.py} | 31 ++- src/{ => features}/auth/schemas.py | 2 +- src/{ => features/auth}/security.py | 2 +- src/features/auth/utils.py | 0 src/internal/__init__ | 0 src/internal/api.py | 2 +- src/libs/__init__.py | 0 src/libs/redis/__init__.py | 0 src/{utils => libs/redis}/cache.py | 2 +- src/libs/redis/rate_limiter.py | 0 src/libs/redis/redis.py | 0 src/loggers.py | 6 +- src/register/__init__.py | 0 src/{ => register}/middlewares.py | 4 +- src/{ => register}/routers.py | 0 src/utils/singleton.py | 13 -- tests/test_pydantic.py | 34 ++++ 65 files changed, 1072 insertions(+), 363 deletions(-) delete mode 100644 src/config.py create mode 100644 src/core/__init__.py rename src/{ => core}/_types.py (91%) create mode 100644 src/core/config.py rename src/{auth => core/database}/__init__.py (100%) create mode 100644 src/core/database/engine.py rename src/{db => core/database}/session.py (58%) create mode 100644 src/core/database/types/__init__.py create mode 100644 src/core/database/types/annotated.py create mode 100644 src/core/database/types/datetime.py create mode 100644 src/core/database/types/encrypted_string.py create mode 100644 src/core/database/types/enum.py create mode 100644 src/core/database/types/guid.py create mode 100644 src/core/errors/_error.py rename src/{exceptions.py => core/errors/auth_exceptions.py} (85%) rename src/{errors.py => core/errors/base_exceptions.py} (100%) rename src/{utils => core/models}/__init__.py (100%) rename src/{db => core/models}/base.py (74%) create mode 100644 src/core/models/mixins/__init__.py rename src/{db/mixins.py => core/models/mixins/audit_log.py} (68%) create mode 100644 src/core/models/mixins/audit_time.py create mode 100644 src/core/models/mixins/audit_user.py create mode 100644 src/core/repositories/__init__.py rename src/{db/dtobase.py => core/repositories/repository.py} (90%) rename src/{consts.py => core/utils/__init__.py} (100%) rename src/{ => core/utils}/cbv.py (100%) rename src/{ => core/utils}/context.py (80%) rename src/{ => core/utils}/i18n.py (92%) create mode 100644 src/core/utils/singleton.py rename src/{ => core/utils}/validators.py (100%) delete mode 100644 src/db/_types.py delete mode 100644 src/enums.py rename src/{utils/rate_limiter.py => features/auth/__init__.py} (100%) create mode 100644 src/features/auth/consts.py create mode 100644 src/features/auth/graphql.py rename src/{ => features}/auth/models.py (86%) rename src/{auth/services.py => features/auth/repositories.py} (61%) rename src/{auth/api.py => features/auth/router.py} (91%) rename src/{ => features}/auth/schemas.py (97%) rename src/{ => features/auth}/security.py (98%) create mode 100644 src/features/auth/utils.py create mode 100644 src/internal/__init__ create mode 100644 src/libs/__init__.py create mode 100644 src/libs/redis/__init__.py rename src/{utils => libs/redis}/cache.py (99%) create mode 100644 src/libs/redis/rate_limiter.py create mode 100644 src/libs/redis/redis.py create mode 100644 src/register/__init__.py rename src/{ => register}/middlewares.py (95%) rename src/{ => register}/routers.py (100%) delete mode 100644 src/utils/singleton.py create mode 100644 tests/test_pydantic.py diff --git a/.env_example b/.env_example index 4f754a5..b1ee515 100644 --- a/.env_example +++ b/.env_example @@ -6,7 +6,7 @@ REFRESH_TOKEN_EXPIRE_MINUTES=11520 SQLALCHEMY_DATABASE_URI=postgresql+asyncpg://demo:91fb8e9e009f5b9ce1854d947e6fe4a3@localhost:5432/naas REDIS_DSN=redis://:cfe1c2c4703abb205d71abdc07cc3f3d@localhost:6379 -APP_ENV=PRD +APP_ENV=PROD # docker compose DEFAULT_DB_PASSWORD=91fb8e9e009f5b9ce1854d947e6fe4a3 diff --git a/.python-version b/.python-version index 375f5ca..871f80a 100644 --- a/.python-version +++ b/.python-version @@ -1 +1 @@ -3.11.6 +3.12.3 diff --git a/README.md b/README.md index d8bc027..2ed8b60 100644 --- a/README.md +++ b/README.md @@ -14,8 +14,12 @@ Please notice that this project is still working in progress. 11. I18N support for backed db and error message. 12. X-request-id for logging and request. 13. ... + +## Project Structure + ## planning 1. Fix some errors and type hint issues 2. Enhance DTO base for better crud support. 3. Release beta version. 4. enhance docs. + diff --git a/alembic/env.py b/alembic/env.py index e9a8a0d..cf0e4ca 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -4,7 +4,7 @@ from sqlalchemy import pool, Connection from sqlalchemy.ext.asyncio import AsyncEngine import asyncio -from src import config as app_config +from src.core import config as app_config from alembic import context diff --git a/deploy/init_pg.py b/deploy/init_pg.py index 755ac13..a2d3429 100644 --- a/deploy/init_pg.py +++ b/deploy/init_pg.py @@ -4,8 +4,8 @@ from sqlalchemy.ext.asyncio import AsyncSession from src.auth.models import Group, Role, User -from src.db.session import async_session -from src.enums import ReservedRoleSlug +from src.core.database.session import async_session +from src.features.auth.consts import ReservedRoleSlug async def create_pg_extensions(session: AsyncSession) -> None: diff --git a/pyproject.toml b/pyproject.toml index 0bee71f..8c32947 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,9 +20,10 @@ dependencies = [ "pydantic_extra_types>=2.4.1", "phonenumbers>=8.13.27", "asyncpg>=0.29.0", + "sqladmin>=0.16.1", ] readme = "README.md" -requires-python = ">= 3.11" +requires-python = ">= 3.12" # [project.scripts] # hello = "naas-backend:hello" diff --git a/requirements-dev.lock b/requirements-dev.lock index aaaa110..27cfda7 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -5,69 +5,158 @@ # pre: false # features: [] # all-features: false +# with-sources: false -e file:. alembic==1.13.1 + # via fastapi-enterprise-template annotated-types==0.6.0 + # via pydantic anyio==4.2.0 -async-timeout==4.0.3 + # via httpx + # via starlette + # via watchfiles asyncpg==0.29.0 + # via fastapi-enterprise-template black==23.12.1 certifi==2023.11.17 + # via httpcore + # via httpx + # via sentry-sdk cfgv==3.4.0 + # via pre-commit click==8.1.7 + # via black + # via uvicorn coverage==7.4.0 + # via pytest-cov distlib==0.3.8 + # via virtualenv fastapi==0.108.0 + # via fastapi-enterprise-template filelock==3.13.1 + # via virtualenv greenlet==3.0.3 + # via sqlalchemy h11==0.14.0 + # via httpcore + # via uvicorn httpcore==1.0.2 + # via httpx httptools==0.6.1 + # via uvicorn httpx==0.26.0 + # via fastapi-enterprise-template identify==2.5.33 + # via pre-commit idna==3.6 + # via anyio + # via httpx iniconfig==2.0.0 + # via pytest +jinja2==3.1.3 + # via sqladmin mako==1.3.0 + # via alembic markupsafe==2.1.3 + # via jinja2 + # via mako + # via wtforms mypy==1.8.0 mypy-extensions==1.0.0 + # via black + # via mypy nodeenv==1.8.0 + # via pre-commit numpy==1.26.3 + # via pandas packaging==23.2 + # via black + # via pytest pandas==2.1.4 + # via fastapi-enterprise-template passlib==1.7.4 + # via fastapi-enterprise-template pathspec==0.12.1 + # via black phonenumbers==8.13.27 + # via fastapi-enterprise-template platformdirs==4.1.0 + # via black + # via virtualenv pluggy==1.3.0 + # via pytest pre-commit==3.6.0 pydantic==2.5.3 + # via fastapi + # via pydantic-extra-types + # via pydantic-settings pydantic-core==2.14.6 + # via pydantic pydantic-extra-types==2.4.1 + # via fastapi-enterprise-template pydantic-settings==2.1.0 + # via fastapi-enterprise-template pyjwt==2.8.0 + # via fastapi-enterprise-template pytest==7.4.4 + # via pytest-asyncio + # via pytest-cov pytest-asyncio==0.23.3 pytest-cov==4.1.0 python-dateutil==2.8.2 + # via pandas python-dotenv==1.0.0 + # via pydantic-settings + # via uvicorn +python-multipart==0.0.9 + # via sqladmin pytz==2023.3.post1 + # via pandas pyyaml==6.0.1 + # via pre-commit + # via uvicorn redis==5.0.1 + # via fastapi-enterprise-template ruff==0.1.11 sentry-sdk==1.39.2 + # via fastapi-enterprise-template +setuptools==69.0.3 + # via nodeenv six==1.16.0 + # via python-dateutil sniffio==1.3.0 + # via anyio + # via httpx +sqladmin==0.16.1 + # via fastapi-enterprise-template sqlalchemy==2.0.25 + # via alembic + # via fastapi-enterprise-template + # via sqladmin starlette==0.32.0.post1 + # via fastapi + # via sqladmin typing-extensions==4.9.0 + # via alembic + # via fastapi + # via mypy + # via pydantic + # via pydantic-core + # via sqlalchemy tzdata==2023.4 + # via pandas urllib3==2.1.0 + # via sentry-sdk uvicorn==0.25.0 + # via fastapi-enterprise-template uvloop==0.19.0 + # via uvicorn virtualenv==20.25.0 + # via pre-commit watchfiles==0.21.0 + # via uvicorn websockets==12.0 -# The following packages are considered to be unsafe in a requirements file: -setuptools==69.0.3 + # via uvicorn +wtforms==3.1.2 + # via sqladmin diff --git a/requirements.lock b/requirements.lock index 7d2d36b..227cf63 100644 --- a/requirements.lock +++ b/requirements.lock @@ -5,47 +5,115 @@ # pre: false # features: [] # all-features: false +# with-sources: false -e file:. alembic==1.13.1 + # via fastapi-enterprise-template annotated-types==0.6.0 + # via pydantic anyio==4.2.0 -async-timeout==4.0.3 + # via httpx + # via starlette + # via watchfiles asyncpg==0.29.0 + # via fastapi-enterprise-template certifi==2023.11.17 + # via httpcore + # via httpx + # via sentry-sdk click==8.1.7 + # via uvicorn fastapi==0.108.0 + # via fastapi-enterprise-template greenlet==3.0.3 + # via sqlalchemy h11==0.14.0 + # via httpcore + # via uvicorn httpcore==1.0.2 + # via httpx httptools==0.6.1 + # via uvicorn httpx==0.26.0 + # via fastapi-enterprise-template idna==3.6 + # via anyio + # via httpx +jinja2==3.1.3 + # via sqladmin mako==1.3.0 + # via alembic markupsafe==2.1.3 + # via jinja2 + # via mako + # via wtforms numpy==1.26.3 + # via pandas pandas==2.1.4 + # via fastapi-enterprise-template passlib==1.7.4 + # via fastapi-enterprise-template phonenumbers==8.13.27 + # via fastapi-enterprise-template pydantic==2.5.3 + # via fastapi + # via pydantic-extra-types + # via pydantic-settings pydantic-core==2.14.6 + # via pydantic pydantic-extra-types==2.4.1 + # via fastapi-enterprise-template pydantic-settings==2.1.0 + # via fastapi-enterprise-template pyjwt==2.8.0 + # via fastapi-enterprise-template python-dateutil==2.8.2 + # via pandas python-dotenv==1.0.0 + # via pydantic-settings + # via uvicorn +python-multipart==0.0.9 + # via sqladmin pytz==2023.3.post1 + # via pandas pyyaml==6.0.1 + # via uvicorn redis==5.0.1 + # via fastapi-enterprise-template sentry-sdk==1.39.2 + # via fastapi-enterprise-template six==1.16.0 + # via python-dateutil sniffio==1.3.0 + # via anyio + # via httpx +sqladmin==0.16.1 + # via fastapi-enterprise-template sqlalchemy==2.0.25 + # via alembic + # via fastapi-enterprise-template + # via sqladmin starlette==0.32.0.post1 + # via fastapi + # via sqladmin typing-extensions==4.9.0 + # via alembic + # via fastapi + # via pydantic + # via pydantic-core + # via sqlalchemy tzdata==2023.4 + # via pandas urllib3==2.1.0 + # via sentry-sdk uvicorn==0.25.0 + # via fastapi-enterprise-template uvloop==0.19.0 + # via uvicorn watchfiles==0.21.0 + # via uvicorn websockets==12.0 + # via uvicorn +wtforms==3.1.2 + # via sqladmin diff --git a/src/app.py b/src/app.py index d65c89d..97a5658 100644 --- a/src/app.py +++ b/src/app.py @@ -7,13 +7,13 @@ from starlette.middleware.cors import CORSMiddleware from starlette.middleware.errors import ServerErrorMiddleware -from src.config import settings +from src.core.config import settings +from src.core.error.auth_exceptions import default_exception_handler, exception_handlers, sentry_ignore_errors from src.enums import Env -from src.exceptions import default_exception_handler, exception_handlers, sentry_ignore_errors -from src.middlewares import RequestMiddleware +from src.libs.redis import cache from src.openapi import openapi_description -from src.routers import router -from src.utils import cache +from src.register.middlewares import RequestMiddleware +from src.register.routers import router def create_app() -> FastAPI: @@ -26,7 +26,7 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]: # noqa: ARG001 yield await pool.disconnect() - if settings.ENV == Env.PRD.name: # noqa: SIM300 + if settings.ENV == Env.PROD.name: # noqa: SIM300 sentry_sdk.init( dsn=settings.WEB_SENTRY_DSN, sample_rate=settings.SENTRY_SAMPLE_RATE, diff --git a/src/config.py b/src/config.py deleted file mode 100644 index f0e9f05..0000000 --- a/src/config.py +++ /dev/null @@ -1,42 +0,0 @@ -import tomllib -from pathlib import Path - -from pydantic import Field -from pydantic_settings import BaseSettings, SettingsConfigDict - -from src.enums import Env - -PROJECT_DIR = Path(__file__).parent.parent -with Path.open(f"{PROJECT_DIR}/pyproject.toml", "rb") as f: - PYPROJECT_CONTENT = tomllib.load(f)["project"] - - -class Settings(BaseSettings): - SECRET_KEY: str - SECURITY_BCRYPT_ROUNDS: int = 4 - ACCESS_TOKEN_EXPIRE_MINUTES: int = 120 - REFRESH_TOKEN_EXPIRE_MINUTES: int = 11520 - BACKEND_CORS: list[str] = Field(default=["*"]) - ALLOWED_HOST: list[str] = Field(default=["*"]) - - PROJECT_NAME: str = PYPROJECT_CONTENT["name"] - VERSION: str = PYPROJECT_CONTENT["version"] - DESCRIPTION: str = PYPROJECT_CONTENT["description"] - LIMITED_RATE: tuple[int, int] = (20, 10) - - WEB_SENTRY_DSN: str | None = None - CELERY_SENTRY_DSN: str | None = None - SENTRY_SAMPLE_RATE: float = 1.0 - SENTRY_TRACES_SAMPLE_RATE: float | None = 1.0 - - SQLALCHEMY_DATABASE_URI: str - DATABASE_POOL_SIZE: int | None = 50 - DATABASE_POOL_MAX_OVERFLOW: int | None = 10 - REDIS_DSN: str - - ENV: str = Env.DEV.name - - model_config = SettingsConfigDict(env_file=f"{PROJECT_DIR}/.env", case_sensitive=True, extra="allow") - - -settings = Settings() # type: ignore # noqa: PGH003 diff --git a/src/core/__init__.py b/src/core/__init__.py new file mode 100644 index 0000000..32c823f --- /dev/null +++ b/src/core/__init__.py @@ -0,0 +1 @@ +"""DDD infrastructure layer. named as core for easy understanding(migrated from MVC).""" diff --git a/src/_types.py b/src/core/_types.py similarity index 91% rename from src/_types.py rename to src/core/_types.py index 5ffdd89..b2cff7f 100644 --- a/src/_types.py +++ b/src/core/_types.py @@ -1,5 +1,4 @@ from datetime import datetime -from enum import Enum from typing import Annotated, Generic, Literal, ParamSpec, TypeAlias, TypedDict, TypeVar import pydantic @@ -7,7 +6,7 @@ from pydantic import ConfigDict, Field, StringConstraints from pydantic.functional_validators import BeforeValidator -from src.validators import items_to_list, mac_address_validator +from src.core.utils.validators import items_to_list, mac_address_validator T = TypeVar("T") P = ParamSpec("P") @@ -58,14 +57,6 @@ class ListT(BaseModel, Generic[T]): results: list[T] | None = None -class AppStrEnum(str, Enum): - def __str__(self) -> str: - return str.__str__(self) - - @classmethod - def to_list(cls) -> list[str]: - return [c.value for c in cls] - class AuditTimeQuery(BaseModel): created_at__lte: datetime diff --git a/src/core/config.py b/src/core/config.py new file mode 100644 index 0000000..bcb04d8 --- /dev/null +++ b/src/core/config.py @@ -0,0 +1,47 @@ +import tomllib +from enum import StrEnum +from pathlib import Path + +from pydantic import Field +from pydantic_settings import BaseSettings, SettingsConfigDict + +PROJECT_DIR = Path(__file__).parent.parent.parent +with Path.open(Path(f"{PROJECT_DIR}/pyproject.toml"), "rb") as f: + PYPROJECT_CONTENT = tomllib.load(f)["project"] + + +class _Env(StrEnum): + DEV = "dev" + PROD = "prod" + STAGE = "stage" + + +class Settings(BaseSettings): + SECRET_KEY: str = Field(default="ea90084454f1f94244f779d605286ae482ffb1f33570dcd1f6a683e5c002b492") + SECURITY_BCRYPT_ROUNDS: int = Field(default=4) + ACCESS_TOKEN_EXPIRE_MINUTES: int = Field(default=120) + REFRESH_TOKEN_EXPIRE_MINUTES: int = Field(default=11520) + BACKEND_CORS: list[str] = Field(default=["*"]) + ALLOWED_HOST: list[str] = Field(default=["*"]) + + PROJECT_NAME: str = Field(default=PYPROJECT_CONTENT["name"]) + VERSION: str = Field(default=PYPROJECT_CONTENT["version"]) + DESCRIPTION: str = Field(default=PYPROJECT_CONTENT["description"]) + LIMITED_RATE: tuple[int, int] = Field(default=(20, 10)) + + WEB_SENTRY_DSN: str | None = Field(default=None) + CELERY_SENTRY_DSN: str | None = Field(default=None) + SENTRY_SAMPLE_RATE: float = Field(default=1.0, gt=0.0, le=1.0) + SENTRY_TRACES_SAMPLE_RATE: float | None = Field(default=None, gt=0.0, le=1.0) + + SQLALCHEMY_DATABASE_URI: str = Field(default= + "postgresql+asyncpg://demo:91fb8e9e009f5b9ce1854d947e6fe4a3@localhost:5432/demo") + DATABASE_POOL_SIZE: int | None = Field(default=50) + DATABASE_POOL_MAX_OVERFLOW: int | None = Field(default=10) + REDIS_DSN: str = Field(default="redis://:cfe1c2c4703abb205d71abdc07cc3f3d@localhost:6379") + + ENV: str = _Env.DEV.name + + model_config = SettingsConfigDict(env_file=f"{PROJECT_DIR}/.env", case_sensitive=True, extra="allow") + +settings = Settings() diff --git a/src/auth/__init__.py b/src/core/database/__init__.py similarity index 100% rename from src/auth/__init__.py rename to src/core/database/__init__.py diff --git a/src/core/database/engine.py b/src/core/database/engine.py new file mode 100644 index 0000000..1bc9c65 --- /dev/null +++ b/src/core/database/engine.py @@ -0,0 +1,122 @@ +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, Literal + +from pydantic import BaseModel, Field + +from src.core.utils._serailization import json_dumps, json_loads +from src.core.utils.dataclass import Empty, EmptyType + +if TYPE_CHECKING: + from sqlalchemy.engine.interfaces import IsolationLevel + from sqlalchemy.pool import Pool + +__all__ = ("EngineConfig",) + +type _EchoFlagType = "None | bool | Literal['debug']" + + +class EngineConfig(BaseModel): + """Sqlalchemy engine configuration. + see details: https://docs.sqlalchemy.org/en/20/core/engines.html + """ + + connect_args: dict[Any, Any] | EmptyType = Field( + default=Empty, + description= + """A dictionary of arguments which will be passed directly + to the DBAPI's ``connect()`` method as keyword arguments""", + ) + + echo: _EchoFlagType | EmptyType = Field( + default=Empty, + description="""If ``True``, the Engine will log all statements as well as a ``repr()`` of their parameter lists + to the defaultlog handler, which defaults to ``sys.stdout`` for output. If set to the string "debug", result + rows will be printed to the standard output as well. The echo attribute of Engine can be modified at any time + to turn logging on and off; direct control of logging is also available using the standard Python logging module + """, + ) + + echo_pool: _EchoFlagType | EmptyType = Empty + + isolation_level: "IsolationLevel| EmptyType" = Field( + default=Empty, + description="""Optional string name of an isolation level which will be set on all new connections + unconditionally Isolation levels are typically some subset of the string names "SERIALIZABLE", + "REPEATABLE READ", "READ COMMITTED", "READ UNCOMMITTED" and "AUTOCOMMIT" based on backend.""" + ) + + json_serializer: Callable[[str], Any] = json_dumps + + json_deserializer: Callable[[Any], str] = json_loads + + max_over_flow: int | EmptyType = Field( + default=Empty, + description="""The number of connections to allow in connection pool “overflow”, + that is connections that can be opened above and beyond the pool_size setting, + which defaults to five. This is only used with:class:`QueuePool `.""" + ) + + pool_size: int | EmptyType = Field( + default=Empty, + description="""The number of connections to keep open inside the connection pool. This used with + :class:`QueuePool ` as well as + :class:`SingletonThreadPool `. With + :class:`QueuePool `, a pool_size setting of ``0`` indicates no limit; to disable pooling, + set ``poolclass`` to :class:`NullPool ` instead.""" + ) + + pool_recycle: int | EmptyType = Field( + default=Empty, + description="""This setting causes the pool to recycle connections after the given number of seconds has passed. + It defaults to``-1``, or no timeout. For example, setting to ``3600`` means connections will be recycled after + one hour. Note that MySQL in particular will disconnect automatically if no activity is detected on a connection + for eight hours (although this is configurable with the MySQLDB connection itself and the server configuration + as well).""" + ) + + pool_use_lifo: bool | EmptyType = Field( + default=Empty, + description="""Use LIFO (last-in-first-out) when retrieving connections from :class:`QueuePool ` instead of FIFO (first-in-first-out). Using LIFO, a server-side timeout scheme can reduce the number + of connections used during non-peak periods of use. When planning for server-side timeouts, ensure that a + recycle or pre-ping strategy is in use to gracefully handle stale connections.""" + ) + + pool_pre_ping: bool | EmptyType = Field( + default=Empty, + description="""If True will enable the connection pool “pre-ping” feature that tests connections for liveness + upon eachcheckout.""" + ) + + pool_timeout: int | EmptyType = Field( + default=Empty, + description="""Number of seconds to wait before giving up on getting a connection from the pool. + This is only used with :class:`QueuePool `. This can be a float but + is subject to the limitations of Python time functions which may not be reliable in the tens of milliseconds.""" + ) + + pool: "Pool| EmptyType" = Field( + default=Empty, + description="""An already-constructed instance of :class:`Pool `, such as a + :class:`QueuePool ` instance. If non-None, this pool will be used directly as the + underlying connection pool for the engine, bypassing whatever connection parameters are present in the URL argument. + For information on constructing connection pools manually, see + `Connection Pooling `_.""" + ) + + poolclass: "type[Pool]| EmptyType" = Field( + default=Empty, + description="""A :class:`Pool ` subclass, which will be used to create a connection pool + instance using the connection parameters given in the URL. Note this differs from pool in that you don`t + actually instantiate the pool in this case, you just indicate what type of pool to be used.""" + ) + + query_cache_size: int | EmptyType = Field( + default=Empty, + description="""Size of the cache used to cache the SQL string form of queries. Set to zero to disable caching. + + See :attr:`query_cache_size ` for more info. + """ + ) + + pool_reset_on_return: Literal["reset", "rollback", "commit"] | EmptyType = Empty diff --git a/src/db/session.py b/src/core/database/session.py similarity index 58% rename from src/db/session.py rename to src/core/database/session.py index 520a759..8327a14 100644 --- a/src/db/session.py +++ b/src/core/database/session.py @@ -1,8 +1,13 @@ import logging +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine -from src.config import settings +from src.core.config import settings + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncSession logger = logging.getLogger(__name__) @@ -14,3 +19,8 @@ max_overflow=settings.DATABASE_POOL_MAX_OVERFLOW, ) async_session = async_sessionmaker(async_engine, autoflush=False, expire_on_commit=False) + + +async def get_session() -> AsyncGenerator["AsyncSession", None]: + async with async_session() as session: + yield session diff --git a/src/core/database/types/__init__.py b/src/core/database/types/__init__.py new file mode 100644 index 0000000..d6f1f94 --- /dev/null +++ b/src/core/database/types/__init__.py @@ -0,0 +1,29 @@ +from src.core.database.types.annotated import ( + bool_false, + bool_true, + date_optional, + date_required, + datetime_optional, + datetime_required, + int_pk, + uuid_pk, +) +from src.core.database.types.datetime import DateTimeTZ +from src.core.database.types.encrypted_string import EncryptedString +from src.core.database.types.enum import IntegerEnum +from src.core.database.types.guid import GUID + +__all__ = ( + "DateTimeTZ", + "EncryptedString", + "IntegerEnum", + "GUID", + "uuid_pk", + "int_pk", + "bool_true", + "bool_false", + "datetime_optional", + "date_optional", + "datetime_required", + "date_required", +) diff --git a/src/core/database/types/annotated.py b/src/core/database/types/annotated.py new file mode 100644 index 0000000..4dd10b8 --- /dev/null +++ b/src/core/database/types/annotated.py @@ -0,0 +1,19 @@ +import uuid +from datetime import date, datetime +from typing import Annotated + +from sqlalchemy import Boolean, Date, Integer +from sqlalchemy.orm import mapped_column +from sqlalchemy.sql import expression + +from src.core.database.types.datetime import DateTimeTZ +from src.core.database.types.guid import GUID + +uuid_pk = Annotated[uuid.UUID, mapped_column(GUID, default=uuid.uuid4, primary_key=True, index=True, nullable=False)] +int_pk = Annotated[int, mapped_column(Integer, primary_key=True, index=True, nullable=False)] +bool_true = Annotated[bool, mapped_column(Boolean, server_default=expression.true(), nullable=False)] +bool_false = Annotated[bool, mapped_column(Boolean, server_default=expression.false(), nullable=False)] +datetime_required = Annotated[datetime, mapped_column(DateTimeTZ, nullable=False)] +datetime_optional = Annotated[datetime, mapped_column(DateTimeTZ, nullable=True)] +date_required = Annotated[date, mapped_column(Date, nullable=False)] +date_optional = Annotated[date | None, mapped_column(Date, nullable=True)] diff --git a/src/core/database/types/datetime.py b/src/core/database/types/datetime.py new file mode 100644 index 0000000..45f0180 --- /dev/null +++ b/src/core/database/types/datetime.py @@ -0,0 +1,32 @@ +from datetime import UTC, datetime +from typing import TYPE_CHECKING + +from sqlalchemy import DateTime +from sqlalchemy.types import TypeDecorator + +if TYPE_CHECKING: + from sqlalchemy.engine import Dialect + + +class DateTimeTZ(TypeDecorator[datetime]): + impl = DateTime(timezone=True) + cache_ok = True + + @property + def python_type(self) -> type[datetime]: + return datetime + + def process_bind_param(self, value: datetime | None, dialect: "Dialect") -> datetime | None: # noqa: ARG002 + if value is None: + return value + if not value.tzinfo: + msg = "tzinfo is required" + raise TypeError(msg) + return value.astimezone(UTC) + + def process_result_value(self, value: datetime | None, dialect: "Dialect") -> datetime | None: # noqa: ARG002 + if value is None: + return value + if value.tzinfo is None: + return value.replace(tzinfo=UTC) + return value diff --git a/src/core/database/types/encrypted_string.py b/src/core/database/types/encrypted_string.py new file mode 100644 index 0000000..3b904e5 --- /dev/null +++ b/src/core/database/types/encrypted_string.py @@ -0,0 +1,131 @@ +import abc +import base64 +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +from cryptography.fernet import Fernet +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from sqlalchemy import String, Text, TypeDecorator +from sqlalchemy import func as sql_func + +from src.core.config import settings + +cryptography = None + +if TYPE_CHECKING: + from sqlalchemy.engine import Dialect + + +class EncryptionBackend(abc.ABC): + def mount_vault(self, key: str | bytes) -> None: + if isinstance(key, str): + key = key.encode() + + @abc.abstractmethod + def init_engine(self, key: bytes | str) -> None: # pragma: nocover + pass + + @abc.abstractmethod + def encrypt(self, value: Any) -> str: # pragma: nocover + pass + + @abc.abstractmethod + def decrypt(self, value: Any) -> str: # pragma: nocover + pass + + +class PGCryptoBackend(EncryptionBackend): + """PG Crypto backend.""" + + def init_engine(self, key: bytes | str) -> None: + if isinstance(key, str): + key = key.encode() + self.passphrase = base64.urlsafe_b64encode(key) + + def encrypt(self, value: Any) -> str: + if not isinstance(value, str): # pragma: nocover + value = repr(value) + value = value.encode() + return sql_func.pgp_sym_encrypt(value, self.passphrase) # type: ignore[return-value] + + def decrypt(self, value: Any) -> str: + if not isinstance(value, str): # pragma: nocover + value = str(value) + return sql_func.pgp_sym_decrypt(value, self.passphrase) # type: ignore[return-value] + + +class FernetBackend(EncryptionBackend): + """Encryption Using a Fernet backend""" + + def mount_vault(self, key: str | bytes) -> None: + if isinstance(key, str): + key = key.encode() + digest = hashes.Hash(hashes.SHA256(), backend=default_backend()) + digest.update(key) + engine_key = digest.finalize() + self.init_engine(engine_key) + + def init_engine(self, key: bytes | str) -> None: + if isinstance(key, str): + key = key.encode() + self.key = base64.urlsafe_b64encode(key) + self.fernet = Fernet(self.key) + + def encrypt(self, value: Any) -> str: + if not isinstance(value, str): + value = repr(value) + value = value.encode() + encrypted = self.fernet.encrypt(value) + return encrypted.decode("utf-8") + + def decrypt(self, value: Any) -> str: + if not isinstance(value, str): # pragma: nocover + value = str(value) + decrypted: str | bytes = self.fernet.decrypt(value.encode()) + if not isinstance(decrypted, str): + decrypted = decrypted.decode("utf-8") + return decrypted + + +class EncryptedString(TypeDecorator[str]): + """Used to store encrypted values in a database""" + + impl = String + cache_ok = True + + def __init__( + self, + key: str | bytes | Callable[[], str | bytes] = settings.SECRET_KEY, + backend: type[EncryptionBackend] = FernetBackend, + ) -> None: + super().__init__() + self.key = key + self.backend = backend() + + @property + def python_type(self) -> type[str]: + return str + + def load_dialect_impl(self, dialect: "Dialect") -> Any: + if dialect.name in {"mysql", "mariadb"}: + return dialect.type_descriptor(Text()) + if dialect.name == "oracle": + return dialect.type_descriptor(String(length=4000)) + return dialect.type_descriptor(String()) + + def process_bind_param(self, value: Any, dialect: "Dialect") -> str | None: # noqa: ARG002 + if value is None: + return value + self.mount_vault() + return self.backend.encrypt(value) + + def process_result_value(self, value: Any, dialect: "Dialect") -> str | None: # noqa: ARG002 + if value is None: + return value + self.mount_vault() + return self.backend.decrypt(value) + + def mount_vault(self) -> None: + key = self.key() if callable(self.key) else self.key + self.backend.mount_vault(key) diff --git a/src/core/database/types/enum.py b/src/core/database/types/enum.py new file mode 100644 index 0000000..bd10922 --- /dev/null +++ b/src/core/database/types/enum.py @@ -0,0 +1,34 @@ +from enum import IntEnum +from typing import TYPE_CHECKING, Any, TypeVar, no_type_check + +from sqlalchemy import Integer +from sqlalchemy.types import TypeDecorator + +if TYPE_CHECKING: + from sqlalchemy.engine import Dialect + +T = TypeVar("T", bound=IntEnum) + + +class IntegerEnum(TypeDecorator[T]): + impl = Integer + cache_ok = True + + def __init__(self, enum_type: type[T]) -> None: + super().__init__() + self.enum_type = enum_type + + @no_type_check + def process_bind_param(self, value: int, dialect: "Dialect") -> int: # noqa: ARG002 + if isinstance(value, self.enum_type): + return value.value + msg = f"expected {self.enum_type.__name__} value, got {value.__class__.__name__}" + raise ValueError(msg) + + @no_type_check + def process_result_value(self, value: int, dialect: "Dialect")-> "T": # noqa: ARG002 + return self.enum_type(value) + + @no_type_check + def copy(self, **kwargs: Any)-> "IntegerEnum[T]": # noqa: ARG002 + return IntegerEnum(self.enum_type) diff --git a/src/core/database/types/guid.py b/src/core/database/types/guid.py new file mode 100644 index 0000000..39fc1a3 --- /dev/null +++ b/src/core/database/types/guid.py @@ -0,0 +1,97 @@ +from base64 import b64decode +from importlib.util import find_spec +from typing import TYPE_CHECKING, Any, cast +from uuid import UUID + +from sqlalchemy.dialects.mssql import UNIQUEIDENTIFIER as MSSQL_UNIQUEIDENTIFIER +from sqlalchemy.dialects.oracle import RAW as ORA_RAW +from sqlalchemy.dialects.postgresql import UUID as PG_UUID +from sqlalchemy.types import BINARY, CHAR, TypeDecorator + +if TYPE_CHECKING: + from collections.abc import Buffer + + from sqlalchemy.engine import Dialect + +UUID_UTILS_INSTALLED = find_spec("uuid_utils") + + +class GUID(TypeDecorator[UUID]): + """Platform-independent GUID type. + + Uses PostgreSQL's UUID type (Postgres, DuckDB, Cockroach), + MSSQL's UNIQUEIDENTIFIER type, Oracle's RAW(16) type, + otherwise uses BINARY(16) or CHAR(32), + storing as stringified hex values. + + Will accept stringified UUIDs as a hexstring or an actual UUID + + """ + + impl = BINARY(16) + cache_ok = True + + @property + def python_type(self) -> type[UUID]: + return UUID + + def __init__(self, *args: Any, binary: bool = True, **kwargs: Any) -> None: # noqa: ARG002 + self.binary = binary + + def load_dialect_impl(self, dialect: "Dialect") -> Any: + if dialect.name in {"postgresql", "duckdb", "cockroachdb"}: + return dialect.type_descriptor(PG_UUID()) + if dialect.name == "oracle": + return dialect.type_descriptor(ORA_RAW(16)) + if dialect.name == "mssql": + return dialect.type_descriptor(MSSQL_UNIQUEIDENTIFIER()) + if self.binary: + return dialect.type_descriptor(BINARY(16)) + return dialect.type_descriptor(CHAR(32)) + + def process_bind_param( + self, + value: bytes | str | UUID | None, + dialect: "Dialect", + ) -> bytes | str | None: + if value is None: + return value + if dialect.name in {"postgresql", "duckdb", "cockroachdb", "mssql"}: + return str(value) + value = self.to_uuid(value) + if value is None: + return value + if dialect.name in {"oracle", "spanner+spanner"}: + return value.bytes + return value.bytes if self.binary else value.hex + + def process_result_value( + self, + value: bytes | str | UUID | None, + dialect: "Dialect", + ) -> UUID | None: + if value is None: + return value + if value.__class__.__name__ == "UUID": + return cast("UUID", value) + if dialect.name == "spanner+spanner": + return UUID(bytes=b64decode(cast("str | Buffer", value))) + if self.binary: + return UUID(bytes=cast("bytes", value)) + return UUID(hex=cast("str", value)) + + @staticmethod + def to_uuid(value: Any) -> UUID | None: + if value.__class__.__name__ == "UUID" or value is None: + return cast("UUID | None", value) + try: + value = UUID(hex=value) + except (TypeError, ValueError): + value = UUID(bytes=value) + return cast("UUID | None", value) + + def compare_values(self, x: Any, y: Any) -> bool: + """Compare two values for equality.""" + if x.__class__.__name__ == "UUID" and y.__class__.__name__ == "UUID": + return cast("bool", x.bytes == y.bytes) + return cast("bool", x == y) diff --git a/src/core/errors/_error.py b/src/core/errors/_error.py new file mode 100644 index 0000000..7e7221f --- /dev/null +++ b/src/core/errors/_error.py @@ -0,0 +1,20 @@ +from typing import Any, NamedTuple, TypedDict + + +class Error(TypedDict): + code: int + message: str + details: list[Any] | None + + +class ErrorCode(NamedTuple): + error: int + message: str + details: list[Any] | None = None + + def dict(self)-> "Error": + return { + "code": self.error, + "message": self.message, + "details": self.details + } diff --git a/src/exceptions.py b/src/core/errors/auth_exceptions.py similarity index 85% rename from src/exceptions.py rename to src/core/errors/auth_exceptions.py index 8b5a0c1..4f6c334 100644 --- a/src/exceptions.py +++ b/src/core/errors/auth_exceptions.py @@ -8,10 +8,10 @@ from fastapi import Request, status from fastapi.responses import JSONResponse -from src import errors -from src.context import locale_ctx, request_id_ctx -from src.errors import ErrorCode -from src.i18n import _ +from src.core.errors import base_exceptions +from src.core.errors.base_exceptions import ErrorCode +from src.core.utils.context import locale_ctx, request_id_ctx +from src.core.utils.i18n import _ _E = NewType("_E", Exception) logger = logging.getLogger(__name__) @@ -107,36 +107,36 @@ def log_exception(exc: type[BaseException] | Exception, logger_trace_info: bool) async def token_invalid_handler(request: Request, exc: TokenInvalidError) -> JSONResponse: log_exception(exc, False) - response_content = errors.ERR_10002.dict() + response_content = base_exceptions.ERR_10002.dict() return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=response_content) async def invalid_token_for_refresh_handler(request: Request, exc: TokenInvalidForRefreshError) -> JSONResponse: log_exception(exc, False) - return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=errors.ERR_10004.dict()) + return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=base_exceptions.ERR_10004.dict()) async def token_expired_handler(request: Request, exc: TokenExpireError) -> JSONResponse: log_exception(exc, False) - return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=errors.ERR_10003.dict()) + return JSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content=base_exceptions.ERR_10003.dict()) async def permission_deny_handler(request: Request, exc: PermissionDenyError) -> JSONResponse: log_exception(exc, False) - return JSONResponse(status_code=status.HTTP_403_FORBIDDEN, content=errors.ERR_10004.dict()) + return JSONResponse(status_code=status.HTTP_403_FORBIDDEN, content=base_exceptions.ERR_10004.dict()) async def resource_not_found_handler(request: Request, exc: NotFoundError) -> JSONResponse: log_exception(exc, True) - error_message = _(errors.ERR_404.message, name=exc.name, filed=exc.field, value=exc.value) - content = {"error": errors.ERR_404.error, "message": error_message} + error_message = _(base_exceptions.ERR_404.message, name=exc.name, filed=exc.field, value=exc.value) + content = {"error": base_exceptions.ERR_404.error, "message": error_message} return JSONResponse(status_code=status.HTTP_404_NOT_FOUND, content=content) async def resource_exist_handler(request: Request, exc: ExistError) -> JSONResponse: log_exception(exc, True) - error_message = _(errors.ERR_409.message, name=exc.name, filed=exc.field, value=exc.value) - content = {"error": errors.ERR_409.error, "message": error_message} + error_message = _(base_exceptions.ERR_409.message, name=exc.name, filed=exc.field, value=exc.value) + content = {"error": base_exceptions.ERR_409.error, "message": error_message} return JSONResponse(status_code=status.HTTP_404_NOT_FOUND, content=content) @@ -155,7 +155,7 @@ def default_exception_handler(request: Request, exc: Exception) -> JSONResponse: log_exception(exc, logger_trace_info=True) return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content={"error": errors.ERR_500.error, "message": _(errors.ERR_500.message, request_id=request_id_ctx.get())}, + content={"error": base_exceptions.ERR_500.error, "message": _(base_exceptions.ERR_500.message, request_id=request_id_ctx.get())}, ) diff --git a/src/errors.py b/src/core/errors/base_exceptions.py similarity index 100% rename from src/errors.py rename to src/core/errors/base_exceptions.py diff --git a/src/utils/__init__.py b/src/core/models/__init__.py similarity index 100% rename from src/utils/__init__.py rename to src/core/models/__init__.py diff --git a/src/db/base.py b/src/core/models/base.py similarity index 74% rename from src/db/base.py rename to src/core/models/base.py index 7302cf5..fc2b34e 100644 --- a/src/db/base.py +++ b/src/core/models/base.py @@ -1,9 +1,9 @@ -from typing import Any, ClassVar +from typing import Any, ClassVar, TypeVar from fastapi.encoders import jsonable_encoder from sqlalchemy.orm import DeclarativeBase -from src._types import VisibleName +from src.core._types import VisibleName class Base(DeclarativeBase): @@ -16,3 +16,9 @@ def dict(self, exclude: set[str] | None = None, native_dict: bool = False) -> di if not native_dict: return jsonable_encoder(self, exclude=exclude) return {c.name: getattr(self, c.name) for c in self.__table__.columns if c.name not in exclude} + + def __getattribute__(self, name: str) -> Any: + return super().__getattribute__(name) + + +ModelT = TypeVar("ModelT", bound=Base) diff --git a/src/core/models/mixins/__init__.py b/src/core/models/mixins/__init__.py new file mode 100644 index 0000000..a77c364 --- /dev/null +++ b/src/core/models/mixins/__init__.py @@ -0,0 +1,5 @@ +from src.core.models.mixins.audit_log import AuditLog, AuditLogMixin +from src.core.models.mixins.audit_time import AuditTimeMixin +from src.core.models.mixins.audit_user import AuditUserMixin + +__all__ = ("AuditLogMixin", "AuditTimeMixin", "AuditUserMixin", "AuditLog") diff --git a/src/db/mixins.py b/src/core/models/mixins/audit_log.py similarity index 68% rename from src/db/mixins.py rename to src/core/models/mixins/audit_log.py index 39f65c5..457ffad 100644 --- a/src/db/mixins.py +++ b/src/core/models/mixins/audit_log.py @@ -1,28 +1,23 @@ from typing import TYPE_CHECKING from fastapi.encoders import jsonable_encoder -from sqlalchemy import DateTime, ForeignKey, Integer, String, event, func, insert, inspect -from sqlalchemy.dialects.postgresql import JSON +from sqlalchemy import JSON, ForeignKey, Integer, String, event, func, insert, inspect from sqlalchemy.engine import Connection from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.orm import Mapped, Mapper, class_mapper, mapped_column, relationship from sqlalchemy.orm.attributes import get_history -from src.context import orm_diff_ctx, request_id_ctx, user_ctx -from src.db._types import int_pk -from src.db.base import Base +from src.core.database.types import DateTimeTZ, int_pk +from src.core.models.base import Base +from src.core.utils.context import orm_diff_ctx, request_id_ctx, user_ctx if TYPE_CHECKING: from datetime import datetime - from src.auth.models import User - from src.db.dtobase import ModelT + from src.core.models.base import ModelT + from src.features.auth.models import User -class AuditTimeMixin: - created_at: Mapped["datetime"] = mapped_column(DateTime(timezone=True), default=func.now(), index=True) - updated_at: Mapped["datetime"] = mapped_column(DateTime(timezone=True), default=func.now(), onupdate=func.now()) - def get_object_change(obj: Mapper) -> dict: insp = inspect(obj) @@ -47,7 +42,7 @@ def get_object_change(obj: Mapper) -> dict: class AuditLog: id: Mapped[int_pk] - created_at: Mapped["datetime"] = mapped_column(DateTime(timezone=True), default=func.now()) + created_at: Mapped["datetime"] = mapped_column(DateTimeTZ, default=func.now()) request_id: Mapped[str] action: Mapped[str] = mapped_column(String, nullable=False) diff: Mapped[dict | None] = mapped_column(JSON) @@ -136,38 +131,3 @@ def __declare_last__(cls) -> None: event.listen(cls, "after_delete", cls.log_delete, propagate=True) -class AuditUserMixin: - created_at: Mapped["datetime"] = mapped_column(DateTime(timezone=True), default=func.now(), index=True) - updated_at: Mapped["datetime"] = mapped_column(DateTime(timezone=True), default=func.now(), onupdate=func.now()) - - @declared_attr - @classmethod - def created_by_fk(cls) -> Mapped[int | None]: - return mapped_column(Integer, ForeignKey("user.id"), default=user_ctx.get) - - @declared_attr - @classmethod - def updated_by_fk(cls) -> Mapped[int | None]: - return mapped_column(Integer, ForeignKey("user.id"), default=user_ctx.get, nullable=True) - - @declared_attr - @classmethod - def created_by(cls) -> Mapped["User"]: - return relationship( - "User", - foreign_keys=[cls.created_by_fk], - primaryjoin=f"{cls.__name__}.created_by_fk==User.id", - enable_typechecks=False, - uselist=False, - ) - - @declared_attr - @classmethod - def updated_by(cls) -> Mapped["User"]: - return relationship( - "User", - foreign_keys=[cls.updated_by_fk], - primaryjoin=f"{cls.__name__}.updated_by_fk==User.id", - enable_typechecks=False, - uselist=False, - ) diff --git a/src/core/models/mixins/audit_time.py b/src/core/models/mixins/audit_time.py new file mode 100644 index 0000000..767ba26 --- /dev/null +++ b/src/core/models/mixins/audit_time.py @@ -0,0 +1,14 @@ +from typing import TYPE_CHECKING + +from sqlalchemy import func +from sqlalchemy.orm import Mapped, mapped_column + +from src.core.database.types import DateTimeTZ + +if TYPE_CHECKING: + from datetime import datetime + + +class AuditTimeMixin: + created_at: Mapped["datetime"] = mapped_column(DateTimeTZ, default=func.now(), index=True) + updated_at: Mapped["datetime"] = mapped_column(DateTimeTZ, default=func.now(), onupdate=func.now()) diff --git a/src/core/models/mixins/audit_user.py b/src/core/models/mixins/audit_user.py new file mode 100644 index 0000000..09b14d1 --- /dev/null +++ b/src/core/models/mixins/audit_user.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from sqlalchemy import Integer, func +from sqlalchemy.ext.declarative import declared_attr +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from src.core.database.types import DateTimeTZ +from src.core.utils.context import user_ctx + +if TYPE_CHECKING: + from datetime import datetime + + from src.features.auth.models import User + + +class AuditUserMixin: + created_at: Mapped["datetime"] = mapped_column(DateTimeTZ, default=func.now(), index=True) + updated_at: Mapped["datetime"] = mapped_column(DateTimeTZ, default=func.now(), onupdate=func.now()) + + @declared_attr + @classmethod + def created_by_fk(cls) -> Mapped[int | None]: + return mapped_column(Integer, default=user_ctx.get) + + @declared_attr + @classmethod + def updated_by_fk(cls) -> Mapped[int | None]: + return mapped_column(Integer, default=user_ctx.get, nullable=True) + + @declared_attr + @classmethod + def created_by(cls) -> Mapped["User"]: + return relationship( + "User", + foreign_keys=[cls.created_by_fk], + primaryjoin=f"{cls.__name__}.created_by_fk==User.id", + enable_typechecks=False, + uselist=False, + ) + + @declared_attr + @classmethod + def updated_by(cls) -> Mapped["User"]: + return relationship( + "User", + foreign_keys=[cls.updated_by_fk], + primaryjoin=f"{cls.__name__}.updated_by_fk==User.id", + enable_typechecks=False, + uselist=False, + ) diff --git a/src/core/repositories/__init__.py b/src/core/repositories/__init__.py new file mode 100644 index 0000000..e67369f --- /dev/null +++ b/src/core/repositories/__init__.py @@ -0,0 +1,3 @@ +from src.core.repositories.repository import BaseRepository + +__all__ = ("BaseRepository",) diff --git a/src/db/dtobase.py b/src/core/repositories/repository.py similarity index 90% rename from src/db/dtobase.py rename to src/core/repositories/repository.py index 6c74b86..5b92a19 100644 --- a/src/db/dtobase.py +++ b/src/core/repositories/repository.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import TYPE_CHECKING, Any, Generic, NamedTuple, TypedDict, TypeVar, overload +from typing import TYPE_CHECKING, Any, Generic, TypedDict, TypeVar, overload from uuid import UUID from pydantic import BaseModel @@ -10,16 +10,16 @@ from sqlalchemy.orm import InstrumentedAttribute, selectinload, undefer from sqlalchemy.sql.base import ExecutableOption -from src._types import Order, QueryParams -from src.context import locale_ctx -from src.db.base import Base -from src.db.session import async_engine -from src.exceptions import ExistError, NotFoundError +from src.core._types import Order, QueryParams +from src.core.database.session import async_engine +from src.core.errors.auth_exceptions import ExistError, NotFoundError +from src.core.models.base import Base +from src.core.utils.context import locale_ctx if TYPE_CHECKING: from sqlalchemy.engine.interfaces import ReflectedForeignKeyConstraint, ReflectedUniqueConstraint - from src.db.mixins import AuditLog + from src.core.models.mixins import AuditLog ModelT = TypeVar("ModelT", bound=Base) PkIdT = TypeVar("PkIdT", int, UUID) @@ -30,12 +30,6 @@ TABLE_PARAMS: dict[str, "InspectorTableConstraint"] = {} - -class OrmField(NamedTuple): - field: str - value: Any - - class InspectorTableConstraint(TypedDict, total=False): foreign_keys: dict[str, tuple[str, str]] unique_constraints: list[list[str]] @@ -70,8 +64,10 @@ async def inspect_table(table_name: str) -> InspectorTableConstraint: return result -class DtoBase(Generic[ModelT, CreateSchemaType, UpdateSchemaType, QuerySchemaType]): +class BaseRepository(Generic[ModelT, CreateSchemaType, UpdateSchemaType, QuerySchemaType]): id_attribute: str = "id" + check_nullable: bool = True + check_unique_constriants: bool = True def __init__(self, model: type[ModelT]) -> None: """ @@ -111,8 +107,8 @@ def get_id_attribute_value( """ return getattr(obj, id_attribute if id_attribute is not None else cls.id_attribute) - def inspect_relationship(self) -> dict[str, type[Any]]: - result = {} + def inspect_relationship(self) -> dict[str, type["RelationT"]]: + result: dict[str, type["RelationT"]] = {} insp = inspect(self.model) for relationship in insp.relationships: key = relationship.key @@ -207,6 +203,19 @@ def _apply_operator_filter(self, stmt: Select[tuple[ModelT]], key: str, value: A key format should be `field_name__operator`. eg: "start_at__lte", "end_at__gt", "name__ic", "name__nic" value (Any): The value used for the filter. + support: + eq: equal, accepts list of values + ne: not equal, accepts list of values + ic: ignore case + nic: not ignore case + le: less than equal + ge: greater than equal + lte: less than + gte: greater than + sw: starts with + nsw: not starts with + ew: ends with + new: not ends with Returns: Select[tuple[ModelT]]: The filtered statement. @@ -220,6 +229,10 @@ def _apply_operator_filter(self, stmt: Select[tuple[ModelT]], key: str, value: A "ge": lambda col, value: col > value, "lte": lambda col, value: col <= value, "gte": lambda col, value: col >= value, + "sw": lambda col, value: col.like(f"{value}%"), + "nsw": lambda col, value: not_(col.like(f"{value}%")), + "ew": lambda col, value: col.like(f"%{value}"), + "new": lambda col, value: not_(col.like(f"%{value}%")), } field_name, operator = key.split("__") @@ -532,34 +545,6 @@ async def _apply_foreign_keys_check( fk_result = (await session.execute(text(stmt_text))).one_or_none() self._check_not_found(fk_result, column, value) - async def list_and_count( - self, session: AsyncSession, query: QuerySchemaType, *options: ExecutableOption, undefer_load: bool = True - ) -> tuple[int, Sequence[ModelT]]: - """ - Asynchronously retrieves a list of items from the database and returns the count and results. - - Args: - session (AsyncSession): The async session object for the database connection. - query (QuerySchemaType): The query schema object containing the query parameters. - options (tuple | None, optional): Additional options for the query. Defaults to None. - undefer_load (bool, optional): Whether to undefer the load. Defaults to True. - Returns: - tuple[int, Sequence[ModelT]]: A tuple containing the count of items and the list of results. - """ - stmt = self._get_base_stmt() - stmt = self._apply_list(stmt, query) - if query.q: - stmt = self._apply_search(stmt, query.q) - c_stmt = stmt.with_only_columns(func.count()).order_by(None) - if query.limit is not None and query.offset is not None: - stmt = self._apply_pagination(stmt, query.limit, query.offset) - if query.order_by and query.order: - stmt = self._apply_order_by(stmt, query.order_by, query.order) - stmt = self._apply_selectinload(stmt, *options, undefer_load=undefer_load) - _count = await session.scalar(c_stmt) - results = (await session.scalars(stmt)).all() - return _count if _count is not None else 0, results - async def create( self, session: AsyncSession, @@ -584,9 +569,12 @@ async def create( Raises: None """ - insp = await inspect_table(self.model.__tablename__) - await self._apply_foreign_keys_check(session, obj_in, insp) - await self._apply_unique_constraints_when_create(session, obj_in, insp) + if any((self.check_nullable, self.check_unique_constriants)): + insp = await inspect_table(self.model.__tablename__) + if self.check_nullable: + await self._apply_foreign_keys_check(session, obj_in, insp) + if self.check_unique_constriants: + await self._apply_unique_constraints_when_create(session, obj_in, insp) m2m = self.inspect_relationship() extra_excluded = set(m2m.keys()) if excludes: @@ -599,7 +587,7 @@ async def create( if m2m: for key, value in m2m.items(): if hasattr(obj_in, key) and getattr(obj_in, key) is not None: - dto_m2m = DtoBase(value) + dto_m2m = BaseRepository(value) db_m2m = await dto_m2m.get_multi_by_pks_or_404(session, [r.id for r in getattr(obj_in, key)]) setattr(new_obj, key, db_m2m) setattr(obj_in, key, value) @@ -628,9 +616,12 @@ async def update( Returns: ModelT: The updated database object. """ - insp = await inspect_table(self.model.__tablename__) - await self._apply_foreign_keys_check(session, obj_in, insp) - await self._apply_unique_constraints_when_update(session, obj_in, insp, db_obj) + if any((self.check_nullable, self.check_unique_constriants)): + insp = await inspect_table(self.model.__tablename__) + if self.check_nullable: + await self._apply_foreign_keys_check(session, obj_in, insp) + if self.check_unique_constriants: + await self._apply_unique_constraints_when_update(session, obj_in, insp, db_obj) m2m = self.inspect_relationship() extra_excluded = set(m2m.keys()) if excludes: @@ -680,7 +671,7 @@ async def update_relationship_field( if getattr(fk_value, relationship_pk_name) not in fk_values: getattr(obj, relationship_name).remove(fk_value) for fk_value in fk_values: - target_dto = DtoBase(model=m2m_model) + target_dto = BaseRepository(model=m2m_model) if fk_value not in local_fk_value_ids: target_obj = await target_dto.get_one_or_404(session, fk_value) getattr(obj, relationship_name).append(target_obj) @@ -688,6 +679,59 @@ async def update_relationship_field( setattr(obj, relationship_name, None) return obj + async def list_and_count( + self, session: AsyncSession, query: QuerySchemaType, *options: ExecutableOption, undefer_load: bool = True + ) -> tuple[int, Sequence[ModelT]]: + """ + Asynchronously retrieves a list of items from the database and returns the count and results. + + Args: + session (AsyncSession): The async session object for the database connection. + query (QuerySchemaType): The query schema object containing the query parameters. + options (tuple | None, optional): Additional options for the query. Defaults to None. + undefer_load (bool, optional): Whether to undefer the load. Defaults to True. + Returns: + tuple[int, Sequence[ModelT]]: A tuple containing the count of items and the list of results. + """ + stmt = self._get_base_stmt() + stmt = self._apply_list(stmt, query) + if query.q: + stmt = self._apply_search(stmt, query.q) + c_stmt = stmt.with_only_columns(func.count()).order_by(None) + if query.limit is not None and query.offset is not None: + stmt = self._apply_pagination(stmt, query.limit, query.offset) + if query.order_by and query.order: + stmt = self._apply_order_by(stmt, query.order_by, query.order) + stmt = self._apply_selectinload(stmt, *options, undefer_load=undefer_load) + _count = await session.scalar(c_stmt) + results = (await session.scalars(stmt)).all() + return _count if _count is not None else 0, results + + async def get_all(self, session: AsyncSession) -> Sequence[ModelT]: + return (await session.scalars(self._get_base_stmt())).all() + + async def get_one_by_id( + self, session: AsyncSession, pk_id: PkIdT, *options: ExecutableOption, undefer_load: bool = False)-> ModelT | None: + """ + Retrieves a single instance of ModelT from the database based on the provided \n + primary key (pk_id) and optional query options (options). + + Parameters: + session (AsyncSession): The asynchronous session used to execute the database query. + pk_id (PkIdT): The primary key value used to identify the instance to be retrieved. + *options ExecutableOption: query options to apply to the database query. + undefer_load (bool, optional): Whether to undefer the load. Defaults to False. + + Returns: + ModelT: The retrieved instance of ModelT from the database. + """ + stmt = self._get_base_stmt() + id_str = self.get_id_attribute_value(self.model) + stmt = stmt.where(id_str == pk_id) + if options: + stmt = self._apply_selectinload(stmt, *options, undefer_load=undefer_load) + return (await session.scalars(stmt)).one_or_none() + async def get_one_or_404( self, session: AsyncSession, pk_id: PkIdT, *options: ExecutableOption, undefer_load: bool = False ) -> ModelT: @@ -707,12 +751,7 @@ async def get_one_or_404( Raises: NotFoundError: If no instance with the given primary key (pk_id) is found in the database. """ - stmt = self._get_base_stmt() - id_str = self.get_id_attribute_value(self.model) - stmt = stmt.where(id_str == pk_id) - if options: - stmt = self._apply_selectinload(stmt, *options, undefer_load=undefer_load) - result = (await session.scalars(stmt)).one_or_none() + result = await self.get_one_by_id(session, pk_id, *options, undefer_load=undefer_load) if not result: raise NotFoundError(self.model.__visible_name__[locale_ctx.get()], self.id_attribute, pk_id) return result @@ -778,6 +817,17 @@ async def get_multi_by_filter( stmt = self._apply_selectinload(stmt, *options, undefer_load=undefer_load) return (await session.scalars(stmt)).all() + async def get_multi_by_ids( + self, session: AsyncSession, pk_ids: list[PkIdT], *options: ExecutableOption, undefer_load: bool = False + ) -> Sequence[ModelT]: + stmt = self._get_base_stmt() + id_str = self.get_id_attribute_value(self.model) + stmt = stmt.where(id_str.in_(pk_ids)) + if options: + stmt = self._apply_selectinload(stmt, *options, undefer_load=undefer_load) + return (await session.scalars(stmt)).all() + + async def get_multi_by_pks_or_404( self, session: AsyncSession, pk_ids: list[PkIdT], *options: ExecutableOption, undefer_load: bool = False ) -> Sequence[ModelT]: @@ -796,12 +846,7 @@ async def get_multi_by_pks_or_404( Raises: NotFoundError: If no records are found with the given primary key IDs. """ - stmt = self._get_base_stmt() - id_str = self.get_id_attribute_value(self.model) - stmt = stmt.where(id_str.in_(pk_ids)) - if options: - stmt = self._apply_selectinload(stmt, *options, undefer_load=undefer_load) - results = (await session.scalars(stmt)).all() + results = await self.get_multi_by_ids(session, pk_ids, *options, undefer_load=undefer_load) if not results: raise NotFoundError(self.model.__visible_name__[locale_ctx.get()], self.id_attribute, pk_ids) for r in results: @@ -821,10 +866,7 @@ async def get_one_and_delete(self, session: AsyncSession, pk_id: PkIdT) -> None: Returns: None: This function does not return anything. """ - stmt = self._get_base_stmt() - id_str = self.get_id_attribute_value(self.model) - result = (await session.scalars(stmt.where(id_str == pk_id))).one_or_none() - result = self._check_not_found(result, self.id_attribute, pk_id) + result = await self.get_one_or_404(session, pk_id) await self.delete(session, result) async def get_multi_and_delete(self, session: AsyncSession, pk_ids: list[PkIdT]) -> None: @@ -842,9 +884,7 @@ async def get_multi_and_delete(self, session: AsyncSession, pk_ids: list[PkIdT]) NotFoundError: If any of the primary key IDs are not found in the database. """ - stmt = self._get_base_stmt() - id_str = self.get_id_attribute_value(self.model) - results = (await session.scalars(stmt.where(id_str.in_(pk_ids)))).all() + results = await self.get_multi_by_ids(session, pk_ids) for r in results: id_value = self.get_id_attribute_value(r) if id_value not in pk_ids: diff --git a/src/consts.py b/src/core/utils/__init__.py similarity index 100% rename from src/consts.py rename to src/core/utils/__init__.py diff --git a/src/cbv.py b/src/core/utils/cbv.py similarity index 100% rename from src/cbv.py rename to src/core/utils/cbv.py diff --git a/src/context.py b/src/core/utils/context.py similarity index 80% rename from src/context.py rename to src/core/utils/context.py index 07a11e0..1efe910 100644 --- a/src/context.py +++ b/src/core/utils/context.py @@ -1,6 +1,7 @@ from contextvars import ContextVar +from uuid import uuid4 -request_id_ctx: ContextVar[str | None] = ContextVar("x-request-id", default=None) +request_id_ctx: ContextVar[str] = ContextVar("x-request-id", default=str(uuid4())) user_ctx: ContextVar[int | None] = ContextVar("x-auth-user", default=None) locale_ctx: ContextVar[str] = ContextVar("Accept-Language", default="en_US") orm_diff_ctx: ContextVar[dict | None] = ContextVar("x-orm-diff", default=None) diff --git a/src/i18n.py b/src/core/utils/i18n.py similarity index 92% rename from src/i18n.py rename to src/core/utils/i18n.py index c3c9a7d..7ff56ba 100644 --- a/src/i18n.py +++ b/src/core/utils/i18n.py @@ -4,9 +4,9 @@ from operator import getitem from typing import Any, Literal, TypeAlias -from src.context import locale_ctx +from src.core.utils.context import locale_ctx +from src.core.utils.singleton import singleton from src.openapi import translations -from src.utils.singleton import singleton ACCEPTED_LANGUAGES: TypeAlias = Literal["en_US", "zh_CN"] diff --git a/src/core/utils/singleton.py b/src/core/utils/singleton.py new file mode 100644 index 0000000..6089268 --- /dev/null +++ b/src/core/utils/singleton.py @@ -0,0 +1,25 @@ +from collections.abc import Callable +from typing import ParamSpec, TypeVar + +T = TypeVar("T") +P = ParamSpec("P") + +def singleton(cls: type[T]) -> Callable[..., T]: + """ + Singleton decorator for any class implements. + + Args: + cls (Type[T]): Class type. + + Returns: + Callable[[P.args, P.kwargs], T]: Instance of cls with the given arguments. + """ + _instance: dict[type[T], T] = {} + + def _singleton(*args: P.args, **kwargs: P.kwargs) -> T: + if cls not in _instance: + _instance[cls] = cls(*args, **kwargs) + return _instance[cls] + + return _singleton + diff --git a/src/validators.py b/src/core/utils/validators.py similarity index 100% rename from src/validators.py rename to src/core/utils/validators.py diff --git a/src/db/__init__.py b/src/db/__init__.py index 680ad64..020e639 100644 --- a/src/db/__init__.py +++ b/src/db/__init__.py @@ -1,5 +1,5 @@ from src.auth.models import * -from src.db.base import Base +from src.core.models.base import Base def orm_by_table_name(table_name: str) -> type[Base] | None: diff --git a/src/db/_types.py b/src/db/_types.py deleted file mode 100644 index a3aff7e..0000000 --- a/src/db/_types.py +++ /dev/null @@ -1,68 +0,0 @@ -import uuid -from datetime import date, datetime -from enum import IntEnum -from typing import Annotated, TypeVar, no_type_check - -from sqlalchemy import Boolean, Date, DateTime, Integer, String, func, type_coerce -from sqlalchemy.dialects.postgresql import BYTEA, UUID -from sqlalchemy.engine import Dialect -from sqlalchemy.orm import mapped_column -from sqlalchemy.sql import expression -from sqlalchemy.sql.elements import BindParameter, ColumnElement -from sqlalchemy.types import TypeDecorator - -from src.config import settings - -T = TypeVar("T", bound=IntEnum) - - -class EncryptedString(TypeDecorator): - impl = BYTEA - cache_ok = True - - def __init__(self, secret_key: str | None = settings.SECRET_KEY) -> None: - super().__init__() - self.secret = secret_key - - @no_type_check - def bind_expression(self, bind_value: BindParameter) -> ColumnElement | None: - bind_value = type_coerce(bind_value, String) # type: ignore # noqa: PGH003 - return func.pgp_sym_encrypt(bind_value, self.secret) - - @no_type_check - def column_expression(self, column: ColumnElement) -> ColumnElement | None: - return func.pgp_sym_decrypt(column, self.secret) - - -class IntegerEnum(TypeDecorator): - impl = Integer - cache_ok = True - - def __init__(self, enum_type: type[T]) -> None: - super().__init__() - self.enum_type = enum_type - - @no_type_check - def process_bind_param(self, value: int, dialect: Dialect) -> int: # noqa: ARG002 - if isinstance(value, self.enum_type): - return value.value - msg = f"expected {self.enum_type.__name__} value, got {value.__class__.__name__}" - raise ValueError(msg) - - @no_type_check - def process_result_value(self, value: int, dialect: Dialect): # noqa: ANN202, ARG002 - return self.enum_type(value) - - @no_type_check - def copy(self, **kwargs): # noqa: ANN202, ARG002, ANN003 - return IntegerEnum(self.enum_type) - - -uuid_pk = Annotated[uuid.UUID, mapped_column(UUID(as_uuid=True), default=uuid.uuid4, primary_key=True)] -int_pk = Annotated[int, mapped_column(Integer, primary_key=True)] -bool_true = Annotated[bool, mapped_column(Boolean, server_default=expression.true())] -bool_false = Annotated[bool, mapped_column(Boolean, server_default=expression.false())] -datetime_required = Annotated[datetime, mapped_column(DateTime(timezone=True))] -datetime_required = Annotated[datetime, mapped_column(DateTime(timezone=True))] -date_required = Annotated[date, mapped_column(Date)] -date_optional = Annotated[date | None, mapped_column(Date)] diff --git a/src/deps.py b/src/deps.py index 5c64ee2..13834b0 100644 --- a/src/deps.py +++ b/src/deps.py @@ -9,14 +9,14 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload -from src import exceptions -from src.auth.models import RolePermission, User -from src.config import settings -from src.context import locale_ctx -from src.db.session import async_session -from src.enums import ReservedRoleSlug -from src.security import API_WHITE_LISTS, JWT_ALGORITHM, JwtTokenPayload -from src.utils.cache import CacheNamespace, redis_client +from src.core.config import settings +from src.core.database.session import async_session +from src.core.error import auth_exceptions +from src.core.utils.context import locale_ctx +from src.features.auth.consts import ReservedRoleSlug +from src.features.auth.models import RolePermission, User +from src.features.auth.security import API_WHITE_LISTS, JWT_ALGORITHM, JwtTokenPayload +from src.libs.redis.cache import CacheNamespace, redis_client token = HTTPBearer() @@ -30,16 +30,16 @@ async def auth(request: Request, session: AsyncSession = Depends(get_session), t try: payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[JWT_ALGORITHM]) except jwt.DecodeError as e: - raise exceptions.TokenInvalidError from e + raise auth_exceptions.TokenInvalidError from e token_data = JwtTokenPayload(**payload) if token_data.refresh: - raise exceptions.TokenInvalidError + raise auth_exceptions.TokenInvalidError now = datetime.now(tz=UTC) if now < token_data.issued_at or now > token_data.expires_at: - raise exceptions.TokenExpireError + raise auth_exceptions.TokenExpireError user = await session.get(User, token_data.sub, options=[selectinload(User.role)]) if not user: - raise exceptions.NotFoundError(User.__visible_name__[locale_ctx.get()], "id", id) + raise auth_exceptions.NotFoundError(User.__visible_name__[locale_ctx.get()], "id", id) check_user_active(user.is_active) operation_id = request.scope.get("operation_id") if not operation_id: @@ -53,7 +53,7 @@ async def auth(request: Request, session: AsyncSession = Depends(get_session), t def check_user_active(is_active: bool) -> None: if not is_active: - raise exceptions.PermissionDenyError + raise auth_exceptions.PermissionDenyError def check_privileged_role(slug: str, operation_id: str) -> bool: @@ -71,11 +71,11 @@ async def check_role_permissions(role_id: int, session: AsyncSession, operation_ await session.scalars(select(RolePermission.permission_id).where(RolePermission.role_id == role_id)) ).all() if not _permissions: - raise exceptions.PermissionDenyError + raise auth_exceptions.PermissionDenyError permissions = [str(p) for p in _permissions] await redis_client.set_nx(name=str(role_id), value=permissions, namespace=CacheNamespace.ROLE_CACHE) if operation_id not in permissions: - raise exceptions.PermissionDenyError + raise auth_exceptions.PermissionDenyError SqlaSession = Annotated[AsyncSession, Depends(get_session)] diff --git a/src/enums.py b/src/enums.py deleted file mode 100644 index 895069c..0000000 --- a/src/enums.py +++ /dev/null @@ -1,12 +0,0 @@ -from enum import IntEnum - -from src._types import AppStrEnum - - -class Env(IntEnum): - PRD = 0 - DEV = 1 - - -class ReservedRoleSlug(AppStrEnum): - ADMIN = "admin" diff --git a/src/utils/rate_limiter.py b/src/features/auth/__init__.py similarity index 100% rename from src/utils/rate_limiter.py rename to src/features/auth/__init__.py diff --git a/src/features/auth/consts.py b/src/features/auth/consts.py new file mode 100644 index 0000000..0428db1 --- /dev/null +++ b/src/features/auth/consts.py @@ -0,0 +1,5 @@ +from enum import StrEnum + + +class ReservedRoleSlug(StrEnum): + ADMIN = "admin" diff --git a/src/features/auth/graphql.py b/src/features/auth/graphql.py new file mode 100644 index 0000000..e69de29 diff --git a/src/auth/models.py b/src/features/auth/models.py similarity index 86% rename from src/auth/models.py rename to src/features/auth/models.py index 4501849..8934700 100644 --- a/src/auth/models.py +++ b/src/features/auth/models.py @@ -8,9 +8,9 @@ from sqlalchemy.orm import Mapped, backref, column_property, mapped_column, relationship from sqlalchemy.orm.collections import attribute_mapped_collection -from src.db import _types -from src.db.base import Base -from src.db.mixins import AuditTimeMixin +from src.core.database import types +from src.core.models.base import Base +from src.core.models.mixins import AuditTimeMixin class RolePermission(Base): @@ -29,7 +29,7 @@ class Role(Base, AuditTimeMixin): __tablename__ = "role" __search_fields__: ClassVar = {"name"} __visible_name__ = {"en_US": "Role", "zh_CN": "用户角色"} - id: Mapped[_types.int_pk] + id: Mapped[types.int_pk] name: Mapped[str] slug: Mapped[str] description: Mapped[str | None] @@ -42,7 +42,7 @@ class Role(Base, AuditTimeMixin): class Permission(Base): __tablename__ = "permission" __visible_name__ = {"en_US": "Permission", "zh_CN": "权限"} - id: Mapped[_types.uuid_pk] + id: Mapped[types.uuid_pk] name: Mapped[str] url: Mapped[str] method: Mapped[str] @@ -54,7 +54,7 @@ class Group(Base, AuditTimeMixin): __tablename__ = "group" __search_fields__: ClassVar = {"name"} __visible_name__ = {"en_US": "Group", "zh_CN": "用户组"} - id: Mapped[_types.int_pk] + id: Mapped[types.int_pk] name: Mapped[str] description: Mapped[str | None] role_id: Mapped[int] = mapped_column(ForeignKey(Role.id, ondelete="RESTRICT")) @@ -66,14 +66,14 @@ class User(Base, AuditTimeMixin): __tablename__ = "user" __search_fields__: ClassVar = {"email", "name", "phone"} __visible_name__ = {"en_US": "User", "zh_CN": "用户"} - id: Mapped[_types.int_pk] + id: Mapped[types.int_pk] name: Mapped[str] email: Mapped[str | None] = mapped_column(unique=True) phone: Mapped[str | None] = mapped_column(unique=True) password: Mapped[str] avatar: Mapped[str | None] last_login: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) - is_active: Mapped[_types.bool_true] + is_active: Mapped[types.bool_true] group_id: Mapped[int] = mapped_column(ForeignKey(Group.id, ondelete="RESTRICT")) group: Mapped["Group"] = relationship(back_populates="user", passive_deletes=True) role_id: Mapped[int] = mapped_column(ForeignKey(Role.id, ondelete="RESTRICT")) @@ -86,14 +86,14 @@ class Menu(Base): __visible_name__ = {"en_US": "Menu", "zh_CN": "菜单"} id: Mapped[int] = mapped_column(Integer, primary_key=True) name: Mapped[str] = mapped_column(unique=True, comment="the unique name of route") - hidden: Mapped[_types.bool_false] + hidden: Mapped[types.bool_false] redirect: Mapped[str] = mapped_column(comment="redirect url for the route") - hideChildrenInMenu: Mapped[_types.bool_false] = mapped_column(comment="hide children in menu force or not") # noqa: N815 + hideChildrenInMenu: Mapped[types.bool_false] = mapped_column(comment="hide children in menu force or not") # noqa: N815 order: Mapped[int] title: Mapped[str] = mapped_column(comment="the title of the route, 面包屑") icon: Mapped[str | None] - keepAlive: Mapped[_types.bool_false] = mapped_column(comment="cache route, 开启multi-tab时为true") # noqa: N815 - hiddenHeaderContent: Mapped[_types.bool_false] = mapped_column(comment="隐藏pageheader页面带的面包屑和标题栏") # noqa: N815 + keepAlive: Mapped[types.bool_false] = mapped_column(comment="cache route, 开启multi-tab时为true") # noqa: N815 + hiddenHeaderContent: Mapped[types.bool_false] = mapped_column(comment="隐藏pageheader页面带的面包屑和标题栏") # noqa: N815 permission: Mapped[list[int] | None] = mapped_column(ARRAY(Integer, dimensions=1)) parent_id: Mapped[int | None] = mapped_column(ForeignKey(id, ondelete="CASCADE")) children: Mapped[list["Menu"]] = relationship( diff --git a/src/auth/services.py b/src/features/auth/repositories.py similarity index 61% rename from src/auth/services.py rename to src/features/auth/repositories.py index 52ef351..8c15305 100644 --- a/src/auth/services.py +++ b/src/features/auth/repositories.py @@ -4,16 +4,15 @@ from sqlalchemy import or_, select from sqlalchemy.ext.asyncio import AsyncSession -from src.auth import schemas -from src.auth.models import Menu, Permission, User -from src.auth.schemas import PermissionCreate, PermissionUpdate -from src.context import locale_ctx -from src.db.dtobase import DtoBase -from src.exceptions import NotFoundError, PermissionDenyError -from src.security import verify_password +from src.core.errors.auth_exceptions import NotFoundError, PermissionDenyError +from src.core.repositories import BaseRepository +from src.core.utils.context import locale_ctx +from src.features.auth import schemas +from src.features.auth.models import Group, Menu, Permission, Role, User +from src.features.auth.security import verify_password -class UserDto(DtoBase[User, schemas.UserCreate, schemas.UserUpdate, schemas.UserQuery]): +class UserRepo(BaseRepository[User, schemas.UserCreate, schemas.UserUpdate, schemas.UserQuery]): async def verify_user(self, session: AsyncSession, user: OAuth2PasswordRequestForm) -> User: stmt = self._get_base_stmt().where(or_(self.model.email == user.username, self.model.phone == user.username)) db_user = await session.scalar(stmt) @@ -24,11 +23,12 @@ async def verify_user(self, session: AsyncSession, user: OAuth2PasswordRequestFo return db_user -class PermissionDto(DtoBase[Permission, schemas.PermissionCreate, schemas.PermissionUpdate, schemas.PermissionQuery]): + +class PermissionRepo(BaseRepository[Permission, schemas.PermissionCreate, schemas.PermissionUpdate, schemas.PermissionQuery]): async def create( self, session: AsyncSession, - obj_in: PermissionCreate, + obj_in: schemas.PermissionCreate, excludes: set[str] | None = None, exclude_unset: bool = False, exclude_none: bool = False, @@ -40,7 +40,7 @@ async def update( self, session: AsyncSession, db_obj: Permission, - obj_in: PermissionUpdate, + obj_in: schemas.PermissionUpdate, excludes: set[str] | None = None, commit: bool | None = True, ) -> Permission: @@ -50,10 +50,17 @@ async def delete(self, session: AsyncSession, db_obj: Permission) -> None: raise NotImplementedError -class MenuDto(DtoBase[Menu, schemas.MenuCreate, schemas.MenuUpdate, schemas.MenuQuery]): +class MenuRepo(BaseRepository[Menu, schemas.MenuCreate, schemas.MenuUpdate, schemas.MenuQuery]): async def get_all(self, session: AsyncSession) -> Sequence[Menu]: return (await session.scalars(select(self.model))).all() @staticmethod def menu_tree_transform(menus: Sequence[Menu]) -> list[dict]: ... + + +class GroupRepo(BaseRepository[Group, schemas.GroupCreate, schemas.GroupUpdate, schemas.GroupQuery]): + ... + +class RoleRepo(BaseRepository[Role, schemas.RoleCreate, schemas.RoleUpdate, schemas.RoleQuery]): + ... diff --git a/src/auth/api.py b/src/features/auth/router.py similarity index 91% rename from src/auth/api.py rename to src/features/auth/router.py index cac66a6..f03def0 100644 --- a/src/auth/api.py +++ b/src/features/auth/router.py @@ -3,17 +3,16 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload -from src import errors -from src._types import IdResponse, ListT -from src.auth import schemas -from src.auth.models import Group, Menu, Role, User -from src.auth.services import MenuDto, UserDto -from src.cbv import cbv -from src.db.dtobase import DtoBase +from src.core._types import IdResponse, ListT +from src.core.errors import base_exceptions +from src.core.errors.auth_exceptions import GenerError +from src.core.utils.cbv import cbv +from src.core.utils.validators import list_to_tree from src.deps import auth, get_session -from src.exceptions import GenerError -from src.security import generate_access_token_response -from src.validators import list_to_tree +from src.features.auth import schemas +from src.features.auth.models import Group, Menu, Role, User +from src.features.auth.repositories import GroupRepo, MenuRepo, RoleRepo, UserRepo +from src.features.auth.security import generate_access_token_response router = APIRouter() @@ -23,7 +22,7 @@ async def login_pwd( user: OAuth2PasswordRequestForm = Depends(), session: AsyncSession = Depends(get_session), ) -> schemas.AccessToken: - dto = UserDto(User) + dto = UserRepo(User) result = await dto.verify_user(session, user) return generate_access_token_response(result.id) @@ -32,7 +31,7 @@ async def login_pwd( class UserAPI: user: User = Depends(auth) session: AsyncSession = Depends(get_session) - dto = UserDto(User) + dto = UserRepo(User) @router.post("/users", operation_id="5091fff6-1adc-4a22-8a8c-ef0107122df7", summary="创建新用户/Create new user") async def create_user(self, user: schemas.UserCreate) -> IdResponse: @@ -68,7 +67,7 @@ async def get_users(self, query: schemas.UserQuery = Depends()) -> ListT[schemas async def update_user(self, id: int, user: schemas.UserUpdate) -> IdResponse: update_user = user.model_dump(exclude_unset=True) if "password" in update_user and update_user["password"] is None: - raise GenerError(errors.ERR_10005, status_code=status.HTTP_406_NOT_ACCEPTABLE) + raise GenerError(base_exceptions.ERR_10005, status_code=status.HTTP_406_NOT_ACCEPTABLE) db_user = await self.dto.get_one_or_404(self.session, id) await self.dto.update(self.session, db_user, user) return IdResponse(id=id) @@ -84,7 +83,7 @@ async def delete_user(self, id: int) -> IdResponse: class GroupAPI: user: User = Depends(auth) session: AsyncSession = Depends(get_session) - dto = DtoBase(Group) + dto = GroupRepo(Group) @router.post("/groups", operation_id="9e3e639d-c694-467d-9209-717b038cf267") async def create_group(self, group: schemas.GroupCreate) -> IdResponse: @@ -118,7 +117,7 @@ async def delete_group(self, id: int) -> IdResponse: class RoleAPI: user: User = Depends(auth) session: AsyncSession = Depends(get_session) - dto = DtoBase(Role) + dto = RoleRepo(Role) @router.post("/roles", operation_id="a18a152b-e9e9-4128-b8be-8a8e9c842abb") async def create_role(self, role: schemas.RoleCreate) -> IdResponse: @@ -152,7 +151,7 @@ async def delete_role(self, id: int) -> IdResponse: class MenuAPI: user: User = Depends(auth) session: AsyncSession = Depends(get_session) - dto = MenuDto(Menu) + dto = MenuRepo(Menu) @router.post("/menus", operation_id="008bf4d4-cc01-48b0-82b8-1a67c0348b31") async def create_menu(self, meun: schemas.MenuCreate) -> IdResponse: diff --git a/src/auth/schemas.py b/src/features/auth/schemas.py similarity index 97% rename from src/auth/schemas.py rename to src/features/auth/schemas.py index 09c7ceb..ce68145 100644 --- a/src/auth/schemas.py +++ b/src/features/auth/schemas.py @@ -2,7 +2,7 @@ from pydantic_extra_types.phone_numbers import PhoneNumber -from src._types import AuditTime, BaseModel, IdCreate, QueryParams +from src.core._types import AuditTime, BaseModel, IdCreate, QueryParams class AccessToken(BaseModel): diff --git a/src/security.py b/src/features/auth/security.py similarity index 98% rename from src/security.py rename to src/features/auth/security.py index 80a275f..68d1081 100644 --- a/src/security.py +++ b/src/features/auth/security.py @@ -6,7 +6,7 @@ from pydantic import BaseModel from src.auth.schemas import AccessToken -from src.config import settings +from src.core.config import settings JWT_ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_SECS = settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60 diff --git a/src/features/auth/utils.py b/src/features/auth/utils.py new file mode 100644 index 0000000..e69de29 diff --git a/src/internal/__init__ b/src/internal/__init__ new file mode 100644 index 0000000..e69de29 diff --git a/src/internal/api.py b/src/internal/api.py index 4f67e24..e3240ba 100644 --- a/src/internal/api.py +++ b/src/internal/api.py @@ -1,9 +1,9 @@ from fastapi import APIRouter, Request from sqlalchemy import delete, select -from src._types import ResultT from src.auth import schemas from src.auth.models import Permission +from src.core._types import ResultT from src.deps import AuthUser, SqlaSession router = APIRouter() diff --git a/src/libs/__init__.py b/src/libs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/libs/redis/__init__.py b/src/libs/redis/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/cache.py b/src/libs/redis/cache.py similarity index 99% rename from src/utils/cache.py rename to src/libs/redis/cache.py index 6c06460..8fef4e8 100644 --- a/src/utils/cache.py +++ b/src/libs/redis/cache.py @@ -20,8 +20,8 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session -from src._types import AppStrEnum, P, R from src.auth.models import User +from src.core._types import AppStrEnum, P, R DEFAULT_CACHE_HEADER = "X-Cache" logger = logging.getLogger(__name__) diff --git a/src/libs/redis/rate_limiter.py b/src/libs/redis/rate_limiter.py new file mode 100644 index 0000000..e69de29 diff --git a/src/libs/redis/redis.py b/src/libs/redis/redis.py new file mode 100644 index 0000000..e69de29 diff --git a/src/loggers.py b/src/loggers.py index 8e473ef..cf5155e 100644 --- a/src/loggers.py +++ b/src/loggers.py @@ -1,13 +1,13 @@ import logging from logging import LogRecord, setLogRecordFactory from logging.config import dictConfig +from typing import Any -from src._types import P -from src.context import request_id_ctx +from src.core.utils.context import request_id_ctx class LogExtraFactory(LogRecord): - def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) request_id = request_id_ctx.get() or "N/A" self.__dict__["request_id"] = request_id diff --git a/src/register/__init__.py b/src/register/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/middlewares.py b/src/register/middlewares.py similarity index 95% rename from src/middlewares.py rename to src/register/middlewares.py index dd19410..f63fe70 100644 --- a/src/middlewares.py +++ b/src/register/middlewares.py @@ -12,8 +12,8 @@ from starlette.responses import Response, StreamingResponse from starlette.types import ASGIApp -from src.context import locale_ctx, request_id_ctx -from src.i18n import ACCEPTED_LANGUAGES +from src.core.utils.context import locale_ctx, request_id_ctx +from src.core.utils.i18n import ACCEPTED_LANGUAGES @dataclass diff --git a/src/routers.py b/src/register/routers.py similarity index 100% rename from src/routers.py rename to src/register/routers.py diff --git a/src/utils/singleton.py b/src/utils/singleton.py deleted file mode 100644 index 2b042e4..0000000 --- a/src/utils/singleton.py +++ /dev/null @@ -1,13 +0,0 @@ -from src._types import P, T - - -def singleton(cls: T) -> T: - """Singleton decorator for any class implements.""" - _instance = {} - - def _singleton(*args: P.args, **kwargs: P.kwargs) -> T: - if cls not in _instance: - _instance[cls] = cls(*args, **kwargs) - return _instance[cls] - - return _singleton diff --git a/tests/test_pydantic.py b/tests/test_pydantic.py new file mode 100644 index 0000000..28274b8 --- /dev/null +++ b/tests/test_pydantic.py @@ -0,0 +1,34 @@ +from src.core.models.base import Base + +from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy import create_engine, ForeignKey +from sqlalchemy.orm import Session +from sqlalchemy.exc import IntegrityError + +from sqlalchemy import event + +class TestModel(Base): + __tablename__ = "test" + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(unique=True) + email: Mapped[str | None] = mapped_column(unique=True) + t_id: Mapped[int] = mapped_column(ForeignKey("test1.id", ondelete="CASCADE")) + test1: Mapped["Test1Model"] = relationship(back_populates="test") + +class Test1Model(Base): + __tablename__ = "test1" + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(unique=True) + test: Mapped[list[TestModel]] = relationship(back_populates="test1") + +engine = create_engine("sqlite://", echo=True) + +Base.metadata.create_all(engine) + +with Session(engine) as session: + t = TestModel(name="test", email="test", t_id=1) + session.add(TestModel(name="test", email="test", t_id=1)) + session.commit() + print(t.__dict__) + print(t.test1) +