-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathindex.html
583 lines (537 loc) · 821 KB
/
index.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
<!DOCTYPE html PUBLIC "-//W3C//DTD HTML 4.01 Transitional//EN">
<html><head><meta http-equiv="Content-Type" content="text/html; charset=utf-8"><meta http-equiv="X-UA-Compatible" content="IE=edge,IE=9,chrome=1"><meta name="generator" content="MATLAB 2022b"><title>Quantization Aware Training for Transfer Learned MobileNet-v2</title><style type="text/css">.rtcContent { padding: 30px; } .S0 { margin: 3px 10px 5px 4px; padding: 0px; line-height: 28.8px; min-height: 0px; white-space: pre-wrap; color: rgb(192, 76, 11); font-family: Helvetica, Arial, sans-serif; font-style: normal; font-size: 24px; font-weight: 400; text-align: left; }
.S1 { margin: 2px 10px 9px 4px; padding: 0px; line-height: 21px; min-height: 0px; white-space: pre-wrap; color: rgb(33, 33, 33); font-family: Helvetica, Arial, sans-serif; font-style: normal; font-size: 14px; font-weight: 400; text-align: left; }
.S2 { margin: 3px 10px 5px 4px; padding: 0px; line-height: 20px; min-height: 0px; white-space: pre-wrap; color: rgb(33, 33, 33); font-family: Helvetica, Arial, sans-serif; font-style: normal; font-size: 20px; font-weight: 700; text-align: left; }
.CodeBlock { background-color: #F5F5F5; margin: 10px 0 10px 0; }
.S3 { border-left: 0.994318px solid rgb(191, 191, 191); border-right: 0.994318px solid rgb(191, 191, 191); border-top: 0.994318px solid rgb(191, 191, 191); border-bottom: 0.994318px solid rgb(191, 191, 191); border-radius: 4px 4px 0px 0px; padding: 6px 45px 4px 13px; line-height: 18.004px; min-height: 0px; white-space: nowrap; color: rgb(33, 33, 33); font-family: Menlo, Monaco, Consolas, "Courier New", monospace; font-size: 14px; }
.S4 { color: rgb(33, 33, 33); padding: 10px 0px 6px 17px; background: rgb(255, 255, 255) none repeat scroll 0% 0% / auto padding-box border-box; font-family: Menlo, Monaco, Consolas, "Courier New", monospace; font-size: 14px; overflow-x: hidden; line-height: 17.234px; }
/* Styling that is common to warnings and errors is in diagnosticOutput.css */.embeddedOutputsErrorElement { min-height: 18px; max-height: 550px;}
.embeddedOutputsErrorElement .diagnosticMessage-errorType { overflow: auto;}
.embeddedOutputsErrorElement.inlineElement {}
.embeddedOutputsErrorElement.rightPaneElement {}
/* Styling that is common to warnings and errors is in diagnosticOutput.css */.embeddedOutputsWarningElement { min-height: 18px; max-height: 550px;}
.embeddedOutputsWarningElement .diagnosticMessage-warningType { overflow: auto;}
.embeddedOutputsWarningElement.inlineElement {}
.embeddedOutputsWarningElement.rightPaneElement {}
/* Copyright 2015-2019 The MathWorks, Inc. *//* In this file, styles are not scoped to rtcContainer since they could be in the Dojo Tooltip */.diagnosticMessage-wrapper { font-family: Menlo, Monaco, Consolas, "Courier New", monospace; font-size: 12px;}
.diagnosticMessage-wrapper.diagnosticMessage-warningType { color: rgb(255,100,0);}
.diagnosticMessage-wrapper.diagnosticMessage-warningType a { color: rgb(255,100,0); text-decoration: underline;}
.diagnosticMessage-wrapper.diagnosticMessage-errorType { color: rgb(230,0,0);}
.diagnosticMessage-wrapper.diagnosticMessage-errorType a { color: rgb(230,0,0); text-decoration: underline;}
.diagnosticMessage-wrapper .diagnosticMessage-messagePart,.diagnosticMessage-wrapper .diagnosticMessage-causePart { white-space: pre-wrap;}
.diagnosticMessage-wrapper .diagnosticMessage-stackPart { white-space: pre;}
.embeddedOutputsTextElement,.embeddedOutputsVariableStringElement { white-space: pre; word-wrap: initial; min-height: 18px; max-height: 550px;}
.embeddedOutputsTextElement .textElement,.embeddedOutputsVariableStringElement .textElement { overflow: auto;}
.textElement,.rtcDataTipElement .textElement { padding-top: 2px;}
.embeddedOutputsTextElement.inlineElement,.embeddedOutputsVariableStringElement.inlineElement {}
.inlineElement .textElement {}
.embeddedOutputsTextElement.rightPaneElement,.embeddedOutputsVariableStringElement.rightPaneElement { min-height: 16px;}
.rightPaneElement .textElement { padding-top: 2px; padding-left: 9px;}
.S5 { border-left: 0.994318px solid rgb(191, 191, 191); border-right: 0.994318px solid rgb(191, 191, 191); border-top: 0.994318px solid rgb(191, 191, 191); border-bottom: 0px none rgb(33, 33, 33); border-radius: 0px; padding: 6px 45px 0px 13px; line-height: 18.004px; min-height: 0px; white-space: nowrap; color: rgb(33, 33, 33); font-family: Menlo, Monaco, Consolas, "Courier New", monospace; font-size: 14px; }
.S6 { border-left: 0.994318px solid rgb(191, 191, 191); border-right: 0.994318px solid rgb(191, 191, 191); border-top: 0px none rgb(33, 33, 33); border-bottom: 0px none rgb(33, 33, 33); border-radius: 0px; padding: 0px 45px 0px 13px; line-height: 18.004px; min-height: 0px; white-space: nowrap; color: rgb(33, 33, 33); font-family: Menlo, Monaco, Consolas, "Courier New", monospace; font-size: 14px; }
.S7 { border-left: 0.994318px solid rgb(191, 191, 191); border-right: 0.994318px solid rgb(191, 191, 191); border-top: 0px none rgb(33, 33, 33); border-bottom: 0.994318px solid rgb(191, 191, 191); border-radius: 0px 0px 4px 4px; padding: 0px 45px 4px 13px; line-height: 18.004px; min-height: 0px; white-space: nowrap; color: rgb(33, 33, 33); font-family: Menlo, Monaco, Consolas, "Courier New", monospace; font-size: 14px; }
.S8 { margin: 10px 10px 9px 4px; padding: 0px; line-height: 21px; min-height: 0px; white-space: pre-wrap; color: rgb(33, 33, 33); font-family: Helvetica, Arial, sans-serif; font-style: normal; font-size: 14px; font-weight: 400; text-align: left; }
.embeddedOutputsMatrixElement,.eoOutputWrapper .matrixElement { min-height: 18px; box-sizing: border-box;}
.embeddedOutputsMatrixElement .matrixElement,.eoOutputWrapper .matrixElement,.rtcDataTipElement .matrixElement { position: relative;}
.matrixElement .variableValue,.rtcDataTipElement .matrixElement .variableValue { white-space: pre; display: inline-block; vertical-align: top; overflow: hidden;}
.embeddedOutputsMatrixElement.inlineElement {}
.embeddedOutputsMatrixElement.inlineElement .topHeaderWrapper { display: none;}
.embeddedOutputsMatrixElement.inlineElement .veTable .body { padding-top: 0 !important; max-height: 100px;}
.inlineElement .matrixElement { max-height: 300px;}
.embeddedOutputsMatrixElement.rightPaneElement {}
.rightPaneElement .matrixElement,.rtcDataTipElement .matrixElement { overflow: hidden; padding-left: 9px;}
.rightPaneElement .matrixElement { margin-bottom: -1px;}
.embeddedOutputsMatrixElement .matrixElement .valueContainer,.eoOutputWrapper .matrixElement .valueContainer,.rtcDataTipElement .matrixElement .valueContainer { white-space: nowrap; margin-bottom: 3px;}
.embeddedOutputsMatrixElement .matrixElement .valueContainer .horizontalEllipsis.hide,.embeddedOutputsMatrixElement .matrixElement .verticalEllipsis.hide,.eoOutputWrapper .matrixElement .valueContainer .horizontalEllipsis.hide,.eoOutputWrapper .matrixElement .verticalEllipsis.hide,.rtcDataTipElement .matrixElement .valueContainer .horizontalEllipsis.hide,.rtcDataTipElement .matrixElement .verticalEllipsis.hide { display: none;}
.embeddedOutputsVariableMatrixElement .matrixElement .valueContainer.hideEllipses .verticalEllipsis, .embeddedOutputsVariableMatrixElement .matrixElement .valueContainer.hideEllipses .horizontalEllipsis { display:none;}
.embeddedOutputsMatrixElement .matrixElement .valueContainer .horizontalEllipsis,.eoOutputWrapper .matrixElement .valueContainer .horizontalEllipsis { margin-bottom: -3px;}
.eoOutputWrapper .embeddedOutputsVariableMatrixElement .matrixElement .valueContainer { cursor: default !important;}
.embeddedOutputsVariableElement { white-space: pre-wrap; word-wrap: break-word; min-height: 18px; max-height: 250px; overflow: auto;}
.variableElement {}
.embeddedOutputsVariableElement.inlineElement {}
.inlineElement .variableElement {}
.embeddedOutputsVariableElement.rightPaneElement { min-height: 16px;}
.rightPaneElement .variableElement { padding-top: 2px; padding-left: 9px;}
.outputsOnRight .embeddedOutputsVariableElement.rightPaneElement .eoOutputContent { /* Remove extra space allocated for navigation border */ margin-top: 0; margin-bottom: 0;}
.variableNameElement { margin-bottom: 3px; display: inline-block;}
/* * Ellipses as base64 for HTML export. */.matrixElement .horizontalEllipsis,.rtcDataTipElement .matrixElement .horizontalEllipsis { display: inline-block; margin-top: 3px; /* base64 encoded version of images-liveeditor/HEllipsis.png */ width: 30px; height: 12px; background-repeat: no-repeat; background-image: url("");}
.matrixElement .verticalEllipsis,.textElement .verticalEllipsis,.rtcDataTipElement .matrixElement .verticalEllipsis,.rtcDataTipElement .textElement .verticalEllipsis { margin-left: 35px; /* base64 encoded version of images-liveeditor/VEllipsis.png */ width: 12px; height: 30px; background-repeat: no-repeat; background-image: url("");}
.S9 { border-left: 0.994318px solid rgb(191, 191, 191); border-right: 0.994318px solid rgb(191, 191, 191); border-top: 0.994318px solid rgb(191, 191, 191); border-bottom: 0.994318px solid rgb(191, 191, 191); border-radius: 4px; padding: 6px 45px 4px 13px; line-height: 18.004px; min-height: 0px; white-space: nowrap; color: rgb(33, 33, 33); font-family: Menlo, Monaco, Consolas, "Courier New", monospace; font-size: 14px; }
.S10 { border-left: 0.994318px solid rgb(191, 191, 191); border-right: 0.994318px solid rgb(191, 191, 191); border-top: 0.994318px solid rgb(191, 191, 191); border-bottom: 0px none rgb(33, 33, 33); border-radius: 4px 4px 0px 0px; padding: 6px 45px 0px 13px; line-height: 18.004px; min-height: 0px; white-space: nowrap; color: rgb(33, 33, 33); font-family: Menlo, Monaco, Consolas, "Courier New", monospace; font-size: 14px; }
.S11 { border-left: 0.994318px solid rgb(191, 191, 191); border-right: 0.994318px solid rgb(191, 191, 191); border-top: 0px none rgb(33, 33, 33); border-bottom: 0.994318px solid rgb(191, 191, 191); border-radius: 0px; padding: 0px 45px 4px 13px; line-height: 18.004px; min-height: 0px; white-space: nowrap; color: rgb(33, 33, 33); font-family: Menlo, Monaco, Consolas, "Courier New", monospace; font-size: 14px; }
.S12 { margin: 10px 0px 20px; padding-left: 0px; font-family: Helvetica, Arial, sans-serif; font-size: 14px; }
.S13 { margin-left: 56px; line-height: 21px; min-height: 0px; text-align: left; white-space: pre-wrap; }
.S14 { margin: 15px 10px 5px 4px; padding: 0px; line-height: 18px; min-height: 0px; white-space: pre-wrap; color: rgb(33, 33, 33); font-family: Helvetica, Arial, sans-serif; font-style: normal; font-size: 17px; font-weight: 700; text-align: left; }
.S15 { margin: 10px 10px 5px 4px; padding: 0px; line-height: 18px; min-height: 0px; white-space: pre-wrap; color: rgb(33, 33, 33); font-family: Helvetica, Arial, sans-serif; font-style: normal; font-size: 15px; font-weight: 700; text-align: left; }</style></head><body><div class = rtcContent><h1 class = 'S0' id = 'T_5E1BFFD0' ><span>Quantization Aware Training for Transfer Learned MobileNet-v2</span></h1><div class = 'S1'><span>This example shows how to perform quantization aware training for transfer learned MobileNet-v2 network.</span></div><div class = 'S1'><span>Low precision types like </span><span style=' font-family: monospace;'>int8 </span><span>propagate quantization error that may degrade the accuracy of the network. Quantization aware training is a method that introduces quantization error at training, thus giving the network the ability to adapt and ultimately produce a network more robust to quantization. In most cases, a quantized network or integer-arithmetic only network constructed after quantization aware training can produce accuracy on par with the original floating point network.</span></div><div class = 'S1'><span>This example takes you through the quantization workflow of a transfer learned MobileNet-v2 network. MobileNet-v2 was chosen for this example because it contains depthwise-separable convolution layers that are especially sensitive to post-training quantization.</span></div><div class = 'S1'><span>The flowchart below highlights the steps necessary to convert a trained network into a quantized one via quantization aware training.</span></div><div class = 'S1'><img class = "imageNode" src = "" width = "292" height = "522" alt = "qat_workflow.png" style = "vertical-align: baseline; width: 292px; height: 522px;"></img></div><div class = 'S1'><span></span></div><h2 class = 'S2' id = 'H_FCCD6D67' ><span>Load Flower Dataset</span></h2><div class = 'S1'><span>Download the flower dataset [</span><a href = "#M_2AA33B76"><span>1</span></a><span>] using the supporting function </span><a href = "#H_D08A0BC5"><span style=' font-family: monospace;'>downloadFlowerDataset</span></a><span>.</span></div><div class="CodeBlock"><div class="inlineWrapper outputs"><div class = 'S3'><span style="white-space: pre"><span >imageFolder = downloadFlowerDataset;</span></span></div><div class = 'S4'><div class="inlineElement eoOutputWrapper embeddedOutputsTextElement" uid="26772109" prevent-scroll="true" data-testid="output_0" style="width: 938.026px; white-space: pre; font-style: normal; color: rgb(33, 33, 33); font-size: 12px;"><div class="textElement eoOutputContent" data-width="908" data-height="16" data-hashorizontaloverflow="false" style="max-height: 261px; white-space: pre; font-style: normal; color: rgb(33, 33, 33); font-size: 12px;">Downloading Flower Dataset (218 MB)...</div></div></div></div><div class="inlineWrapper"><div class = 'S5'><span style="white-space: pre"><span >imds = imageDatastore(imageFolder, </span><span style="color: rgb(14, 0, 255);">...</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > IncludeSubfolders=true, </span><span style="color: rgb(14, 0, 255);">...</span></span></div></div><div class="inlineWrapper"><div class = 'S7'><span style="white-space: pre"><span > LabelSource=</span><span style="color: rgb(167, 9, 245);">"foldernames"</span><span >);</span></span></div></div></div><div class = 'S8'><span>Inspect the classes of the data.</span></div><div class="CodeBlock"><div class="inlineWrapper outputs"><div class = 'S3'><span style="white-space: pre"><span >classes = string(categories(imds.Labels))</span></span></div><div class = 'S4'><div class="inlineElement eoOutputWrapper embeddedOutputsTextMatrixElement embeddedOutputsVariableMatrixElement" uid="669818DB" prevent-scroll="true" data-testid="output_1" data-width="908" style="width: 938.026px; white-space: normal; font-style: normal; color: rgb(33, 33, 33); font-size: 12px;"><div class="matrixElement veSpecifier saveLoad eoOutputContent" style="white-space: normal; font-style: normal; color: rgb(33, 33, 33); font-size: 12px;"><div class="veVariableName variableNameElement" style="width: 908px; white-space: normal; font-style: normal; color: rgb(33, 33, 33); font-size: 12px;"><div class="headerElementClickToInteract" style="white-space: normal; font-style: normal; color: rgb(33, 33, 33); font-size: 12px;"><span style="white-space: normal; font-style: normal; color: rgb(33, 33, 33); font-size: 12px;">classes = </span><span class="veVariableValueSummary headerElement" style="white-space: normal; font-style: normal; color: rgb(179, 179, 179); font-size: 12px;">5×1 string</span></div></div><div class="valueContainer" data-layout="{"totalRows":"5","totalColumns":"1"}" style="white-space: pre; font-style: normal; color: rgb(33, 33, 33); font-size: 12px;"><div class="variableValue" style="width: 82px; white-space: pre; font-style: normal; color: rgb(33, 33, 33); font-size: 12px;">"daisy" <br>"dandelion" <br>"roses" <br>"sunflowers" <br>"tulips" <br></div><div class="horizontalEllipsis hide" style="white-space: pre; font-style: normal; color: rgb(33, 33, 33); font-size: 12px;"></div><div class="verticalEllipsis hide" style="white-space: pre; font-style: normal; color: rgb(33, 33, 33); font-size: 12px;"></div></div></div><div class="outputLayer selectedOutputDecorationLayer doNotExport" style="white-space: normal; font-style: normal; color: rgb(33, 33, 33); font-size: 12px;"></div><div class="outputLayer activeOutputDecorationLayer doNotExport" style="white-space: normal; font-style: normal; color: rgb(33, 33, 33); font-size: 12px;"></div><div class="outputLayer scrollableOutputDecorationLayer doNotExport" style="white-space: normal; font-style: normal; color: rgb(33, 33, 33); font-size: 12px;"></div></div></div></div></div><h2 class = 'S2' id = 'H_A091A87E' ><span>Perform Transfer Learning on MobileNet-v2</span></h2><div class = 'S1'><span>MobileNet-v2 is a convolutional netural network 53 layers deep. The pretrained version of the network is trained on more than a million images from the ImageNet database.</span></div><div class="CodeBlock"><div class="inlineWrapper"><div class = 'S9'><span style="white-space: pre"><span >net = mobilenetv2;</span></span></div></div></div><div class = 'S8'><span>Split the data into training and validation sets and create augmented image datastores that automatically resize the images to the input size of the network.</span></div><div class="CodeBlock"><div class="inlineWrapper"><div class = 'S10'><span style="white-space: pre"><span >[imdsTrain,imdsValidation] = splitEachLabel(imds,0.9);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'> </div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span >inputSize = net.Layers(1).InputSize;</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span >augimdsTrain = augmentedImageDatastore(inputSize,imdsTrain);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span >augimdsValidation = augmentedImageDatastore(inputSize,imdsValidation);</span></span></div></div><div class="inlineWrapper"><div class = 'S7'><span style="white-space: pre"><span >validationActualLabels = imdsValidation.Labels;</span></span></div></div></div><div class = 'S8'><span>Set aside a portion of the training dataset to use during the calibration step of quantization. This datastore should be representative of the data used for training but ideally separate from the one used to validate.</span></div><div class="CodeBlock"><div class="inlineWrapper"><div class = 'S9'><span style="white-space: pre"><span >augimdsCalibration = subset(shuffle(augimdsTrain),1:320);</span></span></div></div></div><div class = 'S8'><span>Perform </span><a href = "#H_C9FB9125"><span>transfer learning</span></a><span> on the network on the flowers image dataset. The learnable parameters of the trained network </span><span style=' font-family: monospace;'>transferNet</span><span> are in </span><span style=' font-family: monospace;'>single</span><span> precision.</span></div><div class="CodeBlock"><div class="inlineWrapper outputs"><div class = 'S3'><span style="white-space: pre"><span >transferNet = createFlowerNetwork(net,augimdsTrain,augimdsValidation,classes);</span></span></div><div class = 'S4'><div class="inlineElement eoOutputWrapper embeddedOutputsFigure" uid="41BDF570" prevent-scroll="true" data-testid="output_2" style="width: 938.026px;"><div class="figureElement eoOutputContent"><img class="figureImage figureContainingNode" src="" style="width: 987px; padding-bottom: 0px;"></div><div class="outputLayer selectedOutputDecorationLayer doNotExport"></div><div class="outputLayer activeOutputDecorationLayer doNotExport"></div><div class="outputLayer scrollableOutputDecorationLayer doNotExport"></div></div></div></div></div><h2 class = 'S2' id = 'H_1B4CE57B' ><span>Evaluate Baseline Network Performance</span></h2><div class = 'S1'><span>Evaluate the performance of the </span><span style=' font-family: monospace;'>single</span><span> precision network. Performance in this case is defined as the correct classification rate.</span></div><div class="CodeBlock"><div class="inlineWrapper outputs"><div class = 'S3'><span style="white-space: pre"><span >netCCR = evaluateModelAccuracy(transferNet,augimdsValidation,validationActualLabels) </span></span></div><div class = 'S4'><div class='variableElement' style='font-family: Menlo, Monaco, Consolas, "Courier New", monospace; font-size: 12px; '>netCCR = 0.9101</div></div></div></div><div class = 'S8'><span>Quantize the network using the </span><a href = "#H_138C14B1"><span style=' font-family: monospace;'>createQuantizedNetwork</span></a><span> function provided at the end of this example and evaluate the performance of the quantized network. Post-training quantization of the original network yields poor performance due to the range of learnable values in the depthwise separable convolution layers. An accuracy of roughly 20% is the equivalent to guessing one of the 5 possible labels for each image.</span></div><div class="CodeBlock"><div class="inlineWrapper"><div class = 'S10'><span style="white-space: pre"><span >originalQuantizedNet = createQuantizedNetwork(transferNet,augimdsCalibration);</span></span></div></div><div class="inlineWrapper outputs"><div class = 'S11'><span style="white-space: pre"><span >originalQuantizedCCR = evaluateModelAccuracy(originalQuantizedNet,augimdsValidation,validationActualLabels)</span></span></div><div class = 'S4'><div class='variableElement' style='font-family: Menlo, Monaco, Consolas, "Courier New", monospace; font-size: 12px; '>originalQuantizedCCR = 0.2452</div></div></div><div class="inlineWrapper"><div class = 'S5'><span style="white-space: pre"><span >bar( </span><span style="color: rgb(14, 0, 255);">...</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > categorical([</span><span style="color: rgb(167, 9, 245);">"Original network"</span><span >,</span><span style="color: rgb(167, 9, 245);">"Post-training quantized network"</span><span >]), </span><span style="color: rgb(14, 0, 255);">...</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > [netCCR,originalQuantizedCCR] </span><span style="color: rgb(14, 0, 255);">...</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > )</span></span></div></div><div class="inlineWrapper outputs"><div class = 'S11'><span style="white-space: pre"><span >ylabel(</span><span style="color: rgb(167, 9, 245);">"Network Accuracy (%)"</span><span >)</span></span></div><div class = 'S4'><div class="inlineElement eoOutputWrapper embeddedOutputsFigure" uid="50F59A12" prevent-scroll="true" data-testid="output_5" style="width: 938.026px;"><div class="figureElement eoOutputContent"><img class="figureImage figureContainingNode" src="" style="width: 616px; padding-bottom: 0px;"></div><div class="outputLayer selectedOutputDecorationLayer doNotExport"></div><div class="outputLayer activeOutputDecorationLayer doNotExport"></div><div class="outputLayer scrollableOutputDecorationLayer doNotExport"></div></div></div></div></div><h2 class = 'S2' id = 'H_07814550' ><span>Replace Network Layers with Quantization Aware Training Layers</span></h2><div class = 'S1'><span>Replace the </span><span style=' font-family: monospace;'>Convolution2D</span><span> and </span><span style=' font-family: monospace;'>GroupedConvolution2D</span><span> layers along with their adjacent </span><span style=' font-family: monospace;'>BatchNormalization</span><span> layers with custom layers that are quantization aware using the </span><a href = "#H_157F752C"><span style=' font-family: monospace;'>makeQuantizationAwareLayers</span></a><span> function provided with this example. The quantization aware layers are </span><a href = "./QuantizedConvolutionBatchNormTrainingLayer.m"><span>custom layers</span></a><span> that have modified </span><span style=' font-family: monospace;'>forward</span><span> and </span><span style=' font-family: monospace;'>predict</span><span> behavior that inject quantization error similar to that of post-training quantization. The quantization error comes from the </span><a href = "./quantizeToFloat.m"><span>quantizeToFloat</span></a><span> function that quantizes, then unquantizes a given value using best-precision scaling to </span><span style=' font-family: monospace;'>int8</span><span> precision. </span></div><div class = 'S1'><span></span></div><div class = 'S1'><span>Quantization to float can be expressed as follows.</span></div><div class = 'S1'><span> </span><span mathmlencoding="<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><mrow><mtable columnalign="left"><mtr><mtd><mrow><mi>𝑥</mi><mi>̂</mi><mo>=</mo><mi mathvariant="normal">quantizeToFloat</mi><mo>(</mo><mi mathvariant="normal">𝑥</mi><mo>)</mo></mrow></mtd></mtr><mtr><mtd><mrow><mtext> </mtext><mtext> </mtext><mtext> </mtext><mo>=</mo><mi mathvariant="normal">unquantize</mi><mo>(</mo><mi mathvariant="normal">quantize</mi><mo>(</mo><mi>𝑥</mi><mo>)</mo><mo>)</mo></mrow></mtd></mtr><mtr><mtd><mrow><mtext> </mtext><mtext> </mtext><mtext> </mtext><mo>=</mo><mi mathvariant="normal">rescale</mi><mo>⋅</mo><mi mathvariant="normal">saturate</mi><mo>(</mo><mi mathvariant="normal">round</mi><mo>(</mo><mfrac><mrow><mi>𝑥</mi></mrow><mrow><mi mathvariant="normal">scale</mi></mrow></mfrac><mo>)</mo><mo>)</mo></mrow></mtd></mtr></mtable></mrow></math>" style="vertical-align:-34px"><img src="" width="230.5" height="80" /></span></div><div class = 'S1'><span>The quantization step uses a non-differentiable operation </span><span style=' font-family: monospace;'>round</span><span> that would normally break the training workflow by zeroing out the gradients. During quantization aware training, bypass the gradient calculations for non-differentiable operations using an identity function. The diagram below [</span><a href = "#M_2AA33B76"><span>2</span></a><span>] shows how the custom layer calculates the gradients for non-differentiable operations with the identity function via straight-through estimation.</span></div><div class = 'S1'><img class = "imageNode" src = "" width = "869" height = "400" alt = "ste.png" style = "vertical-align: baseline; width: 869px; height: 400px;"></img></div><div class = 'S1'><span></span></div><div class = 'S1'><span></span></div><div class = 'S1'><span>For 2-D convolution layers, the weights and biases of the replacement layers include the batch normalization layer statistics. Convolution operations during training use the adjusted and quantized weights [</span><a href = "#M_2AA33B76"><span>3</span></a><span>].</span></div><div class = 'S1'><span>As the batch normalization layer statistics are incorporated into the convolution layers, the </span><span style=' font-family: monospace;'>makeQuantizationAwareLayers</span><span> replaces each batch normalization layer with an identity layer that returns its input as its output. </span></div><div class="CodeBlock"><div class="inlineWrapper"><div class = 'S9'><span style="white-space: pre"><span >quantizationAwareLayerGraph = makeQuantizationAwareLayers(transferNet);</span></span></div></div></div><div class = 'S8'><span>Inspect the layers of the network.</span></div><div class="CodeBlock"><div class="inlineWrapper outputs"><div class = 'S3'><span style="white-space: pre"><span >quantizationAwareLayerGraph.Layers</span></span></div><div class = 'S4'><div class="inlineElement eoOutputWrapper embeddedOutputsVariableStringElement scrollableOutput" uid="569A54B1" prevent-scroll="true" data-testid="output_6" style="width: 938.026px; white-space: pre; font-style: normal; color: rgb(33, 33, 33); font-size: 12px;"><div class="textElement eoOutputContent" data-width="908" data-height="2325" data-hashorizontaloverflow="false" style="max-height: 261px; white-space: pre; font-style: normal; color: rgb(33, 33, 33); font-size: 12px;"><div style="white-space: pre; font-style: normal; color: rgb(33, 33, 33); font-size: 12px;"><span class="variableNameElement" style="white-space: pre; font-style: normal; color: rgb(33, 33, 33); font-size: 12px;">ans = </span></div><div style="white-space: pre; font-style: normal; color: rgb(33, 33, 33); font-size: 12px;"> 154×1 Layer array with layers:
1 'input_1' Image Input 224×224×3 images with 'zscore' normalization
2 'Conv1' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
3 'bn_Conv1' Identity Training Layer No operation to forward behavior
4 'Conv1_relu' Clipped ReLU Clipped ReLU with ceiling 6
5 'expanded_conv_depthwise' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
6 'expanded_conv_depthwise_BN' Identity Training Layer No operation to forward behavior
7 'expanded_conv_depthwise_relu' Clipped ReLU Clipped ReLU with ceiling 6
8 'expanded_conv_project' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
9 'expanded_conv_project_BN' Identity Training Layer No operation to forward behavior
10 'block_1_expand' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
11 'block_1_expand_BN' Identity Training Layer No operation to forward behavior
12 'block_1_expand_relu' Clipped ReLU Clipped ReLU with ceiling 6
13 'block_1_depthwise' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
14 'block_1_depthwise_BN' Identity Training Layer No operation to forward behavior
15 'block_1_depthwise_relu' Clipped ReLU Clipped ReLU with ceiling 6
16 'block_1_project' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
17 'block_1_project_BN' Identity Training Layer No operation to forward behavior
18 'block_2_expand' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
19 'block_2_expand_BN' Identity Training Layer No operation to forward behavior
20 'block_2_expand_relu' Clipped ReLU Clipped ReLU with ceiling 6
21 'block_2_depthwise' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
22 'block_2_depthwise_BN' Identity Training Layer No operation to forward behavior
23 'block_2_depthwise_relu' Clipped ReLU Clipped ReLU with ceiling 6
24 'block_2_project' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
25 'block_2_project_BN' Identity Training Layer No operation to forward behavior
26 'block_2_add' Addition Element-wise addition of 2 inputs
27 'block_3_expand' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
28 'block_3_expand_BN' Identity Training Layer No operation to forward behavior
29 'block_3_expand_relu' Clipped ReLU Clipped ReLU with ceiling 6
30 'block_3_depthwise' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
31 'block_3_depthwise_BN' Identity Training Layer No operation to forward behavior
32 'block_3_depthwise_relu' Clipped ReLU Clipped ReLU with ceiling 6
33 'block_3_project' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
34 'block_3_project_BN' Identity Training Layer No operation to forward behavior
35 'block_4_expand' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
36 'block_4_expand_BN' Identity Training Layer No operation to forward behavior
37 'block_4_expand_relu' Clipped ReLU Clipped ReLU with ceiling 6
38 'block_4_depthwise' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
39 'block_4_depthwise_BN' Identity Training Layer No operation to forward behavior
40 'block_4_depthwise_relu' Clipped ReLU Clipped ReLU with ceiling 6
41 'block_4_project' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
42 'block_4_project_BN' Identity Training Layer No operation to forward behavior
43 'block_4_add' Addition Element-wise addition of 2 inputs
44 'block_5_expand' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
45 'block_5_expand_BN' Identity Training Layer No operation to forward behavior
46 'block_5_expand_relu' Clipped ReLU Clipped ReLU with ceiling 6
47 'block_5_depthwise' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
48 'block_5_depthwise_BN' Identity Training Layer No operation to forward behavior
49 'block_5_depthwise_relu' Clipped ReLU Clipped ReLU with ceiling 6
50 'block_5_project' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
51 'block_5_project_BN' Identity Training Layer No operation to forward behavior
52 'block_5_add' Addition Element-wise addition of 2 inputs
53 'block_6_expand' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
54 'block_6_expand_BN' Identity Training Layer No operation to forward behavior
55 'block_6_expand_relu' Clipped ReLU Clipped ReLU with ceiling 6
56 'block_6_depthwise' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
57 'block_6_depthwise_BN' Identity Training Layer No operation to forward behavior
58 'block_6_depthwise_relu' Clipped ReLU Clipped ReLU with ceiling 6
59 'block_6_project' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
60 'block_6_project_BN' Identity Training Layer No operation to forward behavior
61 'block_7_expand' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
62 'block_7_expand_BN' Identity Training Layer No operation to forward behavior
63 'block_7_expand_relu' Clipped ReLU Clipped ReLU with ceiling 6
64 'block_7_depthwise' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
65 'block_7_depthwise_BN' Identity Training Layer No operation to forward behavior
66 'block_7_depthwise_relu' Clipped ReLU Clipped ReLU with ceiling 6
67 'block_7_project' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
68 'block_7_project_BN' Identity Training Layer No operation to forward behavior
69 'block_7_add' Addition Element-wise addition of 2 inputs
70 'block_8_expand' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
71 'block_8_expand_BN' Identity Training Layer No operation to forward behavior
72 'block_8_expand_relu' Clipped ReLU Clipped ReLU with ceiling 6
73 'block_8_depthwise' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
74 'block_8_depthwise_BN' Identity Training Layer No operation to forward behavior
75 'block_8_depthwise_relu' Clipped ReLU Clipped ReLU with ceiling 6
76 'block_8_project' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
77 'block_8_project_BN' Identity Training Layer No operation to forward behavior
78 'block_8_add' Addition Element-wise addition of 2 inputs
79 'block_9_expand' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
80 'block_9_expand_BN' Identity Training Layer No operation to forward behavior
81 'block_9_expand_relu' Clipped ReLU Clipped ReLU with ceiling 6
82 'block_9_depthwise' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
83 'block_9_depthwise_BN' Identity Training Layer No operation to forward behavior
84 'block_9_depthwise_relu' Clipped ReLU Clipped ReLU with ceiling 6
85 'block_9_project' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
86 'block_9_project_BN' Identity Training Layer No operation to forward behavior
87 'block_9_add' Addition Element-wise addition of 2 inputs
88 'block_10_expand' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
89 'block_10_expand_BN' Identity Training Layer No operation to forward behavior
90 'block_10_expand_relu' Clipped ReLU Clipped ReLU with ceiling 6
91 'block_10_depthwise' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
92 'block_10_depthwise_BN' Identity Training Layer No operation to forward behavior
93 'block_10_depthwise_relu' Clipped ReLU Clipped ReLU with ceiling 6
94 'block_10_project' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
95 'block_10_project_BN' Identity Training Layer No operation to forward behavior
96 'block_11_expand' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
97 'block_11_expand_BN' Identity Training Layer No operation to forward behavior
98 'block_11_expand_relu' Clipped ReLU Clipped ReLU with ceiling 6
99 'block_11_depthwise' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
100 'block_11_depthwise_BN' Identity Training Layer No operation to forward behavior
101 'block_11_depthwise_relu' Clipped ReLU Clipped ReLU with ceiling 6
102 'block_11_project' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
103 'block_11_project_BN' Identity Training Layer No operation to forward behavior
104 'block_11_add' Addition Element-wise addition of 2 inputs
105 'block_12_expand' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
106 'block_12_expand_BN' Identity Training Layer No operation to forward behavior
107 'block_12_expand_relu' Clipped ReLU Clipped ReLU with ceiling 6
108 'block_12_depthwise' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
109 'block_12_depthwise_BN' Identity Training Layer No operation to forward behavior
110 'block_12_depthwise_relu' Clipped ReLU Clipped ReLU with ceiling 6
111 'block_12_project' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
112 'block_12_project_BN' Identity Training Layer No operation to forward behavior
113 'block_12_add' Addition Element-wise addition of 2 inputs
114 'block_13_expand' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
115 'block_13_expand_BN' Identity Training Layer No operation to forward behavior
116 'block_13_expand_relu' Clipped ReLU Clipped ReLU with ceiling 6
117 'block_13_depthwise' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
118 'block_13_depthwise_BN' Identity Training Layer No operation to forward behavior
119 'block_13_depthwise_relu' Clipped ReLU Clipped ReLU with ceiling 6
120 'block_13_project' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
121 'block_13_project_BN' Identity Training Layer No operation to forward behavior
122 'block_14_expand' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
123 'block_14_expand_BN' Identity Training Layer No operation to forward behavior
124 'block_14_expand_relu' Clipped ReLU Clipped ReLU with ceiling 6
125 'block_14_depthwise' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
126 'block_14_depthwise_BN' Identity Training Layer No operation to forward behavior
127 'block_14_depthwise_relu' Clipped ReLU Clipped ReLU with ceiling 6
128 'block_14_project' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
129 'block_14_project_BN' Identity Training Layer No operation to forward behavior
130 'block_14_add' Addition Element-wise addition of 2 inputs
131 'block_15_expand' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
132 'block_15_expand_BN' Identity Training Layer No operation to forward behavior
133 'block_15_expand_relu' Clipped ReLU Clipped ReLU with ceiling 6
134 'block_15_depthwise' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
135 'block_15_depthwise_BN' Identity Training Layer No operation to forward behavior
136 'block_15_depthwise_relu' Clipped ReLU Clipped ReLU with ceiling 6
137 'block_15_project' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
138 'block_15_project_BN' Identity Training Layer No operation to forward behavior
139 'block_15_add' Addition Element-wise addition of 2 inputs
140 'block_16_expand' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
141 'block_16_expand_BN' Identity Training Layer No operation to forward behavior
142 'block_16_expand_relu' Clipped ReLU Clipped ReLU with ceiling 6
143 'block_16_depthwise' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
144 'block_16_depthwise_BN' Identity Training Layer No operation to forward behavior
145 'block_16_depthwise_relu' Clipped ReLU Clipped ReLU with ceiling 6
146 'block_16_project' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
147 'block_16_project_BN' Identity Training Layer No operation to forward behavior
148 'Conv_1' Quantized Fused Convolution Layer Quantization Aware Conv-BN Layer Group for Training
149 'Conv_1_bn' Identity Training Layer No operation to forward behavior
150 'out_relu' Clipped ReLU Clipped ReLU with ceiling 6
151 'global_average_pooling2d_1' 2-D Global Average Pooling 2-D global average pooling
152 'new_fc' Fully Connected 5 fully connected layer
153 'Logits_softmax' Softmax softmax
154 'new_classoutput' Classification Output crossentropyex with 'daisy' and 4 other classes</div></div></div></div></div></div><div class = 'S8'><span>To apply quantization aware training to a network that contains convolution layers without an adjacent bach normalization layer, use the </span><a href = "./QuantizedConvolutionTrainingLayer.m"><span style=' font-family: monospace;'>QuantizedConvolutionTrainingLayer</span></a><span> provided with this example instead of </span><a href = "./QuantizedConvolutionBatchNormTrainingLayer.m"><span style=' font-family: monospace;'>QuantizedConvolutionBatchNormTrainingLayer</span></a><span>.</span></div><h2 class = 'S2' id = 'H_9D6B450E' ><span>Do Quantization Aware Training</span></h2><div class = 'S1'><span>Using the layer graph with quantization aware training layers, train the network. Compared to the training of the original network, the training options have been modified to increase the number of </span><span style=' font-family: monospace;'>MaxEpochs</span><span> to 10 and the </span><span style=' font-family: monospace;'>ValidationFrequency</span><span> to every 2 epochs.</span></div><div class="CodeBlock"><div class="inlineWrapper"><div class = 'S10'><span style="white-space: pre"><span >miniBatchSize = 32;</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span >validationFrequencyEpochs = 2;</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span >numObservations = augimdsTrain.NumObservations;</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span >numIterationsPerEpoch = floor(numObservations/miniBatchSize);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span >validationFrequency = validationFrequencyEpochs*numIterationsPerEpoch;</span></span></div></div><div class="inlineWrapper"><div class = 'S6'> </div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span >options = trainingOptions(</span><span style="color: rgb(167, 9, 245);">"sgdm"</span><span >, </span><span style="color: rgb(14, 0, 255);">...</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > MaxEpochs=10, </span><span style="color: rgb(14, 0, 255);">...</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > MiniBatchSize=miniBatchSize, </span><span style="color: rgb(14, 0, 255);">...</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > InitialLearnRate=3e-4, </span><span style="color: rgb(14, 0, 255);">...</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > Shuffle=</span><span style="color: rgb(167, 9, 245);">"every-epoch"</span><span >, </span><span style="color: rgb(14, 0, 255);">...</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > ValidationData=augimdsValidation, </span><span style="color: rgb(14, 0, 255);">...</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > ValidationFrequency=validationFrequency, </span><span style="color: rgb(14, 0, 255);">...</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > Plots=</span><span style="color: rgb(167, 9, 245);">"training-progress"</span><span >, </span><span style="color: rgb(14, 0, 255);">...</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > Verbose=false);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'> </div></div><div class="inlineWrapper outputs"><div class = 'S11'><span style="white-space: pre"><span >quantizationAwareTrainedNet = trainNetwork(augimdsTrain,quantizationAwareLayerGraph,options);</span></span></div><div class = 'S4'><div class="inlineElement eoOutputWrapper embeddedOutputsFigure" uid="4E2C58EC" prevent-scroll="true" data-testid="output_7" style="width: 938.026px;"><div class="figureElement eoOutputContent"><img class="figureImage figureContainingNode" src="" style="width: 987px; padding-bottom: 0px;"></div><div class="outputLayer selectedOutputDecorationLayer doNotExport"></div><div class="outputLayer activeOutputDecorationLayer doNotExport"></div><div class="outputLayer scrollableOutputDecorationLayer doNotExport"></div></div></div></div></div><h2 class = 'S2' id = 'H_5AA96E0C' ><span>Quantize the Network</span></h2><div class = 'S1'><span>The network returned by the </span><span style=' font-family: monospace;'>trainNetwork</span><span> function still has quantization aware training layers. The quantization aware training operators need to be replaced with operators that are specific to inference. Whereas training was performed using 32-bit floating-point values, the quantized network must perform inference using 8-bit integer inputs and weights. </span></div><div class = 'S1'><span></span></div><div class = 'S1'><span> </span><img class = "imageNode" src = "" width = "258" height = "401" alt = "quantized_training.png" style = "vertical-align: baseline; width: 258px; height: 401px;"></img><img class = "imageNode" src = "" width = "359" height = "400" alt = "quantized_inference.png" style = "vertical-align: baseline; width: 359px; height: 400px;"></img><span> </span></div><div class = 'S1'><span></span></div><div class = 'S1'><span>Remove the quantization aware layers and replace with underlying learned convolution layers using the </span><a href = "#H_D546E6DD"><span style=' font-family: monospace;'>removeQuantizationAwareLayers</span></a><span> function, provided at the end of this example. </span></div><div class="CodeBlock"><div class="inlineWrapper"><div class = 'S9'><span style="white-space: pre"><span >preQuantizedNetwork = removeQuantizationAwareLayers(quantizationAwareTrainedNet);</span></span></div></div></div><div class = 'S8'><span>Perform post-training quantization on the network as normal using the function </span><span style=' font-family: monospace;'>createQuantizedNetwork</span><span>, provided at the end of this example. </span></div><div class="CodeBlock"><div class="inlineWrapper"><div class = 'S10'><span style="white-space: pre"><span >quantizationAwareQuantizedNet = createQuantizedNetwork(preQuantizedNetwork,augimdsCalibration);</span></span></div></div><div class="inlineWrapper outputs"><div class = 'S11'><span style="white-space: pre"><span >quantizedNetworkDetails = quantizationDetails(quantizationAwareQuantizedNet)</span></span></div><div class = 'S4'><div class="inlineElement eoOutputWrapper embeddedOutputsVariableStringElement" uid="C2062F56" prevent-scroll="true" data-testid="output_8" style="width: 938.026px; white-space: pre; font-style: normal; color: rgb(33, 33, 33); font-size: 12px;"><div class="textElement eoOutputContent" data-width="908" data-height="79" data-hashorizontaloverflow="false" style="max-height: 261px; white-space: pre; font-style: normal; color: rgb(33, 33, 33); font-size: 12px;"><div style="white-space: pre; font-style: normal; color: rgb(33, 33, 33); font-size: 12px;"><span class="variableNameElement" style="white-space: pre; font-style: normal; color: rgb(33, 33, 33); font-size: 12px;">quantizedNetworkDetails = <span class="headerElement" style="white-space: pre; font-style: italic; color: rgb(179, 179, 179); font-size: 12px;">struct with fields:</span></span></div><div style="white-space: pre; font-style: normal; color: rgb(33, 33, 33); font-size: 12px;"> IsQuantized: 1
TargetLibrary: "cudnn"
QuantizedLayerNames: [105×1 string]
QuantizedLearnables: [60×3 table]
</div></div></div></div></div></div><h2 class = 'S2' id = 'H_62BFFDFA' ><span>Evaluate the Quantized Network</span></h2><div class = 'S1'><span>Evaluate the performance of the quantized network.</span></div><div class="CodeBlock"><div class="inlineWrapper outputs"><div class = 'S3'><span style="white-space: pre"><span >quantizedNetworkCCR = evaluateModelAccuracy(quantizationAwareQuantizedNet,augimdsValidation,validationActualLabels)</span></span></div><div class = 'S4'><div class='variableElement' style='font-family: Menlo, Monaco, Consolas, "Courier New", monospace; font-size: 12px; '>quantizedNetworkCCR = 0.8937</div></div></div><div class="inlineWrapper"><div class = 'S5'><span style="white-space: pre"><span >bar( </span><span style="color: rgb(14, 0, 255);">...</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > categorical([</span><span style="color: rgb(167, 9, 245);">"Original network"</span><span >,</span><span style="color: rgb(167, 9, 245);">"Post-training quantized network"</span><span >,</span><span style="color: rgb(167, 9, 245);">"Quantization aware training network"</span><span >]), </span><span style="color: rgb(14, 0, 255);">...</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > [netCCR,originalQuantizedCCR,quantizedNetworkCCR] </span><span style="color: rgb(14, 0, 255);">...</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > )</span></span></div></div><div class="inlineWrapper outputs"><div class = 'S11'><span style="white-space: pre"><span >ylabel(</span><span style="color: rgb(167, 9, 245);">"Network Accuracy (%)"</span><span >)</span></span></div><div class = 'S4'><div class="inlineElement eoOutputWrapper embeddedOutputsFigure" uid="3B0056AE" prevent-scroll="true" data-testid="output_10" style="width: 938.026px;"><div class="figureElement eoOutputContent"><img class="figureImage figureContainingNode" src="" style="width: 616px; padding-bottom: 0px;"></div><div class="outputLayer selectedOutputDecorationLayer doNotExport"></div><div class="outputLayer activeOutputDecorationLayer doNotExport"></div><div class="outputLayer scrollableOutputDecorationLayer doNotExport"></div></div></div></div></div><div class = 'S8'><span>The accuracy for the quantized network after quantization aware training is on par with the accuracy of that from the original floating point network.</span></div><h2 class = 'S2' id = 'H_5CF3214F' ><span>References</span></h2><ol class = 'S12' id = 'M_2AA33B76' ><li class = 'S13'><span>The TensorFlow Team. </span><span style=' font-style: italic;'>Flowers</span><span> </span><a href = "http://download.tensorflow.org/example_images/flower_photos.tgz"><span>http://download.tensorflow.org/example_images/flower_photos.tgz</span></a></li><li class = 'S13'><span>Gholami, A., Kim, S., Dong, Z., Mahoney, M., & Keutzer, K. (2021). A Survey of Quantization Methods for Efficient Neural Network Inference. Retrieved from </span><a href = "https://arxiv.org/abs/2103.13630"><span>https://arxiv.org/abs/2103.13630</span></a></li><li class = 'S13'><span>Jacob, B., Kligys, S., Chen, B., Zhu, M., Tang, M., Howard, A., Adam, H., & Kalenichenko, D. (2017). Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference. Retrieved from </span><a href = "https://arxiv.org/abs/1712.05877"><span>https://arxiv.org/abs/1712.05877</span></a></li></ol><h2 class = 'S2' id = 'H_592CCEAB' ><span>Supporting Functions</span></h2><h3 class = 'S14' id = 'H_D08A0BC5' ><span>Download Flower Dataset</span></h3><div class = 'S1'><span>The </span><span style=' font-family: monospace;'>downloadFlowerDataset</span><span> function downloads and extracts the flowers dataset, if it is not yet in the current folder. </span></div><div class="CodeBlock"><div class="inlineWrapper"><div class = 'S10' id = 'M_67532D8E' ><span style="white-space: pre"><span style="color: rgb(14, 0, 255);">function </span><span >imageFolder = downloadFlowerDataset</span></span></div></div><div class="inlineWrapper"><div class = 'S6'> </div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > url = </span><span style="color: rgb(167, 9, 245);">"http://download.tensorflow.org/example_images/flower_photos.tgz"</span><span >;</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > downloadFolder = pwd;</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > filename = fullfile(downloadFolder,</span><span style="color: rgb(167, 9, 245);">"flower_dataset.tgz"</span><span >);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > imageFolder = fullfile(downloadFolder,</span><span style="color: rgb(167, 9, 245);">"flower_photos"</span><span >);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'> </div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > </span><span style="color: rgb(14, 0, 255);">if </span><span >~exist(imageFolder,</span><span style="color: rgb(167, 9, 245);">"dir"</span><span >)</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > disp(</span><span style="color: rgb(167, 9, 245);">"Downloading Flower Dataset (218 MB)..."</span><span >)</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > websave(filename,url);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > untar(filename,downloadFolder)</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > </span><span style="color: rgb(14, 0, 255);">end</span></span></div></div><div class="inlineWrapper"><div class = 'S6'> </div></div><div class="inlineWrapper"><div class = 'S7'><span style="white-space: pre"><span style="color: rgb(14, 0, 255);">end</span></span></div></div></div><h3 class = 'S14' id = 'H_C9FB9125' ><span>Perform Transfer Learning</span></h3><div class = 'S1'><span>The </span><span style=' font-family: monospace;'>createFlowerNetwork</span><span> function replaces the final fully connected and classification layer of the MobileNet-v2 network and retrains the nework to classify flowers.</span></div><div class="CodeBlock"><div class="inlineWrapper"><div class = 'S10'><span style="white-space: pre"><span style="color: rgb(14, 0, 255);">function </span><span >transfer_net = createFlowerNetwork(net,augimdsTrain,augimdsValidation,classes)</span></span></div></div><div class="inlineWrapper"><div class = 'S6'> </div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > </span><span style="color: rgb(0, 128, 19);">% Define network architecture.</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > </span><span style="color: rgb(0, 128, 19);">% Find and replace layers to perform transfer learning.</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > lgraph = layerGraph(net);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'> </div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > </span><span style="color: rgb(0, 128, 19);">% Replace the learnable layer with a new one.</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > learnableLayer = lgraph.Layers(end-2);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > numClasses = numel(classes);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > newLearnableLayer = fullyConnectedLayer(numClasses, </span><span style="color: rgb(14, 0, 255);">...</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > Name=</span><span style="color: rgb(167, 9, 245);">"new_fc"</span><span >, </span><span style="color: rgb(14, 0, 255);">...</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > WeightLearnRateFactor=10, </span><span style="color: rgb(14, 0, 255);">...</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > BiasLearnRateFactor=10);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > lgraph = replaceLayer(lgraph,learnableLayer.Name,newLearnableLayer);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'> </div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > </span><span style="color: rgb(0, 128, 19);">% Replace the classification layer with a new one specific to the type</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > </span><span style="color: rgb(0, 128, 19);">% classes seen in the flowers dataset.</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > classLayer = lgraph.Layers(end);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > newClassLayer = classificationLayer(Name=</span><span style="color: rgb(167, 9, 245);">"new_classoutput"</span><span >);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > lgraph = replaceLayer(lgraph,classLayer.Name,newClassLayer);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'> </div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > </span><span style="color: rgb(0, 128, 19);">% Specify training options.</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > miniBatchSize = 64;</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > validationFrequencyEpochs = 1;</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > numObservations = augimdsTrain.NumObservations;</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > numIterationsPerEpoch = floor(numObservations/miniBatchSize);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > validationFrequency = validationFrequencyEpochs * numIterationsPerEpoch;</span></span></div></div><div class="inlineWrapper"><div class = 'S6'> </div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > options = trainingOptions(</span><span style="color: rgb(167, 9, 245);">"sgdm"</span><span >, </span><span style="color: rgb(14, 0, 255);">...</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > MaxEpochs=5, </span><span style="color: rgb(14, 0, 255);">...</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > MiniBatchSize=miniBatchSize, </span><span style="color: rgb(14, 0, 255);">...</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > InitialLearnRate=3e-4, </span><span style="color: rgb(14, 0, 255);">...</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > Shuffle=</span><span style="color: rgb(167, 9, 245);">"every-epoch"</span><span >, </span><span style="color: rgb(14, 0, 255);">...</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > ValidationData=augimdsValidation, </span><span style="color: rgb(14, 0, 255);">...</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > ValidationFrequency=validationFrequency, </span><span style="color: rgb(14, 0, 255);">...</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > Plots=</span><span style="color: rgb(167, 9, 245);">"training-progress"</span><span >, </span><span style="color: rgb(14, 0, 255);">...</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > Verbose=false);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'> </div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > </span><span style="color: rgb(0, 128, 19);">% Train the network.</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > transfer_net = trainNetwork(augimdsTrain,lgraph,options);</span></span></div></div><div class="inlineWrapper"><div class = 'S7'><span style="white-space: pre"><span style="color: rgb(14, 0, 255);">end</span></span></div></div></div><h3 class = 'S14' id = 'H_EC33AD21' ><span>Evaluate Mode Accuracy</span></h3><div class = 'S1'><span>The </span><span style=' font-family: monospace;'>evaluateModelAccuracy</span><span> function compares the classify output of the network with the actual labels.</span></div><div class="CodeBlock"><div class="inlineWrapper"><div class = 'S10'><span style="white-space: pre"><span style="color: rgb(14, 0, 255);">function </span><span >ccr = evaluateModelAccuracy(net,valDS,labels)</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > ypred = classify(net,valDS);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > ccr = mean(ypred == labels);</span></span></div></div><div class="inlineWrapper"><div class = 'S7'><span style="white-space: pre"><span style="color: rgb(14, 0, 255);">end</span></span></div></div></div><h3 class = 'S14' id = 'H_138C14B1' ><span>Create Quantized Network</span></h3><div class = 'S1'><span>The </span><span style=' font-family: monospace;'>createQuantizedNetwork</span><span> constructs a </span><span style=' font-family: monospace;'>dlquantizer</span><span> object for </span><span style=' font-family: monospace;'>GPU</span><span> target, simulates and collects ranges of the network with a representative datastore using the </span><span style=' font-family: monospace;'>calibrate</span><span> function, then quantizes the network using the </span><span style=' font-family: monospace;'>quantize</span><span> function.</span></div><div class="CodeBlock"><div class="inlineWrapper"><div class = 'S10'><span style="white-space: pre"><span style="color: rgb(14, 0, 255);">function </span><span >qNet = createQuantizedNetwork(net,calDS)</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > dq = dlquantizer(net,ExecutionEnvironment=</span><span style="color: rgb(167, 9, 245);">"GPU"</span><span >);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > calResults = calibrate(dq,calDS);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > qNet = quantize(dq);</span></span></div></div><div class="inlineWrapper"><div class = 'S7'><span style="white-space: pre"><span style="color: rgb(14, 0, 255);">end</span></span></div></div></div><h3 class = 'S14' id = 'H_157F752C' ><span>Make Quantization Aware LayerGraph</span></h3><div class = 'S1'><span>The </span><span style=' font-family: monospace;'>makeQuantizationAwareLayers</span><span> function takes a </span><span style=' font-family: monospace;'>DAGNetwork</span><span> object as input and replaces 2-D convolution, grouped 2-D convolution and batch normalization layers with quantization aware versions. The layer replacement works for this particular network where the layers are in topologically sorted order but may not work for other networks. </span></div><div class="CodeBlock"><div class="inlineWrapper"><div class = 'S10'><span style="white-space: pre"><span style="color: rgb(14, 0, 255);">function </span><span >lg = makeQuantizationAwareLayers(net)</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > lg = layerGraph(net);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'> </div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > </span><span style="color: rgb(14, 0, 255);">for </span><span >idx = 1:numel(lg.Layers) - 1</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > currentLayer = lg.Layers(idx);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > nextLayer = lg.Layers(idx + 1);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'> </div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > </span><span style="color: rgb(0, 128, 19);">% Find 2-D convolution layers or 2-D grouped convolution layers.</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > </span><span style="color: rgb(14, 0, 255);">if </span><span >(isa(currentLayer,</span><span style="color: rgb(167, 9, 245);">"nnet.cnn.layer.Convolution2DLayer"</span><span >) </span><span style="color: rgb(14, 0, 255);">...</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > || isa(currentLayer,</span><span style="color: rgb(167, 9, 245);">"nnet.cnn.layer.GroupedConvolution2DLayer"</span><span >))</span></span></div></div><div class="inlineWrapper"><div class = 'S6'> </div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > </span><span style="color: rgb(14, 0, 255);">if </span><span >isa(nextLayer,</span><span style="color: rgb(167, 9, 245);">"nnet.cnn.layer.BatchNormalizationLayer"</span><span >)</span></span></div></div><div class="inlineWrapper"><div class = 'S6'> </div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > </span><span style="color: rgb(0, 128, 19);">% Replace convolution layer with quantization aware layer.</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > qLayer = QuantizedConvolutionBatchNormTrainingLayer(currentLayer,nextLayer);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > lg = replaceLayer(lg,currentLayer.Name,qLayer);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'> </div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > </span><span style="color: rgb(0, 128, 19);">% Replace batchNormalizationLayer with identity training</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > </span><span style="color: rgb(0, 128, 19);">% layer.</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > qLayer = IdentityTrainingLayer(nextLayer);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > lg = replaceLayer(lg,nextLayer.Name,qLayer);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'> </div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > </span><span style="color: rgb(14, 0, 255);">else</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > </span><span style="color: rgb(0, 128, 19);">% Replace convolution layer with quantization aware layer.</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > qLayer = QuantizedConvolutionTrainingLayer(currentLayer);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > lg = replaceLayer(lg,currentLayer.Name,qLayer);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > </span><span style="color: rgb(14, 0, 255);">end</span></span></div></div><div class="inlineWrapper"><div class = 'S6'> </div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > </span><span style="color: rgb(14, 0, 255);">end</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > </span><span style="color: rgb(14, 0, 255);">end</span></span></div></div><div class="inlineWrapper"><div class = 'S6'> </div></div><div class="inlineWrapper"><div class = 'S7'><span style="white-space: pre"><span style="color: rgb(14, 0, 255);">end</span></span></div></div></div><h3 class = 'S14' id = 'H_D546E6DD' ><span>Remove Quantization Aware Layers from Network</span></h3><div class = 'S1'><span>The </span><span style=' font-family: monospace;'>removeQuantizationAwareLayers</span><span> function extracts the original layers from the quantization aware network and replaces the quantization aware layers with the original underlying layers.</span></div><div class="CodeBlock"><div class="inlineWrapper"><div class = 'S10'><span style="white-space: pre"><span style="color: rgb(14, 0, 255);">function </span><span >net = removeQuantizationAwareLayers(qatNet)</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > lg = layerGraph(qatNet);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'> </div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > </span><span style="color: rgb(0, 128, 19);">% Find quantization aware training layers and replace with the</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > </span><span style="color: rgb(0, 128, 19);">% underlying layers.</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > </span><span style="color: rgb(14, 0, 255);">for </span><span >idx = 1:numel(lg.Layers)</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > currentLayer = lg.Layers(idx);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'> </div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > </span><span style="color: rgb(14, 0, 255);">if </span><span >isa(currentLayer,</span><span style="color: rgb(167, 9, 245);">"QuantizedConvolutionBatchNormTrainingLayer"</span><span >)</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > cLayer = currentLayer.Network.Layers(1);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > lg = replaceLayer(lg,cLayer.Name,cLayer);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'> </div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > bLayer = currentLayer.Network.Layers(2);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > lg = replaceLayer(lg,bLayer.Name,bLayer);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > </span><span style="color: rgb(14, 0, 255);">elseif </span><span >isa(currentLayer,</span><span style="color: rgb(167, 9, 245);">"QuantizedConvolutionTrainingLayer"</span><span >)</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > cLayer = currentLayer.Network.Layers(1);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > lg = replaceLayer(lg,cLayer.Name,cLayer);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > </span><span style="color: rgb(14, 0, 255);">end</span></span></div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > </span><span style="color: rgb(14, 0, 255);">end</span></span></div></div><div class="inlineWrapper"><div class = 'S6'> </div></div><div class="inlineWrapper"><div class = 'S6'><span style="white-space: pre"><span > net = assembleNetwork(lg);</span></span></div></div><div class="inlineWrapper"><div class = 'S6'> </div></div><div class="inlineWrapper"><div class = 'S7'><span style="white-space: pre"><span style="color: rgb(14, 0, 255);">end</span></span></div></div></div><div class = 'S8'><span style=' font-style: italic;'>Copyright 2023 The MathWorks, Inc.</span></div><h4 class = 'S15' id = 'H_5FC10B62' ></h4>
<br>
<!--
##### SOURCE BEGIN #####
%% Quantization Aware Training for Transfer Learned MobileNet-v2
% This example shows how to perform quantization aware training for transfer
% learned MobileNet-v2 network.
%
% Low precision types like |int8| propagate quantization error that may degrade
% the accuracy of the network. Quantization aware training is a method that introduces
% quantization error at training, thus giving the network the ability to adapt
% and ultimately produce a network more robust to quantization. In most cases,
% a quantized network or integer-arithmetic only network constructed after quantization
% aware training can produce accuracy on par with the original floating point
% network.
%
% This example takes you through the quantization workflow of a transfer learned
% MobileNet-v2 network. MobileNet-v2 was chosen for this example because it contains
% depthwise-separable convolution layers that are especially sensitive to post-training
% quantization.
%
% The flowchart below highlights the steps necessary to convert a trained network
% into a quantized one via quantization aware training.
%
%
%
%
%% Load Flower Dataset
% Download the flower dataset [1] using the supporting function |downloadFlowerDataset|.
imageFolder = downloadFlowerDataset;
imds = imageDatastore(imageFolder, ...
IncludeSubfolders=true, ...
LabelSource="foldernames");
%%
% Inspect the classes of the data.
classes = string(categories(imds.Labels))
%% Perform Transfer Learning on MobileNet-v2
% MobileNet-v2 is a convolutional netural network 53 layers deep. The pretrained
% version of the network is trained on more than a million images from the ImageNet
% database.
net = mobilenetv2;
%%
% Split the data into training and validation sets and create augmented image
% datastores that automatically resize the images to the input size of the network.
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.9);
inputSize = net.Layers(1).InputSize;
augimdsTrain = augmentedImageDatastore(inputSize,imdsTrain);
augimdsValidation = augmentedImageDatastore(inputSize,imdsValidation);
validationActualLabels = imdsValidation.Labels;
%%
% Set aside a portion of the training dataset to use during the calibration
% step of quantization. This datastore should be representative of the data used
% for training but ideally separate from the one used to validate.
augimdsCalibration = subset(shuffle(augimdsTrain),1:320);
%%
% Perform transfer learning on the network on the flowers image dataset. The
% learnable parameters of the trained network |transferNet| are in |single| precision.
transferNet = createFlowerNetwork(net,augimdsTrain,augimdsValidation,classes);
%% Evaluate Baseline Network Performance
% Evaluate the performance of the |single| precision network. Performance in
% this case is defined as the correct classification rate.
netCCR = evaluateModelAccuracy(transferNet,augimdsValidation,validationActualLabels)
%%
% Quantize the network using the |createQuantizedNetwork| function provided
% at the end of this example and evaluate the performance of the quantized network.
% Post-training quantization of the original network yields poor performance due
% to the range of learnable values in the depthwise separable convolution layers.
% An accuracy of roughly 20% is the equivalent to guessing one of the 5 possible
% labels for each image.
originalQuantizedNet = createQuantizedNetwork(transferNet,augimdsCalibration);
originalQuantizedCCR = evaluateModelAccuracy(originalQuantizedNet,augimdsValidation,validationActualLabels)
bar( ...
categorical(["Original network","Post-training quantized network"]), ...
[netCCR,originalQuantizedCCR] ...
)
ylabel("Network Accuracy (%)")
%% Replace Network Layers with Quantization Aware Training Layers
% Replace the |Convolution2D| and |GroupedConvolution2D| layers along with their
% adjacent |BatchNormalization| layers with custom layers that are quantization
% aware using the |makeQuantizationAwareLayers| function provided with this example.
% The quantization aware layers are <./QuantizedConvolutionBatchNormTrainingLayer.m
% custom layers> that have modified |forward| and |predict| behavior that inject
% quantization error similar to that of post-training quantization. The quantization
% error comes from the <./quantizeToFloat.m quantizeToFloat> function that quantizes,
% then unquantizes a given value using best-precision scaling to |int8| precision.
%
%
%
% Quantization to float can be expressed as follows.
%
% $$\begin{array}{l}\textrm{𝑥}̂=\textrm{quantizeToFloat}\left(\textrm{𝑥}\right)\\\;\;\;=\textrm{unquantize}\left(\textrm{quantize}\left(\textrm{𝑥}\right)\right)\\\;\;\;=\textrm{rescale}\cdot
% \textrm{saturate}\left(\textrm{round}\left(\frac{\textrm{𝑥}}{\textrm{scale}}\right)\right)\end{array}$$
%
% The quantization step uses a non-differentiable operation |round| that would
% normally break the training workflow by zeroing out the gradients. During quantization
% aware training, bypass the gradient calculations for non-differentiable operations
% using an identity function. The diagram below [2] shows how the custom layer
% calculates the gradients for non-differentiable operations with the identity
% function via straight-through estimation.
%
%
%
%
%
%
%
% For 2-D convolution layers, the weights and biases of the replacement layers
% include the batch normalization layer statistics. Convolution operations during
% training use the adjusted and quantized weights [3].
%
% As the batch normalization layer statistics are incorporated into the convolution
% layers, the |makeQuantizationAwareLayers| replaces each batch normalization
% layer with an identity layer that returns its input as its output.
quantizationAwareLayerGraph = makeQuantizationAwareLayers(transferNet);
%%
% Inspect the layers of the network.
quantizationAwareLayerGraph.Layers
%%
% To apply quantization aware training to a network that contains convolution
% layers without an adjacent bach normalization layer, use the <./QuantizedConvolutionTrainingLayer.m
% |QuantizedConvolutionTrainingLayer|> provided with this example instead of <./QuantizedConvolutionBatchNormTrainingLayer.m
% |QuantizedConvolutionBatchNormTrainingLayer|>.
%% Do Quantization Aware Training
% Using the layer graph with quantization aware training layers, train the network.
% Compared to the training of the original network, the training options have
% been modified to increase the number of |MaxEpochs| to 10 and the |ValidationFrequency|
% to every 2 epochs.
miniBatchSize = 32;
validationFrequencyEpochs = 2;
numObservations = augimdsTrain.NumObservations;
numIterationsPerEpoch = floor(numObservations/miniBatchSize);
validationFrequency = validationFrequencyEpochs*numIterationsPerEpoch;
options = trainingOptions("sgdm", ...
MaxEpochs=10, ...
MiniBatchSize=miniBatchSize, ...
InitialLearnRate=3e-4, ...
Shuffle="every-epoch", ...
ValidationData=augimdsValidation, ...
ValidationFrequency=validationFrequency, ...
Plots="training-progress", ...
Verbose=false);
quantizationAwareTrainedNet = trainNetwork(augimdsTrain,quantizationAwareLayerGraph,options);
%% Quantize the Network
% The network returned by the |trainNetwork| function still has quantization
% aware training layers. The quantization aware training operators need to be
% replaced with operators that are specific to inference. Whereas training was
% performed using 32-bit floating-point values, the quantized network must perform
% inference using 8-bit integer inputs and weights.
%
%
%
%
%
%
%
% Remove the quantization aware layers and replace with underlying learned convolution
% layers using the |removeQuantizationAwareLayers| function, provided at the end
% of this example.
preQuantizedNetwork = removeQuantizationAwareLayers(quantizationAwareTrainedNet);
%%
% Perform post-training quantization on the network as normal using the function
% |createQuantizedNetwork|, provided at the end of this example.
quantizationAwareQuantizedNet = createQuantizedNetwork(preQuantizedNetwork,augimdsCalibration);
quantizedNetworkDetails = quantizationDetails(quantizationAwareQuantizedNet)
%% Evaluate the Quantized Network
% Evaluate the performance of the quantized network.
quantizedNetworkCCR = evaluateModelAccuracy(quantizationAwareQuantizedNet,augimdsValidation,validationActualLabels)
bar( ...
categorical(["Original network","Post-training quantized network","Quantization aware training network"]), ...
[netCCR,originalQuantizedCCR,quantizedNetworkCCR] ...
)
ylabel("Network Accuracy (%)")
%%
% The accuracy for the quantized network after quantization aware training is
% on par with the accuracy of that from the original floating point network.
%% References
%%
% # The TensorFlow Team. _Flowers_ <http://download.tensorflow.org/example_images/flower_photos.tgz
% http://download.tensorflow.org/example_images/flower_photos.tgz>
% # Gholami, A., Kim, S., Dong, Z., Mahoney, M., & Keutzer, K. (2021). A Survey
% of Quantization Methods for Efficient Neural Network Inference. Retrieved from
% <https://arxiv.org/abs/2103.13630 https://arxiv.org/abs/2103.13630>
% # Jacob, B., Kligys, S., Chen, B., Zhu, M., Tang, M., Howard, A., Adam, H.,
% & Kalenichenko, D. (2017). Quantization and Training of Neural Networks for
% Efficient Integer-Arithmetic-Only Inference. Retrieved from <https://arxiv.org/abs/1712.05877
% https://arxiv.org/abs/1712.05877>
%% Supporting Functions
% Download Flower Dataset
% The |downloadFlowerDataset| function downloads and extracts the flowers dataset,
% if it is not yet in the current folder.
function imageFolder = downloadFlowerDataset
url = "http://download.tensorflow.org/example_images/flower_photos.tgz";
downloadFolder = pwd;
filename = fullfile(downloadFolder,"flower_dataset.tgz");
imageFolder = fullfile(downloadFolder,"flower_photos");
if ~exist(imageFolder,"dir")
disp("Downloading Flower Dataset (218 MB)...")
websave(filename,url);
untar(filename,downloadFolder)
end
end
% Perform Transfer Learning
% The |createFlowerNetwork| function replaces the final fully connected and
% classification layer of the MobileNet-v2 network and retrains the nework to
% classify flowers.
function transfer_net = createFlowerNetwork(net,augimdsTrain,augimdsValidation,classes)
% Define network architecture.
% Find and replace layers to perform transfer learning.
lgraph = layerGraph(net);
% Replace the learnable layer with a new one.
learnableLayer = lgraph.Layers(end-2);
numClasses = numel(classes);
newLearnableLayer = fullyConnectedLayer(numClasses, ...
Name="new_fc", ...
WeightLearnRateFactor=10, ...
BiasLearnRateFactor=10);
lgraph = replaceLayer(lgraph,learnableLayer.Name,newLearnableLayer);
% Replace the classification layer with a new one specific to the type
% classes seen in the flowers dataset.
classLayer = lgraph.Layers(end);
newClassLayer = classificationLayer(Name="new_classoutput");
lgraph = replaceLayer(lgraph,classLayer.Name,newClassLayer);
% Specify training options.
miniBatchSize = 64;
validationFrequencyEpochs = 1;
numObservations = augimdsTrain.NumObservations;
numIterationsPerEpoch = floor(numObservations/miniBatchSize);
validationFrequency = validationFrequencyEpochs * numIterationsPerEpoch;
options = trainingOptions("sgdm", ...
MaxEpochs=5, ...
MiniBatchSize=miniBatchSize, ...
InitialLearnRate=3e-4, ...
Shuffle="every-epoch", ...
ValidationData=augimdsValidation, ...
ValidationFrequency=validationFrequency, ...
Plots="training-progress", ...
Verbose=false);
% Train the network.
transfer_net = trainNetwork(augimdsTrain,lgraph,options);
end
% Evaluate Mode Accuracy
% The |evaluateModelAccuracy| function compares the classify output of the network
% with the actual labels.
function ccr = evaluateModelAccuracy(net,valDS,labels)
ypred = classify(net,valDS);
ccr = mean(ypred == labels);
end
% Create Quantized Network
% The |createQuantizedNetwork| constructs a |dlquantizer| object for |GPU| target,
% simulates and collects ranges of the network with a representative datastore
% using the |calibrate| function, then quantizes the network using the |quantize|
% function.
function qNet = createQuantizedNetwork(net,calDS)
dq = dlquantizer(net,ExecutionEnvironment="GPU");
calResults = calibrate(dq,calDS);
qNet = quantize(dq);
end
% Make Quantization Aware LayerGraph
% The |makeQuantizationAwareLayers| function takes a |DAGNetwork| object as
% input and replaces 2-D convolution, grouped 2-D convolution and batch normalization
% layers with quantization aware versions. The layer replacement works for this
% particular network where the layers are in topologically sorted order but may
% not work for other networks.
function lg = makeQuantizationAwareLayers(net)
lg = layerGraph(net);
for idx = 1:numel(lg.Layers) - 1
currentLayer = lg.Layers(idx);
nextLayer = lg.Layers(idx + 1);
% Find 2-D convolution layers or 2-D grouped convolution layers.
if (isa(currentLayer,"nnet.cnn.layer.Convolution2DLayer") ...
|| isa(currentLayer,"nnet.cnn.layer.GroupedConvolution2DLayer"))
if isa(nextLayer,"nnet.cnn.layer.BatchNormalizationLayer")
% Replace convolution layer with quantization aware layer.
qLayer = QuantizedConvolutionBatchNormTrainingLayer(currentLayer,nextLayer);
lg = replaceLayer(lg,currentLayer.Name,qLayer);
% Replace batchNormalizationLayer with identity training
% layer.
qLayer = IdentityTrainingLayer(nextLayer);
lg = replaceLayer(lg,nextLayer.Name,qLayer);
else
% Replace convolution layer with quantization aware layer.
qLayer = QuantizedConvolutionTrainingLayer(currentLayer);
lg = replaceLayer(lg,currentLayer.Name,qLayer);
end
end
end
end
% Remove Quantization Aware Layers from Network
% The |removeQuantizationAwareLayers| function extracts the original layers
% from the quantization aware network and replaces the quantization aware layers
% with the original underlying layers.
function net = removeQuantizationAwareLayers(qatNet)
lg = layerGraph(qatNet);
% Find quantization aware training layers and replace with the
% underlying layers.
for idx = 1:numel(lg.Layers)
currentLayer = lg.Layers(idx);
if isa(currentLayer,"QuantizedConvolutionBatchNormTrainingLayer")
cLayer = currentLayer.Network.Layers(1);
lg = replaceLayer(lg,cLayer.Name,cLayer);
bLayer = currentLayer.Network.Layers(2);
lg = replaceLayer(lg,bLayer.Name,bLayer);
elseif isa(currentLayer,"QuantizedConvolutionTrainingLayer")
cLayer = currentLayer.Network.Layers(1);
lg = replaceLayer(lg,cLayer.Name,cLayer);
end
end
net = assembleNetwork(lg);
end
%%
% _Copyright 2023 The MathWorks, Inc._
%
##### SOURCE END #####
-->
</div></body></html>