Skip to content

Commit

Permalink
Fixed argref_by_name bug with kwargs (#50)
Browse files Browse the repository at this point in the history
Found an awful dirty bug when using @argref_by_name. To route strings, I wrote it so if a kwarg is passed - like (foo='bar') instead of ('bar') then it will not try to dereference that string to an object from the state. However - SimpleAgent calls all tools via kwarg and ReactAgent calls all tools via args. So, ReactAgent looked great and SimpleAgent was passing strings around and breaking code.

I made it so that if strings are passed and if they do not match a key, just assume they're meant to be used as a literal. However, I did keep an error check if the first argument is a string and not found in state as a valid key, then it throws an error.

I also cleaned up the function a bit and split up tests
  • Loading branch information
whitead authored Sep 27, 2024
1 parent 7ab69ab commit a6eb7d1
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 48 deletions.
97 changes: 58 additions & 39 deletions src/aviary/tools/argref.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,17 @@ def argref_wraps(wrapped):
return partial(argref_wrapper, wrapped=wrapped)


def argref_by_name( # noqa: C901
fxn_requires_state: bool = False, prefix: str = "", return_direct: bool = False
def argref_by_name( # noqa: C901,PLR0915
fxn_requires_state: bool = False,
prefix: str = "",
return_direct: bool = False,
):
"""Decorator to allow args to be a string key into a refs dict instead of the full object.
This can prevent LLM-powered tool selections from getting confused by full objects,
instead it enables them to work using named references.
instead it enables them to work using named references. If a reference is not found, it
will fallback on passing the original argument unless it is the first argument. If the
first argument str is not found in the state object, it will raise an error.
Args:
fxn_requires_state: Whether to pass the state object to the decorated function.
Expand Down Expand Up @@ -91,7 +95,7 @@ def argref_by_name( # noqa: C901
"""

def decorator(func): # noqa: C901
def get_call_args(*args, **kwargs):
def get_call_args(*args, **kwargs): # noqa: C901
if "state" not in kwargs:
raise ValueError(
"argref_by_name decorated function must have a 'state' argument. "
Expand All @@ -101,39 +105,54 @@ def get_call_args(*args, **kwargs):
# pop the state argument
state = kwargs["state"] if fxn_requires_state else kwargs.pop("state")

# pop all string arguments
other_args = []
keyname_args = []
# now convert the keynames to actual references (if they are a string)
# tuple is (arg, if was dereferenced)
def maybe_deref_arg(arg):
if isinstance(arg, str):
try:
if arg in state.refs:
return [state.refs[arg]], True
# sometimes it is not correctly converted to a tuple
# so as an attempt to be helpful...
if all(a.strip() in state.refs for a in arg.split(",")):
return [state.refs[a.strip()] for a in arg.split(",")], True
# fall through
except AttributeError as e:
raise AttributeError(
"The state object must have a 'refs' attribute to use argref_by_name decorator."
) from e
return arg, False

# the split thing makes it complicated and we cannot use comprehension
deref_args = []
largs = list(args)
while largs and isinstance(largs[0], str):
keyname_args.append(largs.pop(0))
if largs:
other_args.extend(largs)
if not keyname_args and kwargs:
name = next(iter(kwargs))
keyname_args = [kwargs.pop(name)]

# now convert the keynames to actual references
for arg in keyname_args:
try:
if arg in state.refs:
deref_args.append(state.refs[arg])
# sometimes it is not correctly converted to a tuple
# so as an attempt to be helpful...
elif all(a.strip() in state.refs for a in arg.split(",")):
deref_args.extend([
state.refs[a.strip()] for a in arg.split(",")
])
else:
raise KeyError(
f"Key '{arg}' not found in state. Available keys: {list(state.refs.keys())}"
atleast_one_deref = False
for i, arg in enumerate(args):
a, dr = maybe_deref_arg(arg)
if dr:
deref_args.extend(a)
atleast_one_deref = True
else:
if i == 0 and isinstance(arg, str):
# This is a bit of a heuristic, but if the first arg is a string and not found
# likely the user intended to use a reference
raise KeyError(f"The key {arg} is not found in state.")
deref_args.append(a)
deref_kwargs = {}
for i, (k, v) in enumerate(kwargs.items()):
a, dr = maybe_deref_arg(v)
if dr:
if len(a) > 1:
raise ValueError(
f"Multiple values for argument '{k}' found in state. "
" cannot use split notation for kwargs."
)
except AttributeError as e:
raise AttributeError(
"The state object must have a 'refs' attribute to use argref_by_name decorator."
) from e
return deref_args, other_args, kwargs, state
deref_kwargs[k] = a[0]
else:
if i == 0 and isinstance(a, str) and not atleast_one_deref:
raise KeyError(f"The key {a} is not found in state.")
deref_kwargs[k] = a

return deref_args, deref_kwargs, state

def update_state(state, result):
if return_direct:
Expand All @@ -153,14 +172,14 @@ def update_state(state, result):

@argref_wraps(func)
def wrapper(*args, **kwargs):
args, other_args, kwargs, state = get_call_args(*args, **kwargs)
result = func(*args, *other_args, **kwargs)
args, kwargs, state = get_call_args(*args, **kwargs)
result = func(*args, **kwargs)
return update_state(state, result)

@argref_wraps(func)
async def awrapper(*args, **kwargs):
args, other_args, kwargs, state = get_call_args(*args, **kwargs)
result = await func(*args, *other_args, **kwargs)
args, kwargs, state = get_call_args(*args, **kwargs)
result = await func(*args, **kwargs)
return update_state(state, result)

wrapper.requires_state = True
Expand Down
57 changes: 48 additions & 9 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ def get_todo_list_no_args():


@pytest.mark.asyncio
async def test_argref_by_name() -> None:
async def test_argref_by_name_basic_usage() -> None:
class MyState:
def __init__(self):
self.refs = {"foo": 1}
Expand All @@ -572,6 +572,16 @@ def __init__(self):
result = wrapped_add(6, 2, state=s)
assert s.refs[result.split()[0]] == 6 + 2


@pytest.mark.asyncio
async def test_argref_by_name_error_handling() -> None:
class MyState:
def __init__(self):
self.refs = {"foo": 1}

wrapped_add = argref_by_name()(add)
s = MyState()

# Check if we use a key name that doesn't exist, we blow up
with pytest.raises(KeyError, match="not found in state"):
wrapped_add("bar", 2, state=MyState())
Expand All @@ -580,26 +590,35 @@ def __init__(self):
with pytest.raises(AttributeError, match="must have a 'refs' attribute"):
wrapped_add("foo", 2, state="not a state")

# now try with async and decorator

@pytest.mark.asyncio
async def test_argref_by_name_async_functions() -> None:
class MyState:
def __init__(self):
self.refs = {"foo": 1, "bar": 7}

# Define the async_add function with the decorator
@argref_by_name()
async def async_add(a: int, b: int) -> int:
"""Some docstring."""
return a + b

s = MyState()
result = await async_add("foo", 2, state=s)
assert s.refs[result.split()[0]] == 1 + 2

result = await async_add(6, 2, state=s)
assert s.refs[result.split()[0]] == 6 + 2

# now try with lists
s.refs["bar"] = 7
# Now try with lists
result = await async_add("foo", "bar", state=s)
assert s.refs[result.split()[0]] == 1 + 7

# try the convenience of comma splitting on key
# Try the convenience of comma splitting on key
result = await async_add("foo,bar", state=s)
assert s.refs[result.split()[0]] == 1 + 7

# Define and test async_list
@argref_by_name()
async def async_list(a: int, b: int) -> list[int]:
"""Some docstring."""
Expand All @@ -610,15 +629,35 @@ async def async_list(a: int, b: int) -> list[int]:
assert s.refs[name1] == 1
assert s.refs[name2] == 2

# Define and test async_list_direct
@argref_by_name(return_direct=True)
async def async_list_direct(a: int, b: int) -> list[int]:
"""Some docstring."""
return [a, b]

assert await async_list_direct("foo", 2, state=s) == [1, 2]

# call in context
tool = Tool.from_function(argref_by_name()(add))

@pytest.mark.asyncio
async def test_argref_by_name_advanced_features() -> None:
class MyState:
def __init__(self):
self.refs = {"foo": 1}

s = MyState()

# Define and test dereference via no state value found with return_direct
@argref_by_name(return_direct=True)
def skip_deref_test(foo: float, a: str) -> str:
"""Some docstring."""
return f"{foo} {a}"

assert skip_deref_test("foo", a="not in state", state=s) == "1 not in state"
assert skip_deref_test("foo", "foo", state=s) == "1 1"

# Call in context using Tool and related classes
wrapped_add = argref_by_name()(add)
tool = Tool.from_function(wrapped_add)

tool_call = ToolCall.from_tool(tool, "foo", 2)
action = ToolRequestMessage(tool_calls=[tool_call])
Expand All @@ -627,13 +666,13 @@ async def async_list_direct(a: int, b: int) -> list[int]:
new_messages = await my_env.exec_tool_calls(action, state=MyState())
assert new_messages[0].content.endswith("3")

# assert that we can describe the tool
# Assert that we can describe the tool
assert tool.info.describe_str()
assert (
"(set via a string key instead of the full object)" in tool.info.describe_str()
)

# now try state passing
# Test state passing with fxn_requires_state
@argref_by_name(fxn_requires_state=True)
async def want_state(a: int, state: MyState) -> int: # noqa: ARG001
"""Some docstring.
Expand Down

0 comments on commit a6eb7d1

Please sign in to comment.