-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtest.py
113 lines (88 loc) · 4.14 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import sys
import glob
import numpy as np
import torch
import utils
import logging
import argparse
import torch.nn as nn
import genotypes
import torch.utils
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
from model import NetworkCIFAR as Network
parser = argparse.ArgumentParser("cifar")
parser.add_argument('--data', type=str, default='./data', help='location of the data corpus')
parser.add_argument('--batch_size', type=int, default=96, help='batch size')
parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
parser.add_argument('--gpus', type=str, default='0,1', help='gpu device id')
parser.add_argument('--init_channels', type=int, default=36, help='num of init channels')
parser.add_argument('--layers', type=int, default=20, help='total number of layers')
parser.add_argument('--model_path', type=str, default='./trained_model/cifar10_model.pt',
help='path of pretrained model')
parser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower')
parser.add_argument('--cutout', action='store_true', default=False, help='use cutout')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path probability')
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--arch', type=str, default='RelativeNAS', help='which architecture to use')
parser.add_argument('--set', type=str, default="cifar10", help='data set')
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = '{}'.format(args.gpus)
log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
format=log_format, datefmt='%m/%d %I:%M:%S %p')
CIFAR_CLASSES = 10
if args.set == 'cifar100':
CIFAR_CLASSES = 100
def main():
if not torch.cuda.is_available():
logging.info('no gpu device available')
sys.exit(1)
np.random.seed(args.seed)
cudnn.benchmark = True
torch.manual_seed(args.seed)
cudnn.enabled = True
torch.cuda.manual_seed(args.seed)
logging.info('gpu device = %s' % args.gpus)
logging.info("args = %s", args)
genotype = eval("genotypes.%s" % args.arch)
model = Network(args.init_channels, CIFAR_CLASSES, args.layers, args.auxiliary, genotype)
model = torch.nn.DataParallel(model)
model = model.cuda()
utils.load(model, args.model_path)
logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
criterion = nn.CrossEntropyLoss()
criterion = criterion.cuda()
if args.set == 'cifar100':
_, test_transform = utils._data_transforms_cifar100(args)
test_data = dset.CIFAR100(root=args.data, train=False, download=False, transform=test_transform)
else:
_, test_transform = utils._data_transforms_cifar10(args)
test_data = dset.CIFAR10(root=args.data, train=False, download=False, transform=test_transform)
test_queue = torch.utils.data.DataLoader(
test_data, batch_size=args.batch_size, shuffle=False, pin_memory=False, num_workers=2)
model.module.drop_path_prob = args.drop_path_prob
test_acc, test_obj = infer(test_queue, model, criterion)
logging.info('test_acc %f', test_acc)
def infer(test_queue, model, criterion):
objs = utils.AvgrageMeter()
top1 = utils.AvgrageMeter()
top5 = utils.AvgrageMeter()
model.eval()
for step, (input, target) in enumerate(test_queue):
input = input.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
logits, _ = model(input)
loss = criterion(logits, target)
prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
n = input.size(0)
objs.update(loss.item(), n)
top1.update(prec1.item(), n)
top5.update(prec5.item(), n)
if step % args.report_freq == 0:
logging.info('test %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
return top1.avg, objs.avg
if __name__ == '__main__':
main()