From 868cd75d94d44d2ddfd4fea1e863b6990b7bf57f Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 24 Nov 2023 16:14:21 +0100 Subject: [PATCH] debug epoch-level macro-f1 --- chebai/callbacks/epoch_metrics.py | 21 +++++++++++---------- chebai/cli.py | 2 +- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/chebai/callbacks/epoch_metrics.py b/chebai/callbacks/epoch_metrics.py index 947ef061..b235a9cc 100644 --- a/chebai/callbacks/epoch_metrics.py +++ b/chebai/callbacks/epoch_metrics.py @@ -5,6 +5,7 @@ from lightning.pytorch.utilities.types import STEP_OUTPUT import lightning as pl from torchmetrics.classification import MultilabelF1Score +import torch class _EpochLevelMetric(Callback): @@ -23,25 +24,25 @@ def apply_metric(self, target, pred): raise NotImplementedError def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - self.train_labels = np.empty(shape=(0,), dtype=int) - self.train_preds = np.empty(shape=(0,), dtype=int) + self.train_labels = torch.empty(size=(0,), dtype=torch.int).cuda() + self.train_preds = torch.empty(size=(0,), dtype=torch.int).cuda() def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None: - self.train_labels = np.concatenate((self.train_labels, outputs['labels'].int(),)) - self.train_preds = np.concatenate((self.train_preds, outputs['preds'],)) + self.train_labels = torch.concatenate((self.train_labels, outputs['labels'],)) + self.train_preds = torch.concatenate((self.train_preds, outputs['preds'],)) def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: pl_module.log(f'train_{self.metric_name}', self.apply_metric(self.train_labels, self.train_preds)) def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - self.val_labels = np.empty(shape=(0,), dtype=int) - self.val_preds = np.empty(shape=(0,), dtype=int) + self.val_labels = torch.empty(size=(0,), dtype=torch.int).cuda() + self.val_preds = torch.empty(size=(0,), dtype=torch.int).cuda() def on_validation_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: - self.val_labels = np.concatenate((self.val_labels, outputs['labels'].int(),)) - self.val_preds = np.concatenate((self.val_preds, outputs['preds'],)) + self.val_labels = torch.concatenate((self.val_labels, outputs['labels'],)) + self.val_preds = torch.concatenate((self.val_preds, outputs['preds'],)) def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: pl_module.log(f'val_{self.metric_name}', self.apply_metric(self.val_labels, self.val_preds)) @@ -54,5 +55,5 @@ def metric_name(self): return 'ep_macro-f1' def apply_metric(self, target, pred): - f1 = MultilabelF1Score(num_labels=self.num_labels, average='macro') - return f1(target, pred) + f1 = MultilabelF1Score(num_labels=self.num_labels, average='macro').cuda() + return f1(pred, target) diff --git a/chebai/cli.py b/chebai/cli.py index 75b940bd..0c091db9 100644 --- a/chebai/cli.py +++ b/chebai/cli.py @@ -16,7 +16,7 @@ def add_arguments_to_parser(self, parser): "model.init_args.out_dim", f"model.init_args.{kind}_metrics.init_args.metrics.{average}-f1.init_args.num_labels", ) - parser.link_arguments("model.init_args.out_dim", "trainer.callbacks.num_labels") + parser.link_arguments("model.init_args.out_dim", "trainer.callbacks.init_args.num_labels") # parser.link_arguments('n_splits', 'data.init_args.inner_k_folds') # doesn't work but I don't know why @staticmethod