forked from DataXujing/wenet_trt8
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfix_quant_model.py
executable file
·65 lines (49 loc) · 2.5 KB
/
fix_quant_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import sys
import onnx
import onnx_graphsurgeon as gs
import numpy as np
if __name__ == "__main__":
input_model = sys.argv[1]
out_model = sys.argv[2]
graph = gs.import_onnx(onnx.load(input_model))
tmap = graph.tensors()
# You can figure out the input and output tensors using Netron. In our case:
# Inputs: [inp, MIN_VAL, MAX_VAL]
# Outputs: [max_out]
for node in graph.nodes:
if node.op =="Conv":
if node.attrs['group']==1:
bias_node = node.inputs[2].inputs[0]
bias_val = np.array(bias_node.inputs[0].values * bias_node.inputs[1].values, dtype=np.float32)
node.inputs[2] = gs.Constant(bias_node.name, bias_val)
if node.op =="Gemm" and len(node.inputs) == 3:
bias_node = node.inputs[2].inputs[0]
bias_val = np.array(bias_node.inputs[0].values * bias_node.inputs[1].values, dtype=np.float32)
node.inputs[2] = gs.Constant(bias_node.name, bias_val)
for node in graph.nodes:
if node.op =="DequantizeLinear" and \
isinstance(node.inputs[0], gs.ir.tensor.Constant):
if node.name in ["1006_DequantizeLinear", "encoder.encoders.1.norm_ff.weight_DequantizeLinear"]:
print(node)
if len(node.inputs[0].values.shape)==0:
# print(node.inputs[0].values.shape)
node.inputs[0].values.shape = (1,) # len(node.inputs[0].values.tolist())
# print(node.inputs[0].values.shape)
# const_w = gs.Constant(node.inputs[0].name, node.inputs[0].values.astype(np.float32))
# attrs_dict = {}
# Cast_output = gs.Variable(name=node.inputs[0].name+"_Cast_output",
# dtype=None, shape=None)
# attrs_dict['to'] = 3 # int8
# newNode = gs.Node(name=node.inputs[0].name+"_Cast", op="Cast", inputs=[const_w],
# outputs=[Cast_output], attrs=attrs_dict)
# graph.nodes.append(newNode) # 记得把新节点加入计算图中
# node.inputs[0] = Cast_output
if node.op =="QuantizeLinear" and \
isinstance(node.inputs[0], gs.ir.tensor.Constant):
if len(node.inputs[0].values.shape)==0:
# print(node.inputs[0].values.shape)
node.inputs[0].values.shape = (1,) # len(node.inputs[0].values.tolist())
# Remove the now-dangling subgraph.
graph.cleanup().toposort()
# That's it!
onnx.save(gs.export_onnx(graph), out_model)