Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Register adversarially trained backbones in timm #2509

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

### Added

- Add adversarially trained backbone models by @mzweilin in https://github.com/openvinotoolkit/anomalib/pull/2509

### Removed

- Remove the `__AT__` token in backbone names in favor of the timm style names in https://github.com/openvinotoolkit/anomalib/pull/2509

### Changed

### Deprecated
Expand Down
83 changes: 72 additions & 11 deletions src/anomalib/models/components/feature_extractors/timm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# SPDX-License-Identifier: Apache-2.0

import logging
from collections.abc import Sequence
from collections.abc import Callable, Sequence

import timm
import torch
Expand All @@ -34,6 +34,77 @@
logger = logging.getLogger(__name__)


def register_model_with_adv_trained_weights_tags(
model_name: str,
epsilons: Sequence[float],
lp: str,
cfg_fn: Callable,
) -> None:
"""Register adversarially trained model weights with a URL."""
from timm.models._registry import _model_pretrained_cfgs as model_pretrained_cfgs
from timm.models._registry import generate_default_cfgs

origin_url = "https://huggingface.co/madrylab/robust-imagenet-models"
paper_ids = "arXiv:2007.08489"

cfgs = {}
for eps in epsilons:
url = f"https://huggingface.co/mzweilin/robust-imagenet-models/resolve/main/{model_name}_{lp}_eps{eps}.pth"
tag = f"adv_{lp}_{eps}"
model_and_tag = f"{model_name}.{tag}"
cfgs[model_and_tag] = cfg_fn(
url=url,
origin_url=origin_url,
paper_ids=paper_ids,
)

default_cfgs = generate_default_cfgs(cfgs)

for model_and_tag in cfgs:
tag = model_and_tag[len(model_name) + 1 :] # Remove "[MODEL NAME]."
if model_and_tag in model_pretrained_cfgs:
logger.warning(f"Overriding model weights registration in timm: {model_and_tag}")
model_pretrained_cfgs[model_and_tag] = default_cfgs[model_name].cfgs[tag]
logger.info(f"Register model weights in timm: {model_and_tag}")


def register_in_bulk() -> None:
"""Register adversarially trained model weights in timm."""
from timm.models.resnet import _cfg as resnet_cfg_fn

l2_epsilons = [0, 0.01, 0.03, 0.05, 0.1, 0.25, 0.5, 1, 3, 5]
linf_epsilons = [0, 0.5, 1, 2, 4, 8]
model_names = ["resnet18", "resnet50", "wide_resnet50_2"]
cfg_fn = resnet_cfg_fn
for model_name in model_names:
register_model_with_adv_trained_weights_tags(
model_name=model_name,
epsilons=l2_epsilons,
lp="l2",
cfg_fn=cfg_fn,
)
register_model_with_adv_trained_weights_tags(
model_name=model_name,
epsilons=linf_epsilons,
lp="linf",
cfg_fn=cfg_fn,
)


def try_register_in_bulk() -> None:
"""Catch the error in case we cannot register new weights in timm due to changes of internal APIs."""
try:
register_in_bulk()
except ImportError as e:
logger.warning(
f"Adversarially trained backbones are not available. An error occured when registering weights: {e}",
)


# We will register model weights only once even if we import the module repeatedly, because it is a singleton.
try_register_in_bulk()


class TimmFeatureExtractor(nn.Module):
"""Extract intermediate features from timm models.

Expand Down Expand Up @@ -84,23 +155,13 @@ def __init__(
) -> None:
super().__init__()

# Extract backbone-name and weight-URI from the backbone string.
if "__AT__" in backbone:
backbone, uri = backbone.split("__AT__")
pretrained_cfg = timm.models.registry.get_pretrained_cfg(backbone)
# Override pretrained_cfg["url"] to use different pretrained weights.
pretrained_cfg["url"] = uri
else:
pretrained_cfg = None

self.backbone = backbone
self.layers = list(layers)
self.idx = self._map_layer_to_idx()
self.requires_grad = requires_grad
self.feature_extractor = timm.create_model(
backbone,
pretrained=pre_trained,
pretrained_cfg=pretrained_cfg,
features_only=True,
exportable=True,
out_indices=self.idx,
Expand Down
6 changes: 5 additions & 1 deletion src/anomalib/models/image/padim/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,11 @@ def __init__(
pre_trained=pre_trained,
).eval()
self.n_features_original = sum(self.feature_extractor.out_dims)
self.n_features = n_features or _N_FEATURES_DEFAULTS.get(self.backbone)

# The backbone tag may include weights, e.g. resnet18.adv_l2_0.1
backbone_name = self.backbone.split(".")[0]
self.n_features = n_features or _N_FEATURES_DEFAULTS.get(backbone_name)

if self.n_features is None:
msg = (
f"n_features must be specified for backbone {self.backbone}. "
Expand Down