-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
aa2908d
commit df1e46b
Showing
4 changed files
with
216 additions
and
0 deletions.
There are no files selected for viewing
59 changes: 59 additions & 0 deletions
59
src/callbacks/offline_callbacks/offline_classsubset_accuracy_callback.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from functools import partial | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torchmetrics.functional.classification import multiclass_accuracy | ||
|
||
from callbacks.base.periodic_callback import PeriodicCallback | ||
from metrics.entropy import multiclass_entropy | ||
|
||
|
||
class OfflineClassSubsetAccuracyCallback(PeriodicCallback): | ||
def __init__(self, dataset_key, **kwargs): | ||
super().__init__(**kwargs) | ||
self.dataset_key = dataset_key | ||
self.__config_id = None | ||
self.class_subset_indices = None | ||
self.num_classes = None | ||
|
||
def _before_training(self, model, **kwargs): | ||
dataset = self.data_container.get_dataset(self.dataset_key) | ||
self.class_subset_indices = dataset.class_subset_indices | ||
self.num_classes = dataset.subset_num_classes | ||
|
||
def _register_sampler_configs(self, trainer): | ||
self.__config_id = self._register_sampler_config_from_key(key=self.dataset_key, mode="x class") | ||
|
||
def _forward(self, batch, model, trainer): | ||
(x, cls), _ = batch | ||
x = x.to(model.device, non_blocking=True) | ||
with trainer.autocast_context: | ||
predictions = model.classify(x) | ||
# only use logits for the actual available classes | ||
predictions = { | ||
name: prediction[:, self.class_subset_indices] | ||
for name, prediction in predictions.items() | ||
} | ||
return predictions, cls.clone() | ||
|
||
# noinspection PyMethodOverriding | ||
def _periodic_callback(self, model, trainer, batch_size, data_iter, **_): | ||
predictions, target = self.iterate_over_dataset( | ||
forward_fn=partial(self._forward, model=model, trainer=trainer), | ||
config_id=self.__config_id, | ||
batch_size=batch_size, | ||
data_iter=data_iter, | ||
) | ||
|
||
# log | ||
target = target.to(model.device, non_blocking=True) | ||
for prediction_name, y_hat in predictions.items(): | ||
# accuracy | ||
acc = multiclass_accuracy( | ||
preds=y_hat, | ||
target=target, | ||
num_classes=self.num_classes, | ||
average="micro", | ||
) | ||
acc_key = f"accuracy1/{self.dataset_key}/{prediction_name}" | ||
self.writer.add_scalar(acc_key, acc, logger=self.logger, format_str=".6f") |
101 changes: 101 additions & 0 deletions
101
src/callbacks/offline_callbacks/offline_imagenet_c_callback.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
from collections import defaultdict | ||
from functools import partial | ||
from itertools import product | ||
|
||
import numpy as np | ||
import torch | ||
from kappadata.common.transforms import ImagenetNoaugTransform | ||
from kappadata.wrappers import XTransformWrapper, SubsetWrapper | ||
from torchmetrics.functional.classification import multiclass_accuracy | ||
|
||
from callbacks.base.periodic_callback import PeriodicCallback | ||
from datasets.imagenet import ImageNet | ||
from utils.kappaconfig.testrun_constants import TEST_RUN_EFFECTIVE_BATCH_SIZE | ||
|
||
|
||
class OfflineImagenetCCallback(PeriodicCallback): | ||
def __init__(self, resize_size=256, center_crop_size=224, interpolation="bicubic", **kwargs): | ||
super().__init__(**kwargs) | ||
self.transform = ImagenetNoaugTransform( | ||
resize_size=resize_size, | ||
center_crop_size=center_crop_size, | ||
interpolation=interpolation, | ||
) | ||
self.dataset_keys = [ | ||
f"imagenet_c_{distortion}_{level}" | ||
for distortion, level in product(ImageNet.IMAGENET_C_DISTORTIONS, [1, 2, 3, 4, 5]) | ||
] | ||
self.__config_ids = {} | ||
self.n_classes = None | ||
|
||
def _before_training(self, model, **kwargs): | ||
assert len(model.output_shape) == 1 | ||
self.n_classes = self.data_container.get_dataset("train").getdim_class() | ||
|
||
def register_root_datasets(self, dataset_config_provider=None, is_mindatarun=False): | ||
for key in self.dataset_keys: | ||
if key in self.data_container.datasets: | ||
continue | ||
temp = key.replace("imagenet_c_", "") | ||
distortion = temp[:-2] | ||
level = temp[-1] | ||
dataset = ImageNet( | ||
version="imagenet_c", | ||
split=f"{distortion}/{level}", | ||
dataset_config_provider=dataset_config_provider, | ||
) | ||
dataset = XTransformWrapper(dataset=dataset, transform=ImagenetNoaugTransform()) | ||
if is_mindatarun: | ||
rng = torch.Generator().manual_seed(0) | ||
dataset = SubsetWrapper( | ||
dataset=dataset, | ||
indices=torch.randperm(len(dataset), generator=rng)[:TEST_RUN_EFFECTIVE_BATCH_SIZE].tolist(), | ||
) | ||
else: | ||
assert len(dataset) == 50000 | ||
self.data_container.datasets[key] = dataset | ||
|
||
def _register_sampler_configs(self, trainer): | ||
for key in self.dataset_keys: | ||
self.__config_ids[key] = self._register_sampler_config_from_key(key=key, mode="x class") | ||
|
||
@staticmethod | ||
def _forward(batch, model, trainer): | ||
(x, cls), _ = batch | ||
x = x.to(model.device, non_blocking=True) | ||
with trainer.autocast_context: | ||
predictions = model.classify(x) | ||
predictions = {name: prediction.cpu() for name, prediction in predictions.items()} | ||
return predictions, cls.clone() | ||
|
||
# noinspection PyMethodOverriding | ||
def _periodic_callback(self, model, trainer, batch_size, data_iter, **_): | ||
all_accuracies = defaultdict(dict) | ||
for dataset_key in self.dataset_keys: | ||
# extract | ||
predictions, classes = self.iterate_over_dataset( | ||
forward_fn=partial(self._forward, model=model, trainer=trainer), | ||
config_id=self.__config_ids[dataset_key], | ||
batch_size=batch_size, | ||
data_iter=data_iter, | ||
) | ||
|
||
# push to GPU for accuracy calculation | ||
predictions = {k: v.to(model.device, non_blocking=True) for k, v in predictions.items()} | ||
classes = classes.to(model.device, non_blocking=True) | ||
|
||
# log | ||
for name, prediction in predictions.items(): | ||
acc = multiclass_accuracy( | ||
preds=prediction, | ||
target=classes, | ||
num_classes=self.n_classes, | ||
average="micro", | ||
).item() | ||
self.writer.add_scalar(f"accuracy1/{dataset_key}/{name}", acc, logger=self.logger, format_str=".4f") | ||
all_accuracies[name][dataset_key] = acc | ||
|
||
# summarize over all | ||
for name in all_accuracies.keys(): | ||
acc = float(np.mean(list(all_accuracies[name].values()))) | ||
self.writer.add_scalar(f"accuracy1/imagenet_c_overall/{name}", acc, logger=self.logger, format_str=".4f") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from .imagenet import ImageNet | ||
|
||
class ImageNetA(ImageNet): | ||
IN1K_TO_INA = [ | ||
6, 11, 13, 15, 17, 22, 23, 27, 30, 37, 39, 42, 47, 50, 57, 70, 71, 76, 79, 89, 90, 94, 96, 97, 99, 105, 107, | ||
108, 110, 113, 124, 125, 130, 132, 143, 144, 150, 151, 207, 234, 235, 254, 277, 283, 287, 291, 295, 298, 301, | ||
306, 307, 308, 309, 310, 311, 313, 314, 315, 317, 319, 323, 324, 326, 327, 330, 334, 335, 336, 347, 361, 363, | ||
372, 378, 386, 397, 400, 401, 402, 404, 407, 411, 416, 417, 420, 425, 428, 430, 437, 438, 445, 456, 457, 461, | ||
462, 470, 472, 483, 486, 488, 492, 496, 514, 516, 528, 530, 539, 542, 543, 549, 552, 557, 561, 562, 569, 572, | ||
573, 575, 579, 589, 606, 607, 609, 614, 626, 627, 640, 641, 642, 643, 658, 668, 677, 682, 684, 687, 701, 704, | ||
719, 736, 746, 749, 752, 758, 763, 765, 768, 773, 774, 776, 779, 780, 786, 792, 797, 802, 803, 804, 813, 815, | ||
820, 823, 831, 833, 835, 839, 845, 847, 850, 859, 862, 870, 879, 880, 888, 890, 897, 900, 907, 913, 924, 932, | ||
933, 934, 937, 943, 945, 947, 951, 954, 956, 957, 959, 971, 972, 980, 981, 984, 986, 987, 988, | ||
] | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(version="imagenet_a", split="val", **kwargs) | ||
|
||
def getshape_class(self): | ||
return 1000, | ||
|
||
@property | ||
def subset_num_classes(self): | ||
return 200 | ||
|
||
@property | ||
def class_subset_indices(self): | ||
return self.IN1K_TO_INA |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from .imagenet import ImageNet | ||
|
||
class ImageNetR(ImageNet): | ||
IN1K_TO_INR = [ | ||
1, 2, 4, 6, 8, 9, 11, 13, 22, 23, 26, 29, 31, 39, 47, 63, 71, 76, 79, 84, 90, 94, 96, 97, 99, 100, 105, 107, | ||
113, 122, 125, 130, 132, 144, 145, 147, 148, 150, 151, 155, 160, 161, 162, 163, 171, 172, 178, 187, 195, 199, | ||
203, 207, 208, 219, 231, 232, 234, 235, 242, 245, 247, 250, 251, 254, 259, 260, 263, 265, 267, 269, 276, 277, | ||
281, 288, 289, 291, 292, 293, 296, 299, 301, 308, 309, 310, 311, 314, 315, 319, 323, 327, 330, 334, 335, 337, | ||
338, 340, 341, 344, 347, 353, 355, 361, 362, 365, 366, 367, 368, 372, 388, 390, 393, 397, 401, 407, 413, 414, | ||
425, 428, 430, 435, 437, 441, 447, 448, 457, 462, 463, 469, 470, 471, 472, 476, 483, 487, 515, 546, 555, 558, | ||
570, 579, 583, 587, 593, 594, 596, 609, 613, 617, 621, 629, 637, 657, 658, 701, 717, 724, 763, 768, 774, 776, | ||
779, 780, 787, 805, 812, 815, 820, 824, 833, 847, 852, 866, 875, 883, 889, 895, 907, 928, 931, 932, 933, 934, | ||
936, 937, 943, 945, 947, 948, 949, 951, 953, 954, 957, 963, 965, 967, 980, 981, 983, 988, | ||
] | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(version="imagenet_r", split="val", **kwargs) | ||
|
||
def getshape_class(self): | ||
return 1000, | ||
|
||
@property | ||
def subset_num_classes(self): | ||
return 200 | ||
|
||
@property | ||
def class_subset_indices(self): | ||
return self.IN1K_TO_INR |