forked from Syyabb/PUD
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutil.py
151 lines (116 loc) · 4.18 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import numpy as np
import torch
from torch import optim
from torch.utils.data import DataLoader, Dataset
from typing import Collection, Dict, List, Union
import torch.backends.cudnn as cudnn
import datasets
if torch.cuda.is_available():
cudnn.benchmark = True
default_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def get_module_device(module: torch.nn.Module, check=True):
if check:
assert len(set(param.device for param in module.parameters())) == 1
return next(module.parameters()).device
def either_dataloader_dataset_to_both(
data: Union[DataLoader, Dataset], *, batch_size=None, eval=False, **kwargs
):
if isinstance(data, DataLoader):
dataloader = data
dataset = data.dataset
elif isinstance(data, Dataset):
dataset = data
dl_kwargs = {}
if eval:
dl_kwargs.update(dict(batch_size=1000, shuffle=False, drop_last=False))
else:
dl_kwargs.update(dict(batch_size=128, shuffle=True))
if batch_size is not None:
dl_kwargs["batch_size"] = batch_size
dl_kwargs.update(kwargs)
dataloader = datasets.make_dataloader(data, **dl_kwargs)
else:
raise NotImplementedError()
return dataloader, dataset
clf_loss = torch.nn.CrossEntropyLoss()
def clf_correct(y_pred: torch.Tensor, y: torch.Tensor):
y_hat = y_pred.data.max(1)[1]
correct = (y_hat == y).long().cpu().sum()
return correct
def clf_eval(model: torch.nn.Module, data: Union[DataLoader, Dataset], tf_writer=None):
device = get_module_device(model)
dataloader, _ = either_dataloader_dataset_to_both(data, eval=True)
total_correct, total_loss = 0.0, 0.0
cl_num = 0
num = 0
with torch.no_grad():
model.eval()
for x, oy, y in dataloader:
x, y = x.to(device), y.to(device)
oy = oy.to(device)
y_pred = model(x)
loss = clf_loss(y_pred, y)
correct = clf_correct(y_pred, y)
total_correct += correct.item()
total_loss += loss.item()
for i in range(y.shape[0]):
if oy[i] == y[i]:
cl_num += 1
num += 1
if tf_writer is not None:
for i in range(50):
tf_writer.add_image("Images_bd", x[i], global_step=i)
n = len(dataloader.dataset)
print("this dataset has a clean data rate {:.4f}".format(cl_num/num))
total_correct /= n
total_loss /= n
return total_correct, total_loss
def get_mean_lr(opt: optim.Optimizer):
return np.mean([group["lr"] for group in opt.param_groups])
def sigmoid_rampup(current, rampup_length):
"""Exponential rampup from https://arxiv.org/abs/1610.02242"""
if rampup_length == 0:
return 1.0
else:
current = np.clip(current, 0.0, rampup_length)
phase = 1.0 - current / rampup_length
return float(np.exp(-5.0 * phase * phase))
def compute_all_reps(
model: torch.nn.Sequential,
data: Union[DataLoader, Dataset],
*,
layers: Collection[int],
flat=False,
) -> Dict[int, np.ndarray]:
device = get_module_device(model)
dataloader, dataset = either_dataloader_dataset_to_both(data, eval=True)
n = len(dataset)
max_layer = max(layers)
assert max_layer < len(model)
reps = {}
x = dataset[0][0][None, ...].to(device)
for i, layer in enumerate(model):
if i > max_layer:
break
x = layer(x)
if i in layers:
inner_shape = x.shape[1:]
reps[i] = torch.empty(n, *inner_shape)
with torch.no_grad():
model.eval()
start_index = 0
for x, _, _ in dataloader:
x = x.to(device)
minibatch_size = len(x)
for i, layer in enumerate(model):
if i > max_layer:
break
x = layer(x)
if i in layers:
reps[i][start_index : start_index + minibatch_size] = x.cpu()
start_index += minibatch_size
if flat:
for layer in reps:
layer_reps = reps[layer]
reps[layer] = layer_reps.reshape(layer_reps.shape[0], -1)
return reps