Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]6577-Support-torch-trt-dynamo #6578

Draft
wants to merge 4 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1251,6 +1251,7 @@ def trt_export(
key_in_ckpt: str | None = None,
precision: str | None = None,
input_shape: Sequence[int] | None = None,
use_torchscript: bool | None = None,
use_trace: bool | None = None,
dynamic_batchsize: Sequence[int] | None = None,
device: int | None = None,
Expand All @@ -1265,15 +1266,17 @@ def trt_export(
Export the model checkpoint to the given filepath as a TensorRT engine-based TorchScript.
Currently, this API only supports converting models whose inputs are all tensors.

There are two ways to export a model:
There are three ways to export a model:
1, Torch-TensorRT way: PyTorch module ---> TorchScript module ---> TensorRT engine-based TorchScript.
2, ONNX-TensorRT way: PyTorch module ---> TorchScript module ---> ONNX model ---> TensorRT engine --->
TensorRT engine-based TorchScript.
3, Torch-TensorRT dynamo way: PyTorch module ---> TensorRT engine-based TorchScript.

When exporting through the first way, some models suffer from the slowdown problem, since Torch-TensorRT
may only convert a little part of the PyTorch model to the TensorRT engine. However when exporting through
the second way, some Python data structures like `dict` are not supported. And some TorchScript models are
not supported by the ONNX if exported through `torch.jit.script`.
not supported by the ONNX if exported through `torch.jit.script`. When exporting through the dynamo way,
the converter_kwargs parameter must contains {'ir': 'dynamo_compile'}.

Typical usage examples:

Expand All @@ -1296,6 +1299,8 @@ def trt_export(
precision: the weight precision of the converted TensorRT engine based TorchScript models. Should be 'fp32' or 'fp16'.
input_shape: the input shape that is used to convert the model. Should be a list like [N, C, H, W] or
[N, C, H, W, D]. If not given, will try to parse from the `metadata` config.
use_torchscript: whether converting the PyTorch model to a TorchScript model before compiling it by torch_tensorrt,
default to True.
use_trace: whether using `torch.jit.trace` to convert the PyTorch model to a TorchScript model and then convert to
a TensorRT engine based TorchScript model or an ONNX model (if `use_onnx` is True).
dynamic_batchsize: a sequence with three elements to define the batch size range of the input for the model to be
Expand Down Expand Up @@ -1329,6 +1334,7 @@ def trt_export(
key_in_ckpt=key_in_ckpt,
precision=precision,
input_shape=input_shape,
use_torchscript=use_torchscript,
use_trace=use_trace,
dynamic_batchsize=dynamic_batchsize,
device=device,
Expand All @@ -1348,6 +1354,7 @@ def trt_export(
key_in_ckpt_,
precision_,
input_shape_,
use_torchscript_,
use_trace_,
dynamic_batchsize_,
device_,
Expand All @@ -1365,6 +1372,7 @@ def trt_export(
key_in_ckpt="",
precision="fp32",
input_shape=[],
use_torchscript=True,
use_trace=False,
dynamic_batchsize=None,
device=None,
Expand Down Expand Up @@ -1393,6 +1401,7 @@ def trt_export(
"precision": precision_,
"input_shape": input_shape_,
"dynamic_batchsize": dynamic_batchsize_,
"use_torchscript": use_torchscript_,
"use_trace": use_trace_,
"device": device_,
"use_onnx": use_onnx_,
Expand Down
17 changes: 13 additions & 4 deletions monai/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,7 @@ def convert_to_trt(
precision: str,
input_shape: Sequence[int],
dynamic_batchsize: Sequence[int] | None = None,
use_torchscript: bool = True,
use_trace: bool = False,
filename_or_obj: Any | None = None,
verify: bool = False,
Expand All @@ -865,15 +866,17 @@ def convert_to_trt(
"""
Utility to export a model into a TensorRT engine-based TorchScript model with optional input / output data verification.

There are two ways to export a model:
There are three ways to export a model:
1, Torch-TensorRT way: PyTorch module ---> TorchScript module ---> TensorRT engine-based TorchScript.
2, ONNX-TensorRT way: PyTorch module ---> TorchScript module ---> ONNX model ---> TensorRT engine --->
TensorRT engine-based TorchScript.
3, Torch-TensorRT Dynamo way: PyTorch module ---> TensorRT engine-based TorchScript.

When exporting through the first way, some models suffer from the slowdown problem, since Torch-TensorRT
may only convert a little part of the PyTorch model to the TensorRT engine. However when exporting through
the second way, some Python data structures like `dict` are not supported. And some TorchScript models are
not supported by the ONNX if exported through `torch.jit.script`.
not supported by the ONNX if exported through `torch.jit.script`. When exporting through the Dynamo way,
the converter_kwargs parameter must contains {'ir': 'dynamo_compile'}.

Args:
model: a source PyTorch model to convert.
Expand All @@ -885,6 +888,8 @@ def convert_to_trt(
input should between `MIN_BATCH` and `MAX_BATCH` and the `OPT_BATCH` is the best performance batchsize that the
TensorRT tries to fit. The `OPT_BATCH` should be the most frequently used input batchsize in the application,
default to None.
use_torchscript: whether converting the PyTorch model to a TorchScript model before compiling it by torch_tensorrt,
default to True.
use_trace: whether using `torch.jit.trace` to convert the PyTorch model to a TorchScript model and then convert to
a TensorRT engine based TorchScript model or an ONNX model (if `use_onnx` is True), default to False.
filename_or_obj: if not None, specify a file-like object (has to implement write and flush) or a string containing a
Expand Down Expand Up @@ -920,7 +925,7 @@ def convert_to_trt(

device = device if device else 0
target_device = torch.device(f"cuda:{device}") if device else torch.device("cuda:0")
convert_precision = torch.float32 if precision == "fp32" else torch.half
convert_precision = {torch.float32} if precision == "fp32" else {torch.half}
inputs = [torch.rand(ensure_tuple(input_shape)).to(target_device)]

def scale_batch_size(input_shape: Sequence[int], scale_num: int):
Expand All @@ -938,7 +943,11 @@ def scale_batch_size(input_shape: Sequence[int], scale_num: int):

# convert the torch model to a TorchScript model on target device
model = model.eval().to(target_device)
ir_model = convert_to_torchscript(model, device=target_device, inputs=inputs, use_trace=use_trace)
ir_model = (
convert_to_torchscript(model, device=target_device, inputs=inputs, use_trace=use_trace)
if use_torchscript
else model
)
ir_model.eval()

if use_onnx:
Expand Down