diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index 6812dd99..ed6c9e39 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -120,14 +120,18 @@ def graph_data(self) -> HeteroData: Creates the graph in all workers. """ - graph_filename = Path( - self.config.hardware.paths.graph, - self.config.hardware.files.graph, - ) + if self.config.hardware.files.graph is not None: + graph_filename = Path( + self.config.hardware.paths.graph, + self.config.hardware.files.graph, + ) + + if graph_filename.exists() and not self.config.graph.overwrite: + LOGGER.info("Loading graph data from %s", graph_filename) + return torch.load(graph_filename) - if graph_filename.exists() and not self.config.graph.overwrite: - LOGGER.info("Loading graph data from %s", graph_filename) - return torch.load(graph_filename) + else: + graph_filename = None from anemoi.graphs.create import GraphCreator