Skip to content

Commit

Permalink
update: update migration script (#193)
Browse files Browse the repository at this point in the history
  • Loading branch information
gary-Shen authored Feb 21, 2025
1 parent f917091 commit 047380a
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 136 deletions.
36 changes: 10 additions & 26 deletions labelu/alembic_labelu/alembic_labelu_tools.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import json

from alembic import op
from sqlalchemy import engine_from_config
from sqlalchemy.engine import reflection
import sqlalchemy as sa

def table_exist(table_name):
"""check table is not exist
Expand All @@ -13,20 +12,11 @@ def table_exist(table_name):
Returns:
bool: true or false, whether the table_name exists
"""
config = op.get_context().config
engine = engine_from_config(
config.get_section(config.config_ini_section), prefix="sqlalchemy."
)
insp = reflection.Inspector.from_engine(engine)
table_exist = False
conn = op.get_bind()
inspector = sa.inspect(conn)
tables = inspector.get_table_names()

for table in insp.get_table_names():
if table_name not in table:
continue
table_exist = True
break

return table_exist
return table_name in tables

def column_exist_in_table(table_name, column_name):
"""check column is not exist in table
Expand All @@ -38,17 +28,11 @@ def column_exist_in_table(table_name, column_name):
Returns:
bool: true or false, whether the column_name exists in the table_name
"""
config = op.get_context().config
engine = engine_from_config(
config.get_section(config.config_ini_section), prefix="sqlalchemy."
)
insp = reflection.Inspector.from_engine(engine)
column_exist = False
for col in insp.get_columns(table_name):
if column_name not in col["name"]:
continue
column_exist = True
return column_exist
conn = op.get_bind()
inspector = sa.inspect(conn)
columns = inspector.get_columns(table_name)

return column_name in [column["name"] for column in columns]


def get_tool_label_dict(task_config: dict) -> dict:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import sqlalchemy as sa
from sqlalchemy.sql import table, column


# revision identifiers, used by Alembic.
revision = '2eb983c9a254'
down_revision = 'eb9c5b98168b'
Expand All @@ -21,10 +20,11 @@
def upgrade() -> None:
# Create task_collaborator table
# if the table is not exists then create it
is_table_exist = op.execute(
"SHOW TABLES LIKE 'task_collaborator';"
).fetchone()
if not is_table_exist:
conn = op.get_bind()
inspector = sa.inspect(conn)
tables = inspector.get_table_names()

if 'task_collaborator' not in tables:
op.create_table(
'task_collaborator',
sa.Column('task_id', sa.Integer(), nullable=False),
Expand All @@ -49,79 +49,67 @@ def upgrade() -> None:
)

# Performances index
is_index_2_exist = op.execute(
"SHOW INDEX FROM task_collaborator WHERE Key_name = 'ix_task_collaborator_task_id';"
).fetchone()
indices = inspector.get_indexes('task_collaborator')
existing_index_names = {idx['name'] for idx in indices}

if is_index_2_exist:
return

op.create_index(
'ix_task_collaborator_task_id',
'task_collaborator',
['task_id']
)
op.create_index(
'ix_task_collaborator_user_id',
'task_collaborator',
['user_id']
)
op.create_index(
'ix_task_created_by_deleted_at',
'task',
['created_by', 'deleted_at']
)
if 'ix_task_collaborator_task_id' not in existing_index_names:
op.create_index(
'ix_task_collaborator_task_id',
'task_collaborator',
['task_id']
)
op.create_index(
'ix_task_collaborator_user_id',
'task_collaborator',
['user_id']
)
op.create_index(
'ix_task_created_by_deleted_at',
'task',
['created_by', 'deleted_at']
)

# Task sample: updater -> updaters; create a new table task_sample_updater
is_task_sample_updater_table_exist = op.execute(
"SHOW TABLES LIKE 'task_sample_updater';"
).fetchone()

if is_task_sample_updater_table_exist:
return

op.create_table(
'task_sample_updater',
sa.Column('sample_id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column(
'created_at',
sa.DateTime(timezone=True),
server_default=sa.text('CURRENT_TIMESTAMP'),
nullable=False
),
sa.ForeignKeyConstraint(
['sample_id'],
['task_sample.id'],
ondelete='CASCADE'
),
sa.ForeignKeyConstraint(
['user_id'],
['user.id'],
ondelete='CASCADE'
),
sa.PrimaryKeyConstraint('sample_id', 'user_id')
)

# Performances index
# check if the index is already exists
is_index_exist = op.execute(
"SHOW INDEX FROM task_sample_updater WHERE Key_name = 'ix_task_sample_updater_sample_id';"
).fetchone()

if is_index_exist:
return

op.create_index(
'ix_task_sample_updater_sample_id',
'task_sample_updater',
['sample_id']
)
op.create_index(
'ix_task_sample_updater_user_id',
'task_sample_updater',
['user_id']
)
if 'task_sample_updater' not in tables:
op.create_table(
'task_sample_updater',
sa.Column('sample_id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column(
'created_at',
sa.DateTime(timezone=True),
server_default=sa.text('CURRENT_TIMESTAMP'),
nullable=False
),
sa.ForeignKeyConstraint(
['sample_id'],
['task_sample.id'],
ondelete='CASCADE'
),
sa.ForeignKeyConstraint(
['user_id'],
['user.id'],
ondelete='CASCADE'
),
sa.PrimaryKeyConstraint('sample_id', 'user_id')
)

# Performances index
# check if the index is already exists
indices = inspector.get_indexes('task_sample_updater')
existing_index_names = {idx['name'] for idx in indices}

if 'ix_task_sample_updater_sample_id' not in existing_index_names:
op.create_index(
'ix_task_sample_updater_sample_id',
'task_sample_updater',
['sample_id']
)
op.create_index(
'ix_task_sample_updater_user_id',
'task_sample_updater',
['user_id']
)

# Migrate data from task_sample.updated_by to task_sample_updater
task_sample = table(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,38 @@
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.automap import automap_base

from labelu.alembic_labelu.alembic_labelu_tools import (
get_tool_label_dict,
replace_key_with_value,
)
def get_tool_label_dict(task_config: dict) -> dict:
"""get the key value of labels in a given task_id and task_config"""

label_dict = {"无标签": "noneAttribute"}
# obtain the labels in current task
# get the general labels
for normal_label in task_config.get("attribute", []):
if normal_label.get("key", ""):
label_dict[normal_label.get("key")] = normal_label.get("value")
# get the labels in configuration defined by user
for task_tool in task_config.get("tools", []):
if "config" not in task_tool.keys():
continue
labels = task_tool.get("config").get("attributeList", [])
for label in labels:
if label.get("key", ""):
label_dict[label.get("key")] = label.get("value")
return label_dict


def replace_key_with_value(sample_data: dict, label_dict: dict) -> dict:
"""replace the key with value in task_sample table to modify the error for the history version"""

annotated_result = sample_data.get("result")
annotated_result = json.loads(annotated_result)
for sample_tool, sample_tool_results in annotated_result.items():
if sample_tool.endswith("Tool"):
for sample_tool_result in sample_tool_results.get("result", []):
tool_label = sample_tool_result.get("attribute", "")
if tool_label in label_dict:
sample_tool_result["attribute"] = label_dict[tool_label]
return annotated_result

# revision identifiers, used by Alembic.
revision = "9d5da133bbe4"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
Create Date: 2024-02-07 15:58:30.618151
"""
import imp
import json
import os

Expand All @@ -14,6 +13,7 @@
from sqlalchemy.ext.automap import automap_base
from sqlalchemy.orm import sessionmaker
from labelu.internal.common.config import settings
from labelu.alembic_labelu.alembic_labelu_tools import table_exist, column_exist_in_table

Base = automap_base()

Expand All @@ -24,17 +24,6 @@
branch_labels = None
depends_on = None

# import alembic_labelu_tools from the absolute path
alembic_labelu_tools = imp.load_source(
"alembic_labelu_tools",
(
os.path.join(
os.path.abspath(os.path.dirname(os.path.dirname(__file__))),
"alembic_labelu_tools.py",
)
),
)

def upgrade() -> None:
bind = op.get_bind()
Base.prepare(autoload_with=bind, reflect=True)
Expand All @@ -47,7 +36,7 @@ def upgrade() -> None:

with context.begin_transaction():
# Create a new table task_pre_annotation
if not alembic_labelu_tools.table_exist("task_pre_annotation"):
if not table_exist("task_pre_annotation"):
op.create_table(
"task_pre_annotation",
sa.Column("id", sa.Integer, primary_key=True, autoincrement=True, index=True),
Expand Down Expand Up @@ -78,7 +67,7 @@ def upgrade() -> None:
),
)
# Update the task_sample table
if not alembic_labelu_tools.column_exist_in_table(
if not column_exist_in_table(
"task_sample", "file_id"
):
with op.batch_alter_table('task_sample', recreate="always") as batch_op:
Expand All @@ -93,7 +82,7 @@ def upgrade() -> None:
)

# Update the task_attachment table
if not alembic_labelu_tools.column_exist_in_table("task_attachment", "filename"):
if not column_exist_in_table("task_attachment", "filename"):
with op.batch_alter_table("task_attachment", recreate="always") as batch_op_task_attachment:
batch_op_task_attachment.add_column(
sa.Column(
Expand Down
18 changes: 3 additions & 15 deletions labelu/alembic_labelu/versions/e76c2ca5562e_add_inner_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,12 @@
Create Date: 2023-02-28 22:29:31.595257
"""
import imp
import os

from alembic import op
from alembic import context
import sqlalchemy as sa
from sqlalchemy.sql import text
from labelu.alembic_labelu.alembic_labelu_tools import column_exist_in_table

# import alembic_labelu_tools from the absolute path
alembic_labelu_tools = imp.load_source(
"alembic_labelu_tools",
(
os.path.join(
os.path.abspath(os.path.dirname(os.path.dirname(__file__))),
"alembic_labelu_tools.py",
)
),
)

# revision identifiers, used by Alembic.
revision = "e76c2ca5562e"
Expand All @@ -36,7 +24,7 @@ def upgrade() -> None:
add inner_id and last_sample_inner_id columns in the task and sample tables
"""
with context.begin_transaction():
if not alembic_labelu_tools.column_exist_in_table(
if not column_exist_in_table(
"task", "last_sample_inner_id"
):
with op.batch_alter_table("task", recreate="always") as batch_op_task:
Expand All @@ -50,7 +38,7 @@ def upgrade() -> None:
),
insert_before="config",
)
if not alembic_labelu_tools.column_exist_in_table("task_sample", "inner_id"):
if not column_exist_in_table("task_sample", "inner_id"):
with op.batch_alter_table(
"task_sample", recreate="always"
) as batch_op_task_sample:
Expand Down
12 changes: 12 additions & 0 deletions labelu/internal/application/service/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,18 @@ async def get(db: Session, task_id: int, current_user: User) -> TaskResponseWith

# get task detail
task = crud_task.get(db=db, task_id=task_id)

# not the collaborators
if task.created_by != current_user.id and current_user not in task.collaborators:
logger.error(
"cannot get task, the task owner is:{}, the get operator is:{}",
task.created_by,
current_user.id,
)
raise LabelUException(
code=ErrorCode.CODE_30001_NO_PERMISSION,
status_code=status.HTTP_403_FORBIDDEN,
)
if not task:
logger.error("cannot find task:{}", task_id)
raise LabelUException(
Expand Down
2 changes: 1 addition & 1 deletion labelu/internal/dependencies/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def get_current_user(
)
user = crud_user.get(db, id=token_data.id)
if not user:
raise LabelUException(code=ErrorCode.CODE_40002_USER_NOT_FOUND, status_code=404)
raise LabelUException(code=ErrorCode.CODE_40002_USER_NOT_FOUND, status_code=401)
return user


Expand Down

0 comments on commit 047380a

Please sign in to comment.