Skip to content

Commit

Permalink
debug epoch-level macro-f1
Browse files Browse the repository at this point in the history
  • Loading branch information
sfluegel committed Nov 24, 2023
1 parent ad107a9 commit 868cd75
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 11 deletions.
21 changes: 11 additions & 10 deletions chebai/callbacks/epoch_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from lightning.pytorch.utilities.types import STEP_OUTPUT
import lightning as pl
from torchmetrics.classification import MultilabelF1Score
import torch


class _EpochLevelMetric(Callback):
Expand All @@ -23,25 +24,25 @@ def apply_metric(self, target, pred):
raise NotImplementedError

def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.train_labels = np.empty(shape=(0,), dtype=int)
self.train_preds = np.empty(shape=(0,), dtype=int)
self.train_labels = torch.empty(size=(0,), dtype=torch.int).cuda()
self.train_preds = torch.empty(size=(0,), dtype=torch.int).cuda()

def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT,
batch: Any, batch_idx: int) -> None:
self.train_labels = np.concatenate((self.train_labels, outputs['labels'].int(),))
self.train_preds = np.concatenate((self.train_preds, outputs['preds'],))
self.train_labels = torch.concatenate((self.train_labels, outputs['labels'],))
self.train_preds = torch.concatenate((self.train_preds, outputs['preds'],))

def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
pl_module.log(f'train_{self.metric_name}', self.apply_metric(self.train_labels, self.train_preds))

def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.val_labels = np.empty(shape=(0,), dtype=int)
self.val_preds = np.empty(shape=(0,), dtype=int)
self.val_labels = torch.empty(size=(0,), dtype=torch.int).cuda()
self.val_preds = torch.empty(size=(0,), dtype=torch.int).cuda()

def on_validation_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT,
batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
self.val_labels = np.concatenate((self.val_labels, outputs['labels'].int(),))
self.val_preds = np.concatenate((self.val_preds, outputs['preds'],))
self.val_labels = torch.concatenate((self.val_labels, outputs['labels'],))
self.val_preds = torch.concatenate((self.val_preds, outputs['preds'],))

def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
pl_module.log(f'val_{self.metric_name}', self.apply_metric(self.val_labels, self.val_preds))
Expand All @@ -54,5 +55,5 @@ def metric_name(self):
return 'ep_macro-f1'

def apply_metric(self, target, pred):
f1 = MultilabelF1Score(num_labels=self.num_labels, average='macro')
return f1(target, pred)
f1 = MultilabelF1Score(num_labels=self.num_labels, average='macro').cuda()
return f1(pred, target)
2 changes: 1 addition & 1 deletion chebai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def add_arguments_to_parser(self, parser):
"model.init_args.out_dim",
f"model.init_args.{kind}_metrics.init_args.metrics.{average}-f1.init_args.num_labels",
)
parser.link_arguments("model.init_args.out_dim", "trainer.callbacks.num_labels")
parser.link_arguments("model.init_args.out_dim", "trainer.callbacks.init_args.num_labels")
# parser.link_arguments('n_splits', 'data.init_args.inner_k_folds') # doesn't work but I don't know why

@staticmethod
Expand Down

0 comments on commit 868cd75

Please sign in to comment.