From c03bb368a6ee65b74185fbcf05323d954964e102 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 18 Dec 2023 10:16:14 +0100 Subject: [PATCH] sync epoch-level metrics and outputs across devices --- chebai/callbacks/epoch_metrics.py | 1 + chebai/callbacks/model_checkpoint.py | 4 ++-- chebai/preprocessing/datasets/base.py | 6 ++++-- chebai/trainer/InnerCVTrainer.py | 10 +++++++--- 4 files changed, 14 insertions(+), 7 deletions(-) diff --git a/chebai/callbacks/epoch_metrics.py b/chebai/callbacks/epoch_metrics.py index a1d02939..76b37382 100644 --- a/chebai/callbacks/epoch_metrics.py +++ b/chebai/callbacks/epoch_metrics.py @@ -78,6 +78,7 @@ def on_train_epoch_end( pl_module.log( f"train_{self.metric_name}", self.apply_metric(self.train_labels, self.train_preds, mode="train"), + sync_dist=True, ) def on_validation_epoch_start( diff --git a/chebai/callbacks/model_checkpoint.py b/chebai/callbacks/model_checkpoint.py index c60724d3..cf7a3d8c 100644 --- a/chebai/callbacks/model_checkpoint.py +++ b/chebai/callbacks/model_checkpoint.py @@ -35,7 +35,7 @@ def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None: def __resolve_ckpt_dir(self, trainer: "Trainer") -> _PATH: """Overwritten for compatibility with wandb -> saves checkpoints in same dir as wandb logs""" - print(f"Resolving checkpoint dir (custom)") + rank_zero_info(f"Resolving checkpoint dir (custom)") if self.dirpath is not None: # short circuit if dirpath was passed to ModelCheckpoint return self.dirpath @@ -58,5 +58,5 @@ def __resolve_ckpt_dir(self, trainer: "Trainer") -> _PATH: # if no loggers, use default_root_dir ckpt_path = os.path.join(trainer.default_root_dir, "checkpoints") - print(f"Now using checkpoint path {ckpt_path}") + rank_zero_info(f"Now using checkpoint path {ckpt_path}") return ckpt_path diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index abf9a2d5..05ef0e05 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -10,6 +10,7 @@ import lightning as pl from chebai.preprocessing import reader as dr +from lightning_utilities.core.rank_zero import rank_zero_info class XYBaseDataModule(LightningDataModule): @@ -149,6 +150,7 @@ def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]] return self.dataloader( "validation" if not self.use_inner_cross_validation else "train_val", shuffle=False, + num_workers=self.num_workers, **kwargs, ) @@ -161,8 +163,8 @@ def predict_dataloader( return self.dataloader(self.prediction_kind, shuffle=False, **kwargs) def setup(self, **kwargs): - print("Check for processed data in ", self.processed_dir) - print(f"Cross-validation enabled: {self.use_inner_cross_validation}") + rank_zero_info("Check for processed data in ", self.processed_dir) + rank_zero_info(f"Cross-validation enabled: {self.use_inner_cross_validation}") if any( not os.path.isfile(os.path.join(self.processed_dir, f)) for f in self.processed_file_names diff --git a/chebai/trainer/InnerCVTrainer.py b/chebai/trainer/InnerCVTrainer.py index 2fe2ffb2..ca947380 100644 --- a/chebai/trainer/InnerCVTrainer.py +++ b/chebai/trainer/InnerCVTrainer.py @@ -9,10 +9,14 @@ from lightning.pytorch.callbacks import ModelCheckpoint from lightning.fabric.plugins.environments import SLURMEnvironment -from lightning_utilities.core.rank_zero import WarningCache +from lightning_utilities.core.rank_zero import ( + WarningCache, + rank_zero_warn, + rank_zero_info, +) from lightning.pytorch.loggers import CSVLogger from iterstrat.ml_stratifiers import MultilabelStratifiedKFold -from lightning.pytorch.callbacks.model_checkpoint import _is_dir, rank_zero_warn +from lightning.pytorch.callbacks.model_checkpoint import _is_dir from chebai.loggers.custom import CustomLogger from chebai.preprocessing.datasets.base import XYBaseDataModule @@ -55,7 +59,7 @@ def cv_fit(self, datamodule: XYBaseDataModule, n_splits: int = -1, *args, **kwar logger = new_trainer.logger if isinstance(logger, CustomLogger): logger.set_fold(fold) - print(f"Logging this fold at {logger.experiment.dir}") + rank_zero_info(f"Logging this fold at {logger.experiment.dir}") else: rank_zero_warn( f"Using k-fold cross-validation without an adapted logger class"