Skip to content

Commit

Permalink
Merge pull request #66 from IBM/fix/metric_names
Browse files Browse the repository at this point in the history
remove spaces in metric names
  • Loading branch information
CarlosGomes98 authored Jul 31, 2024
2 parents 90d3e49 + ee5eea5 commit d9bf33a
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 15 deletions.
12 changes: 6 additions & 6 deletions terratorch/tasks/classification_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,31 +154,31 @@ def configure_metrics(self) -> None:
class_names = self.hparams["class_names"]
metrics = MetricCollection(
{
"Overall Accuracy": MulticlassAccuracy(
"Overall_Accuracy": MulticlassAccuracy(
num_classes=num_classes,
ignore_index=ignore_index,
average="micro",
),
"Average Accuracy": MulticlassAccuracy(
"Average_Accuracy": MulticlassAccuracy(
num_classes=num_classes,
ignore_index=ignore_index,
average="macro",
),
"Multiclass Accuracy Class": ClasswiseWrapper(
"Multiclass_Accuracy_Class": ClasswiseWrapper(
MulticlassAccuracy(
num_classes=num_classes,
ignore_index=ignore_index,
average=None,
),
labels=class_names,
),
"Multiclass Jaccard Index": MulticlassJaccardIndex(num_classes=num_classes, ignore_index=ignore_index),
"Multiclass Jaccard Index Class": ClasswiseWrapper(
"Multiclass_Jaccard_Index": MulticlassJaccardIndex(num_classes=num_classes, ignore_index=ignore_index),
"Multiclass_Jaccard_Index_Class": ClasswiseWrapper(
MulticlassJaccardIndex(num_classes=num_classes, ignore_index=ignore_index, average=None),
labels=class_names,
),
# why FBetaScore
"Multiclass F1 Score": MulticlassFBetaScore(
"Multiclass_F1_Score": MulticlassFBetaScore(
num_classes=num_classes,
ignore_index=ignore_index,
beta=1.0,
Expand Down
6 changes: 3 additions & 3 deletions terratorch/tasks/multilabel_classification_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ def configure_losses(self) -> None:
def configure_metrics(self) -> None:
metrics = MetricCollection(
{
"Overall Accuracy": MultilabelAccuracy(
"Overall_Accuracy": MultilabelAccuracy(
num_labels=self.hparams["model_args"]["num_classes"], average="micro"
),
"Average Accuracy": MultilabelAccuracy(
"Average_Accuracy": MultilabelAccuracy(
num_labels=self.hparams["model_args"]["num_classes"], average="macro"
),
"Multilabel F1 Score": MultilabelFBetaScore(
"Multilabel_F1_Score": MultilabelFBetaScore(
num_labels=self.hparams["model_args"]["num_classes"], beta=1.0, average="micro"
),
}
Expand Down
12 changes: 6 additions & 6 deletions terratorch/tasks/segmentation_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,13 @@ def configure_metrics(self) -> None:
class_names = self.hparams["class_names"]
metrics = MetricCollection(
{
"Multiclass Accuracy": MulticlassAccuracy(
"Multiclass_Accuracy": MulticlassAccuracy(
num_classes=num_classes,
ignore_index=ignore_index,
multidim_average="global",
average="micro",
),
"Multiclass Accuracy Class": ClasswiseWrapper(
"Multiclass_Accuracy_Class": ClasswiseWrapper(
MulticlassAccuracy(
num_classes=num_classes,
ignore_index=ignore_index,
Expand All @@ -183,18 +183,18 @@ def configure_metrics(self) -> None:
),
labels=class_names,
),
"Multiclass Jaccard Index Micro": MulticlassJaccardIndex(
"Multiclass_Jaccard_Index_Micro": MulticlassJaccardIndex(
num_classes=num_classes, ignore_index=ignore_index, average="micro"
),
"Multiclass Jaccard Index": MulticlassJaccardIndex(
"Multiclass_Jaccard_Index": MulticlassJaccardIndex(
num_classes=num_classes,
ignore_index=ignore_index,
),
"Multiclass Jaccard Index Class": ClasswiseWrapper(
"Multiclass_Jaccard_Index_Class": ClasswiseWrapper(
MulticlassJaccardIndex(num_classes=num_classes, ignore_index=ignore_index, average=None),
labels=class_names,
),
"Multiclass F1 Score": MulticlassF1Score(
"Multiclass_F1_Score": MulticlassF1Score(
num_classes=num_classes,
ignore_index=ignore_index,
multidim_average="global",
Expand Down

0 comments on commit d9bf33a

Please sign in to comment.