diff --git a/src/anomalib/metrics/min_max.py b/src/anomalib/metrics/min_max.py index c071597bf5..ce04484603 100644 --- a/src/anomalib/metrics/min_max.py +++ b/src/anomalib/metrics/min_max.py @@ -67,8 +67,8 @@ class MinMax(Metric): def __init__(self, **kwargs) -> None: super().__init__(**kwargs) - self.add_state("min", torch.tensor(float("inf")), persistent=True) - self.add_state("max", torch.tensor(float("-inf")), persistent=True) + self.add_state("min", torch.tensor(float("inf")), persistent=True, dist_reduce_fx="min") + self.add_state("max", torch.tensor(float("-inf")), persistent=True, dist_reduce_fx="max") self.min = torch.tensor(float("inf")) self.max = torch.tensor(float("-inf")) @@ -84,8 +84,8 @@ def update(self, predictions: torch.Tensor, *args, **kwargs) -> None: """ del args, kwargs # These variables are not used. - self.max = torch.max(self.max, torch.max(predictions)) self.min = torch.min(self.min, torch.min(predictions)) + self.max = torch.max(self.min, torch.max(predictions)) def compute(self) -> tuple[torch.Tensor, torch.Tensor]: """Compute final minimum and maximum values. diff --git a/src/anomalib/metrics/threshold/f1_adaptive_threshold.py b/src/anomalib/metrics/threshold/f1_adaptive_threshold.py index 1d1461ddaf..80e6445d9f 100644 --- a/src/anomalib/metrics/threshold/f1_adaptive_threshold.py +++ b/src/anomalib/metrics/threshold/f1_adaptive_threshold.py @@ -46,14 +46,6 @@ class F1AdaptiveThreshold(BinaryPrecisionRecallCurve, Threshold): This class computes and stores the optimal threshold for converting anomaly scores to binary predictions by maximizing the F1 score on validation data. - Args: - default_value: Initial threshold value used before computation. - Defaults to ``0.5``. - **kwargs: Additional arguments passed to parent classes. - - Attributes: - value (torch.Tensor): Current threshold value. - Example: >>> from anomalib.metrics import F1AdaptiveThreshold >>> import torch @@ -68,12 +60,6 @@ class F1AdaptiveThreshold(BinaryPrecisionRecallCurve, Threshold): Optimal threshold: 0.5000 """ - def __init__(self, default_value: float = 0.5, **kwargs) -> None: - super().__init__(**kwargs) - - self.add_state("value", default=torch.tensor(default_value), persistent=True) - self.value = torch.tensor(default_value) - def compute(self) -> torch.Tensor: """Compute optimal threshold by maximizing F1 score. @@ -105,18 +91,7 @@ def compute(self) -> torch.Tensor: precision, recall, thresholds = super().compute() f1_score = (2 * precision * recall) / (precision + recall + 1e-10) - if thresholds.dim() == 0: - # special case where recall is 1.0 even for the highest threshold. - # In this case 'thresholds' will be scalar. - self.value = thresholds - else: - self.value = thresholds[torch.argmax(f1_score)] - return self.value - def __repr__(self) -> str: - """Return string representation including current threshold value. - - Returns: - str: String in format "ClassName(value=X.XX)" - """ - return f"{super().__repr__()} (value={self.value:.2f})" + # account for special case where recall is 1.0 even for the highest threshold. + # In this case 'thresholds' will be scalar. + return thresholds if thresholds.dim() == 0 else thresholds[torch.argmax(f1_score)] diff --git a/src/anomalib/post_processing/one_class.py b/src/anomalib/post_processing/one_class.py index ca89ba4df5..b7e135f1e8 100644 --- a/src/anomalib/post_processing/one_class.py +++ b/src/anomalib/post_processing/one_class.py @@ -59,13 +59,30 @@ def __init__( ) -> None: super().__init__(**kwargs) + # configure sensitivity values + self.image_sensitivity = image_sensitivity + self.pixel_sensitivity = pixel_sensitivity + + # initialize threshold and normalization metrics self._image_threshold = F1AdaptiveThreshold() self._pixel_threshold = F1AdaptiveThreshold() self._image_normalization_stats = MinMax() self._pixel_normalization_stats = MinMax() - self.image_sensitivity = image_sensitivity - self.pixel_sensitivity = pixel_sensitivity + # register buffers to persist threshold and normalization values + self.register_buffer("image_threshold", torch.tensor(0)) + self.register_buffer("pixel_threshold", torch.tensor(0)) + self.register_buffer("image_min", torch.tensor(0)) + self.register_buffer("image_max", torch.tensor(1)) + self.register_buffer("pixel_min", torch.tensor(0)) + self.register_buffer("pixel_max", torch.tensor(1)) + + self.image_threshold: torch.Tensor + self.pixel_threshold: torch.Tensor + self.image_min: torch.Tensor + self.image_max: torch.Tensor + self.pixel_min: torch.Tensor + self.pixel_max: torch.Tensor def on_validation_batch_end( self, @@ -103,13 +120,13 @@ def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) """ del trainer, pl_module if self._image_threshold.update_called: - self._image_threshold.compute() + self.image_threshold = self._image_threshold.compute() if self._pixel_threshold.update_called: - self._pixel_threshold.compute() + self.pixel_threshold = self._pixel_threshold.compute() if self._image_normalization_stats.update_called: - self._image_normalization_stats.compute() + self.image_min, self.image_max = self._image_normalization_stats.compute() if self._pixel_normalization_stats.update_called: - self._pixel_normalization_stats.compute() + self.pixel_min, self.pixel_max = self._pixel_normalization_stats.compute() def on_test_batch_end( self, @@ -168,8 +185,8 @@ def forward(self, predictions: InferenceBatch) -> InferenceBatch: msg = "At least one of pred_score or anomaly_map must be provided." raise ValueError(msg) pred_score = predictions.pred_score or torch.amax(predictions.anomaly_map, dim=(-2, -1)) - pred_score = self._normalize(pred_score, self.image_min, self.image_max, self.raw_image_threshold) - anomaly_map = self._normalize(predictions.anomaly_map, self.pixel_min, self.pixel_max, self.raw_pixel_threshold) + pred_score = self._normalize(pred_score, self.image_min, self.image_max, self.image_threshold) + anomaly_map = self._normalize(predictions.anomaly_map, self.pixel_min, self.pixel_max, self.pixel_threshold) pred_label = self._threshold(pred_score, self.normalized_image_threshold) pred_mask = self._threshold(anomaly_map, self.normalized_pixel_threshold) return InferenceBatch( @@ -216,9 +233,9 @@ def normalize_batch(self, batch: Batch) -> None: batch (Batch): Batch containing model predictions. """ # normalize pixel-level predictions - batch.anomaly_map = self._normalize(batch.anomaly_map, self.pixel_min, self.pixel_max, self.raw_pixel_threshold) + batch.anomaly_map = self._normalize(batch.anomaly_map, self.pixel_min, self.pixel_max, self.pixel_threshold) # normalize image-level predictions - batch.pred_score = self._normalize(batch.pred_score, self.image_min, self.image_max, self.raw_image_threshold) + batch.pred_score = self._normalize(batch.pred_score, self.image_min, self.image_max, self.image_threshold) @staticmethod def _threshold(preds: torch.Tensor | None, threshold: float) -> torch.Tensor | None: @@ -256,26 +273,7 @@ def _normalize( if preds is None: return None preds = ((preds - threshold) / (norm_max - norm_min)) + 0.5 - preds = torch.minimum(preds, torch.tensor(1)) - return torch.maximum(preds, torch.tensor(0)) - - @property - def raw_image_threshold(self) -> float: - """Get the raw image-level threshold. - - Returns: - float: Raw image-level threshold value. - """ - return self._image_threshold.value - - @property - def raw_pixel_threshold(self) -> float: - """Get the raw pixel-level threshold. - - Returns: - float: Raw pixel-level threshold value. - """ - return self._pixel_threshold.value + return preds.clamp(min=0, max=1) @property def normalized_image_threshold(self) -> float: @@ -298,39 +296,3 @@ def normalized_pixel_threshold(self) -> float: if self.pixel_sensitivity is not None: return 1 - self.pixel_sensitivity return 0.5 - - @property - def image_min(self) -> float: - """Get the minimum value for image-level normalization. - - Returns: - float: Minimum image-level value. - """ - return self._image_normalization_stats.min - - @property - def image_max(self) -> float: - """Get the maximum value for image-level normalization. - - Returns: - float: Maximum image-level value. - """ - return self._image_normalization_stats.max - - @property - def pixel_min(self) -> float: - """Get the minimum value for pixel-level normalization. - - Returns: - float: Minimum pixel-level value. - """ - return self._pixel_normalization_stats.min - - @property - def pixel_max(self) -> float: - """Get the maximum value for pixel-level normalization. - - Returns: - float: Maximum pixel-level value. - """ - return self._pixel_normalization_stats.max diff --git a/tests/unit/metrics/test_adaptive_threshold.py b/tests/unit/metrics/test_adaptive_threshold.py index 7f7dbf1796..e76a7effa2 100644 --- a/tests/unit/metrics/test_adaptive_threshold.py +++ b/tests/unit/metrics/test_adaptive_threshold.py @@ -18,7 +18,7 @@ ) def test_adaptive_threshold(labels: torch.Tensor, preds: torch.Tensor, target_threshold: int | float) -> None: """Test if the adaptive threshold computation returns the desired value.""" - adaptive_threshold = F1AdaptiveThreshold(default_value=0.5) + adaptive_threshold = F1AdaptiveThreshold() adaptive_threshold.update(preds, labels) threshold_value = adaptive_threshold.compute()