-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcore.py
275 lines (221 loc) · 10.7 KB
/
core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
"""Core API for human-agent collaboration environments and collaborative acts for team coordination."""
import re
import logging
from enum import Enum
from typing import Dict, SupportsFloat, Any, Tuple, Optional
from pydantic import BaseModel
from collaborative_gym.spaces import MultiSpace
from collaborative_gym.spaces import UnicodeWithRegexPattern, MAX_UNICODE_LENGTH
logging.basicConfig(
level=logging.INFO, format="%(name)s : %(levelname)-8s : %(message)s"
)
logger = logging.getLogger(__name__)
ActType = str
ObsType = dict[str, Any]
class ObservationTypes(Enum):
"""Enumeration of observation rendering types for the GUI interface."""
NO_RENDER = "NoRender"
JUPYTER_NOTEBOOK = "JupyterEditor"
TEXT_EDITOR = "TextEditor"
PAPER_LIBRARY = "PaperLibrary"
PAPER_SEARCH = "PaperSearchInterface"
TRAVEL_SEARCH = "TravelSearchInterface"
DISTANCE_MATRIX = "DistanceMatrix"
def __str__(self):
return self.value
def __eq__(self, other):
if isinstance(other, ObservationTypes):
return self.value == other.value
if isinstance(other, str):
return self.value == other
return False
class CoEnv:
"""Base environment class for human-agent collaboration tasks.
The class encapsulates an environment with arbitrary behind-the-scenes dynamics through the `step` and `reset` functions.
An environment can be partially or fully observed by a single agent based on the private properties of observations.
Attributes:
task_description: Description of the collaborative task
team_members: List of team member identifiers
action_space: Actions that affect the shared environment
private_action_space: Actions that only affect the acting member's environment
additional_task_info: Optional task-specific information for creating information asymmetry
example_question: Example task scenario for team members
example_trajectory: Example sequence of actions and observations
env_id: Unique identifier for this environment instance
"""
task_description: str # Description of the task
team_members: list[str] # List of team member identifiers
action_space: MultiSpace # Actions that will affect the shared environment
private_action_space: MultiSpace # Actions that will affect the private environment
# Additional information about the task (can be used to create information asymmetry)
additional_task_info: dict[str, Any] = {}
# Associate each task environment with an example question and trajectory for team members to understand the task
example_question: str = ""
example_trajectory: list[Tuple[str, ActType, ObsType]] = (
[]
) # A sequence of Thought/Action/Observation tuples
env_id: str # Unique identifier for this environment instance
def __init__(self, team_members: list[str], env_id: str):
self.team_members = team_members
self.env_id = env_id
def step(
self, role: str, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Run one timestep of the environment's dynamics using the agent actions.
Returns:
observation (ObsType): An element of the environment's observation_space as the current observation due to the agent actions.
reward (SupportsFloat): The reward as a result of taking the action. 0 for success, -1 for failure unless otherwise specified.
terminated (bool): Whether the agent reaches the terminal state.
private (bool): Whether the change shall notify the action taker or be broadcasted to all team members.
info (dict): Contains auxiliary diagnostic information (helpful for debugging, learning, and logging).
"""
raise NotImplementedError
def handle_action_error(
self, error_msg: str, private: bool = True
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Handle action errors in a consistent way across all environments.
Args:
error_msg (str): The error message to log and return
private (bool): Whether this is a private action. Defaults to True for errors.
Returns:
tuple containing standard step() return values with error information
"""
logging.error(error_msg)
return self.get_obs(), -1, False, private, {"action_error": error_msg}
def parse_and_validate_action(
self, role: str, action: ActType
) -> tuple[dict[str, Any], bool, Optional[str], Optional[str]]:
"""Common routine for action parsing and validation.
Args:
role (str): The team member attempting the action
action (ActType): The action string to parse and validate
Returns:
tuple containing:
parsed_action (dict): Dictionary of parsed action parameters
private (bool): Whether this is a private action
action_id (Optional[str]): The machine readable identifier of the action if successful
error_message (Optional[str]): Error message if validation fails, None if successful
"""
# Validate role
if role not in self.team_members:
return {}, True, None, f"{role!r} is not a valid team member."
# Validate action against action spaces
private = False
if self.private_action_space.contains(action):
private = True
elif not self.action_space.contains(action):
return (
{},
True,
None,
f"{action!r} invalid. Please strictly follow the action space specifications.",
)
# Parse action parameters using regex pattern
action_space = self.private_action_space if private else self.action_space
parsed_action = None
for space in action_space:
parsed_action = space.parse(action)
if parsed_action is not None: # {} is also a valid parsed action
return parsed_action, private, space.machine_readable_identifier, None
# Should not reach here if action spaces are properly defined
return {}, True, None, f"Failed to parse parameters from {action!r}"
def reset(
self,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]: # type: ignore
"""Resets the environment to an initial internal state, returning an initial observation and info."""
raise NotImplementedError
def close(self):
"""After the user has finished using the environment, close contains the code necessary to "clean up" the environment.
This is critical for closing rendering windows, database or HTTP connections.
Calling ``close`` on an already closed environment has no effect and won't raise an error.
"""
pass
def get_obs(self) -> ObsType:
"""Get the current observation of the environment.
Returns:
ObsType: Current observation dictionary containing the environment state.
"""
raise NotImplementedError
def obs_type(self) -> Dict[str, ObservationTypes]:
"""Return the type of each observation field for GUI rendering.
Returns:
Dict[str, ObservationTypes]: Mapping of observation keys to their rendering types.
"""
raise NotImplementedError
def action_space_to_description(self) -> str:
"""Convert the action space into a human-readable description.
Returns:
str: A string describing available actions and their parameters.
"""
raise NotImplementedError
def evaluate_task_performance(self) -> Dict:
"""Evaluate the performance of the team on the collaborative task.
Returns:
Dict: Performance metrics and statistics for the task completion.
"""
pass
def dump_action_space(self):
"""Dumps the action space to a JSON serializable format."""
return [action.dump_json() for action in self.action_space] + [
action.dump_json() for action in self.private_action_space
]
def __str__(self):
"""Returns a string of the environment with :attr:`spec` id's if :attr:`spec.
Returns:
A string identifying the environment
"""
return f"<{type(self).__name__} instance>"
def __repr__(self):
"""Gives a string representation of this environment."""
pass
def __enter__(self):
"""Support with-statement for the environment."""
return self
def __exit__(self, *args: Any):
"""Support with-statement for the environment and closes the environment."""
self.close()
# propagate exception
return False
class TeamMemberConfig(BaseModel):
"""
Configuration class for team members in collaborative environments.
Defines the properties and initialization parameters for both human and AI
team members, including their type (agent, user, etc.), name, and startup
configuration for node processes.
Attributes:
name: Unique identifier for the team member
type: Type of team member ('agent', 'cmd_user', 'gui_user', etc.)
start_node_base_command: Command used to start this member's node process
"""
name: str
type: str
start_node_base_command: str
# Collaborative Actions
class SendTeammateMessage(UnicodeWithRegexPattern):
def __init__(self):
super().__init__(
min_length=0,
max_length=MAX_UNICODE_LENGTH,
regex_pattern=re.compile(
r"^SEND_TEAMMATE_MESSAGE\(message=(.*)\)$", re.DOTALL
),
params=["message"],
machine_readable_identifier="SEND_TEAMMATE_MESSAGE",
human_readable_name="Send a message to your teammate(s).",
human_readable_description="Send a message to your teammate(s) to provide information, ask for feedback, "
"allocate task, etc. This action is useful for collaboration.",
)
class WaitTeammateContinue(UnicodeWithRegexPattern):
def __init__(self):
super().__init__(
min_length=0,
max_length=MAX_UNICODE_LENGTH,
regex_pattern=re.compile(r"^WAIT_TEAMMATE_CONTINUE\(\)$", re.DOTALL),
params=[],
machine_readable_identifier="WAIT_TEAMMATE_CONTINUE",
human_readable_name="Wait for your teammate(s) to continue.",
human_readable_description="Skip your turn and wait for your teammate(s) to continue. This action is useful"
" for collaboration, especially when you need to wait for your teammate(s) to "
"provide information, complete certain parts before you can proceed, etc.",
)