From 4c08b5601ff3466de6cf6459810f9614a24605e8 Mon Sep 17 00:00:00 2001 From: DomInvivo Date: Thu, 10 Aug 2023 12:25:06 -0400 Subject: [PATCH] minor fix --- graphium/data/datamodule.py | 2 +- profiling/profile_predictor.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/graphium/data/datamodule.py b/graphium/data/datamodule.py index d81249bbf..c898e09c8 100644 --- a/graphium/data/datamodule.py +++ b/graphium/data/datamodule.py @@ -1191,7 +1191,7 @@ def setup( labels_size = {} labels_dtype = {} if stage == "fit" or stage is None: - if self.processed_graph_data_path is not None: + if self.dataloading_from == "disk": processed_train_data_path = self._path_to_load_from_file("train") assert self._data_ready_at_path( processed_train_data_path diff --git a/profiling/profile_predictor.py b/profiling/profile_predictor.py index 16df450c1..80ad284d4 100644 --- a/profiling/profile_predictor.py +++ b/profiling/profile_predictor.py @@ -20,7 +20,9 @@ def main(): with fsspec.open(CONFIG_PATH, "r") as f: cfg = yaml.safe_load(f) - cfg["datamodule"]["args"]["processed_graph_data_path"] = "graphium/data/cache/profiling/predictor_data.cache" + cfg["datamodule"]["args"][ + "processed_graph_data_path" + ] = "graphium/data/cache/profiling/predictor_data.cache" # cfg["datamodule"]["args"]["df_path"] = DATA_PATH cfg["trainer"]["trainer"]["max_epochs"] = 5 cfg["trainer"]["trainer"]["min_epochs"] = 5