Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] More flexible episode_reward computation in logger #136

Merged
merged 6 commits into from
Oct 4, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 147 additions & 85 deletions benchmarl/experiment/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

import json
import os
import warnings
from pathlib import Path

from typing import Dict, List, Optional

import numpy as np
Expand Down Expand Up @@ -91,28 +93,27 @@ def log_collection(
step: int,
) -> float:
to_log = {}
json_metrics = {}
groups_episode_rewards = []
gobal_done = self._get_global_done(batch) # Does not have agent dim
any_episode_ended = gobal_done.nonzero().numel() > 0
if not any_episode_ended:
warnings.warn(
"No episode terminated this iteration and thus the episode rewards will be NaN, "
"this is normal if your horizon is longer then one iteration. Learning is proceeding fine."
"The episodes will probably terminate in a future iteration."
)
for group in self.group_map.keys():
episode_reward = self._get_episode_reward(group, batch)
done = self._get_done(group, batch)
reward = self._get_reward(group, batch)
to_log.update(
{
f"collection/{group}/reward/reward_min": reward.min().item(),
f"collection/{group}/reward/reward_mean": reward.mean().item(),
f"collection/{group}/reward/reward_max": reward.max().item(),
}
group_episode_rewards = self._log_individual_and_group_rewards(
group,
batch,
gobal_done,
any_episode_ended,
to_log,
log_individual_agents=False, # Turn on if you want single agent granularity
)
json_metrics[group + "_return"] = episode_reward.mean(-2)[done.any(-2)]
episode_reward = episode_reward[done]
if episode_reward.numel() > 0:
to_log.update(
{
f"collection/{group}/reward/episode_reward_min": episode_reward.min().item(),
f"collection/{group}/reward/episode_reward_mean": episode_reward.mean().item(),
f"collection/{group}/reward/episode_reward_max": episode_reward.max().item(),
}
)
# group_episode_rewards has shape (n_episodes) as we took the mean over agents in the group
groups_episode_rewards.append(group_episode_rewards)

if "info" in batch.get(("next", group)).keys():
to_log.update(
{
Expand All @@ -130,19 +131,13 @@ def log_collection(
}
)
to_log.update(task.log_info(batch))
mean_group_return = torch.stack(
[value for key, value in json_metrics.items()], dim=0
).mean(0)
if mean_group_return.numel() > 0:
to_log.update(
{
"collection/reward/episode_reward_min": mean_group_return.min().item(),
"collection/reward/episode_reward_mean": mean_group_return.mean().item(),
"collection/reward/episode_reward_max": mean_group_return.max().item(),
}
)
# global_episode_rewards has shape (n_episodes) as we took the mean over groups
global_episode_rewards = self._log_global_episode_reward(
groups_episode_rewards, to_log, prefix="collection"
)

self.log(to_log, step=step)
return mean_group_return.mean().item()
return global_episode_rewards.mean().item()

def log_training(self, group: str, training_td: TensorDictBase, step: int):
if not len(self.loggers):
Expand All @@ -164,57 +159,45 @@ def log_evaluation(
not len(self.loggers) and not self.experiment_config.create_json
) or not len(rollouts):
return

# Cut rollouts at first done
max_length_rollout_0 = 0
for i in range(len(rollouts)):
r = rollouts[i]
next_done = self._get_global_done(r).squeeze(-1)

# First done index for this traj
done_index = next_done.nonzero(as_tuple=True)[0]
if done_index.numel() > 0:
done_index = done_index[0]
r = r[: done_index + 1]
if i == 0:
max_length_rollout_0 = max(r.batch_size[0], max_length_rollout_0)
rollouts[i] = r

to_log = {}
json_metrics = {}
max_length_rollout_0 = 0
for group in self.group_map.keys():
# Cut the rollouts at the first done
rollouts_group = []
for i, r in enumerate(rollouts):
next_done = self._get_done(group, r)
# Reduce it to batch size
next_done = next_done.sum(
tuple(range(r.batch_dims, next_done.ndim)),
dtype=torch.bool,
)
# First done index for this traj
done_index = next_done.nonzero(as_tuple=True)[0]
if done_index.numel() > 0:
done_index = done_index[0]
r = r[: done_index + 1]
if i == 0:
max_length_rollout_0 = max(r.batch_size[0], max_length_rollout_0)
rollouts_group.append(r)

returns = [
self._get_reward(group, td).sum(0).mean().item()
for td in rollouts_group
]
json_metrics[group + "_return"] = torch.tensor(
returns, device=rollouts_group[0].device
# returns has shape (n_episodes)
returns = torch.stack(
[self._get_reward(group, td).sum(0).mean() for td in rollouts],
dim=0,
)
to_log.update(
{
f"eval/{group}/reward/episode_reward_min": min(returns),
f"eval/{group}/reward/episode_reward_mean": sum(returns)
/ len(rollouts_group),
f"eval/{group}/reward/episode_reward_max": max(returns),
}
self._log_min_mean_max(
to_log, f"eval/{group}/reward/episode_reward", returns
)
json_metrics[group + "_return"] = returns

mean_group_return = torch.stack(
[value for key, value in json_metrics.items()], dim=0
).mean(0)
to_log.update(
{
"eval/reward/episode_reward_min": mean_group_return.min().item(),
"eval/reward/episode_reward_mean": mean_group_return.mean().item(),
"eval/reward/episode_reward_max": mean_group_return.max().item(),
"eval/reward/episode_len_mean": sum(td.batch_size[0] for td in rollouts)
/ len(rollouts),
}
mean_group_return = self._log_global_episode_reward(
list(json_metrics.values()), to_log, prefix="eval"
)
# mean_group_return has shape (n_episodes) as we take the mean groups
json_metrics["return"] = mean_group_return

to_log["eval/reward/episode_len_mean"] = sum(
td.batch_size[0] for td in rollouts
) / len(rollouts)

if self.json_writer is not None:
self.json_writer.write(
metrics=json_metrics,
Expand Down Expand Up @@ -265,34 +248,113 @@ def finish(self):
def _get_reward(
self, group: str, td: TensorDictBase, remove_agent_dim: bool = False
):
if ("next", group, "reward") not in td.keys(True, True):
reward = td.get(("next", group, "reward"), None)
if reward is None:
reward = (
td.get(("next", "reward")).expand(td.get(group).shape).unsqueeze(-1)
)
else:
reward = td.get(("next", group, "reward"))
return reward.mean(-2) if remove_agent_dim else reward

def _get_done(self, group: str, td: TensorDictBase, remove_agent_dim: bool = False):
if ("next", group, "done") not in td.keys(True, True):
def _get_agents_done(
self, group: str, td: TensorDictBase, remove_agent_dim: bool = False
):
done = td.get(("next", group, "done"), None)
if done is None:
done = td.get(("next", "done")).expand(td.get(group).shape).unsqueeze(-1)
else:
done = td.get(("next", group, "done"))

return done.any(-2) if remove_agent_dim else done

def _get_global_done(
self,
td: TensorDictBase,
):
done = td.get(("next", "done"))
return done

def _get_episode_reward(
self, group: str, td: TensorDictBase, remove_agent_dim: bool = False
):
if ("next", group, "episode_reward") not in td.keys(True, True):
episode_reward = td.get(("next", group, "episode_reward"), None)
if episode_reward is None:
episode_reward = (
td.get(("next", "episode_reward"))
.expand(td.get(group).shape)
.unsqueeze(-1)
)
else:
episode_reward = td.get(("next", group, "episode_reward"))
return episode_reward.mean(-2) if remove_agent_dim else episode_reward

def _log_individual_and_group_rewards(
self,
group: str,
batch: TensorDictBase,
global_done: Tensor,
any_episode_ended: bool,
to_log: Dict[str, Tensor],
prefix: str = "collection",
log_individual_agents: bool = True,
):
reward = self._get_reward(group, batch) # Has agent dim
episode_reward = self._get_episode_reward(group, batch) # Has agent dim
n_agents_in_group = episode_reward.shape[-2]

# Add multiagent dim
unsqueeze_global_done = global_done.unsqueeze(-1).expand(
(*batch.get_item_shape(group), 1)
)
#######
# All trajectories are considered done at the global done
#######

# 1. Here we log rewards from individual agent data
if log_individual_agents:
for i in range(n_agents_in_group):
self._log_min_mean_max(
to_log,
f"{prefix}/{group}/reward/agent_{i}/reward",
reward[..., i, :],
)
if any_episode_ended:
agent_global_done = unsqueeze_global_done[..., i, :]
self._log_min_mean_max(
to_log,
f"{prefix}/{group}/reward/agent_{i}/episode_reward",
episode_reward[..., i, :][agent_global_done],
)

# 2. Here we log rewards from group data taking the mean over agents
group_episode_reward = episode_reward.mean(-2)[global_done]
if any_episode_ended:
self._log_min_mean_max(
to_log, f"{prefix}/{group}/reward/episode_reward", group_episode_reward
)
self._log_min_mean_max(to_log, f"{prefix}/reward/reward", reward)

return group_episode_reward

def _log_global_episode_reward(
self, episode_rewards: List[Tensor], to_log: Dict[str, Tensor], prefix: str
):
# Each element in the list is the episode reward (with shape n_episodes) for the group at the global done,
# so they will have same shape as done is shared
episode_rewards = torch.stack(episode_rewards, dim=0).mean(
0
) # Mean over groups
if episode_rewards.numel() > 0:
self._log_min_mean_max(
to_log, f"{prefix}/reward/episode_reward", episode_rewards
)

return episode_rewards

def _log_min_mean_max(self, to_log: Dict[str, Tensor], key: str, value: Tensor):
to_log.update(
{
key + "_min": value.min().item(),
key + "_mean": value.mean().item(),
key + "_max": value.max().item(),
}
)


class JsonWriter:
"""
Expand Down
Loading