forked from braun-steven/simple-einet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_normflows.py
108 lines (86 loc) · 3 KB
/
test_normflows.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
import torchvision as tv
import numpy as np
import normflows as nf
from matplotlib import pyplot as plt
from tqdm import tqdm
# Set up model
# Define flows
L = 3
K = 16
torch.manual_seed(0)
input_shape = (3, 32, 32)
n_dims = np.prod(input_shape)
channels = 3
hidden_channels = 256
split_mode = 'channel'
scale = True
num_classes = 10
# Set up flows, distributions and merge operations
q0 = []
merges = []
flows = []
for i in range(L):
flows_ = []
for j in range(K):
flows_ += [nf.flows.GlowBlock(channels * 2 ** (L + 1 - i), hidden_channels,
split_mode=split_mode, scale=scale)]
flows_ += [nf.flows.Squeeze()]
flows += [flows_]
if i > 0:
merges += [nf.flows.Merge()]
latent_shape = (input_shape[0] * 2 ** (L - i), input_shape[1] // 2 ** (L - i),
input_shape[2] // 2 ** (L - i))
else:
latent_shape = (input_shape[0] * 2 ** (L + 1), input_shape[1] // 2 ** L,
input_shape[2] // 2 ** L)
q0 += [nf.distributions.ClassCondDiagGaussian(latent_shape, num_classes)]
# Construct flow model with the multiscale architecture
model = nf.MultiscaleFlow(q0, flows, merges)
# Move model on GPU if available
enable_cuda = True
device = torch.device('cuda' if torch.cuda.is_available()
and enable_cuda else 'cpu')
model = model.to(device)
# Prepare training data
batch_size = 128
transform = tv.transforms.Compose([tv.transforms.ToTensor(
), nf.utils.Scale(255. / 256.), nf.utils.Jitter(1 / 256.)])
train_data = tv.datasets.CIFAR10('datasets/', train=True,
download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True,
drop_last=True)
test_data = tv.datasets.CIFAR10('datasets/', train=False,
download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)
train_iter = iter(train_loader)
# Train model
max_iter = 20000
loss_hist = np.array([])
optimizer = torch.optim.Adamax(model.parameters(), lr=1e-3, weight_decay=1e-5)
for i in tqdm(range(max_iter)):
try:
x, y = next(train_iter)
except StopIteration:
train_iter = iter(train_loader)
x, y = next(train_iter)
optimizer.zero_grad()
loss = model.forward_kld(x.to(device), y.to(device))
if ~(torch.isnan(loss) | torch.isinf(loss)):
loss.backward()
optimizer.step()
loss_hist = np.append(loss_hist, loss.detach().to('cpu').numpy())
plt.figure(figsize=(10, 10))
plt.plot(loss_hist, label='loss')
plt.legend()
plt.show()
# Get bits per dim
n = 0
bpd_cum = 0
with torch.no_grad():
for x, y in iter(test_loader):
nll = model(x.to(device), y.to(device))
nll_np = nll.cpu().numpy()
bpd_cum += np.nansum(nll_np / np.log(2) / n_dims + 8)
n += len(x) - np.sum(np.isnan(nll_np))
print('Bits per dim: ', bpd_cum / n)