-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_runtime.py
132 lines (125 loc) · 4.1 KB
/
test_runtime.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import sys
import caffe
import argparse
import cPickle as pickle
import lib
from lib import trace_pb2 as tpb
from lib import dbg, RuntimeClassifier
default_error_tolerance = {
'obj': 0.05,
'scene': 0.05,
'face': -0.07
}
default_datadir = {
'obj': lib.env.imagenet_val_dir,
'scene': lib.env.mitplaces_val_dir,
'face': lib.env.msrface_val_dir,
}
def test_synthesized_data(args, cl, outf):
datadir = default_datadir[args.task]
with open(args.input, 'rb') as f:
trace = pickle.load(f)
if args.model is not None:
dbg.dbg("compact model is %s" % args.model)
cl.set_compact_model(args.model)
correct = 0
total = 0
for frame, fn, lbl in trace:
total += 1
lbl = int(lbl)
img = caffe.io.load_image(os.path.join(datadir, fn))
predict = cl.predict(frame, img)
cl.history[-1].label = lbl
if predict == lbl: correct += 1
if args.frame is not None and total > args.frame: break
accuracy = float(correct) / total
avg_prepare_time, avg_forward_time = cl.get_avg_latency()
# output the trace
trace = tpb.Trace()
trace.points.extend(cl.history)
trace.accuracy = accuracy
trace.avg_prepare_time = avg_prepare_time
trace.avg_forward_time = avg_forward_time
if outf is None:
print(str(trace))
else:
with open(outf, 'w') as f:
f.write(str(trace))
def test_video(args, cl, outf):
total = 0
correct = 0
class_acc = {}
if args.model is not None:
dbg.dbg("compact model is %s" % args.model)
cl.set_compact_model(args.model)
with open(args.input) as f:
for line in f:
t = line.strip().split()
frame = int(t[0])
#if frame % 4 != 0: continue
lbl = int(t[2])
img = caffe.io.load_image(t[1])
predict = cl.predict(frame, img)
cl.history[-1].label = lbl
c = (predict == lbl)
total += 1
if lbl not in class_acc: class_acc[lbl] = [0, 0]
class_acc[lbl][1] += 1
if c:
correct += 1
class_acc[lbl][0] += 1
if args.frame is not None and total > args.frame: break
accuracy = float(correct) / total
avg_prepare_time, avg_forward_time = cl.get_avg_latency()
# output the trace
trace = tpb.Trace()
trace.points.extend(cl.history)
trace.accuracy = accuracy
trace.avg_prepare_time = avg_prepare_time
trace.avg_forward_time = avg_forward_time
if outf is None:
print(str(trace))
else:
with open(outf, 'w') as f:
f.write(str(trace))
# output the per-class accuracy
dbg.dbg('per-class accuracy')
for lbl in sorted(class_acc.keys()):
correct, total = class_acc[lbl]
acc = float(correct) / total
dbg.dbg('%s: %s (%s/%s)' % (lbl, acc, correct, total))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Test the runtime classifier.')
parser.add_argument('--type', choices=['syn', 'video'], help='Type of input data: synthesized / video data)')
parser.add_argument('--task', choices=['obj', 'scene', 'face'],
help='Recognition task: obj/scene/face')
parser.add_argument('-i', '--input', help='Specify the input trace')
parser.add_argument('-o', '--output', help='Specify the output trace file')
parser.add_argument('--model', default=None,
help="Set the compact model (default: None). " +
"If no compact model is specified, runtime only uses " +
"the oracle model to classify frames.")
parser.add_argument('--frame', type=int,
help="Set the max number of frames to process; " +
"if not set, process all frames in the input")
parser.add_argument('--cpu', action='store_true', help='Use CPU for classification')
args = parser.parse_args()
if args.type is None:
print('Missing type')
exit()
if args.input is None:
print('Missing input')
exit()
if args.task is None:
print('Missing task')
exit()
config = {
'cpu': args.cpu,
'error_tolerance': default_error_tolerance[args.task],
}
cl = RuntimeClassifier(args.task, **config)
if args.type == 'syn':
test_synthesized_data(args, cl, args.output)
else:
test_video(args, cl, args.output)