Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Oct 4, 2024
1 parent 20ad0fb commit d1b7f5f
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions benchmarl/experiment/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,17 +252,20 @@ def finish(self):
def _get_reward(
self, group: str, td: TensorDictBase, remove_agent_dim: bool = False
):
reward = td.get(("next", group, "reward"))
reward = td.get(("next", group, "reward"), None)
if reward is None:
reward = (
td.get(("next", "reward")).expand(td.get(group).shape).unsqueeze(-1)
)
return reward.mean(-2) if remove_agent_dim else reward

def _get_agents_done(
self, group: str, td: TensorDictBase, remove_agent_dim: bool = False
):
group_td = td.get(("next", group))
if ("done") not in group_td.keys():
done = td.get(("next", group, "done"), None)
if done is None:
done = td.get(("next", "done")).expand(td.get(group).shape).unsqueeze(-1)
else:
done = group_td.get("done")

return done.any(-2) if remove_agent_dim else done

def _get_global_done(
Expand All @@ -275,7 +278,13 @@ def _get_global_done(
def _get_episode_reward(
self, group: str, td: TensorDictBase, remove_agent_dim: bool = False
):
episode_reward = td.get(("next", group, "episode_reward"))
episode_reward = td.get(("next", group, "episode_reward"), None)
if episode_reward is None:
episode_reward = (
td.get(("next", "episode_reward"))
.expand(td.get(group).shape)
.unsqueeze(-1)
)
return episode_reward.mean(-2) if remove_agent_dim else episode_reward

def _log_individual_and_group_rewards(
Expand Down

0 comments on commit d1b7f5f

Please sign in to comment.