Skip to content

Commit

Permalink
Fix multi-GPU support for post-processor (openvinotoolkit#2499)
Browse files Browse the repository at this point in the history
* examples/notebooks/500_use_cases/501_dobot/501a_training_a_model_with_cubes_from_a_robotic_arm.ipynb: convert to Git LFS

* make post-processor compatible with multi-gpu

* remove obsolete repr method, fix adap thresh test

* use clamp instead of min/max

---------

Co-authored-by: Samet Akcay <[email protected]>
  • Loading branch information
djdameln and samet-akcay authored Jan 14, 2025
1 parent c587c5c commit 8d14ae1
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 98 deletions.
6 changes: 3 additions & 3 deletions src/anomalib/metrics/min_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ class MinMax(Metric):

def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.add_state("min", torch.tensor(float("inf")), persistent=True)
self.add_state("max", torch.tensor(float("-inf")), persistent=True)
self.add_state("min", torch.tensor(float("inf")), persistent=True, dist_reduce_fx="min")
self.add_state("max", torch.tensor(float("-inf")), persistent=True, dist_reduce_fx="max")

self.min = torch.tensor(float("inf"))
self.max = torch.tensor(float("-inf"))
Expand All @@ -84,8 +84,8 @@ def update(self, predictions: torch.Tensor, *args, **kwargs) -> None:
"""
del args, kwargs # These variables are not used.

self.max = torch.max(self.max, torch.max(predictions))
self.min = torch.min(self.min, torch.min(predictions))
self.max = torch.max(self.min, torch.max(predictions))

def compute(self) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute final minimum and maximum values.
Expand Down
31 changes: 3 additions & 28 deletions src/anomalib/metrics/threshold/f1_adaptive_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,6 @@ class F1AdaptiveThreshold(BinaryPrecisionRecallCurve, Threshold):
This class computes and stores the optimal threshold for converting anomaly
scores to binary predictions by maximizing the F1 score on validation data.
Args:
default_value: Initial threshold value used before computation.
Defaults to ``0.5``.
**kwargs: Additional arguments passed to parent classes.
Attributes:
value (torch.Tensor): Current threshold value.
Example:
>>> from anomalib.metrics import F1AdaptiveThreshold
>>> import torch
Expand All @@ -68,12 +60,6 @@ class F1AdaptiveThreshold(BinaryPrecisionRecallCurve, Threshold):
Optimal threshold: 0.5000
"""

def __init__(self, default_value: float = 0.5, **kwargs) -> None:
super().__init__(**kwargs)

self.add_state("value", default=torch.tensor(default_value), persistent=True)
self.value = torch.tensor(default_value)

def compute(self) -> torch.Tensor:
"""Compute optimal threshold by maximizing F1 score.
Expand Down Expand Up @@ -105,18 +91,7 @@ def compute(self) -> torch.Tensor:

precision, recall, thresholds = super().compute()
f1_score = (2 * precision * recall) / (precision + recall + 1e-10)
if thresholds.dim() == 0:
# special case where recall is 1.0 even for the highest threshold.
# In this case 'thresholds' will be scalar.
self.value = thresholds
else:
self.value = thresholds[torch.argmax(f1_score)]
return self.value

def __repr__(self) -> str:
"""Return string representation including current threshold value.
Returns:
str: String in format "ClassName(value=X.XX)"
"""
return f"{super().__repr__()} (value={self.value:.2f})"
# account for special case where recall is 1.0 even for the highest threshold.
# In this case 'thresholds' will be scalar.
return thresholds if thresholds.dim() == 0 else thresholds[torch.argmax(f1_score)]
94 changes: 28 additions & 66 deletions src/anomalib/post_processing/one_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,30 @@ def __init__(
) -> None:
super().__init__(**kwargs)

# configure sensitivity values
self.image_sensitivity = image_sensitivity
self.pixel_sensitivity = pixel_sensitivity

# initialize threshold and normalization metrics
self._image_threshold = F1AdaptiveThreshold()
self._pixel_threshold = F1AdaptiveThreshold()
self._image_normalization_stats = MinMax()
self._pixel_normalization_stats = MinMax()

self.image_sensitivity = image_sensitivity
self.pixel_sensitivity = pixel_sensitivity
# register buffers to persist threshold and normalization values
self.register_buffer("image_threshold", torch.tensor(0))
self.register_buffer("pixel_threshold", torch.tensor(0))
self.register_buffer("image_min", torch.tensor(0))
self.register_buffer("image_max", torch.tensor(1))
self.register_buffer("pixel_min", torch.tensor(0))
self.register_buffer("pixel_max", torch.tensor(1))

self.image_threshold: torch.Tensor
self.pixel_threshold: torch.Tensor
self.image_min: torch.Tensor
self.image_max: torch.Tensor
self.pixel_min: torch.Tensor
self.pixel_max: torch.Tensor

def on_validation_batch_end(
self,
Expand Down Expand Up @@ -103,13 +120,13 @@ def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule)
"""
del trainer, pl_module
if self._image_threshold.update_called:
self._image_threshold.compute()
self.image_threshold = self._image_threshold.compute()
if self._pixel_threshold.update_called:
self._pixel_threshold.compute()
self.pixel_threshold = self._pixel_threshold.compute()
if self._image_normalization_stats.update_called:
self._image_normalization_stats.compute()
self.image_min, self.image_max = self._image_normalization_stats.compute()
if self._pixel_normalization_stats.update_called:
self._pixel_normalization_stats.compute()
self.pixel_min, self.pixel_max = self._pixel_normalization_stats.compute()

def on_test_batch_end(
self,
Expand Down Expand Up @@ -168,8 +185,8 @@ def forward(self, predictions: InferenceBatch) -> InferenceBatch:
msg = "At least one of pred_score or anomaly_map must be provided."
raise ValueError(msg)
pred_score = predictions.pred_score or torch.amax(predictions.anomaly_map, dim=(-2, -1))
pred_score = self._normalize(pred_score, self.image_min, self.image_max, self.raw_image_threshold)
anomaly_map = self._normalize(predictions.anomaly_map, self.pixel_min, self.pixel_max, self.raw_pixel_threshold)
pred_score = self._normalize(pred_score, self.image_min, self.image_max, self.image_threshold)
anomaly_map = self._normalize(predictions.anomaly_map, self.pixel_min, self.pixel_max, self.pixel_threshold)
pred_label = self._threshold(pred_score, self.normalized_image_threshold)
pred_mask = self._threshold(anomaly_map, self.normalized_pixel_threshold)
return InferenceBatch(
Expand Down Expand Up @@ -216,9 +233,9 @@ def normalize_batch(self, batch: Batch) -> None:
batch (Batch): Batch containing model predictions.
"""
# normalize pixel-level predictions
batch.anomaly_map = self._normalize(batch.anomaly_map, self.pixel_min, self.pixel_max, self.raw_pixel_threshold)
batch.anomaly_map = self._normalize(batch.anomaly_map, self.pixel_min, self.pixel_max, self.pixel_threshold)
# normalize image-level predictions
batch.pred_score = self._normalize(batch.pred_score, self.image_min, self.image_max, self.raw_image_threshold)
batch.pred_score = self._normalize(batch.pred_score, self.image_min, self.image_max, self.image_threshold)

@staticmethod
def _threshold(preds: torch.Tensor | None, threshold: float) -> torch.Tensor | None:
Expand Down Expand Up @@ -256,26 +273,7 @@ def _normalize(
if preds is None:
return None
preds = ((preds - threshold) / (norm_max - norm_min)) + 0.5
preds = torch.minimum(preds, torch.tensor(1))
return torch.maximum(preds, torch.tensor(0))

@property
def raw_image_threshold(self) -> float:
"""Get the raw image-level threshold.
Returns:
float: Raw image-level threshold value.
"""
return self._image_threshold.value

@property
def raw_pixel_threshold(self) -> float:
"""Get the raw pixel-level threshold.
Returns:
float: Raw pixel-level threshold value.
"""
return self._pixel_threshold.value
return preds.clamp(min=0, max=1)

@property
def normalized_image_threshold(self) -> float:
Expand All @@ -298,39 +296,3 @@ def normalized_pixel_threshold(self) -> float:
if self.pixel_sensitivity is not None:
return 1 - self.pixel_sensitivity
return 0.5

@property
def image_min(self) -> float:
"""Get the minimum value for image-level normalization.
Returns:
float: Minimum image-level value.
"""
return self._image_normalization_stats.min

@property
def image_max(self) -> float:
"""Get the maximum value for image-level normalization.
Returns:
float: Maximum image-level value.
"""
return self._image_normalization_stats.max

@property
def pixel_min(self) -> float:
"""Get the minimum value for pixel-level normalization.
Returns:
float: Minimum pixel-level value.
"""
return self._pixel_normalization_stats.min

@property
def pixel_max(self) -> float:
"""Get the maximum value for pixel-level normalization.
Returns:
float: Maximum pixel-level value.
"""
return self._pixel_normalization_stats.max
2 changes: 1 addition & 1 deletion tests/unit/metrics/test_adaptive_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)
def test_adaptive_threshold(labels: torch.Tensor, preds: torch.Tensor, target_threshold: int | float) -> None:
"""Test if the adaptive threshold computation returns the desired value."""
adaptive_threshold = F1AdaptiveThreshold(default_value=0.5)
adaptive_threshold = F1AdaptiveThreshold()
adaptive_threshold.update(preds, labels)
threshold_value = adaptive_threshold.compute()

Expand Down

0 comments on commit 8d14ae1

Please sign in to comment.