Skip to content

Commit

Permalink
fix model validation in pydantic; (#27)
Browse files Browse the repository at this point in the history
Co-authored-by: Angela <[email protected]>
  • Loading branch information
aditya1503 and axl1313 authored Feb 7, 2025
1 parent 705b9d3 commit b79317e
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 16 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Fixed

- Pydantic model validation error when querying Project and listing Organizations.

## [0.0.1a3] - 2025-02-06

### Added
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ classifiers = [
]
dependencies = [
"codex-sdk==0.1.0a9",
"pydantic>=1.9.0, <3",
"pydantic>=2.0.0, <3",
]

[project.urls]
Expand Down
4 changes: 3 additions & 1 deletion src/cleanlab_codex/internal/organization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,6 @@


def list_organizations(client: _Codex) -> list[Organization]:
return [Organization.model_validate(org) for org in client.users.myself.organizations.list().organizations]
return [
Organization.model_validate(org.model_dump()) for org in client.users.myself.organizations.list().organizations
]
6 changes: 4 additions & 2 deletions src/cleanlab_codex/internal/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,16 @@ def query_project(
) -> tuple[Optional[str], Optional[Entry]]:
maybe_entry = client.projects.entries.query(project_id, question=question)
if maybe_entry is not None:
entry = Entry.model_validate(maybe_entry)
entry = Entry.model_validate(maybe_entry.model_dump())
if entry.answer is not None:
return entry.answer, entry

return fallback_answer, entry

if not read_only:
created_entry = Entry.model_validate(client.projects.entries.add_question(project_id, question=question))
created_entry = Entry.model_validate(
client.projects.entries.add_question(project_id, question=question).model_dump()
)
return fallback_answer, created_entry

return fallback_answer, None
6 changes: 3 additions & 3 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from codex import AuthenticationError
from codex.types.project_return_schema import Config as ProjectReturnConfig
from codex.types.project_return_schema import ProjectReturnSchema
from codex.types.users.myself.user_organizations_schema import Organization as SDKOrganization
from codex.types.users.myself.user_organizations_schema import UserOrganizationsSchema

from cleanlab_codex.client import Client
from cleanlab_codex.project import MissingProjectError
from cleanlab_codex.types.organization import Organization
from cleanlab_codex.types.project import ProjectConfig

FAKE_PROJECT_ID = str(uuid.uuid4())
Expand All @@ -29,7 +29,7 @@ def test_client_uses_default_organization(mock_client_from_api_key: MagicMock) -
default_org_id = "default-org-id"
mock_client_from_api_key.users.myself.organizations.list.return_value = UserOrganizationsSchema(
organizations=[
Organization(
SDKOrganization(
organization_id=default_org_id,
created_at=datetime.now(),
updated_at=datetime.now(),
Expand Down Expand Up @@ -98,7 +98,7 @@ def test_get_project_not_found(mock_client_from_api_key: MagicMock) -> None:
def test_list_organizations(mock_client_from_api_key: MagicMock) -> None:
mock_client_from_api_key.users.myself.organizations.list.return_value = UserOrganizationsSchema(
organizations=[
Organization(
SDKOrganization(
organization_id=FAKE_ORGANIZATION_ID,
created_at=datetime.now(),
updated_at=datetime.now(),
Expand Down
34 changes: 25 additions & 9 deletions tests/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from codex import AuthenticationError
from codex.types.project_create_params import Config
from codex.types.projects.access_key_retrieve_project_id_response import AccessKeyRetrieveProjectIDResponse
from codex.types.projects.entry import Entry as SDKEntry

from cleanlab_codex.project import MissingProjectError, Project
from cleanlab_codex.types.entry import Entry, EntryCreate
from cleanlab_codex.types.entry import EntryCreate

FAKE_PROJECT_ID = str(uuid.uuid4())
FAKE_USER_ID = "Test User"
Expand Down Expand Up @@ -138,11 +139,12 @@ def test_query_read_only(mock_client_from_access_key: MagicMock) -> None:
FAKE_PROJECT_ID, question="What is the capital of France?"
)
mock_client_from_access_key.projects.entries.add_question.assert_not_called()
assert res == (None, None)
assert res[0] is None
assert res[1] is None


def test_query_question_found_fallback_answer(mock_client_from_access_key: MagicMock) -> None:
unanswered_entry = Entry(
unanswered_entry = SDKEntry(
id=str(uuid.uuid4()),
created_at=datetime.now(tz=timezone.utc),
question="What is the capital of France?",
Expand All @@ -151,22 +153,32 @@ def test_query_question_found_fallback_answer(mock_client_from_access_key: Magic
mock_client_from_access_key.projects.entries.query.return_value = unanswered_entry
project = Project(mock_client_from_access_key, FAKE_PROJECT_ID)
res = project.query("What is the capital of France?")
assert res == (None, unanswered_entry)
assert res[0] is None
assert res[1] is not None
assert res[1].model_dump() == unanswered_entry.model_dump()


def test_query_question_not_found_fallback_answer(mock_client_from_access_key: MagicMock) -> None:
mock_client_from_access_key.projects.entries.query.return_value = None
mock_client_from_access_key.projects.entries.add_question.return_value = MagicMock(spec=Entry)
mock_entry = SDKEntry(
id="fake-id",
created_at=datetime.now(tz=timezone.utc),
question="What is the capital of France?",
answer=None,
)
mock_client_from_access_key.projects.entries.add_question.return_value = mock_entry

project = Project(mock_client_from_access_key, FAKE_PROJECT_ID)
res = project.query("What is the capital of France?", fallback_answer="Paris")
assert res[0] == "Paris"
assert res[1] is not None
assert res[1].model_dump() == mock_entry.model_dump()


def test_query_add_question_when_not_found(mock_client_from_access_key: MagicMock) -> None:
"""Test that query adds question when not found and not read_only"""
mock_client_from_access_key.projects.entries.query.return_value = None
new_entry = Entry(
new_entry = SDKEntry(
id=str(uuid.uuid4()),
created_at=datetime.now(tz=timezone.utc),
question="What is the capital of France?",
Expand All @@ -180,11 +192,13 @@ def test_query_add_question_when_not_found(mock_client_from_access_key: MagicMoc
mock_client_from_access_key.projects.entries.add_question.assert_called_once_with(
FAKE_PROJECT_ID, question="What is the capital of France?"
)
assert res == (None, new_entry)
assert res[0] is None
assert res[1] is not None
assert res[1].model_dump() == new_entry.model_dump()


def test_query_answer_found(mock_client_from_access_key: MagicMock) -> None:
answered_entry = Entry(
answered_entry = SDKEntry(
id=str(uuid.uuid4()),
created_at=datetime.now(tz=timezone.utc),
question="What is the capital of France?",
Expand All @@ -193,7 +207,9 @@ def test_query_answer_found(mock_client_from_access_key: MagicMock) -> None:
mock_client_from_access_key.projects.entries.query.return_value = answered_entry
project = Project(mock_client_from_access_key, FAKE_PROJECT_ID)
res = project.query("What is the capital of France?")
assert res == ("Paris", answered_entry)
assert res[0] == answered_entry.answer
assert res[1] is not None
assert res[1].model_dump() == answered_entry.model_dump()


def test_add_entries_empty_list(mock_client_from_access_key: MagicMock) -> None:
Expand Down

0 comments on commit b79317e

Please sign in to comment.