From 6a00494b56b13a1d866a9f3044abf76e49f67849 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Mon, 20 Jan 2025 15:18:54 +0100 Subject: [PATCH] [BUGFIX] [ENHANCEMENT] create users and workspaces with predefined ID (#5786) # Description This PR allows users and workspaces to be created with predefined IDs. Passing a predefined ID when creating workspaces and users simplifies the migration process since users don't need to review responses and map them to new users, avoiding potential issues in the process. **Type of change** - Bug fix (non-breaking change which fixes an issue) - Improvement (change adding some improvement to an existing functionality) - Documentation update **How Has This Been Tested** **Checklist** - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --- argilla-server/CHANGELOG.md | 5 ++ .../argilla_server/api/schemas/v1/users.py | 1 + .../api/schemas/v1/workspaces.py | 4 +- .../src/argilla_server/contexts/accounts.py | 15 +++- .../api/handlers/v1/users/test_create_user.py | 82 ++++++++++++++++++- .../v1/workspaces/test_create_workspace.py | 72 ++++++++++++++++ argilla/CHANGELOG.md | 3 + .../migrate_from_legacy_datasets.md | 16 ++-- argilla/src/argilla/_api/_workspaces.py | 2 +- argilla/src/argilla/users/_resource.py | 3 +- argilla/src/argilla/workspaces/_resource.py | 4 +- .../tests/integration/test_manage_users.py | 7 ++ .../integration/test_manage_workspaces.py | 7 ++ 13 files changed, 208 insertions(+), 13 deletions(-) diff --git a/argilla-server/CHANGELOG.md b/argilla-server/CHANGELOG.md index e652289960..0b0c212cd2 100644 --- a/argilla-server/CHANGELOG.md +++ b/argilla-server/CHANGELOG.md @@ -16,6 +16,11 @@ These are the section headers that we use: ## [Unreleased]() +### Added + +- Added support to create users with predefined ids. ([#5786](https://github.com/argilla-io/argilla/pull/5786)) +- Added support to create workspaces with predefined ids. ([#5786](https://github.com/argilla-io/argilla/pull/5786)) + ## [2.6.0](https://github.com/argilla-io/argilla/compare/v2.5.0...v2.6.0) ### Added diff --git a/argilla-server/src/argilla_server/api/schemas/v1/users.py b/argilla-server/src/argilla_server/api/schemas/v1/users.py index 56ec19bf33..7d8478f8cd 100644 --- a/argilla-server/src/argilla_server/api/schemas/v1/users.py +++ b/argilla-server/src/argilla_server/api/schemas/v1/users.py @@ -59,6 +59,7 @@ class User(BaseModel): class UserCreate(BaseModel): + id: Optional[UUID] = None first_name: UserFirstName last_name: Optional[UserLastName] = None username: UserUsername diff --git a/argilla-server/src/argilla_server/api/schemas/v1/workspaces.py b/argilla-server/src/argilla_server/api/schemas/v1/workspaces.py index 5fba689f95..4b301aa214 100644 --- a/argilla-server/src/argilla_server/api/schemas/v1/workspaces.py +++ b/argilla-server/src/argilla_server/api/schemas/v1/workspaces.py @@ -13,7 +13,8 @@ # limitations under the License. from datetime import datetime -from typing import List +from optparse import Option +from typing import List, Optional from uuid import UUID from pydantic import BaseModel, Field, ConfigDict @@ -29,6 +30,7 @@ class Workspace(BaseModel): class WorkspaceCreate(BaseModel): + id: Optional[UUID] = None name: str = Field(min_length=1) diff --git a/argilla-server/src/argilla_server/contexts/accounts.py b/argilla-server/src/argilla_server/contexts/accounts.py index f12fbadf0c..6eb16851b4 100644 --- a/argilla-server/src/argilla_server/contexts/accounts.py +++ b/argilla-server/src/argilla_server/contexts/accounts.py @@ -66,7 +66,15 @@ async def create_workspace(db: AsyncSession, workspace_attrs: dict) -> Workspace if await Workspace.get_by(db, name=workspace_attrs["name"]) is not None: raise NotUniqueError(f"Workspace name `{workspace_attrs['name']}` is not unique") - return await Workspace.create(db, name=workspace_attrs["name"]) + if workspace_id := workspace_attrs.get("id"): + if await Workspace.get(db, id=workspace_id) is not None: + raise NotUniqueError(f"Workspace with id `{workspace_id}` is not unique") + + return await Workspace.create( + db, + id=workspace_attrs.get("id"), + name=workspace_attrs["name"], + ) async def delete_workspace(db: AsyncSession, workspace: Workspace): @@ -108,8 +116,13 @@ async def create_user(db: AsyncSession, user_attrs: dict, workspaces: Union[List if await get_user_by_username(db, user_attrs["username"]) is not None: raise NotUniqueError(f"User username `{user_attrs['username']}` is not unique") + if user_id := user_attrs.get("id"): + if await User.get(db, id=user_id) is not None: + raise NotUniqueError(f"User with id `{user_id}` is not unique") + user = await User.create( db, + id=user_attrs.get("id"), first_name=user_attrs["first_name"], last_name=user_attrs["last_name"], username=user_attrs["username"], diff --git a/argilla-server/tests/unit/api/handlers/v1/users/test_create_user.py b/argilla-server/tests/unit/api/handlers/v1/users/test_create_user.py index ebe95aa6b7..4c9bf2f2e7 100644 --- a/argilla-server/tests/unit/api/handlers/v1/users/test_create_user.py +++ b/argilla-server/tests/unit/api/handlers/v1/users/test_create_user.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from uuid import uuid4 import pytest from argilla_server.constants import API_KEY_HEADER_NAME @@ -146,6 +146,86 @@ async def test_create_user_with_non_default_role( assert response.json()["role"] == UserRole.owner assert user.role == UserRole.owner + async def test_create_user_with_predefined_id( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + user_id = uuid4() + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "id": str(user_id), + "first_name": "First name", + "last_name": "Last name", + "username": "username", + "password": "12345678", + }, + ) + + assert response.status_code == 201 + + user = (await db.execute(select(User).filter_by(username="username"))).scalar_one() + assert user.id == user_id + + async def test_create_user_with_none_user_id( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "id": None, + "first_name": "First name", + "last_name": "Last name", + "username": "username", + "password": "12345678", + }, + ) + + assert response.status_code == 201 + + user = (await db.execute(select(User).filter_by(username="username"))).scalar_one() + assert user.id is not None + + async def test_create_user_with_wrong_user_id( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "id": "wrong_id", + "first_name": "First name", + "last_name": "Last name", + "username": "username", + "password": "12345678", + }, + ) + + assert response.status_code == 422 + assert (await db.execute(select(func.count(User.id)))).scalar() == 1 + + async def test_create_user_with_existing_id( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + user_id = uuid4() + await UserFactory.create(id=user_id) + + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "id": str(user_id), + "first_name": "First name", + "last_name": "Last name", + "username": "username", + "password": "12345678", + }, + ) + + assert response.status_code == 409 + assert (await db.execute(select(func.count(User.id)))).scalar() == 2 + async def test_create_user_without_authentication(self, db: AsyncSession, async_client: AsyncClient): response = await async_client.post( self.url(), diff --git a/argilla-server/tests/unit/api/handlers/v1/workspaces/test_create_workspace.py b/argilla-server/tests/unit/api/handlers/v1/workspaces/test_create_workspace.py index d2bb2b4e40..679ce3caca 100644 --- a/argilla-server/tests/unit/api/handlers/v1/workspaces/test_create_workspace.py +++ b/argilla-server/tests/unit/api/handlers/v1/workspaces/test_create_workspace.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from uuid import uuid4 import pytest from argilla_server.constants import API_KEY_HEADER_NAME @@ -47,6 +48,77 @@ async def test_create_workspace(self, db: AsyncSession, async_client: AsyncClien "updated_at": workspace.updated_at.isoformat(), } + async def test_create_workspace_with_predefined_id( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + workspace_id = uuid4() + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={"id": str(workspace_id), "name": "workspace"}, + ) + + assert response.status_code == 201 + + assert (await db.execute(select(func.count(Workspace.id)))).scalar() == 1 + workspace = (await db.execute(select(Workspace).filter_by(name="workspace"))).scalar_one() + + assert response.json() == { + "id": str(workspace_id), + "name": "workspace", + "inserted_at": workspace.inserted_at.isoformat(), + "updated_at": workspace.updated_at.isoformat(), + } + + async def test_create_workspace_with_none_id( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={"id": None, "name": "workspace"}, + ) + + assert response.status_code == 201 + + assert (await db.execute(select(func.count(Workspace.id)))).scalar() == 1 + workspace = (await db.execute(select(Workspace).filter_by(name="workspace"))).scalar_one() + + assert response.json() == { + "id": str(workspace.id), + "name": "workspace", + "inserted_at": workspace.inserted_at.isoformat(), + "updated_at": workspace.updated_at.isoformat(), + } + + async def test_create_workspace_with_wrong_id( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={"id": "wrong_id", "name": "workspace"}, + ) + + assert response.status_code == 422 + + assert (await db.execute(select(func.count(Workspace.id)))).scalar() == 0 + + async def test_create_workspace_with_existing_id( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + workspace_id = uuid4() + await WorkspaceFactory.create(id=workspace_id) + + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={"id": str(workspace_id), "name": "workspace"}, + ) + + assert response.status_code == 409 + assert (await db.execute(select(func.count(Workspace.id)))).scalar() == 1 + async def test_create_workspace_without_authentication(self, db: AsyncSession, async_client: AsyncClient): response = await async_client.post( self.url(), diff --git a/argilla/CHANGELOG.md b/argilla/CHANGELOG.md index 1b8d87b9e9..6f3141ae90 100644 --- a/argilla/CHANGELOG.md +++ b/argilla/CHANGELOG.md @@ -19,6 +19,9 @@ These are the section headers that we use: ### Added - Return similarity score when searching by similarity. ([#5778](https://github.com/argilla-io/argilla/pull/5778)) +- Added support to create users with predefined ids. ([#5786](https://github.com/argilla-io/argilla/pull/5786)) +- Added support to create workspaces with predefined ids. ([#5786](https://github.com/argilla-io/argilla/pull/5786)) + ## [2.6.0](https://github.com/argilla-io/argilla/compare/v2.5.0...v2.6.0) diff --git a/argilla/docs/how_to_guides/migrate_from_legacy_datasets.md b/argilla/docs/how_to_guides/migrate_from_legacy_datasets.md index f7cb1db770..83751bbc7b 100644 --- a/argilla/docs/how_to_guides/migrate_from_legacy_datasets.md +++ b/argilla/docs/how_to_guides/migrate_from_legacy_datasets.md @@ -69,26 +69,30 @@ Next, recreate the users and workspaces on the Argilla V2 server: ```python for workspace in workspaces_v1: rg.Workspace( - name=workspace.name + id=workspace.id, + name=workspace.name, ).create() ``` ```python for user in users_v1: - user = rg.User( + user_v2 = rg.User( + id=user.id, username=user.username, first_name=user.first_name, last_name=user.last_name, role=user.role, password="" # (1) ).create() + if user.role == "owner": continue - for workspace_name in user.workspaces: - if workspace_name != user.name: - workspace = client.workspaces(name=workspace_name) - user.add_to_workspace(workspace) + for workspace in user.workspaces: + workspace_v2 = client.workspaces(name=workspace.name) + if workspace_v2 is None: + continue + user.add_to_workspace(workspace_v2) ``` 1. You need to chose a new password for the user, to do this programmatically you can use the `uuid` package to generate a random password. Take care to keep track of the passwords you chose, since you will not be able to retrieve them later. diff --git a/argilla/src/argilla/_api/_workspaces.py b/argilla/src/argilla/_api/_workspaces.py index 1bf80020c6..46fbfe2708 100644 --- a/argilla/src/argilla/_api/_workspaces.py +++ b/argilla/src/argilla/_api/_workspaces.py @@ -35,7 +35,7 @@ class WorkspacesAPI(ResourceAPI[WorkspaceModel]): @api_error_handler def create(self, workspace: WorkspaceModel) -> WorkspaceModel: # TODO: Unify API endpoint - response = self.http_client.post(url="/api/v1/workspaces", json={"name": workspace.name}) + response = self.http_client.post(url="/api/v1/workspaces", json=workspace.model_dump()) response.raise_for_status() response_json = response.json() workspace = self._model_from_json(json_workspace=response_json) diff --git a/argilla/src/argilla/users/_resource.py b/argilla/src/argilla/users/_resource.py index 09d5f40b29..fb082787f7 100644 --- a/argilla/src/argilla/users/_resource.py +++ b/argilla/src/argilla/users/_resource.py @@ -45,8 +45,8 @@ def __init__( last_name: Optional[str] = None, role: Optional[str] = None, password: Optional[str] = None, - client: Optional["Argilla"] = None, id: Optional[UUID] = None, + client: Optional["Argilla"] = None, _model: Optional[UserModel] = None, ) -> None: """Initializes a User object with a client and a username @@ -57,6 +57,7 @@ def __init__( last_name (str): The last name of the user role (str): The role of the user, either 'annotator', admin, or 'owner' password (str): The password of the user + id (UUID): The ID of the user. If provided before a .create, the will be created with this ID client (Argilla): The client used to interact with Argilla Returns: diff --git a/argilla/src/argilla/workspaces/_resource.py b/argilla/src/argilla/workspaces/_resource.py index c508bf6372..c891873f85 100644 --- a/argilla/src/argilla/workspaces/_resource.py +++ b/argilla/src/argilla/workspaces/_resource.py @@ -50,9 +50,9 @@ def __init__( """Initializes a Workspace object with a client and a name or id Parameters: - client (Argilla): The client used to interact with Argilla name (str): The name of the workspace - id (UUID): The id of the workspace + id (UUID): The id of the workspace. If provided before a .create, the workspace will be created with this ID + client (Argilla): The client used to interact with Argilla Returns: Workspace: The initialized workspace object diff --git a/argilla/tests/integration/test_manage_users.py b/argilla/tests/integration/test_manage_users.py index 80c0f7b2d9..fdbf6a7860 100644 --- a/argilla/tests/integration/test_manage_users.py +++ b/argilla/tests/integration/test_manage_users.py @@ -26,6 +26,13 @@ def test_create_user(self, client: Argilla): assert user.id is not None assert client.users(username=user.username).id == user.id + def test_create_user_with_id(self, client: Argilla): + user_id = uuid.uuid4() + user = User(id=user_id, username=f"test_user_{uuid.uuid4()}", password="test_password") + client.users.add(user) + assert user.id is not None + assert client.users(username=user.username).id == user_id + def test_create_user_without_password(self, client: Argilla): user = User(username=f"test_user_{uuid.uuid4()}") with pytest.raises(expected_exception=UnprocessableEntityError): diff --git a/argilla/tests/integration/test_manage_workspaces.py b/argilla/tests/integration/test_manage_workspaces.py index 6462003a65..6c5167679e 100644 --- a/argilla/tests/integration/test_manage_workspaces.py +++ b/argilla/tests/integration/test_manage_workspaces.py @@ -24,6 +24,13 @@ def test_create_workspace(self, client: Argilla): assert workspace in client.workspaces assert client.api.workspaces.exists(workspace.id) + def test_create_workspace_with_id(self, client: Argilla): + workspace_id = uuid.uuid4() + workspace = Workspace(id=workspace_id, name=f"test_workspace{uuid.uuid4()}") + client.workspaces.add(workspace) + assert workspace in client.workspaces + assert client.workspaces(workspace.name).id == workspace_id + def test_create_and_delete_workspace(self, client: Argilla): workspace = client.workspaces(name="test_workspace") if workspace: