Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 5, 2023
1 parent 573c32b commit 4f98a2b
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 32 deletions.
13 changes: 5 additions & 8 deletions monai/networks/blocks/head_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ class HeadController(nn.Module):
"""
Text-based controller for segmentation outputs, the text-driven segmentor enables for optional outputs instead of
fixed output channels. Users can choose and control the number and name of output channels from a mult-class segmentation
model. This can enabble incremental learning by adding new classes to a existing pre-trained model without
model. This can enabble incremental learning by adding new classes to a existing pre-trained model without
catatrophic forgetting.
Text-dirven segmentor, based on: "Liu et al.,
CLIP-Driven Universal Model for Organ Segmentation and Tumor Detection <https://arxiv.org/pdf/2301.00785.pdf>"
"""
"""
def __init__(
self,
out_channels: int,
Expand Down Expand Up @@ -113,7 +113,7 @@ def forward(self, x, out, logits_options=None):

params = self.controller(x_cond)
params.squeeze_(-1).squeeze_(-1).squeeze_(-1)

head_inputs = self.precls_conv(out[i].unsqueeze(0))
head_inputs = head_inputs.repeat(self.class_num,1,1,1,1)
N, _, D, H, W = head_inputs.size()
Expand All @@ -122,9 +122,6 @@ def forward(self, x, out, logits_options=None):

logits = self.heads_forward(head_inputs, weights, biases, N)
logits_array.append(logits.reshape(1, -1, D, H, W))

out = torch.cat(logits_array,dim=0)
return out



44 changes: 20 additions & 24 deletions monai/networks/nets/universal_model.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,24 @@
from typing import Sequence, Tuple, Type, Union
from typing import Sequence, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from torch.nn import LayerNorm
from monai.networks.blocks.text_enbedding import TextEncoder
from monai.networks.blocks.head_controller import HeadController
from monai.networks.nets import SwinUNETR

class SwinUNETR_backbone(SwinUNETR):
"""
Universal Model uses SwinUNETR as backbone without the segmentation head based on:
Universal Model uses SwinUNETR as backbone without the segmentation head based on:
"Hatamizadeh et al.,
Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images
<https://arxiv.org/abs/2201.01266>" and
"Tang et al.,
Self-Supervised Pre-Training of Swin Transformers for 3D Medical Image Analysis
<https://arxiv.org/abs/2111.14791>"
"""
def __init__(
"""
def __init__(
self,
img_size: Union[Sequence[int], int],
in_channels: int,
Expand All @@ -38,8 +34,8 @@ def __init__(
use_checkpoint: bool = False,
spatial_dims: int = 3,
):
super(SwinUNETR_backbone, self).__init__(img_size,in_channels,out_channels,feature_size=48)
super().__init__(img_size,in_channels,out_channels,feature_size=48)

def forward(self, x_in):
hidden_states_out = self.swinViT(x_in, self.normalize)
enc0 = self.encoder1(x_in)
Expand All @@ -53,23 +49,23 @@ def forward(self, x_in):
dec1 = self.decoder3(dec2, enc2)
dec0 = self.decoder2(dec1, enc1)
out = self.decoder1(dec0, enc0)

return dec4, out

class Universal_model(nn.Module):
"""
Universal Model for organ and tumor segmentation, based on: "Liu et al.,
CLIP-Driven Universal Model for Organ Segmentation and Tumor Detection <https://arxiv.org/pdf/2301.00785.pdf>"
"""
"""
def __init__(
self,
img_size,
in_channels,
out_channels,
bottleneck_size=768,
text_dim=512,
hidden_size=256,
backbone = 'swinunetr',
self,
img_size,
in_channels,
out_channels,
bottleneck_size=768,
text_dim=512,
hidden_size=256,
backbone = 'swinunetr',
encoding = 'clip_embedding'
):
super().__init__()
Expand All @@ -86,7 +82,7 @@ def __init__(
use_checkpoint=False,
)
else:
raise Exception('{} backbone is not implemented, please add your own'.format(backbone))
raise Exception(f'{backbone} backbone is not implemented, please add your own')
self.class_num = out_channels

# text encoder
Expand Down Expand Up @@ -118,7 +114,7 @@ def load_params(self, model_dict):
self.backbone.load_state_dict(store_dict)
print('Use swin unetr pretrained weights')
else:
raise Exception('{} backbone is not implemented, please add your own'.format(backbone))
raise Exception(f'{backbone} backbone is not implemented, please add your own')

def encoding_task(self, task_id):
N = task_id.shape[0]
Expand Down

0 comments on commit 4f98a2b

Please sign in to comment.