-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathTNAS_Train_tester.py
82 lines (69 loc) · 3.39 KB
/
TNAS_Train_tester.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
from TNAS_model import makenet
import torch.optim as optim
input_string = "1234 mẹ mày gayyyyyyyyyyyy"
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
num_epochs = 2
batch_size = 256
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
train_indices = range(25000)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=False,sampler=train_indices)
#tao sẽ thêm phần validation sau
val_indices = range(25000,50000)
valloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=False,sampler=val_indices)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=False)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
def get_metric(input_string):
net = makenet(input_string=input_string)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1,momentum=0.9,nesterov=True,weight_decay=0.0005)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max = 200,eta_min=0)
for epoch in range(num_epochs):
net.train()
scheduler.step()
for inputs, labels in trainloader:
# get the inputs; data is a list of [inputs, labels]
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
correct = 0
total = 0
with torch.no_grad():
net.eval()
for inputs, labels in valloader: #về sau ở đây t sẽ thay bằng validation
outputs = net(inputs)
loss = criterion(outputs,labels)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
#đéo ai qtam đến loss làm metric nên ta cũng không có care đến in loss ra ngoài
acc = correct/total
print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')
return acc
#CHẮC CHẮN CÓ BUG VÀ T CŨNG CHƯA THÊM PHẦN CHẠY TRÊN GPU ĐÂU NHA MÀY
# with torch.no_grad():
# net.eval()
# for inputs, labels in testset: #về sau ở đây t sẽ thay bằng validation
# outputs = net(inputs)
# loss = criterion(outputs,labels)
# _, predicted = torch.max(outputs.data, 1)
# total += labels.size(0)
# correct += (predicted == labels).sum().item()
# #đéo ai qtam đến loss làm metric nên ta cũng không có care đến in loss ra ngoài
# print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')