-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
238 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,11 @@ | ||
name: CI | ||
on: | ||
push: | ||
branches: | ||
- main | ||
pull_request: | ||
branches: | ||
- main | ||
|
||
jobs: | ||
typecheck: | ||
|
@@ -27,3 +31,20 @@ jobs: | |
python-version: "3.13" | ||
- uses: pypa/hatch@install | ||
- run: hatch fmt --check | ||
test: | ||
name: Test | ||
runs-on: ubuntu-22.04 | ||
strategy: | ||
matrix: | ||
python: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- uses: actions/setup-python@v5 | ||
with: | ||
python-version: ${{ matrix.python }} | ||
- uses: pypa/hatch@install | ||
# TODO: remove after we release codex-sdk package (and are no longer installing from github) | ||
- name: setup git url rewrite | ||
run: git config --global url."https://${{ secrets.GH_USERNAME }}:${{ secrets.CLEANLAB_BOT_PAT }}@github.com".insteadOf ssh://[email protected] | ||
- run: hatch test -v --cover --include python=$(echo ${{ matrix.python }}) | ||
- run: hatch run coverage:report |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from tests.fixtures.client import mock_client | ||
|
||
__all__ = ["mock_client"] |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from typing import Generator | ||
from unittest.mock import MagicMock, patch | ||
|
||
import pytest | ||
|
||
|
||
@pytest.fixture | ||
def mock_client() -> Generator[MagicMock, None, None]: | ||
with patch("cleanlab_codex.codex.init_codex_client") as mock_init: | ||
mock_client = MagicMock() | ||
mock_init.return_value = mock_client | ||
yield mock_client |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,150 @@ | ||
from typing import Generator | ||
from unittest.mock import MagicMock, patch | ||
# ruff: noqa: DTZ005 | ||
|
||
import uuid | ||
from datetime import datetime | ||
from unittest.mock import MagicMock | ||
|
||
import pytest | ||
from codex import Codex as _Codex | ||
from codex.types.project_return_schema import Config, ProjectReturnSchema | ||
from codex.types.users.myself.user_organizations_schema import UserOrganizationsSchema | ||
|
||
from cleanlab_codex.codex import Codex | ||
from cleanlab_codex.internal.project import MissingProjectIdError | ||
from cleanlab_codex.types.entry import Entry, EntryCreate | ||
from cleanlab_codex.types.organization import Organization | ||
from cleanlab_codex.types.project import ProjectConfig | ||
|
||
FAKE_PROJECT_ID = 1 | ||
FAKE_USER_ID = "Test User" | ||
FAKE_ORGANIZATION_ID = "Test Organization" | ||
FAKE_PROJECT_NAME = "Test Project" | ||
FAKE_PROJECT_DESCRIPTION = "Test Description" | ||
DEFAULT_PROJECT_CONFIG = ProjectConfig() | ||
|
||
|
||
def test_list_organizations(mock_client: MagicMock): | ||
mock_client.users.myself.organizations.list.return_value = UserOrganizationsSchema( | ||
organizations=[ | ||
Organization( | ||
organization_id=FAKE_ORGANIZATION_ID, | ||
created_at=datetime.now(), | ||
updated_at=datetime.now(), | ||
user_id=FAKE_USER_ID, | ||
) | ||
], | ||
) | ||
codex = Codex("") | ||
organizations = codex.list_organizations() | ||
assert len(organizations) == 1 | ||
assert organizations[0].organization_id == FAKE_ORGANIZATION_ID | ||
assert organizations[0].user_id == FAKE_USER_ID | ||
|
||
|
||
def test_create_project(mock_client: MagicMock): | ||
mock_client.projects.create.return_value = ProjectReturnSchema( | ||
id=FAKE_PROJECT_ID, | ||
config=Config(), | ||
created_at=datetime.now(), | ||
created_by_user_id=FAKE_USER_ID, | ||
name=FAKE_PROJECT_NAME, | ||
organization_id=FAKE_ORGANIZATION_ID, | ||
updated_at=datetime.now(), | ||
description=FAKE_PROJECT_DESCRIPTION, | ||
) | ||
codex = Codex("") | ||
project_id = codex.create_project(FAKE_PROJECT_NAME, FAKE_ORGANIZATION_ID, FAKE_PROJECT_DESCRIPTION) | ||
mock_client.projects.create.assert_called_once_with( | ||
config=DEFAULT_PROJECT_CONFIG, | ||
organization_id=FAKE_ORGANIZATION_ID, | ||
name=FAKE_PROJECT_NAME, | ||
description=FAKE_PROJECT_DESCRIPTION, | ||
) | ||
assert project_id == FAKE_PROJECT_ID | ||
|
||
fake_project_id = 1 | ||
|
||
def test_add_entries(mock_client: MagicMock): | ||
answered_entry_create = EntryCreate( | ||
question="What is the capital of France?", | ||
answer="Paris", | ||
) | ||
unanswered_entry_create = EntryCreate( | ||
question="What is the capital of Germany?", | ||
) | ||
codex = Codex("") | ||
codex.add_entries([answered_entry_create, unanswered_entry_create], project_id=FAKE_PROJECT_ID) | ||
|
||
@pytest.fixture | ||
def mock_client() -> Generator[_Codex, None, None]: | ||
with patch("cleanlab_codex.codex.init_codex_client", return_value=MagicMock()) as mock: | ||
yield mock | ||
for call, entry in zip( | ||
mock_client.projects.entries.create.call_args_list, | ||
[answered_entry_create, unanswered_entry_create], | ||
): | ||
assert call.args[0] == FAKE_PROJECT_ID | ||
assert call.kwargs["question"] == entry["question"] | ||
assert call.kwargs["answer"] == entry.get("answer") | ||
|
||
|
||
def test_query_read_only(mock_client: _Codex): | ||
mock_client.projects.entries.query.return_value = None # type: ignore | ||
def test_create_project_access_key(mock_client: MagicMock): | ||
codex = Codex("") | ||
res = codex.query("What is the capital of France?", read_only=True, project_id=fake_project_id) | ||
mock_client.projects.entries.query.assert_called_once_with( # type: ignore | ||
fake_project_id, "What is the capital of France?" | ||
access_key_name = "Test Access Key" | ||
access_key_description = "Test Access Key Description" | ||
codex.create_project_access_key(FAKE_PROJECT_ID, access_key_name, access_key_description) | ||
mock_client.projects.access_keys.create.assert_called_once_with( | ||
project_id=FAKE_PROJECT_ID, | ||
name=access_key_name, | ||
description=access_key_description, | ||
) | ||
mock_client.projects.entries.add_question.assert_not_called() # type: ignore | ||
|
||
|
||
def test_query_no_project_id(mock_client: MagicMock): | ||
mock_client.access_key = None | ||
codex = Codex("") | ||
|
||
with pytest.raises(MissingProjectIdError): | ||
codex.query("What is the capital of France?") | ||
|
||
|
||
def test_query_read_only(mock_client: MagicMock): | ||
mock_client.access_key = None | ||
mock_client.projects.entries.query.return_value = None | ||
|
||
codex = Codex("") | ||
res = codex.query("What is the capital of France?", read_only=True, project_id=FAKE_PROJECT_ID) | ||
mock_client.projects.entries.query.assert_called_once_with( | ||
FAKE_PROJECT_ID, question="What is the capital of France?" | ||
) | ||
mock_client.projects.entries.add_question.assert_not_called() | ||
assert res == (None, None) | ||
|
||
|
||
def test_query_question_found_fallback_answer(mock_client: MagicMock): | ||
unanswered_entry = Entry( | ||
id=str(uuid.uuid4()), | ||
created_at=datetime.now(), | ||
question="What is the capital of France?", | ||
answer=None, | ||
) | ||
mock_client.projects.entries.query.return_value = unanswered_entry | ||
codex = Codex("") | ||
res = codex.query("What is the capital of France?", project_id=FAKE_PROJECT_ID) | ||
assert res == (None, unanswered_entry) | ||
|
||
|
||
def test_query_question_not_found_fallback_answer(mock_client: MagicMock): | ||
mock_client.projects.entries.query.return_value = None | ||
mock_client.projects.entries.add_question.return_value = None | ||
|
||
codex = Codex("") | ||
res = codex.query("What is the capital of France?", fallback_answer="Paris") | ||
assert res == ("Paris", None) | ||
|
||
|
||
def test_query_answer_found(mock_client: MagicMock): | ||
answered_entry = Entry( | ||
id=str(uuid.uuid4()), | ||
created_at=datetime.now(), | ||
question="What is the capital of France?", | ||
answer="Paris", | ||
) | ||
mock_client.projects.entries.query.return_value = answered_entry | ||
codex = Codex("") | ||
res = codex.query("What is the capital of France?", project_id=FAKE_PROJECT_ID) | ||
assert res == ("Paris", answered_entry) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from unittest.mock import MagicMock | ||
|
||
from llama_index.core.tools import FunctionTool | ||
|
||
from cleanlab_codex.codex_tool import CodexTool | ||
|
||
|
||
def test_to_openai_tool(mock_client: MagicMock): # noqa: ARG001 | ||
tool = CodexTool.from_access_key("") | ||
openai_tool = tool.to_openai_tool() | ||
assert openai_tool.get("type") == "function" | ||
assert openai_tool.get("function", {}).get("name") == tool.tool_name | ||
assert openai_tool.get("function", {}).get("description") == tool.tool_description | ||
assert openai_tool.get("function", {}).get("parameters", {}).get("type") == "object" | ||
|
||
|
||
def test_to_llamaindex_tool(mock_client: MagicMock): # noqa: ARG001 | ||
tool = CodexTool.from_access_key("") | ||
llama_index_tool = tool.to_llamaindex_tool() | ||
assert isinstance(llama_index_tool, FunctionTool) | ||
assert llama_index_tool.metadata.name == tool.tool_name | ||
assert llama_index_tool.metadata.description == tool.tool_description | ||
assert llama_index_tool.fn == tool.query |