From f62cf695260e53a986eec9ea2d117e4adaad08d3 Mon Sep 17 00:00:00 2001 From: Nick Date: Thu, 30 Dec 2021 10:51:47 -0500 Subject: [PATCH] ensure inputs to graphs are float in onnx and tensorrt export --- tools/onnx2tensorrt.py | 2 +- tools/pytorch2onnx.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/onnx2tensorrt.py b/tools/onnx2tensorrt.py index 6e2b3d4cff..6601a1adb5 100644 --- a/tools/onnx2tensorrt.py +++ b/tools/onnx2tensorrt.py @@ -40,7 +40,7 @@ def _prepare_input_img(img_path: str, # prepare data data = dict(img=img_path) data = test_pipeline(data) - imgs = data['img'] + imgs = [img.float() for img in data['img']] img_metas = [i.data for i in data['img_metas']] if rescale_shape is not None: diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py index d50a3381f2..e2c63bdaf2 100644 --- a/tools/pytorch2onnx.py +++ b/tools/pytorch2onnx.py @@ -90,7 +90,7 @@ def _prepare_input_img(img_path, # prepare data data = dict(img=img_path) data = test_pipeline(data) - imgs = data['img'] + imgs = [img.float() for img in data['img']] img_metas = [i.data for i in data['img_metas']] if rescale_shape is not None: