Skip to content

Commit

Permalink
sync epoch-level metrics and outputs across devices
Browse files Browse the repository at this point in the history
  • Loading branch information
sfluegel committed Dec 18, 2023
1 parent 4491834 commit c03bb36
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 7 deletions.
1 change: 1 addition & 0 deletions chebai/callbacks/epoch_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def on_train_epoch_end(
pl_module.log(
f"train_{self.metric_name}",
self.apply_metric(self.train_labels, self.train_preds, mode="train"),
sync_dist=True,
)

def on_validation_epoch_start(
Expand Down
4 changes: 2 additions & 2 deletions chebai/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None:

def __resolve_ckpt_dir(self, trainer: "Trainer") -> _PATH:
"""Overwritten for compatibility with wandb -> saves checkpoints in same dir as wandb logs"""
print(f"Resolving checkpoint dir (custom)")
rank_zero_info(f"Resolving checkpoint dir (custom)")
if self.dirpath is not None:
# short circuit if dirpath was passed to ModelCheckpoint
return self.dirpath
Expand All @@ -58,5 +58,5 @@ def __resolve_ckpt_dir(self, trainer: "Trainer") -> _PATH:
# if no loggers, use default_root_dir
ckpt_path = os.path.join(trainer.default_root_dir, "checkpoints")

print(f"Now using checkpoint path {ckpt_path}")
rank_zero_info(f"Now using checkpoint path {ckpt_path}")
return ckpt_path
6 changes: 4 additions & 2 deletions chebai/preprocessing/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import lightning as pl

from chebai.preprocessing import reader as dr
from lightning_utilities.core.rank_zero import rank_zero_info


class XYBaseDataModule(LightningDataModule):
Expand Down Expand Up @@ -149,6 +150,7 @@ def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]
return self.dataloader(
"validation" if not self.use_inner_cross_validation else "train_val",
shuffle=False,
num_workers=self.num_workers,
**kwargs,
)

Expand All @@ -161,8 +163,8 @@ def predict_dataloader(
return self.dataloader(self.prediction_kind, shuffle=False, **kwargs)

def setup(self, **kwargs):
print("Check for processed data in ", self.processed_dir)
print(f"Cross-validation enabled: {self.use_inner_cross_validation}")
rank_zero_info("Check for processed data in ", self.processed_dir)
rank_zero_info(f"Cross-validation enabled: {self.use_inner_cross_validation}")
if any(
not os.path.isfile(os.path.join(self.processed_dir, f))
for f in self.processed_file_names
Expand Down
10 changes: 7 additions & 3 deletions chebai/trainer/InnerCVTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@
from lightning.pytorch.callbacks import ModelCheckpoint

from lightning.fabric.plugins.environments import SLURMEnvironment
from lightning_utilities.core.rank_zero import WarningCache
from lightning_utilities.core.rank_zero import (
WarningCache,
rank_zero_warn,
rank_zero_info,
)
from lightning.pytorch.loggers import CSVLogger
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
from lightning.pytorch.callbacks.model_checkpoint import _is_dir, rank_zero_warn
from lightning.pytorch.callbacks.model_checkpoint import _is_dir

from chebai.loggers.custom import CustomLogger
from chebai.preprocessing.datasets.base import XYBaseDataModule
Expand Down Expand Up @@ -55,7 +59,7 @@ def cv_fit(self, datamodule: XYBaseDataModule, n_splits: int = -1, *args, **kwar
logger = new_trainer.logger
if isinstance(logger, CustomLogger):
logger.set_fold(fold)
print(f"Logging this fold at {logger.experiment.dir}")
rank_zero_info(f"Logging this fold at {logger.experiment.dir}")
else:
rank_zero_warn(
f"Using k-fold cross-validation without an adapted logger class"
Expand Down

0 comments on commit c03bb36

Please sign in to comment.