diff --git a/src/zeroband/train.py b/src/zeroband/train.py index d4d1e039..90e4454b 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -295,7 +295,7 @@ def train(config: Config): logger.info(log) - if memory_profiler is not None: + if config.train.memory_profiler is not None: memory_profiler.step() if config.diloco is not None: