-
Notifications
You must be signed in to change notification settings - Fork 0
/
baffon.py
51 lines (37 loc) · 1.05 KB
/
baffon.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import dtorch as dt
from dtorchvision.models import AutoEncoder
from dtorchvision.datasets import MNISTDataset
import dtorchvision.models
from matplotlib import pyplot as plt
import random
#autoencoder = AutoEncoder(784, [128, 32])
autoencoder = dtorchvision.models.MNISTAutoEncoder_128_32()
i = input("Train? (y/n) ")
if (i == 'n'):
dataset = MNISTDataset(download=True)
(x, _), _ = dataset.data
a = autoencoder(x)
loss = dt.loss.MSELoss()
print(loss(a, x))
img = random.randint(0, len(a) - 1)
plt.imshow(a[img].reshape(28, 28))
plt.show()
plt.imshow(x[img].reshape(28, 28))
plt.show()
exit()
autoencoder.load('model.jt')
autoencoder.train()
dataset = MNISTDataset()
(x, _), _ = dataset.data
optimizer = dt.optim.Adam(autoencoder.parameters(), lr = 0.001)
print(autoencoder)
epochs = 1000
loss = dt.loss.MSELoss()
for i in range(epochs):
y_pred = autoencoder(x)
optimizer.zero_grad()
res = loss(y_pred, x)
print(f"Loss", res[0])
res.backward()
optimizer.step()
autoencoder.save('model.jt')