diff --git a/CHANGELOG.md b/CHANGELOG.md index 036a2f0e49..a791b41e7f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,8 +8,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ### Added +- Add adversarially trained backbone models by @mzweilin in https://github.com/openvinotoolkit/anomalib/pull/2509 + ### Removed +- Remove the `__AT__` token in backbone names in favor of the timm style names in https://github.com/openvinotoolkit/anomalib/pull/2509 + ### Changed ### Deprecated diff --git a/src/anomalib/models/components/feature_extractors/timm.py b/src/anomalib/models/components/feature_extractors/timm.py index 2499394190..9687bbe21f 100644 --- a/src/anomalib/models/components/feature_extractors/timm.py +++ b/src/anomalib/models/components/feature_extractors/timm.py @@ -25,7 +25,7 @@ # SPDX-License-Identifier: Apache-2.0 import logging -from collections.abc import Sequence +from collections.abc import Callable, Sequence import timm import torch @@ -34,6 +34,77 @@ logger = logging.getLogger(__name__) +def register_model_with_adv_trained_weights_tags( + model_name: str, + epsilons: Sequence[float], + lp: str, + cfg_fn: Callable, +) -> None: + """Register adversarially trained model weights with a URL.""" + from timm.models._registry import _model_pretrained_cfgs as model_pretrained_cfgs + from timm.models._registry import generate_default_cfgs + + origin_url = "https://huggingface.co/madrylab/robust-imagenet-models" + paper_ids = "arXiv:2007.08489" + + cfgs = {} + for eps in epsilons: + url = f"https://huggingface.co/mzweilin/robust-imagenet-models/resolve/main/{model_name}_{lp}_eps{eps}.pth" + tag = f"adv_{lp}_{eps}" + model_and_tag = f"{model_name}.{tag}" + cfgs[model_and_tag] = cfg_fn( + url=url, + origin_url=origin_url, + paper_ids=paper_ids, + ) + + default_cfgs = generate_default_cfgs(cfgs) + + for model_and_tag in cfgs: + tag = model_and_tag[len(model_name) + 1 :] # Remove "[MODEL NAME]." + if model_and_tag in model_pretrained_cfgs: + logger.warning(f"Overriding model weights registration in timm: {model_and_tag}") + model_pretrained_cfgs[model_and_tag] = default_cfgs[model_name].cfgs[tag] + logger.info(f"Register model weights in timm: {model_and_tag}") + + +def register_in_bulk() -> None: + """Register adversarially trained model weights in timm.""" + from timm.models.resnet import _cfg as resnet_cfg_fn + + l2_epsilons = [0, 0.01, 0.03, 0.05, 0.1, 0.25, 0.5, 1, 3, 5] + linf_epsilons = [0, 0.5, 1, 2, 4, 8] + model_names = ["resnet18", "resnet50", "wide_resnet50_2"] + cfg_fn = resnet_cfg_fn + for model_name in model_names: + register_model_with_adv_trained_weights_tags( + model_name=model_name, + epsilons=l2_epsilons, + lp="l2", + cfg_fn=cfg_fn, + ) + register_model_with_adv_trained_weights_tags( + model_name=model_name, + epsilons=linf_epsilons, + lp="linf", + cfg_fn=cfg_fn, + ) + + +def try_register_in_bulk() -> None: + """Catch the error in case we cannot register new weights in timm due to changes of internal APIs.""" + try: + register_in_bulk() + except ImportError as e: + logger.warning( + f"Adversarially trained backbones are not available. An error occured when registering weights: {e}", + ) + + +# We will register model weights only once even if we import the module repeatedly, because it is a singleton. +try_register_in_bulk() + + class TimmFeatureExtractor(nn.Module): """Extract intermediate features from timm models. @@ -84,15 +155,6 @@ def __init__( ) -> None: super().__init__() - # Extract backbone-name and weight-URI from the backbone string. - if "__AT__" in backbone: - backbone, uri = backbone.split("__AT__") - pretrained_cfg = timm.models.registry.get_pretrained_cfg(backbone) - # Override pretrained_cfg["url"] to use different pretrained weights. - pretrained_cfg["url"] = uri - else: - pretrained_cfg = None - self.backbone = backbone self.layers = list(layers) self.idx = self._map_layer_to_idx() @@ -100,7 +162,6 @@ def __init__( self.feature_extractor = timm.create_model( backbone, pretrained=pre_trained, - pretrained_cfg=pretrained_cfg, features_only=True, exportable=True, out_indices=self.idx, diff --git a/src/anomalib/models/image/padim/torch_model.py b/src/anomalib/models/image/padim/torch_model.py index 317299bdc6..47221ad587 100644 --- a/src/anomalib/models/image/padim/torch_model.py +++ b/src/anomalib/models/image/padim/torch_model.py @@ -125,7 +125,11 @@ def __init__( pre_trained=pre_trained, ).eval() self.n_features_original = sum(self.feature_extractor.out_dims) - self.n_features = n_features or _N_FEATURES_DEFAULTS.get(self.backbone) + + # The backbone tag may include weights, e.g. resnet18.adv_l2_0.1 + backbone_name = self.backbone.split(".")[0] + self.n_features = n_features or _N_FEATURES_DEFAULTS.get(backbone_name) + if self.n_features is None: msg = ( f"n_features must be specified for backbone {self.backbone}. "