-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcore.py
199 lines (160 loc) · 6.65 KB
/
core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
#
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
import typing as tp
import salina
import torch
import torch.utils.data
from salina import Agent, Workspace, instantiate_class
from salina.agents import Agents, TemporalAgent
from salina.agents.brax import EpisodesDone
class Task:
"""A Reinforcement Learning task defined as a SaLinA agent. Use make() method
to instantiate the salina agent corresponding to the task.
Parameters
----------
env_agent_cfg : The OmegaConf (or dict) that allows to configure the SaLinA agent
task_id : An identifier of the task
n_interactions : Defaults to None. Number of env interactions allowed for training
input_dimension : The input dimension of the observations
output_dimension: The output dimension of the actions (i.e size of the output tensor, or number of actions if discrete actions)
"""
def __init__(self,env_agent_cfg: dict,
task_id: int,
n_interactions: tp.Union[None,int] = None,
input_dimension: tp.Union[None,int] = None,
output_dimension: tp.Union[None,int] = None,
) -> None:
self._task_id = task_id
self._n_interactions = n_interactions
self._env_agent_cfg = env_agent_cfg
if input_dimension is None or output_dimension is None:
env = env_agent_cfg["make_env_fn"](**env_agent_cfg["make_env_args"])
self._input_dimension = env.observation_space.shape[0]
self._output_dimension = env.action_space.shape[0]
else:
self._input_dimension = input_dimension
self._output_dimension = output_dimension
def input_dimension(self) -> int:
return self._input_dimension
def output_dimension(self) -> int:
return self._output_dimension
def task_id(self) -> int:
return self._task_id
def env_cfg(self) -> dict:
return self._env_agent_cfg
def make(self) -> salina.Agent:
agent = instantiate_class(self._env_agent_cfg)
agent.set_name("env")
return agent
def n_interactions(self) -> int:
return self._n_interactions
class Scenario:
"""
A scenario is a sequence of train tasks and a sequence of test tasks.
"""
def __init__(self) -> None:
self._train_tasks = []
self._test_tasks = []
def train_tasks(self) -> tp.List[Task]:
return self._train_tasks
def test_tasks(self) -> tp.List[Task]:
return self._test_tasks
class Framework:
"""A (CRL) Model can be updated over one new task, and evaluated over any task
Parameters
----------
seed
params : The OmegaConf (or dict) that allows to configure the model
"""
def __init__(self,seed: int,params: dict) -> None:
self.seed=seed
self.cfg=params
self._stage=0
def memory_size(self) -> dict:
raise NotImplementedError
def get_stage(self) -> int:
return self._stage
def train(self,task: Task,logger: tp.Any, **extra_args) -> None:
""" Update a model over a particular task.
Parameters
----------
task: The task to train on
logger: a salina logger to log metrics and messages
"""
logger.message("-- Train stage "+str(self._stage))
output=self._train(task,logger.get_logger("stage_"+str(self._stage)+"/"))
[logger.add_scalar("monitor_per_stage/"+k,output[k],self._stage) for k in output]
self._stage+=1
def evaluate(self,test_tasks: tp.List[Task], logger: tp.Any) -> dict:
""" Evaluate a model over a set of test tasks
Parameters
----------
test_tasks: The set of tasks to evaluate on
logger: a salina logger
Returns
----------
evaluation: A dict containing some evaluation metrics
"""
logger.message("Starting evaluation...")
with torch.no_grad():
evaluation={}
for k,task in enumerate(test_tasks):
metrics=self._evaluate_single_task(task)
evaluation[task.task_id()]=metrics
logger.message("Evaluation over task "+str(k)+":"+str(metrics))
logger.message("-- End evaluation...")
return evaluation
def _train(self,task: Task,logger: tp.Any) -> None:
raise NotImplementedError
def get_evaluation_agent(self,task_id: int) -> salina.Agent:
raise NotImplementedError
def _evaluate_single_task(self,task: Task) -> dict:
metrics={}
env_agent=task.make()
policy_agent=self.get_evaluation_agent(task.task_id())
if not policy_agent is None:
policy_agent.eval()
no_autoreset = EpisodesDone()
acquisition_agent = TemporalAgent(Agents(env_agent,no_autoreset,policy_agent))
acquisition_agent.seed(self.seed*13+self._stage*100)
acquisition_agent.to(self.cfg.evaluation.device)
avg_reward=0.0
n=0
avg_success=0.0
for r in range(self.cfg.evaluation.n_rollouts):
workspace=Workspace()
acquisition_agent(workspace,t=0,stop_variable="env/done")
ep_lengths=workspace["env/done"].max(0)[1]+1
B=ep_lengths.size()[0]
arange=torch.arange(B).to(ep_lengths.device)
cr=workspace["env/cumulated_reward"][ep_lengths-1,arange]
avg_reward+=cr.sum().item()
if self.cfg.evaluation.evaluate_success:
cr=workspace["env/success"][ep_lengths-1,arange]
avg_success+=cr.sum().item()
n+=B
avg_reward /= n
metrics["avg_reward"] = avg_reward
if self.cfg.evaluation.evaluate_success:
avg_success/=n
metrics["success_rate"]=avg_success
return metrics
class CRLAgent(Agent):
"""A salina Agent that is able to apply set_task() and add_regularizer() methods
"""
def set_task(self,task_id: tp.Union[None,int] = None) -> None:
pass
def add_regularizer(self, *args) -> torch.Tensor:
return torch.Tensor([0.]).to(list(self.parameters())[0].device)
class CRLAgents(Agents):
"""A batch of CRL Agents called sequentially.
"""
def set_task(self,task_id: tp.Union[None,int] = None) -> None:
for agent in self:
agent.set_task(task_id)
def add_regularizer(self, *args) -> torch.Tensor:
return torch.cat([agent.add_regularizer(*args) for agent in self]).sum()