diff --git a/labelu/alembic_labelu/alembic_labelu_tools.py b/labelu/alembic_labelu/alembic_labelu_tools.py index 0ed629e..ff6a1af 100644 --- a/labelu/alembic_labelu/alembic_labelu_tools.py +++ b/labelu/alembic_labelu/alembic_labelu_tools.py @@ -1,8 +1,7 @@ import json from alembic import op -from sqlalchemy import engine_from_config -from sqlalchemy.engine import reflection +import sqlalchemy as sa def table_exist(table_name): """check table is not exist @@ -13,20 +12,11 @@ def table_exist(table_name): Returns: bool: true or false, whether the table_name exists """ - config = op.get_context().config - engine = engine_from_config( - config.get_section(config.config_ini_section), prefix="sqlalchemy." - ) - insp = reflection.Inspector.from_engine(engine) - table_exist = False + conn = op.get_bind() + inspector = sa.inspect(conn) + tables = inspector.get_table_names() - for table in insp.get_table_names(): - if table_name not in table: - continue - table_exist = True - break - - return table_exist + return table_name in tables def column_exist_in_table(table_name, column_name): """check column is not exist in table @@ -38,17 +28,11 @@ def column_exist_in_table(table_name, column_name): Returns: bool: true or false, whether the column_name exists in the table_name """ - config = op.get_context().config - engine = engine_from_config( - config.get_section(config.config_ini_section), prefix="sqlalchemy." - ) - insp = reflection.Inspector.from_engine(engine) - column_exist = False - for col in insp.get_columns(table_name): - if column_name not in col["name"]: - continue - column_exist = True - return column_exist + conn = op.get_bind() + inspector = sa.inspect(conn) + columns = inspector.get_columns(table_name) + + return column_name in [column["name"] for column in columns] def get_tool_label_dict(task_config: dict) -> dict: diff --git a/labelu/alembic_labelu/versions/2eb983c9a254_add_collaborators_and_updaters.py b/labelu/alembic_labelu/versions/2eb983c9a254_add_collaborators_and_updaters.py index ce761d5..4d78c92 100644 --- a/labelu/alembic_labelu/versions/2eb983c9a254_add_collaborators_and_updaters.py +++ b/labelu/alembic_labelu/versions/2eb983c9a254_add_collaborators_and_updaters.py @@ -10,7 +10,6 @@ import sqlalchemy as sa from sqlalchemy.sql import table, column - # revision identifiers, used by Alembic. revision = '2eb983c9a254' down_revision = 'eb9c5b98168b' @@ -21,10 +20,11 @@ def upgrade() -> None: # Create task_collaborator table # if the table is not exists then create it - is_table_exist = op.execute( - "SHOW TABLES LIKE 'task_collaborator';" - ).fetchone() - if not is_table_exist: + conn = op.get_bind() + inspector = sa.inspect(conn) + tables = inspector.get_table_names() + + if 'task_collaborator' not in tables: op.create_table( 'task_collaborator', sa.Column('task_id', sa.Integer(), nullable=False), @@ -49,79 +49,67 @@ def upgrade() -> None: ) # Performances index - is_index_2_exist = op.execute( - "SHOW INDEX FROM task_collaborator WHERE Key_name = 'ix_task_collaborator_task_id';" - ).fetchone() + indices = inspector.get_indexes('task_collaborator') + existing_index_names = {idx['name'] for idx in indices} - if is_index_2_exist: - return - - op.create_index( - 'ix_task_collaborator_task_id', - 'task_collaborator', - ['task_id'] - ) - op.create_index( - 'ix_task_collaborator_user_id', - 'task_collaborator', - ['user_id'] - ) - op.create_index( - 'ix_task_created_by_deleted_at', - 'task', - ['created_by', 'deleted_at'] - ) + if 'ix_task_collaborator_task_id' not in existing_index_names: + op.create_index( + 'ix_task_collaborator_task_id', + 'task_collaborator', + ['task_id'] + ) + op.create_index( + 'ix_task_collaborator_user_id', + 'task_collaborator', + ['user_id'] + ) + op.create_index( + 'ix_task_created_by_deleted_at', + 'task', + ['created_by', 'deleted_at'] + ) # Task sample: updater -> updaters; create a new table task_sample_updater - is_task_sample_updater_table_exist = op.execute( - "SHOW TABLES LIKE 'task_sample_updater';" - ).fetchone() - - if is_task_sample_updater_table_exist: - return - - op.create_table( - 'task_sample_updater', - sa.Column('sample_id', sa.Integer(), nullable=False), - sa.Column('user_id', sa.Integer(), nullable=False), - sa.Column( - 'created_at', - sa.DateTime(timezone=True), - server_default=sa.text('CURRENT_TIMESTAMP'), - nullable=False - ), - sa.ForeignKeyConstraint( - ['sample_id'], - ['task_sample.id'], - ondelete='CASCADE' - ), - sa.ForeignKeyConstraint( - ['user_id'], - ['user.id'], - ondelete='CASCADE' - ), - sa.PrimaryKeyConstraint('sample_id', 'user_id') - ) - - # Performances index - # check if the index is already exists - is_index_exist = op.execute( - "SHOW INDEX FROM task_sample_updater WHERE Key_name = 'ix_task_sample_updater_sample_id';" - ).fetchone() - - if is_index_exist: - return - - op.create_index( - 'ix_task_sample_updater_sample_id', - 'task_sample_updater', - ['sample_id'] - ) - op.create_index( - 'ix_task_sample_updater_user_id', - 'task_sample_updater', - ['user_id'] - ) + if 'task_sample_updater' not in tables: + op.create_table( + 'task_sample_updater', + sa.Column('sample_id', sa.Integer(), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=False), + sa.Column( + 'created_at', + sa.DateTime(timezone=True), + server_default=sa.text('CURRENT_TIMESTAMP'), + nullable=False + ), + sa.ForeignKeyConstraint( + ['sample_id'], + ['task_sample.id'], + ondelete='CASCADE' + ), + sa.ForeignKeyConstraint( + ['user_id'], + ['user.id'], + ondelete='CASCADE' + ), + sa.PrimaryKeyConstraint('sample_id', 'user_id') + ) + + # Performances index + # check if the index is already exists + indices = inspector.get_indexes('task_sample_updater') + existing_index_names = {idx['name'] for idx in indices} + + if 'ix_task_sample_updater_sample_id' not in existing_index_names: + op.create_index( + 'ix_task_sample_updater_sample_id', + 'task_sample_updater', + ['sample_id'] + ) + op.create_index( + 'ix_task_sample_updater_user_id', + 'task_sample_updater', + ['user_id'] + ) # Migrate data from task_sample.updated_by to task_sample_updater task_sample = table( diff --git a/labelu/alembic_labelu/versions/9d5da133bbe4_replace_key_with_value_in_sample_table.py b/labelu/alembic_labelu/versions/9d5da133bbe4_replace_key_with_value_in_sample_table.py index f229c4e..632e2ae 100644 --- a/labelu/alembic_labelu/versions/9d5da133bbe4_replace_key_with_value_in_sample_table.py +++ b/labelu/alembic_labelu/versions/9d5da133bbe4_replace_key_with_value_in_sample_table.py @@ -13,11 +13,38 @@ from sqlalchemy.orm import sessionmaker from sqlalchemy.ext.automap import automap_base -from labelu.alembic_labelu.alembic_labelu_tools import ( - get_tool_label_dict, - replace_key_with_value, -) +def get_tool_label_dict(task_config: dict) -> dict: + """get the key value of labels in a given task_id and task_config""" + label_dict = {"无标签": "noneAttribute"} + # obtain the labels in current task + # get the general labels + for normal_label in task_config.get("attribute", []): + if normal_label.get("key", ""): + label_dict[normal_label.get("key")] = normal_label.get("value") + # get the labels in configuration defined by user + for task_tool in task_config.get("tools", []): + if "config" not in task_tool.keys(): + continue + labels = task_tool.get("config").get("attributeList", []) + for label in labels: + if label.get("key", ""): + label_dict[label.get("key")] = label.get("value") + return label_dict + + +def replace_key_with_value(sample_data: dict, label_dict: dict) -> dict: + """replace the key with value in task_sample table to modify the error for the history version""" + + annotated_result = sample_data.get("result") + annotated_result = json.loads(annotated_result) + for sample_tool, sample_tool_results in annotated_result.items(): + if sample_tool.endswith("Tool"): + for sample_tool_result in sample_tool_results.get("result", []): + tool_label = sample_tool_result.get("attribute", "") + if tool_label in label_dict: + sample_tool_result["attribute"] = label_dict[tool_label] + return annotated_result # revision identifiers, used by Alembic. revision = "9d5da133bbe4" diff --git a/labelu/alembic_labelu/versions/bc8fcb35b66b_add_media_and_pre_annotation.py b/labelu/alembic_labelu/versions/bc8fcb35b66b_add_media_and_pre_annotation.py index 416c5c0..c2619b7 100644 --- a/labelu/alembic_labelu/versions/bc8fcb35b66b_add_media_and_pre_annotation.py +++ b/labelu/alembic_labelu/versions/bc8fcb35b66b_add_media_and_pre_annotation.py @@ -5,7 +5,6 @@ Create Date: 2024-02-07 15:58:30.618151 """ -import imp import json import os @@ -14,6 +13,7 @@ from sqlalchemy.ext.automap import automap_base from sqlalchemy.orm import sessionmaker from labelu.internal.common.config import settings +from labelu.alembic_labelu.alembic_labelu_tools import table_exist, column_exist_in_table Base = automap_base() @@ -24,17 +24,6 @@ branch_labels = None depends_on = None -# import alembic_labelu_tools from the absolute path -alembic_labelu_tools = imp.load_source( - "alembic_labelu_tools", - ( - os.path.join( - os.path.abspath(os.path.dirname(os.path.dirname(__file__))), - "alembic_labelu_tools.py", - ) - ), -) - def upgrade() -> None: bind = op.get_bind() Base.prepare(autoload_with=bind, reflect=True) @@ -47,7 +36,7 @@ def upgrade() -> None: with context.begin_transaction(): # Create a new table task_pre_annotation - if not alembic_labelu_tools.table_exist("task_pre_annotation"): + if not table_exist("task_pre_annotation"): op.create_table( "task_pre_annotation", sa.Column("id", sa.Integer, primary_key=True, autoincrement=True, index=True), @@ -78,7 +67,7 @@ def upgrade() -> None: ), ) # Update the task_sample table - if not alembic_labelu_tools.column_exist_in_table( + if not column_exist_in_table( "task_sample", "file_id" ): with op.batch_alter_table('task_sample', recreate="always") as batch_op: @@ -93,7 +82,7 @@ def upgrade() -> None: ) # Update the task_attachment table - if not alembic_labelu_tools.column_exist_in_table("task_attachment", "filename"): + if not column_exist_in_table("task_attachment", "filename"): with op.batch_alter_table("task_attachment", recreate="always") as batch_op_task_attachment: batch_op_task_attachment.add_column( sa.Column( diff --git a/labelu/alembic_labelu/versions/e76c2ca5562e_add_inner_id.py b/labelu/alembic_labelu/versions/e76c2ca5562e_add_inner_id.py index c1d3256..1ade91a 100644 --- a/labelu/alembic_labelu/versions/e76c2ca5562e_add_inner_id.py +++ b/labelu/alembic_labelu/versions/e76c2ca5562e_add_inner_id.py @@ -5,24 +5,12 @@ Create Date: 2023-02-28 22:29:31.595257 """ -import imp -import os - from alembic import op from alembic import context import sqlalchemy as sa from sqlalchemy.sql import text +from labelu.alembic_labelu.alembic_labelu_tools import column_exist_in_table -# import alembic_labelu_tools from the absolute path -alembic_labelu_tools = imp.load_source( - "alembic_labelu_tools", - ( - os.path.join( - os.path.abspath(os.path.dirname(os.path.dirname(__file__))), - "alembic_labelu_tools.py", - ) - ), -) # revision identifiers, used by Alembic. revision = "e76c2ca5562e" @@ -36,7 +24,7 @@ def upgrade() -> None: add inner_id and last_sample_inner_id columns in the task and sample tables """ with context.begin_transaction(): - if not alembic_labelu_tools.column_exist_in_table( + if not column_exist_in_table( "task", "last_sample_inner_id" ): with op.batch_alter_table("task", recreate="always") as batch_op_task: @@ -50,7 +38,7 @@ def upgrade() -> None: ), insert_before="config", ) - if not alembic_labelu_tools.column_exist_in_table("task_sample", "inner_id"): + if not column_exist_in_table("task_sample", "inner_id"): with op.batch_alter_table( "task_sample", recreate="always" ) as batch_op_task_sample: diff --git a/labelu/internal/application/service/task.py b/labelu/internal/application/service/task.py index 974c315..5bab831 100644 --- a/labelu/internal/application/service/task.py +++ b/labelu/internal/application/service/task.py @@ -102,6 +102,18 @@ async def get(db: Session, task_id: int, current_user: User) -> TaskResponseWith # get task detail task = crud_task.get(db=db, task_id=task_id) + + # not the collaborators + if task.created_by != current_user.id and current_user not in task.collaborators: + logger.error( + "cannot get task, the task owner is:{}, the get operator is:{}", + task.created_by, + current_user.id, + ) + raise LabelUException( + code=ErrorCode.CODE_30001_NO_PERMISSION, + status_code=status.HTTP_403_FORBIDDEN, + ) if not task: logger.error("cannot find task:{}", task_id) raise LabelUException( diff --git a/labelu/internal/dependencies/user.py b/labelu/internal/dependencies/user.py index 509cdb3..2d9ae87 100644 --- a/labelu/internal/dependencies/user.py +++ b/labelu/internal/dependencies/user.py @@ -43,7 +43,7 @@ def get_current_user( ) user = crud_user.get(db, id=token_data.id) if not user: - raise LabelUException(code=ErrorCode.CODE_40002_USER_NOT_FOUND, status_code=404) + raise LabelUException(code=ErrorCode.CODE_40002_USER_NOT_FOUND, status_code=401) return user