Skip to content

Commit

Permalink
Simplify subcomponent resolve in base module (#2473)
Browse files Browse the repository at this point in the history
* simplify subcomponent resolve in base module

* add tests for AnomalibModule._resolve_component

* Update src/anomalib/models/components/base/anomalib_module.py

Co-authored-by: Ashwin Vaidya <[email protected]>

* formatting

---------

Co-authored-by: Ashwin Vaidya <[email protected]>
  • Loading branch information
djdameln and ashwinvaidya17 authored Jan 3, 2025
1 parent 03a5309 commit 8a82dc7
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 88 deletions.
127 changes: 39 additions & 88 deletions src/anomalib/models/components/base/anomalib_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
import logging
import warnings
from abc import ABC, abstractmethod
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -136,10 +136,10 @@ def __init__(
self.loss: nn.Module
self.callbacks: list[Callback]

self.pre_processor = self._resolve_pre_processor(pre_processor)
self.post_processor = self._resolve_post_processor(post_processor)
self.evaluator = self._resolve_evaluator(evaluator)
self.visualizer = self._resolve_visualizer(visualizer)
self.pre_processor = self._resolve_component(pre_processor, PreProcessor, self.configure_pre_processor)
self.post_processor = self._resolve_component(post_processor, PostProcessor, self.configure_post_processor)
self.evaluator = self._resolve_component(evaluator, Evaluator, self.configure_evaluator)
self.visualizer = self._resolve_component(visualizer, Visualizer, self.configure_visualizer)

self._input_size: tuple[int, int] | None = None

Expand Down Expand Up @@ -270,34 +270,46 @@ def learning_type(self) -> LearningType:
"""
raise NotImplementedError

def _resolve_pre_processor(self, pre_processor: PreProcessor | bool) -> PreProcessor | None:
"""Resolve and validate the pre-processor configuration.
@staticmethod
def _resolve_component(
component: nn.Module | None,
component_type: type,
default_callable: Callable,
) -> nn.Module | None:
"""Resolve and validate the subcomponent configuration.
This method resolves the configuration for various subcomponents like
pre-processor, post-processor, evaluator and visualizer. It validates
the configuration and returns the configured component. If the component
is a boolean, it uses the default callable to create the component. If
the component is already an instance of the component type, it returns
the component as is.
Args:
pre_processor (PreProcessor | bool): Pre-processor configuration
- ``True`` -> use default pre-processor
- ``False`` -> no pre-processor
- ``PreProcessor`` -> use provided pre-processor
component (object): Component configuration
component_type (Type): Type of the component
default_callable (Callable): Callable to create default component
Returns:
PreProcessor | None: Configured pre-processor
Component | None: Configured component
Raises:
TypeError: If pre_processor is invalid type
TypeError: If component is invalid type
"""
if isinstance(pre_processor, PreProcessor):
return pre_processor
if isinstance(pre_processor, bool):
return self.configure_pre_processor() if pre_processor else None
msg = f"Invalid pre-processor type: {type(pre_processor)}"
if isinstance(component, component_type):
return component
if isinstance(component, bool):
return default_callable() if component else None
msg = f"Passed object should be {component_type} or bool, got: {type(component)}"
raise TypeError(msg)

@classmethod
def configure_pre_processor(cls, image_size: tuple[int, int] | None = None) -> PreProcessor:
@staticmethod
def configure_pre_processor(image_size: tuple[int, int] | None = None) -> PreProcessor:
"""Configure the default pre-processor.
The default pre-processor resizes images and normalizes using ImageNet
statistics.
statistics. Override this method to provide a custom pre-processor for
the model.
Args:
image_size (tuple[int, int] | None, optional): Target size for
Expand All @@ -319,31 +331,12 @@ def configure_pre_processor(cls, image_size: tuple[int, int] | None = None) -> P
]),
)

def _resolve_post_processor(self, post_processor: PostProcessor | bool) -> PostProcessor | None:
"""Resolve and validate the post-processor configuration.
Args:
post_processor (PostProcessor | bool): Post-processor configuration
- ``True`` -> use default post-processor
- ``False`` -> no post-processor
- ``PostProcessor`` -> use provided post-processor
Returns:
PostProcessor | None: Configured post-processor
Raises:
TypeError: If post_processor is invalid type
"""
if isinstance(post_processor, PostProcessor):
return post_processor
if isinstance(post_processor, bool):
return self.configure_post_processor() if post_processor else None
msg = f"Invalid post-processor type: {type(post_processor)}"
raise TypeError(msg)

def configure_post_processor(self) -> PostProcessor | None:
"""Configure the default post-processor.
The default post-processor is based on the model's learning type. Override
this method to provide a custom post-processor for the model.
Returns:
PostProcessor | None: Configured post-processor based on learning type
Expand All @@ -365,34 +358,12 @@ def configure_post_processor(self) -> PostProcessor | None:
)
raise NotImplementedError(msg)

def _resolve_evaluator(self, evaluator: Evaluator | bool) -> Evaluator | None:
"""Resolve and validate the evaluator configuration.
Args:
evaluator (Evaluator | bool): Evaluator configuration
- ``True`` -> use default evaluator
- ``False`` -> no evaluator
- ``Evaluator`` -> use provided evaluator
Returns:
Evaluator | None: Configured evaluator
Raises:
TypeError: If evaluator is invalid type
"""
if isinstance(evaluator, Evaluator):
return evaluator
if isinstance(evaluator, bool):
return self.configure_evaluator() if evaluator else None
msg = f"evaluator must be of type Evaluator or bool, got {type(evaluator)}"
raise TypeError(msg)

@staticmethod
def configure_evaluator() -> Evaluator:
"""Configure the default evaluator.
The default evaluator includes metrics for both image-level and
pixel-level evaluation.
pixel-level evaluation. Override this method to provide custom metrics for the model.
Returns:
Evaluator: Configured evaluator with default metrics
Expand All @@ -409,32 +380,12 @@ def configure_evaluator() -> Evaluator:
test_metrics = [image_auroc, image_f1score, pixel_auroc, pixel_f1score]
return Evaluator(test_metrics=test_metrics)

def _resolve_visualizer(self, visualizer: Visualizer | bool) -> Visualizer | None:
"""Resolve and validate the visualizer configuration.
Args:
visualizer (Visualizer | bool): Visualizer configuration
- ``True`` -> use default visualizer
- ``False`` -> no visualizer
- ``Visualizer`` -> use provided visualizer
Returns:
Visualizer | None: Configured visualizer
Raises:
TypeError: If visualizer is invalid type
"""
if isinstance(visualizer, Visualizer):
return visualizer
if isinstance(visualizer, bool):
return self.configure_visualizer() if visualizer else None
msg = f"Visualizer must be of type Visualizer or bool, got {type(visualizer)}"
raise TypeError(msg)

@classmethod
def configure_visualizer(cls) -> ImageVisualizer:
"""Configure the default visualizer.
Override this method to provide a custom visualizer for the model.
Returns:
ImageVisualizer: Default image visualizer instance
Expand Down
58 changes: 58 additions & 0 deletions tests/unit/models/components/base/test_anomaly_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path

import pytest
from torch import nn

from anomalib.models.components.base import AnomalibModule

Expand Down Expand Up @@ -57,3 +58,60 @@ def test_from_config(self, model_name: str) -> None:
model = AnomalibModule.from_config(config_path=config_path)
assert model is not None
assert isinstance(model, AnomalibModule)


class TestResolveComponents:
"""Test AnomalibModule._resolve_component."""

class DummyComponent(nn.Module):
"""Dummy component class."""

def __init__(self, value: int) -> None:
self.value = value

@classmethod
def dummy_configure_component(cls) -> DummyComponent:
"""Dummy configure component method, simulates configure_<component> methods in module."""
return cls.DummyComponent(value=1)

def test_component_passed(self) -> None:
"""Test that the component is returned as is if it is an instance of the component type."""
component = self.DummyComponent(value=0)
resolved = AnomalibModule._resolve_component( # noqa: SLF001
component=component,
component_type=self.DummyComponent,
default_callable=self.dummy_configure_component,
)
assert isinstance(resolved, self.DummyComponent)
assert resolved.value == 0

def test_component_true(self) -> None:
"""Test that the default_callable is called if component is True."""
component = True
resolved = AnomalibModule._resolve_component( # noqa: SLF001
component=component,
component_type=self.DummyComponent,
default_callable=self.dummy_configure_component,
)
assert isinstance(resolved, self.DummyComponent)
assert resolved.value == 1

def test_component_false(self) -> None:
"""Test that None is returned if component is False."""
component = False
resolved = AnomalibModule._resolve_component( # noqa: SLF001
component=component,
component_type=self.DummyComponent,
default_callable=self.dummy_configure_component,
)
assert resolved is None

def test_raises_type_error(self) -> None:
"""Test that a TypeError is raised if the component is not of the correct type."""
component = 1
with pytest.raises(TypeError):
AnomalibModule._resolve_component( # noqa: SLF001
component=component,
component_type=self.DummyComponent,
default_callable=self.dummy_configure_component,
)

0 comments on commit 8a82dc7

Please sign in to comment.