Skip to content

Commit

Permalink
Add ResNet backbones for FlexibleUNet (Project-MONAI#7571)
Browse files Browse the repository at this point in the history
Fixes Project-MONAI#7570.

### Description

Add ResNet backbones (with option to use pretrained Med3D weights) for
FlexibleUNet.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Konstantin Sukharev <[email protected]>
Co-authored-by: YunLiu <[email protected]>
  • Loading branch information
k-sukharev and KumoLiu authored Apr 23, 2024
1 parent ec6aa33 commit dc58e5c
Show file tree
Hide file tree
Showing 6 changed files with 222 additions and 117 deletions.
5 changes: 5 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,11 @@ Nets
.. autoclass:: ResNet
:members:

`ResNetFeatures`
~~~~~~~~~~~~~~~~
.. autoclass:: ResNetFeatures
:members:

`SENet`
~~~~~~~
.. autoclass:: SENet
Expand Down
2 changes: 2 additions & 0 deletions monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
ResNet,
ResNetBlock,
ResNetBottleneck,
ResNetEncoder,
ResNetFeatures,
get_medicalnet_pretrained_resnet_args,
get_pretrained_resnet_medicalnet,
resnet10,
Expand Down
13 changes: 8 additions & 5 deletions monai/networks/nets/flexible_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from monai.networks.layers.utils import get_act_layer
from monai.networks.nets import EfficientNetEncoder
from monai.networks.nets.basic_unet import UpCat
from monai.networks.nets.resnet import ResNetEncoder
from monai.utils import InterpolateMode, optional_import

__all__ = ["FlexibleUNet", "FlexUNet", "FLEXUNET_BACKBONE", "FlexUNetEncoderRegister"]
Expand Down Expand Up @@ -78,6 +79,7 @@ def register_class(self, name: type[Any] | str):

FLEXUNET_BACKBONE = FlexUNetEncoderRegister()
FLEXUNET_BACKBONE.register_class(EfficientNetEncoder)
FLEXUNET_BACKBONE.register_class(ResNetEncoder)


class UNetDecoder(nn.Module):
Expand Down Expand Up @@ -238,7 +240,7 @@ def __init__(
) -> None:
"""
A flexible implement of UNet, in which the backbone/encoder can be replaced with
any efficient network. Currently the input must have a 2 or 3 spatial dimension
any efficient or residual network. Currently the input must have a 2 or 3 spatial dimension
and the spatial size of each dimension must be a multiple of 32 if is_pad parameter
is False.
Please notice each output of backbone must be 2x downsample in spatial dimension
Expand All @@ -248,10 +250,11 @@ def __init__(
Args:
in_channels: number of input channels.
out_channels: number of output channels.
backbone: name of backbones to initialize, only support efficientnet right now,
can be from [efficientnet-b0,..., efficientnet-b8, efficientnet-l2].
pretrained: whether to initialize pretrained ImageNet weights, only available
for spatial_dims=2 and batch norm is used, default to False.
backbone: name of backbones to initialize, only support efficientnet and resnet right now,
can be from [efficientnet-b0, ..., efficientnet-b8, efficientnet-l2, resnet10, ..., resnet200].
pretrained: whether to initialize pretrained weights. ImageNet weights are available for efficient networks
if spatial_dims=2 and batch norm is used. MedicalNet weights are available for residual networks
if spatial_dims=3 and in_channels=1. Default to False.
decoder_channels: number of output channels for all feature maps in decoder.
`len(decoder_channels)` should equal to `len(encoder_channels) - 1`,default
to (256, 128, 64, 32, 16).
Expand Down
145 changes: 143 additions & 2 deletions monai/networks/nets/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch
import torch.nn as nn

from monai.networks.blocks.encoder import BaseEncoder
from monai.networks.layers.factories import Conv, Norm, Pool
from monai.networks.layers.utils import get_pool_layer
from monai.utils import ensure_tuple_rep
Expand All @@ -45,6 +46,19 @@
"resnet200",
]


resnet_params = {
# model_name: (block, layers, shortcut_type, bias_downsample, datasets23)
"resnet10": ("basic", [1, 1, 1, 1], "B", False, True),
"resnet18": ("basic", [2, 2, 2, 2], "A", True, True),
"resnet34": ("basic", [3, 4, 6, 3], "A", True, True),
"resnet50": ("bottleneck", [3, 4, 6, 3], "B", False, True),
"resnet101": ("bottleneck", [3, 4, 23, 3], "B", False, False),
"resnet152": ("bottleneck", [3, 8, 36, 3], "B", False, False),
"resnet200": ("bottleneck", [3, 24, 36, 3], "B", False, False),
}


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -335,6 +349,120 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


class ResNetFeatures(ResNet):

def __init__(self, model_name: str, pretrained: bool = True, spatial_dims: int = 3, in_channels: int = 1) -> None:
"""Initialize resnet18 to resnet200 models as a backbone, the backbone can be used as an encoder for
segmentation and objection models.
Compared with the class `ResNet`, the only different place is the forward function.
Args:
model_name: name of model to initialize, can be from [resnet10, ..., resnet200].
pretrained: whether to initialize pretrained MedicalNet weights,
only available for spatial_dims=3 and in_channels=1.
spatial_dims: number of spatial dimensions of the input image.
in_channels: number of input channels for first convolutional layer.
"""
if model_name not in resnet_params:
model_name_string = ", ".join(resnet_params.keys())
raise ValueError(f"invalid model_name {model_name} found, must be one of {model_name_string} ")

block, layers, shortcut_type, bias_downsample, datasets23 = resnet_params[model_name]

super().__init__(
block=block,
layers=layers,
block_inplanes=get_inplanes(),
spatial_dims=spatial_dims,
n_input_channels=in_channels,
conv1_t_stride=2,
shortcut_type=shortcut_type,
feed_forward=False,
bias_downsample=bias_downsample,
)
if pretrained:
if spatial_dims == 3 and in_channels == 1:
_load_state_dict(self, model_name, datasets23=datasets23)
else:
raise ValueError("Pretrained resnet models are only available for in_channels=1 and spatial_dims=3.")

def forward(self, inputs: torch.Tensor):
"""
Args:
inputs: input should have spatially N dimensions
``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``, N is defined by `dimensions`.
Returns:
a list of torch Tensors.
"""
x = self.conv1(inputs)
x = self.bn1(x)
x = self.relu(x)

features = []
features.append(x)

if not self.no_max_pool:
x = self.maxpool(x)

x = self.layer1(x)
features.append(x)

x = self.layer2(x)
features.append(x)

x = self.layer3(x)
features.append(x)

x = self.layer4(x)
features.append(x)

return features


class ResNetEncoder(ResNetFeatures, BaseEncoder):
"""Wrap the original resnet to an encoder for flexible-unet."""

backbone_names = ["resnet10", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnet200"]

@classmethod
def get_encoder_parameters(cls) -> list[dict]:
"""Get the initialization parameter for resnet backbones."""
parameter_list = []
for backbone_name in cls.backbone_names:
parameter_list.append(
{"model_name": backbone_name, "pretrained": True, "spatial_dims": 3, "in_channels": 1}
)
return parameter_list

@classmethod
def num_channels_per_output(cls) -> list[tuple[int, ...]]:
"""Get number of resnet backbone output feature maps channel."""
return [
(64, 64, 128, 256, 512),
(64, 64, 128, 256, 512),
(64, 64, 128, 256, 512),
(64, 256, 512, 1024, 2048),
(64, 256, 512, 1024, 2048),
(64, 256, 512, 1024, 2048),
(64, 256, 512, 1024, 2048),
]

@classmethod
def num_outputs(cls) -> list[int]:
"""Get number of resnet backbone output feature maps.
Since every backbone contains the same 5 output feature maps, the number list should be `[5] * 7`.
"""
return [5] * 7

@classmethod
def get_encoder_names(cls) -> list[str]:
"""Get names of resnet backbones."""
return cls.backbone_names


def _resnet(
arch: str,
block: type[ResNetBlock | ResNetBottleneck],
Expand Down Expand Up @@ -477,7 +605,7 @@ def resnet200(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->

def get_pretrained_resnet_medicalnet(resnet_depth: int, device: str = "cpu", datasets23: bool = True):
"""
Donwlad resnet pretrained weights from https://huggingface.co/TencentMedicalNet
Download resnet pretrained weights from https://huggingface.co/TencentMedicalNet
Args:
resnet_depth: depth of the pretrained model. Supported values are 10, 18, 34, 50, 101, 152 and 200
Expand Down Expand Up @@ -533,11 +661,24 @@ def get_pretrained_resnet_medicalnet(resnet_depth: int, device: str = "cpu", dat
def get_medicalnet_pretrained_resnet_args(resnet_depth: int):
"""
Return correct shortcut_type and bias_downsample
for pretrained MedicalNet weights according to resnet depth
for pretrained MedicalNet weights according to resnet depth.
"""
# After testing
# False: 10, 50, 101, 152, 200
# Any: 18, 34
bias_downsample = -1 if resnet_depth in [18, 34] else 0 # 18, 10, 34
shortcut_type = "A" if resnet_depth in [18, 34] else "B"
return bias_downsample, shortcut_type


def _load_state_dict(model: nn.Module, model_name: str, datasets23: bool = True) -> None:
search_res = re.search(r"resnet(\d+)", model_name)
if search_res:
resnet_depth = int(search_res.group(1))
datasets23 = model_name.endswith("_23datasets")
else:
raise ValueError("model_name argument should contain resnet depth. Example: resnet18 or resnet18_23datasets.")

model_state_dict = get_pretrained_resnet_medicalnet(resnet_depth, device="cpu", datasets23=datasets23)
model_state_dict = {key.replace("module.", ""): value for key, value in model_state_dict.items()}
model.load_state_dict(model_state_dict)
Loading

0 comments on commit dc58e5c

Please sign in to comment.