Skip to content

Commit

Permalink
Add ImageNet-R,-A,-200
Browse files Browse the repository at this point in the history
Co-authored-by: George Pachitariu <[email protected]>
Co-authored-by: Evgenia Rusak <[email protected]>
  • Loading branch information
3 people committed May 15, 2021
1 parent b35eab0 commit df49ebb
Show file tree
Hide file tree
Showing 18 changed files with 1,574 additions and 304 deletions.
23 changes: 22 additions & 1 deletion examples/batchnorm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Steffen Schneider*, Evgenia Rusak*, Luisa Eck, Oliver Bringmann, Wieland Brendel
Website: [domainadaptation.org/batchnorm](https://domainadaptation.org/batchnorm)

This repository contains evaluation code for the paper *Improving robustness against common corruptions by covariate shift adaptation*.
We will release the code in the upcoming weeks. To get notified, watch and/or star this repository to get notified of updates!
The repository is updated frequently. To get notified, watch and/or star this repository!

Today's state-of-the-art machine vision models are vulnerable to image corruptions like blurring or compression artefacts, limiting their performance in many real-world applications. We here argue that popular benchmarks to measure model robustness against common corruptions (like ImageNet-C) underestimate model robustness in many (but not all) application scenarios. The key insight is that in many scenarios, multiple unlabeled examples of the corruptions are available and can be used for unsupervised online adaptation. Replacing the activation statistics estimated by batch normalization on the training set with the statistics of the corrupted images consistently improves the robustness across 25 different popular computer vision models. Using the corrected statistics, ResNet-50 reaches 62.2% mCE on ImageNet-C compared to 76.7% without adaptation. With the more robust AugMix model, we improve the state of the art from 56.5% mCE to 51.0% mCE. Even adapting to a single sample improves robustness for the ResNet-50 and AugMix models, and 32 samples are sufficient to improve the current state of the art for a ResNet-50 architecture. We argue that results with adapted statistics should be included whenever reporting scores in corruption benchmarks and other out-of-distribution generalization settings

Expand All @@ -26,6 +26,7 @@ With a simple recalculation of batch normalization statistics, we improve the me
| [DeepAugment+AugMix](https://github.com/hendrycks/imagenet-r) | 53.6 | 48.4 |45.4|
| [DeepAug+AM+RNXt101](https://github.com/hendrycks/imagenet-r) | **44.5** |**40.7** | **38.0** |


### Results for models trained with [Fixup](https://github.com/hongyi-zhang/Fixup) and [GroupNorm](https://github.com/ppwwyyxx/GroupNorm-reproduce) on ImageNet-C

Fixup and GN trained models perform better than non-adapted BN models but worse than adapted BN models.
Expand All @@ -36,6 +37,26 @@ Fixup and GN trained models perform better than non-adapted BN models but worse
|ResNet-101 |68.2 |67.6 |69.0 |**59.1**|
|ResNet-152 |67.6 |65.4 |69.3 |**58.0**|

### To reproduce the first table above

Run [`scripts/paper/table1.sh`](scripts/paper/table1.sh):
```sh
row="2" # This is the row to compute from the table
docker run -v "$IMAGENET_C_PATH":/ImageNet-C:ro \
-v "$CHECKPOINT_PATH":/checkpoints:ro \
-v .:/batchnorm \
-v ..:/deps \
-it georgepachitariu/robustness:latest \
bash /batchnorm/scripts/paper/table1.sh $row 2>&1
```
The script file requires 2 dependencies:
1. `IMANGENETC_PATH="/ImageNet-C"`
This is the path where you store the ImageNet-C dataset. The dataset is described [here](https://github.com/hendrycks/robustness) and you can download it from [here](https://zenodo.org/record/2235448#.YJjcNyaxWcw).

2. `CHECKPOINT_PATH="/checkpoints"`
This is the path where you store our checkpoints.
You can download them from here: TODO.


## News

Expand Down
2 changes: 1 addition & 1 deletion examples/batchnorm/scripts/paper/table1.sbatch
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
scontrol show job "$SLURM_JOB_ID"

# The image georgepachitariu/robustness was created using
# the Dockerfile from parent folder.
# the Dockerfile from main repository folder.
row="2" # This is the row in the table
singularity exec --nv -B /scratch_local \
-B "$IMAGENET_C_PATH":/ImageNet-C:ro \
Expand Down
1 change: 0 additions & 1 deletion robusta/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# This licence notice applies to all originally written code by the
# authors. Code taken from other open-source projects is indicated.
# See NOTICE for a list of all third-party licences used in the project.

"""A package for robustness and adaptation on ImageNet scale."""

from robusta import batchnorm
Expand Down
24 changes: 10 additions & 14 deletions robusta/batchnorm/bn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# This licence notice applies to all originally written code by the
# authors. Code taken from other open-source projects is indicated.
# See NOTICE for a list of all third-party licences used in the project.

""" Batch norm variants
"""

Expand All @@ -39,6 +38,7 @@ def adapt_bayesian(model: nn.Module, prior: float):


class PartlyAdaptiveBN(nn.Module):

@staticmethod
def find_bns(parent, estimate_mean, estimate_var):
replace_mods = []
Expand All @@ -52,8 +52,7 @@ def find_bns(parent, estimate_mean, estimate_var):
else:
replace_mods.extend(
PartlyAdaptiveBN.find_bns(child, estimate_mean,
estimate_var)
)
estimate_var))

return replace_mods

Expand Down Expand Up @@ -129,6 +128,7 @@ def forward(self, input):


class EMABatchNorm(nn.Module):

@staticmethod
def reset_stats(module):
module.reset_running_stats()
Expand Down Expand Up @@ -205,23 +205,19 @@ def __init__(self, layer, prior):
self.layer = layer
self.layer.eval()

self.norm = nn.BatchNorm2d(
self.layer.num_features, affine=False, momentum=1.0
)
self.norm = nn.BatchNorm2d(self.layer.num_features,
affine=False,
momentum=1.0)

self.prior = prior

def forward(self, input):
self.norm(input)

running_mean = (
self.prior * self.layer.running_mean
+ (1 - self.prior) * self.norm.running_mean
)
running_var = (
self.prior * self.layer.running_var
+ (1 - self.prior) * self.norm.running_var
)
running_mean = (self.prior * self.layer.running_mean +
(1 - self.prior) * self.norm.running_mean)
running_var = (self.prior * self.layer.running_var +
(1 - self.prior) * self.norm.running_var)

return F.batch_norm(
input,
Expand Down
2 changes: 1 addition & 1 deletion robusta/batchnorm/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
# This licence notice applies to all originally written code by the
# authors. Code taken from other open-source projects is indicated.
# See NOTICE for a list of all third-party licences used in the project.

""" Helper functions for stages ablations """

import torchvision
from torch import nn


def split_model(model):
if not isinstance(model, torchvision.models.ResNet):
print("Only resnet models defined for this analysis so far")
Expand Down
1 change: 0 additions & 1 deletion robusta/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
# authors. Code taken from other open-source projects is indicated.
# See NOTICE for a list of all third-party licences used in the project.


from robusta.datasets import base
from robusta.datasets import imagenet200
from robusta.datasets import imageneta
Expand Down
Loading

0 comments on commit df49ebb

Please sign in to comment.