Skip to content

Commit

Permalink
Moved DummyEnv to report empty ToolRequestMessage (#183)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza authored Jan 30, 2025
1 parent f4ebd73 commit 1030170
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/aviary/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ async def step(
) -> tuple[Messages, float, bool, bool]:
msgs: Messages = await self.exec_tool_calls(
action, state=self.state, concurrency=self.concurrent_tool_calls
)
) or [Message(content=f"No tool calls input in tool request {action}.")]
self.state.messages.extend(msgs)
return msgs, self.state.reward, self.state.done, False

Expand Down
7 changes: 7 additions & 0 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ async def my_policy(obs: list[Message]) -> ToolRequestMessage: # noqa: ARG001,
assert isinstance(obs, list)
assert len(obs) == 1

# Check if we have a bad policy that gives an empty action, the env reports this
obs, reward, done, _ = await dummy_env.step(
action=ToolRequestMessage(tool_calls=[])
)
assert not done, "Should not be done after empty action"
assert "no tool calls" in obs[0].content.lower()

action = await my_policy(obs)
_, reward, done, _ = await dummy_env.step(action)
assert reward > 0
Expand Down

0 comments on commit 1030170

Please sign in to comment.