diff --git a/backend/pyproject.toml b/backend/pyproject.toml index c76662d..92f3023 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -46,11 +46,12 @@ dev-dependencies = [ "pre-commit>=3.5.0", "ruff>=0.5.1", "pytest-cov>=4.1.0", - "pytest-asyncio>=0.23.7", + "pytest-asyncio==0.20.3", # breaking change in 0.23: https://github.com/pytest-dev/pytest-asyncio/issues/706 "black>=23.12.1", "mypy>=1.8.0", "pytest-sugar>=1.0.0", "pytest-tldr>=0.2.5", + "polyfactory>=2.16.0", ] [tool.hatch.metadata] @@ -89,7 +90,7 @@ fixable = ["ALL"] [tool.ruff.lint.extend-per-file-ignores] "env.py" = ["INP001", "I001", "ERA001"] -"tests/*.py" = ["S101", "ANN201", "PLR2004"] +"tests/*.py" = ["S101", "ANN201", "PLR2004", "ANN001"] "*exceptions.py" = ["ARG001"] "models.py" = ["RUF012"] "schemas.py" = ["ANN201"] diff --git a/backend/requirements-dev.lock b/backend/requirements-dev.lock index 42f392d..a79e63a 100644 --- a/backend/requirements-dev.lock +++ b/backend/requirements-dev.lock @@ -63,6 +63,8 @@ dnspython==2.5.0 email-validator==2.1.0.post1 # via fastapi # via netsight +faker==26.0.0 + # via polyfactory fastapi==0.111.0 # via netsight fastapi-cli==0.0.4 @@ -144,6 +146,7 @@ platformdirs==4.1.0 # via virtualenv pluggy==1.3.0 # via pytest +polyfactory==2.16.0 pre-commit==3.6.0 prompt-toolkit==3.0.47 # via click-repl @@ -172,12 +175,13 @@ pytest==7.4.4 # via pytest-cov # via pytest-sugar # via pytest-tldr -pytest-asyncio==0.23.7 +pytest-asyncio==0.20.3 pytest-cov==4.1.0 pytest-sugar==1.0.0 pytest-tldr==0.2.5 python-dateutil==2.8.2 # via celery + # via faker # via pandas python-dotenv==1.0.0 # via pydantic-settings @@ -231,6 +235,7 @@ typing-extensions==4.9.0 # via alembic # via fastapi # via mypy + # via polyfactory # via pydantic # via pydantic-core # via sqlalchemy diff --git a/backend/src/core/database/session.py b/backend/src/core/database/session.py index a5b563f..c9ac3fb 100644 --- a/backend/src/core/database/session.py +++ b/backend/src/core/database/session.py @@ -1,24 +1,17 @@ import logging from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine -from sqlalchemy.pool import NullPool -from src.core.config import _Env, settings +from src.core.config import settings logger = logging.getLogger(__name__) -if _Env.PROD.name == settings.ENV: - async_engine = create_async_engine( - url=settings.SQLALCHEMY_DATABASE_URI, - pool_pre_ping=True, - future=True, - pool_size=settings.DATABASE_POOL_SIZE, - connect_args={"server_settings": {"jit": "off"}}, - max_overflow=settings.DATABASE_POOL_MAX_OVERFLOW, - ) -else: - # compatibility with pytest - async_engine = create_async_engine( - url=settings.SQLALCHEMY_DATABASE_URI, pool_pre_ping=True, future=True, poolclass=NullPool - ) +async_engine = create_async_engine( + url=settings.SQLALCHEMY_DATABASE_URI, + pool_pre_ping=True, + future=True, + pool_size=settings.DATABASE_POOL_SIZE, + connect_args={"server_settings": {"jit": "off"}}, + max_overflow=settings.DATABASE_POOL_MAX_OVERFLOW, +) async_session = async_sessionmaker(async_engine, autoflush=False, expire_on_commit=False) diff --git a/backend/src/features/deps.py b/backend/src/features/deps.py index 7f8776f..3d4b402 100644 --- a/backend/src/features/deps.py +++ b/backend/src/features/deps.py @@ -25,15 +25,15 @@ token = HTTPBearer() -async def get_session() -> AsyncGenerator["AsyncSession", None]: - async with async_session() as session: - try: - yield session - except SQLAlchemyError as e: - logger.exception("Database error: %s", e) # noqa: TRY401 - await session.rollback() - finally: - await session.aclose() +async def get_session() -> AsyncGenerator[AsyncSession, None]: + session: AsyncSession = async_session() + try: + yield session + except SQLAlchemyError as e: + logger.exception("Database error: %s", e) # noqa: TRY401 + await session.rollback() + finally: + await session.aclose() async def auth( diff --git a/backend/src/features/org/schemas.py b/backend/src/features/org/schemas.py index 2bebeab..aa310a9 100644 --- a/backend/src/features/org/schemas.py +++ b/backend/src/features/org/schemas.py @@ -53,7 +53,7 @@ class SiteBase(BaseModel): address: str latitude: float longitude: float - classfication: str | None = None + classification: str | None = None comments: str | None diff --git a/backend/tests/factoreis.py b/backend/tests/factoreis.py index 9b9682a..49f2917 100644 --- a/backend/tests/factoreis.py +++ b/backend/tests/factoreis.py @@ -1,8 +1,11 @@ +from polyfactory.factories.pydantic_factory import ModelFactory from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory from polyfactory.fields import Ignore from src.features.admin.models import Group, Menu, Permission, Role, User from src.features.admin.security import get_password_hash +from src.features.org import models as org_models +from src.features.org import schemas as org_schemas class UserFactory(SQLAlchemyFactory[User]): @@ -35,3 +38,20 @@ class MenuFactory(SQLAlchemyFactory[Menu]): id = Ignore() permission = [1, 2, 3, 4] # noqa: RUF012 + + +class SiteCreateFactory(ModelFactory[org_schemas.SiteCreate]): + __model__ = org_schemas.SiteCreate + + +class SiteGroupCreateFactory(ModelFactory[org_schemas.SiteGroupCreate]): + __model__ = org_schemas.SiteGroupCreate + + +class SiteFactory(SQLAlchemyFactory[org_models.Site]): + __model__ = org_models.Site + __set_foreign_keys__ = False + id = Ignore() + status = "Active" + country = "China" + time_zone = 8