-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
97 lines (82 loc) · 2.63 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from torchvision import datasets, transforms
from model.pixel_lstm.pixel_lstm import DiagonalPixelLSTM
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
BATCH_SIZE = 600
EPOCHS = 10
DEVICE = "cuda"
train_set = datasets.MNIST(
"./MNIST",
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
)
test_set = datasets.MNIST(
"./MNIST",
train=False,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
)
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE)
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.shape[0], -1)
model = nn.Sequential(
nn.Conv2d(1, 10, [3, 3], padding=1),
nn.ReLU(),
DiagonalPixelLSTM(10, 50),
nn.MaxPool2d([28, 28]),
Flatten(),
nn.Linear(50, 10),
)
model.to(DEVICE)
optimizer = torch.optim.SGD(model.parameters(), lr=0.05, momentum=0.5)
def train_epoch(model, dataloader, optimizer):
model.train()
criterion = nn.CrossEntropyLoss()
total_correct = 0
total_loss = 0.0
total_examples = 0
for i, data in enumerate(dataloader):
x, y = data
x, y = x.to(DEVICE), y.to(DEVICE)
y_hat = model(x)
loss = criterion(y_hat, y)
model.zero_grad()
loss.backward()
optimizer.step()
total_examples += y.size(0)
total_loss += loss.item()
total_correct += (torch.argmax(y_hat, 1) == y).sum().item()
return total_loss / len(dataloader), total_correct / total_examples
def test_epoch(model, dataloader):
model.eval()
criterion = nn.CrossEntropyLoss()
total_correct = 0
total_examples = 0
total_loss = 0.0
for i, data in enumerate(dataloader):
x, y = data
x, y = x.to(DEVICE), y.to(DEVICE)
y_hat = model(x)
loss = criterion(y_hat, y)
total_loss += loss.item()
total_correct += (torch.argmax(y_hat, 1) == y).sum().item()
total_examples += y.size(0)
return total_loss / len(dataloader), total_correct / total_examples
for epoch in range(EPOCHS):
print("Epoch {:02d}".format(epoch))
loss, acc = train_epoch(
model, tqdm(train_loader, desc="Training", unit="batch"), optimizer
)
print("Train Loss: {:.04f} Accuracy: {:.04f}".format(loss, acc))
loss, acc = test_epoch(model, test_loader)
print("Test Loss: {:.04f} Accuracy: {:.04f}".format(loss, acc))
print("Finished.")