Skip to content

Commit

Permalink
add tests for import errors
Browse files Browse the repository at this point in the history
  • Loading branch information
axl1313 committed Feb 7, 2025
1 parent 5249f33 commit 0495365
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/cleanlab_codex/codex_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def to_smolagents_tool(self) -> Any:
try:
from cleanlab_codex.utils.smolagents import CodexTool as SmolagentsCodexTool
except ImportError as e:
raise MissingDependencyError("smolagents", "https://github.com/huggingface/smolagents") from e
raise MissingDependencyError(e.name or "smolagents", "https://github.com/huggingface/smolagents") from e

return SmolagentsCodexTool(
query=self.query,
Expand Down
57 changes: 57 additions & 0 deletions tests/test_codex_tool.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,24 @@
import builtins
import importlib
import sys
from typing import Any
from unittest.mock import MagicMock, patch

import pytest
from langchain_core.tools.structured import StructuredTool
from llama_index.core.tools import FunctionTool

from cleanlab_codex.codex_tool import CodexTool
from cleanlab_codex.utils.errors import MissingDependencyError


def patch_import_with_import_error(missing_module: str) -> None:
def custom_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name.startswith(missing_module):
raise ImportError("test", name=missing_module)
return importlib.__import__(name, *args, **kwargs)

builtins.__import__ = custom_import


def test_to_openai_tool(mock_client_from_access_key: MagicMock) -> None:
Expand All @@ -32,6 +45,20 @@ def test_to_llamaindex_tool(mock_client_from_access_key: MagicMock) -> None:
assert llama_index_tool.fn == tool.query


def test_to_llamaindex_tool_import_error(
mock_client_from_access_key: MagicMock,
) -> None:
with patch("cleanlab_codex.codex_tool.Project") as mock_project:
mock_project.from_access_key.return_value = MagicMock(client=mock_client_from_access_key, id="test_project_id")

tool = CodexTool.from_access_key("sk-test-123")
patch_import_with_import_error("llama_index")
with pytest.raises(MissingDependencyError) as exc_info:
tool.to_llamaindex_tool()

assert exc_info.value.import_name == "llama_index"


def test_to_langchain_tool(mock_client_from_access_key: MagicMock) -> None:
with patch("cleanlab_codex.codex_tool.Project") as mock_project:
mock_project.from_access_key.return_value = MagicMock(client=mock_client_from_access_key, id="test_project_id")
Expand All @@ -50,6 +77,18 @@ def test_to_langchain_tool(mock_client_from_access_key: MagicMock) -> None:
), f"Expected description '{tool.tool_description}', got '{langchain_tool.description}'."


def test_to_langchain_tool_import_error(mock_client_from_access_key: MagicMock) -> None:
with patch("cleanlab_codex.codex_tool.Project") as mock_project:
mock_project.from_access_key.return_value = MagicMock(client=mock_client_from_access_key, id="test_project_id")

tool = CodexTool.from_access_key("sk-test-123")
patch_import_with_import_error("langchain")
with pytest.raises(MissingDependencyError) as exc_info:
tool.to_langchain_tool()

assert exc_info.value.import_name == "langchain"


def test_to_aws_converse_tool(mock_client_from_access_key: MagicMock) -> None:
with patch("cleanlab_codex.codex_tool.Project") as mock_project:
mock_project.from_access_key.return_value = MagicMock(client=mock_client_from_access_key, id="test_project_id")
Expand Down Expand Up @@ -95,3 +134,21 @@ def test_to_smolagents_tool(mock_client_from_access_key: MagicMock) -> None:
assert isinstance(smolagents_tool, Tool)
assert smolagents_tool.name == tool.tool_name
assert smolagents_tool.description == tool.tool_description


def test_to_smolagents_tool_import_error(
mock_client_from_access_key: MagicMock,
) -> None:
with patch("cleanlab_codex.codex_tool.Project") as mock_project:
mock_project.from_access_key.return_value = MagicMock(client=mock_client_from_access_key, id="test_project_id")

tool = CodexTool.from_access_key("sk-test-123")
import_module_name = "smolagents"
if sys.version_info >= (3, 10):
import_module_name = "cleanlab_codex.utils.smolagents"
patch_import_with_import_error(import_module_name)

with pytest.raises(MissingDependencyError) as exc_info:
tool.to_smolagents_tool()

assert exc_info.value.import_name == import_module_name

0 comments on commit 0495365

Please sign in to comment.