Skip to content

Commit

Permalink
robustness datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
BenediktAlkin committed May 23, 2024
1 parent aa2908d commit df1e46b
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from functools import partial

import torch
import torch.nn.functional as F
from torchmetrics.functional.classification import multiclass_accuracy

from callbacks.base.periodic_callback import PeriodicCallback
from metrics.entropy import multiclass_entropy


class OfflineClassSubsetAccuracyCallback(PeriodicCallback):
def __init__(self, dataset_key, **kwargs):
super().__init__(**kwargs)
self.dataset_key = dataset_key
self.__config_id = None
self.class_subset_indices = None
self.num_classes = None

def _before_training(self, model, **kwargs):
dataset = self.data_container.get_dataset(self.dataset_key)
self.class_subset_indices = dataset.class_subset_indices
self.num_classes = dataset.subset_num_classes

def _register_sampler_configs(self, trainer):
self.__config_id = self._register_sampler_config_from_key(key=self.dataset_key, mode="x class")

def _forward(self, batch, model, trainer):
(x, cls), _ = batch
x = x.to(model.device, non_blocking=True)
with trainer.autocast_context:
predictions = model.classify(x)
# only use logits for the actual available classes
predictions = {
name: prediction[:, self.class_subset_indices]
for name, prediction in predictions.items()
}
return predictions, cls.clone()

# noinspection PyMethodOverriding
def _periodic_callback(self, model, trainer, batch_size, data_iter, **_):
predictions, target = self.iterate_over_dataset(
forward_fn=partial(self._forward, model=model, trainer=trainer),
config_id=self.__config_id,
batch_size=batch_size,
data_iter=data_iter,
)

# log
target = target.to(model.device, non_blocking=True)
for prediction_name, y_hat in predictions.items():
# accuracy
acc = multiclass_accuracy(
preds=y_hat,
target=target,
num_classes=self.num_classes,
average="micro",
)
acc_key = f"accuracy1/{self.dataset_key}/{prediction_name}"
self.writer.add_scalar(acc_key, acc, logger=self.logger, format_str=".6f")
101 changes: 101 additions & 0 deletions src/callbacks/offline_callbacks/offline_imagenet_c_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from collections import defaultdict
from functools import partial
from itertools import product

import numpy as np
import torch
from kappadata.common.transforms import ImagenetNoaugTransform
from kappadata.wrappers import XTransformWrapper, SubsetWrapper
from torchmetrics.functional.classification import multiclass_accuracy

from callbacks.base.periodic_callback import PeriodicCallback
from datasets.imagenet import ImageNet
from utils.kappaconfig.testrun_constants import TEST_RUN_EFFECTIVE_BATCH_SIZE


class OfflineImagenetCCallback(PeriodicCallback):
def __init__(self, resize_size=256, center_crop_size=224, interpolation="bicubic", **kwargs):
super().__init__(**kwargs)
self.transform = ImagenetNoaugTransform(
resize_size=resize_size,
center_crop_size=center_crop_size,
interpolation=interpolation,
)
self.dataset_keys = [
f"imagenet_c_{distortion}_{level}"
for distortion, level in product(ImageNet.IMAGENET_C_DISTORTIONS, [1, 2, 3, 4, 5])
]
self.__config_ids = {}
self.n_classes = None

def _before_training(self, model, **kwargs):
assert len(model.output_shape) == 1
self.n_classes = self.data_container.get_dataset("train").getdim_class()

def register_root_datasets(self, dataset_config_provider=None, is_mindatarun=False):
for key in self.dataset_keys:
if key in self.data_container.datasets:
continue
temp = key.replace("imagenet_c_", "")
distortion = temp[:-2]
level = temp[-1]
dataset = ImageNet(
version="imagenet_c",
split=f"{distortion}/{level}",
dataset_config_provider=dataset_config_provider,
)
dataset = XTransformWrapper(dataset=dataset, transform=ImagenetNoaugTransform())
if is_mindatarun:
rng = torch.Generator().manual_seed(0)
dataset = SubsetWrapper(
dataset=dataset,
indices=torch.randperm(len(dataset), generator=rng)[:TEST_RUN_EFFECTIVE_BATCH_SIZE].tolist(),
)
else:
assert len(dataset) == 50000
self.data_container.datasets[key] = dataset

def _register_sampler_configs(self, trainer):
for key in self.dataset_keys:
self.__config_ids[key] = self._register_sampler_config_from_key(key=key, mode="x class")

@staticmethod
def _forward(batch, model, trainer):
(x, cls), _ = batch
x = x.to(model.device, non_blocking=True)
with trainer.autocast_context:
predictions = model.classify(x)
predictions = {name: prediction.cpu() for name, prediction in predictions.items()}
return predictions, cls.clone()

# noinspection PyMethodOverriding
def _periodic_callback(self, model, trainer, batch_size, data_iter, **_):
all_accuracies = defaultdict(dict)
for dataset_key in self.dataset_keys:
# extract
predictions, classes = self.iterate_over_dataset(
forward_fn=partial(self._forward, model=model, trainer=trainer),
config_id=self.__config_ids[dataset_key],
batch_size=batch_size,
data_iter=data_iter,
)

# push to GPU for accuracy calculation
predictions = {k: v.to(model.device, non_blocking=True) for k, v in predictions.items()}
classes = classes.to(model.device, non_blocking=True)

# log
for name, prediction in predictions.items():
acc = multiclass_accuracy(
preds=prediction,
target=classes,
num_classes=self.n_classes,
average="micro",
).item()
self.writer.add_scalar(f"accuracy1/{dataset_key}/{name}", acc, logger=self.logger, format_str=".4f")
all_accuracies[name][dataset_key] = acc

# summarize over all
for name in all_accuracies.keys():
acc = float(np.mean(list(all_accuracies[name].values())))
self.writer.add_scalar(f"accuracy1/imagenet_c_overall/{name}", acc, logger=self.logger, format_str=".4f")
28 changes: 28 additions & 0 deletions src/datasets/imagenet_a.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from .imagenet import ImageNet

class ImageNetA(ImageNet):
IN1K_TO_INA = [
6, 11, 13, 15, 17, 22, 23, 27, 30, 37, 39, 42, 47, 50, 57, 70, 71, 76, 79, 89, 90, 94, 96, 97, 99, 105, 107,
108, 110, 113, 124, 125, 130, 132, 143, 144, 150, 151, 207, 234, 235, 254, 277, 283, 287, 291, 295, 298, 301,
306, 307, 308, 309, 310, 311, 313, 314, 315, 317, 319, 323, 324, 326, 327, 330, 334, 335, 336, 347, 361, 363,
372, 378, 386, 397, 400, 401, 402, 404, 407, 411, 416, 417, 420, 425, 428, 430, 437, 438, 445, 456, 457, 461,
462, 470, 472, 483, 486, 488, 492, 496, 514, 516, 528, 530, 539, 542, 543, 549, 552, 557, 561, 562, 569, 572,
573, 575, 579, 589, 606, 607, 609, 614, 626, 627, 640, 641, 642, 643, 658, 668, 677, 682, 684, 687, 701, 704,
719, 736, 746, 749, 752, 758, 763, 765, 768, 773, 774, 776, 779, 780, 786, 792, 797, 802, 803, 804, 813, 815,
820, 823, 831, 833, 835, 839, 845, 847, 850, 859, 862, 870, 879, 880, 888, 890, 897, 900, 907, 913, 924, 932,
933, 934, 937, 943, 945, 947, 951, 954, 956, 957, 959, 971, 972, 980, 981, 984, 986, 987, 988,
]

def __init__(self, **kwargs):
super().__init__(version="imagenet_a", split="val", **kwargs)

def getshape_class(self):
return 1000,

@property
def subset_num_classes(self):
return 200

@property
def class_subset_indices(self):
return self.IN1K_TO_INA
28 changes: 28 additions & 0 deletions src/datasets/imagenet_r.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from .imagenet import ImageNet

class ImageNetR(ImageNet):
IN1K_TO_INR = [
1, 2, 4, 6, 8, 9, 11, 13, 22, 23, 26, 29, 31, 39, 47, 63, 71, 76, 79, 84, 90, 94, 96, 97, 99, 100, 105, 107,
113, 122, 125, 130, 132, 144, 145, 147, 148, 150, 151, 155, 160, 161, 162, 163, 171, 172, 178, 187, 195, 199,
203, 207, 208, 219, 231, 232, 234, 235, 242, 245, 247, 250, 251, 254, 259, 260, 263, 265, 267, 269, 276, 277,
281, 288, 289, 291, 292, 293, 296, 299, 301, 308, 309, 310, 311, 314, 315, 319, 323, 327, 330, 334, 335, 337,
338, 340, 341, 344, 347, 353, 355, 361, 362, 365, 366, 367, 368, 372, 388, 390, 393, 397, 401, 407, 413, 414,
425, 428, 430, 435, 437, 441, 447, 448, 457, 462, 463, 469, 470, 471, 472, 476, 483, 487, 515, 546, 555, 558,
570, 579, 583, 587, 593, 594, 596, 609, 613, 617, 621, 629, 637, 657, 658, 701, 717, 724, 763, 768, 774, 776,
779, 780, 787, 805, 812, 815, 820, 824, 833, 847, 852, 866, 875, 883, 889, 895, 907, 928, 931, 932, 933, 934,
936, 937, 943, 945, 947, 948, 949, 951, 953, 954, 957, 963, 965, 967, 980, 981, 983, 988,
]

def __init__(self, **kwargs):
super().__init__(version="imagenet_r", split="val", **kwargs)

def getshape_class(self):
return 1000,

@property
def subset_num_classes(self):
return 200

@property
def class_subset_indices(self):
return self.IN1K_TO_INR

0 comments on commit df1e46b

Please sign in to comment.