Skip to content

Commit

Permalink
distinguish between import and package name in error (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
axl1313 authored Feb 7, 2025
1 parent 31709e3 commit 705b9d3
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 6 deletions.
13 changes: 9 additions & 4 deletions src/cleanlab_codex/codex_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def to_smolagents_tool(self) -> Any:
try:
from cleanlab_codex.utils.smolagents import CodexTool as SmolagentsCodexTool
except ImportError as e:
raise MissingDependencyError("smolagents", "https://github.com/huggingface/smolagents") from e
raise MissingDependencyError(e.name or "smolagents", "https://github.com/huggingface/smolagents") from e

return SmolagentsCodexTool(
query=self.query,
Expand All @@ -145,8 +145,9 @@ def to_llamaindex_tool(self) -> Any:

except ImportError as e:
raise MissingDependencyError(
"llama-index-core",
"https://docs.llamaindex.ai/en/stable/getting_started/installation/",
import_name=e.name or "llama_index",
package_name="llama-index-core",
package_url="https://docs.llamaindex.ai/en/stable/getting_started/installation/",
) from e

return FunctionTool.from_defaults(
Expand All @@ -165,7 +166,11 @@ def to_langchain_tool(self) -> Any:
from langchain_core.tools.structured import StructuredTool

except ImportError as e:
raise MissingDependencyError("langchain", "https://pypi.org/project/langchain/") from e
raise MissingDependencyError(
import_name=e.name or "langchain",
package_name="langchain",
package_url="https://pypi.org/project/langchain/",
) from e

return StructuredTool.from_function(
func=self.query,
Expand Down
11 changes: 9 additions & 2 deletions src/cleanlab_codex/utils/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,19 @@
class MissingDependencyError(Exception):
"""Raised when a lazy import is missing."""

def __init__(self, import_name: str, package_url: str | None = None) -> None:
def __init__(self, import_name: str, package_name: str | None = None, package_url: str | None = None) -> None:
"""
Args:
import_name: The name of the import that failed.
package_name: The name of the package to install.
package_url: The URL for more information about the package.
"""
self.import_name = import_name
self.package_name = package_name
self.package_url = package_url

def __str__(self) -> str:
message = f"Failed to import {self.import_name}. Please install the package using `pip install {self.import_name}` and try again."
message = f"Failed to import {self.import_name}. Please install the package using `pip install {self.package_name or self.import_name}` and try again."
if self.package_url:
message += f" For more information, see {self.package_url}."
return message
57 changes: 57 additions & 0 deletions tests/test_codex_tool.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,24 @@
import builtins
import importlib
import sys
from typing import Any
from unittest.mock import MagicMock, patch

import pytest
from langchain_core.tools.structured import StructuredTool
from llama_index.core.tools import FunctionTool

from cleanlab_codex.codex_tool import CodexTool
from cleanlab_codex.utils.errors import MissingDependencyError


def patch_import_with_import_error(missing_module: str) -> None:
def custom_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name.startswith(missing_module):
raise ImportError("test", name=missing_module)
return importlib.__import__(name, *args, **kwargs)

builtins.__import__ = custom_import


def test_to_openai_tool(mock_client_from_access_key: MagicMock) -> None:
Expand All @@ -32,6 +45,20 @@ def test_to_llamaindex_tool(mock_client_from_access_key: MagicMock) -> None:
assert llama_index_tool.fn == tool.query


def test_to_llamaindex_tool_import_error(
mock_client_from_access_key: MagicMock,
) -> None:
with patch("cleanlab_codex.codex_tool.Project") as mock_project:
mock_project.from_access_key.return_value = MagicMock(client=mock_client_from_access_key, id="test_project_id")

tool = CodexTool.from_access_key("sk-test-123")
patch_import_with_import_error("llama_index")
with pytest.raises(MissingDependencyError) as exc_info:
tool.to_llamaindex_tool()

assert exc_info.value.import_name == "llama_index"


def test_to_langchain_tool(mock_client_from_access_key: MagicMock) -> None:
with patch("cleanlab_codex.codex_tool.Project") as mock_project:
mock_project.from_access_key.return_value = MagicMock(client=mock_client_from_access_key, id="test_project_id")
Expand All @@ -50,6 +77,18 @@ def test_to_langchain_tool(mock_client_from_access_key: MagicMock) -> None:
), f"Expected description '{tool.tool_description}', got '{langchain_tool.description}'."


def test_to_langchain_tool_import_error(mock_client_from_access_key: MagicMock) -> None:
with patch("cleanlab_codex.codex_tool.Project") as mock_project:
mock_project.from_access_key.return_value = MagicMock(client=mock_client_from_access_key, id="test_project_id")

tool = CodexTool.from_access_key("sk-test-123")
patch_import_with_import_error("langchain")
with pytest.raises(MissingDependencyError) as exc_info:
tool.to_langchain_tool()

assert exc_info.value.import_name == "langchain"


def test_to_aws_converse_tool(mock_client_from_access_key: MagicMock) -> None:
with patch("cleanlab_codex.codex_tool.Project") as mock_project:
mock_project.from_access_key.return_value = MagicMock(client=mock_client_from_access_key, id="test_project_id")
Expand Down Expand Up @@ -95,3 +134,21 @@ def test_to_smolagents_tool(mock_client_from_access_key: MagicMock) -> None:
assert isinstance(smolagents_tool, Tool)
assert smolagents_tool.name == tool.tool_name
assert smolagents_tool.description == tool.tool_description


def test_to_smolagents_tool_import_error(
mock_client_from_access_key: MagicMock,
) -> None:
with patch("cleanlab_codex.codex_tool.Project") as mock_project:
mock_project.from_access_key.return_value = MagicMock(client=mock_client_from_access_key, id="test_project_id")

tool = CodexTool.from_access_key("sk-test-123")
import_module_name = "smolagents"
if sys.version_info >= (3, 10):
import_module_name = "cleanlab_codex.utils.smolagents"
patch_import_with_import_error(import_module_name)

with pytest.raises(MissingDependencyError) as exc_info:
tool.to_smolagents_tool()

assert exc_info.value.import_name == import_module_name

0 comments on commit 705b9d3

Please sign in to comment.