Skip to content

Commit

Permalink
merge branch into master
Browse files Browse the repository at this point in the history
  • Loading branch information
Harmit Minhas committed Dec 20, 2022
2 parents 4f394b0 + f62cf69 commit da8e457
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 17 deletions.
6 changes: 5 additions & 1 deletion docs/en/useful_tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ python tools/pytorch2onnx.py \
--show \
--verify \
--dynamic-export \
--normalize-in-graph \
--cfg-options \
model.test_cfg.mode="whole"
```
Expand All @@ -76,6 +77,7 @@ Description of arguments:
- `--show`: Determines whether to print the architecture of the exported model. If not specified, it will be set to `False`.
- `--verify`: Determines whether to verify the correctness of an exported model. If not specified, it will be set to `False`.
- `--dynamic-export`: Determines whether to export ONNX model with dynamic input and output shapes. If not specified, it will be set to `False`.
- `--normalize-in-graph`: Whether to include image normalization in ONNX graph. This would cause the saved ONNX model to do image normalization as the first step, allowing inference with non-normalized inputs. If not specified, it will be set to `False`.
- `--cfg-options`:Update config options.

:::{note}
Expand Down Expand Up @@ -205,7 +207,8 @@ python ${MMSEG_PATH}/tools/onnx2tensorrt.py \
--max-shape ${MAX_SHAPE} \
--input-img ${INPUT_IMG} \
--show \
--verify
--verify \
--skip-normalize
```

Description of all arguments
Expand All @@ -222,6 +225,7 @@ Description of all arguments
- `--dataset` : Palette provider, `CityscapesDataset` as default.
- `--verify` : Verify the outputs of ONNXRuntime and TensorRT.
- `--verbose` : Whether to verbose logging messages while creating TensorRT engine. Defaults to False.
- `--skip-normalize` : Whether to skip image normalization in preprocessing. Defaults to False.

:::{note}
Only tested on whole mode.
Expand Down
20 changes: 14 additions & 6 deletions mmseg/datasets/pipelines/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,12 +458,16 @@ class Normalize(object):
std (sequence): Std values of 3 channels.
to_rgb (bool): Whether to convert the image from BGR to RGB,
default is true.
normalize_in_graph (bool): Whether to normalize the input in the
graph (mmseg.models.segmentors.base.forward) or here in the
__call__ function (only relevant during onnx export)
"""

def __init__(self, mean, std, to_rgb=True):
def __init__(self, mean, std, to_rgb=True, normalize_in_graph=False):
self.mean = np.array(mean, dtype=np.float32)
self.std = np.array(std, dtype=np.float32)
self.to_rgb = to_rgb
self.normalize_in_graph = normalize_in_graph

def __call__(self, results):
"""Call function to normalize images.
Expand All @@ -475,17 +479,21 @@ def __call__(self, results):
dict: Normalized results, 'img_norm_cfg' key is added into
result dict.
"""

results['img'] = mmcv.imnormalize(results['img'], self.mean, self.std,
self.to_rgb)
if not self.normalize_in_graph:
results['img'] = mmcv.imnormalize(results['img'], self.mean,
self.std, self.to_rgb)
results['img_norm_cfg'] = dict(
mean=self.mean, std=self.std, to_rgb=self.to_rgb)
mean=self.mean,
std=self.std,
to_rgb=self.to_rgb,
normalize_in_graph=self.normalize_in_graph)
return results

def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(mean={self.mean}, std={self.std}, to_rgb=' \
f'{self.to_rgb})'
f'{self.to_rgb})' \
f', normalize_in_graph={self.normalize_in_graph}'
return repr_str


Expand Down
10 changes: 10 additions & 0 deletions mmseg/models/segmentors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,16 @@ def forward(self, img, img_metas, return_loss=True, **kwargs):
if return_loss:
return self.forward_train(img, img_metas, **kwargs)
else:
if torch.onnx.is_in_onnx_export() or kwargs.pop('do_norm', False):
assert len(img_metas) == 1
img_norm_cfg = img_metas[0][0].get('img_norm_cfg', None)
if img_norm_cfg and img_norm_cfg['normalize_in_graph']:
mean = torch.tensor(img_norm_cfg['mean'])[None, ..., None,
None]
std = torch.tensor(img_norm_cfg['std'])[None, ..., None,
None]
img = [(i - mean) / std for i in img]

return self.forward_test(img, img_metas, **kwargs)

def train_step(self, data_batch, optimizer, **kwargs):
Expand Down
21 changes: 16 additions & 5 deletions tools/onnx2tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,21 @@ def get_GiB(x: int):
def _prepare_input_img(img_path: str,
test_pipeline: Iterable[dict],
shape: Optional[Iterable] = None,
rescale_shape: Optional[Iterable] = None) -> dict:
rescale_shape: Optional[Iterable] = None,
normalize_in_graph: Optional[bool] = False) -> dict:
# build the data pipeline
if shape is not None:
test_pipeline[1]['img_scale'] = (shape[1], shape[0])
test_pipeline[1]['transforms'][0]['keep_ratio'] = False
for transform in test_pipeline[1]['transforms']:
if transform['type'] == 'Normalize':
transform.normalize_in_graph = normalize_in_graph
test_pipeline = [LoadImage()] + test_pipeline[1:]
test_pipeline = Compose(test_pipeline)
# prepare data
data = dict(img=img_path)
data = test_pipeline(data)
imgs = data['img']
imgs = [img.float() for img in data['img']]
img_metas = [i.data for i in data['img_metas']]

if rescale_shape is not None:
Expand Down Expand Up @@ -114,7 +118,8 @@ def onnx2tensorrt(onnx_file: str,
show: bool = False,
dataset: str = 'CityscapesDataset',
workspace_size: int = 1,
verbose: bool = False):
verbose: bool = False,
skip_normalize: bool = False):
import tensorrt as trt
min_shape = input_config['min_shape']
max_shape = input_config['max_shape']
Expand All @@ -137,7 +142,8 @@ def onnx2tensorrt(onnx_file: str,
inputs = _prepare_input_img(
input_config['input_path'],
config.data.test.pipeline,
shape=min_shape[2:])
shape=min_shape[2:],
normalize_in_graph=skip_normalize)

imgs = inputs['imgs']
img_metas = inputs['img_metas']
Expand Down Expand Up @@ -234,6 +240,10 @@ def parse_args():
action='store_true',
help='Whether to verbose logging messages while creating \
TensorRT engine.')
parser.add_argument(
'--skip-normalize',
action='store_true',
help='Whether to skip image normalization in preprocessing')
args = parser.parse_args()
return args

Expand Down Expand Up @@ -274,7 +284,8 @@ def parse_args():
show=args.show,
dataset=args.dataset,
workspace_size=args.workspace_size,
verbose=args.verbose)
verbose=args.verbose,
skip_normalize=args.skip_normalize)

# Following strings of text style are from colorama package
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
Expand Down
26 changes: 21 additions & 5 deletions tools/pytorch2onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,21 @@ def _demo_mm_inputs(input_shape, num_classes):
def _prepare_input_img(img_path,
test_pipeline,
shape=None,
rescale_shape=None):
rescale_shape=None,
normalize_in_graph=False):
# build the data pipeline
if shape is not None:
test_pipeline[1]['img_scale'] = (shape[1], shape[0])
test_pipeline[1]['transforms'][0]['keep_ratio'] = False
for transform in test_pipeline[1]['transforms']:
if transform['type'] == 'Normalize':
transform.normalize_in_graph = normalize_in_graph
test_pipeline = [LoadImage()] + test_pipeline[1:]
test_pipeline = Compose(test_pipeline)
# prepare data
data = dict(img=img_path)
data = test_pipeline(data)
imgs = data['img']
imgs = [img.float() for img in data['img']]
img_metas = [i.data for i in data['img_metas']]

if rescale_shape is not None:
Expand Down Expand Up @@ -122,6 +126,8 @@ def _update_input_img(img_list, img_meta_list, update_ori_shape=False):
(img_shape[1] / ori_shape[1], img_shape[0] / ori_shape[0]) * 2,
'flip':
False,
'img_norm_cfg':
img_meta['img_norm_cfg']
} for _ in range(N)]]

return img_list, new_img_meta_list
Expand Down Expand Up @@ -221,14 +227,19 @@ def pytorch2onnx(model,
for ori_img, flip_img in zip(img_list, flip_img_list)
]

# update img_meta
img_list, img_meta_list = _update_input_img(
img_list, img_meta_list, test_mode == 'whole')

# check the numerical value
# get pytorch output
with torch.no_grad():
pytorch_result = model(img_list, img_meta_list, return_loss=False)
img_list_clone = [tensor.clone() for tensor in img_list
] # in case img needs to be normalized
pytorch_result = model(
img_list_clone,
img_meta_list,
return_loss=False,
do_norm=args.normalize_in_graph)
pytorch_result = np.stack(pytorch_result, 0)

# get onnx output
Expand Down Expand Up @@ -324,6 +335,10 @@ def parse_args():
'--dynamic-export',
action='store_true',
help='Whether to export onnx with dynamic axis.')
parser.add_argument(
'--normalize-in-graph',
action='store_true',
help='Whether to include image normalization in ONNX graph.')
args = parser.parse_args()
return args

Expand Down Expand Up @@ -374,7 +389,8 @@ def parse_args():
args.input_img,
cfg.data.test.pipeline,
shape=preprocess_shape,
rescale_shape=rescale_shape)
rescale_shape=rescale_shape,
normalize_in_graph=args.normalize_in_graph)
else:
if isinstance(segmentor.decode_head, nn.ModuleList):
num_classes = segmentor.decode_head[-1].num_classes
Expand Down

0 comments on commit da8e457

Please sign in to comment.