Skip to content
This repository has been archived by the owner on Dec 19, 2023. It is now read-only.

Commit

Permalink
QuantizationDebugger with per-layer debug metrics
Browse files Browse the repository at this point in the history
The debugger accepts a quantized debug model, which has NumericVerify ops in the model. The debugger runs inference with given dataset, and collects metrics.

PiperOrigin-RevId: 353820318
Change-Id: I3d685bad9eff3e0ba762ea1df939c0a5102ba618
  • Loading branch information
teijeong authored and tensorflower-gardener committed Jan 26, 2021
1 parent 0b2207b commit 2fb5d64
Show file tree
Hide file tree
Showing 3 changed files with 354 additions and 0 deletions.
37 changes: 37 additions & 0 deletions tensorflow/lite/experimental/quantization_debugger/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# QuantizationDebugger for TFLite accuracy tooling.
load("//tensorflow:tensorflow.bzl", "pytype_strict_library")
load("//tensorflow:tensorflow.bzl", "py_strict_test")

package(
default_visibility = ["//tensorflow:internal"],
licenses = ["notice"], # Apache 2.0
)

pytype_strict_library(
name = "debugger",
srcs = ["debugger.py"],
srcs_version = "PY3",
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/python/util:tf_export",
"//third_party/py/numpy",
],
)

py_strict_test(
name = "debugger_test",
srcs = [
"debugger_test.py",
],
python_version = "PY3",
deps = [
":debugger",
"//tensorflow:tensorflow_py",
"//tensorflow/lite/python:convert",
"//tensorflow/lite/python:lite",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python/platform:client_testlib",
"//tensorflow/python/training/tracking",
"//third_party/py/numpy",
],
)
186 changes: 186 additions & 0 deletions tensorflow/lite/experimental/quantization_debugger/debugger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python TF-Lite QuantizationDebugger."""
import collections
from typing import Callable, Dict, Iterable, List, Mapping, Optional, Sequence

import numpy as np
import tensorflow as tf

from tensorflow.python.util import tf_export

# Returns metrics based on difference of values for quantized/float ops.
_DEFAULT_LAYER_DEBUG_METRICS = {
'num_elements': lambda diffs: diffs.size,
'stddev': np.std,
'mean_error': np.average,
'max_abs_error': lambda diffs: np.max(np.abs(diffs)),
'mean_square_error': lambda diffs: np.average(diffs**2),
}


@tf_export.tf_export(v1=['lite.experimental.QuantizationDebugOptions'])
class QuantizationDebugOptions:
"""Debug options to set up a given QuantizationDebugger."""

def __init__(
self,
layer_debug_metrics: Optional[Mapping[str, Callable[[np.ndarray],
float]]] = None
) -> None:
"""Initializes debugger options.
Args:
layer_debug_metrics: a dict to specify layer debug functions
{function_name_str: function} where the function accpets result of
NumericVerify Op, which is value difference between float and
dequantized op results. The function returns single scalar value.
"""
self.layer_debug_metrics = layer_debug_metrics


@tf_export.tf_export(v1=['lite.experimental.QuantizationDebugger'])
class QuantizationDebugger:
"""Debugger for Quantized TensorFlow Lite debug mode models.
This can run the TensorFlow Lite converted models equipped with debug ops and
collect debug information. This debugger calculates statistics from
user-defined post-processing functions as well as default ones.
"""

def __init__(
self,
quant_debug_model_path: Optional[str] = None,
quant_debug_model_content: Optional[bytes] = None,
debug_dataset: Optional[Callable[[],
Iterable[Sequence[np.ndarray]]]] = None,
debug_options: Optional[QuantizationDebugOptions] = None) -> None:
"""Runs the TFLite debugging model with given debug options.
Args:
quant_debug_model_path: Path to debug mode TF-Lite Flatbuffer file.
quant_debug_model_content: Content of the quantized debug model.
debug_dataset: a factory function that returns dataset generator which is
used to generate input samples (list of np.ndarray) for the model. The
generated elements must have same types and shape as inputs to the
model.
debug_options: Debug options to debug the given model.
Raises:
ValueError: If the debugger was unable to be created.
Attributes:
layer_statistics: results of error metrics for each NumericVerify op
results. in {layer_name: {metric_name: metric}} format.
"""
self._data_gen = debug_dataset
self._debug_options = debug_options or QuantizationDebugOptions()

input_data = next(iter(self._data_gen()))
self._quant_interpreter = tf.lite.Interpreter(quant_debug_model_path,
quant_debug_model_content)

self._numeric_verify_tensor_details = None
if not self._get_numeric_verify_tensor_details():
raise ValueError('Please check if the quantized model is in debug mode')

self._layer_debug_metrics = _DEFAULT_LAYER_DEBUG_METRICS.copy()
if self._debug_options.layer_debug_metrics:
self._layer_debug_metrics.update(self._debug_options.layer_debug_metrics)

self.layer_statistics = None

def run(self) -> None:
"""Runs models and gets metrics."""
self.layer_statistics = self._collect_layer_statistics()

def _collect_layer_statistics(self) -> Dict[str, Dict[str, float]]:
"""Collects layer statistics by applying layer debug metrics.
For all data from the given RepresentativeDataset, collect statistics per
example by getting the NumericVerify op results in _quant_interpreter
and calculating layer debug metrics on the results.
Returns:
aggregated per-layer statistics of NumericVerify results.
{layer_name: {metric_name: metric}}
"""
layer_statistics = collections.defaultdict(
lambda: collections.defaultdict(list))

initialize = True
for tensor_data in self._data_gen():
self._set_input_tensors(self._quant_interpreter, tensor_data, initialize)
initialize = False

# Run the model.
self._quant_interpreter.invoke()

# Collect the statistics of this invoke result.
for tensor_details in self._get_numeric_verify_tensor_details():
tensor_name = tensor_details['name']
diffs = self._quant_interpreter.get_tensor(tensor_details['index'])
for metric_name, metric_fn in self._layer_debug_metrics.items():
layer_statistics[tensor_name][metric_name].append(metric_fn(diffs))

# Calculate final aggregated metrics for each layer.
for metrics in layer_statistics.values():
for metric_name in metrics:
metrics[metric_name] = np.mean(metrics[metric_name])

return layer_statistics

def _set_input_tensors(
self,
interpreter: tf.lite.Interpreter,
tensor_data: Sequence[np.ndarray],
initialize: bool,
) -> None:
"""Sets input tensors into TFLite model Interpreter.
Args:
interpreter: a tf.lite.Interpreter object with allocated tensors.
tensor_data: a list of Numpy array data.
initialize: set to true when input is first set for the interpreter, to
set input shapes and allocate tensors.
Raises:
ValueError: when inputs can't be set, or size of provided inputs does not
match size of model inputs.
"""
input_indices = [
detail['index'] for detail in interpreter.get_input_details()
]
if len(input_indices) != len(tensor_data):
raise ValueError(
'Number of inputs provided ({}) does not match number of inputs to '
'the model ({})'.format(len(tensor_data), len(input_indices)))

if initialize:
for input_idx, tensor in zip(input_indices, tensor_data):
interpreter.resize_tensor_input(input_idx, tensor.shape)
interpreter.allocate_tensors()

for input_idx, tensor in zip(input_indices, tensor_data):
interpreter.set_tensor(input_idx, tensor)

def _get_numeric_verify_tensor_details(self) -> List[str]:
"""Returns all names of all tensors from NumericVerify op."""
if not self._numeric_verify_tensor_details:
self._numeric_verify_tensor_details = [
detail for detail in self._quant_interpreter.get_tensor_details()
if detail['name'].startswith('NumericVerify')
]
return self._numeric_verify_tensor_details
131 changes: 131 additions & 0 deletions tensorflow/lite/experimental/quantization_debugger/debugger_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for QuantizationDebugger."""

import numpy as np
import tensorflow as tf

from tensorflow.lite.experimental.quantization_debugger import debugger
from tensorflow.lite.python import convert
from tensorflow.lite.python import lite
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
from tensorflow.python.training.tracking import tracking


def _get_model():
"""Returns somple model with Conv2D and representative dataset gen."""
root = tracking.AutoTrackable()
kernel_in = np.array([-2, -1, 1, 2], dtype=np.float32).reshape((2, 2, 1, 1))

@tf.function(
input_signature=[tf.TensorSpec(shape=[1, 3, 3, 1], dtype=tf.float32)])
def func(inp):
kernel = tf.constant(kernel_in, dtype=tf.float32)
conv = tf.nn.conv2d(inp, kernel, strides=1, padding='SAME')
output = tf.nn.relu(conv, name='output')
return output

root.f = func
to_save = root.f.get_concrete_function()
return to_save


def _calibration_gen():
for i in range(5):
yield [np.arange(9).reshape((1, 3, 3, 1)).astype(np.float32) * i]


def _quantize_model(func, calibration_gen, debug=True):
"""Quantizes model, in debug or normal mode."""
converter = lite.TFLiteConverterV2.from_concrete_functions([func])
converter.target_spec.supported_ops = [lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.representative_dataset = calibration_gen

# Create a TFLite model with new quantizer and numeric verify ops.
converter.optimizations = [lite.Optimize.DEFAULT]
converter.experimental_new_quantizer = True
if debug:
converter._experimental_calibrate_only = True
calibrated = converter.convert()
return convert.mlir_quantize(calibrated, enable_numeric_verify=True)
else:
return converter.convert()


class QuantizationDebuggerTest(test_util.TensorFlowTestCase):

@classmethod
def setUpClass(cls):
super().setUpClass()
cls.float_model = _get_model()
cls.debug_model = _quantize_model(cls.float_model, _calibration_gen)

@test_util.run_v2_only
def test_quantization_debugger(self):
options = debugger.QuantizationDebugOptions(
layer_debug_metrics={'l1_norm': lambda diffs: np.mean(np.abs(diffs))})
quant_debugger = debugger.QuantizationDebugger(
quant_debug_model_content=QuantizationDebuggerTest.debug_model,
debug_dataset=_calibration_gen,
debug_options=options)
quant_debugger.run()

expected_metrics = {
'num_elements': 9,
'stddev': 0.03850026,
'mean_error': 0.01673192,
'max_abs_error': 0.10039272,
'mean_square_error': 0.0027558778,
'l1_norm': 0.023704167,
}
self.assertLen(quant_debugger.layer_statistics, 1)
actual_metrics = next(iter(quant_debugger.layer_statistics.values()))

self.assertCountEqual(expected_metrics.keys(), actual_metrics.keys())
for key, value in expected_metrics.items():
self.assertAlmostEqual(value, actual_metrics[key])

@test_util.run_v2_only
def test_quantization_debugger_wrong_input_raises_ValueError(self):

def wrong_calibration_gen():
for _ in range(5):
yield [
np.ones((1, 3, 3, 1), dtype=np.float32),
np.ones((1, 3, 3, 1), dtype=np.float32)
]

quant_debugger = debugger.QuantizationDebugger(
quant_debug_model_content=QuantizationDebuggerTest.debug_model,
debug_dataset=wrong_calibration_gen)
with self.assertRaisesRegex(
ValueError, r'inputs provided \(2\).+inputs to the model \(1\)'):
quant_debugger.run()

@test_util.run_v2_only
def test_quantization_debugger_non_debug_model_raises_ValueError(self):
normal_quant_model = _quantize_model(
QuantizationDebuggerTest.float_model, _calibration_gen, debug=False)

with self.assertRaisesRegex(
ValueError, 'Please check if the quantized model is in debug mode'):
debugger.QuantizationDebugger(
quant_debug_model_content=normal_quant_model,
debug_dataset=_calibration_gen)


if __name__ == '__main__':
test.main()

0 comments on commit 2fb5d64

Please sign in to comment.