From 2fb5d641d852ff40df2dd85f461115ab9583d851 Mon Sep 17 00:00:00 2001 From: Taehee Jeong Date: Tue, 26 Jan 2021 01:28:45 -0800 Subject: [PATCH] QuantizationDebugger with per-layer debug metrics The debugger accepts a quantized debug model, which has NumericVerify ops in the model. The debugger runs inference with given dataset, and collects metrics. PiperOrigin-RevId: 353820318 Change-Id: I3d685bad9eff3e0ba762ea1df939c0a5102ba618 --- .../experimental/quantization_debugger/BUILD | 37 ++++ .../quantization_debugger/debugger.py | 186 ++++++++++++++++++ .../quantization_debugger/debugger_test.py | 131 ++++++++++++ 3 files changed, 354 insertions(+) create mode 100644 tensorflow/lite/experimental/quantization_debugger/BUILD create mode 100644 tensorflow/lite/experimental/quantization_debugger/debugger.py create mode 100644 tensorflow/lite/experimental/quantization_debugger/debugger_test.py diff --git a/tensorflow/lite/experimental/quantization_debugger/BUILD b/tensorflow/lite/experimental/quantization_debugger/BUILD new file mode 100644 index 00000000000000..3deb89fa58fedb --- /dev/null +++ b/tensorflow/lite/experimental/quantization_debugger/BUILD @@ -0,0 +1,37 @@ +# QuantizationDebugger for TFLite accuracy tooling. +load("//tensorflow:tensorflow.bzl", "pytype_strict_library") +load("//tensorflow:tensorflow.bzl", "py_strict_test") + +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) + +pytype_strict_library( + name = "debugger", + srcs = ["debugger.py"], + srcs_version = "PY3", + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/python/util:tf_export", + "//third_party/py/numpy", + ], +) + +py_strict_test( + name = "debugger_test", + srcs = [ + "debugger_test.py", + ], + python_version = "PY3", + deps = [ + ":debugger", + "//tensorflow:tensorflow_py", + "//tensorflow/lite/python:convert", + "//tensorflow/lite/python:lite", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/training/tracking", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/lite/experimental/quantization_debugger/debugger.py b/tensorflow/lite/experimental/quantization_debugger/debugger.py new file mode 100644 index 00000000000000..2d0793d2940900 --- /dev/null +++ b/tensorflow/lite/experimental/quantization_debugger/debugger.py @@ -0,0 +1,186 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python TF-Lite QuantizationDebugger.""" +import collections +from typing import Callable, Dict, Iterable, List, Mapping, Optional, Sequence + +import numpy as np +import tensorflow as tf + +from tensorflow.python.util import tf_export + +# Returns metrics based on difference of values for quantized/float ops. +_DEFAULT_LAYER_DEBUG_METRICS = { + 'num_elements': lambda diffs: diffs.size, + 'stddev': np.std, + 'mean_error': np.average, + 'max_abs_error': lambda diffs: np.max(np.abs(diffs)), + 'mean_square_error': lambda diffs: np.average(diffs**2), +} + + +@tf_export.tf_export(v1=['lite.experimental.QuantizationDebugOptions']) +class QuantizationDebugOptions: + """Debug options to set up a given QuantizationDebugger.""" + + def __init__( + self, + layer_debug_metrics: Optional[Mapping[str, Callable[[np.ndarray], + float]]] = None + ) -> None: + """Initializes debugger options. + + Args: + layer_debug_metrics: a dict to specify layer debug functions + {function_name_str: function} where the function accpets result of + NumericVerify Op, which is value difference between float and + dequantized op results. The function returns single scalar value. + """ + self.layer_debug_metrics = layer_debug_metrics + + +@tf_export.tf_export(v1=['lite.experimental.QuantizationDebugger']) +class QuantizationDebugger: + """Debugger for Quantized TensorFlow Lite debug mode models. + + This can run the TensorFlow Lite converted models equipped with debug ops and + collect debug information. This debugger calculates statistics from + user-defined post-processing functions as well as default ones. + """ + + def __init__( + self, + quant_debug_model_path: Optional[str] = None, + quant_debug_model_content: Optional[bytes] = None, + debug_dataset: Optional[Callable[[], + Iterable[Sequence[np.ndarray]]]] = None, + debug_options: Optional[QuantizationDebugOptions] = None) -> None: + """Runs the TFLite debugging model with given debug options. + + Args: + quant_debug_model_path: Path to debug mode TF-Lite Flatbuffer file. + quant_debug_model_content: Content of the quantized debug model. + debug_dataset: a factory function that returns dataset generator which is + used to generate input samples (list of np.ndarray) for the model. The + generated elements must have same types and shape as inputs to the + model. + debug_options: Debug options to debug the given model. + + Raises: + ValueError: If the debugger was unable to be created. + + Attributes: + layer_statistics: results of error metrics for each NumericVerify op + results. in {layer_name: {metric_name: metric}} format. + """ + self._data_gen = debug_dataset + self._debug_options = debug_options or QuantizationDebugOptions() + + input_data = next(iter(self._data_gen())) + self._quant_interpreter = tf.lite.Interpreter(quant_debug_model_path, + quant_debug_model_content) + + self._numeric_verify_tensor_details = None + if not self._get_numeric_verify_tensor_details(): + raise ValueError('Please check if the quantized model is in debug mode') + + self._layer_debug_metrics = _DEFAULT_LAYER_DEBUG_METRICS.copy() + if self._debug_options.layer_debug_metrics: + self._layer_debug_metrics.update(self._debug_options.layer_debug_metrics) + + self.layer_statistics = None + + def run(self) -> None: + """Runs models and gets metrics.""" + self.layer_statistics = self._collect_layer_statistics() + + def _collect_layer_statistics(self) -> Dict[str, Dict[str, float]]: + """Collects layer statistics by applying layer debug metrics. + + For all data from the given RepresentativeDataset, collect statistics per + example by getting the NumericVerify op results in _quant_interpreter + and calculating layer debug metrics on the results. + + Returns: + aggregated per-layer statistics of NumericVerify results. + {layer_name: {metric_name: metric}} + """ + layer_statistics = collections.defaultdict( + lambda: collections.defaultdict(list)) + + initialize = True + for tensor_data in self._data_gen(): + self._set_input_tensors(self._quant_interpreter, tensor_data, initialize) + initialize = False + + # Run the model. + self._quant_interpreter.invoke() + + # Collect the statistics of this invoke result. + for tensor_details in self._get_numeric_verify_tensor_details(): + tensor_name = tensor_details['name'] + diffs = self._quant_interpreter.get_tensor(tensor_details['index']) + for metric_name, metric_fn in self._layer_debug_metrics.items(): + layer_statistics[tensor_name][metric_name].append(metric_fn(diffs)) + + # Calculate final aggregated metrics for each layer. + for metrics in layer_statistics.values(): + for metric_name in metrics: + metrics[metric_name] = np.mean(metrics[metric_name]) + + return layer_statistics + + def _set_input_tensors( + self, + interpreter: tf.lite.Interpreter, + tensor_data: Sequence[np.ndarray], + initialize: bool, + ) -> None: + """Sets input tensors into TFLite model Interpreter. + + Args: + interpreter: a tf.lite.Interpreter object with allocated tensors. + tensor_data: a list of Numpy array data. + initialize: set to true when input is first set for the interpreter, to + set input shapes and allocate tensors. + + Raises: + ValueError: when inputs can't be set, or size of provided inputs does not + match size of model inputs. + """ + input_indices = [ + detail['index'] for detail in interpreter.get_input_details() + ] + if len(input_indices) != len(tensor_data): + raise ValueError( + 'Number of inputs provided ({}) does not match number of inputs to ' + 'the model ({})'.format(len(tensor_data), len(input_indices))) + + if initialize: + for input_idx, tensor in zip(input_indices, tensor_data): + interpreter.resize_tensor_input(input_idx, tensor.shape) + interpreter.allocate_tensors() + + for input_idx, tensor in zip(input_indices, tensor_data): + interpreter.set_tensor(input_idx, tensor) + + def _get_numeric_verify_tensor_details(self) -> List[str]: + """Returns all names of all tensors from NumericVerify op.""" + if not self._numeric_verify_tensor_details: + self._numeric_verify_tensor_details = [ + detail for detail in self._quant_interpreter.get_tensor_details() + if detail['name'].startswith('NumericVerify') + ] + return self._numeric_verify_tensor_details diff --git a/tensorflow/lite/experimental/quantization_debugger/debugger_test.py b/tensorflow/lite/experimental/quantization_debugger/debugger_test.py new file mode 100644 index 00000000000000..2b61db3e5ba74c --- /dev/null +++ b/tensorflow/lite/experimental/quantization_debugger/debugger_test.py @@ -0,0 +1,131 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for QuantizationDebugger.""" + +import numpy as np +import tensorflow as tf + +from tensorflow.lite.experimental.quantization_debugger import debugger +from tensorflow.lite.python import convert +from tensorflow.lite.python import lite +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test +from tensorflow.python.training.tracking import tracking + + +def _get_model(): + """Returns somple model with Conv2D and representative dataset gen.""" + root = tracking.AutoTrackable() + kernel_in = np.array([-2, -1, 1, 2], dtype=np.float32).reshape((2, 2, 1, 1)) + + @tf.function( + input_signature=[tf.TensorSpec(shape=[1, 3, 3, 1], dtype=tf.float32)]) + def func(inp): + kernel = tf.constant(kernel_in, dtype=tf.float32) + conv = tf.nn.conv2d(inp, kernel, strides=1, padding='SAME') + output = tf.nn.relu(conv, name='output') + return output + + root.f = func + to_save = root.f.get_concrete_function() + return to_save + + +def _calibration_gen(): + for i in range(5): + yield [np.arange(9).reshape((1, 3, 3, 1)).astype(np.float32) * i] + + +def _quantize_model(func, calibration_gen, debug=True): + """Quantizes model, in debug or normal mode.""" + converter = lite.TFLiteConverterV2.from_concrete_functions([func]) + converter.target_spec.supported_ops = [lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.representative_dataset = calibration_gen + + # Create a TFLite model with new quantizer and numeric verify ops. + converter.optimizations = [lite.Optimize.DEFAULT] + converter.experimental_new_quantizer = True + if debug: + converter._experimental_calibrate_only = True + calibrated = converter.convert() + return convert.mlir_quantize(calibrated, enable_numeric_verify=True) + else: + return converter.convert() + + +class QuantizationDebuggerTest(test_util.TensorFlowTestCase): + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.float_model = _get_model() + cls.debug_model = _quantize_model(cls.float_model, _calibration_gen) + + @test_util.run_v2_only + def test_quantization_debugger(self): + options = debugger.QuantizationDebugOptions( + layer_debug_metrics={'l1_norm': lambda diffs: np.mean(np.abs(diffs))}) + quant_debugger = debugger.QuantizationDebugger( + quant_debug_model_content=QuantizationDebuggerTest.debug_model, + debug_dataset=_calibration_gen, + debug_options=options) + quant_debugger.run() + + expected_metrics = { + 'num_elements': 9, + 'stddev': 0.03850026, + 'mean_error': 0.01673192, + 'max_abs_error': 0.10039272, + 'mean_square_error': 0.0027558778, + 'l1_norm': 0.023704167, + } + self.assertLen(quant_debugger.layer_statistics, 1) + actual_metrics = next(iter(quant_debugger.layer_statistics.values())) + + self.assertCountEqual(expected_metrics.keys(), actual_metrics.keys()) + for key, value in expected_metrics.items(): + self.assertAlmostEqual(value, actual_metrics[key]) + + @test_util.run_v2_only + def test_quantization_debugger_wrong_input_raises_ValueError(self): + + def wrong_calibration_gen(): + for _ in range(5): + yield [ + np.ones((1, 3, 3, 1), dtype=np.float32), + np.ones((1, 3, 3, 1), dtype=np.float32) + ] + + quant_debugger = debugger.QuantizationDebugger( + quant_debug_model_content=QuantizationDebuggerTest.debug_model, + debug_dataset=wrong_calibration_gen) + with self.assertRaisesRegex( + ValueError, r'inputs provided \(2\).+inputs to the model \(1\)'): + quant_debugger.run() + + @test_util.run_v2_only + def test_quantization_debugger_non_debug_model_raises_ValueError(self): + normal_quant_model = _quantize_model( + QuantizationDebuggerTest.float_model, _calibration_gen, debug=False) + + with self.assertRaisesRegex( + ValueError, 'Please check if the quantized model is in debug mode'): + debugger.QuantizationDebugger( + quant_debug_model_content=normal_quant_model, + debug_dataset=_calibration_gen) + + +if __name__ == '__main__': + test.main()