From 148f789d58c4fd23621f0d2eba8cca791f3585f9 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Wed, 8 Jan 2025 15:55:22 +0100 Subject: [PATCH 01/28] examples/notebooks/500_use_cases/501_dobot/501a_training_a_model_with_cubes_from_a_robotic_arm.ipynb: convert to Git LFS From b4652f7bebf2daac5f494ce5c17779f33e10ca91 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Fri, 10 Jan 2025 19:35:40 +0100 Subject: [PATCH 02/28] make post-processor compatible with multi-gpu --- src/anomalib/metrics/min_max.py | 6 +- .../threshold/f1_adaptive_threshold.py | 25 +---- src/anomalib/post_processing/one_class.py | 91 ++++++------------- 3 files changed, 34 insertions(+), 88 deletions(-) diff --git a/src/anomalib/metrics/min_max.py b/src/anomalib/metrics/min_max.py index c071597bf5..ce04484603 100644 --- a/src/anomalib/metrics/min_max.py +++ b/src/anomalib/metrics/min_max.py @@ -67,8 +67,8 @@ class MinMax(Metric): def __init__(self, **kwargs) -> None: super().__init__(**kwargs) - self.add_state("min", torch.tensor(float("inf")), persistent=True) - self.add_state("max", torch.tensor(float("-inf")), persistent=True) + self.add_state("min", torch.tensor(float("inf")), persistent=True, dist_reduce_fx="min") + self.add_state("max", torch.tensor(float("-inf")), persistent=True, dist_reduce_fx="max") self.min = torch.tensor(float("inf")) self.max = torch.tensor(float("-inf")) @@ -84,8 +84,8 @@ def update(self, predictions: torch.Tensor, *args, **kwargs) -> None: """ del args, kwargs # These variables are not used. - self.max = torch.max(self.max, torch.max(predictions)) self.min = torch.min(self.min, torch.min(predictions)) + self.max = torch.max(self.min, torch.max(predictions)) def compute(self) -> tuple[torch.Tensor, torch.Tensor]: """Compute final minimum and maximum values. diff --git a/src/anomalib/metrics/threshold/f1_adaptive_threshold.py b/src/anomalib/metrics/threshold/f1_adaptive_threshold.py index 1d1461ddaf..aa9122fc82 100644 --- a/src/anomalib/metrics/threshold/f1_adaptive_threshold.py +++ b/src/anomalib/metrics/threshold/f1_adaptive_threshold.py @@ -46,14 +46,6 @@ class F1AdaptiveThreshold(BinaryPrecisionRecallCurve, Threshold): This class computes and stores the optimal threshold for converting anomaly scores to binary predictions by maximizing the F1 score on validation data. - Args: - default_value: Initial threshold value used before computation. - Defaults to ``0.5``. - **kwargs: Additional arguments passed to parent classes. - - Attributes: - value (torch.Tensor): Current threshold value. - Example: >>> from anomalib.metrics import F1AdaptiveThreshold >>> import torch @@ -68,12 +60,6 @@ class F1AdaptiveThreshold(BinaryPrecisionRecallCurve, Threshold): Optimal threshold: 0.5000 """ - def __init__(self, default_value: float = 0.5, **kwargs) -> None: - super().__init__(**kwargs) - - self.add_state("value", default=torch.tensor(default_value), persistent=True) - self.value = torch.tensor(default_value) - def compute(self) -> torch.Tensor: """Compute optimal threshold by maximizing F1 score. @@ -105,13 +91,10 @@ def compute(self) -> torch.Tensor: precision, recall, thresholds = super().compute() f1_score = (2 * precision * recall) / (precision + recall + 1e-10) - if thresholds.dim() == 0: - # special case where recall is 1.0 even for the highest threshold. - # In this case 'thresholds' will be scalar. - self.value = thresholds - else: - self.value = thresholds[torch.argmax(f1_score)] - return self.value + + # account for special case where recall is 1.0 even for the highest threshold. + # In this case 'thresholds' will be scalar. + return thresholds if thresholds.dim() == 0 else thresholds[torch.argmax(f1_score)] def __repr__(self) -> str: """Return string representation including current threshold value. diff --git a/src/anomalib/post_processing/one_class.py b/src/anomalib/post_processing/one_class.py index ca89ba4df5..4dd0d9ce47 100644 --- a/src/anomalib/post_processing/one_class.py +++ b/src/anomalib/post_processing/one_class.py @@ -59,13 +59,30 @@ def __init__( ) -> None: super().__init__(**kwargs) + # configure sensitivity values + self.image_sensitivity = image_sensitivity + self.pixel_sensitivity = pixel_sensitivity + + # initialize threshold and normalization metrics self._image_threshold = F1AdaptiveThreshold() self._pixel_threshold = F1AdaptiveThreshold() self._image_normalization_stats = MinMax() self._pixel_normalization_stats = MinMax() - self.image_sensitivity = image_sensitivity - self.pixel_sensitivity = pixel_sensitivity + # register buffers to persist threshold and normalization values + self.register_buffer("image_threshold", torch.tensor(0)) + self.register_buffer("pixel_threshold", torch.tensor(0)) + self.register_buffer("image_min", torch.tensor(0)) + self.register_buffer("image_max", torch.tensor(1)) + self.register_buffer("pixel_min", torch.tensor(0)) + self.register_buffer("pixel_max", torch.tensor(1)) + + self.image_threshold: torch.Tensor + self.pixel_threshold: torch.Tensor + self.image_min: torch.Tensor + self.image_max: torch.Tensor + self.pixel_min: torch.Tensor + self.pixel_max: torch.Tensor def on_validation_batch_end( self, @@ -103,13 +120,13 @@ def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) """ del trainer, pl_module if self._image_threshold.update_called: - self._image_threshold.compute() + self.image_threshold = self._image_threshold.compute() if self._pixel_threshold.update_called: - self._pixel_threshold.compute() + self.pixel_threshold = self._pixel_threshold.compute() if self._image_normalization_stats.update_called: - self._image_normalization_stats.compute() + self.image_min, self.image_max = self._image_normalization_stats.compute() if self._pixel_normalization_stats.update_called: - self._pixel_normalization_stats.compute() + self.pixel_min, self.pixel_max = self._pixel_normalization_stats.compute() def on_test_batch_end( self, @@ -168,8 +185,8 @@ def forward(self, predictions: InferenceBatch) -> InferenceBatch: msg = "At least one of pred_score or anomaly_map must be provided." raise ValueError(msg) pred_score = predictions.pred_score or torch.amax(predictions.anomaly_map, dim=(-2, -1)) - pred_score = self._normalize(pred_score, self.image_min, self.image_max, self.raw_image_threshold) - anomaly_map = self._normalize(predictions.anomaly_map, self.pixel_min, self.pixel_max, self.raw_pixel_threshold) + pred_score = self._normalize(pred_score, self.image_min, self.image_max, self.image_threshold) + anomaly_map = self._normalize(predictions.anomaly_map, self.pixel_min, self.pixel_max, self.pixel_threshold) pred_label = self._threshold(pred_score, self.normalized_image_threshold) pred_mask = self._threshold(anomaly_map, self.normalized_pixel_threshold) return InferenceBatch( @@ -216,9 +233,9 @@ def normalize_batch(self, batch: Batch) -> None: batch (Batch): Batch containing model predictions. """ # normalize pixel-level predictions - batch.anomaly_map = self._normalize(batch.anomaly_map, self.pixel_min, self.pixel_max, self.raw_pixel_threshold) + batch.anomaly_map = self._normalize(batch.anomaly_map, self.pixel_min, self.pixel_max, self.pixel_threshold) # normalize image-level predictions - batch.pred_score = self._normalize(batch.pred_score, self.image_min, self.image_max, self.raw_image_threshold) + batch.pred_score = self._normalize(batch.pred_score, self.image_min, self.image_max, self.image_threshold) @staticmethod def _threshold(preds: torch.Tensor | None, threshold: float) -> torch.Tensor | None: @@ -259,24 +276,6 @@ def _normalize( preds = torch.minimum(preds, torch.tensor(1)) return torch.maximum(preds, torch.tensor(0)) - @property - def raw_image_threshold(self) -> float: - """Get the raw image-level threshold. - - Returns: - float: Raw image-level threshold value. - """ - return self._image_threshold.value - - @property - def raw_pixel_threshold(self) -> float: - """Get the raw pixel-level threshold. - - Returns: - float: Raw pixel-level threshold value. - """ - return self._pixel_threshold.value - @property def normalized_image_threshold(self) -> float: """Get the normalized image-level threshold. @@ -298,39 +297,3 @@ def normalized_pixel_threshold(self) -> float: if self.pixel_sensitivity is not None: return 1 - self.pixel_sensitivity return 0.5 - - @property - def image_min(self) -> float: - """Get the minimum value for image-level normalization. - - Returns: - float: Minimum image-level value. - """ - return self._image_normalization_stats.min - - @property - def image_max(self) -> float: - """Get the maximum value for image-level normalization. - - Returns: - float: Maximum image-level value. - """ - return self._image_normalization_stats.max - - @property - def pixel_min(self) -> float: - """Get the minimum value for pixel-level normalization. - - Returns: - float: Minimum pixel-level value. - """ - return self._pixel_normalization_stats.min - - @property - def pixel_max(self) -> float: - """Get the maximum value for pixel-level normalization. - - Returns: - float: Maximum pixel-level value. - """ - return self._pixel_normalization_stats.max From 785679a29e66db05db92e2f0f2ed8cac6a0fcfbf Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Fri, 10 Jan 2025 19:41:36 +0100 Subject: [PATCH 03/28] remove obsolete repr method, fix adap thresh test --- src/anomalib/metrics/threshold/f1_adaptive_threshold.py | 8 -------- tests/unit/metrics/test_adaptive_threshold.py | 2 +- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/src/anomalib/metrics/threshold/f1_adaptive_threshold.py b/src/anomalib/metrics/threshold/f1_adaptive_threshold.py index aa9122fc82..80e6445d9f 100644 --- a/src/anomalib/metrics/threshold/f1_adaptive_threshold.py +++ b/src/anomalib/metrics/threshold/f1_adaptive_threshold.py @@ -95,11 +95,3 @@ def compute(self) -> torch.Tensor: # account for special case where recall is 1.0 even for the highest threshold. # In this case 'thresholds' will be scalar. return thresholds if thresholds.dim() == 0 else thresholds[torch.argmax(f1_score)] - - def __repr__(self) -> str: - """Return string representation including current threshold value. - - Returns: - str: String in format "ClassName(value=X.XX)" - """ - return f"{super().__repr__()} (value={self.value:.2f})" diff --git a/tests/unit/metrics/test_adaptive_threshold.py b/tests/unit/metrics/test_adaptive_threshold.py index 7f7dbf1796..e76a7effa2 100644 --- a/tests/unit/metrics/test_adaptive_threshold.py +++ b/tests/unit/metrics/test_adaptive_threshold.py @@ -18,7 +18,7 @@ ) def test_adaptive_threshold(labels: torch.Tensor, preds: torch.Tensor, target_threshold: int | float) -> None: """Test if the adaptive threshold computation returns the desired value.""" - adaptive_threshold = F1AdaptiveThreshold(default_value=0.5) + adaptive_threshold = F1AdaptiveThreshold() adaptive_threshold.update(preds, labels) threshold_value = adaptive_threshold.compute() From 9ede7bc90b45fe6b4875443ec31020b85313ce93 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Fri, 10 Jan 2025 21:05:41 +0100 Subject: [PATCH 04/28] add strict param to anomalibmetric class --- src/anomalib/metrics/base.py | 24 ++- src/anomalib/metrics/min_max.py | 140 +++++++++++++++++- .../threshold/f1_adaptive_threshold.py | 7 +- src/anomalib/post_processing/one_class.py | 41 +++-- 4 files changed, 188 insertions(+), 24 deletions(-) diff --git a/src/anomalib/metrics/base.py b/src/anomalib/metrics/base.py index 040fd00ece..257fc2202d 100644 --- a/src/anomalib/metrics/base.py +++ b/src/anomalib/metrics/base.py @@ -47,6 +47,7 @@ from collections.abc import Sequence +import torch from torchmetrics import Metric, MetricCollection from anomalib.data import Batch @@ -67,6 +68,7 @@ class AnomalibMetric: fields (Sequence[str] | None): Names of fields to extract from batch. If None, uses class's ``default_fields``. Required if no defaults. prefix (str): Prefix added to metric name. Defaults to "". + strict (bool): Whether to raise an error if batch is missing fields. **kwargs: Additional arguments passed to parent metric class. Raises: @@ -97,6 +99,7 @@ def __init__( self, fields: Sequence[str] | None = None, prefix: str = "", + strict: bool = True, **kwargs, ) -> None: fields = fields or getattr(self, "default_fields", None) @@ -109,6 +112,7 @@ def __init__( raise ValueError(msg) self.fields = fields self.name = prefix + self.__class__.__name__ + self.strict = strict super().__init__(**kwargs) def __init_subclass__(cls, **kwargs) -> None: @@ -131,12 +135,30 @@ def update(self, batch: Batch, *args, **kwargs) -> None: ValueError: If batch is missing any required fields. """ for key in self.fields: - if getattr(batch, key, None) is None: + if not hasattr(batch, key): msg = f"Batch object is missing required field: {key}" raise ValueError(msg) + if getattr(batch, key, None) is None: + if self.strict: + msg = f"Field {key} in batch object is None" + raise ValueError(msg) + self._update_count -= 1 # type: ignore[attr-defined] + return values = [getattr(batch, key) for key in self.fields] super().update(*values, *args, **kwargs) # type: ignore[misc] + def compute(self) -> torch.Tensor: + """Compute the metric value. + + If the metric has not been updated, and metric is not in strict mode, return None. + + Returns: + torch.Tensor: Computed metric value or None. + """ + if self._update_count == 0 and not self.strict: # type: ignore[attr-defined] + return None + return super().compute() # type: ignore[misc] + def create_anomalib_metric(metric_cls: type) -> type: """Create an Anomalib version of a torchmetrics metric. diff --git a/src/anomalib/metrics/min_max.py b/src/anomalib/metrics/min_max.py index ce04484603..3bfa154566 100644 --- a/src/anomalib/metrics/min_max.py +++ b/src/anomalib/metrics/min_max.py @@ -30,8 +30,10 @@ import torch from torchmetrics import Metric +from anomalib.metrics import AnomalibMetric -class MinMax(Metric): + +class _MinMax(Metric): """Track minimum and maximum values across batches. This metric maintains running minimum and maximum values across all batches @@ -95,3 +97,139 @@ def compute(self) -> tuple[torch.Tensor, torch.Tensor]: values tracked across all batches """ return self.min, self.max + + +class _Min(Metric): + """Track minimum value across batches. + + This metric maintains running minimum values across all batches + it processes. It is useful for tasks like normalization or monitoring the + range of values during training. + + Args: + full_state_update (bool, optional): Whether to update the internal state + with each new batch. Defaults to ``True``. + kwargs: Additional keyword arguments passed to the parent class. + + Attributes: + min (torch.Tensor): Running minimum value seen across all batches + + Example: + >>> from anomalib.metrics import MinMax + >>> import torch + >>> # Create metric + >>> minmax = Min() + >>> # Update with batches + >>> batch1 = torch.tensor([0.1, 0.2, 0.3]) + >>> batch2 = torch.tensor([0.2, 0.4, 0.5]) + >>> minmax.update(batch1) + >>> minmax.update(batch2) + >>> # Get final min/max values + >>> min_val, max_val = minmax.compute() + >>> min_val, max_val + (tensor(0.1000), tensor(0.5000)) + """ + + full_state_update: bool = True + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.add_state("min", torch.tensor(float("inf")), persistent=True, dist_reduce_fx="min") + + self.min = torch.tensor(float("inf")) + + def update(self, predictions: torch.Tensor, *args, **kwargs) -> None: + """Update running min and max values with new predictions. + + Args: + predictions (torch.Tensor): New tensor of values to include in min/max + tracking + *args: Additional positional arguments (unused) + **kwargs: Additional keyword arguments (unused) + """ + del args, kwargs # These variables are not used. + + self.min = torch.min(self.min, torch.min(predictions)) + + def compute(self) -> tuple[torch.Tensor, torch.Tensor]: + """Compute final minimum and maximum values. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Tuple containing the (min, max) + values tracked across all batches + """ + return self.min + + +class _Max(Metric): + """Track maximum value across batches. + + This metric maintains running maximum values across all batches + it processes. It is useful for tasks like normalization or monitoring the + range of values during training. + + Args: + full_state_update (bool, optional): Whether to update the internal state + with each new batch. Defaults to ``True``. + kwargs: Additional keyword arguments passed to the parent class. + + Attributes: + max (torch.Tensor): Running maximum value seen across all batches + + Example: + >>> from anomalib.metrics import MinMax + >>> import torch + >>> # Create metric + >>> minmax = Min() + >>> # Update with batches + >>> batch1 = torch.tensor([0.1, 0.2, 0.3]) + >>> batch2 = torch.tensor([0.2, 0.4, 0.5]) + >>> minmax.update(batch1) + >>> minmax.update(batch2) + >>> # Get final min/max values + >>> min_val, max_val = minmax.compute() + >>> min_val, max_val + (tensor(0.1000), tensor(0.5000)) + """ + + full_state_update: bool = True + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.add_state("max", torch.tensor(float("-inf")), persistent=True, dist_reduce_fx="max") + + self.max = torch.tensor(float("-inf")) + + def update(self, predictions: torch.Tensor, *args, **kwargs) -> None: + """Update running min and max values with new predictions. + + Args: + predictions (torch.Tensor): New tensor of values to include in min/max + tracking + *args: Additional positional arguments (unused) + **kwargs: Additional keyword arguments (unused) + """ + del args, kwargs # These variables are not used. + + self.max = torch.max(self.max, torch.min(predictions)) + + def compute(self) -> tuple[torch.Tensor, torch.Tensor]: + """Compute final minimum and maximum values. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Tuple containing the (min, max) + values tracked across all batches + """ + return self.max + + +class MinMax(AnomalibMetric, _MinMax): # type: ignore[misc] + """Wrapper to add AnomalibMetric functionality to MinMax metric.""" + + +class Min(AnomalibMetric, _Min): # type: ignore[misc] + """Wrapper to add AnomalibMetric functionality to Min metric.""" + + +class Max(AnomalibMetric, _Max): # type: ignore[misc] + """Wrapper to add AnomalibMetric functionality to Max metric.""" diff --git a/src/anomalib/metrics/threshold/f1_adaptive_threshold.py b/src/anomalib/metrics/threshold/f1_adaptive_threshold.py index 80e6445d9f..1cc0d6fa7b 100644 --- a/src/anomalib/metrics/threshold/f1_adaptive_threshold.py +++ b/src/anomalib/metrics/threshold/f1_adaptive_threshold.py @@ -33,6 +33,7 @@ import torch +from anomalib.metrics import AnomalibMetric from anomalib.metrics.precision_recall_curve import BinaryPrecisionRecallCurve from .base import Threshold @@ -40,7 +41,7 @@ logger = logging.getLogger(__name__) -class F1AdaptiveThreshold(BinaryPrecisionRecallCurve, Threshold): +class _F1AdaptiveThreshold(BinaryPrecisionRecallCurve, Threshold): """Adaptive threshold that maximizes F1 score. This class computes and stores the optimal threshold for converting anomaly @@ -95,3 +96,7 @@ def compute(self) -> torch.Tensor: # account for special case where recall is 1.0 even for the highest threshold. # In this case 'thresholds' will be scalar. return thresholds if thresholds.dim() == 0 else thresholds[torch.argmax(f1_score)] + + +class F1AdaptiveThreshold(AnomalibMetric, _F1AdaptiveThreshold): # type: ignore[misc] + """Wrapper to add AnomalibMetric functionality to F1AdaptiveThreshold metric.""" diff --git a/src/anomalib/post_processing/one_class.py b/src/anomalib/post_processing/one_class.py index 4dd0d9ce47..0f2c2a932a 100644 --- a/src/anomalib/post_processing/one_class.py +++ b/src/anomalib/post_processing/one_class.py @@ -22,7 +22,8 @@ from lightning import LightningModule, Trainer from anomalib.data import Batch, InferenceBatch -from anomalib.metrics import F1AdaptiveThreshold, MinMax +from anomalib.metrics import F1AdaptiveThreshold +from anomalib.metrics.min_max import Max, Min from .base import PostProcessor @@ -64,10 +65,12 @@ def __init__( self.pixel_sensitivity = pixel_sensitivity # initialize threshold and normalization metrics - self._image_threshold = F1AdaptiveThreshold() - self._pixel_threshold = F1AdaptiveThreshold() - self._image_normalization_stats = MinMax() - self._pixel_normalization_stats = MinMax() + self._image_threshold = F1AdaptiveThreshold(fields=["pred_score", "gt_label"], strict=False) + self._pixel_threshold = F1AdaptiveThreshold(fields=["anomaly_map", "gt_mask"], strict=False) + self._image_min = Min(fields=["pred_score"], strict=False) + self._image_max = Max(fields=["pred_score"], strict=False) + self._pixel_min = Min(fields=["anomaly_map"], strict=False) + self._pixel_max = Max(fields=["anomaly_map"], strict=False) # register buffers to persist threshold and normalization values self.register_buffer("image_threshold", torch.tensor(0)) @@ -102,14 +105,12 @@ def on_validation_batch_end( **kwargs: Arbitrary keyword arguments. """ del trainer, pl_module, args, kwargs # Unused arguments. - if outputs.pred_score is not None: - self._image_threshold.update(outputs.pred_score, outputs.gt_label) - if outputs.anomaly_map is not None: - self._pixel_threshold.update(outputs.anomaly_map, outputs.gt_mask) - if outputs.pred_score is not None: - self._image_normalization_stats.update(outputs.pred_score) - if outputs.anomaly_map is not None: - self._pixel_normalization_stats.update(outputs.anomaly_map) + self._image_threshold.update(outputs) + self._pixel_threshold.update(outputs) + self._image_min.update(outputs) + self._image_max.update(outputs) + self._pixel_min.update(outputs) + self._pixel_max.update(outputs) def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: """Compute final threshold and normalization values. @@ -119,14 +120,12 @@ def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) pl_module (LightningModule): PyTorch Lightning module instance. """ del trainer, pl_module - if self._image_threshold.update_called: - self.image_threshold = self._image_threshold.compute() - if self._pixel_threshold.update_called: - self.pixel_threshold = self._pixel_threshold.compute() - if self._image_normalization_stats.update_called: - self.image_min, self.image_max = self._image_normalization_stats.compute() - if self._pixel_normalization_stats.update_called: - self.pixel_min, self.pixel_max = self._pixel_normalization_stats.compute() + self.image_threshold = self._image_threshold.compute() + self.pixel_threshold = self._pixel_threshold.compute() + self.image_min = self._image_min.compute() + self.image_max = self._image_max.compute() + self.pixel_min = self._pixel_min.compute() + self.pixel_max = self._pixel_max.compute() def on_test_batch_end( self, From 1feebc0dc6124cac5634aedacc4752fe0cebb057 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Tue, 14 Jan 2025 10:30:11 +0100 Subject: [PATCH 05/28] use clamp instead of min/max --- src/anomalib/post_processing/one_class.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/anomalib/post_processing/one_class.py b/src/anomalib/post_processing/one_class.py index 4dd0d9ce47..b7e135f1e8 100644 --- a/src/anomalib/post_processing/one_class.py +++ b/src/anomalib/post_processing/one_class.py @@ -273,8 +273,7 @@ def _normalize( if preds is None: return None preds = ((preds - threshold) / (norm_max - norm_min)) + 0.5 - preds = torch.minimum(preds, torch.tensor(1)) - return torch.maximum(preds, torch.tensor(0)) + return preds.clamp(min=0, max=1) @property def normalized_image_threshold(self) -> float: From 8b76900868edcf6c30671b427d629816f116a692 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Tue, 14 Jan 2025 14:21:31 +0100 Subject: [PATCH 06/28] _threshold -> _apply_threshold, minor refactor --- src/anomalib/post_processing/one_class.py | 24 +++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/anomalib/post_processing/one_class.py b/src/anomalib/post_processing/one_class.py index b173c14ef8..eb6202c8f7 100644 --- a/src/anomalib/post_processing/one_class.py +++ b/src/anomalib/post_processing/one_class.py @@ -186,8 +186,8 @@ def forward(self, predictions: InferenceBatch) -> InferenceBatch: pred_score = predictions.pred_score or torch.amax(predictions.anomaly_map, dim=(-2, -1)) pred_score = self._normalize(pred_score, self.image_min, self.image_max, self.image_threshold) anomaly_map = self._normalize(predictions.anomaly_map, self.pixel_min, self.pixel_max, self.pixel_threshold) - pred_label = self._threshold(pred_score, self.normalized_image_threshold) - pred_mask = self._threshold(anomaly_map, self.normalized_pixel_threshold) + pred_label = self._apply_threshold(pred_score, self.normalized_image_threshold) + pred_mask = self._apply_threshold(anomaly_map, self.normalized_pixel_threshold) return InferenceBatch( pred_label=pred_label, pred_score=pred_score, @@ -217,12 +217,12 @@ def threshold_batch(self, batch: Batch) -> None: batch.pred_label = ( batch.pred_label if batch.pred_label is not None - else self._threshold(batch.pred_score, self.normalized_image_threshold) + else self._apply_threshold(batch.pred_score, self.normalized_image_threshold) ) batch.pred_mask = ( batch.pred_mask if batch.pred_mask is not None - else self._threshold(batch.anomaly_map, self.normalized_pixel_threshold) + else self._apply_threshold(batch.anomaly_map, self.normalized_pixel_threshold) ) def normalize_batch(self, batch: Batch) -> None: @@ -237,7 +237,7 @@ def normalize_batch(self, batch: Batch) -> None: batch.pred_score = self._normalize(batch.pred_score, self.image_min, self.image_max, self.image_threshold) @staticmethod - def _threshold(preds: torch.Tensor | None, threshold: float) -> torch.Tensor | None: + def _apply_threshold(preds: torch.Tensor | None, threshold: float) -> torch.Tensor | None: """Apply thresholding to a single tensor. Args: @@ -247,16 +247,16 @@ def _threshold(preds: torch.Tensor | None, threshold: float) -> torch.Tensor | N Returns: torch.Tensor | None: Thresholded predictions or None if input is None. """ - if preds is None: - return None + if preds is None or threshold is None: + return preds return preds > threshold @staticmethod def _normalize( preds: torch.Tensor | None, - norm_min: float, - norm_max: float, - threshold: float, + norm_min: float | None, + norm_max: float | None, + threshold: float | None, ) -> torch.Tensor | None: """Normalize a tensor using min, max, and threshold values. @@ -269,8 +269,8 @@ def _normalize( Returns: torch.Tensor | None: Normalized predictions or None if input is None. """ - if preds is None: - return None + if preds is None or norm_min is None or norm_max is None or threshold is None: + return preds preds = ((preds - threshold) / (norm_max - norm_min)) + 0.5 return preds.clamp(min=0, max=1) From 2923d1e132c9f9587cc92b40b803611ea4bca24d Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Tue, 14 Jan 2025 14:24:13 +0100 Subject: [PATCH 07/28] add outer update count --- src/anomalib/metrics/base.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/anomalib/metrics/base.py b/src/anomalib/metrics/base.py index 257fc2202d..978f9be178 100644 --- a/src/anomalib/metrics/base.py +++ b/src/anomalib/metrics/base.py @@ -113,6 +113,7 @@ def __init__( self.fields = fields self.name = prefix + self.__class__.__name__ self.strict = strict + self.__update_count = 0 # keeps track of the update calls of the wrapper class super().__init__(**kwargs) def __init_subclass__(cls, **kwargs) -> None: @@ -134,14 +135,17 @@ def update(self, batch: Batch, *args, **kwargs) -> None: Raises: ValueError: If batch is missing any required fields. """ + self.__update_count += 1 for key in self.fields: if not hasattr(batch, key): msg = f"Batch object is missing required field: {key}" raise ValueError(msg) if getattr(batch, key, None) is None: if self.strict: - msg = f"Field {key} in batch object is None" + msg = f"Cannot update metric of type {type(self)}. Field {key} in batch object is None" raise ValueError(msg) + # we need to decrement the update count of the super class + # if we are not actually updating the metric states. self._update_count -= 1 # type: ignore[attr-defined] return values = [getattr(batch, key) for key in self.fields] @@ -159,6 +163,11 @@ def compute(self) -> torch.Tensor: return None return super().compute() # type: ignore[misc] + @property + def update_called(self) -> bool: + """Check if the update method has been called.""" + return self.__update_count > 0 + def create_anomalib_metric(metric_cls: type) -> type: """Create an Anomalib version of a torchmetrics metric. From b187761fffb5071ecf09e461e4fb4f300c935a0f Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Tue, 14 Jan 2025 16:52:54 +0100 Subject: [PATCH 08/28] revert to using single minmax metric --- src/anomalib/metrics/min_max.py | 2 +- src/anomalib/post_processing/one_class.py | 21 +++++++-------------- 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/src/anomalib/metrics/min_max.py b/src/anomalib/metrics/min_max.py index 3bfa154566..94f1a2dfb0 100644 --- a/src/anomalib/metrics/min_max.py +++ b/src/anomalib/metrics/min_max.py @@ -96,7 +96,7 @@ def compute(self) -> tuple[torch.Tensor, torch.Tensor]: tuple[torch.Tensor, torch.Tensor]: Tuple containing the (min, max) values tracked across all batches """ - return self.min, self.max + return torch.stack([self.min, self.max]) class _Min(Metric): diff --git a/src/anomalib/post_processing/one_class.py b/src/anomalib/post_processing/one_class.py index eb6202c8f7..e9098aa37a 100644 --- a/src/anomalib/post_processing/one_class.py +++ b/src/anomalib/post_processing/one_class.py @@ -22,8 +22,7 @@ from lightning import LightningModule, Trainer from anomalib.data import Batch, InferenceBatch -from anomalib.metrics import F1AdaptiveThreshold -from anomalib.metrics.min_max import Max, Min +from anomalib.metrics import F1AdaptiveThreshold, MinMax from .base import PostProcessor @@ -67,10 +66,8 @@ def __init__( # initialize threshold and normalization metrics self._image_threshold = F1AdaptiveThreshold(fields=["pred_score", "gt_label"], strict=False) self._pixel_threshold = F1AdaptiveThreshold(fields=["anomaly_map", "gt_mask"], strict=False) - self._image_min = Min(fields=["pred_score"], strict=False) - self._image_max = Max(fields=["pred_score"], strict=False) - self._pixel_min = Min(fields=["anomaly_map"], strict=False) - self._pixel_max = Max(fields=["anomaly_map"], strict=False) + self._image_min_max = MinMax(fields=["pred_score"], strict=False) + self._pixel_min_max = MinMax(fields=["anomaly_map"], strict=False) # register buffers to persist threshold and normalization values self.register_buffer("image_threshold", torch.tensor(0)) @@ -107,10 +104,8 @@ def on_validation_batch_end( del trainer, pl_module, args, kwargs # Unused arguments. self._image_threshold.update(outputs) self._pixel_threshold.update(outputs) - self._image_min.update(outputs) - self._image_max.update(outputs) - self._pixel_min.update(outputs) - self._pixel_max.update(outputs) + self._image_min_max.update(outputs) + self._pixel_min_max.update(outputs) def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: """Compute final threshold and normalization values. @@ -122,10 +117,8 @@ def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) del trainer, pl_module self.image_threshold = self._image_threshold.compute() self.pixel_threshold = self._pixel_threshold.compute() - self.image_min = self._image_min.compute() - self.image_max = self._image_max.compute() - self.pixel_min = self._pixel_min.compute() - self.pixel_max = self._pixel_max.compute() + image_min_max = self._image_min_max.compute() + self.image_min, self.image_max = image_min_max if image_min_max is not None else (None, None) def on_test_batch_end( self, From e8e5be34b60e4db68ae3b6d50790508ba919b1b2 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Tue, 14 Jan 2025 16:56:06 +0100 Subject: [PATCH 09/28] remove individual min and max metrics --- src/anomalib/metrics/min_max.py | 132 -------------------------------- 1 file changed, 132 deletions(-) diff --git a/src/anomalib/metrics/min_max.py b/src/anomalib/metrics/min_max.py index 94f1a2dfb0..a8164d0659 100644 --- a/src/anomalib/metrics/min_max.py +++ b/src/anomalib/metrics/min_max.py @@ -99,137 +99,5 @@ def compute(self) -> tuple[torch.Tensor, torch.Tensor]: return torch.stack([self.min, self.max]) -class _Min(Metric): - """Track minimum value across batches. - - This metric maintains running minimum values across all batches - it processes. It is useful for tasks like normalization or monitoring the - range of values during training. - - Args: - full_state_update (bool, optional): Whether to update the internal state - with each new batch. Defaults to ``True``. - kwargs: Additional keyword arguments passed to the parent class. - - Attributes: - min (torch.Tensor): Running minimum value seen across all batches - - Example: - >>> from anomalib.metrics import MinMax - >>> import torch - >>> # Create metric - >>> minmax = Min() - >>> # Update with batches - >>> batch1 = torch.tensor([0.1, 0.2, 0.3]) - >>> batch2 = torch.tensor([0.2, 0.4, 0.5]) - >>> minmax.update(batch1) - >>> minmax.update(batch2) - >>> # Get final min/max values - >>> min_val, max_val = minmax.compute() - >>> min_val, max_val - (tensor(0.1000), tensor(0.5000)) - """ - - full_state_update: bool = True - - def __init__(self, **kwargs) -> None: - super().__init__(**kwargs) - self.add_state("min", torch.tensor(float("inf")), persistent=True, dist_reduce_fx="min") - - self.min = torch.tensor(float("inf")) - - def update(self, predictions: torch.Tensor, *args, **kwargs) -> None: - """Update running min and max values with new predictions. - - Args: - predictions (torch.Tensor): New tensor of values to include in min/max - tracking - *args: Additional positional arguments (unused) - **kwargs: Additional keyword arguments (unused) - """ - del args, kwargs # These variables are not used. - - self.min = torch.min(self.min, torch.min(predictions)) - - def compute(self) -> tuple[torch.Tensor, torch.Tensor]: - """Compute final minimum and maximum values. - - Returns: - tuple[torch.Tensor, torch.Tensor]: Tuple containing the (min, max) - values tracked across all batches - """ - return self.min - - -class _Max(Metric): - """Track maximum value across batches. - - This metric maintains running maximum values across all batches - it processes. It is useful for tasks like normalization or monitoring the - range of values during training. - - Args: - full_state_update (bool, optional): Whether to update the internal state - with each new batch. Defaults to ``True``. - kwargs: Additional keyword arguments passed to the parent class. - - Attributes: - max (torch.Tensor): Running maximum value seen across all batches - - Example: - >>> from anomalib.metrics import MinMax - >>> import torch - >>> # Create metric - >>> minmax = Min() - >>> # Update with batches - >>> batch1 = torch.tensor([0.1, 0.2, 0.3]) - >>> batch2 = torch.tensor([0.2, 0.4, 0.5]) - >>> minmax.update(batch1) - >>> minmax.update(batch2) - >>> # Get final min/max values - >>> min_val, max_val = minmax.compute() - >>> min_val, max_val - (tensor(0.1000), tensor(0.5000)) - """ - - full_state_update: bool = True - - def __init__(self, **kwargs) -> None: - super().__init__(**kwargs) - self.add_state("max", torch.tensor(float("-inf")), persistent=True, dist_reduce_fx="max") - - self.max = torch.tensor(float("-inf")) - - def update(self, predictions: torch.Tensor, *args, **kwargs) -> None: - """Update running min and max values with new predictions. - - Args: - predictions (torch.Tensor): New tensor of values to include in min/max - tracking - *args: Additional positional arguments (unused) - **kwargs: Additional keyword arguments (unused) - """ - del args, kwargs # These variables are not used. - - self.max = torch.max(self.max, torch.min(predictions)) - - def compute(self) -> tuple[torch.Tensor, torch.Tensor]: - """Compute final minimum and maximum values. - - Returns: - tuple[torch.Tensor, torch.Tensor]: Tuple containing the (min, max) - values tracked across all batches - """ - return self.max - - class MinMax(AnomalibMetric, _MinMax): # type: ignore[misc] """Wrapper to add AnomalibMetric functionality to MinMax metric.""" - - -class Min(AnomalibMetric, _Min): # type: ignore[misc] - """Wrapper to add AnomalibMetric functionality to Min metric.""" - - -class Max(AnomalibMetric, _Max): # type: ignore[misc] - """Wrapper to add AnomalibMetric functionality to Max metric.""" From 91a394735e2ef5707a51db458e9f3b680970ccf8 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Tue, 14 Jan 2025 16:59:27 +0100 Subject: [PATCH 10/28] fix threshold metric test --- tests/unit/metrics/test_adaptive_threshold.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/metrics/test_adaptive_threshold.py b/tests/unit/metrics/test_adaptive_threshold.py index e76a7effa2..b12a564c64 100644 --- a/tests/unit/metrics/test_adaptive_threshold.py +++ b/tests/unit/metrics/test_adaptive_threshold.py @@ -6,7 +6,7 @@ import pytest import torch -from anomalib.metrics import F1AdaptiveThreshold +from anomalib.metrics.threshold.f1_adaptive_threshold import _F1AdaptiveThreshold @pytest.mark.parametrize( @@ -18,7 +18,7 @@ ) def test_adaptive_threshold(labels: torch.Tensor, preds: torch.Tensor, target_threshold: int | float) -> None: """Test if the adaptive threshold computation returns the desired value.""" - adaptive_threshold = F1AdaptiveThreshold() + adaptive_threshold = _F1AdaptiveThreshold() adaptive_threshold.update(preds, labels) threshold_value = adaptive_threshold.compute() From 2dbc9e5e5b6a1698aa2f8baa90ac369d32192e6f Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Tue, 14 Jan 2025 17:14:13 +0100 Subject: [PATCH 11/28] remove usage of minmax metric in ai-vad --- src/anomalib/models/video/ai_vad/density.py | 32 +++++++++++---------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/src/anomalib/models/video/ai_vad/density.py b/src/anomalib/models/video/ai_vad/density.py index 65cef958f9..6c81bf57d3 100644 --- a/src/anomalib/models/video/ai_vad/density.py +++ b/src/anomalib/models/video/ai_vad/density.py @@ -33,7 +33,6 @@ import torch from torch import Tensor, nn -from anomalib.metrics.min_max import MinMax from anomalib.models.components.base import DynamicBufferMixin from anomalib.models.components.cluster.gmm import GaussianMixture @@ -296,10 +295,14 @@ def __init__(self, n_neighbors: int) -> None: self.n_neighbors = n_neighbors self.feature_collection: dict[str, list[torch.Tensor]] = {} self.group_index: dict[str, int] = {} - self.normalization_statistics = MinMax() self.register_buffer("memory_bank", Tensor()) - self.memory_bank: torch.Tensor = Tensor() + self.register_buffer("min", torch.tensor(torch.inf)) + self.register_buffer("max", torch.tensor(-torch.inf)) + + self.memory_bank: torch.Tensor + self.min: torch.Tensor + self.max: torch.Tensor def update(self, features: torch.Tensor, group: str | None = None) -> None: """Update the internal feature bank while keeping track of the group. @@ -428,9 +431,8 @@ def _compute_normalization_statistics(self, grouped_features: dict[str, Tensor]) """ for group, features in grouped_features.items(): distances = self.predict(features, group, normalize=False) - self.normalization_statistics.update(distances) - - self.normalization_statistics.compute() + self.min = torch.min(self.min, torch.min(distances)) + self.max = torch.max(self.min, torch.max(distances)) def _normalize(self, distances: torch.Tensor) -> torch.Tensor: """Normalize distance predictions. @@ -441,9 +443,7 @@ def _normalize(self, distances: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: Normalized distances. """ - return (distances - self.normalization_statistics.min) / ( - self.normalization_statistics.max - self.normalization_statistics.min - ) + return (distances - self.min) / (self.max - self.min) class GMMEstimator(BaseDensityEstimator): @@ -474,7 +474,11 @@ def __init__(self, n_components: int = 2) -> None: self.gmm = GaussianMixture(n_components=n_components) self.memory_bank: list[torch.Tensor] | torch.Tensor = [] - self.normalization_statistics = MinMax() + self.register_buffer("min", torch.tensor(torch.inf)) + self.register_buffer("max", torch.tensor(-torch.inf)) + + self.min: torch.Tensor + self.max: torch.Tensor def update(self, features: torch.Tensor, group: str | None = None) -> None: """Update the feature bank with new features. @@ -528,8 +532,8 @@ def _compute_normalization_statistics(self) -> None: statistics used for score normalization during inference. """ training_scores = self.predict(self.memory_bank, normalize=False) - self.normalization_statistics.update(training_scores) - self.normalization_statistics.compute() + self.min = torch.min(self.min, torch.min(training_scores)) + self.max = torch.max(self.min, torch.max(training_scores)) def _normalize(self, density: torch.Tensor) -> torch.Tensor: """Normalize anomaly scores using min-max statistics. @@ -540,6 +544,4 @@ def _normalize(self, density: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: Normalized anomaly scores of shape ``(N,)``. """ - return (density - self.normalization_statistics.min) / ( - self.normalization_statistics.max - self.normalization_statistics.min - ) + return (density - self.min) / (self.max - self.min) From 5ec700d2b7a51032e197cf6488c9fa6543588f2f Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Tue, 14 Jan 2025 17:15:47 +0100 Subject: [PATCH 12/28] make minmax metric non-persistent --- src/anomalib/metrics/min_max.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/anomalib/metrics/min_max.py b/src/anomalib/metrics/min_max.py index a8164d0659..a2a9f8a0b8 100644 --- a/src/anomalib/metrics/min_max.py +++ b/src/anomalib/metrics/min_max.py @@ -69,8 +69,8 @@ class _MinMax(Metric): def __init__(self, **kwargs) -> None: super().__init__(**kwargs) - self.add_state("min", torch.tensor(float("inf")), persistent=True, dist_reduce_fx="min") - self.add_state("max", torch.tensor(float("-inf")), persistent=True, dist_reduce_fx="max") + self.add_state("min", torch.tensor(float("inf")), dist_reduce_fx="min") + self.add_state("max", torch.tensor(float("-inf")), dist_reduce_fx="max") self.min = torch.tensor(float("inf")) self.max = torch.tensor(float("-inf")) From fedd93246d8a03a2901ce3c513403e89e39804f6 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Tue, 14 Jan 2025 17:53:30 +0100 Subject: [PATCH 13/28] fix docstring --- src/anomalib/metrics/min_max.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/anomalib/metrics/min_max.py b/src/anomalib/metrics/min_max.py index a2a9f8a0b8..b608522506 100644 --- a/src/anomalib/metrics/min_max.py +++ b/src/anomalib/metrics/min_max.py @@ -50,10 +50,10 @@ class _MinMax(Metric): max (torch.Tensor): Running maximum value seen across all batches Example: - >>> from anomalib.metrics import MinMax + >>> from anomalib.metrics.min_max import _MinMax >>> import torch >>> # Create metric - >>> minmax = MinMax() + >>> minmax = _MinMax() >>> # Update with batches >>> batch1 = torch.tensor([0.1, 0.2, 0.3]) >>> batch2 = torch.tensor([0.2, 0.4, 0.5]) From a447fd9f72df7d926518029d702cbc66a38cc2f3 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Tue, 14 Jan 2025 17:54:20 +0100 Subject: [PATCH 14/28] add create_anomalib_metric to module init --- src/anomalib/metrics/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/anomalib/metrics/__init__.py b/src/anomalib/metrics/__init__.py index e6ea10908d..cd9258c7f8 100644 --- a/src/anomalib/metrics/__init__.py +++ b/src/anomalib/metrics/__init__.py @@ -44,7 +44,7 @@ from .aupr import AUPR from .aupro import AUPRO from .auroc import AUROC -from .base import AnomalibMetric +from .base import AnomalibMetric, create_anomalib_metric from .evaluator import Evaluator from .f1_score import F1Max, F1Score from .min_max import MinMax @@ -60,6 +60,7 @@ "AnomalibMetric", "AnomalyScoreDistribution", "BinaryPrecisionRecallCurve", + "create_anomalib_metric", "Evaluator", "F1AdaptiveThreshold", "F1Max", From 4f91157fba2b0256fd42e426e30ebed080a88efa Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Tue, 14 Jan 2025 18:01:35 +0100 Subject: [PATCH 15/28] refactor update, add example to docstring --- src/anomalib/metrics/base.py | 45 ++++++++++++++++++++++++++++-------- 1 file changed, 35 insertions(+), 10 deletions(-) diff --git a/src/anomalib/metrics/base.py b/src/anomalib/metrics/base.py index 978f9be178..c5179f0de4 100644 --- a/src/anomalib/metrics/base.py +++ b/src/anomalib/metrics/base.py @@ -40,6 +40,26 @@ >>> from anomalib.metrics import create_anomalib_metric >>> F1Score = create_anomalib_metric(BinaryF1Score) >>> f1_score = F1Score(fields=["pred_label", "gt_label"]) + + Strict mode vs non-strict mode:: + + >>> F1Score = create_anomalib_metric(BinaryF1Score) + >>> + >>> # create metric in strict mode (default), and non-strict mode + >>> f1_score_strict = F1Score(fields=["pred_label", "gt_label"], strict=True) + >>> f1_score_nonstrict = F1Score(fields=["pred_label", "gt_label"], strict=False) + >>> + >>> # create a batch in which 'pred_label' field is None + >>> batch = ImageBatch( + ... image=torch.rand(4, 3, 256, 256), + ... gt_label=torch.tensor([0, 0, 1, 1]) + ... ) + >>> + >>> f1_score_strict.update(batch) # ValueError + >>> f1_score_strict.compute() # UserWarning, tensor(0.) + >>> + >>> f1_score_nonstrict.update(batch) # No error + >>> f1_score_nonstrict.compute() # None """ # Copyright (C) 2024 Intel Corporation @@ -137,17 +157,22 @@ def update(self, batch: Batch, *args, **kwargs) -> None: """ self.__update_count += 1 for key in self.fields: - if not hasattr(batch, key): - msg = f"Batch object is missing required field: {key}" - raise ValueError(msg) if getattr(batch, key, None) is None: - if self.strict: - msg = f"Cannot update metric of type {type(self)}. Field {key} in batch object is None" - raise ValueError(msg) - # we need to decrement the update count of the super class - # if we are not actually updating the metric states. + # We cannot update the metric if the batch is missing required fields, + # so we need to decrement the update count of the super class. self._update_count -= 1 # type: ignore[attr-defined] - return + if not self.strict: + # If not in strict mode, skip updating the metric but don't raise an error + return + # otherwise, raise an error + if not hasattr(batch, key): + msg = f"Cannot update metric of type {type(self)}. Batch object \ + is missing required field: {key}" + else: + msg = f"Cannot update metric of type {type(self)}. Passed item \ + does not have a value for field with name {key}." + raise ValueError(msg) + values = [getattr(batch, key) for key in self.fields] super().update(*values, *args, **kwargs) # type: ignore[misc] @@ -166,7 +191,7 @@ def compute(self) -> torch.Tensor: @property def update_called(self) -> bool: """Check if the update method has been called.""" - return self.__update_count > 0 + return self._update_count > 0 if self.strict else self.__update_count > 0 # type: ignore[attr-defined] def create_anomalib_metric(metric_cls: type) -> type: From af94c9c6bfb46c142b35ba3899ce969a1193691d Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Tue, 14 Jan 2025 18:29:25 +0100 Subject: [PATCH 16/28] update error messages --- src/anomalib/metrics/base.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/anomalib/metrics/base.py b/src/anomalib/metrics/base.py index c5179f0de4..aa5bb7ea01 100644 --- a/src/anomalib/metrics/base.py +++ b/src/anomalib/metrics/base.py @@ -166,11 +166,15 @@ def update(self, batch: Batch, *args, **kwargs) -> None: return # otherwise, raise an error if not hasattr(batch, key): - msg = f"Cannot update metric of type {type(self)}. Batch object \ - is missing required field: {key}" + msg = ( + f"Cannot update metric of type {type(self)}. Passed dataclass instance " + f"is missing required field: {key}" + ) else: - msg = f"Cannot update metric of type {type(self)}. Passed item \ - does not have a value for field with name {key}." + msg = ( + f"Cannot update metric of type {type(self)}. Passed dataclass instance " + f"does not have a value for field with name {key}." + ) raise ValueError(msg) values = [getattr(batch, key) for key in self.fields] From 61c588b6e852f88fbc1f0cb381e04f32bc350256 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Tue, 14 Jan 2025 18:39:14 +0100 Subject: [PATCH 17/28] add unit tests for AnomalibMetric class --- tests/unit/metrics/test_anomalib_metric.py | 94 ++++++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 tests/unit/metrics/test_anomalib_metric.py diff --git a/tests/unit/metrics/test_anomalib_metric.py b/tests/unit/metrics/test_anomalib_metric.py new file mode 100644 index 0000000000..6043a205c4 --- /dev/null +++ b/tests/unit/metrics/test_anomalib_metric.py @@ -0,0 +1,94 @@ +"""Tests for the AnomalibMetric base class.""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +from torchmetrics import Metric + +from anomalib.data import ImageBatch +from anomalib.metrics import AnomalibMetric, create_anomalib_metric + + +class DummyMetric(Metric): + """Dummy metric that does nothing.""" + + def update(self, *args, **kwargs) -> None: + """Dummy update method.""" + + def compute(self) -> None: + """Dummy compute method.""" + + +class TestMetricCreation: + """Test the creation of Anomalib metrics.""" + + @staticmethod + def test_create_anomalib_metric_function() -> None: + """Test if defining a metric using the function works.""" + metric_cls = create_anomalib_metric(DummyMetric) + assert issubclass(metric_cls, AnomalibMetric) + assert issubclass(metric_cls, Metric) + + @staticmethod + def test_create_anomalib_metric_subclass() -> None: + """Test if defining a metric using a subclass works.""" + + class AnomalibDummyMetric(AnomalibMetric, DummyMetric): + pass + + assert issubclass(AnomalibDummyMetric, AnomalibMetric) + assert issubclass(AnomalibDummyMetric, Metric) + + +class TestStrictMode: + """Test the strict mode of Anomalib metrics.""" + + @staticmethod + def test_raises_error_on_missing_fields() -> None: + """Test that an error is raised when required fields are missing in strict mode.""" + metric_cls = create_anomalib_metric(DummyMetric) + + metric = metric_cls(fields=["non_existent_field"]) + batch = ImageBatch(image=torch.rand(4, 3, 10, 10)) # batch without field + with pytest.raises(ValueError, match="instance is missing required field"): + metric.update(batch) + assert metric._update_count == 0 # noqa: SLF001 + assert metric.update_called is False + + @staticmethod + def test_raises_error_when_field_is_none() -> None: + """Test that an error is raised when a required field is None in strict mode.""" + metric_cls = create_anomalib_metric(DummyMetric) + + metric = metric_cls(fields=["pred_score"]) + batch = ImageBatch(image=torch.rand(4, 3, 10, 10), pred_score=None) # batch where field is None + with pytest.raises(ValueError, match="instance does not have a value for field with name"): + metric.update(batch) + assert metric._update_count == 0 # noqa: SLF001 + assert metric.update_called is False + + @staticmethod + def test_no_error_on_missing_fields() -> None: + """Test that no error is raised when required fields are missing in non-strict mode.""" + metric_cls = create_anomalib_metric(DummyMetric) + + metric = metric_cls(fields=["pred_score"], strict=False) + batch = ImageBatch(image=torch.rand(4, 3, 10, 10)) # batch without pred_score field + metric.update(batch) + assert metric.compute() is None + assert metric._update_count == 0 # noqa: SLF001 + assert metric.update_called is True + + @staticmethod + def test_no_error_when_field_is_none() -> None: + """Test that no error is raised when a required field is None in non-strict mode.""" + metric_cls = create_anomalib_metric(DummyMetric) + + metric = metric_cls(fields=["pred_score"], strict=False) + batch = ImageBatch(image=torch.rand(4, 3, 10, 10), pred_score=None) + metric.update(batch) + assert metric.compute() is None + assert metric._update_count == 0 # noqa: SLF001 + assert metric.update_called is True From 1e9f647d358709d2259be719bf65772fbaa85ef3 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Tue, 14 Jan 2025 19:07:50 +0100 Subject: [PATCH 18/28] typing --- src/anomalib/post_processing/one_class.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/anomalib/post_processing/one_class.py b/src/anomalib/post_processing/one_class.py index e9098aa37a..124581169e 100644 --- a/src/anomalib/post_processing/one_class.py +++ b/src/anomalib/post_processing/one_class.py @@ -230,7 +230,10 @@ def normalize_batch(self, batch: Batch) -> None: batch.pred_score = self._normalize(batch.pred_score, self.image_min, self.image_max, self.image_threshold) @staticmethod - def _apply_threshold(preds: torch.Tensor | None, threshold: float) -> torch.Tensor | None: + def _apply_threshold( + preds: torch.Tensor | None, + threshold: torch.Tensor | None, + ) -> torch.Tensor | None: """Apply thresholding to a single tensor. Args: @@ -247,9 +250,9 @@ def _apply_threshold(preds: torch.Tensor | None, threshold: float) -> torch.Tens @staticmethod def _normalize( preds: torch.Tensor | None, - norm_min: float | None, - norm_max: float | None, - threshold: float | None, + norm_min: torch.Tensor | None, + norm_max: torch.Tensor | None, + threshold: torch.Tensor | None, ) -> torch.Tensor | None: """Normalize a tensor using min, max, and threshold values. @@ -275,8 +278,8 @@ def normalized_image_threshold(self) -> float: float: Normalized image-level threshold value, adjusted by sensitivity. """ if self.image_sensitivity is not None: - return 1 - self.image_sensitivity - return 0.5 + return torch.tensor(1.0) - self.image_sensitivity + return torch.tensor(0.5) @property def normalized_pixel_threshold(self) -> float: @@ -286,5 +289,5 @@ def normalized_pixel_threshold(self) -> float: float: Normalized pixel-level threshold value, adjusted by sensitivity. """ if self.pixel_sensitivity is not None: - return 1 - self.pixel_sensitivity - return 0.5 + return torch.tensor(1.0) - self.pixel_sensitivity + return torch.tensor(0.5) From faec7eb3ab8c077df5edc8d54cbd0079cb9257b2 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Mon, 20 Jan 2025 13:33:36 +0100 Subject: [PATCH 19/28] improve update_called logic --- src/anomalib/metrics/base.py | 5 ----- src/anomalib/post_processing/one_class.py | 12 ++++++++---- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/anomalib/metrics/base.py b/src/anomalib/metrics/base.py index aa5bb7ea01..b843cf6857 100644 --- a/src/anomalib/metrics/base.py +++ b/src/anomalib/metrics/base.py @@ -192,11 +192,6 @@ def compute(self) -> torch.Tensor: return None return super().compute() # type: ignore[misc] - @property - def update_called(self) -> bool: - """Check if the update method has been called.""" - return self._update_count > 0 if self.strict else self.__update_count > 0 # type: ignore[attr-defined] - def create_anomalib_metric(metric_cls: type) -> type: """Create an Anomalib version of a torchmetrics metric. diff --git a/src/anomalib/post_processing/one_class.py b/src/anomalib/post_processing/one_class.py index 124581169e..ffa906a176 100644 --- a/src/anomalib/post_processing/one_class.py +++ b/src/anomalib/post_processing/one_class.py @@ -115,10 +115,14 @@ def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) pl_module (LightningModule): PyTorch Lightning module instance. """ del trainer, pl_module - self.image_threshold = self._image_threshold.compute() - self.pixel_threshold = self._pixel_threshold.compute() - image_min_max = self._image_min_max.compute() - self.image_min, self.image_max = image_min_max if image_min_max is not None else (None, None) + if self._image_threshold.update_called: + self.image_threshold = self._image_threshold.compute() + if self._pixel_threshold.update_called: + self.pixel_threshold = self._pixel_threshold.compute() + if self._image_min_max.update_called: + self.image_min, self.image_max = self._image_min_max.compute() + if self._pixel_min_max.update_called: + self.pixel_min, self.pixel_max = self._pixel_min_max.compute() def on_test_batch_end( self, From 640f05a64faa5aad8333b9048dd39f3fc555a093 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Mon, 20 Jan 2025 13:34:28 +0100 Subject: [PATCH 20/28] set pixel metrics as non-strict --- src/anomalib/models/components/base/anomalib_module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/anomalib/models/components/base/anomalib_module.py b/src/anomalib/models/components/base/anomalib_module.py index 3b2bd75a62..862ce1df1e 100644 --- a/src/anomalib/models/components/base/anomalib_module.py +++ b/src/anomalib/models/components/base/anomalib_module.py @@ -374,8 +374,8 @@ def configure_evaluator() -> Evaluator: """ image_auroc = AUROC(fields=["pred_score", "gt_label"], prefix="image_") image_f1score = F1Score(fields=["pred_label", "gt_label"], prefix="image_") - pixel_auroc = AUROC(fields=["anomaly_map", "gt_mask"], prefix="pixel_") - pixel_f1score = F1Score(fields=["pred_mask", "gt_mask"], prefix="pixel_") + pixel_auroc = AUROC(fields=["anomaly_map", "gt_mask"], prefix="pixel_", strict=False) + pixel_f1score = F1Score(fields=["pred_mask", "gt_mask"], prefix="pixel_", strict=False) test_metrics = [image_auroc, image_f1score, pixel_auroc, pixel_f1score] return Evaluator(test_metrics=test_metrics) From 0f2eb0e381eafddbfceec47639360a78748e3e1e Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Mon, 20 Jan 2025 14:07:09 +0100 Subject: [PATCH 21/28] add task type tests --- tests/helpers/data.py | 19 ++ tests/integration/test_task_types.py | 278 +++++++++++++++++++++++++++ 2 files changed, 297 insertions(+) create mode 100644 tests/integration/test_task_types.py diff --git a/tests/helpers/data.py b/tests/helpers/data.py index e1efccc1b1..541a787aea 100644 --- a/tests/helpers/data.py +++ b/tests/helpers/data.py @@ -392,6 +392,25 @@ def _generate_dummy_mvtec_dataset( mask_filename = mask_path / f"{i:03}{mask_suffix}{mask_extension}" self.image_generator.generate_image(label, image_filename, mask_filename) + def _generate_dummy_folder_dataset(self) -> None: + """Generate dummy folder dataset in a temporary directory.""" + # folder names + normal_dir = self.root / self.normal_category + abnormal_dir = self.root / self.abnormal_category + mask_dir = self.root / "masks" + + # generate images + for i in range(self.num_train): + label = LabelName.NORMAL + image_filename = normal_dir / f"{self.normal_category}_{i:03}.png" + self.image_generator.generate_image(label, image_filename) + + for i in range(self.num_test): + label = LabelName.ABNORMAL + image_filename = abnormal_dir / f"{self.abnormal_category}_{i:03}.png" + mask_filename = mask_dir / image_filename.name + self.image_generator.generate_image(label, image_filename, mask_filename) + def _generate_dummy_btech_dataset(self) -> None: """Generate dummy BeanTech dataset in directory using the same convention as BeanTech AD.""" # BeanTech AD follows the same convention as MVTec AD. diff --git a/tests/integration/test_task_types.py b/tests/integration/test_task_types.py new file mode 100644 index 0000000000..c0ba23113e --- /dev/null +++ b/tests/integration/test_task_types.py @@ -0,0 +1,278 @@ +"""Tests to check behaviour of the auxiliary components across different task types (classification, segmentation) .""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import copy +from pathlib import Path +from typing import Any + +import pytest +import torch +from torchmetrics import Metric + +from anomalib import LearningType +from anomalib.data import AnomalibDataModule, Batch, Folder, ImageDataFormat +from anomalib.engine import Engine +from anomalib.metrics import AnomalibMetric, Evaluator +from anomalib.models import AnomalibModule +from anomalib.post_processing import OneClassPostProcessor +from anomalib.visualization import ImageVisualizer +from tests.helpers.data import DummyImageDatasetGenerator + + +class DummyBaseModel(AnomalibModule): + """Dummy model for testing. + + No training, and all auxiliary components default to None. This allows testing of the different components + in isolation. + """ + + def training_step(self, *args, **kwargs) -> None: + """Dummy training step.""" + + @property + def trainer_arguments(self) -> dict[str, Any]: + """Run for single epoch.""" + return {"max_epochs": 1} + + @property + def learning_type(self) -> LearningType: + """Return the learning type of the model.""" + return LearningType.ONE_CLASS + + def configure_optimizers(self) -> None: + """No optimizers needed.""" + + def configure_preprocessor(self) -> None: + """No default pre-processor needed.""" + + def configure_post_processor(self) -> None: + """No default post-processor needed.""" + + def configure_evaluator(self) -> None: + """No default evaluator needed.""" + + def configure_visualizer(self) -> None: + """No default visualizer needed.""" + + +class DummyClassificationModel(DummyBaseModel): + """Dummy classification model for testing. + + Validation step returns random image-only scores, to simulate a model that performs classification. + """ + + def validation_step(self, batch: Batch, *args, **kwargs) -> Batch: + """Validation steps that returns random image-level scores.""" + del args, kwargs + batch.pred_score = torch.rand(batch.batch_size, device=self.device) + return batch + + +class DummySegmentationModel(DummyBaseModel): + """Dummy segmentation model for testing. + + Validation step returns random image- and pixel-level scores, to simulate a model that performs segmentation. + """ + + def validation_step(self, batch: Batch, *args, **kwargs) -> Batch: + """Validation steps that returns random image- and pixel-level scores.""" + del args, kwargs + batch.pred_score = torch.rand(batch.batch_size, device=self.device) + batch.anomaly_map = torch.rand(batch.batch_size, *batch.image.shape[-2:], device=self.device) + return batch + + +class _DummyMetric(Metric): + """Dummy metric for testing.""" + + def update(self, *args, **kwargs) -> None: + """Dummy update method.""" + + def compute(self) -> None: + """Dummy compute method.""" + assert self.update_called # simulate failure to compute if states are not updated + + +class DummyMetric(AnomalibMetric, _DummyMetric): + """Dummy Anomalib metric for testing.""" + + +@pytest.fixture +def folder_dataset_path(project_path: Path) -> Path: + """Create a dummy folder dataset for testing.""" + data_path = project_path / "dataset" + dataset_generator = DummyImageDatasetGenerator( + data_format=ImageDataFormat.FOLDER, + root=data_path, + num_train=10, + num_test=10, + ) + dataset_generator.generate_dataset() + return data_path + + +@pytest.fixture +def classification_datamodule(folder_dataset_path: Path) -> AnomalibDataModule: + """Create a classification datamodule for testing. + + The datamodule is created with a folder dataset, that does not have a mask directory. + """ + # create the folder datamodule + return Folder( + name="cls_dataset", + root=folder_dataset_path, + normal_dir="good", + abnormal_dir="bad", + train_batch_size=1, + eval_batch_size=1, + num_workers=0, + ) + + +@pytest.fixture +def segmentation_datamodule(folder_dataset_path: Path) -> AnomalibDataModule: + """Create a segmentation datamodule for testing. + + The datamodule is created with a folder dataset, that has a mask directory. + """ + # create the folder datamodule + return Folder( + name="seg_dataset", + root=folder_dataset_path, + normal_dir="good", + abnormal_dir="bad", + mask_dir="masks", # include masks for segmentation dataset + train_batch_size=1, + eval_batch_size=1, + num_workers=0, + ) + + +@pytest.fixture +def image_and_pixel_evaluator() -> Evaluator: + """Create an evaluator with image- and pixel-level metrics for testing.""" + image_metric = DummyMetric(fields=["pred_score", "gt_label"], prefix="image_") + pixel_metric = DummyMetric(fields=["anomaly_map", "gt_mask"], prefix="pixel_", strict=False) + val_metrics = [image_metric, pixel_metric] + test_metrics = copy.deepcopy(val_metrics) + return Evaluator(val_metrics=[image_metric, pixel_metric], test_metrics=test_metrics) + + +@pytest.fixture +def engine(project_path: Path) -> Engine: + """Create an engine for testing. + + Run on cpu to speed up tests. + """ + return Engine(accelerator="cpu", default_root_dir=project_path) + + +class TestEvaluation: + """Test evaluation across task types. + + Tests if image- and/or pixel- metrics are computed without errors for models and datasets with different task types. + """ + + @staticmethod + def test_cls_model_cls_dataset( + engine: Engine, + classification_datamodule: AnomalibDataModule, + image_and_pixel_evaluator: Evaluator, + ) -> None: + """Test classification model with classification dataset.""" + model = DummyClassificationModel(evaluator=image_and_pixel_evaluator) + engine.train(model, datamodule=classification_datamodule) + + @staticmethod + def test_cls_model_seg_dataset( + engine: Engine, + segmentation_datamodule: AnomalibDataModule, + image_and_pixel_evaluator: Evaluator, + ) -> None: + """Test classification model with segmentation dataset.""" + model = DummyClassificationModel(evaluator=image_and_pixel_evaluator) + engine.train(model, datamodule=segmentation_datamodule) + + @staticmethod + def test_seg_model_cls_dataset( + engine: Engine, + classification_datamodule: AnomalibDataModule, + image_and_pixel_evaluator: Evaluator, + ) -> None: + """Test segmentation model with classification dataset.""" + model = DummySegmentationModel(evaluator=image_and_pixel_evaluator) + engine.train(model, datamodule=classification_datamodule) + + @staticmethod + def test_seg_model_seg_dataset( + engine: Engine, + segmentation_datamodule: AnomalibDataModule, + image_and_pixel_evaluator: Evaluator, + ) -> None: + """Test segmentation model with segmentation dataset.""" + model = DummySegmentationModel(evaluator=image_and_pixel_evaluator) + engine.train(model, datamodule=segmentation_datamodule) + + +class TestPostProcessing: + """Tests post-processing across task types. + + Tests if post-processing is applied without errors for models and datasets with different task types. + """ + + @staticmethod + def test_cls_model_cls_dataset(engine: Engine, classification_datamodule: AnomalibDataModule) -> None: + """Test classification model with classification dataset.""" + model = DummyClassificationModel(post_processor=OneClassPostProcessor()) + engine.train(model, datamodule=classification_datamodule) + + @staticmethod + def test_cls_model_seg_dataset(engine: Engine, segmentation_datamodule: AnomalibDataModule) -> None: + """Test classification model with segmentation dataset.""" + model = DummyClassificationModel(post_processor=OneClassPostProcessor()) + engine.train(model, datamodule=segmentation_datamodule) + + @staticmethod + def test_seg_model_cls_dataset(engine: Engine, classification_datamodule: AnomalibDataModule) -> None: + """Test segmentation model with classification dataset.""" + model = DummySegmentationModel(post_processor=OneClassPostProcessor()) + engine.train(model, datamodule=classification_datamodule) + + @staticmethod + def test_seg_model_seg_dataset(engine: Engine, segmentation_datamodule: AnomalibDataModule) -> None: + """Test segmentation model with segmentation dataset.""" + model = DummySegmentationModel(post_processor=OneClassPostProcessor()) + engine.train(model, datamodule=segmentation_datamodule) + + +class TestVisualization: + """Tests visualization across task types. + + Tests if visualizations are created without errors for models and datasets with different task types. + """ + + @staticmethod + def test_cls_model_cls_dataset(engine: Engine, classification_datamodule: AnomalibDataModule) -> None: + """Test classification model with classification dataset.""" + model = DummyClassificationModel(visualizer=ImageVisualizer()) + engine.train(model, datamodule=classification_datamodule) + + @staticmethod + def test_cls_model_seg_dataset(engine: Engine, segmentation_datamodule: AnomalibDataModule) -> None: + """Test classification model with segmentation dataset.""" + model = DummyClassificationModel(visualizer=ImageVisualizer()) + engine.train(model, datamodule=segmentation_datamodule) + + @staticmethod + def test_seg_model_cls_dataset(engine: Engine, classification_datamodule: AnomalibDataModule) -> None: + """Test segmentation model with classification dataset.""" + model = DummySegmentationModel(visualizer=ImageVisualizer()) + engine.train(model, datamodule=classification_datamodule) + + @staticmethod + def test_seg_model_seg_dataset(engine: Engine, segmentation_datamodule: AnomalibDataModule) -> None: + """Test segmentation model with segmentation dataset.""" + model = DummySegmentationModel(visualizer=ImageVisualizer()) + engine.train(model, datamodule=segmentation_datamodule) From af984ea74a5bc43ac3283491e937ba3dbb7f8ac0 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Mon, 20 Jan 2025 14:13:03 +0100 Subject: [PATCH 22/28] update metric unit tests --- tests/unit/metrics/test_anomalib_metric.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/metrics/test_anomalib_metric.py b/tests/unit/metrics/test_anomalib_metric.py index 6043a205c4..d198bae15d 100644 --- a/tests/unit/metrics/test_anomalib_metric.py +++ b/tests/unit/metrics/test_anomalib_metric.py @@ -79,7 +79,7 @@ def test_no_error_on_missing_fields() -> None: metric.update(batch) assert metric.compute() is None assert metric._update_count == 0 # noqa: SLF001 - assert metric.update_called is True + assert metric.update_called is False @staticmethod def test_no_error_when_field_is_none() -> None: @@ -91,4 +91,4 @@ def test_no_error_when_field_is_none() -> None: metric.update(batch) assert metric.compute() is None assert metric._update_count == 0 # noqa: SLF001 - assert metric.update_called is True + assert metric.update_called is False From 166125073e2f701b6e539d758509575c07f96f77 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Mon, 20 Jan 2025 15:37:39 +0100 Subject: [PATCH 23/28] update minmax tests after merging main --- tests/unit/metrics/test_min_max.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/unit/metrics/test_min_max.py b/tests/unit/metrics/test_min_max.py index 98284e4dc1..847cf5e892 100644 --- a/tests/unit/metrics/test_min_max.py +++ b/tests/unit/metrics/test_min_max.py @@ -5,12 +5,12 @@ import torch -from anomalib.metrics import MinMax # Assuming the metric is part of `anomalib` +from anomalib.metrics import _MinMax # Assuming the metric is part of `anomalib` def test_initialization() -> None: """Test if the metric initializes with correct default values.""" - metric = MinMax() + metric = _MinMax() assert torch.isinf(metric.min), "Initial min should be positive infinity." assert metric.min > 0, "Initial min should be positive infinity." assert torch.isinf(metric.max), "Initial max should be negative infinity." @@ -19,7 +19,7 @@ def test_initialization() -> None: def test_update_single_batch() -> None: """Test updating the metric with a single batch.""" - metric = MinMax() + metric = _MinMax() batch = torch.tensor([1.0, 2.0, 3.0, -1.0]) metric.update(batch) @@ -29,7 +29,7 @@ def test_update_single_batch() -> None: def test_update_multiple_batches() -> None: """Test updating the metric with multiple batches.""" - metric = MinMax() + metric = _MinMax() batch1 = torch.tensor([0.5, 1.5, 3.0]) batch2 = torch.tensor([-0.5, 0.0, 2.5]) @@ -42,7 +42,7 @@ def test_update_multiple_batches() -> None: def test_compute() -> None: """Test computation of the min and max values after updates.""" - metric = MinMax() + metric = _MinMax() batch1 = torch.tensor([1.0, 2.0]) batch2 = torch.tensor([-1.0, 0.0]) @@ -57,7 +57,7 @@ def test_compute() -> None: def test_no_updates() -> None: """Test behavior when no updates are made to the metric.""" - metric = MinMax() + metric = _MinMax() min_val, max_val = metric.compute() From 06120f8f4fa29febff399bc09adcfc305950645a Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Mon, 20 Jan 2025 15:39:42 +0100 Subject: [PATCH 24/28] license header --- src/anomalib/metrics/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anomalib/metrics/base.py b/src/anomalib/metrics/base.py index b843cf6857..32edef1b9f 100644 --- a/src/anomalib/metrics/base.py +++ b/src/anomalib/metrics/base.py @@ -62,7 +62,7 @@ >>> f1_score_nonstrict.compute() # None """ -# Copyright (C) 2024 Intel Corporation +# Copyright (C) 2024-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 from collections.abc import Sequence From e8666b2e0f6d52f44e1f4596f6071dbffb6a83f3 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Mon, 20 Jan 2025 15:42:49 +0100 Subject: [PATCH 25/28] remove outer update count --- src/anomalib/metrics/base.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/anomalib/metrics/base.py b/src/anomalib/metrics/base.py index 32edef1b9f..19feb95b49 100644 --- a/src/anomalib/metrics/base.py +++ b/src/anomalib/metrics/base.py @@ -133,7 +133,6 @@ def __init__( self.fields = fields self.name = prefix + self.__class__.__name__ self.strict = strict - self.__update_count = 0 # keeps track of the update calls of the wrapper class super().__init__(**kwargs) def __init_subclass__(cls, **kwargs) -> None: @@ -155,7 +154,6 @@ def update(self, batch: Batch, *args, **kwargs) -> None: Raises: ValueError: If batch is missing any required fields. """ - self.__update_count += 1 for key in self.fields: if getattr(batch, key, None) is None: # We cannot update the metric if the batch is missing required fields, From fe3c117e9c7784806cbbabfd30d680ab986e6e73 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Mon, 20 Jan 2025 15:56:25 +0100 Subject: [PATCH 26/28] typing --- src/anomalib/metrics/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anomalib/metrics/base.py b/src/anomalib/metrics/base.py index 19feb95b49..d603257321 100644 --- a/src/anomalib/metrics/base.py +++ b/src/anomalib/metrics/base.py @@ -178,7 +178,7 @@ def update(self, batch: Batch, *args, **kwargs) -> None: values = [getattr(batch, key) for key in self.fields] super().update(*values, *args, **kwargs) # type: ignore[misc] - def compute(self) -> torch.Tensor: + def compute(self) -> torch.Tensor | None: """Compute the metric value. If the metric has not been updated, and metric is not in strict mode, return None. From f3076387baba25237c92544554541351baa13245 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Mon, 20 Jan 2025 16:06:42 +0100 Subject: [PATCH 27/28] Revert "typing" This reverts commit fe3c117e9c7784806cbbabfd30d680ab986e6e73. --- src/anomalib/metrics/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anomalib/metrics/base.py b/src/anomalib/metrics/base.py index d603257321..19feb95b49 100644 --- a/src/anomalib/metrics/base.py +++ b/src/anomalib/metrics/base.py @@ -178,7 +178,7 @@ def update(self, batch: Batch, *args, **kwargs) -> None: values = [getattr(batch, key) for key in self.fields] super().update(*values, *args, **kwargs) # type: ignore[misc] - def compute(self) -> torch.Tensor | None: + def compute(self) -> torch.Tensor: """Compute the metric value. If the metric has not been updated, and metric is not in strict mode, return None. From aced62c5d538712f99259c0326e3b521b2c7e524 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Mon, 20 Jan 2025 16:16:48 +0100 Subject: [PATCH 28/28] fix import --- tests/unit/metrics/test_min_max.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/metrics/test_min_max.py b/tests/unit/metrics/test_min_max.py index 847cf5e892..62e8ff1639 100644 --- a/tests/unit/metrics/test_min_max.py +++ b/tests/unit/metrics/test_min_max.py @@ -5,7 +5,7 @@ import torch -from anomalib.metrics import _MinMax # Assuming the metric is part of `anomalib` +from anomalib.metrics.min_max import _MinMax # Assuming the metric is part of `anomalib` def test_initialization() -> None: