-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: auto extract tool properties from fn def (#24)
* auto extract tool properties from fn def * update docstring
- Loading branch information
Showing
11 changed files
with
151 additions
and
119 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
from inspect import signature | ||
from typing import Any, Callable, Dict, List, Literal, Type | ||
|
||
from pydantic import BaseModel, Field, create_model | ||
from typing_extensions import Annotated, get_args, get_origin | ||
|
||
|
||
class Property(BaseModel): | ||
type: Literal["string", "number", "integer", "boolean", "array", "object"] | ||
description: str | ||
|
||
|
||
class FunctionParameters(BaseModel): | ||
type: Literal["object"] = "object" | ||
properties: Dict[str, Property] | ||
required: List[str] | ||
|
||
|
||
def pydantic_model_from_function(name: str, func: Callable[..., Any]) -> Type[BaseModel]: | ||
""" | ||
Create a pydantic model representing a function's schema. | ||
For example, a function with the following signature: | ||
```python | ||
def my_function( | ||
a: Annotated[int, "This is an integer"], b: str = "default" | ||
) -> None: ... | ||
``` | ||
will be represented by the following pydantic model when `name="my_function"`: | ||
```python | ||
class my_function(BaseModel): | ||
a: int = Field(description="This is an integer") | ||
b: str = "default" | ||
``` | ||
Args: | ||
name: The name for the pydantic model. | ||
func: The function to create a schema for. | ||
Returns: | ||
A pydantic model representing the function's schema. | ||
""" | ||
fields = {} | ||
params = signature(func).parameters | ||
|
||
for param_name, param in params.items(): | ||
param_type = param.annotation | ||
if isinstance(param_type, str): | ||
param_type = eval(param_type) # noqa: S307 | ||
|
||
param_default = param.default | ||
description = None | ||
|
||
if get_origin(param_type) is Annotated: | ||
args = get_args(param_type) | ||
param_type = args[0] | ||
if isinstance(args[1], str): | ||
description = args[1] | ||
|
||
if param_type is param.empty: | ||
param_type = Any | ||
|
||
if param_default is param.empty: | ||
fields[param_name] = (param_type, Field(description=description)) | ||
else: | ||
fields[param_name] = ( | ||
param_type, | ||
Field(default=param_default, description=description), | ||
) | ||
|
||
return create_model(name, **fields) # type: ignore | ||
|
||
|
||
def required_properties_from_model(model: Type[BaseModel]) -> List[str]: | ||
"""Returns a list of required properties from a pydantic model.""" | ||
return [name for name, field in model.model_fields.items() if field.is_required()] |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
File renamed without changes.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from typing import Any | ||
|
||
from typing_extensions import Annotated | ||
|
||
from cleanlab_codex.utils.function import pydantic_model_from_function | ||
|
||
|
||
def test_function_schema_with_annotated_params() -> None: | ||
def function_with_annotated_params( | ||
a: Annotated[str, "This is a string"], # noqa: ARG001 | ||
) -> None: ... | ||
|
||
fn_schema = pydantic_model_from_function("test_function", function_with_annotated_params) | ||
assert fn_schema.model_json_schema()["title"] == "test_function" | ||
assert fn_schema.model_fields["a"].annotation is str | ||
assert fn_schema.model_fields["a"].description == "This is a string" | ||
assert fn_schema.model_fields["a"].is_required() | ||
|
||
|
||
def test_function_schema_without_annotations() -> None: | ||
def function_without_annotations(a) -> None: # type: ignore # noqa: ARG001 | ||
... | ||
|
||
fn_schema = pydantic_model_from_function("test_function", function_without_annotations) | ||
assert fn_schema.model_json_schema()["title"] == "test_function" | ||
assert fn_schema.model_fields["a"].annotation is Any # type: ignore[comparison-overlap] | ||
assert fn_schema.model_fields["a"].is_required() | ||
assert fn_schema.model_fields["a"].description is None | ||
|
||
|
||
def test_function_schema_with_default_param() -> None: | ||
def function_with_default_param(a: int = 1) -> None: # noqa: ARG001 | ||
... | ||
|
||
fn_schema = pydantic_model_from_function("test_function", function_with_default_param) | ||
assert fn_schema.model_json_schema()["title"] == "test_function" | ||
assert fn_schema.model_fields["a"].annotation is int | ||
assert fn_schema.model_fields["a"].default == 1 | ||
assert not fn_schema.model_fields["a"].is_required() | ||
assert fn_schema.model_fields["a"].description is None |