Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3938 enhance milmodel for torchscript #3939

Draft
wants to merge 8 commits into
base: dev
Choose a base branch
from
172 changes: 99 additions & 73 deletions monai/networks/nets/milmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Optional, Union, cast
from typing import Optional, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -78,20 +78,15 @@ def __init__(
nfc = net.fc.in_features # save the number of final features
net.fc = torch.nn.Identity() # remove final linear layer

self.extra_outputs = {} # type: Dict[str, torch.Tensor]

if mil_mode == "att_trans_pyramid":
# register hooks to capture outputs of intermediate layers
def forward_hook(layer_name):
def hook(module, input, output):
self.extra_outputs[layer_name] = output

return hook

net.layer1.register_forward_hook(forward_hook("layer1"))
net.layer2.register_forward_hook(forward_hook("layer2"))
net.layer3.register_forward_hook(forward_hook("layer3"))
net.layer4.register_forward_hook(forward_hook("layer4"))
nfc = nfc + 256
self.trans_pyramid_module = TransPyramidModule(
num_classes=num_classes,
backbone=net,
trans_blocks=trans_blocks,
trans_dropout=trans_dropout,
nfc=nfc,
)

elif isinstance(backbone, str):

Expand Down Expand Up @@ -124,7 +119,7 @@ def hook(module, input, output):
if backbone is not None and mil_mode not in ["mean", "max", "att", "att_trans"]:
raise ValueError("Custom backbone is not supported for the mode:" + str(mil_mode))

if self.mil_mode in ["mean", "max"]:
if self.mil_mode in ["mean", "max", "att_trans_pyramid"]:
pass
elif self.mil_mode == "att":
self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1))
Expand All @@ -134,47 +129,15 @@ def hook(module, input, output):
self.transformer = nn.TransformerEncoder(transformer, num_layers=trans_blocks)
self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1))

elif self.mil_mode == "att_trans_pyramid":

transformer_list = nn.ModuleList(
[
nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=256, nhead=8, dropout=trans_dropout), num_layers=trans_blocks
),
nn.Sequential(
nn.Linear(768, 256),
nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=256, nhead=8, dropout=trans_dropout),
num_layers=trans_blocks,
),
),
nn.Sequential(
nn.Linear(1280, 256),
nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=256, nhead=8, dropout=trans_dropout),
num_layers=trans_blocks,
),
),
nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=2304, nhead=8, dropout=trans_dropout),
num_layers=trans_blocks,
),
]
)
self.transformer = transformer_list
nfc = nfc + 256
self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1))

else:
raise ValueError("Unsupported mil_mode: " + str(mil_mode))

self.myfc = nn.Linear(nfc, num_classes)
if not hasattr(self, "myfc"):
self.myfc = nn.Linear(nfc, num_classes)
self.net = net

def calc_head(self, x: torch.Tensor) -> torch.Tensor:

sh = x.shape

if self.mil_mode == "mean":
x = self.myfc(x)
x = torch.mean(x, dim=1)
Expand Down Expand Up @@ -203,42 +166,105 @@ def calc_head(self, x: torch.Tensor) -> torch.Tensor:

x = self.myfc(x)

elif self.mil_mode == "att_trans_pyramid" and self.transformer is not None:
else:
raise ValueError("Wrong model mode" + str(self.mil_mode))

l1 = torch.mean(self.extra_outputs["layer1"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2)
l2 = torch.mean(self.extra_outputs["layer2"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2)
l3 = torch.mean(self.extra_outputs["layer3"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2)
l4 = torch.mean(self.extra_outputs["layer4"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2)
return x

transformer_list = cast(nn.ModuleList, self.transformer)
def forward(self, x: torch.Tensor, no_head: bool = False) -> torch.Tensor:

x = transformer_list[0](l1)
x = transformer_list[1](torch.cat((x, l2), dim=2))
x = transformer_list[2](torch.cat((x, l3), dim=2))
x = transformer_list[3](torch.cat((x, l4), dim=2))
sh = x.shape
x = x.reshape(sh[0] * sh[1], sh[2], sh[3], sh[4])
if hasattr(self, "trans_pyramid_module"):
batch, channel = sh[0], sh[1]
x = self.trans_pyramid_module(x, batch=batch, channel=channel, no_head=no_head)
else:
x = self.net(x)
x = x.reshape(sh[0], sh[1], -1)

x = x.permute(1, 0, 2)
if not no_head:
x = self.calc_head(x)

a = self.attention(x)
a = torch.softmax(a, dim=1)
x = torch.sum(x * a, dim=1)
return x

x = self.myfc(x)

else:
raise ValueError("Wrong model mode" + str(self.mil_mode))
class TransPyramidModule(nn.Module):
def __init__(
self,
num_classes: int,
backbone: nn.Module,
trans_blocks: int,
trans_dropout: float,
nfc: int,
) -> None:

return x
super().__init__()

def forward(self, x: torch.Tensor, no_head: bool = False) -> torch.Tensor:
self.backbone = backbone
transformer_list = nn.ModuleList(
[
nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=256, nhead=8, dropout=trans_dropout), num_layers=trans_blocks
),
nn.Sequential(
nn.Linear(768, 256),
nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=256, nhead=8, dropout=trans_dropout),
num_layers=trans_blocks,
),
),
nn.Sequential(
nn.Linear(1280, 256),
nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=256, nhead=8, dropout=trans_dropout),
num_layers=trans_blocks,
),
),
nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=2304, nhead=8, dropout=trans_dropout),
num_layers=trans_blocks,
),
]
)
self.transformer = transformer_list
self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1))
self.myfc = nn.Linear(nfc, num_classes)

sh = x.shape
x = x.reshape(sh[0] * sh[1], sh[2], sh[3], sh[4])
def forward(self, x: torch.Tensor, batch: int, channel: int, no_head: bool = False):

x = self.backbone.conv1(x) # type: ignore
x = self.backbone.bn1(x) # type: ignore
x = self.backbone.relu(x) # type: ignore
x = self.backbone.maxpool(x) # type: ignore

x = self.net(x)
x = x.reshape(sh[0], sh[1], -1)
x_l1 = self.backbone.layer1(x) # type: ignore
x_l2 = self.backbone.layer2(x_l1) # type: ignore
x_l3 = self.backbone.layer3(x_l2) # type: ignore
x_l4 = self.backbone.layer4(x_l3) # type: ignore

x = self.backbone.avgpool(x_l4) # type: ignore
x = torch.flatten(x, 1)
x = self.backbone.fc(x) # type: ignore

x = x.reshape(batch, channel, -1)

if not no_head:
x = self.calc_head(x)
l1 = torch.mean(x_l1, dim=(2, 3)).reshape(batch, channel, -1).permute(1, 0, 2)
l2 = torch.mean(x_l2, dim=(2, 3)).reshape(batch, channel, -1).permute(1, 0, 2)
l3 = torch.mean(x_l3, dim=(2, 3)).reshape(batch, channel, -1).permute(1, 0, 2)
l4 = torch.mean(x_l4, dim=(2, 3)).reshape(batch, channel, -1).permute(1, 0, 2)

x = self.transformer[0](l1)
x = self.transformer[1](torch.cat((x, l2), dim=2))
x = self.transformer[2](torch.cat((x, l3), dim=2))
x = self.transformer[3](torch.cat((x, l4), dim=2))

x = x.permute(1, 0, 2)

a = self.attention(x)
a = torch.softmax(a, dim=1)
x = torch.sum(x * a, dim=1)

x = self.myfc(x)

return x
11 changes: 6 additions & 5 deletions tests/test_milmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,12 @@ def test_ill_args(self):
mil_mode="att_trans_pyramid",
)

def test_script(self):
input_param, input_shape, expected_shape = TEST_CASE_MILMODEL[0]
net = MILModel(**input_param)
test_data = torch.randn(input_shape, dtype=torch.float)
test_script_save(net, test_data)
@parameterized.expand(TEST_CASE_MILMODEL)
def test_script(self, input_param, input_shape, expected_shape):
if "mil_mode" in input_param.keys():
net = MILModel(**input_param)
test_data = torch.randn(input_shape, dtype=torch.float)
test_script_save(net, test_data)


if __name__ == "__main__":
Expand Down