Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 5, 2023
1 parent 2b7b04b commit 117002a
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 27 deletions.
20 changes: 9 additions & 11 deletions monai/networks/blocks/head_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ class HeadController(nn.Module):
"""
Text-based controller for segmentation outputs, the text-driven segmentor enables for optional outputs instead of
fixed output channels. Users can choose and control the number and name of output channels from a mult-class segmentation
model. This can enabble incremental learning by adding new classes to a existing pre-trained model without
model. This can enabble incremental learning by adding new classes to a existing pre-trained model without
catatrophic forgetting.
Text-dirven segmentor, based on: "Liu et al.,
CLIP-Driven Universal Model for Organ Segmentation and Tumor Detection <https://arxiv.org/pdf/2301.00785.pdf>"
"""
"""
def __init__(
self,
out_channels: int,
Expand All @@ -45,19 +45,19 @@ def __init__(
text_encoding: the text embedding features passed.
"""
super().__init__()

self.head_hidden_size = head_hidden_size
self.bias_nums = [head_hidden_size] * (head_layers - 1) + [1] # defined by segmentor head's hidden size, last element of 1.
self.bias_nums = [head_hidden_size] * (head_layers - 1) + [1] # defined by segmentor head's hidden size, last element of 1.
self.weight_nums = [head_in_channels*head_hidden_size] + [head_hidden_size*head_hidden_size]*(head_layers-2) + [head_hidden_size] #first+intermediate+last layer

self.class_num = out_channels
self.text_encoding = text_encoding
# text-driven controller: connection of bottleneck feature to segmentor features, e.g., from 256(*2) to weights and bias nums
if self.text_encoding:
self.controller = nn.Conv3d(2*hidden_size, sum(self.weight_nums+self.bias_nums), kernel_size=1, stride=1, padding=0)
else:
self.controller = nn.Conv3d(hidden_size, sum(self.weight_nums+self.bias_nums), kernel_size=1, stride=1, padding=0)
# convolution layer of backbone output to segmentor head input size, e.g., 48 to 8
# convolution layer of backbone output to segmentor head input size, e.g., 48 to 8
self.precls_conv = nn.Sequential(
nn.GroupNorm(16, feature_size),
nn.ReLU(inplace=True),
Expand Down Expand Up @@ -106,7 +106,7 @@ def heads_forward(self, features, weights, biases, num_insts):
x = nn.functional.relu(x)
return x

def forward(self, x, out, text_encoding=None, logits_options=None):
def forward(self, x, out, text_encoding=None, logits_options=None):
logits_options = range(self.class_num) if not isinstance(logits_options, list) else logits_options
b = x.shape[0]
logits_array = []
Expand All @@ -128,8 +128,6 @@ def forward(self, x, out, text_encoding=None, logits_options=None):
weights, biases = self.parse_dynamic_params(params, self.head_hidden_size, self.weight_nums, self.bias_nums)
logits = self.heads_forward(head_inputs, weights, biases, N)
logits_array.append(logits.reshape(1, -1, D, H, W))

out = torch.cat(logits_array,dim=0)
return out


31 changes: 15 additions & 16 deletions monai/networks/nets/universal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@ class Universal_model(nn.Module):
"""
Universal Model for organ and tumor segmentation, based on: "Liu et al.,
CLIP-Driven Universal Model for Organ Segmentation and Tumor Detection <https://arxiv.org/pdf/2301.00785.pdf>"
"""
"""
def __init__(
self,
img_size,
in_channels,
out_channels,
bottleneck_size: int = 768,
text_dim: int = 512,
hidden_size: int = 256,
backbone: str = 'swinunetr',
self,
img_size,
in_channels,
out_channels,
bottleneck_size: int = 768,
text_dim: int = 512,
hidden_size: int = 256,
backbone: str = 'swinunetr',
encoding: str = 'clip_embedding',
logits_options: list = None,
):
Expand All @@ -41,7 +41,7 @@ def __init__(
use_checkpoint=False,
)
else:
raise Exception('{} backbone is not implemented, please add your own'.format(backbone))
raise Exception(f'{backbone} backbone is not implemented, please add your own')
self.class_num = out_channels
self.logits_options = logits_options
# text encoder
Expand Down Expand Up @@ -74,7 +74,7 @@ def load_params(self, model_dict):
self.backbone.load_state_dict(store_dict)
print('Use swin unetr pretrained weights')
else:
raise Exception('{} backbone is not implemented, please add your own'.format(self.backbone_name))
raise Exception(f'{self.backbone_name} backbone is not implemented, please add your own')

def forward(self, x_in):
# get backbone feature
Expand All @@ -89,7 +89,7 @@ def forward(self, x_in):

class SwinUNETR_backbone(SwinUNETR):
"""
Universal Model uses SwinUNETR as backbone without the segmentation head based on:
Universal Model uses SwinUNETR as backbone without the segmentation head based on:
"Hatamizadeh et al.,
Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images
Expand All @@ -98,8 +98,8 @@ class SwinUNETR_backbone(SwinUNETR):
"Tang et al.,
Self-Supervised Pre-Training of Swin Transformers for 3D Medical Image Analysis
<https://arxiv.org/abs/2111.14791>"
"""
def __init__(
"""
def __init__(
self,
img_size: Union[Sequence[int], int],
in_channels: int,
Expand All @@ -115,7 +115,7 @@ def __init__(
use_checkpoint: bool = False,
spatial_dims: int = 3,
):
super(SwinUNETR_backbone, self).__init__(img_size,in_channels,out_channels,feature_size=48)
super().__init__(img_size,in_channels,out_channels,feature_size=48)

def forward(self, x_in):
hidden_states_out = self.swinViT(x_in, self.normalize)
Expand All @@ -132,4 +132,3 @@ def forward(self, x_in):
out = self.decoder1(dec0, enc0)

return dec4, out

0 comments on commit 117002a

Please sign in to comment.