Skip to content

Commit

Permalink
fix(model): fix user name error
Browse files Browse the repository at this point in the history
  • Loading branch information
wangxin688 committed Jan 27, 2024
1 parent 05e707a commit 96c29ef
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/context.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from contextvars import ContextVar

request_id_ctx: ContextVar[str | None] = ContextVar("x-request-id", default=None)
auth_user_ctx: ContextVar[int | None] = ContextVar("x-auth-user", default=None)
user_ctx: ContextVar[int | None] = ContextVar("x-auth-user", default=None)
locale_ctx: ContextVar[str] = ContextVar("Accept-Language", default="en_US")
orm_diff_ctx: ContextVar[dict | None] = ContextVar("x-orm-diff", default=None)
celery_current_id: ContextVar[str | None] = ContextVar("x-celery-cid", default=None)
Expand Down
32 changes: 31 additions & 1 deletion src/db/_types.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import uuid
from datetime import date, datetime
from typing import Annotated
from enum import IntEnum
from typing import Annotated, TypeVar, no_type_check

from sqlalchemy import Boolean, Date, DateTime, Integer, String, func, type_coerce
from sqlalchemy.dialects.postgresql import BYTEA, UUID
from sqlalchemy.engine import Dialect
from sqlalchemy.orm import mapped_column
from sqlalchemy.sql import expression
from sqlalchemy.sql.elements import BindParameter, ColumnElement
from sqlalchemy.types import TypeDecorator

from src.config import settings

T = TypeVar("T", bound=IntEnum)


class EncryptedString(TypeDecorator):
impl = BYTEA
Expand All @@ -20,14 +24,40 @@ def __init__(self, secret_key: str | None = settings.SECRET_KEY) -> None:
super().__init__()
self.secret = secret_key

@no_type_check
def bind_expression(self, bind_value: BindParameter) -> ColumnElement | None:
bind_value = type_coerce(bind_value, String) # type: ignore # noqa: PGH003
return func.pgp_sym_encrypt(bind_value, self.secret)

@no_type_check
def column_expression(self, column: ColumnElement) -> ColumnElement | None:
return func.pgp_sym_decrypt(column, self.secret)


class IntegerEnum(TypeDecorator):
impl = Integer
cache_ok = True

def __init__(self, enum_type: type[T]) -> None:
super().__init__()
self.enum_type = enum_type

@no_type_check
def process_bind_param(self, value: int, dialect: Dialect) -> int: # noqa: ARG002
if isinstance(value, self.enum_type):
return value.value
msg = f"expected {self.enum_type.__name__} value, got {value.__class__.__name__}"
raise ValueError(msg)

@no_type_check
def process_result_value(self, value: int, dialect: Dialect): # noqa: ANN202, ARG002
return self.enum_type(value)

@no_type_check
def copy(self, **kwargs): # noqa: ANN202, ARG002, ANN003
return IntegerEnum(self.enum_type)


uuid_pk = Annotated[uuid.UUID, mapped_column(UUID(as_uuid=True), default=uuid.uuid4, primary_key=True)]
int_pk = Annotated[int, mapped_column(Integer, primary_key=True)]
bool_true = Annotated[bool, mapped_column(Boolean, server_default=expression.true())]
Expand Down
16 changes: 8 additions & 8 deletions src/db/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sqlalchemy.orm import Mapped, Mapper, class_mapper, mapped_column, relationship
from sqlalchemy.orm.attributes import get_history

from src.context import auth_user_ctx, orm_diff_ctx, request_id_ctx
from src.context import orm_diff_ctx, request_id_ctx, user_ctx
from src.db._types import int_pk
from src.db.base import Base

Expand Down Expand Up @@ -54,8 +54,8 @@ class AuditLog:
def user_id(cls) -> Mapped[int | None]:
return mapped_column(
Integer,
ForeignKey("auth_user.id", ondelete="SET NULL"),
default=auth_user_ctx.get,
ForeignKey("user.id", ondelete="SET NULL"),
default=user_ctx.get,
nullable=True,
)

Expand Down Expand Up @@ -93,7 +93,7 @@ def log_create(cls, mapper: Mapper, connection: Connection, target: Mapper) -> N
"action": "create",
"post_change": target.dict(exclude_relationship=True),
"parent_id": target.id,
"user_id": auth_user_ctx.get(),
"user_id": user_ctx.get(),
},
)

Expand All @@ -109,7 +109,7 @@ def log_update(cls, mapper: Mapper, connection: Connection, target: Mapper) -> N
"action": "update",
"diff": changes["diff"],
"parent_id": target.id,
"user_id": auth_user_ctx.get(),
"user_id": user_ctx.get(),
},
)

Expand All @@ -122,7 +122,7 @@ def log_delete(cls, mapper: Mapper, connection: Connection, target: Mapper) -> N
"action": "delete",
"diff": target.dict(exclude_relationship=True),
"parent_id": target.id,
"user_id": auth_user_ctx.get(),
"user_id": user_ctx.get(),
},
)

Expand All @@ -137,12 +137,12 @@ class AuditUserMixin:
@declared_attr
@classmethod
def created_by_fk(cls) -> Mapped[int | None]:
return mapped_column(Integer, ForeignKey("auth_user.id"), default=auth_user_ctx.get, nullable=True)
return mapped_column(Integer, ForeignKey("user.id"), default=user_ctx.get, nullable=True)

@declared_attr
@classmethod
def updated_by_fk(cls) -> Mapped[int | None]:
return mapped_column(Integer, ForeignKey("auth_user.id"), default=auth_user_ctx.get, nullable=True)
return mapped_column(Integer, ForeignKey("user.id"), default=user_ctx.get, nullable=True)

@declared_attr
@classmethod
Expand Down

0 comments on commit 96c29ef

Please sign in to comment.