From 07a68ff82618870efa1adf36bff219217a7d48b6 Mon Sep 17 00:00:00 2001 From: alexriedel1 Date: Thu, 5 Dec 2024 18:27:37 +0100 Subject: [PATCH] new perlin --- src/anomalib/data/transforms/__init__.py | 3 +- src/anomalib/data/utils/__init__.py | 7 +- src/anomalib/data/utils/augmenter.py | 171 -------- .../data/utils/generators/__init__.py | 4 +- src/anomalib/data/utils/generators/perlin.py | 401 ++++++++++++------ src/anomalib/data/utils/synthetic.py | 8 +- .../models/image/draem/lightning_model.py | 4 +- .../models/image/dsr/anomaly_generator.py | 1 - .../models/image/dsr/lightning_model.py | 5 +- 9 files changed, 295 insertions(+), 309 deletions(-) delete mode 100644 src/anomalib/data/utils/augmenter.py diff --git a/src/anomalib/data/transforms/__init__.py b/src/anomalib/data/transforms/__init__.py index 146fb19e15..89a5c673d2 100644 --- a/src/anomalib/data/transforms/__init__.py +++ b/src/anomalib/data/transforms/__init__.py @@ -4,5 +4,6 @@ # SPDX-License-Identifier: Apache-2.0 from .center_crop import ExportableCenterCrop +from .multi_random_choice import MultiRandomChoice -__all__ = ["ExportableCenterCrop"] +__all__ = ["ExportableCenterCrop", "MultiRandomChoice"] diff --git a/src/anomalib/data/utils/__init__.py b/src/anomalib/data/utils/__init__.py index e75ba5bf49..eae2ae6bc6 100644 --- a/src/anomalib/data/utils/__init__.py +++ b/src/anomalib/data/utils/__init__.py @@ -3,10 +3,10 @@ # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from .augmenter import Augmenter from .boxes import boxes_to_anomaly_maps, boxes_to_masks, masks_to_boxes from .download import DownloadInfo, download_and_extract -from .generators import random_2d_perlin +from .generators import generate_perlin_noise + from .image import ( generate_output_image_filename, get_image_filenames, @@ -30,7 +30,7 @@ "generate_output_image_filename", "get_image_filenames", "get_image_height_and_width", - "random_2d_perlin", + "generate_perlin_noise", "read_image", "read_mask", "read_depth_image", @@ -42,7 +42,6 @@ "TestSplitMode", "LabelName", "DirType", - "Augmenter", "masks_to_boxes", "boxes_to_masks", "boxes_to_anomaly_maps", diff --git a/src/anomalib/data/utils/augmenter.py b/src/anomalib/data/utils/augmenter.py deleted file mode 100644 index aa35434773..0000000000 --- a/src/anomalib/data/utils/augmenter.py +++ /dev/null @@ -1,171 +0,0 @@ -"""Augmenter module to generates out-of-distribution samples for the DRAEM implementation.""" - -# Original Code -# Copyright (c) 2021 VitjanZ -# https://github.com/VitjanZ/DRAEM. -# SPDX-License-Identifier: MIT -# -# Modified -# Copyright (C) 2022-2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -import math -import random -from pathlib import Path - -import cv2 -import imgaug.augmenters as iaa -import numpy as np -import torch -from PIL import Image -from torchvision.datasets.folder import IMG_EXTENSIONS - -from anomalib.data.utils.generators.perlin import random_2d_perlin - - -def nextpow2(value: int) -> int: - """Return the smallest power of 2 greater than or equal to the input value.""" - return 2 ** (math.ceil(math.log(value, 2))) - - -class Augmenter: - """Class that generates noisy augmentations of input images. - - Args: - anomaly_source_path (str | None): Path to a folder of images that will be used as source of the anomalous - noise. If not specified, random noise will be used instead. - p_anomalous (float): Probability that the anomalous perturbation will be applied to a given image. - beta (float): Parameter that determines the opacity of the noise mask. - """ - - def __init__( - self, - anomaly_source_path: str | None = None, - p_anomalous: float = 0.5, - beta: float | tuple[float, float] = (0.2, 1.0), - ) -> None: - self.p_anomalous = p_anomalous - self.beta = beta - - self.anomaly_source_paths: list[Path] = [] - if anomaly_source_path is not None: - for img_ext in IMG_EXTENSIONS: - self.anomaly_source_paths.extend(Path(anomaly_source_path).rglob("*" + img_ext)) - - self.augmenters = [ - iaa.GammaContrast((0.5, 2.0), per_channel=True), - iaa.MultiplyAndAddToBrightness(mul=(0.8, 1.2), add=(-30, 30)), - iaa.pillike.EnhanceSharpness(), - iaa.AddToHueAndSaturation((-50, 50), per_channel=True), - iaa.Solarize(0.5, threshold=(32, 128)), - iaa.Posterize(), - iaa.Invert(), - iaa.pillike.Autocontrast(), - iaa.pillike.Equalize(), - iaa.Affine(rotate=(-45, 45)), - ] - self.rot = iaa.Sequential([iaa.Affine(rotate=(-90, 90))]) - - def rand_augmenter(self) -> iaa.Sequential: - """Select 3 random transforms that will be applied to the anomaly source images. - - Returns: - A selection of 3 transforms. - """ - aug_ind = np.random.default_rng().choice(np.arange(len(self.augmenters)), 3, replace=False) - return iaa.Sequential([self.augmenters[aug_ind[0]], self.augmenters[aug_ind[1]], self.augmenters[aug_ind[2]]]) - - def generate_perturbation( - self, - height: int, - width: int, - anomaly_source_path: Path | str | None = None, - ) -> tuple[np.ndarray, np.ndarray]: - """Generate an image containing a random anomalous perturbation using a source image. - - Args: - height (int): height of the generated image. - width: (int): width of the generated image. - anomaly_source_path (Path | str | None): Path to an image file. If not provided, random noise will be used - instead. - - Returns: - Image containing a random anomalous perturbation, and the corresponding ground truth anomaly mask. - """ - # Generate random perlin noise - perlin_scale = 6 - min_perlin_scale = 0 - - perlin_scalex = 2 ** np.random.default_rng().integers(min_perlin_scale, perlin_scale) - perlin_scaley = 2 ** np.random.default_rng().integers(min_perlin_scale, perlin_scale) - - perlin_noise = random_2d_perlin((nextpow2(height), nextpow2(width)), (perlin_scalex, perlin_scaley))[ - :height, - :width, - ] - perlin_noise = self.rot(image=perlin_noise) - - # Create mask from perlin noise - mask = np.where(perlin_noise > 0.5, np.ones_like(perlin_noise), np.zeros_like(perlin_noise)) - mask = np.expand_dims(mask, axis=2).astype(np.float32) - - # Load anomaly source image - if anomaly_source_path: - anomaly_source_img = np.array(Image.open(anomaly_source_path)) - anomaly_source_img = cv2.resize(anomaly_source_img, dsize=(width, height)) - else: # if no anomaly source is specified, we use the perlin noise as anomalous source - anomaly_source_img = np.expand_dims(perlin_noise, 2).repeat(3, 2) - anomaly_source_img = (anomaly_source_img * 255).astype(np.uint8) - - # Augment anomaly source image - aug = self.rand_augmenter() - anomaly_img_augmented = aug(image=anomaly_source_img) - - # Create anomalous perturbation that we will apply to the image - perturbation = anomaly_img_augmented.astype(np.float32) * mask / 255.0 - - return perturbation, mask - - def augment_batch(self, batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """Generate anomalous augmentations for a batch of input images. - - Args: - batch (torch.Tensor): Batch of input images - - Returns: - - Augmented image to which anomalous perturbations have been added. - - Ground truth masks corresponding to the anomalous perturbations. - """ - batch_size, channels, height, width = batch.shape - - # Collect perturbations - perturbations_list = [] - masks_list = [] - for _ in range(batch_size): - if torch.rand(1) > self.p_anomalous: # include normal samples - perturbations_list.append(torch.zeros((channels, height, width))) - masks_list.append(torch.zeros((1, height, width))) - else: - anomaly_source_path = ( - random.sample(self.anomaly_source_paths, 1)[0] if len(self.anomaly_source_paths) > 0 else None - ) - perturbation, mask = self.generate_perturbation(height, width, anomaly_source_path) - perturbations_list.append(torch.Tensor(perturbation).permute((2, 0, 1))) - masks_list.append(torch.Tensor(mask).permute((2, 0, 1))) - - perturbations = torch.stack(perturbations_list).to(batch.device) - masks = torch.stack(masks_list).to(batch.device) - - # Apply perturbations batch wise - if isinstance(self.beta, float): - beta = self.beta - elif isinstance(self.beta, tuple): - beta = torch.rand(batch_size) * (self.beta[1] - self.beta[0]) + self.beta[0] - beta = beta.view(batch_size, 1, 1, 1).expand_as(batch).to(batch.device) # type: ignore[attr-defined] - else: - msg = "Beta must be either float or tuple of floats" - raise TypeError(msg) - - augmented_batch = batch * (1 - masks) + (beta) * perturbations + (1 - beta) * batch * (masks) - - return augmented_batch, masks diff --git a/src/anomalib/data/utils/generators/__init__.py b/src/anomalib/data/utils/generators/__init__.py index a79bad9770..ed79a081df 100644 --- a/src/anomalib/data/utils/generators/__init__.py +++ b/src/anomalib/data/utils/generators/__init__.py @@ -3,6 +3,6 @@ # Copyright (C) 2022 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from .perlin import random_2d_perlin +from .perlin import PerlinAnomalyGenerator, generate_perlin_noise -__all__ = ["random_2d_perlin"] +__all__ = ["PerlinAnomalyGenerator", "generate_perlin_noise"] \ No newline at end of file diff --git a/src/anomalib/data/utils/generators/perlin.py b/src/anomalib/data/utils/generators/perlin.py index fa683d7546..00fca00c9a 100644 --- a/src/anomalib/data/utils/generators/perlin.py +++ b/src/anomalib/data/utils/generators/perlin.py @@ -1,160 +1,317 @@ -"""Helper functions for generating Perlin noise.""" - -# Original Code -# Copyright (c) 2021 VitjanZ -# https://github.com/VitjanZ/DRAEM. -# SPDX-License-Identifier: MIT -# -# Modified +"""Perlin noise-based synthetic anomaly generator.""" + # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -# ruff: noqa - -import math +from pathlib import Path -import numpy as np import torch +from torchvision import io +from torchvision.datasets.folder import IMG_EXTENSIONS +from torchvision.transforms import v2 +from anomalib.data.transforms import MultiRandomChoice -def lerp_np(x, y, w): - """Helper function.""" - return (y - x) * w + x - - -def rand_perlin_2d_octaves_np(shape, res, octaves=1, persistence=0.5): - """Generate Perlin noise parameterized by the octaves method. Numpy version.""" - noise = np.zeros(shape) - frequency = 1 - amplitude = 1 - for _ in range(octaves): - noise += amplitude * generate_perlin_noise_2d(shape, (frequency * res[0], frequency * res[1])) - frequency *= 2 - amplitude *= persistence - return noise +def generate_perlin_noise( + height: int, + width: int, + scale: tuple[int, int] | None = None, + device: torch.device | None = None, +) -> torch.Tensor: + """Generate a Perlin noise pattern. -def generate_perlin_noise_2d(shape, res): - """Fractal perlin noise.""" - - def f(t): - return 6 * t**5 - 15 * t**4 + 10 * t**3 - - delta = (res[0] / shape[0], res[1] / shape[1]) - d = (shape[0] // res[0], shape[1] // res[1]) - grid = np.mgrid[0 : res[0] : delta[0], 0 : res[1] : delta[1]].transpose(1, 2, 0) % 1 - # Gradients - angles = 2 * np.pi * np.random.default_rng().random(res[0] + 1, res[1] + 1) - gradients = np.dstack((np.cos(angles), np.sin(angles))) - g00 = gradients[0:-1, 0:-1].repeat(d[0], 0).repeat(d[1], 1) - g10 = gradients[1:, 0:-1].repeat(d[0], 0).repeat(d[1], 1) - g01 = gradients[0:-1, 1:].repeat(d[0], 0).repeat(d[1], 1) - g11 = gradients[1:, 1:].repeat(d[0], 0).repeat(d[1], 1) - # Ramps - n00 = np.sum(grid * g00, 2) - n10 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1])) * g10, 2) - n01 = np.sum(np.dstack((grid[:, :, 0], grid[:, :, 1] - 1)) * g01, 2) - n11 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1] - 1)) * g11, 2) - # Interpolation - t = f(grid) - n0 = n00 * (1 - t[:, :, 0]) + t[:, :, 0] * n10 - n1 = n01 * (1 - t[:, :, 0]) + t[:, :, 0] * n11 - return np.sqrt(2) * ((1 - t[:, :, 1]) * n0 + t[:, :, 1] * n1) - - -def random_2d_perlin( - shape: tuple, - res: tuple[int | torch.Tensor, int | torch.Tensor], - fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3, -) -> np.ndarray | torch.Tensor: - """Returns a random 2d perlin noise array. + This function generates a Perlin noise pattern using a grid-based gradient noise approach. + The noise is generated by interpolating between randomly generated gradient vectors at grid vertices. + The interpolation uses a quintic curve for smooth transitions. Args: - shape (tuple): Shape of the 2d map. - res (tuple[int | torch.Tensor, int | torch.Tensor]): Tuple of scales for perlin noise for height and width dimension. - fade (_type_, optional): Function used for fading the resulting 2d map. - Defaults to equation 6*t**5-15*t**4+10*t**3. + height: Desired height of the noise pattern + width: Desired width of the noise pattern + scale: Tuple of (scale_x, scale_y) for noise granularity. If None, random scales will be used. + Larger scales produce coarser noise patterns, while smaller scales produce finer patterns. + device: Device to generate the noise on. If None, uses current default device Returns: - np.ndarray | torch.Tensor: Random 2d-array/tensor generated using perlin noise. - """ - if isinstance(res[0], int | np.integer): - result = _rand_perlin_2d_np(shape, res, fade) - elif isinstance(res[0], torch.Tensor): - result = _rand_perlin_2d(shape, res, fade) - else: - msg = f"got scales of type {type(res[0])}" - raise TypeError(msg) - return result + Tensor of shape [height, width] containing the noise pattern, with values roughly in [-1, 1] range + Examples: + >>> # Generate 256x256 noise with default random scale + >>> noise = generate_perlin_noise(256, 256) + >>> print(noise.shape) + torch.Size([256, 256]) -def _rand_perlin_2d_np(shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3): - """Generate a random image containing Perlin noise. Numpy version.""" - delta = (res[0] / shape[0], res[1] / shape[1]) - d = (shape[0] // res[0], shape[1] // res[1]) - grid = np.mgrid[0 : res[0] : delta[0], 0 : res[1] : delta[1]].transpose(1, 2, 0) % 1 + >>> # Generate 512x512 noise with fixed scale + >>> noise = generate_perlin_noise(512, 512, scale=(8, 8)) + >>> print(noise.shape) + torch.Size([512, 512]) - angles = 2 * math.pi * np.random.default_rng().random((res[0] + 1, res[1] + 1)) - gradients = np.stack((np.cos(angles), np.sin(angles)), axis=-1) + >>> # Generate noise on GPU if available + >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + >>> noise = generate_perlin_noise(128, 128, device=device) + """ + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - def tile_grads(slice1, slice2): - return np.repeat(np.repeat(gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]], d[0], axis=0), d[1], axis=1) + # Handle scale parameter + if scale is None: + min_scale, max_scale = 0, 6 + scalex = 2 ** torch.randint(min_scale, max_scale, (1,), device=device).item() + scaley = 2 ** torch.randint(min_scale, max_scale, (1,), device=device).item() + else: + scalex, scaley = scale - def dot(grad, shift): - return ( - np.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), axis=-1) - * grad[: shape[0], : shape[1]] - ).sum(axis=-1) + # Ensure dimensions are powers of 2 for proper noise generation + def nextpow2(value: int) -> int: + return int(2 ** torch.ceil(torch.log2(torch.tensor(value))).int().item()) - n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) - n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) - n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) - n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) - t = fade(grid[: shape[0], : shape[1]]) - return math.sqrt(2) * lerp_np(lerp_np(n00, n10, t[..., 0]), lerp_np(n01, n11, t[..., 0]), t[..., 1]) + pad_h = nextpow2(height) + pad_w = nextpow2(width) + # Generate base grid + delta = (scalex / pad_h, scaley / pad_w) + d = (pad_h // scalex, pad_w // scaley) -def _rand_perlin_2d(shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3): - """Generate a random image containing Perlin noise. PyTorch version.""" - delta = (res[0] / shape[0], res[1] / shape[1]) - d = (shape[0] // res[0], shape[1] // res[1]) + grid = ( + torch.stack( + torch.meshgrid( + torch.arange(0, scalex, delta[0], device=device), + torch.arange(0, scaley, delta[1], device=device), + indexing="ij", + ), + dim=-1, + ) + % 1 + ) - grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1])), dim=-1) % 1 - angles = 2 * math.pi * torch.rand(res[0] + 1, res[1] + 1) + # Generate random gradients + angles = 2 * torch.pi * torch.rand(int(scalex) + 1, int(scaley) + 1, device=device) gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1) - def tile_grads(slice1, slice2): + def tile_grads(slice1: list[int | None], slice2: list[int | None]) -> torch.Tensor: return ( gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]] - .repeat_interleave(d[0], 0) - .repeat_interleave(d[1], 1) + .repeat_interleave(int(d[0]), 0) + .repeat_interleave(int(d[1]), 1) ) - def dot(grad, shift): + def dot(grad: torch.Tensor, shift: list[float]) -> torch.Tensor: return ( torch.stack( - (grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), + (grid[:pad_h, :pad_w, 0] + shift[0], grid[:pad_h, :pad_w, 1] + shift[1]), dim=-1, ) - * grad[: shape[0], : shape[1]] + * grad[:pad_h, :pad_w] ).sum(dim=-1) + # Calculate noise values at grid points n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) - n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) - t = fade(grid[: shape[0], : shape[1]]) - return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]) - - -def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5): - """Generate Perlin noise parameterized by the octaves method. PyTorch version.""" - noise = torch.zeros(shape) - frequency = 1 - amplitude = 1 - for _ in range(octaves): - noise += amplitude * _rand_perlin_2d(shape, (frequency * res[0], frequency * res[1])) - frequency *= 2 - amplitude *= persistence - return noise + + # Interpolate between grid points using quintic curve + def fade(t: torch.Tensor) -> torch.Tensor: + return 6 * t**5 - 15 * t**4 + 10 * t**3 + + t = fade(grid[:pad_h, :pad_w]) + noise = torch.sqrt(torch.tensor(2.0, device=device)) * torch.lerp( + torch.lerp(n00, n10, t[..., 0]), + torch.lerp(n01, n11, t[..., 0]), + t[..., 1], + ) + + # Crop to desired dimensions + return noise[:height, :width] + + +class PerlinAnomalyGenerator(v2.Transform): + """Perlin noise-based synthetic anomaly generator. + + Examples: + >>> # Single image usage with default parameters + >>> transform = PerlinAnomalyGenerator() + >>> image = torch.randn(3, 256, 256) # [C, H, W] + >>> augmented_image, anomaly_mask = transform(image) + >>> print(augmented_image.shape) # [C, H, W] + >>> print(anomaly_mask.shape) # [1, H, W] + + >>> # Batch usage with custom parameters + >>> transform = PerlinAnomalyGenerator( + ... probability=0.8, + ... blend_factor=0.5 + ... ) + >>> batch = torch.randn(4, 3, 256, 256) # [B, C, H, W] + >>> augmented_batch, anomaly_masks = transform(batch) + >>> print(augmented_batch.shape) # [B, C, H, W] + >>> print(anomaly_masks.shape) # [B, 1, H, W] + + >>> # Using anomaly source images + >>> transform = PerlinAnomalyGenerator( + ... anomaly_source_path='path/to/anomaly/images', + ... probability=0.7, + ... blend_factor=(0.3, 0.9), + ... rotation_range=(-45, 45) + ... ) + >>> augmented_image, anomaly_mask = transform(image) + """ + + def __init__( + self, + anomaly_source_path: str | None = None, + probability: float = 0.5, + blend_factor: float | tuple[float, float] = (0.2, 1.0), + rotation_range: tuple[float, float] = (-90, 90), + ) -> None: + super().__init__() + self.probability = probability + self.blend_factor = blend_factor + + # Load anomaly source paths + self.anomaly_source_paths: list[Path] = [] + if anomaly_source_path is not None: + for img_ext in IMG_EXTENSIONS: + self.anomaly_source_paths.extend(Path(anomaly_source_path).rglob("*" + img_ext)) + + # Initialize perlin rotation transform + self.perlin_rotation_transform = v2.RandomAffine( + degrees=rotation_range, + interpolation=v2.InterpolationMode.BILINEAR, + fill=0, + ) + + # Initialize augmenters + self.augmenters = MultiRandomChoice( + transforms=[ + v2.ColorJitter(contrast=(0.5, 2.0)), + v2.RandomPhotometricDistort( + brightness=(0.8, 1.2), + contrast=(1.0, 1.0), # No contrast change + saturation=(1.0, 1.0), # No saturation change + hue=(0.0, 0.0), # No hue change + p=1.0, + ), + v2.RandomAdjustSharpness(sharpness_factor=2.0, p=1.0), + v2.ColorJitter(hue=[-50 / 360, 50 / 360], saturation=[0.5, 1.5]), + v2.RandomSolarize(threshold=torch.empty(1).uniform_(32 / 255, 128 / 255).item(), p=1.0), + v2.RandomPosterize(bits=4, p=1.0), + v2.RandomInvert(p=1.0), + v2.AutoAugment(), + v2.RandomEqualize(p=1.0), + v2.RandomAffine(degrees=(-45, 45), interpolation=v2.InterpolationMode.BILINEAR, fill=0), + ], + probabilities=None, + num_transforms=3, + fixed_num_transforms=True, + ) + + def generate_perturbation( + self, + height: int, + width: int, + device: torch.device | None = None, + anomaly_source_path: Path | str | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Generate perturbed image and mask. + + Args: + height: Height of the output image + width: Width of the output image + device: Device to generate the perturbation on + anomaly_source_path: Optional path to source image for anomaly + + Returns: + tuple[torch.Tensor, torch.Tensor]: Perturbation and mask tensors + """ + # Generate perlin noise + perlin_noise = generate_perlin_noise(height, width, device=device) + + # Create rotated noise pattern + perlin_noise = perlin_noise.unsqueeze(0) # [1, H, W] + perlin_noise = self.perlin_rotation_transform(perlin_noise).squeeze(0) # [H, W] + + # Generate binary mask from perlin noise + mask = torch.where( + perlin_noise > 0.5, + torch.ones_like(perlin_noise, device=device), + torch.zeros_like(perlin_noise, device=device), + ).unsqueeze(-1) # [H, W, 1] + + # Generate anomaly source image + if anomaly_source_path: + anomaly_source_img = ( + io.read_image(str(anomaly_source_path), mode=io.ImageReadMode.RGB).float().to(device) / 255.0 + ) + if anomaly_source_img.shape[-2:] != (height, width): + anomaly_source_img = v2.functional.resize(anomaly_source_img, [height, width], antialias=True) + anomaly_source_img = anomaly_source_img.permute(1, 2, 0) # [H, W, C] + else: + anomaly_source_img = perlin_noise.unsqueeze(-1).repeat(1, 1, 3) # [H, W, C] + anomaly_source_img = (anomaly_source_img * 0.5) + 0.25 # Adjust intensity range + + # Apply augmentations to source image + anomaly_augmented = self.augmenters(anomaly_source_img.permute(2, 0, 1)) # [C, H, W] + anomaly_augmented = anomaly_augmented.permute(1, 2, 0) # [H, W, C] + + # Create final perturbation by applying mask + perturbation = anomaly_augmented * mask + + return perturbation, mask + + def _transform_image( + self, + img: torch.Tensor, + h: int, + w: int, + device: torch.device, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Transform a single image.""" + if torch.rand(1, device=device) > self.probability: + return img, torch.zeros((1, h, w), device=device) + + anomaly_source_path = ( + list(self.anomaly_source_paths)[int(torch.randint(len(self.anomaly_source_paths), (1,)).item())] + if self.anomaly_source_paths + else None + ) + + perturbation, mask = self.generate_perturbation(h, w, device, anomaly_source_path) + perturbation = perturbation.permute(2, 0, 1) + mask = mask.permute(2, 0, 1) + + beta = ( + self.blend_factor + if isinstance(self.blend_factor, float) + else torch.rand(1, device=device) * (self.blend_factor[1] - self.blend_factor[0]) + self.blend_factor[0] + if isinstance(self.blend_factor, tuple) + # Add type guard + else torch.tensor(0.5, device=device) # Fallback value + ) + + if not isinstance(beta, float): + beta = beta.view(-1, 1, 1).expand_as(img) + + augmented_img = img * (1 - mask) + beta * perturbation + (1 - beta) * img * mask + return augmented_img, mask + + def forward(self, img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Apply augmentation using the mask for single image or batch.""" + device = img.device + is_batch = len(img.shape) == 4 + + if is_batch: + batch, _, height, width = img.shape + # Initialize batch outputs + batch_augmented = [] + batch_masks = [] + + for i in range(batch): + # Apply transform to each image in batch + augmented, mask = self._transform_image(img[i], height, width, device) + batch_augmented.append(augmented) + batch_masks.append(mask) + + return torch.stack(batch_augmented), torch.stack(batch_masks) + + # Handle single image + return self._transform_image(img, img.shape[1], img.shape[2], device) \ No newline at end of file diff --git a/src/anomalib/data/utils/synthetic.py b/src/anomalib/data/utils/synthetic.py index 20ba836bee..ce06cdfa66 100644 --- a/src/anomalib/data/utils/synthetic.py +++ b/src/anomalib/data/utils/synthetic.py @@ -20,7 +20,8 @@ from anomalib import TaskType from anomalib.data.base.dataset import AnomalibDataset -from anomalib.data.utils import Augmenter, Split, read_image +from anomalib.data.utils import Split, read_image +from anomalib.data.utils.generators.perlin import PerlinAnomalyGenerator logger = logging.getLogger(__name__) @@ -66,7 +67,8 @@ def make_synthetic_dataset( anomalous_samples = anomalous_samples.reset_index(drop=True) # initialize augmenter - augmenter = Augmenter("./datasets/dtd", p_anomalous=1.0, beta=(0.01, 0.2)) + #augmenter = Augmenter("./datasets/dtd", p_anomalous=1.0, beta=(0.01, 0.2)) + augmenter = PerlinAnomalyGenerator(anomaly_source_path="./datasets/dtd", probability=1.0, blend_factor=(0.01, 0.2)) def augment(sample: Series) -> Series: """Apply synthetic anomalous augmentation to a sample from a dataframe. @@ -83,7 +85,7 @@ def augment(sample: Series) -> Series: # read and transform image image = read_image(sample.image_path, as_tensor=True) # apply anomalous perturbation - aug_im, mask = augmenter.augment_batch(image.unsqueeze(0)) + aug_im, mask = augmenter(image) # target file name with leading zeros file_name = f"{str(sample.name).zfill(int(math.log10(n_anomalous)) + 1)}.png" # write image diff --git a/src/anomalib/models/image/draem/lightning_model.py b/src/anomalib/models/image/draem/lightning_model.py index 6eb0e197fc..c7897c141f 100644 --- a/src/anomalib/models/image/draem/lightning_model.py +++ b/src/anomalib/models/image/draem/lightning_model.py @@ -15,7 +15,7 @@ from torchvision.transforms.v2 import Compose, Resize, Transform from anomalib import LearningType -from anomalib.data.utils import Augmenter +from anomalib.data.utils.generators.perlin import PerlinAnomalyGenerator from anomalib.models.components import AnomalyModule from .loss import DraemLoss @@ -46,7 +46,7 @@ def __init__( ) -> None: super().__init__() - self.augmenter = Augmenter(anomaly_source_path, beta=beta) + self.augmenter = PerlinAnomalyGenerator(anomaly_source_path, beta=beta) self.model = DraemModel(sspcab=enable_sspcab) self.loss = DraemLoss() self.sspcab = enable_sspcab diff --git a/src/anomalib/models/image/dsr/anomaly_generator.py b/src/anomalib/models/image/dsr/anomaly_generator.py index 396019de39..641f25a692 100644 --- a/src/anomalib/models/image/dsr/anomaly_generator.py +++ b/src/anomalib/models/image/dsr/anomaly_generator.py @@ -8,7 +8,6 @@ import torch from torch import Tensor, nn -from anomalib.data.utils.generators.perlin import _rand_perlin_2d_np class DsrAnomalyGenerator(nn.Module): diff --git a/src/anomalib/models/image/dsr/lightning_model.py b/src/anomalib/models/image/dsr/lightning_model.py index 8381fce73d..de11e6e784 100644 --- a/src/anomalib/models/image/dsr/lightning_model.py +++ b/src/anomalib/models/image/dsr/lightning_model.py @@ -17,9 +17,8 @@ from anomalib import LearningType from anomalib.data.utils import DownloadInfo, download_and_extract -from anomalib.data.utils.augmenter import Augmenter +from anomalib.data.utils.generators.perlin import PerlinAnomalyGenerator from anomalib.models.components import AnomalyModule -from anomalib.models.image.dsr.anomaly_generator import DsrAnomalyGenerator from anomalib.models.image.dsr.loss import DsrSecondStageLoss, DsrThirdStageLoss from anomalib.models.image.dsr.torch_model import DsrModel @@ -49,7 +48,7 @@ def __init__(self, latent_anomaly_strength: float = 0.2, upsampling_train_ratio: self.upsampling_train_ratio = upsampling_train_ratio self.quantized_anomaly_generator = DsrAnomalyGenerator() - self.perlin_generator = Augmenter() + self.perlin_generator = PerlinAnomalyGenerator() self.model = DsrModel(latent_anomaly_strength) self.second_stage_loss = DsrSecondStageLoss() self.third_stage_loss = DsrThirdStageLoss()