Skip to content

Commit

Permalink
chore: remove jax.tree_map
Browse files Browse the repository at this point in the history
  • Loading branch information
sash-a committed Dec 3, 2024
1 parent 900eca8 commit 81c108d
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions mava/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def _episode(key: PRNGKey) -> Tuple[PRNGKey, Metrics]:
# find the first instance of done to get the metrics at that timestep, we don't
# care about subsequent steps because we only the results from the first episode
done_idx = np.argmax(timesteps.last(), axis=0)
metrics = jax.tree_map(lambda m: m[done_idx, np.arange(n_parallel_envs)], metrics)
metrics = tree.map(lambda m: m[done_idx, np.arange(n_parallel_envs)], metrics)
del metrics["is_terminal_step"] # uneeded for logging

return key, metrics
Expand All @@ -307,7 +307,7 @@ def _episode(key: PRNGKey) -> Tuple[PRNGKey, Metrics]:
metrics_array.append(metric)

# flatten metrics
metrics: Metrics = jax.tree_map(lambda *x: np.array(x).reshape(-1), *metrics_array)
metrics: Metrics = tree.map(lambda *x: np.array(x).reshape(-1), *metrics_array)
return metrics

def timed_eval_fn(params: FrozenDict, key: PRNGKey, init_act_state: ActorState) -> Metrics:
Expand Down

0 comments on commit 81c108d

Please sign in to comment.