Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]6838 support generative and hovernet with TensorRT #6859

Draft
wants to merge 8 commits into
base: dev
Choose a base branch
from
30 changes: 21 additions & 9 deletions monai/networks/nets/hovernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import warnings
from collections import OrderedDict
from collections.abc import Callable, Sequence
from typing import Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -446,6 +447,10 @@ class HoVerNet(nn.Module):
pretrained_state_dict_key: this arg is used when `pretrained_url` is provided and `adapt_standard_resnet` is True.
It is used to extract the expected state dict.
freeze_encoder: whether to freeze the encoder of the network.
use_list_output: whether to use a list as the forward method output. If set to False, the output of the
forward method would be a dictionary mapping output names to output tensors. Otherwise, the output of the
forward method would be a list of tensors in `[HoVerNetBranch.NP, HoVerNetBranch.HV, HoVerNetBranch.NC]`
order.
"""

Mode = HoVerNetMode
Expand All @@ -465,6 +470,7 @@ def __init__(
adapt_standard_resnet: bool = False,
pretrained_state_dict_key: str | None = None,
freeze_encoder: bool = False,
use_list_output: bool = False,
) -> None:
super().__init__()

Expand Down Expand Up @@ -499,6 +505,7 @@ def __init__(

conv_type: type[nn.Conv2d] = Conv[Conv.CONV, 2]

self.use_list_output = use_list_output
self.conv0 = nn.Sequential(
OrderedDict(
[
Expand Down Expand Up @@ -574,7 +581,7 @@ def __init__(
weights = _remap_preact_resnet_model(pretrained_url)
_load_pretrained_encoder(self, weights)

def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
def forward(self, x: torch.Tensor) -> Union[dict[str, torch.Tensor], list[torch.Tensor]]:
if self.mode == HoVerNetMode.ORIGINAL.value:
if x.shape[-1] != 270 or x.shape[-2] != 270:
raise ValueError("Input size should be 270 x 270 when using HoVerNetMode.ORIGINAL")
Expand All @@ -594,14 +601,19 @@ def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
x = self.bottleneck(x)
x = self.upsample(x)

output = {
HoVerNetBranch.NP.value: self.nucleus_prediction(x, short_cuts),
HoVerNetBranch.HV.value: self.horizontal_vertical(x, short_cuts),
}
if self.type_prediction is not None:
output[HoVerNetBranch.NC.value] = self.type_prediction(x, short_cuts)

return output
if self.use_list_output:
list_output = [self.nucleus_prediction(x, short_cuts), self.horizontal_vertical(x, short_cuts)]
if self.type_prediction is not None:
list_output.append(self.type_prediction(x, short_cuts))
return list_output
else:
output = {
HoVerNetBranch.NP.value: self.nucleus_prediction(x, short_cuts),
HoVerNetBranch.HV.value: self.horizontal_vertical(x, short_cuts),
}
if self.type_prediction is not None:
output[HoVerNetBranch.NC.value] = self.type_prediction(x, short_cuts)
return output


def _load_pretrained_encoder(model: nn.Module, state_dict: OrderedDict | dict):
Expand Down
59 changes: 41 additions & 18 deletions monai/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,9 +775,9 @@ def convert_to_torchscript(

def _onnx_trt_compile(
onnx_model,
min_shape: Sequence[int],
opt_shape: Sequence[int],
max_shape: Sequence[int],
min_shape: Sequence[Any],
opt_shape: Sequence[Any],
max_shape: Sequence[Any],
device: int,
precision: str,
input_names: Sequence[str] | None,
Expand Down Expand Up @@ -806,7 +806,7 @@ def _onnx_trt_compile(
trt, _ = optional_import("tensorrt", "8.5.3")
torch_tensorrt, _ = optional_import("torch_tensorrt", "1.4.0")

input_shapes = (min_shape, opt_shape, max_shape)
input_shapes = list(zip(min_shape, opt_shape, max_shape))
# default to an empty list to fit the `torch_tensorrt.ts.embed_engine_in_new_module` function.
input_names = [] if not input_names else input_names
output_names = [] if not output_names else output_names
Expand All @@ -817,8 +817,8 @@ def _onnx_trt_compile(
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
profile = builder.create_optimization_profile()
if input_names:
profile.set_shape(input_names[0], *input_shapes)
for index, input_name in enumerate(input_names):
profile.set_shape(input_name, *(input_shapes[index]))

# parse the ONNX model
parser = trt.OnnxParser(network, logger)
Expand Down Expand Up @@ -849,8 +849,8 @@ def _onnx_trt_compile(
def convert_to_trt(
model: nn.Module,
precision: str,
input_shape: Sequence[int],
dynamic_batchsize: Sequence[int] | None = None,
input_shape: Sequence[Any],
dynamic_batchsize: Sequence[Any] | None = None,
use_trace: bool = False,
filename_or_obj: Any | None = None,
verify: bool = False,
Expand Down Expand Up @@ -915,26 +915,48 @@ def convert_to_trt(
if not dynamic_batchsize:
warnings.warn(f"There is no dynamic batch range. The converted model only takes {input_shape} shape input.")

if (dynamic_batchsize is not None) and (len(dynamic_batchsize) != 3):
warnings.warn(f"The dynamic batch range sequence should have 3 elements, but got {dynamic_batchsize} elements.")
# if (dynamic_batchsize is not None) and (len(dynamic_batchsize) != 3):
# warnings.warn(f"The dynamic batch range sequence should have 3 elements, but got {dynamic_batchsize} elements.")

device = device if device else 0
target_device = torch.device(f"cuda:{device}") if device else torch.device("cuda:0")
convert_precision = torch.float32 if precision == "fp32" else torch.half
inputs = [torch.rand(ensure_tuple(input_shape)).to(target_device)]

def scale_batch_size(input_shape: Sequence[int], scale_num: int):
scale_shape = [*input_shape]
scale_shape[0] *= scale_num
return scale_shape

# Use the dynamic batchsize range to generate the min, opt and max model input shape
if dynamic_batchsize:
min_input_shape = scale_batch_size(input_shape, dynamic_batchsize[0])
opt_input_shape = scale_batch_size(input_shape, dynamic_batchsize[1])
max_input_shape = scale_batch_size(input_shape, dynamic_batchsize[2])
min_input_shape = []
opt_input_shape = []
max_input_shape = []
if isinstance(input_shape[0], list):
inputs = [torch.rand(ensure_tuple(shape)).to(target_device) for shape in input_shape]
# Use the dynamic batchsize range to generate the min, opt and max model input shape
if dynamic_batchsize:
min_input_shape.extend(
[scale_batch_size(input_shape[i], dynamic_batchsize[i][0]) for i in range(len(input_shape))]
)
opt_input_shape.extend(
[scale_batch_size(input_shape[i], dynamic_batchsize[i][1]) for i in range(len(input_shape))]
)
max_input_shape.append(
[scale_batch_size(input_shape[i], dynamic_batchsize[i][2]) for i in range(len(input_shape))]
)
else:
max_input_shape.extend(input_shape)
min_input_shape = opt_input_shape = max_input_shape

else:
min_input_shape = opt_input_shape = max_input_shape = input_shape
inputs = [torch.rand(ensure_tuple(input_shape)).to(target_device)]
# Use the dynamic batchsize range to generate the min, opt and max model input shape
if dynamic_batchsize:
min_input_shape.append(scale_batch_size(input_shape, dynamic_batchsize[0]))
opt_input_shape.append(scale_batch_size(input_shape, dynamic_batchsize[1]))
max_input_shape.append(scale_batch_size(input_shape, dynamic_batchsize[2]))
else:
max_input_shape.append(input_shape)
min_input_shape = opt_input_shape = max_input_shape

# convert the torch model to a TorchScript model on target device
model = model.eval().to(target_device)
Expand Down Expand Up @@ -967,8 +989,9 @@ def scale_batch_size(input_shape: Sequence[int], scale_num: int):
with torch.cuda.device(device=device):
input_placeholder = [
torch_tensorrt.Input(
min_shape=min_input_shape, opt_shape=opt_input_shape, max_shape=max_input_shape
min_shape=min_input_shape[i], opt_shape=opt_input_shape[i], max_shape=max_input_shape[i]
)
for i in range(len(min_input_shape))
]
trt_model = torch_tensorrt.compile(
ir_model,
Expand Down
3 changes: 3 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,9 @@
EnsureTyped,
EnsureTypeD,
EnsureTypeDict,
ExtendSubKeysd,
ExtendSubKeysD,
ExtendSubKeysDict,
FgBgToIndicesd,
FgBgToIndicesD,
FgBgToIndicesDict,
Expand Down
42 changes: 42 additions & 0 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@
"EnsureTypeD",
"EnsureTypeDict",
"EnsureTyped",
"ExtendSubKeysD",
"ExtendSubKeysDict",
"ExtendSubKeysd",
"FgBgToIndicesD",
"FgBgToIndicesDict",
"FgBgToIndicesd",
Expand Down Expand Up @@ -828,6 +831,44 @@ def __call__(self, data):
return d


class ExtendSubKeysd(MapTransform):
"""
If an item is dictionary and the value is a list of Tensor, it maps the elements in the Tensor list with a key.
{"pred": [tensor1, tensor2]} --> {"pred": {"mapname1": tensor1, "mapname2": tensor2} }

Args:
keys: keys of the corresponding items to be extend
map_names: the map-names of items to be map. Should be longer than the item list.
prefix: optional prefix to be added to the map names. By default no prefix will be added.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(
self, keys: KeysCollection, map_names: list[Hashable] | None = None, prefix: str | None = None
) -> None:
super().__init__(keys)
self.map_names = map_names
self.prefix = prefix

def __call__(self, data):
d = dict(data)
for key in self.key_iterator(d):
tensor_list = d[key]
map_names_size = len(self.map_names)
tensor_list_size = len(tensor_list)
if map_names_size < tensor_list_size:
raise AttributeError(
f"The map names' size {map_names_size} must be longer than the output list's {tensor_list_size}."
)
self.map_names = self.map_names[:tensor_list_size]
self.map_names = [f"{self.prefix}_{x}" for x in self.map_names] if self.prefix else self.map_names
extend_dict = dict(zip(self.map_names, tensor_list))

d[key] = extend_dict
return d


class SqueezeDimd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.SqueezeDim`.
Expand Down Expand Up @@ -1868,3 +1909,4 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
RandCuCIMD = RandCuCIMDict = RandCuCIMd
AddCoordinateChannelsD = AddCoordinateChannelsDict = AddCoordinateChannelsd
FlattenSubKeysD = FlattenSubKeysDict = FlattenSubKeysd
ExtendSubKeysD = ExtendSubKeysDict = ExtendSubKeysd