From feeeefc71189acc632fc6a9b8acdc662461d8524 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sat, 1 Feb 2025 10:56:53 +0000 Subject: [PATCH] amend --- benchmarl/experiment/logger.py | 47 ++++++++++++++++++++++++++-------- 1 file changed, 37 insertions(+), 10 deletions(-) diff --git a/benchmarl/experiment/logger.py b/benchmarl/experiment/logger.py index 20e516aa..5fc67d85 100644 --- a/benchmarl/experiment/logger.py +++ b/benchmarl/experiment/logger.py @@ -7,6 +7,7 @@ import json import os import warnings +from collections.abc import MutableMapping, Sequence from pathlib import Path from typing import Dict, List, Optional @@ -17,6 +18,8 @@ from tensordict import TensorDictBase from torch import Tensor + +from torchrl.record import TensorboardLogger from torchrl.record.loggers import get_logger from torchrl.record.loggers.wandb import WandbLogger @@ -73,17 +76,41 @@ def __init__( ) def log_hparams(self, **kwargs): + kwargs.update( + { + "algorithm_name": self.algorithm_name, + "model_name": self.model_name, + "task_name": self.task_name, + "environment_name": self.environment_name, + "seed": self.seed, + } + ) for logger in self.loggers: - kwargs.update( - { - "algorithm_name": self.algorithm_name, - "model_name": self.model_name, - "task_name": self.task_name, - "environment_name": self.environment_name, - "seed": self.seed, - } - ) - logger.log_hparams(kwargs) + if isinstance(logger, TensorboardLogger): + # Tensorboard does not like nested dictionaries -> flatten them + def flatten(dictionary, parent_key="", separator="_"): + items = [] + for key, value in dictionary.items(): + new_key = parent_key + separator + key if parent_key else key + if isinstance(value, MutableMapping): + items.extend( + flatten(value, new_key, separator=separator).items() + ) + elif isinstance(value, Sequence): + for i, v in enumerate(value): + items.append((new_key + separator + str(i), v)) + else: + items.append((new_key, value)) + return dict(items) + + # Convert any non-supported values + for key, value in kwargs.items(): + if not isinstance(value, (int, float, str, Tensor)): + kwargs[key] = str(value) + + logger.log_hparams(flatten(kwargs)) + else: + logger.log_hparams(kwargs) def log_collection( self,