-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathbench-mobilenet-quant.py
121 lines (97 loc) · 4.31 KB
/
bench-mobilenet-quant.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import numpy as np
import tvm
from PIL import Image
from tvm import te
from tvm.contrib import graph_runtime
from tvm import relay
from tvm.runtime import container
from tvm.runtime import vm as vm_rt
from tvm.relay import testing
from tvm.relay import vm
from tvm.contrib.download import download_testdata
def extract(path):
import tarfile
if path.endswith("tgz") or path.endswith("gz"):
dir_path = os.path.dirname(path)
tar = tarfile.open(path)
tar.extractall(path=dir_path)
tar.close()
else:
raise RuntimeError('Could not decompress the file: ' + path)
def load_test_image(dtype='float32'):
image_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true'
image_path = download_testdata(image_url, 'cat.png', module='data')
resized_image = Image.open(image_path).resize((224, 224))
#image_data = np.asarray(resized_image).astype("float32")
image_data = np.asarray(resized_image).astype("int8")
# Add a dimension to the image so that we have NHWC format layout
image_data = np.expand_dims(image_data, axis=0)
# Preprocess image as described here:
# https://github.com/tensorflow/models/blob/edb6ed22a801665946c63d650ab9a0b23d98e1b1/research/slim/preprocessing/inception_preprocessing.py#L243
image_data[:, :, :, 0] = 2.0 / 255.0 * image_data[:, :, :, 0] - 1
image_data[:, :, :, 1] = 2.0 / 255.0 * image_data[:, :, :, 1] - 1
image_data[:, :, :, 2] = 2.0 / 255.0 * image_data[:, :, :, 2] - 1
print('input', image_data.shape)
return image_data
model_url = "http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz"
# Download model tar file and extract it to get mobilenet_v1_1.0_224.tflite
#model_path = download_testdata(model_url, "mobilenet_v1_1.0_224.tgz", module=['tf', 'official'])
#model_dir = os.path.dirname(model_path)
model_dir = './'
#extract(model_path)
model_name ='mobilenet_v1_0.5_128_quant.tflite'
# Now we can open mobilenet_v1_1.0_224.tflite
#tflite_model_file = os.path.join(model_dir, "mobilenet_v1_1.0_224.tflite")
tflite_model_file = os.path.join(model_dir, model_name)
tflite_model_buf = open(tflite_model_file, "rb").read()
# Get TFLite model from buffer
try:
import tflite
tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
except AttributeError:
import tflite.Model
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
image_data = load_test_image()
input_tensor = "input"
input_shape = (1, 224, 224, 3)
#input_dtype = "float32"
input_dtype = "int8"
# Parse TFLite model and convert it to a Relay module
mod, params = relay.frontend.from_tflite(tflite_model,
shape_dict={input_tensor: input_shape},
dtype_dict={input_tensor: input_dtype})
# Build the module against to x86 CPU
target = "llvm -mattr=+neon,+vfp4,+thumb2"
ctx = tvm.context(str(target), 0)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(mod, target, params=params)
# Create a runtime executor module
module = graph_runtime.create(graph, lib, tvm.cpu())
# Feed input data
module.set_input(input_tensor, tvm.nd.array(image_data))
# Feed related params
module.set_input(**params)
ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=10)
prof_res = np.array(ftimer().results) * 1000 # multiply 1000 for converting to millisecond
print("%-20s %-19s (%s)" % (model_name, "%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res)))
# Run
#module.run()
# Get output
#tvm_output = module.get_output(0).asnumpy()
# Load label file
#label_file_url = ''.join(['https://raw.githubusercontent.com/',
# 'tensorflow/tensorflow/master/tensorflow/lite/java/demo/',
# 'app/src/main/assets/',
# 'labels_mobilenet_quant_v1_224.txt'])
#label_file = "labels_mobilenet_quant_v1_224.txt"
#label_path = download_testdata(label_file_url, label_file, module='data')
# List of 1001 classes
#with open(label_path) as f:
# labels = f.readlines()
# Convert result to 1D data
#predictions = np.squeeze(tvm_output)
# Get top 1 prediction
#prediction = np.argmax(predictions)
# Convert id to class name and show the result
#print("The image prediction result is: id " + str(prediction) + " name: " + labels[prediction])