Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Apr 29, 2024
1 parent d96a40f commit affe867
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions benchmarl/environments/meltingpot/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,17 +107,23 @@ def group_map(self, env: EnvBase) -> Dict[str, List[str]]:
return env.group_map

def get_env_transforms(self, env: EnvBase) -> List[Transform]:
return [
DoubleToFloat(),
FlattenObservation(
in_keys=[
(group, "observation", "INTERACTION_INVENTORIES")
for group in self.group_map(env).keys()
],
first_dim=-2,
last_dim=-1,
),
interaction_inventories_keys = [
(group, "observation", "INTERACTION_INVENTORIES")
for group in self.group_map(env).keys()
if (group, "observation", "INTERACTION_INVENTORIES")
in env.observation_spec.keys(True, True)
]
return [DoubleToFloat()] + (
[
FlattenObservation(
in_keys=interaction_inventories_keys,
first_dim=-2,
last_dim=-1,
)
]
if len(interaction_inventories_keys)
else []
)

def get_replay_buffer_transforms(self, env: EnvBase) -> List[Transform]:
return [
Expand Down

0 comments on commit affe867

Please sign in to comment.