From e0fdcf3c7d85f3838ac0e091b7fa8e08ea517c71 Mon Sep 17 00:00:00 2001 From: tangy5 Date: Mon, 3 Apr 2023 18:04:14 -0700 Subject: [PATCH 1/8] add text driven segmentor for controllable outputs Signed-off-by: tangy5 --- monai/networks/blocks/head_controller.py | 130 +++++++++++++++++++++++ 1 file changed, 130 insertions(+) create mode 100644 monai/networks/blocks/head_controller.py diff --git a/monai/networks/blocks/head_controller.py b/monai/networks/blocks/head_controller.py new file mode 100644 index 0000000000..9021ec75f5 --- /dev/null +++ b/monai/networks/blocks/head_controller.py @@ -0,0 +1,130 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +import torch.nn as nn + +class HeadController(nn.Module): + """ + Text-based controller for segmentation outputs, the text-driven segmentor enables for optional outputs instead of + fixed output channels. Users can choose and control the number and name of output channels from a mult-class segmentation + model. This can enabble incremental learning by adding new classes to a existing pre-trained model without + catatrophic forgetting. + + Text-dirven segmentor, based on: "Liu et al., + CLIP-Driven Universal Model for Organ Segmentation and Tumor Detection " + """ + def __init__( + self, + task_encoding: str, + out_channels: int, + hidden_size: int = 256, + ) -> None: + """ + Args: + task_encoding: the text embedding features passed. TODO: make optional + out_channels: number of output channels, to control text-baesd embedding for classes. + hidden_size: dimension of hidden features, compatible to different vision feature dimensions. + """ + + weight_nums, bias_nums = [], [] + weight_nums.append(8*8) + weight_nums.append(8*8) + weight_nums.append(8*1) + bias_nums.append(8) + bias_nums.append(8) + bias_nums.append(1) + self.weight_nums = weight_nums + self.bias_nums = bias_nums + + #TODO: parameterize basic kernel size, stride, and padding + self.controller = nn.Conv3d(2*hidden_size, sum(weight_nums+bias_nums), kernel_size=1, stride=1, padding=0) + + self.class_num = out_channels + + self.task_encoding = task_encoding + + self.precls_conv = nn.Sequential( + nn.GroupNorm(16, 48), + nn.ReLU(inplace=True), + nn.Conv3d(48, 8, kernel_size=1) + ) + self.GAP = nn.Sequential( + nn.GroupNorm(16, 768), + nn.ReLU(inplace=True), + torch.nn.AdaptiveAvgPool3d((1,1,1)), + nn.Conv3d(768, 256, kernel_size=1, stride=1, padding=0) + ) + + def parse_dynamic_params(self, params, channels, weight_nums, bias_nums): + assert params.dim() == 2 + assert len(weight_nums) == len(bias_nums) + assert params.size(1) == sum(weight_nums) + sum(bias_nums) + + num_insts = params.size(0) + num_layers = len(weight_nums) + + params_splits = list(torch.split_with_sizes( + params, weight_nums + bias_nums, dim=1 + )) + + weight_splits = params_splits[:num_layers] + bias_splits = params_splits[num_layers:] + + for l in range(num_layers): + if l < num_layers - 1: + weight_splits[l] = weight_splits[l].reshape(num_insts * channels, -1, 1, 1, 1) + bias_splits[l] = bias_splits[l].reshape(num_insts * channels) + else: + weight_splits[l] = weight_splits[l].reshape(num_insts * 1, -1, 1, 1, 1) + bias_splits[l] = bias_splits[l].reshape(num_insts * 1) + + return weight_splits, bias_splits + + def heads_forward(self, features, weights, biases, num_insts): + assert features.dim() == 5 + n_layers = len(weights) + x = features + for i, (w, b) in enumerate(zip(weights, biases)): + x = nn.function.conv3d( + x, w, bias=b, + stride=1, padding=0, + groups=num_insts + ) + if i < n_layers - 1: + x = nn.function.relu(x) + return x + + def forward(self, x): + x_feat = self.GAP(x) + b = x_feat.shape[0] + logits_array = [] + for i in range(b): + x_cond = torch.cat([x_feat[i].unsqueeze(0).repeat(self.class_num,1,1,1,1), self.task_encoding], 1) + params = self.controller(x_cond) + params.squeeze_(-1).squeeze_(-1).squeeze_(-1) + + head_inputs = self.precls_conv(out[i].unsqueeze(0)) + head_inputs = head_inputs.repeat(self.class_num,1,1,1,1) + N, _, D, H, W = head_inputs.size() + head_inputs = head_inputs.reshape(1, -1, D, H, W) + weights, biases = self.parse_dynamic_params(params, 8, self.weight_nums, self.bias_nums) + + logits = self.heads_forward(head_inputs, weights, biases, N) + logits_array.append(logits.reshape(1, -1, D, H, W)) + + out = torch.cat(logits_array,dim=0) + return out + + + From b067e57cc395dcc4454169968aea9625abe44bc3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 4 Apr 2023 01:09:05 +0000 Subject: [PATCH 2/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/blocks/head_controller.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/monai/networks/blocks/head_controller.py b/monai/networks/blocks/head_controller.py index 9021ec75f5..c4b896f308 100644 --- a/monai/networks/blocks/head_controller.py +++ b/monai/networks/blocks/head_controller.py @@ -18,12 +18,12 @@ class HeadController(nn.Module): """ Text-based controller for segmentation outputs, the text-driven segmentor enables for optional outputs instead of fixed output channels. Users can choose and control the number and name of output channels from a mult-class segmentation - model. This can enabble incremental learning by adding new classes to a existing pre-trained model without + model. This can enabble incremental learning by adding new classes to a existing pre-trained model without catatrophic forgetting. - + Text-dirven segmentor, based on: "Liu et al., CLIP-Driven Universal Model for Organ Segmentation and Tumor Detection " - """ + """ def __init__( self, task_encoding: str, @@ -113,7 +113,7 @@ def forward(self, x): x_cond = torch.cat([x_feat[i].unsqueeze(0).repeat(self.class_num,1,1,1,1), self.task_encoding], 1) params = self.controller(x_cond) params.squeeze_(-1).squeeze_(-1).squeeze_(-1) - + head_inputs = self.precls_conv(out[i].unsqueeze(0)) head_inputs = head_inputs.repeat(self.class_num,1,1,1,1) N, _, D, H, W = head_inputs.size() @@ -122,9 +122,6 @@ def forward(self, x): logits = self.heads_forward(head_inputs, weights, biases, N) logits_array.append(logits.reshape(1, -1, D, H, W)) - + out = torch.cat(logits_array,dim=0) return out - - - From 573c32b2b06e576a192d2250c676606b23f1aedf Mon Sep 17 00:00:00 2001 From: tangy5 Date: Tue, 4 Apr 2023 17:01:10 -0700 Subject: [PATCH 3/8] add universal model, add parameters Signed-off-by: tangy5 --- monai/networks/blocks/head_controller.py | 65 ++++++----- monai/networks/nets/universal_model.py | 139 +++++++++++++++++++++++ 2 files changed, 173 insertions(+), 31 deletions(-) mode change 100644 => 100755 monai/networks/blocks/head_controller.py create mode 100755 monai/networks/nets/universal_model.py diff --git a/monai/networks/blocks/head_controller.py b/monai/networks/blocks/head_controller.py old mode 100644 new mode 100755 index c4b896f308..b5fa891c3c --- a/monai/networks/blocks/head_controller.py +++ b/monai/networks/blocks/head_controller.py @@ -13,60 +13,57 @@ import torch import torch.nn as nn +from typing import Optional class HeadController(nn.Module): """ Text-based controller for segmentation outputs, the text-driven segmentor enables for optional outputs instead of fixed output channels. Users can choose and control the number and name of output channels from a mult-class segmentation - model. This can enabble incremental learning by adding new classes to a existing pre-trained model without + model. This can enabble incremental learning by adding new classes to a existing pre-trained model without catatrophic forgetting. - + Text-dirven segmentor, based on: "Liu et al., CLIP-Driven Universal Model for Organ Segmentation and Tumor Detection " - """ + """ def __init__( self, - task_encoding: str, out_channels: int, + weight_nums: list = [64, 64, 8], + bias_nums: list = [8, 8, 1], hidden_size: int = 256, + task_encoding: Optional[torch.Tensor] = None, ) -> None: """ Args: - task_encoding: the text embedding features passed. TODO: make optional out_channels: number of output channels, to control text-baesd embedding for classes. + weight_nums: weight feature size of text-driven segmentor conv layers, len(weight_nums) defines the number of layers. + bias_nums: bias size of text-driven segmentor conv layers, len(bias_nums) needs to be consistent to len(weight_nums). hidden_size: dimension of hidden features, compatible to different vision feature dimensions. + task_encoding: the text embedding features passed. TODO: make optional """ - - weight_nums, bias_nums = [], [] - weight_nums.append(8*8) - weight_nums.append(8*8) - weight_nums.append(8*1) - bias_nums.append(8) - bias_nums.append(8) - bias_nums.append(1) + super().__init__() self.weight_nums = weight_nums self.bias_nums = bias_nums - #TODO: parameterize basic kernel size, stride, and padding - self.controller = nn.Conv3d(2*hidden_size, sum(weight_nums+bias_nums), kernel_size=1, stride=1, padding=0) self.class_num = out_channels - self.task_encoding = task_encoding + if task_encoding: + self.task_encoding = task_encoding + self.controller = nn.Conv3d(2*hidden_size, sum(weight_nums+bias_nums), kernel_size=1, stride=1, padding=0) + else: + self.controller = nn.Conv3d(hidden_size, sum(weight_nums+bias_nums), kernel_size=1, stride=1, padding=0) self.precls_conv = nn.Sequential( nn.GroupNorm(16, 48), nn.ReLU(inplace=True), nn.Conv3d(48, 8, kernel_size=1) ) - self.GAP = nn.Sequential( - nn.GroupNorm(16, 768), - nn.ReLU(inplace=True), - torch.nn.AdaptiveAvgPool3d((1,1,1)), - nn.Conv3d(768, 256, kernel_size=1, stride=1, padding=0) - ) def parse_dynamic_params(self, params, channels, weight_nums, bias_nums): + """ + Text-driven segmentor with layers of conv for dynamic outputs + """ assert params.dim() == 2 assert len(weight_nums) == len(bias_nums) assert params.size(1) == sum(weight_nums) + sum(bias_nums) @@ -96,24 +93,27 @@ def heads_forward(self, features, weights, biases, num_insts): n_layers = len(weights) x = features for i, (w, b) in enumerate(zip(weights, biases)): - x = nn.function.conv3d( + x = nn.functional.conv3d( x, w, bias=b, stride=1, padding=0, groups=num_insts ) if i < n_layers - 1: - x = nn.function.relu(x) + x = nn.functional.relu(x) return x - def forward(self, x): - x_feat = self.GAP(x) - b = x_feat.shape[0] + def forward(self, x, out, logits_options=None): + logits_options = range(x.shape[0]) if not isinstance(logits_options, list) else logits_options logits_array = [] - for i in range(b): - x_cond = torch.cat([x_feat[i].unsqueeze(0).repeat(self.class_num,1,1,1,1), self.task_encoding], 1) + for i in logits_options: + if self.task_encoding: + x_cond = torch.cat([x[i].unsqueeze(0).repeat(self.class_num,1,1,1,1), self.task_encoding], 1) + else: + x_cond = x[i].unsqueeze(0).repeat(self.class_num,1,1,1,1) + params = self.controller(x_cond) params.squeeze_(-1).squeeze_(-1).squeeze_(-1) - + head_inputs = self.precls_conv(out[i].unsqueeze(0)) head_inputs = head_inputs.repeat(self.class_num,1,1,1,1) N, _, D, H, W = head_inputs.size() @@ -122,6 +122,9 @@ def forward(self, x): logits = self.heads_forward(head_inputs, weights, biases, N) logits_array.append(logits.reshape(1, -1, D, H, W)) - + out = torch.cat(logits_array,dim=0) return out + + + diff --git a/monai/networks/nets/universal_model.py b/monai/networks/nets/universal_model.py new file mode 100755 index 0000000000..251477fd31 --- /dev/null +++ b/monai/networks/nets/universal_model.py @@ -0,0 +1,139 @@ +from typing import Sequence, Tuple, Type, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from torch.nn import LayerNorm +from monai.networks.blocks.text_enbedding import TextEncoder +from monai.networks.blocks.head_controller import HeadController +from monai.networks.nets import SwinUNETR + +class SwinUNETR_backbone(SwinUNETR): + """ + Universal Model uses SwinUNETR as backbone without the segmentation head based on: + + "Hatamizadeh et al., + Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images + " and + + "Tang et al., + Self-Supervised Pre-Training of Swin Transformers for 3D Medical Image Analysis + " + """ + def __init__( + self, + img_size: Union[Sequence[int], int], + in_channels: int, + out_channels: int, + depths: Sequence[int] = (2, 2, 2, 2), + num_heads: Sequence[int] = (3, 6, 12, 24), + feature_size: int = 48, + norm_name: Union[Tuple, str] = "instance", + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + dropout_path_rate: float = 0.0, + normalize: bool = True, + use_checkpoint: bool = False, + spatial_dims: int = 3, + ): + super(SwinUNETR_backbone, self).__init__(img_size,in_channels,out_channels,feature_size=48) + + def forward(self, x_in): + hidden_states_out = self.swinViT(x_in, self.normalize) + enc0 = self.encoder1(x_in) + enc1 = self.encoder2(hidden_states_out[0]) + enc2 = self.encoder3(hidden_states_out[1]) + enc3 = self.encoder4(hidden_states_out[2]) + dec4 = self.encoder10(hidden_states_out[4]) + + dec3 = self.decoder5(dec4, hidden_states_out[3]) + dec2 = self.decoder4(dec3, enc3) + dec1 = self.decoder3(dec2, enc2) + dec0 = self.decoder2(dec1, enc1) + out = self.decoder1(dec0, enc0) + + return dec4, out + +class Universal_model(nn.Module): + """ + Universal Model for organ and tumor segmentation, based on: "Liu et al., + CLIP-Driven Universal Model for Organ Segmentation and Tumor Detection " + """ + def __init__( + self, + img_size, + in_channels, + out_channels, + bottleneck_size=768, + text_dim=512, + hidden_size=256, + backbone = 'swinunetr', + encoding = 'clip_embedding' + ): + super().__init__() + self.backbone_name = backbone + if backbone == 'swinunetr': + self.backbone = SwinUNETR_backbone( + img_size=img_size, + in_channels=in_channels, + out_channels=out_channels, + feature_size=48, + drop_rate=0.0, + attn_drop_rate=0.0, + dropout_path_rate=0.0, + use_checkpoint=False, + ) + else: + raise Exception('{} backbone is not implemented, please add your own'.format(backbone)) + self.class_num = out_channels + + # text encoder + self.text_encoder = TextEncoder( + out_channels=self.class_num, + text_dim=text_dim, + hidden_size=hidden_size, + encoding=encoding + ) + + self.head_controller = HeadController( + out_channels=out_channels, + ) + + self.GAP = nn.Sequential( + nn.GroupNorm(16, bottleneck_size), + nn.ReLU(inplace=True), + torch.nn.AdaptiveAvgPool3d((1,1,1)), + nn.Conv3d(bottleneck_size, hidden_size, kernel_size=1, stride=1, padding=0) + ) + + def load_params(self, model_dict): + if self.backbone_name == 'swinunetr': + store_dict = self.backbone.state_dict() + for key in model_dict.keys(): + if 'out' not in key: + store_dict[key] = model_dict[key] + + self.backbone.load_state_dict(store_dict) + print('Use swin unetr pretrained weights') + else: + raise Exception('{} backbone is not implemented, please add your own'.format(backbone)) + + def encoding_task(self, task_id): + N = task_id.shape[0] + task_encoding = torch.zeros(size=(N, 7)) + for i in range(N): + task_encoding[i, task_id[i]]=1 + return task_encoding.cuda() + + def forward(self, x_in): + # get backbone feature + dec4, out = self.backbone(x_in) + # get task text encoding + task_encoding = self.text_encoder() + # text controlled outputs + x_feat = self.GAP(dec4) + out = self.head_controller(x_feat, out, task_encoding) + + return out From 6bb3d12084c855c1df729c4d4c4d07e638b1d76c Mon Sep 17 00:00:00 2001 From: tangy5 Date: Tue, 4 Apr 2023 21:54:27 -0700 Subject: [PATCH 4/8] update configs Signed-off-by: tangy5 --- monai/networks/blocks/head_controller.py | 75 +++++++------ monai/networks/nets/universal_model.py | 132 +++++++++++------------ 2 files changed, 104 insertions(+), 103 deletions(-) diff --git a/monai/networks/blocks/head_controller.py b/monai/networks/blocks/head_controller.py index b5fa891c3c..e0f8dc9c5b 100755 --- a/monai/networks/blocks/head_controller.py +++ b/monai/networks/blocks/head_controller.py @@ -13,7 +13,6 @@ import torch import torch.nn as nn -from typing import Optional class HeadController(nn.Module): """ @@ -28,39 +27,44 @@ class HeadController(nn.Module): def __init__( self, out_channels: int, - weight_nums: list = [64, 64, 8], - bias_nums: list = [8, 8, 1], + feature_size: int = 48, + head_in_channels:int = 8, + head_layers: int = 3, + head_hidden_size: int = 8, hidden_size: int = 256, - task_encoding: Optional[torch.Tensor] = None, + text_encoding: bool = True, ) -> None: """ Args: out_channels: number of output channels, to control text-baesd embedding for classes. - weight_nums: weight feature size of text-driven segmentor conv layers, len(weight_nums) defines the number of layers. - bias_nums: bias size of text-driven segmentor conv layers, len(bias_nums) needs to be consistent to len(weight_nums). - hidden_size: dimension of hidden features, compatible to different vision feature dimensions. - task_encoding: the text embedding features passed. TODO: make optional + feature_size: the backbone output feature size before segmentation heads. + head_in_channels: number of dynamic segmentor input channels. + head_layers: number of conv layers of the dynamic segmentor. + head_hidden_size: hidden feature size of the intermediate dynamic segmentor conv layers . + hidden_size: dimension of backbone's bottleneck features. + text_encoding: the text embedding features passed. """ super().__init__() - self.weight_nums = weight_nums - self.bias_nums = bias_nums - - + + self.head_hidden_size = head_hidden_size + self.bias_nums = [head_hidden_size] * (head_layers - 1) + [1] # defined by segmentor head's hidden size, last element of 1. + self.weight_nums = [head_in_channels*head_hidden_size] + [head_hidden_size*head_hidden_size]*(head_layers-2) + [head_hidden_size] #first+intermediate+last layer + self.class_num = out_channels - self.task_encoding = task_encoding - if task_encoding: - self.task_encoding = task_encoding - self.controller = nn.Conv3d(2*hidden_size, sum(weight_nums+bias_nums), kernel_size=1, stride=1, padding=0) + self.text_encoding = text_encoding + # text-driven controller: connection of bottleneck feature to segmentor features, e.g., from 256(*2) to weights and bias nums + if self.text_encoding: + self.controller = nn.Conv3d(2*hidden_size, sum(self.weight_nums+self.bias_nums), kernel_size=1, stride=1, padding=0) else: - self.controller = nn.Conv3d(hidden_size, sum(weight_nums+bias_nums), kernel_size=1, stride=1, padding=0) - + self.controller = nn.Conv3d(hidden_size, sum(self.weight_nums+self.bias_nums), kernel_size=1, stride=1, padding=0) + # convolution layer of backbone output to segmentor head input size, e.g., 48 to 8 self.precls_conv = nn.Sequential( - nn.GroupNorm(16, 48), + nn.GroupNorm(16, feature_size), nn.ReLU(inplace=True), - nn.Conv3d(48, 8, kernel_size=1) + nn.Conv3d(feature_size, head_in_channels, kernel_size=1) ) - def parse_dynamic_params(self, params, channels, weight_nums, bias_nums): + def parse_dynamic_params(self, params, head_hidden_size, weight_nums, bias_nums): """ Text-driven segmentor with layers of conv for dynamic outputs """ @@ -80,8 +84,8 @@ def parse_dynamic_params(self, params, channels, weight_nums, bias_nums): for l in range(num_layers): if l < num_layers - 1: - weight_splits[l] = weight_splits[l].reshape(num_insts * channels, -1, 1, 1, 1) - bias_splits[l] = bias_splits[l].reshape(num_insts * channels) + weight_splits[l] = weight_splits[l].reshape(num_insts * head_hidden_size, -1, 1, 1, 1) + bias_splits[l] = bias_splits[l].reshape(num_insts * head_hidden_size) else: weight_splits[l] = weight_splits[l].reshape(num_insts * 1, -1, 1, 1, 1) bias_splits[l] = bias_splits[l].reshape(num_insts * 1) @@ -102,24 +106,26 @@ def heads_forward(self, features, weights, biases, num_insts): x = nn.functional.relu(x) return x - def forward(self, x, out, logits_options=None): - logits_options = range(x.shape[0]) if not isinstance(logits_options, list) else logits_options + def forward(self, x, out, text_encoding=None, logits_options=None): + logits_options = range(self.class_num) if not isinstance(logits_options, list) else logits_options + b = x.shape[0] logits_array = [] - for i in logits_options: - if self.task_encoding: - x_cond = torch.cat([x[i].unsqueeze(0).repeat(self.class_num,1,1,1,1), self.task_encoding], 1) + for i in range(b): ## loop in batch size + # extract the corresponding text encoding and concate with x + if self.text_encoding: + x_cond = torch.cat([x[i].unsqueeze(0).repeat(len(logits_options),1,1,1,1), text_encoding[logits_options]], 1) else: - x_cond = x[i].unsqueeze(0).repeat(self.class_num,1,1,1,1) - + x_cond = x[i].unsqueeze(0).repeat(len(logits_options),1,1,1,1) + # generate param for segmentor params = self.controller(x_cond) params.squeeze_(-1).squeeze_(-1).squeeze_(-1) - + ## dynamic segmentor head_inputs = self.precls_conv(out[i].unsqueeze(0)) - head_inputs = head_inputs.repeat(self.class_num,1,1,1,1) + head_inputs = head_inputs.repeat(len(logits_options),1,1,1,1) N, _, D, H, W = head_inputs.size() head_inputs = head_inputs.reshape(1, -1, D, H, W) - weights, biases = self.parse_dynamic_params(params, 8, self.weight_nums, self.bias_nums) - + # conv operation + weights, biases = self.parse_dynamic_params(params, self.head_hidden_size, self.weight_nums, self.bias_nums) logits = self.heads_forward(head_inputs, weights, biases, N) logits_array.append(logits.reshape(1, -1, D, H, W)) @@ -127,4 +133,3 @@ def forward(self, x, out, logits_options=None): return out - diff --git a/monai/networks/nets/universal_model.py b/monai/networks/nets/universal_model.py index 251477fd31..2bc0505ad6 100755 --- a/monai/networks/nets/universal_model.py +++ b/monai/networks/nets/universal_model.py @@ -1,60 +1,14 @@ -from typing import Sequence, Tuple, Type, Union +from typing import Sequence, Tuple, Union -import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint as checkpoint -from torch.nn import LayerNorm -from monai.networks.blocks.text_enbedding import TextEncoder -from monai.networks.blocks.head_controller import HeadController -from monai.networks.nets import SwinUNETR +from model.text_embedding import TextEncoder +from model.head_controller import HeadController -class SwinUNETR_backbone(SwinUNETR): - """ - Universal Model uses SwinUNETR as backbone without the segmentation head based on: - - "Hatamizadeh et al., - Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images - " and - - "Tang et al., - Self-Supervised Pre-Training of Swin Transformers for 3D Medical Image Analysis - " - """ - def __init__( - self, - img_size: Union[Sequence[int], int], - in_channels: int, - out_channels: int, - depths: Sequence[int] = (2, 2, 2, 2), - num_heads: Sequence[int] = (3, 6, 12, 24), - feature_size: int = 48, - norm_name: Union[Tuple, str] = "instance", - drop_rate: float = 0.0, - attn_drop_rate: float = 0.0, - dropout_path_rate: float = 0.0, - normalize: bool = True, - use_checkpoint: bool = False, - spatial_dims: int = 3, - ): - super(SwinUNETR_backbone, self).__init__(img_size,in_channels,out_channels,feature_size=48) - - def forward(self, x_in): - hidden_states_out = self.swinViT(x_in, self.normalize) - enc0 = self.encoder1(x_in) - enc1 = self.encoder2(hidden_states_out[0]) - enc2 = self.encoder3(hidden_states_out[1]) - enc3 = self.encoder4(hidden_states_out[2]) - dec4 = self.encoder10(hidden_states_out[4]) +from monai.networks.blocks.text_embedding import TextEncoder +from monai.networks.blocks.head_controller import HeadController - dec3 = self.decoder5(dec4, hidden_states_out[3]) - dec2 = self.decoder4(dec3, enc3) - dec1 = self.decoder3(dec2, enc2) - dec0 = self.decoder2(dec1, enc1) - out = self.decoder1(dec0, enc0) - - return dec4, out +from monai.networks.nets import SwinUNETR class Universal_model(nn.Module): """ @@ -66,11 +20,12 @@ def __init__( img_size, in_channels, out_channels, - bottleneck_size=768, - text_dim=512, - hidden_size=256, - backbone = 'swinunetr', - encoding = 'clip_embedding' + bottleneck_size: int = 768, + text_dim: int = 512, + hidden_size: int = 256, + backbone: str = 'swinunetr', + encoding: str = 'clip_embedding', + logits_options: list = None, ): super().__init__() self.backbone_name = backbone @@ -88,7 +43,7 @@ def __init__( else: raise Exception('{} backbone is not implemented, please add your own'.format(backbone)) self.class_num = out_channels - + self.logits_options = logits_options # text encoder self.text_encoder = TextEncoder( out_channels=self.class_num, @@ -99,6 +54,7 @@ def __init__( self.head_controller = HeadController( out_channels=out_channels, + text_encoding=True ) self.GAP = nn.Sequential( @@ -118,22 +74,62 @@ def load_params(self, model_dict): self.backbone.load_state_dict(store_dict) print('Use swin unetr pretrained weights') else: - raise Exception('{} backbone is not implemented, please add your own'.format(backbone)) - - def encoding_task(self, task_id): - N = task_id.shape[0] - task_encoding = torch.zeros(size=(N, 7)) - for i in range(N): - task_encoding[i, task_id[i]]=1 - return task_encoding.cuda() + raise Exception('{} backbone is not implemented, please add your own'.format(self.backbone_name)) def forward(self, x_in): # get backbone feature dec4, out = self.backbone(x_in) # get task text encoding - task_encoding = self.text_encoder() + text_encoding = self.text_encoder() # text controlled outputs x_feat = self.GAP(dec4) - out = self.head_controller(x_feat, out, task_encoding) + out = self.head_controller(x_feat, out, text_encoding, self.logits_options) return out + +class SwinUNETR_backbone(SwinUNETR): + """ + Universal Model uses SwinUNETR as backbone without the segmentation head based on: + + "Hatamizadeh et al., + Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images + " and + + "Tang et al., + Self-Supervised Pre-Training of Swin Transformers for 3D Medical Image Analysis + " + """ + def __init__( + self, + img_size: Union[Sequence[int], int], + in_channels: int, + out_channels: int, + depths: Sequence[int] = (2, 2, 2, 2), + num_heads: Sequence[int] = (3, 6, 12, 24), + feature_size: int = 48, + norm_name: Union[Tuple, str] = "instance", + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + dropout_path_rate: float = 0.0, + normalize: bool = True, + use_checkpoint: bool = False, + spatial_dims: int = 3, + ): + super(SwinUNETR_backbone, self).__init__(img_size,in_channels,out_channels,feature_size=48) + + def forward(self, x_in): + hidden_states_out = self.swinViT(x_in, self.normalize) + enc0 = self.encoder1(x_in) + enc1 = self.encoder2(hidden_states_out[0]) + enc2 = self.encoder3(hidden_states_out[1]) + enc3 = self.encoder4(hidden_states_out[2]) + dec4 = self.encoder10(hidden_states_out[4]) + + dec3 = self.decoder5(dec4, hidden_states_out[3]) + dec2 = self.decoder4(dec3, enc3) + dec1 = self.decoder3(dec2, enc2) + dec0 = self.decoder2(dec1, enc1) + out = self.decoder1(dec0, enc0) + + return dec4, out + \ No newline at end of file From 2b7b04bfe7d2f68ed314472ae07b0f490908225c Mon Sep 17 00:00:00 2001 From: tangy5 Date: Tue, 4 Apr 2023 22:03:58 -0700 Subject: [PATCH 5/8] change mode Signed-off-by: tangy5 --- monai/networks/blocks/head_controller.py | 0 monai/networks/nets/universal_model.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) mode change 100755 => 100644 monai/networks/blocks/head_controller.py mode change 100755 => 100644 monai/networks/nets/universal_model.py diff --git a/monai/networks/blocks/head_controller.py b/monai/networks/blocks/head_controller.py old mode 100755 new mode 100644 diff --git a/monai/networks/nets/universal_model.py b/monai/networks/nets/universal_model.py old mode 100755 new mode 100644 From 9923686187b3ab42990d2ed8f2fdcf633ccfb1cf Mon Sep 17 00:00:00 2001 From: tangy5 Date: Tue, 4 Apr 2023 22:07:52 -0700 Subject: [PATCH 6/8] remove umused imports Signed-off-by: tangy5 --- monai/networks/nets/universal_model.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/monai/networks/nets/universal_model.py b/monai/networks/nets/universal_model.py index 2bc0505ad6..e4e6c39236 100644 --- a/monai/networks/nets/universal_model.py +++ b/monai/networks/nets/universal_model.py @@ -2,8 +2,6 @@ import torch import torch.nn as nn -from model.text_embedding import TextEncoder -from model.head_controller import HeadController from monai.networks.blocks.text_embedding import TextEncoder from monai.networks.blocks.head_controller import HeadController @@ -104,16 +102,6 @@ def __init__( img_size: Union[Sequence[int], int], in_channels: int, out_channels: int, - depths: Sequence[int] = (2, 2, 2, 2), - num_heads: Sequence[int] = (3, 6, 12, 24), - feature_size: int = 48, - norm_name: Union[Tuple, str] = "instance", - drop_rate: float = 0.0, - attn_drop_rate: float = 0.0, - dropout_path_rate: float = 0.0, - normalize: bool = True, - use_checkpoint: bool = False, - spatial_dims: int = 3, ): super(SwinUNETR_backbone, self).__init__(img_size,in_channels,out_channels,feature_size=48) From cb1402cfe50aeeb63587d56c12c8a3c73a8bf9d9 Mon Sep 17 00:00:00 2001 From: tangy5 Date: Tue, 4 Apr 2023 23:24:04 -0700 Subject: [PATCH 7/8] add args Signed-off-by: tangy5 --- monai/networks/nets/universal_model.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/monai/networks/nets/universal_model.py b/monai/networks/nets/universal_model.py index e4e6c39236..01d546c084 100644 --- a/monai/networks/nets/universal_model.py +++ b/monai/networks/nets/universal_model.py @@ -102,6 +102,16 @@ def __init__( img_size: Union[Sequence[int], int], in_channels: int, out_channels: int, + depths: Sequence[int] = (2, 2, 2, 2), + num_heads: Sequence[int] = (3, 6, 12, 24), + feature_size: int = 48, + norm_name: Union[Tuple, str] = "instance", + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + dropout_path_rate: float = 0.0, + normalize: bool = True, + use_checkpoint: bool = False, + spatial_dims: int = 3, ): super(SwinUNETR_backbone, self).__init__(img_size,in_channels,out_channels,feature_size=48) From fb548045ef37348db4677a9c8f1f8da934518bd8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 5 Apr 2023 06:24:53 +0000 Subject: [PATCH 8/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/blocks/head_controller.py | 20 +++++++-------- monai/networks/nets/universal_model.py | 31 ++++++++++++------------ 2 files changed, 24 insertions(+), 27 deletions(-) diff --git a/monai/networks/blocks/head_controller.py b/monai/networks/blocks/head_controller.py index e0f8dc9c5b..a19e925114 100644 --- a/monai/networks/blocks/head_controller.py +++ b/monai/networks/blocks/head_controller.py @@ -18,12 +18,12 @@ class HeadController(nn.Module): """ Text-based controller for segmentation outputs, the text-driven segmentor enables for optional outputs instead of fixed output channels. Users can choose and control the number and name of output channels from a mult-class segmentation - model. This can enabble incremental learning by adding new classes to a existing pre-trained model without + model. This can enabble incremental learning by adding new classes to a existing pre-trained model without catatrophic forgetting. - + Text-dirven segmentor, based on: "Liu et al., CLIP-Driven Universal Model for Organ Segmentation and Tumor Detection " - """ + """ def __init__( self, out_channels: int, @@ -45,11 +45,11 @@ def __init__( text_encoding: the text embedding features passed. """ super().__init__() - + self.head_hidden_size = head_hidden_size - self.bias_nums = [head_hidden_size] * (head_layers - 1) + [1] # defined by segmentor head's hidden size, last element of 1. + self.bias_nums = [head_hidden_size] * (head_layers - 1) + [1] # defined by segmentor head's hidden size, last element of 1. self.weight_nums = [head_in_channels*head_hidden_size] + [head_hidden_size*head_hidden_size]*(head_layers-2) + [head_hidden_size] #first+intermediate+last layer - + self.class_num = out_channels self.text_encoding = text_encoding # text-driven controller: connection of bottleneck feature to segmentor features, e.g., from 256(*2) to weights and bias nums @@ -57,7 +57,7 @@ def __init__( self.controller = nn.Conv3d(2*hidden_size, sum(self.weight_nums+self.bias_nums), kernel_size=1, stride=1, padding=0) else: self.controller = nn.Conv3d(hidden_size, sum(self.weight_nums+self.bias_nums), kernel_size=1, stride=1, padding=0) - # convolution layer of backbone output to segmentor head input size, e.g., 48 to 8 + # convolution layer of backbone output to segmentor head input size, e.g., 48 to 8 self.precls_conv = nn.Sequential( nn.GroupNorm(16, feature_size), nn.ReLU(inplace=True), @@ -106,7 +106,7 @@ def heads_forward(self, features, weights, biases, num_insts): x = nn.functional.relu(x) return x - def forward(self, x, out, text_encoding=None, logits_options=None): + def forward(self, x, out, text_encoding=None, logits_options=None): logits_options = range(self.class_num) if not isinstance(logits_options, list) else logits_options b = x.shape[0] logits_array = [] @@ -128,8 +128,6 @@ def forward(self, x, out, text_encoding=None, logits_options=None): weights, biases = self.parse_dynamic_params(params, self.head_hidden_size, self.weight_nums, self.bias_nums) logits = self.heads_forward(head_inputs, weights, biases, N) logits_array.append(logits.reshape(1, -1, D, H, W)) - + out = torch.cat(logits_array,dim=0) return out - - diff --git a/monai/networks/nets/universal_model.py b/monai/networks/nets/universal_model.py index 01d546c084..43f4b446b4 100644 --- a/monai/networks/nets/universal_model.py +++ b/monai/networks/nets/universal_model.py @@ -12,16 +12,16 @@ class Universal_model(nn.Module): """ Universal Model for organ and tumor segmentation, based on: "Liu et al., CLIP-Driven Universal Model for Organ Segmentation and Tumor Detection " - """ + """ def __init__( - self, - img_size, - in_channels, - out_channels, - bottleneck_size: int = 768, - text_dim: int = 512, - hidden_size: int = 256, - backbone: str = 'swinunetr', + self, + img_size, + in_channels, + out_channels, + bottleneck_size: int = 768, + text_dim: int = 512, + hidden_size: int = 256, + backbone: str = 'swinunetr', encoding: str = 'clip_embedding', logits_options: list = None, ): @@ -39,7 +39,7 @@ def __init__( use_checkpoint=False, ) else: - raise Exception('{} backbone is not implemented, please add your own'.format(backbone)) + raise Exception(f'{backbone} backbone is not implemented, please add your own') self.class_num = out_channels self.logits_options = logits_options # text encoder @@ -72,7 +72,7 @@ def load_params(self, model_dict): self.backbone.load_state_dict(store_dict) print('Use swin unetr pretrained weights') else: - raise Exception('{} backbone is not implemented, please add your own'.format(self.backbone_name)) + raise Exception(f'{self.backbone_name} backbone is not implemented, please add your own') def forward(self, x_in): # get backbone feature @@ -87,7 +87,7 @@ def forward(self, x_in): class SwinUNETR_backbone(SwinUNETR): """ - Universal Model uses SwinUNETR as backbone without the segmentation head based on: + Universal Model uses SwinUNETR as backbone without the segmentation head based on: "Hatamizadeh et al., Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images @@ -96,8 +96,8 @@ class SwinUNETR_backbone(SwinUNETR): "Tang et al., Self-Supervised Pre-Training of Swin Transformers for 3D Medical Image Analysis " - """ - def __init__( + """ + def __init__( self, img_size: Union[Sequence[int], int], in_channels: int, @@ -113,7 +113,7 @@ def __init__( use_checkpoint: bool = False, spatial_dims: int = 3, ): - super(SwinUNETR_backbone, self).__init__(img_size,in_channels,out_channels,feature_size=48) + super().__init__(img_size,in_channels,out_channels,feature_size=48) def forward(self, x_in): hidden_states_out = self.swinViT(x_in, self.normalize) @@ -130,4 +130,3 @@ def forward(self, x_in): out = self.decoder1(dec0, enc0) return dec4, out - \ No newline at end of file