From 96c29ef62d9d8467c6607d56fc3488ec155c1627 Mon Sep 17 00:00:00 2001 From: wangxin688 <182467653@qq.com> Date: Sat, 27 Jan 2024 21:54:00 +0800 Subject: [PATCH] fix(model): fix user name error --- src/context.py | 2 +- src/db/_types.py | 32 +++++++++++++++++++++++++++++++- src/db/mixins.py | 16 ++++++++-------- 3 files changed, 40 insertions(+), 10 deletions(-) diff --git a/src/context.py b/src/context.py index d759bb8..07a11e0 100644 --- a/src/context.py +++ b/src/context.py @@ -1,7 +1,7 @@ from contextvars import ContextVar request_id_ctx: ContextVar[str | None] = ContextVar("x-request-id", default=None) -auth_user_ctx: ContextVar[int | None] = ContextVar("x-auth-user", default=None) +user_ctx: ContextVar[int | None] = ContextVar("x-auth-user", default=None) locale_ctx: ContextVar[str] = ContextVar("Accept-Language", default="en_US") orm_diff_ctx: ContextVar[dict | None] = ContextVar("x-orm-diff", default=None) celery_current_id: ContextVar[str | None] = ContextVar("x-celery-cid", default=None) diff --git a/src/db/_types.py b/src/db/_types.py index cd523fb..a3aff7e 100644 --- a/src/db/_types.py +++ b/src/db/_types.py @@ -1,9 +1,11 @@ import uuid from datetime import date, datetime -from typing import Annotated +from enum import IntEnum +from typing import Annotated, TypeVar, no_type_check from sqlalchemy import Boolean, Date, DateTime, Integer, String, func, type_coerce from sqlalchemy.dialects.postgresql import BYTEA, UUID +from sqlalchemy.engine import Dialect from sqlalchemy.orm import mapped_column from sqlalchemy.sql import expression from sqlalchemy.sql.elements import BindParameter, ColumnElement @@ -11,6 +13,8 @@ from src.config import settings +T = TypeVar("T", bound=IntEnum) + class EncryptedString(TypeDecorator): impl = BYTEA @@ -20,14 +24,40 @@ def __init__(self, secret_key: str | None = settings.SECRET_KEY) -> None: super().__init__() self.secret = secret_key + @no_type_check def bind_expression(self, bind_value: BindParameter) -> ColumnElement | None: bind_value = type_coerce(bind_value, String) # type: ignore # noqa: PGH003 return func.pgp_sym_encrypt(bind_value, self.secret) + @no_type_check def column_expression(self, column: ColumnElement) -> ColumnElement | None: return func.pgp_sym_decrypt(column, self.secret) +class IntegerEnum(TypeDecorator): + impl = Integer + cache_ok = True + + def __init__(self, enum_type: type[T]) -> None: + super().__init__() + self.enum_type = enum_type + + @no_type_check + def process_bind_param(self, value: int, dialect: Dialect) -> int: # noqa: ARG002 + if isinstance(value, self.enum_type): + return value.value + msg = f"expected {self.enum_type.__name__} value, got {value.__class__.__name__}" + raise ValueError(msg) + + @no_type_check + def process_result_value(self, value: int, dialect: Dialect): # noqa: ANN202, ARG002 + return self.enum_type(value) + + @no_type_check + def copy(self, **kwargs): # noqa: ANN202, ARG002, ANN003 + return IntegerEnum(self.enum_type) + + uuid_pk = Annotated[uuid.UUID, mapped_column(UUID(as_uuid=True), default=uuid.uuid4, primary_key=True)] int_pk = Annotated[int, mapped_column(Integer, primary_key=True)] bool_true = Annotated[bool, mapped_column(Boolean, server_default=expression.true())] diff --git a/src/db/mixins.py b/src/db/mixins.py index 6e853b1..bae2879 100644 --- a/src/db/mixins.py +++ b/src/db/mixins.py @@ -9,7 +9,7 @@ from sqlalchemy.orm import Mapped, Mapper, class_mapper, mapped_column, relationship from sqlalchemy.orm.attributes import get_history -from src.context import auth_user_ctx, orm_diff_ctx, request_id_ctx +from src.context import orm_diff_ctx, request_id_ctx, user_ctx from src.db._types import int_pk from src.db.base import Base @@ -54,8 +54,8 @@ class AuditLog: def user_id(cls) -> Mapped[int | None]: return mapped_column( Integer, - ForeignKey("auth_user.id", ondelete="SET NULL"), - default=auth_user_ctx.get, + ForeignKey("user.id", ondelete="SET NULL"), + default=user_ctx.get, nullable=True, ) @@ -93,7 +93,7 @@ def log_create(cls, mapper: Mapper, connection: Connection, target: Mapper) -> N "action": "create", "post_change": target.dict(exclude_relationship=True), "parent_id": target.id, - "user_id": auth_user_ctx.get(), + "user_id": user_ctx.get(), }, ) @@ -109,7 +109,7 @@ def log_update(cls, mapper: Mapper, connection: Connection, target: Mapper) -> N "action": "update", "diff": changes["diff"], "parent_id": target.id, - "user_id": auth_user_ctx.get(), + "user_id": user_ctx.get(), }, ) @@ -122,7 +122,7 @@ def log_delete(cls, mapper: Mapper, connection: Connection, target: Mapper) -> N "action": "delete", "diff": target.dict(exclude_relationship=True), "parent_id": target.id, - "user_id": auth_user_ctx.get(), + "user_id": user_ctx.get(), }, ) @@ -137,12 +137,12 @@ class AuditUserMixin: @declared_attr @classmethod def created_by_fk(cls) -> Mapped[int | None]: - return mapped_column(Integer, ForeignKey("auth_user.id"), default=auth_user_ctx.get, nullable=True) + return mapped_column(Integer, ForeignKey("user.id"), default=user_ctx.get, nullable=True) @declared_attr @classmethod def updated_by_fk(cls) -> Mapped[int | None]: - return mapped_column(Integer, ForeignKey("auth_user.id"), default=auth_user_ctx.get, nullable=True) + return mapped_column(Integer, ForeignKey("user.id"), default=user_ctx.get, nullable=True) @declared_attr @classmethod