diff --git a/tools/onnx2tensorrt.py b/tools/onnx2tensorrt.py index 6e2b3d4cff..6601a1adb5 100644 --- a/tools/onnx2tensorrt.py +++ b/tools/onnx2tensorrt.py @@ -40,7 +40,7 @@ def _prepare_input_img(img_path: str, # prepare data data = dict(img=img_path) data = test_pipeline(data) - imgs = data['img'] + imgs = [img.float() for img in data['img']] img_metas = [i.data for i in data['img_metas']] if rescale_shape is not None: diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py index d50a3381f2..e2c63bdaf2 100644 --- a/tools/pytorch2onnx.py +++ b/tools/pytorch2onnx.py @@ -90,7 +90,7 @@ def _prepare_input_img(img_path, # prepare data data = dict(img=img_path) data = test_pipeline(data) - imgs = data['img'] + imgs = [img.float() for img in data['img']] img_metas = [i.data for i in data['img_metas']] if rescale_shape is not None: