Skip to content

Commit

Permalink
feat: supports task collaborators (#189)
Browse files Browse the repository at this point in the history
  • Loading branch information
gary-Shen authored Feb 21, 2025
1 parent 840a313 commit b159c63
Show file tree
Hide file tree
Showing 43 changed files with 906 additions and 235 deletions.
9 changes: 1 addition & 8 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,8 @@ jobs:
# ====================== release ======================

- name: Set test pip url
if: ${{ github.ref_name == 'alpha' }}
env:
NEXT_VERSION: ${{ env.NEXT_VERSION }}
run: |
echo "PYPI_URL=https://test.pypi.org/project/labelu/${{ env.NEXT_VERSION }}" >> $GITHUB_ENV
- name: Set release pip url
if: ${{ github.ref_name == 'main' }}
if: ${{ github.ref_name == 'main' || github.ref_name == 'alpha' }}
env:
NEXT_VERSION: ${{ env.NEXT_VERSION }}
run: |
Expand Down
20 changes: 0 additions & 20 deletions docker/conf.d/labelu.cnf

This file was deleted.

7 changes: 0 additions & 7 deletions labelu/alembic_labelu/env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import os
import sys
from logging.config import fileConfig

from sqlalchemy import engine_from_config
Expand All @@ -8,11 +6,6 @@
from alembic import context

from labelu.internal.common.db import Base
from labelu.internal.domain.models.task import Task
from labelu.internal.domain.models.user import User
from labelu.internal.domain.models.sample import TaskSample
from labelu.internal.domain.models.attachment import TaskAttachment
from labelu.internal.common.config import settings


# this is the Alembic Config object, which provides
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""add_collaborators_and_updaters
Revision ID: 2eb983c9a254
Revises: eb9c5b98168b
Create Date: 2025-02-19 16:16:39.259779
"""
from datetime import datetime
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import table, column


# revision identifiers, used by Alembic.
revision = '2eb983c9a254'
down_revision = 'eb9c5b98168b'
branch_labels = None
depends_on = None


def upgrade() -> None:
# Create task_collaborator table
op.create_table(
'task_collaborator',
sa.Column('task_id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column(
'created_at',
sa.DateTime(timezone=True),
server_default=sa.text('CURRENT_TIMESTAMP'),
nullable=False
),
sa.ForeignKeyConstraint(
['task_id'],
['task.id'],
ondelete='CASCADE'
),
sa.ForeignKeyConstraint(
['user_id'],
['user.id'],
ondelete='CASCADE'
),
sa.PrimaryKeyConstraint('task_id', 'user_id')
)

# Performances index
op.create_index(
'ix_task_collaborator_task_id',
'task_collaborator',
['task_id']
)
op.create_index(
'ix_task_collaborator_user_id',
'task_collaborator',
['user_id']
)
op.create_index(
'ix_task_created_by_deleted_at',
'task',
['created_by', 'deleted_at']
)

# Task sample: updater -> updaters; create a new table task_sample_updater
op.create_table(
'task_sample_updater',
sa.Column('sample_id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column(
'created_at',
sa.DateTime(timezone=True),
server_default=sa.text('CURRENT_TIMESTAMP'),
nullable=False
),
sa.ForeignKeyConstraint(
['sample_id'],
['task_sample.id'],
ondelete='CASCADE'
),
sa.ForeignKeyConstraint(
['user_id'],
['user.id'],
ondelete='CASCADE'
),
sa.PrimaryKeyConstraint('sample_id', 'user_id')
)

# Performances index
op.create_index(
'ix_task_sample_updater_sample_id',
'task_sample_updater',
['sample_id']
)
op.create_index(
'ix_task_sample_updater_user_id',
'task_sample_updater',
['user_id']
)

# Migrate data from task_sample.updated_by to task_sample_updater
task_sample = table(
'task_sample',
column('id', sa.Integer),
column('updated_by', sa.Integer),
column('updated_at', sa.DateTime)
)

task_sample_updater = table(
'task_sample_updater',
column('sample_id', sa.Integer),
column('user_id', sa.Integer),
column('created_at', sa.DateTime)
)

conn = op.get_bind()
for row in conn.execute(sa.select([task_sample.c.id, task_sample.c.updated_by, task_sample.c.updated_at])):
if row.updated_by:
conn.execute(
task_sample_updater.insert().values(
sample_id=row.id,
user_id=row.updated_by,
created_at=row.updated_at or datetime.now()
)
)


def downgrade() -> None:
op.drop_index('ix_task_collaborator_user_id')
op.drop_index('ix_task_collaborator_task_id')

op.drop_table('task_collaborator')

op.drop_index('ix_task_sample_updater_user_id')
op.drop_index('ix_task_sample_updater_task_sample_id')

op.drop_table('task_sample_updater')
8 changes: 4 additions & 4 deletions labelu/internal/adapter/persistence/crud_attachment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@

def list_by(
db: Session,
pageSize: int,
size: int,
ids: List[int] | None = [],
task_id: Optional[int] = None,
owner_id: Optional[int] = None,
after: Optional[int] = None,
before: Optional[int] = None,
pageNo: Optional[int] = None,
page: Optional[int] = None,
sorting: Optional[str] = None,
) -> Tuple[List[TaskAttachment], int]:
# query filter
Expand All @@ -39,8 +39,8 @@ def list_by(
count = query.count()

results = (
query.offset(offset=pageNo * pageSize if pageNo else 0)
.limit(limit=pageSize)
query.offset(offset=page * size if page else 0)
.limit(limit=size)
.all()
)

Expand Down
22 changes: 9 additions & 13 deletions labelu/internal/adapter/persistence/crud_pre_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,17 @@ def batch(db: Session, pre_annotations: List[TaskPreAnnotation]) -> List[TaskPre

def list_by(
db: Session,
owner_id: int,
task_id: int | None = None,
sample_name: str | None = None,
after: int | None = None,
before: int | None = None,
pageNo: int | None = None,
page: int | None = None,
sorting: str | None = None,
pageSize: int | None = 10,
size: int | None = 10,
) -> Tuple[List[TaskPreAnnotation], int]:

# query filter
query_filter = [TaskPreAnnotation.created_by == owner_id, TaskPreAnnotation.deleted_at == None]
query_filter = [TaskPreAnnotation.deleted_at == None]
if before:
query_filter.append(TaskPreAnnotation.id < before)
if after:
Expand All @@ -48,8 +47,8 @@ def list_by(
count = query.count()

results = (
query.offset(offset=pageNo * pageSize if pageNo else 0)
.limit(limit=pageSize)
query.offset(offset=page * size if page else 0)
.limit(limit=size)
.all()
)

Expand All @@ -65,19 +64,17 @@ def list_by(

return results, count

def list_by_task_id_and_owner_id(db: Session, task_id: int, owner_id: int) -> Dict[str, List[TaskPreAnnotation]]:
def list_by_task_id_and_owner_id(db: Session, task_id: int) -> Dict[str, List[TaskPreAnnotation]]:
pre_annotations = db.query(TaskPreAnnotation).filter(
TaskPreAnnotation.task_id == task_id,
TaskPreAnnotation.deleted_at == None,
TaskPreAnnotation.created_by == owner_id
).all()

return pre_annotations

def list_by_task_id_and_file_id(db: Session, task_id: int, file_id: int, owner_id: int) -> List[TaskPreAnnotation]:
def list_by_task_id_and_file_id(db: Session, task_id: int, file_id: int) -> List[TaskPreAnnotation]:
return db.query(TaskPreAnnotation).filter(
TaskPreAnnotation.task_id == task_id,
TaskPreAnnotation.created_by == owner_id,
TaskPreAnnotation.deleted_at == None,
TaskPreAnnotation.file_id == file_id
).all()
Expand All @@ -97,7 +94,6 @@ def list_by_task_id_and_owner_id_and_sample_name(db: Session, task_id: int, owne
return db.query(TaskPreAnnotation).filter(
TaskPreAnnotation.task_id == task_id,
TaskPreAnnotation.deleted_at == None,
TaskPreAnnotation.created_by == owner_id,
TaskPreAnnotation.sample_name == sample_name
).all()

Expand Down Expand Up @@ -134,8 +130,8 @@ def delete(db: Session, pre_annotation_ids: List[int]) -> None:
)


def count(db: Session, task_id: int, owner_id: int, sample_name: str | None) -> int:
query_filter = [TaskPreAnnotation.created_by == owner_id, TaskPreAnnotation.deleted_at == None]
def count(db: Session, task_id: int, sample_name: str | None) -> int:
query_filter = [TaskPreAnnotation.deleted_at == None]
if task_id:
query_filter.append(TaskPreAnnotation.task_id == task_id)

Expand Down
18 changes: 8 additions & 10 deletions labelu/internal/adapter/persistence/crud_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,15 @@ def batch(db: Session, samples: List[TaskSample]) -> List[TaskSample]:
def list_by(
db: Session,
task_id: Union[int, None],
owner_id: int,
after: Union[int, None],
before: Union[int, None],
pageNo: Union[int, None],
pageSize: int,
page: Union[int, None],
size: int,
sorting: Union[str, None],
) -> List[TaskSample]:

# query filter
query_filter = [TaskSample.created_by == owner_id, TaskSample.deleted_at == None]
query_filter = [TaskSample.deleted_at == None]
if before:
query_filter.append(TaskSample.id < before)
if after:
Expand Down Expand Up @@ -58,8 +57,8 @@ def list_by(
else:
query = query.order_by(TaskSample.id.asc())
results = (
query.offset(offset=pageNo * pageSize if pageNo else 0)
.limit(limit=pageSize)
query.offset(offset=page * size if page else 0)
.limit(limit=size)
.all()
)
if before:
Expand Down Expand Up @@ -100,19 +99,18 @@ def delete(db: Session, sample_ids: List[int]) -> None:
)


def count(db: Session, task_id: int, owner_id: int) -> int:
query_filter = [TaskSample.created_by == owner_id, TaskSample.deleted_at == None]
def count(db: Session, task_id: int) -> int:
query_filter = [TaskSample.deleted_at == None]
if task_id:
query_filter.append(TaskSample.task_id == task_id)
return db.query(TaskSample).filter(*query_filter).count()


def statics(
db: Session,
owner_id: int,
task_ids: List[int],
) -> dict:
query_filter = [TaskSample.created_by == owner_id, TaskSample.deleted_at == None]
query_filter = [TaskSample.deleted_at == None]
if task_ids:
query_filter.append(TaskSample.task_id.in_(task_ids))

Expand Down
27 changes: 22 additions & 5 deletions labelu/internal/adapter/persistence/crud_task.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from datetime import datetime
from typing import Any, Dict, List
from sqlalchemy import or_, exists
from typing import Any, Dict, List, Tuple

from sqlalchemy.orm import Session
from fastapi.encoders import jsonable_encoder

from labelu.internal.domain.models.task import Task
from labelu.internal.domain.models.task_collaborator import TaskCollaborator


def create(db: Session, task: Task) -> Task:
Expand All @@ -14,15 +16,30 @@ def create(db: Session, task: Task) -> Task:
return task


def list_by(db: Session, owner_id: int, page: int = 0, size: int = 100) -> List[Task]:
return (
def list_by(db: Session, owner_id: int, page: int = 0, size: int = 100) -> Tuple[List[Task], int]:
collaborator_exists = exists().where(
TaskCollaborator.task_id == Task.id,
TaskCollaborator.user_id == owner_id
)
query = (
db.query(Task)
.filter(Task.created_by == owner_id, Task.deleted_at == None)
.filter(
Task.deleted_at == None,
or_(
Task.created_by == owner_id,
collaborator_exists
)
)
)
total = query.count()

return (
query
.order_by(Task.id.desc())
.offset(offset=page * size)
.limit(limit=size)
.all()
)
), total


def get(db: Session, task_id: int, lock: bool = False) -> Task:
Expand Down
Loading

0 comments on commit b159c63

Please sign in to comment.