Skip to content

Commit

Permalink
add adjustment factor for macro-f1
Browse files Browse the repository at this point in the history
  • Loading branch information
sfluegel committed Dec 14, 2023
1 parent b197932 commit 8a3b87a
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 10 deletions.
40 changes: 35 additions & 5 deletions chebai/callbacks/epoch_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@ def __init__(self, num_labels):
self.train_labels, self.val_labels = None, None
self.train_preds, self.val_preds = None, None
self.num_labels = num_labels
self.val_macro_adjust, self.train_macro_adjust = (
None,
None,
) # factor to compensate for not present classes

@property
def metric_name(self):
raise NotImplementedError

def apply_metric(self, target, pred):
def apply_metric(self, target, pred, mode="test"):
raise NotImplementedError

def on_train_epoch_start(
Expand Down Expand Up @@ -51,12 +55,25 @@ def on_train_batch_end(
)
)

def _calculate_macro_adjust(self, labels):
classes_present = torch.sum(torch.sum(labels, dim=0) > 0).item()
total_classes = labels.shape[1]
macro_adjust = total_classes / classes_present
return macro_adjust

def on_train_epoch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
) -> None:
if self.train_macro_adjust is None:
self.train_macro_adjust = self._calculate_macro_adjust(self.train_labels)
if self.train_macro_adjust != 1:
print(
f"some classes are missing in train set, calculating macro-scores with adjustment factor {macro_adjust}"
)

pl_module.log(
f"train_{self.metric_name}",
self.apply_metric(self.train_labels, self.train_preds),
self.apply_metric(self.train_labels, self.train_preds, mode="train"),
)

def on_validation_epoch_start(
Expand Down Expand Up @@ -94,9 +111,17 @@ def on_validation_batch_end(
def on_validation_epoch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
) -> None:
if self.val_macro_adjust is None:
self.val_macro_adjust = self._calculate_macro_adjust(self.val_labels)
if self.val_macro_adjust != 1:
print(
f"some classes are missing in val set, calculating macro-scores with adjustment factor {macro_adjust}"
)

pl_module.log(
f"val_{self.metric_name}",
self.apply_metric(self.val_labels, self.val_preds),
self.apply_metric(self.val_labels, self.val_preds, mode="val"),
sync_dist=True,
)


Expand All @@ -105,8 +130,13 @@ class EpochLevelMacroF1(_EpochLevelMetric):
def metric_name(self):
return "ep_macro-f1"

def apply_metric(self, target, pred):
def apply_metric(self, target, pred, mode="train"):
f1 = MultilabelF1Score(num_labels=self.num_labels, average="macro")
if target.get_device() != -1: # -1 == CPU
f1 = f1.to(device=target.get_device())
return f1(pred, target)
if mode == "train":
return f1(pred, target) * self.train_macro_adjust
elif mode == "val":
return f1(pred, target) * self.val_macro_adjust
else:
return f1(pred, target)
18 changes: 13 additions & 5 deletions chebai/result/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,16 @@ def print_metrics(preds, labels, device, classes=None, top_k=10, markdown_output
"""Prints relevant metrics, including micro and macro F1, recall and precision, best k classes and worst classes."""
f1_macro = MultilabelF1Score(preds.shape[1], average="macro").to(device=device)
f1_micro = MultilabelF1Score(preds.shape[1], average="micro").to(device=device)

classes_present = torch.sum(torch.sum(labels, dim=0) > 0).item()
total_classes = labels.shape[1]
macro_adjust = total_classes / classes_present
if classes_present != total_classes:
print(
f"{total_classes - classes_present} are missing, calculating macro-scores with adjustment factor {macro_adjust}"
)
print(
f"Macro-F1 on test set with {preds.shape[1]} classes: {f1_macro(preds, labels):3f}"
f"Macro-F1 on test set with {preds.shape[1]} classes: {f1_macro(preds, labels) * macro_adjust:3f}"
)
print(
f"Micro-F1 on test set with {preds.shape[1]} classes: {f1_micro(preds, labels):3f}"
Expand All @@ -144,9 +152,9 @@ def print_metrics(preds, labels, device, classes=None, top_k=10, markdown_output
)
recall_macro = MultilabelRecall(preds.shape[1], average="macro").to(device=device)
recall_micro = MultilabelRecall(preds.shape[1], average="micro").to(device=device)
print(f"Macro-Precision: {precision_macro(preds, labels):3f}")
print(f"Macro-Precision: {precision_macro(preds, labels) * macro_adjust:3f}")
print(f"Micro-Precision: {precision_micro(preds, labels):3f}")
print(f"Macro-Recall: {recall_macro(preds, labels):3f}")
print(f"Macro-Recall: {recall_macro(preds, labels) * macro_adjust:3f}")
print(f"Micro-Recall: {recall_micro(preds, labels):3f}")
if markdown_output:
print(
Expand All @@ -169,6 +177,6 @@ def print_metrics(preds, labels, device, classes=None, top_k=10, markdown_output

zeros = []
for i, f1 in enumerate(classwise_f1):
if f1 == 0.0:
if f1 == 0.0 and torch.sum(labels[:, i]):
zeros.append(f"{classes[i] if classes is not None else i}")
print(f'Classes with F1-score = 0: {", ".join(zeros)}')
print(f'Classes with F1-score == 0 (and non-zero labels): {", ".join(zeros)}')

0 comments on commit 8a3b87a

Please sign in to comment.