From f113cd77cf01b01286c60ac7f09ec72e5da75c6d Mon Sep 17 00:00:00 2001 From: "nicholas.bergh" Date: Tue, 28 Sep 2021 17:09:54 -0400 Subject: [PATCH 01/14] add support for normalization in graph --- mmseg/datasets/pipelines/transforms.py | 12 +++++++----- mmseg/models/segmentors/base.py | 9 +++++++++ tools/pytorch2onnx.py | 20 +++++++++++++++++--- 3 files changed, 33 insertions(+), 8 deletions(-) diff --git a/mmseg/datasets/pipelines/transforms.py b/mmseg/datasets/pipelines/transforms.py index f2a642c141..4738dee54e 100644 --- a/mmseg/datasets/pipelines/transforms.py +++ b/mmseg/datasets/pipelines/transforms.py @@ -434,10 +434,11 @@ class Normalize(object): default is true. """ - def __init__(self, mean, std, to_rgb=True): + def __init__(self, mean, std, to_rgb=True, normalize_in_graph=False): self.mean = np.array(mean, dtype=np.float32) self.std = np.array(std, dtype=np.float32) self.to_rgb = to_rgb + self.normalize_in_graph = normalize_in_graph def __call__(self, results): """Call function to normalize images. @@ -449,9 +450,9 @@ def __call__(self, results): dict: Normalized results, 'img_norm_cfg' key is added into result dict. """ - - results['img'] = mmcv.imnormalize(results['img'], self.mean, self.std, - self.to_rgb) + if not self.normalize_in_graph: + results['img'] = mmcv.imnormalize(results['img'], self.mean, self.std, + self.to_rgb) results['img_norm_cfg'] = dict( mean=self.mean, std=self.std, to_rgb=self.to_rgb) return results @@ -459,7 +460,8 @@ def __call__(self, results): def __repr__(self): repr_str = self.__class__.__name__ repr_str += f'(mean={self.mean}, std={self.std}, to_rgb=' \ - f'{self.to_rgb})' + f'{self.to_rgb})' \ + f', normalize_in_graph={self.normalize_in_graph}' return repr_str diff --git a/mmseg/models/segmentors/base.py b/mmseg/models/segmentors/base.py index 906c6fe564..68a2f7c55f 100644 --- a/mmseg/models/segmentors/base.py +++ b/mmseg/models/segmentors/base.py @@ -104,6 +104,15 @@ def forward(self, img, img_metas, return_loss=True, **kwargs): should be double nested (i.e. List[Tensor], List[List[dict]]), with the outer list indicating test time augmentations. """ + if torch.onnx.is_in_onnx_export(): + assert len(img_metas) == 1 + img_norm_cfg = img_metas[0][0].get('img_norm_cfg', None) + if img_norm_cfg: + mean = torch.tensor(img_norm_cfg['mean'])[None, ..., None, + None] + std = torch.tensor(img_norm_cfg['std'])[None, ..., None, None] + img[0] = (img[0] - mean) / std + if return_loss: return self.forward_train(img, img_metas, **kwargs) else: diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py index 1751a7b750..6bc52c2fbe 100644 --- a/tools/pytorch2onnx.py +++ b/tools/pytorch2onnx.py @@ -1,4 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. +# import sys +# sys.path.insert(0, '/nicholasbergh/mmsegmentation/mmseg') + import argparse from functools import partial @@ -76,11 +79,16 @@ def _demo_mm_inputs(input_shape, num_classes): def _prepare_input_img(img_path, test_pipeline, shape=None, - rescale_shape=None): + rescale_shape=None, + normalize_in_graph=False): # build the data pipeline if shape is not None: test_pipeline[1]['img_scale'] = (shape[1], shape[0]) test_pipeline[1]['transforms'][0]['keep_ratio'] = False + if normalize_in_graph: + for transform in test_pipeline[1]['transforms']: + if transform['type'] is 'Normalize': + transform.normalize_in_graph = True test_pipeline = [LoadImage()] + test_pipeline[1:] test_pipeline = Compose(test_pipeline) # prepare data @@ -121,6 +129,8 @@ def _update_input_img(img_list, img_meta_list, update_ori_shape=False): (img_shape[1] / ori_shape[1], img_shape[0] / ori_shape[0]) * 2, 'flip': False, + 'img_norm_cfg': + img_meta['img_norm_cfg'] } for _ in range(N)]] return img_list, new_img_meta_list @@ -159,7 +169,6 @@ def pytorch2onnx(model, imgs = mm_inputs.pop('imgs') img_metas = mm_inputs.pop('img_metas') - img_list = [img[None, :] for img in imgs] img_meta_list = [[img_meta] for img_meta in img_metas] # update img_meta img_list, img_meta_list = _update_input_img(img_list, img_meta_list) @@ -322,6 +331,10 @@ def parse_args(): '--dynamic-export', action='store_true', help='Whether to export onnx with dynamic axis.') + parser.add_argument( + '--normalize-in-graph', + action='store_true', + help='Whether to include image normalization in ONNX graph.') args = parser.parse_args() return args @@ -372,7 +385,8 @@ def parse_args(): args.input_img, cfg.data.test.pipeline, shape=preprocess_shape, - rescale_shape=rescale_shape) + rescale_shape=rescale_shape, + normalize_in_graph=args.normalize_in_graph) else: if isinstance(segmentor.decode_head, nn.ModuleList): num_classes = segmentor.decode_head[-1].num_classes From d8f9fe0a7e1d4aa9ecded03f3ad5d9a902f97b66 Mon Sep 17 00:00:00 2001 From: "nicholas.bergh" Date: Tue, 28 Sep 2021 17:10:17 -0400 Subject: [PATCH 02/14] add support for uint8 input --- tools/pytorch2onnx.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py index 6bc52c2fbe..fce18172a5 100644 --- a/tools/pytorch2onnx.py +++ b/tools/pytorch2onnx.py @@ -169,6 +169,7 @@ def pytorch2onnx(model, imgs = mm_inputs.pop('imgs') img_metas = mm_inputs.pop('img_metas') + img_list = [img[None, :].type(torch.uint8) for img in imgs] img_meta_list = [[img_meta] for img_meta in img_metas] # update img_meta img_list, img_meta_list = _update_input_img(img_list, img_meta_list) From 7970216c467dde24d3870baf40d0b6240d585d22 Mon Sep 17 00:00:00 2001 From: "nicholas.bergh" Date: Tue, 28 Sep 2021 17:14:06 -0400 Subject: [PATCH 03/14] add back missing line --- tools/pytorch2onnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py index fce18172a5..970481ee99 100644 --- a/tools/pytorch2onnx.py +++ b/tools/pytorch2onnx.py @@ -169,7 +169,7 @@ def pytorch2onnx(model, imgs = mm_inputs.pop('imgs') img_metas = mm_inputs.pop('img_metas') - img_list = [img[None, :].type(torch.uint8) for img in imgs] + img_list = [img[None, :] for img in imgs] img_meta_list = [[img_meta] for img_meta in img_metas] # update img_meta img_list, img_meta_list = _update_input_img(img_list, img_meta_list) From 93b105193de06b62ac52d97412f50035318a3337 Mon Sep 17 00:00:00 2001 From: "nicholas.bergh" Date: Wed, 29 Sep 2021 16:04:05 -0400 Subject: [PATCH 04/14] ensure model is normalized for verify, ensure norm only happens when specified --- mmseg/datasets/pipelines/transforms.py | 3 ++- mmseg/models/segmentors/base.py | 6 +++--- tools/pytorch2onnx.py | 17 +++++++---------- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/mmseg/datasets/pipelines/transforms.py b/mmseg/datasets/pipelines/transforms.py index 4738dee54e..3967733535 100644 --- a/mmseg/datasets/pipelines/transforms.py +++ b/mmseg/datasets/pipelines/transforms.py @@ -454,7 +454,8 @@ def __call__(self, results): results['img'] = mmcv.imnormalize(results['img'], self.mean, self.std, self.to_rgb) results['img_norm_cfg'] = dict( - mean=self.mean, std=self.std, to_rgb=self.to_rgb) + mean=self.mean, std=self.std, to_rgb=self.to_rgb, + normalize_in_graph=self.normalize_in_graph) return results def __repr__(self): diff --git a/mmseg/models/segmentors/base.py b/mmseg/models/segmentors/base.py index 68a2f7c55f..836f4d2275 100644 --- a/mmseg/models/segmentors/base.py +++ b/mmseg/models/segmentors/base.py @@ -104,14 +104,14 @@ def forward(self, img, img_metas, return_loss=True, **kwargs): should be double nested (i.e. List[Tensor], List[List[dict]]), with the outer list indicating test time augmentations. """ - if torch.onnx.is_in_onnx_export(): + if torch.onnx.is_in_onnx_export() or kwargs.pop('do_norm', False): assert len(img_metas) == 1 img_norm_cfg = img_metas[0][0].get('img_norm_cfg', None) - if img_norm_cfg: + if img_norm_cfg and img_norm_cfg['normalize_in_graph']: mean = torch.tensor(img_norm_cfg['mean'])[None, ..., None, None] std = torch.tensor(img_norm_cfg['std'])[None, ..., None, None] - img[0] = (img[0] - mean) / std + img = (img - mean) / std if return_loss: return self.forward_train(img, img_metas, **kwargs) diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py index 970481ee99..39c6544f79 100644 --- a/tools/pytorch2onnx.py +++ b/tools/pytorch2onnx.py @@ -85,10 +85,9 @@ def _prepare_input_img(img_path, if shape is not None: test_pipeline[1]['img_scale'] = (shape[1], shape[0]) test_pipeline[1]['transforms'][0]['keep_ratio'] = False - if normalize_in_graph: - for transform in test_pipeline[1]['transforms']: - if transform['type'] is 'Normalize': - transform.normalize_in_graph = True + for transform in test_pipeline[1]['transforms']: + if transform['type'] is 'Normalize': + transform.normalize_in_graph = normalize_in_graph test_pipeline = [LoadImage()] + test_pipeline[1:] test_pipeline = Compose(test_pipeline) # prepare data @@ -200,6 +199,7 @@ def pytorch2onnx(model, } register_extra_symbolics(opset_version) + # import pdb; pdb.set_trace() with torch.no_grad(): torch.onnx.export( model, (img_list, ), @@ -229,15 +229,12 @@ def pytorch2onnx(model, torch.cat((ori_img, flip_img), 0) for ori_img, flip_img in zip(img_list, flip_img_list) ] - - # update img_meta - img_list, img_meta_list = _update_input_img( - img_list, img_meta_list, test_mode == 'whole') - + # import pdb; pdb.set_trace() # check the numerical value # get pytorch output with torch.no_grad(): - pytorch_result = model(img_list, img_meta_list, return_loss=False) + img_list_clone = [tensor.clone() for tensor in img_list] # in case img needs to be normalized + pytorch_result = model(img_list_clone, img_meta_list, return_loss=False, do_norm=args.normalize_in_graph) pytorch_result = np.stack(pytorch_result, 0) # get onnx output From a2ce38b702ff9253ff106d3288405f88291965b0 Mon Sep 17 00:00:00 2001 From: "nicholas.bergh" Date: Wed, 29 Sep 2021 17:08:34 -0400 Subject: [PATCH 05/14] fix img norm bug, disable verify temporarily --- mmseg/models/segmentors/base.py | 2 +- tools/pytorch2onnx.py | 6 +----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/mmseg/models/segmentors/base.py b/mmseg/models/segmentors/base.py index 836f4d2275..80b19ea767 100644 --- a/mmseg/models/segmentors/base.py +++ b/mmseg/models/segmentors/base.py @@ -111,7 +111,7 @@ def forward(self, img, img_metas, return_loss=True, **kwargs): mean = torch.tensor(img_norm_cfg['mean'])[None, ..., None, None] std = torch.tensor(img_norm_cfg['std'])[None, ..., None, None] - img = (img - mean) / std + img[0] = (img[0] - mean) / std if return_loss: return self.forward_train(img, img_metas, **kwargs) diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py index 39c6544f79..81aacd7638 100644 --- a/tools/pytorch2onnx.py +++ b/tools/pytorch2onnx.py @@ -1,7 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -# import sys -# sys.path.insert(0, '/nicholasbergh/mmsegmentation/mmseg') - import argparse from functools import partial @@ -199,7 +196,6 @@ def pytorch2onnx(model, } register_extra_symbolics(opset_version) - # import pdb; pdb.set_trace() with torch.no_grad(): torch.onnx.export( model, (img_list, ), @@ -215,6 +211,7 @@ def pytorch2onnx(model, model.forward = origin_forward if verify: + assert not args.normalize_in_graph, "verify not supported with normalize_in_graph yet" # check by onnx import onnx onnx_model = onnx.load(output_file) @@ -229,7 +226,6 @@ def pytorch2onnx(model, torch.cat((ori_img, flip_img), 0) for ori_img, flip_img in zip(img_list, flip_img_list) ] - # import pdb; pdb.set_trace() # check the numerical value # get pytorch output with torch.no_grad(): From aa7be96e893f7a272ffc70679d78c10416d17e82 Mon Sep 17 00:00:00 2001 From: "nicholas.bergh" Date: Wed, 13 Oct 2021 11:08:28 -0400 Subject: [PATCH 06/14] fix verify in onnx export --- tools/pytorch2onnx.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py index 81aacd7638..531b77d6a8 100644 --- a/tools/pytorch2onnx.py +++ b/tools/pytorch2onnx.py @@ -211,7 +211,6 @@ def pytorch2onnx(model, model.forward = origin_forward if verify: - assert not args.normalize_in_graph, "verify not supported with normalize_in_graph yet" # check by onnx import onnx onnx_model = onnx.load(output_file) @@ -226,6 +225,10 @@ def pytorch2onnx(model, torch.cat((ori_img, flip_img), 0) for ori_img, flip_img in zip(img_list, flip_img_list) ] + + img_list, img_meta_list = _update_input_img( + img_list, img_meta_list, test_mode == 'whole') + # check the numerical value # get pytorch output with torch.no_grad(): From 3215bb8e6a89b9a8442cceb0c4c74c2fccfd3d24 Mon Sep 17 00:00:00 2001 From: "nicholas.bergh" Date: Wed, 3 Nov 2021 16:28:50 -0400 Subject: [PATCH 07/14] disable constant folding --- tools/pytorch2onnx.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py index 531b77d6a8..204191bccd 100644 --- a/tools/pytorch2onnx.py +++ b/tools/pytorch2onnx.py @@ -206,7 +206,8 @@ def pytorch2onnx(model, keep_initializers_as_inputs=False, verbose=show, opset_version=opset_version, - dynamic_axes=dynamic_axes) + dynamic_axes=dynamic_axes, + do_constant_folding=False) print(f'Successfully exported ONNX model: {output_file}') model.forward = origin_forward From 7c0d7761f2cbd795af611d611ebfb3a6b380c38c Mon Sep 17 00:00:00 2001 From: "nicholas.bergh" Date: Tue, 9 Nov 2021 14:51:54 -0500 Subject: [PATCH 08/14] restore constant folding --- tools/pytorch2onnx.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py index 204191bccd..531b77d6a8 100644 --- a/tools/pytorch2onnx.py +++ b/tools/pytorch2onnx.py @@ -206,8 +206,7 @@ def pytorch2onnx(model, keep_initializers_as_inputs=False, verbose=show, opset_version=opset_version, - dynamic_axes=dynamic_axes, - do_constant_folding=False) + dynamic_axes=dynamic_axes) print(f'Successfully exported ONNX model: {output_file}') model.forward = origin_forward From 9b8aee4765a1f3dbb78b356aaf0c88638fa7577b Mon Sep 17 00:00:00 2001 From: Nick Date: Tue, 9 Nov 2021 14:58:19 -0500 Subject: [PATCH 09/14] fix formatting for flake --- mmseg/datasets/pipelines/transforms.py | 8 +++++--- tools/pytorch2onnx.py | 11 ++++++++--- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/mmseg/datasets/pipelines/transforms.py b/mmseg/datasets/pipelines/transforms.py index 3967733535..302dfb7d74 100644 --- a/mmseg/datasets/pipelines/transforms.py +++ b/mmseg/datasets/pipelines/transforms.py @@ -451,10 +451,12 @@ def __call__(self, results): result dict. """ if not self.normalize_in_graph: - results['img'] = mmcv.imnormalize(results['img'], self.mean, self.std, - self.to_rgb) + results['img'] = mmcv.imnormalize(results['img'], self.mean, + self.std, self.to_rgb) results['img_norm_cfg'] = dict( - mean=self.mean, std=self.std, to_rgb=self.to_rgb, + mean=self.mean, + std=self.std, + to_rgb=self.to_rgb, normalize_in_graph=self.normalize_in_graph) return results diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py index 531b77d6a8..d50a3381f2 100644 --- a/tools/pytorch2onnx.py +++ b/tools/pytorch2onnx.py @@ -83,7 +83,7 @@ def _prepare_input_img(img_path, test_pipeline[1]['img_scale'] = (shape[1], shape[0]) test_pipeline[1]['transforms'][0]['keep_ratio'] = False for transform in test_pipeline[1]['transforms']: - if transform['type'] is 'Normalize': + if transform['type'] == 'Normalize': transform.normalize_in_graph = normalize_in_graph test_pipeline = [LoadImage()] + test_pipeline[1:] test_pipeline = Compose(test_pipeline) @@ -232,8 +232,13 @@ def pytorch2onnx(model, # check the numerical value # get pytorch output with torch.no_grad(): - img_list_clone = [tensor.clone() for tensor in img_list] # in case img needs to be normalized - pytorch_result = model(img_list_clone, img_meta_list, return_loss=False, do_norm=args.normalize_in_graph) + img_list_clone = [tensor.clone() for tensor in img_list + ] # in case img needs to be normalized + pytorch_result = model( + img_list_clone, + img_meta_list, + return_loss=False, + do_norm=args.normalize_in_graph) pytorch_result = np.stack(pytorch_result, 0) # get onnx output From 6f142d7b75fcdc6ccee79d53e12912a08a11ba80 Mon Sep 17 00:00:00 2001 From: Nick Date: Tue, 9 Nov 2021 16:50:12 -0500 Subject: [PATCH 10/14] update docstring in transforms.Normalize --- mmseg/datasets/pipelines/transforms.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mmseg/datasets/pipelines/transforms.py b/mmseg/datasets/pipelines/transforms.py index 302dfb7d74..575909205d 100644 --- a/mmseg/datasets/pipelines/transforms.py +++ b/mmseg/datasets/pipelines/transforms.py @@ -432,6 +432,9 @@ class Normalize(object): std (sequence): Std values of 3 channels. to_rgb (bool): Whether to convert the image from BGR to RGB, default is true. + normalize_in_graph (bool): Whether to normalize the input in the + graph (mmseg.models.segmentors.base.forward) or here in the + __call__ function (only relevant during onnx export) """ def __init__(self, mean, std, to_rgb=True, normalize_in_graph=False): From cebae08da21997e30a55afc098035d8b7cb1e1af Mon Sep 17 00:00:00 2001 From: Nick Date: Tue, 9 Nov 2021 16:54:39 -0500 Subject: [PATCH 11/14] add normalize-in-graph to pytorch2onnx documentation --- docs/useful_tools.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/useful_tools.md b/docs/useful_tools.md index 8cec9b7024..ae2e8e7e56 100644 --- a/docs/useful_tools.md +++ b/docs/useful_tools.md @@ -59,6 +59,7 @@ python tools/pytorch2onnx.py \ --show \ --verify \ --dynamic-export \ + --normalize-in-graph \ --cfg-options \ model.test_cfg.mode="whole" ``` @@ -74,6 +75,7 @@ Description of arguments: - `--show`: Determines whether to print the architecture of the exported model. If not specified, it will be set to `False`. - `--verify`: Determines whether to verify the correctness of an exported model. If not specified, it will be set to `False`. - `--dynamic-export`: Determines whether to export ONNX model with dynamic input and output shapes. If not specified, it will be set to `False`. +- `--normalize-in-graph`: Whether to include image normalization in ONNX graph. This would cause the saved ONNX model to do image normalization as the first step, allowing inference with non-normalized inputs. If not specified, it will be set to `False`. - `--cfg-options`:Update config options. :::{note} From ad82c48f2e1abd46db6938c5d3f907beb9280e31 Mon Sep 17 00:00:00 2001 From: Nick Date: Mon, 15 Nov 2021 13:35:07 -0500 Subject: [PATCH 12/14] add img norm for all imgs in forward of BaseSegmentor --- mmseg/models/segmentors/base.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/mmseg/models/segmentors/base.py b/mmseg/models/segmentors/base.py index 9c8fd4bd56..6eb546eb3e 100644 --- a/mmseg/models/segmentors/base.py +++ b/mmseg/models/segmentors/base.py @@ -104,18 +104,19 @@ def forward(self, img, img_metas, return_loss=True, **kwargs): should be double nested (i.e. List[Tensor], List[List[dict]]), with the outer list indicating test time augmentations. """ - if torch.onnx.is_in_onnx_export() or kwargs.pop('do_norm', False): - assert len(img_metas) == 1 - img_norm_cfg = img_metas[0][0].get('img_norm_cfg', None) - if img_norm_cfg and img_norm_cfg['normalize_in_graph']: - mean = torch.tensor(img_norm_cfg['mean'])[None, ..., None, - None] - std = torch.tensor(img_norm_cfg['std'])[None, ..., None, None] - img[0] = (img[0] - mean) / std - if return_loss: return self.forward_train(img, img_metas, **kwargs) else: + if torch.onnx.is_in_onnx_export() or kwargs.pop('do_norm', False): + assert len(img_metas) == 1 + img_norm_cfg = img_metas[0][0].get('img_norm_cfg', None) + if img_norm_cfg and img_norm_cfg['normalize_in_graph']: + mean = torch.tensor(img_norm_cfg['mean'])[None, ..., None, + None] + std = torch.tensor(img_norm_cfg['std'])[None, ..., None, + None] + img = [(i - mean) / std for i in img] + return self.forward_test(img, img_metas, **kwargs) def train_step(self, data_batch, optimizer, **kwargs): From 4ef719f7204828df03b742cef206d825a6b82360 Mon Sep 17 00:00:00 2001 From: Nick Date: Fri, 19 Nov 2021 15:40:46 -0500 Subject: [PATCH 13/14] add img normalization skip arg to onnx2tensorrt --- docs/useful_tools.md | 4 +++- tools/onnx2tensorrt.py | 19 +++++++++++++++---- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/docs/useful_tools.md b/docs/useful_tools.md index ae2e8e7e56..75baa4f7ff 100644 --- a/docs/useful_tools.md +++ b/docs/useful_tools.md @@ -205,7 +205,8 @@ python ${MMSEG_PATH}/tools/onnx2tensorrt.py \ --max-shape ${MAX_SHAPE} \ --input-img ${INPUT_IMG} \ --show \ - --verify + --verify \ + --skip-normalize ``` Description of all arguments @@ -222,6 +223,7 @@ Description of all arguments - `--dataset` : Palette provider, `CityscapesDataset` as default. - `--verify` : Verify the outputs of ONNXRuntime and TensorRT. - `--verbose` : Whether to verbose logging messages while creating TensorRT engine. Defaults to False. +- `--skip-normalize` : Whether to skip image normalization in preprocessing. Defaults to False. :::{note} Only tested on whole mode. diff --git a/tools/onnx2tensorrt.py b/tools/onnx2tensorrt.py index f8a258fc80..6e2b3d4cff 100644 --- a/tools/onnx2tensorrt.py +++ b/tools/onnx2tensorrt.py @@ -26,11 +26,15 @@ def get_GiB(x: int): def _prepare_input_img(img_path: str, test_pipeline: Iterable[dict], shape: Optional[Iterable] = None, - rescale_shape: Optional[Iterable] = None) -> dict: + rescale_shape: Optional[Iterable] = None, + normalize_in_graph: Optional[bool] = False) -> dict: # build the data pipeline if shape is not None: test_pipeline[1]['img_scale'] = (shape[1], shape[0]) test_pipeline[1]['transforms'][0]['keep_ratio'] = False + for transform in test_pipeline[1]['transforms']: + if transform['type'] == 'Normalize': + transform.normalize_in_graph = normalize_in_graph test_pipeline = [LoadImage()] + test_pipeline[1:] test_pipeline = Compose(test_pipeline) # prepare data @@ -113,7 +117,8 @@ def onnx2tensorrt(onnx_file: str, show: bool = False, dataset: str = 'CityscapesDataset', workspace_size: int = 1, - verbose: bool = False): + verbose: bool = False, + skip_normalize: bool = False): import tensorrt as trt min_shape = input_config['min_shape'] max_shape = input_config['max_shape'] @@ -136,7 +141,8 @@ def onnx2tensorrt(onnx_file: str, inputs = _prepare_input_img( input_config['input_path'], config.data.test.pipeline, - shape=min_shape[2:]) + shape=min_shape[2:], + normalize_in_graph=skip_normalize) imgs = inputs['imgs'] img_metas = inputs['img_metas'] @@ -233,6 +239,10 @@ def parse_args(): action='store_true', help='Whether to verbose logging messages while creating \ TensorRT engine.') + parser.add_argument( + '--skip-normalize', + action='store_true', + help='Whether to skip image normalization in preprocessing') args = parser.parse_args() return args @@ -273,4 +283,5 @@ def parse_args(): show=args.show, dataset=args.dataset, workspace_size=args.workspace_size, - verbose=args.verbose) + verbose=args.verbose, + skip_normalize=args.skip_normalize) From f62cf695260e53a986eec9ea2d117e4adaad08d3 Mon Sep 17 00:00:00 2001 From: Nick Date: Thu, 30 Dec 2021 10:51:47 -0500 Subject: [PATCH 14/14] ensure inputs to graphs are float in onnx and tensorrt export --- tools/onnx2tensorrt.py | 2 +- tools/pytorch2onnx.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/onnx2tensorrt.py b/tools/onnx2tensorrt.py index 6e2b3d4cff..6601a1adb5 100644 --- a/tools/onnx2tensorrt.py +++ b/tools/onnx2tensorrt.py @@ -40,7 +40,7 @@ def _prepare_input_img(img_path: str, # prepare data data = dict(img=img_path) data = test_pipeline(data) - imgs = data['img'] + imgs = [img.float() for img in data['img']] img_metas = [i.data for i in data['img_metas']] if rescale_shape is not None: diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py index d50a3381f2..e2c63bdaf2 100644 --- a/tools/pytorch2onnx.py +++ b/tools/pytorch2onnx.py @@ -90,7 +90,7 @@ def _prepare_input_img(img_path, # prepare data data = dict(img=img_path) data = test_pipeline(data) - imgs = data['img'] + imgs = [img.float() for img in data['img']] img_metas = [i.data for i in data['img_metas']] if rescale_shape is not None: