Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add non-strict mode to AnomalibMetric #2508

Merged
merged 32 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
148f789
examples/notebooks/500_use_cases/501_dobot/501a_training_a_model_with…
djdameln Jan 8, 2025
b4652f7
make post-processor compatible with multi-gpu
djdameln Jan 10, 2025
785679a
remove obsolete repr method, fix adap thresh test
djdameln Jan 10, 2025
9ede7bc
add strict param to anomalibmetric class
djdameln Jan 10, 2025
1feebc0
use clamp instead of min/max
djdameln Jan 14, 2025
6000512
Merge branch 'fix-multi-gpu-pp' into metric-add-strict-param
djdameln Jan 14, 2025
8b76900
_threshold -> _apply_threshold, minor refactor
djdameln Jan 14, 2025
2923d1e
add outer update count
djdameln Jan 14, 2025
b187761
revert to using single minmax metric
djdameln Jan 14, 2025
e8e5be3
remove individual min and max metrics
djdameln Jan 14, 2025
91a3947
fix threshold metric test
djdameln Jan 14, 2025
2dbc9e5
remove usage of minmax metric in ai-vad
djdameln Jan 14, 2025
5ec700d
make minmax metric non-persistent
djdameln Jan 14, 2025
fedd932
fix docstring
djdameln Jan 14, 2025
a447fd9
add create_anomalib_metric to module init
djdameln Jan 14, 2025
4f91157
refactor update, add example to docstring
djdameln Jan 14, 2025
af94c9c
update error messages
djdameln Jan 14, 2025
61c588b
add unit tests for AnomalibMetric class
djdameln Jan 14, 2025
1e9f647
typing
djdameln Jan 14, 2025
09d7505
merge main
djdameln Jan 14, 2025
5745d5f
Merge branch 'main' into metric-add-strict-param
djdameln Jan 17, 2025
faec7eb
improve update_called logic
djdameln Jan 20, 2025
640f05a
set pixel metrics as non-strict
djdameln Jan 20, 2025
0f2eb0e
add task type tests
djdameln Jan 20, 2025
af984ea
update metric unit tests
djdameln Jan 20, 2025
d5104e7
Merge branch 'main' into metric-add-strict-param
djdameln Jan 20, 2025
1661250
update minmax tests after merging main
djdameln Jan 20, 2025
06120f8
license header
djdameln Jan 20, 2025
e8666b2
remove outer update count
djdameln Jan 20, 2025
fe3c117
typing
djdameln Jan 20, 2025
f307638
Revert "typing"
djdameln Jan 20, 2025
aced62c
fix import
djdameln Jan 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/anomalib/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from .aupr import AUPR
from .aupro import AUPRO
from .auroc import AUROC
from .base import AnomalibMetric
from .base import AnomalibMetric, create_anomalib_metric
from .evaluator import Evaluator
from .f1_score import F1Max, F1Score
from .min_max import MinMax
Expand All @@ -60,6 +60,7 @@
"AnomalibMetric",
"AnomalyScoreDistribution",
"BinaryPrecisionRecallCurve",
"create_anomalib_metric",
"Evaluator",
"F1AdaptiveThreshold",
"F1Max",
Expand Down
57 changes: 55 additions & 2 deletions src/anomalib/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,34 @@
>>> from anomalib.metrics import create_anomalib_metric
>>> F1Score = create_anomalib_metric(BinaryF1Score)
>>> f1_score = F1Score(fields=["pred_label", "gt_label"])

Strict mode vs non-strict mode::

>>> F1Score = create_anomalib_metric(BinaryF1Score)
>>>
>>> # create metric in strict mode (default), and non-strict mode
>>> f1_score_strict = F1Score(fields=["pred_label", "gt_label"], strict=True)
>>> f1_score_nonstrict = F1Score(fields=["pred_label", "gt_label"], strict=False)
>>>
>>> # create a batch in which 'pred_label' field is None
>>> batch = ImageBatch(
... image=torch.rand(4, 3, 256, 256),
... gt_label=torch.tensor([0, 0, 1, 1])
... )
>>>
>>> f1_score_strict.update(batch) # ValueError
>>> f1_score_strict.compute() # UserWarning, tensor(0.)
>>>
>>> f1_score_nonstrict.update(batch) # No error
>>> f1_score_nonstrict.compute() # None
djdameln marked this conversation as resolved.
Show resolved Hide resolved
"""

# Copyright (C) 2024 Intel Corporation
# Copyright (C) 2024-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from collections.abc import Sequence

import torch
from torchmetrics import Metric, MetricCollection

from anomalib.data import Batch
Expand All @@ -67,6 +88,7 @@ class AnomalibMetric:
fields (Sequence[str] | None): Names of fields to extract from batch.
If None, uses class's ``default_fields``. Required if no defaults.
prefix (str): Prefix added to metric name. Defaults to "".
strict (bool): Whether to raise an error if batch is missing fields.
**kwargs: Additional arguments passed to parent metric class.

Raises:
Expand Down Expand Up @@ -97,6 +119,7 @@ def __init__(
self,
fields: Sequence[str] | None = None,
prefix: str = "",
strict: bool = True,
**kwargs,
) -> None:
fields = fields or getattr(self, "default_fields", None)
Expand All @@ -109,6 +132,7 @@ def __init__(
raise ValueError(msg)
self.fields = fields
self.name = prefix + self.__class__.__name__
self.strict = strict
super().__init__(**kwargs)

def __init_subclass__(cls, **kwargs) -> None:
Expand All @@ -132,11 +156,40 @@ def update(self, batch: Batch, *args, **kwargs) -> None:
"""
for key in self.fields:
if getattr(batch, key, None) is None:
msg = f"Batch object is missing required field: {key}"
# We cannot update the metric if the batch is missing required fields,
# so we need to decrement the update count of the super class.
self._update_count -= 1 # type: ignore[attr-defined]
if not self.strict:
# If not in strict mode, skip updating the metric but don't raise an error
return
# otherwise, raise an error
if not hasattr(batch, key):
msg = (
f"Cannot update metric of type {type(self)}. Passed dataclass instance "
f"is missing required field: {key}"
)
else:
msg = (
f"Cannot update metric of type {type(self)}. Passed dataclass instance "
f"does not have a value for field with name {key}."
)
raise ValueError(msg)

values = [getattr(batch, key) for key in self.fields]
super().update(*values, *args, **kwargs) # type: ignore[misc]

def compute(self) -> torch.Tensor:
"""Compute the metric value.

If the metric has not been updated, and metric is not in strict mode, return None.

Returns:
torch.Tensor: Computed metric value or None.
"""
if self._update_count == 0 and not self.strict: # type: ignore[attr-defined]
return None
return super().compute() # type: ignore[misc]
Comment on lines +189 to +191
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate the types are ignored? compute always expected to return tensor?

Copy link
Contributor Author

@djdameln djdameln Jan 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without the ignore, mypy raises the following error:

src/anomalib/metrics/base.py:191: error: "compute" undefined in superclass  [misc]

Mypy does not recognize the parent class, because this class can be added to any Metric class, sort of like a mixin. We do assert that the super class is a Metric, in the __init_subclass__ method above, so it's safe to ignore this check here.



def create_anomalib_metric(metric_cls: type) -> type:
"""Create an Anomalib version of a torchmetrics metric.
Expand Down
18 changes: 12 additions & 6 deletions src/anomalib/metrics/min_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@
import torch
from torchmetrics import Metric

from anomalib.metrics import AnomalibMetric

class MinMax(Metric):

class _MinMax(Metric):
"""Track minimum and maximum values across batches.

This metric maintains running minimum and maximum values across all batches
Expand All @@ -48,10 +50,10 @@ class MinMax(Metric):
max (torch.Tensor): Running maximum value seen across all batches

Example:
>>> from anomalib.metrics import MinMax
>>> from anomalib.metrics.min_max import _MinMax
>>> import torch
>>> # Create metric
>>> minmax = MinMax()
>>> minmax = _MinMax()
>>> # Update with batches
>>> batch1 = torch.tensor([0.1, 0.2, 0.3])
>>> batch2 = torch.tensor([0.2, 0.4, 0.5])
Expand All @@ -67,8 +69,8 @@ class MinMax(Metric):

def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.add_state("min", torch.tensor(float("inf")), persistent=True, dist_reduce_fx="min")
self.add_state("max", torch.tensor(float("-inf")), persistent=True, dist_reduce_fx="max")
self.add_state("min", torch.tensor(float("inf")), dist_reduce_fx="min")
self.add_state("max", torch.tensor(float("-inf")), dist_reduce_fx="max")

self.min = torch.tensor(float("inf"))
self.max = torch.tensor(float("-inf"))
Expand All @@ -94,4 +96,8 @@ def compute(self) -> tuple[torch.Tensor, torch.Tensor]:
tuple[torch.Tensor, torch.Tensor]: Tuple containing the (min, max)
values tracked across all batches
"""
return self.min, self.max
return torch.stack([self.min, self.max])


class MinMax(AnomalibMetric, _MinMax): # type: ignore[misc]
"""Wrapper to add AnomalibMetric functionality to MinMax metric."""
7 changes: 6 additions & 1 deletion src/anomalib/metrics/threshold/f1_adaptive_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,15 @@

import torch

from anomalib.metrics import AnomalibMetric
from anomalib.metrics.precision_recall_curve import BinaryPrecisionRecallCurve

from .base import Threshold

logger = logging.getLogger(__name__)


class F1AdaptiveThreshold(BinaryPrecisionRecallCurve, Threshold):
class _F1AdaptiveThreshold(BinaryPrecisionRecallCurve, Threshold):
"""Adaptive threshold that maximizes F1 score.

This class computes and stores the optimal threshold for converting anomaly
Expand Down Expand Up @@ -95,3 +96,7 @@ def compute(self) -> torch.Tensor:
# account for special case where recall is 1.0 even for the highest threshold.
# In this case 'thresholds' will be scalar.
return thresholds if thresholds.dim() == 0 else thresholds[torch.argmax(f1_score)]


class F1AdaptiveThreshold(AnomalibMetric, _F1AdaptiveThreshold): # type: ignore[misc]
"""Wrapper to add AnomalibMetric functionality to F1AdaptiveThreshold metric."""
4 changes: 2 additions & 2 deletions src/anomalib/models/components/base/anomalib_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,8 +374,8 @@ def configure_evaluator() -> Evaluator:
"""
image_auroc = AUROC(fields=["pred_score", "gt_label"], prefix="image_")
image_f1score = F1Score(fields=["pred_label", "gt_label"], prefix="image_")
pixel_auroc = AUROC(fields=["anomaly_map", "gt_mask"], prefix="pixel_")
pixel_f1score = F1Score(fields=["pred_mask", "gt_mask"], prefix="pixel_")
pixel_auroc = AUROC(fields=["anomaly_map", "gt_mask"], prefix="pixel_", strict=False)
pixel_f1score = F1Score(fields=["pred_mask", "gt_mask"], prefix="pixel_", strict=False)
test_metrics = [image_auroc, image_f1score, pixel_auroc, pixel_f1score]
return Evaluator(test_metrics=test_metrics)

Expand Down
32 changes: 17 additions & 15 deletions src/anomalib/models/video/ai_vad/density.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import torch
from torch import Tensor, nn

from anomalib.metrics.min_max import MinMax
from anomalib.models.components.base import DynamicBufferMixin
from anomalib.models.components.cluster.gmm import GaussianMixture

Expand Down Expand Up @@ -296,10 +295,14 @@ def __init__(self, n_neighbors: int) -> None:
self.n_neighbors = n_neighbors
self.feature_collection: dict[str, list[torch.Tensor]] = {}
self.group_index: dict[str, int] = {}
self.normalization_statistics = MinMax()

self.register_buffer("memory_bank", Tensor())
self.memory_bank: torch.Tensor = Tensor()
self.register_buffer("min", torch.tensor(torch.inf))
self.register_buffer("max", torch.tensor(-torch.inf))

self.memory_bank: torch.Tensor
self.min: torch.Tensor
self.max: torch.Tensor

def update(self, features: torch.Tensor, group: str | None = None) -> None:
"""Update the internal feature bank while keeping track of the group.
Expand Down Expand Up @@ -428,9 +431,8 @@ def _compute_normalization_statistics(self, grouped_features: dict[str, Tensor])
"""
for group, features in grouped_features.items():
distances = self.predict(features, group, normalize=False)
self.normalization_statistics.update(distances)

self.normalization_statistics.compute()
self.min = torch.min(self.min, torch.min(distances))
self.max = torch.max(self.min, torch.max(distances))

def _normalize(self, distances: torch.Tensor) -> torch.Tensor:
"""Normalize distance predictions.
Expand All @@ -441,9 +443,7 @@ def _normalize(self, distances: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: Normalized distances.
"""
return (distances - self.normalization_statistics.min) / (
self.normalization_statistics.max - self.normalization_statistics.min
)
return (distances - self.min) / (self.max - self.min)


class GMMEstimator(BaseDensityEstimator):
Expand Down Expand Up @@ -474,7 +474,11 @@ def __init__(self, n_components: int = 2) -> None:
self.gmm = GaussianMixture(n_components=n_components)
self.memory_bank: list[torch.Tensor] | torch.Tensor = []

self.normalization_statistics = MinMax()
self.register_buffer("min", torch.tensor(torch.inf))
self.register_buffer("max", torch.tensor(-torch.inf))

self.min: torch.Tensor
self.max: torch.Tensor

def update(self, features: torch.Tensor, group: str | None = None) -> None:
"""Update the feature bank with new features.
Expand Down Expand Up @@ -528,8 +532,8 @@ def _compute_normalization_statistics(self) -> None:
statistics used for score normalization during inference.
"""
training_scores = self.predict(self.memory_bank, normalize=False)
self.normalization_statistics.update(training_scores)
self.normalization_statistics.compute()
self.min = torch.min(self.min, torch.min(training_scores))
self.max = torch.max(self.min, torch.max(training_scores))

def _normalize(self, density: torch.Tensor) -> torch.Tensor:
"""Normalize anomaly scores using min-max statistics.
Expand All @@ -540,6 +544,4 @@ def _normalize(self, density: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: Normalized anomaly scores of shape ``(N,)``.
"""
return (density - self.normalization_statistics.min) / (
self.normalization_statistics.max - self.normalization_statistics.min
)
return (density - self.min) / (self.max - self.min)
Loading
Loading