diff --git a/CHANGELOG.md b/CHANGELOG.md index 6af79f6..b987320 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Fixed + +- Pydantic model validation error when querying Project and listing Organizations. + ## [0.0.1a3] - 2025-02-06 ### Added diff --git a/pyproject.toml b/pyproject.toml index b0da1bd..c7fa840 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ ] dependencies = [ "codex-sdk==0.1.0a9", - "pydantic>=1.9.0, <3", + "pydantic>=2.0.0, <3", ] [project.urls] diff --git a/src/cleanlab_codex/internal/organization.py b/src/cleanlab_codex/internal/organization.py index 1dfecb3..f77ccf6 100644 --- a/src/cleanlab_codex/internal/organization.py +++ b/src/cleanlab_codex/internal/organization.py @@ -9,4 +9,6 @@ def list_organizations(client: _Codex) -> list[Organization]: - return [Organization.model_validate(org) for org in client.users.myself.organizations.list().organizations] + return [ + Organization.model_validate(org.model_dump()) for org in client.users.myself.organizations.list().organizations + ] diff --git a/src/cleanlab_codex/internal/project.py b/src/cleanlab_codex/internal/project.py index 8ef8ad2..ac15e7e 100644 --- a/src/cleanlab_codex/internal/project.py +++ b/src/cleanlab_codex/internal/project.py @@ -25,14 +25,16 @@ def query_project( ) -> tuple[Optional[str], Optional[Entry]]: maybe_entry = client.projects.entries.query(project_id, question=question) if maybe_entry is not None: - entry = Entry.model_validate(maybe_entry) + entry = Entry.model_validate(maybe_entry.model_dump()) if entry.answer is not None: return entry.answer, entry return fallback_answer, entry if not read_only: - created_entry = Entry.model_validate(client.projects.entries.add_question(project_id, question=question)) + created_entry = Entry.model_validate( + client.projects.entries.add_question(project_id, question=question).model_dump() + ) return fallback_answer, created_entry return fallback_answer, None diff --git a/tests/test_client.py b/tests/test_client.py index 9a8870d..be1eddc 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -8,11 +8,11 @@ from codex import AuthenticationError from codex.types.project_return_schema import Config as ProjectReturnConfig from codex.types.project_return_schema import ProjectReturnSchema +from codex.types.users.myself.user_organizations_schema import Organization as SDKOrganization from codex.types.users.myself.user_organizations_schema import UserOrganizationsSchema from cleanlab_codex.client import Client from cleanlab_codex.project import MissingProjectError -from cleanlab_codex.types.organization import Organization from cleanlab_codex.types.project import ProjectConfig FAKE_PROJECT_ID = str(uuid.uuid4()) @@ -29,7 +29,7 @@ def test_client_uses_default_organization(mock_client_from_api_key: MagicMock) - default_org_id = "default-org-id" mock_client_from_api_key.users.myself.organizations.list.return_value = UserOrganizationsSchema( organizations=[ - Organization( + SDKOrganization( organization_id=default_org_id, created_at=datetime.now(), updated_at=datetime.now(), @@ -98,7 +98,7 @@ def test_get_project_not_found(mock_client_from_api_key: MagicMock) -> None: def test_list_organizations(mock_client_from_api_key: MagicMock) -> None: mock_client_from_api_key.users.myself.organizations.list.return_value = UserOrganizationsSchema( organizations=[ - Organization( + SDKOrganization( organization_id=FAKE_ORGANIZATION_ID, created_at=datetime.now(), updated_at=datetime.now(), diff --git a/tests/test_project.py b/tests/test_project.py index 6237911..b70e99a 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -6,9 +6,10 @@ from codex import AuthenticationError from codex.types.project_create_params import Config from codex.types.projects.access_key_retrieve_project_id_response import AccessKeyRetrieveProjectIDResponse +from codex.types.projects.entry import Entry as SDKEntry from cleanlab_codex.project import MissingProjectError, Project -from cleanlab_codex.types.entry import Entry, EntryCreate +from cleanlab_codex.types.entry import EntryCreate FAKE_PROJECT_ID = str(uuid.uuid4()) FAKE_USER_ID = "Test User" @@ -138,11 +139,12 @@ def test_query_read_only(mock_client_from_access_key: MagicMock) -> None: FAKE_PROJECT_ID, question="What is the capital of France?" ) mock_client_from_access_key.projects.entries.add_question.assert_not_called() - assert res == (None, None) + assert res[0] is None + assert res[1] is None def test_query_question_found_fallback_answer(mock_client_from_access_key: MagicMock) -> None: - unanswered_entry = Entry( + unanswered_entry = SDKEntry( id=str(uuid.uuid4()), created_at=datetime.now(tz=timezone.utc), question="What is the capital of France?", @@ -151,22 +153,32 @@ def test_query_question_found_fallback_answer(mock_client_from_access_key: Magic mock_client_from_access_key.projects.entries.query.return_value = unanswered_entry project = Project(mock_client_from_access_key, FAKE_PROJECT_ID) res = project.query("What is the capital of France?") - assert res == (None, unanswered_entry) + assert res[0] is None + assert res[1] is not None + assert res[1].model_dump() == unanswered_entry.model_dump() def test_query_question_not_found_fallback_answer(mock_client_from_access_key: MagicMock) -> None: mock_client_from_access_key.projects.entries.query.return_value = None - mock_client_from_access_key.projects.entries.add_question.return_value = MagicMock(spec=Entry) + mock_entry = SDKEntry( + id="fake-id", + created_at=datetime.now(tz=timezone.utc), + question="What is the capital of France?", + answer=None, + ) + mock_client_from_access_key.projects.entries.add_question.return_value = mock_entry project = Project(mock_client_from_access_key, FAKE_PROJECT_ID) res = project.query("What is the capital of France?", fallback_answer="Paris") assert res[0] == "Paris" + assert res[1] is not None + assert res[1].model_dump() == mock_entry.model_dump() def test_query_add_question_when_not_found(mock_client_from_access_key: MagicMock) -> None: """Test that query adds question when not found and not read_only""" mock_client_from_access_key.projects.entries.query.return_value = None - new_entry = Entry( + new_entry = SDKEntry( id=str(uuid.uuid4()), created_at=datetime.now(tz=timezone.utc), question="What is the capital of France?", @@ -180,11 +192,13 @@ def test_query_add_question_when_not_found(mock_client_from_access_key: MagicMoc mock_client_from_access_key.projects.entries.add_question.assert_called_once_with( FAKE_PROJECT_ID, question="What is the capital of France?" ) - assert res == (None, new_entry) + assert res[0] is None + assert res[1] is not None + assert res[1].model_dump() == new_entry.model_dump() def test_query_answer_found(mock_client_from_access_key: MagicMock) -> None: - answered_entry = Entry( + answered_entry = SDKEntry( id=str(uuid.uuid4()), created_at=datetime.now(tz=timezone.utc), question="What is the capital of France?", @@ -193,7 +207,9 @@ def test_query_answer_found(mock_client_from_access_key: MagicMock) -> None: mock_client_from_access_key.projects.entries.query.return_value = answered_entry project = Project(mock_client_from_access_key, FAKE_PROJECT_ID) res = project.query("What is the capital of France?") - assert res == ("Paris", answered_entry) + assert res[0] == answered_entry.answer + assert res[1] is not None + assert res[1].model_dump() == answered_entry.model_dump() def test_add_entries_empty_list(mock_client_from_access_key: MagicMock) -> None: