Skip to content

Commit

Permalink
[Codex Integrations] Add lanchain + aws integrations (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
ulya-tkch authored Feb 6, 2025
1 parent 7880a61 commit 4b02d57
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 15 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ extra-dependencies = [
"pytest",
"llama-index-core",
"smolagents",
"langchain-core",
]
[tool.hatch.envs.types.scripts]
check = "mypy --strict --install-types --non-interactive {args:src/cleanlab_codex tests}"
Expand All @@ -53,6 +54,7 @@ allow-direct-references = true
extra-dependencies = [
"llama-index-core",
"smolagents; python_version >= '3.10'",
"langchain-core",
]

[tool.hatch.envs.hatch-test.env-vars]
Expand Down
34 changes: 34 additions & 0 deletions src/cleanlab_codex/codex_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,37 @@ def to_llamaindex_tool(self) -> Any:
tool_properties=self._tool_properties,
),
)

def to_langchain_tool(self) -> Any:
"""Converts the tool to a [LangChain tool](https://python.langchain.com/docs/concepts/tools/).
Note: You must have the [`langchain` library installed](https://python.langchain.com/docs/concepts/architecture/) to use this method.
"""
from langchain_core.tools.structured import StructuredTool

from cleanlab_codex.utils.langchain import create_args_schema

return StructuredTool.from_function(
func=self.query,
name=self._tool_name,
description=self._tool_description,
args_schema=create_args_schema(
name=self._tool_name,
func=self.query,
tool_properties=self._tool_properties,
),
)

def to_aws_converse_tool(self) -> Any:
"""Converts the tool to an [AWS Converse API tool](https://docs.aws.amazon.com/bedrock/latest/userguide/tool-use-inference-call.html).
Note: You must have the [`boto3` library installed](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html) (AWS SDK for Python) to use this method.
"""
from cleanlab_codex.utils.aws import format_as_aws_converse_tool

return format_as_aws_converse_tool(
tool_name=self._tool_name,
tool_description=self._tool_description,
tool_properties=self._tool_properties,
required_properties=self._tool_requirements,
)
15 changes: 13 additions & 2 deletions src/cleanlab_codex/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
from cleanlab_codex.utils.aws import Tool as AWSConverseTool
from cleanlab_codex.utils.aws import ToolSpec as AWSToolSpec
from cleanlab_codex.utils.aws import format_as_aws_converse_tool
from cleanlab_codex.utils.openai import Function as OpenAIFunction
from cleanlab_codex.utils.openai import FunctionParameters as OpenAIFunctionParameters
from cleanlab_codex.utils.openai import Tool as OpenAITool
from cleanlab_codex.utils.openai import format_as_openai_tool
from cleanlab_codex.utils.types import FunctionParameters

__all__ = ["OpenAIFunction", "OpenAIFunctionParameters", "OpenAITool", "format_as_openai_tool"]
__all__ = [
"FunctionParameters",
"OpenAIFunction",
"OpenAITool",
"AWSConverseTool",
"AWSToolSpec",
"format_as_openai_tool",
"format_as_aws_converse_tool",
]
32 changes: 32 additions & 0 deletions src/cleanlab_codex/utils/aws.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from __future__ import annotations

from typing import Any, Dict, List

from pydantic import BaseModel, Field

from cleanlab_codex.utils.types import FunctionParameters


class ToolSpec(BaseModel):
name: str
description: str
input_schema: Dict[str, FunctionParameters] = Field(alias="inputSchema")


class Tool(BaseModel):
tool_spec: ToolSpec = Field(alias="toolSpec")


def format_as_aws_converse_tool(
tool_name: str,
tool_description: str,
tool_properties: Dict[str, Any],
required_properties: List[str],
) -> Dict[str, Any]:
return Tool(
toolSpec=ToolSpec(
name=tool_name,
description=tool_description,
inputSchema={"json": FunctionParameters(properties=tool_properties, required=required_properties)},
)
).model_dump(by_alias=True)
44 changes: 44 additions & 0 deletions src/cleanlab_codex/utils/langchain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import annotations

from inspect import signature
from typing import Any, Callable, Dict

from pydantic import BaseModel, Field


def create_args_schema(name: str, func: Callable[..., Any], tool_properties: Dict[str, Any]) -> type[BaseModel]:
"""
Creates a pydantic BaseModel for langchain's args_schema.
Args:
name: Name of the schema.
func: The function for which the schema is being generated.
tool_properties: Metadata about each argument.
Returns:
type[BaseModel]: A pydantic model, annotated as required by langchain.
"""
fields = {}
params = signature(func).parameters

for param_name, param in params.items():
param_type = param.annotation if param.annotation is not param.empty else Any
param_default = param.default
description = tool_properties.get(param_name, {}).get("description", None)

if param_default is param.empty:
fields[param_name] = (param_type, Field(description=description))
else:
fields[param_name] = (
param_type,
Field(default=param_default, description=description),
)

return type(
name,
(BaseModel,),
{
"__annotations__": {k: v[0] for k, v in fields.items()},
**{k: v[1] for k, v in fields.items()},
},
)
17 changes: 4 additions & 13 deletions src/cleanlab_codex/utils/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,12 @@

from pydantic import BaseModel


class Property(BaseModel):
type: Literal["string", "number", "integer", "boolean", "array", "object"]
description: str
from cleanlab_codex.utils.types import FunctionParameters


class FunctionParameters(BaseModel):
type: Literal["object"] = "object"
properties: Dict[str, Property]
required: List[str]
class Tool(BaseModel):
type: Literal["function"] = "function"
function: Function


class Function(BaseModel):
Expand All @@ -22,11 +18,6 @@ class Function(BaseModel):
parameters: FunctionParameters


class Tool(BaseModel):
type: Literal["function"] = "function"
function: Function


def format_as_openai_tool(
tool_name: str,
tool_description: str,
Expand Down
14 changes: 14 additions & 0 deletions src/cleanlab_codex/utils/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Dict, List, Literal

from pydantic import BaseModel


class Property(BaseModel):
type: Literal["string", "number", "integer", "boolean", "array", "object"]
description: str


class FunctionParameters(BaseModel):
type: Literal["object"] = "object"
properties: Dict[str, Property]
required: List[str]
52 changes: 52 additions & 0 deletions tests/test_codex_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from unittest.mock import MagicMock, patch

import pytest
from langchain_core.tools.structured import StructuredTool
from llama_index.core.tools import FunctionTool

from cleanlab_codex.codex_tool import CodexTool
Expand Down Expand Up @@ -31,6 +32,57 @@ def test_to_llamaindex_tool(mock_client_from_access_key: MagicMock) -> None:
assert llama_index_tool.fn == tool.query


def test_to_langchain_tool(mock_client_from_access_key: MagicMock) -> None:
with patch("cleanlab_codex.codex_tool.Project") as mock_project:
mock_project.from_access_key.return_value = MagicMock(client=mock_client_from_access_key, id="test_project_id")

tool = CodexTool.from_access_key("sk-test-123")
langchain_tool = tool.to_langchain_tool()
assert isinstance(langchain_tool, StructuredTool)
assert callable(langchain_tool)
assert hasattr(langchain_tool, "name")
assert hasattr(langchain_tool, "description")
assert (
langchain_tool.name == tool.tool_name
), f"Expected tool name '{tool.tool_name}', got '{langchain_tool.name}'."
assert (
langchain_tool.description == tool.tool_description
), f"Expected description '{tool.tool_description}', got '{langchain_tool.description}'."


def test_to_aws_converse_tool(mock_client_from_access_key: MagicMock) -> None:
with patch("cleanlab_codex.codex_tool.Project") as mock_project:
mock_project.from_access_key.return_value = MagicMock(client=mock_client_from_access_key, id="test_project_id")

tool = CodexTool.from_access_key("sk-test-123")
aws_converse_tool = tool.to_aws_converse_tool()

assert "toolSpec" in aws_converse_tool
assert (
aws_converse_tool["toolSpec"].get("name") == tool.tool_name
), f"Expected name '{tool.tool_name}', got '{aws_converse_tool['toolSpec'].get('name')}'"
assert (
aws_converse_tool["toolSpec"].get("description") == tool.tool_description
), f"Expected description '{tool.tool_description}', got '{aws_converse_tool['toolSpec'].get('description')}'"
assert "inputSchema" in aws_converse_tool["toolSpec"], "inputSchema key is missing in toolSpec"

input_schema = aws_converse_tool["toolSpec"]["inputSchema"]
assert "json" in input_schema

json_schema = input_schema["json"]
assert json_schema.get("type") == "object"
assert "properties" in json_schema

properties = json_schema["properties"]
assert "question" in properties

question_property = properties["question"]
assert question_property.get("type") == "string"
assert "description" in question_property
assert "required" in json_schema
assert "question" in json_schema["required"]


@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher")
def test_to_smolagents_tool(mock_client_from_access_key: MagicMock) -> None:
from smolagents import Tool # type: ignore
Expand Down

0 comments on commit 4b02d57

Please sign in to comment.