diff --git a/docs/source/networks.rst b/docs/source/networks.rst index b59c8af5fc..249375dfc1 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -491,6 +491,11 @@ Nets .. autoclass:: ResNet :members: +`ResNetFeatures` +~~~~~~~~~~~~~~~~ +.. autoclass:: ResNetFeatures + :members: + `SENet` ~~~~~~~ .. autoclass:: SENet diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 9247aaee85..de5d1adc7e 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -59,6 +59,8 @@ ResNet, ResNetBlock, ResNetBottleneck, + ResNetEncoder, + ResNetFeatures, get_medicalnet_pretrained_resnet_args, get_pretrained_resnet_medicalnet, resnet10, diff --git a/monai/networks/nets/flexible_unet.py b/monai/networks/nets/flexible_unet.py index ac2124b5f9..c27b0fc17b 100644 --- a/monai/networks/nets/flexible_unet.py +++ b/monai/networks/nets/flexible_unet.py @@ -24,6 +24,7 @@ from monai.networks.layers.utils import get_act_layer from monai.networks.nets import EfficientNetEncoder from monai.networks.nets.basic_unet import UpCat +from monai.networks.nets.resnet import ResNetEncoder from monai.utils import InterpolateMode, optional_import __all__ = ["FlexibleUNet", "FlexUNet", "FLEXUNET_BACKBONE", "FlexUNetEncoderRegister"] @@ -78,6 +79,7 @@ def register_class(self, name: type[Any] | str): FLEXUNET_BACKBONE = FlexUNetEncoderRegister() FLEXUNET_BACKBONE.register_class(EfficientNetEncoder) +FLEXUNET_BACKBONE.register_class(ResNetEncoder) class UNetDecoder(nn.Module): @@ -238,7 +240,7 @@ def __init__( ) -> None: """ A flexible implement of UNet, in which the backbone/encoder can be replaced with - any efficient network. Currently the input must have a 2 or 3 spatial dimension + any efficient or residual network. Currently the input must have a 2 or 3 spatial dimension and the spatial size of each dimension must be a multiple of 32 if is_pad parameter is False. Please notice each output of backbone must be 2x downsample in spatial dimension @@ -248,10 +250,11 @@ def __init__( Args: in_channels: number of input channels. out_channels: number of output channels. - backbone: name of backbones to initialize, only support efficientnet right now, - can be from [efficientnet-b0,..., efficientnet-b8, efficientnet-l2]. - pretrained: whether to initialize pretrained ImageNet weights, only available - for spatial_dims=2 and batch norm is used, default to False. + backbone: name of backbones to initialize, only support efficientnet and resnet right now, + can be from [efficientnet-b0, ..., efficientnet-b8, efficientnet-l2, resnet10, ..., resnet200]. + pretrained: whether to initialize pretrained weights. ImageNet weights are available for efficient networks + if spatial_dims=2 and batch norm is used. MedicalNet weights are available for residual networks + if spatial_dims=3 and in_channels=1. Default to False. decoder_channels: number of output channels for all feature maps in decoder. `len(decoder_channels)` should equal to `len(encoder_channels) - 1`,default to (256, 128, 64, 32, 16). diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index 34a4b7057e..99975271da 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -21,6 +21,7 @@ import torch import torch.nn as nn +from monai.networks.blocks.encoder import BaseEncoder from monai.networks.layers.factories import Conv, Norm, Pool from monai.networks.layers.utils import get_pool_layer from monai.utils import ensure_tuple_rep @@ -45,6 +46,19 @@ "resnet200", ] + +resnet_params = { + # model_name: (block, layers, shortcut_type, bias_downsample, datasets23) + "resnet10": ("basic", [1, 1, 1, 1], "B", False, True), + "resnet18": ("basic", [2, 2, 2, 2], "A", True, True), + "resnet34": ("basic", [3, 4, 6, 3], "A", True, True), + "resnet50": ("bottleneck", [3, 4, 6, 3], "B", False, True), + "resnet101": ("bottleneck", [3, 4, 23, 3], "B", False, False), + "resnet152": ("bottleneck", [3, 8, 36, 3], "B", False, False), + "resnet200": ("bottleneck", [3, 24, 36, 3], "B", False, False), +} + + logger = logging.getLogger(__name__) @@ -335,6 +349,120 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +class ResNetFeatures(ResNet): + + def __init__(self, model_name: str, pretrained: bool = True, spatial_dims: int = 3, in_channels: int = 1) -> None: + """Initialize resnet18 to resnet200 models as a backbone, the backbone can be used as an encoder for + segmentation and objection models. + + Compared with the class `ResNet`, the only different place is the forward function. + + Args: + model_name: name of model to initialize, can be from [resnet10, ..., resnet200]. + pretrained: whether to initialize pretrained MedicalNet weights, + only available for spatial_dims=3 and in_channels=1. + spatial_dims: number of spatial dimensions of the input image. + in_channels: number of input channels for first convolutional layer. + """ + if model_name not in resnet_params: + model_name_string = ", ".join(resnet_params.keys()) + raise ValueError(f"invalid model_name {model_name} found, must be one of {model_name_string} ") + + block, layers, shortcut_type, bias_downsample, datasets23 = resnet_params[model_name] + + super().__init__( + block=block, + layers=layers, + block_inplanes=get_inplanes(), + spatial_dims=spatial_dims, + n_input_channels=in_channels, + conv1_t_stride=2, + shortcut_type=shortcut_type, + feed_forward=False, + bias_downsample=bias_downsample, + ) + if pretrained: + if spatial_dims == 3 and in_channels == 1: + _load_state_dict(self, model_name, datasets23=datasets23) + else: + raise ValueError("Pretrained resnet models are only available for in_channels=1 and spatial_dims=3.") + + def forward(self, inputs: torch.Tensor): + """ + Args: + inputs: input should have spatially N dimensions + ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``, N is defined by `dimensions`. + + Returns: + a list of torch Tensors. + """ + x = self.conv1(inputs) + x = self.bn1(x) + x = self.relu(x) + + features = [] + features.append(x) + + if not self.no_max_pool: + x = self.maxpool(x) + + x = self.layer1(x) + features.append(x) + + x = self.layer2(x) + features.append(x) + + x = self.layer3(x) + features.append(x) + + x = self.layer4(x) + features.append(x) + + return features + + +class ResNetEncoder(ResNetFeatures, BaseEncoder): + """Wrap the original resnet to an encoder for flexible-unet.""" + + backbone_names = ["resnet10", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnet200"] + + @classmethod + def get_encoder_parameters(cls) -> list[dict]: + """Get the initialization parameter for resnet backbones.""" + parameter_list = [] + for backbone_name in cls.backbone_names: + parameter_list.append( + {"model_name": backbone_name, "pretrained": True, "spatial_dims": 3, "in_channels": 1} + ) + return parameter_list + + @classmethod + def num_channels_per_output(cls) -> list[tuple[int, ...]]: + """Get number of resnet backbone output feature maps channel.""" + return [ + (64, 64, 128, 256, 512), + (64, 64, 128, 256, 512), + (64, 64, 128, 256, 512), + (64, 256, 512, 1024, 2048), + (64, 256, 512, 1024, 2048), + (64, 256, 512, 1024, 2048), + (64, 256, 512, 1024, 2048), + ] + + @classmethod + def num_outputs(cls) -> list[int]: + """Get number of resnet backbone output feature maps. + + Since every backbone contains the same 5 output feature maps, the number list should be `[5] * 7`. + """ + return [5] * 7 + + @classmethod + def get_encoder_names(cls) -> list[str]: + """Get names of resnet backbones.""" + return cls.backbone_names + + def _resnet( arch: str, block: type[ResNetBlock | ResNetBottleneck], @@ -477,7 +605,7 @@ def resnet200(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> def get_pretrained_resnet_medicalnet(resnet_depth: int, device: str = "cpu", datasets23: bool = True): """ - Donwlad resnet pretrained weights from https://huggingface.co/TencentMedicalNet + Download resnet pretrained weights from https://huggingface.co/TencentMedicalNet Args: resnet_depth: depth of the pretrained model. Supported values are 10, 18, 34, 50, 101, 152 and 200 @@ -533,7 +661,7 @@ def get_pretrained_resnet_medicalnet(resnet_depth: int, device: str = "cpu", dat def get_medicalnet_pretrained_resnet_args(resnet_depth: int): """ Return correct shortcut_type and bias_downsample - for pretrained MedicalNet weights according to resnet depth + for pretrained MedicalNet weights according to resnet depth. """ # After testing # False: 10, 50, 101, 152, 200 @@ -541,3 +669,16 @@ def get_medicalnet_pretrained_resnet_args(resnet_depth: int): bias_downsample = -1 if resnet_depth in [18, 34] else 0 # 18, 10, 34 shortcut_type = "A" if resnet_depth in [18, 34] else "B" return bias_downsample, shortcut_type + + +def _load_state_dict(model: nn.Module, model_name: str, datasets23: bool = True) -> None: + search_res = re.search(r"resnet(\d+)", model_name) + if search_res: + resnet_depth = int(search_res.group(1)) + datasets23 = model_name.endswith("_23datasets") + else: + raise ValueError("model_name argument should contain resnet depth. Example: resnet18 or resnet18_23datasets.") + + model_state_dict = get_pretrained_resnet_medicalnet(resnet_depth, device="cpu", datasets23=datasets23) + model_state_dict = {key.replace("module.", ""): value for key, value in model_state_dict.items()} + model.load_state_dict(model_state_dict) diff --git a/tests/test_flexible_unet.py b/tests/test_flexible_unet.py index 404855c9a8..42baa28b71 100644 --- a/tests/test_flexible_unet.py +++ b/tests/test_flexible_unet.py @@ -23,12 +23,11 @@ EfficientNetBNFeatures, FlexibleUNet, FlexUNetEncoderRegister, - ResNet, - ResNetBlock, - ResNetBottleneck, + ResNetEncoder, + ResNetFeatures, ) from monai.utils import optional_import -from tests.utils import skip_if_downloading_fails, skip_if_quick +from tests.utils import SkipIfNoModule, skip_if_downloading_fails, skip_if_quick torchvision, has_torchvision = optional_import("torchvision") PIL, has_pil = optional_import("PIL") @@ -59,101 +58,6 @@ def get_encoder_names(cls): return ["encoder_wrong_channels", "encoder_no_param1", "encoder_no_param2", "encoder_no_param3"] -class ResNetEncoder(ResNet, BaseEncoder): - backbone_names = ["resnet10", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnet200"] - output_feature_channels = [(64, 128, 256, 512)] * 3 + [(256, 512, 1024, 2048)] * 4 - parameter_layers = [ - [1, 1, 1, 1], - [2, 2, 2, 2], - [3, 4, 6, 3], - [3, 4, 6, 3], - [3, 4, 23, 3], - [3, 8, 36, 3], - [3, 24, 36, 3], - ] - - def __init__(self, in_channels, pretrained, **kargs): - super().__init__(**kargs, n_input_channels=in_channels) - if pretrained: - # Author of paper zipped the state_dict on googledrive, - # so would need to download, unzip and read (2.8gb file for a ~150mb state dict). - # Would like to load dict from url but need somewhere to save the state dicts. - raise NotImplementedError( - "Currently not implemented. You need to manually download weights provided by the paper's author" - " and load then to the model with `state_dict`. See https://github.com/Tencent/MedicalNet" - ) - - @staticmethod - def get_inplanes(): - return [64, 128, 256, 512] - - @classmethod - def get_encoder_parameters(cls) -> list[dict]: - """ - Get parameter list to initialize encoder networks. - Each parameter dict must have `spatial_dims`, `in_channels` - and `pretrained` parameters. - """ - parameter_list = [] - res_type: type[ResNetBlock] | type[ResNetBottleneck] - for backbone in range(len(cls.backbone_names)): - if backbone < 3: - res_type = ResNetBlock - else: - res_type = ResNetBottleneck - parameter_list.append( - { - "block": res_type, - "layers": cls.parameter_layers[backbone], - "block_inplanes": ResNetEncoder.get_inplanes(), - "spatial_dims": 2, - "in_channels": 3, - "pretrained": False, - } - ) - return parameter_list - - @classmethod - def num_channels_per_output(cls): - """ - Get number of output features' channel. - """ - return cls.output_feature_channels - - @classmethod - def num_outputs(cls): - """ - Get number of output feature. - """ - return [4] * 7 - - @classmethod - def get_encoder_names(cls): - """ - Get the name string of backbones which will be used to initialize flexible unet. - """ - return cls.backbone_names - - def forward(self, x: torch.Tensor): - feature_list = [] - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - if not self.no_max_pool: - x = self.maxpool(x) - x = self.layer1(x) - feature_list.append(x) - x = self.layer2(x) - feature_list.append(x) - x = self.layer3(x) - feature_list.append(x) - x = self.layer4(x) - feature_list.append(x) - - return feature_list - - -FLEXUNET_BACKBONE.register_class(ResNetEncoder) FLEXUNET_BACKBONE.register_class(DummyEncoder) @@ -204,9 +108,7 @@ def make_shape_cases( def make_error_case(): - error_dummy_backbones = DummyEncoder.get_encoder_names() - error_resnet_backbones = ResNetEncoder.get_encoder_names() - error_backbones = error_dummy_backbones + error_resnet_backbones + error_backbones = DummyEncoder.get_encoder_names() error_param_list = [] for backbone in error_backbones: error_param_list.append( @@ -232,7 +134,7 @@ def make_error_case(): norm="instance", ) CASES_3D = make_shape_cases( - models=[SEL_MODELS[0]], + models=[SEL_MODELS[0], SEL_MODELS[2]], spatial_dims=[3], batches=[1], pretrained=[False], @@ -345,6 +247,7 @@ def make_error_case(): "spatial_dims": 2, "norm": ("batch", {"eps": 1e-3, "momentum": 0.01}), }, + EfficientNetBNFeatures, { "in_channels": 3, "num_classes": 10, @@ -354,7 +257,20 @@ def make_error_case(): "norm": ("batch", {"eps": 1e-3, "momentum": 0.01}), }, ["_conv_stem.weight"], - ) + ), + ( + { + "in_channels": 1, + "out_channels": 10, + "backbone": SEL_MODELS[2], + "pretrained": True, + "spatial_dims": 3, + "norm": ("batch", {"eps": 1e-3, "momentum": 0.01}), + }, + ResNetFeatures, + {"model_name": SEL_MODELS[2], "pretrained": True, "spatial_dims": 3, "in_channels": 1}, + ["conv1.weight"], + ), ] CASE_ERRORS = make_error_case() @@ -363,6 +279,7 @@ def make_error_case(): CASE_REGISTER_ENCODER = ["EfficientNetEncoder", "monai.networks.nets.EfficientNetEncoder"] +@SkipIfNoModule("hf_hub_download") @skip_if_quick class TestFLEXIBLEUNET(unittest.TestCase): @@ -381,19 +298,19 @@ def test_shape(self, input_param, input_shape, expected_shape): self.assertEqual(result.shape, expected_shape) @parameterized.expand(CASES_PRETRAIN) - def test_pretrain(self, input_param, efficient_input_param, weight_list): + def test_pretrain(self, flexunet_input_param, feature_extractor_class, feature_extractor_input_param, weight_list): device = "cuda" if torch.cuda.is_available() else "cpu" with skip_if_downloading_fails(): - net = FlexibleUNet(**input_param).to(device) + net = FlexibleUNet(**flexunet_input_param).to(device) with skip_if_downloading_fails(): - eff_net = EfficientNetBNFeatures(**efficient_input_param).to(device) + feature_extractor_net = feature_extractor_class(**feature_extractor_input_param).to(device) for weight_name in weight_list: - if weight_name in net.encoder.state_dict() and weight_name in eff_net.state_dict(): + if weight_name in net.encoder.state_dict() and weight_name in feature_extractor_net.state_dict(): net_weight = net.encoder.state_dict()[weight_name] - download_weight = eff_net.state_dict()[weight_name] + download_weight = feature_extractor_net.state_dict()[weight_name] weight_diff = torch.abs(net_weight - download_weight) diff_sum = torch.sum(weight_diff) # check if a weight in weight_list equals to the downloaded weight. diff --git a/tests/test_resnet.py b/tests/test_resnet.py index ad1aad8fc6..449edba4bf 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -24,6 +24,7 @@ from monai.networks import eval_mode from monai.networks.nets import ( ResNet, + ResNetFeatures, get_medicalnet_pretrained_resnet_args, get_pretrained_resnet_medicalnet, resnet10, @@ -36,7 +37,14 @@ ) from monai.networks.nets.resnet import ResNetBlock from monai.utils import optional_import -from tests.utils import equal_state_dict, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick, test_script_save +from tests.utils import ( + SkipIfNoModule, + equal_state_dict, + skip_if_downloading_fails, + skip_if_no_cuda, + skip_if_quick, + test_script_save, +) if TYPE_CHECKING: import torchvision @@ -191,6 +199,15 @@ ] +CASE_EXTRACT_FEATURES = [ + ( + {"model_name": "resnet10", "pretrained": True, "spatial_dims": 3, "in_channels": 1}, + [1, 1, 64, 64, 64], + ([1, 64, 32, 32, 32], [1, 64, 16, 16, 16], [1, 128, 8, 8, 8], [1, 256, 4, 4, 4], [1, 512, 2, 2, 2]), + ) +] + + class TestResNet(unittest.TestCase): def setUp(self): @@ -270,5 +287,25 @@ def test_script(self, model, input_param, input_shape, expected_shape): test_script_save(net, test_data) +@SkipIfNoModule("hf_hub_download") +class TestExtractFeatures(unittest.TestCase): + + @parameterized.expand(CASE_EXTRACT_FEATURES) + def test_shape(self, input_param, input_shape, expected_shapes): + device = "cuda" if torch.cuda.is_available() else "cpu" + + with skip_if_downloading_fails(): + net = ResNetFeatures(**input_param).to(device) + + # run inference with random tensor + with eval_mode(net): + features = net(torch.randn(input_shape).to(device)) + + # check output shape + self.assertEqual(len(features), len(expected_shapes)) + for feature, expected_shape in zip(features, expected_shapes): + self.assertEqual(feature.shape, torch.Size(expected_shape)) + + if __name__ == "__main__": unittest.main()