Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CVAT] Add threading for exchange oracle blocking requests #2044

Draft
wants to merge 19 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions packages/examples/cvat/exchange-oracle/src/.env.template
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ ENVIRONMENT=
WORKERS_AMOUNT=
WEBHOOK_MAX_RETRIES=
WEBHOOK_DELAY_IF_FAILED=
MAX_WORKER_THREADS=

# DB

MAX_DB_CONNECTIONS=
DB_CONNECTION_RECYCLE_TIMEOUT=

# Postgres_config

Expand Down Expand Up @@ -35,14 +41,17 @@ PROCESS_RECORDING_ORACLE_WEBHOOKS_CHUNK_SIZE=
TRACK_COMPLETED_PROJECTS_INT=
TRACK_COMPLETED_PROJECTS_CHUNK_SIZE=
TRACK_COMPLETED_TASKS_INT=
TRACK_COMPLETED_TASKS_CHUNK_SIZE=
TRACK_COMPLETED_ESCROWS_INT=
TRACK_COMPLETED_ESCROWS_CHUNK_SIZE=
PROCESS_JOB_LAUNCHER_WEBHOOKS_INT=
TRACK_CREATING_TASKS_INT=
TRACK_CREATING_TASKS_CHUNK_SIZE=
TRACK_ASSIGNMENTS_INT=
TRACK_ASSIGNMENTS_CHUNK_SIZE=
REJECTED_PROJECTS_CHUNK_SIZE=
ACCEPTED_PROJECTS_CHUNK_SIZE=
TRACK_ESCROW_CREATION_CHUNK_SIZE=
TRACK_ESCROW_CREATION_INT=
TRACK_ESCROW_CREATION_CHUNK_SIZE=
TRACK_COMPLETED_ESCROWS_MAX_DOWNLOADING_RETRIES=
TRACK_COMPLETED_ESCROWS_JOBS_DOWNLOADING_BATCH_SIZE=

Expand Down
15 changes: 15 additions & 0 deletions packages/examples/cvat/exchange-oracle/src/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from src.endpoints import init_api
from src.handlers.error_handlers import setup_error_handlers
from src.log import setup_logging
from src.utils.concurrency import fastapi_set_max_threads

setup_logging()

Expand All @@ -31,6 +32,20 @@ async def startup_event():
logger = logging.getLogger("app")
logger.info("Exchange Oracle is up and running!")

if Config.features.db_connection_limit < Config.features.thread_limit:
logger.warn(
"The DB connection limit {} is less than maximum number of working threads {}. "
"This configuration can cause runtime errors on long blocking DB calls. "
"Consider changing values of the {} and {} environment variables.".format(
Config.features.db_connection_limit,
Config.features.thread_limit,
Config.features.DB_CONNECTION_LIMIT_ENV_VAR,
Config.features.THREAD_LIMIT_ENV_VAR,
)
)

await fastapi_set_max_threads(Config.features.thread_limit)


is_test = Config.environment == "test"
if not is_test:
Expand Down
18 changes: 17 additions & 1 deletion packages/examples/cvat/exchange-oracle/src/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ class CronConfig:
track_completed_projects_int = int(os.environ.get("TRACK_COMPLETED_PROJECTS_INT", 30))
track_completed_projects_chunk_size = os.environ.get("TRACK_COMPLETED_PROJECTS_CHUNK_SIZE", 5)
track_completed_tasks_int = int(os.environ.get("TRACK_COMPLETED_TASKS_INT", 30))
track_creating_tasks_chunk_size = os.environ.get("TRACK_CREATING_TASKS_CHUNK_SIZE", 5)
track_completed_tasks_chunk_size = os.environ.get("TRACK_COMPLETED_TASKS_CHUNK_SIZE", 20)
track_creating_tasks_int = int(os.environ.get("TRACK_CREATING_TASKS_INT", 300))
track_creating_tasks_chunk_size = os.environ.get("TRACK_CREATING_TASKS_CHUNK_SIZE", 5)
track_assignments_int = int(os.environ.get("TRACK_ASSIGNMENTS_INT", 5))
track_assignments_chunk_size = os.environ.get("TRACK_ASSIGNMENTS_CHUNK_SIZE", 10)

Expand Down Expand Up @@ -152,6 +153,9 @@ def bucket_url(cls):


class FeaturesConfig:
THREAD_LIMIT_ENV_VAR = "MAX_WORKER_THREADS"
DB_CONNECTION_LIMIT_ENV_VAR = "MAX_DB_CONNECTIONS"

enable_custom_cloud_host = to_bool(os.environ.get("ENABLE_CUSTOM_CLOUD_HOST", "no"))
"Allows using a custom host in manifest bucket urls"

Expand All @@ -164,6 +168,18 @@ class FeaturesConfig:
profiling_enabled = to_bool(os.getenv("PROFILING_ENABLED", False))
"Allow to profile specific requests"

thread_limit = int(os.getenv(THREAD_LIMIT_ENV_VAR, 5))
"Maximum number of threads for blocking requests"

db_connection_limit = int(os.getenv(DB_CONNECTION_LIMIT_ENV_VAR, 15))
"""
Maximum number of active parallel DB connections.
The recommended value is >= thread_limit + cron jobs count
"""

db_connection_recycle_timeout = int(os.getenv("DB_CONNECTION_RECYCLE_TIMEOUT", 600))
"DB connection lifetime after the last action on the connection, in seconds"


class CoreConfig:
default_assignment_time = int(os.environ.get("DEFAULT_ASSIGNMENT_TIME", 1800))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def track_completed_projects() -> None:
projects = cvat_service.get_projects_by_status(
session,
ProjectStatuses.annotation,
task_status=TaskStatuses.completed,
limit=CronConfig.track_completed_projects_chunk_size,
for_update=ForUpdateParams(skip_locked=True),
)
Expand Down Expand Up @@ -74,7 +75,12 @@ def track_completed_tasks() -> None:
logger.debug("Starting cron job")
with SessionLocal.begin() as session:
tasks = cvat_service.get_tasks_by_status(
session, TaskStatuses.annotation, for_update=ForUpdateParams(skip_locked=True)
session,
TaskStatuses.annotation,
job_status=JobStatuses.completed,
project_status=ProjectStatuses.annotation,
limit=CronConfig.track_completed_tasks_chunk_size,
for_update=ForUpdateParams(skip_locked=True),
)

completed_task_ids = []
Expand Down
24 changes: 21 additions & 3 deletions packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
import json
import logging
import zipfile
from contextlib import contextmanager
from contextvars import ContextVar
from datetime import timedelta
from enum import Enum
from http import HTTPStatus
from io import BytesIO
from time import sleep
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, Generator, List, Optional, Tuple

from cvat_sdk.api_client import ApiClient, Configuration, exceptions, models
from cvat_sdk.api_client.api_client import Endpoint
Expand Down Expand Up @@ -90,7 +92,23 @@ def _get_annotations(
return file_buffer


_api_client_context: ContextVar[ApiClient] = ContextVar("api_client", default=None)


@contextmanager
def api_client_context(api_client: ApiClient) -> Generator[ApiClient, None, None]:
old = _api_client_context.set(api_client)
try:
yield api_client
finally:
_api_client_context.reset(old)


def get_api_client() -> ApiClient:
current_api_client = _api_client_context.get()
if current_api_client:
return current_api_client

configuration = Configuration(
host=Config.cvat_config.cvat_url,
username=Config.cvat_config.cvat_admin,
Expand Down Expand Up @@ -559,15 +577,15 @@ def update_job_assignee(id: str, assignee_id: Optional[int]):
raise


def restart_job(id: str):
def restart_job(id: str, *, assignee_id: Optional[int] = None):
logger = logging.getLogger("app")

with get_api_client() as api_client:
try:
api_client.jobs_api.partial_update(
id=id,
patched_job_write_request=models.PatchedJobWriteRequest(
stage="annotation", state="new"
stage="annotation", state="new", assignee=assignee_id
),
)
except exceptions.ApiException as e:
Expand Down
2 changes: 2 additions & 0 deletions packages/examples/cvat/exchange-oracle/src/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
DATABASE_URL,
echo="debug" if Config.loglevel <= src.utils.logging.TRACE else False,
connect_args={"options": "-c lock_timeout={:d}".format(Config.postgres_config.lock_timeout)},
pool_size=Config.features.db_connection_limit,
pool_recycle=Config.features.db_connection_recycle_timeout,
)
SessionLocal = sessionmaker(autocommit=False, bind=engine)

Expand Down
42 changes: 31 additions & 11 deletions packages/examples/cvat/exchange-oracle/src/endpoints/exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,22 @@
import src.services.cvat as cvat_service
import src.services.exchange as oracle_service
from src.db import SessionLocal
from src.db import errors as db_errors
from src.schemas.exchange import AssignmentRequest, TaskResponse, UserRequest, UserResponse
from src.utils.concurrency import run_as_sync
from src.validators.signature import validate_human_app_signature

router = APIRouter()


@router.get("/tasks", description="Lists available tasks")
async def list_tasks(
def list_tasks(
wallet_address: Optional[str] = Query(default=None),
signature: str = Header(description="Calling service signature"),
) -> list[TaskResponse]:
await validate_human_app_signature(signature)
# Declare this endpoint as sync as it uses a lot of blocking IO (db, manifest, escrow)

run_as_sync(validate_human_app_signature, signature)

if not wallet_address:
return oracle_service.get_available_tasks()
Expand All @@ -28,11 +32,13 @@ async def list_tasks(


@router.put("/register", description="Binds a CVAT user a to HUMAN App user")
async def register(
def register(
user: UserRequest,
signature: str = Header(description="Calling service signature"),
) -> UserResponse:
await validate_human_app_signature(signature)
# Declare this endpoint as sync as it uses a lot of blocking IO (db, manifest, escrow, CVAT)

run_as_sync(validate_human_app_signature, signature)

with SessionLocal.begin() as session:
email_db_user = cvat_service.get_user_by_email(session, user.cvat_email, for_update=True)
Expand Down Expand Up @@ -97,19 +103,33 @@ async def register(
"/tasks/{id}/assignment",
description="Start an assignment within the task for the annotator",
)
async def create_assignment(
def create_assignment(
data: AssignmentRequest,
project_id: str = Path(alias="id"),
signature: str = Header(description="Calling service signature"),
) -> TaskResponse:
await validate_human_app_signature(signature)
# Declare this endpoint as sync as it uses a lot of blocking IO (db, manifest, escrow, CVAT)

run_as_sync(validate_human_app_signature, signature)

try:
assignment_id = oracle_service.create_assignment(
project_id=project_id, wallet_address=data.wallet_address
attempt = 0
max_attempts = 10
while attempt < max_attempts:
try:
assignment_id = oracle_service.create_assignment(
project_id=project_id, wallet_address=data.wallet_address
)
break
except oracle_service.UserHasUnfinishedAssignmentError as e:
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e)) from e
except db_errors.LockNotAvailable:
attempt += 1

if attempt >= max_attempts:
raise HTTPException(
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
detail="Too many requests at the moment, please try again later",
)
except oracle_service.UserHasUnfinishedAssignmentError as e:
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e)) from e

if not assignment_id:
raise HTTPException(
Expand Down
40 changes: 38 additions & 2 deletions packages/examples/cvat/exchange-oracle/src/endpoints/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import time
from typing import Any, Callable

import fastapi
import packaging.version as pv
from fastapi import FastAPI, Request, Response
from fastapi.responses import HTMLResponse, StreamingResponse
from pyinstrument import Profiler
Expand Down Expand Up @@ -58,14 +60,32 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):

"""

@staticmethod
async def _set_body(request: Request, body: bytes):
# Before FastAPI 0.108.0 infinite hang is expected,
# if request body is awaited more than once.
# It's not needed when using FastAPI >= 0.108.0.
# https://github.com/tiangolo/fastapi/discussions/8187#discussioncomment-7962889
if pv.parse(fastapi.__version__) >= pv.Version("0.108.0"):
return

async def receive():
return {"type": "http.request", "body": body}

request._receive = receive

def __init__(self, app: FastAPI) -> None:
super().__init__(app)
self.logger = get_root_logger()

self.max_displayed_body_size = 200

async def dispatch(self, request: Request, call_next: Callable) -> Response:
logging_dict: dict[str, Any] = {}

await request.body()
body = await request.body()
await self._set_body(request, body)

response, response_dict = await self._log_response(call_next, request)
request_dict = await self._log_request(request)
logging_dict["request"] = request_dict
Expand Down Expand Up @@ -97,10 +117,26 @@ async def _log_request(self, request: Request) -> dict[str, Any]:
}

try:
body = await request.json()
body = await request.body()
await self._set_body(request, body)
except Exception:
body = None
else:
if body is not None:
raw_body = False

if len(body) < self.max_displayed_body_size:
try:
body = json.loads(body)
except (json.JSONDecodeError, TypeError):
raw_body = True
else:
raw_body = True

if raw_body:
body = body.decode(errors="ignore")
body = body[: self.max_displayed_body_size]

request_logging["body"] = body

return request_logging
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def _process_skeletons_from_boxes_escrows(self):
except Exception as e:
logger.error(
"Failed to handle completed projects for escrow {}: {}".format(
escrow_address, e
completed_project.escrow_address, e
)
)
continue
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ async def http_exception_handler(_, exc):
@app.exception_handler(Exception)
async def generic_exception_handler(_, exc: Exception):
message = (
"Something went wrong" if Config.environment != "development" else ".".join(exc.args)
"Something went wrong"
if Config.environment != "development"
else ".".join(map(str, exc.args))
)

return JSONResponse(content={"message": message}, status_code=500)
Loading