Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: auto extract tool properties from fn def #24

Merged
merged 2 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 19 additions & 26 deletions src/cleanlab_codex/codex_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,19 @@

from __future__ import annotations

from typing import Any, ClassVar, Optional
from typing import Any, Optional

from typing_extensions import Annotated

from cleanlab_codex.project import Project
from cleanlab_codex.utils.function import pydantic_model_from_function, required_properties_from_model


class CodexTool:
"""A tool that connects to a Codex project to answer questions."""

_tool_name = "ask_advisor"
_tool_description = "Asks an all-knowing advisor this query in cases where it cannot be answered from the provided Context. If the answer is available, this returns None."
_tool_properties: ClassVar[dict[str, Any]] = {
"question": {
"type": "string",
"description": "The question to ask the advisor. This should be the same as the original user question, except in cases where the user question is missing information that could be additionally clarified.",
}
}
_tool_requirements: ClassVar[list[str]] = ["question"]
_tool_description = "Asks an all-knowing advisor this query in cases where it cannot be answered from the provided Context. If the answer is unavailable, this returns None."
DEFAULT_FALLBACK_ANSWER = "Based on the available information, I cannot provide a complete answer to this question."

def __init__(
Expand All @@ -29,6 +25,9 @@ def __init__(
):
self._project = project
self._fallback_answer = fallback_answer
self._tool_function_schema = pydantic_model_from_function(self._tool_name, self.query)
self._tool_properties = self._tool_function_schema.model_json_schema()["properties"]
self._tool_requirements = required_properties_from_model(self._tool_function_schema)

@classmethod
def from_access_key(
Expand Down Expand Up @@ -86,11 +85,17 @@ def fallback_answer(self, value: Optional[str]) -> None:
"""Sets the fallback answer to use if the Codex project cannot answer the question."""
self._fallback_answer = value

def query(self, question: str) -> Optional[str]:
"""Asks an all-knowing advisor this question in cases where it cannot be answered from the provided Context.
def query(
self,
question: Annotated[
str,
"The question to ask the advisor. This should be the same as the original user question, except in cases where the user question is missing information that could be additionally clarified.",
],
) -> Optional[str]:
"""Asks an all-knowing advisor this question in cases where it cannot be answered from the provided Context. If the answer is unavailable, this returns a fallback answer or None.

Args:
question: The question to ask the advisor. This should be the same as the original user question, except in cases where the user question is missing information that could be additionally clarified.
question (str): The question to ask the advisor. This should be the same as the original user question, except in cases where the user question is missing information that could be additionally clarified.

Returns:
The answer to the question if available. If no answer is available, this returns a fallback answer or None.
Expand Down Expand Up @@ -130,17 +135,11 @@ def to_llamaindex_tool(self) -> Any:
"""
from llama_index.core.tools import FunctionTool

from cleanlab_codex.utils.llamaindex import get_function_schema

return FunctionTool.from_defaults(
fn=self.query,
name=self._tool_name,
description=self._tool_description,
fn_schema=get_function_schema(
name=self._tool_name,
func=self.query,
tool_properties=self._tool_properties,
),
fn_schema=self._tool_function_schema,
)

def to_langchain_tool(self) -> Any:
Expand All @@ -150,17 +149,11 @@ def to_langchain_tool(self) -> Any:
"""
from langchain_core.tools.structured import StructuredTool

from cleanlab_codex.utils.langchain import create_args_schema

return StructuredTool.from_function(
func=self.query,
name=self._tool_name,
description=self._tool_description,
args_schema=create_args_schema(
name=self._tool_name,
func=self.query,
tool_properties=self._tool_properties,
),
args_schema=self._tool_function_schema,
)

def to_aws_converse_tool(self) -> Any:
Expand Down
2 changes: 1 addition & 1 deletion src/cleanlab_codex/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from cleanlab_codex.utils.aws import Tool as AWSConverseTool
from cleanlab_codex.utils.aws import ToolSpec as AWSToolSpec
from cleanlab_codex.utils.aws import format_as_aws_converse_tool
from cleanlab_codex.utils.function import FunctionParameters
from cleanlab_codex.utils.openai import Function as OpenAIFunction
from cleanlab_codex.utils.openai import Tool as OpenAITool
from cleanlab_codex.utils.openai import format_as_openai_tool
from cleanlab_codex.utils.types import FunctionParameters

__all__ = [
"FunctionParameters",
Expand Down
9 changes: 7 additions & 2 deletions src/cleanlab_codex/utils/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pydantic import BaseModel, Field

from cleanlab_codex.utils.types import FunctionParameters
from cleanlab_codex.utils.function import FunctionParameters


class ToolSpec(BaseModel):
Expand All @@ -27,6 +27,11 @@ def format_as_aws_converse_tool(
toolSpec=ToolSpec(
name=tool_name,
description=tool_description,
inputSchema={"json": FunctionParameters(properties=tool_properties, required=required_properties)},
inputSchema={
"json": FunctionParameters(
properties=tool_properties,
required=required_properties,
)
},
)
).model_dump(by_alias=True)
79 changes: 79 additions & 0 deletions src/cleanlab_codex/utils/function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from inspect import signature
from typing import Any, Callable, Dict, List, Literal, Type

from pydantic import BaseModel, Field, create_model
from typing_extensions import Annotated, get_args, get_origin


class Property(BaseModel):
type: Literal["string", "number", "integer", "boolean", "array", "object"]
description: str


class FunctionParameters(BaseModel):
type: Literal["object"] = "object"
properties: Dict[str, Property]
required: List[str]


def pydantic_model_from_function(name: str, func: Callable[..., Any]) -> Type[BaseModel]:
"""
Create a pydantic model representing a function's schema.

For example, a function with the following signature:

```python
def my_function(
a: Annotated[int, "This is an integer"], b: str = "default"
) -> None: ...
```

will be represented by the following pydantic model when `name="my_function"`:

```python
class my_function(BaseModel):
a: int = Field(description="This is an integer")
b: str = "default"
```

Args:
name: The name for the pydantic model.
func: The function to create a schema for.

Returns:
A pydantic model representing the function's schema.
"""
fields = {}
params = signature(func).parameters

for param_name, param in params.items():
param_type = param.annotation
if isinstance(param_type, str):
param_type = eval(param_type) # noqa: S307

param_default = param.default
description = None

if get_origin(param_type) is Annotated:
args = get_args(param_type)
param_type = args[0]
if isinstance(args[1], str):
description = args[1]

if param_type is param.empty:
param_type = Any

if param_default is param.empty:
fields[param_name] = (param_type, Field(description=description))
else:
fields[param_name] = (
param_type,
Field(default=param_default, description=description),
)

return create_model(name, **fields) # type: ignore


def required_properties_from_model(model: Type[BaseModel]) -> List[str]:
"""Returns a list of required properties from a pydantic model."""
return [name for name, field in model.model_fields.items() if field.is_required()]
44 changes: 0 additions & 44 deletions src/cleanlab_codex/utils/langchain.py

This file was deleted.

30 changes: 0 additions & 30 deletions src/cleanlab_codex/utils/llamaindex.py

This file was deleted.

7 changes: 5 additions & 2 deletions src/cleanlab_codex/utils/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pydantic import BaseModel

from cleanlab_codex.utils.types import FunctionParameters
from cleanlab_codex.utils.function import FunctionParameters


class Tool(BaseModel):
Expand All @@ -28,6 +28,9 @@ def format_as_openai_tool(
function=Function(
name=tool_name,
description=tool_description,
parameters=FunctionParameters(properties=tool_properties, required=required_properties),
parameters=FunctionParameters(
properties=tool_properties,
required=required_properties,
),
)
).model_dump()
14 changes: 0 additions & 14 deletions src/cleanlab_codex/utils/types.py

This file was deleted.

File renamed without changes.
Empty file added tests/utils/__init__.py
Empty file.
40 changes: 40 additions & 0 deletions tests/utils/test_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Any

from typing_extensions import Annotated

from cleanlab_codex.utils.function import pydantic_model_from_function


def test_function_schema_with_annotated_params() -> None:
def function_with_annotated_params(
a: Annotated[str, "This is a string"], # noqa: ARG001
) -> None: ...

fn_schema = pydantic_model_from_function("test_function", function_with_annotated_params)
assert fn_schema.model_json_schema()["title"] == "test_function"
assert fn_schema.model_fields["a"].annotation is str
assert fn_schema.model_fields["a"].description == "This is a string"
assert fn_schema.model_fields["a"].is_required()


def test_function_schema_without_annotations() -> None:
def function_without_annotations(a) -> None: # type: ignore # noqa: ARG001
...

fn_schema = pydantic_model_from_function("test_function", function_without_annotations)
assert fn_schema.model_json_schema()["title"] == "test_function"
assert fn_schema.model_fields["a"].annotation is Any # type: ignore[comparison-overlap]
assert fn_schema.model_fields["a"].is_required()
assert fn_schema.model_fields["a"].description is None


def test_function_schema_with_default_param() -> None:
def function_with_default_param(a: int = 1) -> None: # noqa: ARG001
...

fn_schema = pydantic_model_from_function("test_function", function_with_default_param)
assert fn_schema.model_json_schema()["title"] == "test_function"
assert fn_schema.model_fields["a"].annotation is int
assert fn_schema.model_fields["a"].default == 1
assert not fn_schema.model_fields["a"].is_required()
assert fn_schema.model_fields["a"].description is None