diff --git a/src/callbacks/offline_callbacks/offline_classsubset_accuracy_callback.py b/src/callbacks/offline_callbacks/offline_classsubset_accuracy_callback.py new file mode 100644 index 0000000..15317d7 --- /dev/null +++ b/src/callbacks/offline_callbacks/offline_classsubset_accuracy_callback.py @@ -0,0 +1,59 @@ +from functools import partial + +import torch +import torch.nn.functional as F +from torchmetrics.functional.classification import multiclass_accuracy + +from callbacks.base.periodic_callback import PeriodicCallback +from metrics.entropy import multiclass_entropy + + +class OfflineClassSubsetAccuracyCallback(PeriodicCallback): + def __init__(self, dataset_key, **kwargs): + super().__init__(**kwargs) + self.dataset_key = dataset_key + self.__config_id = None + self.class_subset_indices = None + self.num_classes = None + + def _before_training(self, model, **kwargs): + dataset = self.data_container.get_dataset(self.dataset_key) + self.class_subset_indices = dataset.class_subset_indices + self.num_classes = dataset.subset_num_classes + + def _register_sampler_configs(self, trainer): + self.__config_id = self._register_sampler_config_from_key(key=self.dataset_key, mode="x class") + + def _forward(self, batch, model, trainer): + (x, cls), _ = batch + x = x.to(model.device, non_blocking=True) + with trainer.autocast_context: + predictions = model.classify(x) + # only use logits for the actual available classes + predictions = { + name: prediction[:, self.class_subset_indices] + for name, prediction in predictions.items() + } + return predictions, cls.clone() + + # noinspection PyMethodOverriding + def _periodic_callback(self, model, trainer, batch_size, data_iter, **_): + predictions, target = self.iterate_over_dataset( + forward_fn=partial(self._forward, model=model, trainer=trainer), + config_id=self.__config_id, + batch_size=batch_size, + data_iter=data_iter, + ) + + # log + target = target.to(model.device, non_blocking=True) + for prediction_name, y_hat in predictions.items(): + # accuracy + acc = multiclass_accuracy( + preds=y_hat, + target=target, + num_classes=self.num_classes, + average="micro", + ) + acc_key = f"accuracy1/{self.dataset_key}/{prediction_name}" + self.writer.add_scalar(acc_key, acc, logger=self.logger, format_str=".6f") \ No newline at end of file diff --git a/src/callbacks/offline_callbacks/offline_imagenet_c_callback.py b/src/callbacks/offline_callbacks/offline_imagenet_c_callback.py new file mode 100644 index 0000000..4bfe385 --- /dev/null +++ b/src/callbacks/offline_callbacks/offline_imagenet_c_callback.py @@ -0,0 +1,101 @@ +from collections import defaultdict +from functools import partial +from itertools import product + +import numpy as np +import torch +from kappadata.common.transforms import ImagenetNoaugTransform +from kappadata.wrappers import XTransformWrapper, SubsetWrapper +from torchmetrics.functional.classification import multiclass_accuracy + +from callbacks.base.periodic_callback import PeriodicCallback +from datasets.imagenet import ImageNet +from utils.kappaconfig.testrun_constants import TEST_RUN_EFFECTIVE_BATCH_SIZE + + +class OfflineImagenetCCallback(PeriodicCallback): + def __init__(self, resize_size=256, center_crop_size=224, interpolation="bicubic", **kwargs): + super().__init__(**kwargs) + self.transform = ImagenetNoaugTransform( + resize_size=resize_size, + center_crop_size=center_crop_size, + interpolation=interpolation, + ) + self.dataset_keys = [ + f"imagenet_c_{distortion}_{level}" + for distortion, level in product(ImageNet.IMAGENET_C_DISTORTIONS, [1, 2, 3, 4, 5]) + ] + self.__config_ids = {} + self.n_classes = None + + def _before_training(self, model, **kwargs): + assert len(model.output_shape) == 1 + self.n_classes = self.data_container.get_dataset("train").getdim_class() + + def register_root_datasets(self, dataset_config_provider=None, is_mindatarun=False): + for key in self.dataset_keys: + if key in self.data_container.datasets: + continue + temp = key.replace("imagenet_c_", "") + distortion = temp[:-2] + level = temp[-1] + dataset = ImageNet( + version="imagenet_c", + split=f"{distortion}/{level}", + dataset_config_provider=dataset_config_provider, + ) + dataset = XTransformWrapper(dataset=dataset, transform=ImagenetNoaugTransform()) + if is_mindatarun: + rng = torch.Generator().manual_seed(0) + dataset = SubsetWrapper( + dataset=dataset, + indices=torch.randperm(len(dataset), generator=rng)[:TEST_RUN_EFFECTIVE_BATCH_SIZE].tolist(), + ) + else: + assert len(dataset) == 50000 + self.data_container.datasets[key] = dataset + + def _register_sampler_configs(self, trainer): + for key in self.dataset_keys: + self.__config_ids[key] = self._register_sampler_config_from_key(key=key, mode="x class") + + @staticmethod + def _forward(batch, model, trainer): + (x, cls), _ = batch + x = x.to(model.device, non_blocking=True) + with trainer.autocast_context: + predictions = model.classify(x) + predictions = {name: prediction.cpu() for name, prediction in predictions.items()} + return predictions, cls.clone() + + # noinspection PyMethodOverriding + def _periodic_callback(self, model, trainer, batch_size, data_iter, **_): + all_accuracies = defaultdict(dict) + for dataset_key in self.dataset_keys: + # extract + predictions, classes = self.iterate_over_dataset( + forward_fn=partial(self._forward, model=model, trainer=trainer), + config_id=self.__config_ids[dataset_key], + batch_size=batch_size, + data_iter=data_iter, + ) + + # push to GPU for accuracy calculation + predictions = {k: v.to(model.device, non_blocking=True) for k, v in predictions.items()} + classes = classes.to(model.device, non_blocking=True) + + # log + for name, prediction in predictions.items(): + acc = multiclass_accuracy( + preds=prediction, + target=classes, + num_classes=self.n_classes, + average="micro", + ).item() + self.writer.add_scalar(f"accuracy1/{dataset_key}/{name}", acc, logger=self.logger, format_str=".4f") + all_accuracies[name][dataset_key] = acc + + # summarize over all + for name in all_accuracies.keys(): + acc = float(np.mean(list(all_accuracies[name].values()))) + self.writer.add_scalar(f"accuracy1/imagenet_c_overall/{name}", acc, logger=self.logger, format_str=".4f") diff --git a/src/datasets/imagenet_a.py b/src/datasets/imagenet_a.py new file mode 100644 index 0000000..ba6a14f --- /dev/null +++ b/src/datasets/imagenet_a.py @@ -0,0 +1,28 @@ +from .imagenet import ImageNet + +class ImageNetA(ImageNet): + IN1K_TO_INA = [ + 6, 11, 13, 15, 17, 22, 23, 27, 30, 37, 39, 42, 47, 50, 57, 70, 71, 76, 79, 89, 90, 94, 96, 97, 99, 105, 107, + 108, 110, 113, 124, 125, 130, 132, 143, 144, 150, 151, 207, 234, 235, 254, 277, 283, 287, 291, 295, 298, 301, + 306, 307, 308, 309, 310, 311, 313, 314, 315, 317, 319, 323, 324, 326, 327, 330, 334, 335, 336, 347, 361, 363, + 372, 378, 386, 397, 400, 401, 402, 404, 407, 411, 416, 417, 420, 425, 428, 430, 437, 438, 445, 456, 457, 461, + 462, 470, 472, 483, 486, 488, 492, 496, 514, 516, 528, 530, 539, 542, 543, 549, 552, 557, 561, 562, 569, 572, + 573, 575, 579, 589, 606, 607, 609, 614, 626, 627, 640, 641, 642, 643, 658, 668, 677, 682, 684, 687, 701, 704, + 719, 736, 746, 749, 752, 758, 763, 765, 768, 773, 774, 776, 779, 780, 786, 792, 797, 802, 803, 804, 813, 815, + 820, 823, 831, 833, 835, 839, 845, 847, 850, 859, 862, 870, 879, 880, 888, 890, 897, 900, 907, 913, 924, 932, + 933, 934, 937, 943, 945, 947, 951, 954, 956, 957, 959, 971, 972, 980, 981, 984, 986, 987, 988, + ] + + def __init__(self, **kwargs): + super().__init__(version="imagenet_a", split="val", **kwargs) + + def getshape_class(self): + return 1000, + + @property + def subset_num_classes(self): + return 200 + + @property + def class_subset_indices(self): + return self.IN1K_TO_INA diff --git a/src/datasets/imagenet_r.py b/src/datasets/imagenet_r.py new file mode 100644 index 0000000..17b0869 --- /dev/null +++ b/src/datasets/imagenet_r.py @@ -0,0 +1,28 @@ +from .imagenet import ImageNet + +class ImageNetR(ImageNet): + IN1K_TO_INR = [ + 1, 2, 4, 6, 8, 9, 11, 13, 22, 23, 26, 29, 31, 39, 47, 63, 71, 76, 79, 84, 90, 94, 96, 97, 99, 100, 105, 107, + 113, 122, 125, 130, 132, 144, 145, 147, 148, 150, 151, 155, 160, 161, 162, 163, 171, 172, 178, 187, 195, 199, + 203, 207, 208, 219, 231, 232, 234, 235, 242, 245, 247, 250, 251, 254, 259, 260, 263, 265, 267, 269, 276, 277, + 281, 288, 289, 291, 292, 293, 296, 299, 301, 308, 309, 310, 311, 314, 315, 319, 323, 327, 330, 334, 335, 337, + 338, 340, 341, 344, 347, 353, 355, 361, 362, 365, 366, 367, 368, 372, 388, 390, 393, 397, 401, 407, 413, 414, + 425, 428, 430, 435, 437, 441, 447, 448, 457, 462, 463, 469, 470, 471, 472, 476, 483, 487, 515, 546, 555, 558, + 570, 579, 583, 587, 593, 594, 596, 609, 613, 617, 621, 629, 637, 657, 658, 701, 717, 724, 763, 768, 774, 776, + 779, 780, 787, 805, 812, 815, 820, 824, 833, 847, 852, 866, 875, 883, 889, 895, 907, 928, 931, 932, 933, 934, + 936, 937, 943, 945, 947, 948, 949, 951, 953, 954, 957, 963, 965, 967, 980, 981, 983, 988, + ] + + def __init__(self, **kwargs): + super().__init__(version="imagenet_r", split="val", **kwargs) + + def getshape_class(self): + return 1000, + + @property + def subset_num_classes(self): + return 200 + + @property + def class_subset_indices(self): + return self.IN1K_TO_INR