Skip to content

Commit

Permalink
update configs
Browse files Browse the repository at this point in the history
Signed-off-by: tangy5 <[email protected]>
  • Loading branch information
tangy5 committed Apr 5, 2023
1 parent 573c32b commit 6bb3d12
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 103 deletions.
75 changes: 40 additions & 35 deletions monai/networks/blocks/head_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import torch
import torch.nn as nn
from typing import Optional

class HeadController(nn.Module):
"""
Expand All @@ -28,39 +27,44 @@ class HeadController(nn.Module):
def __init__(
self,
out_channels: int,
weight_nums: list = [64, 64, 8],
bias_nums: list = [8, 8, 1],
feature_size: int = 48,
head_in_channels:int = 8,
head_layers: int = 3,
head_hidden_size: int = 8,
hidden_size: int = 256,
task_encoding: Optional[torch.Tensor] = None,
text_encoding: bool = True,
) -> None:
"""
Args:
out_channels: number of output channels, to control text-baesd embedding for classes.
weight_nums: weight feature size of text-driven segmentor conv layers, len(weight_nums) defines the number of layers.
bias_nums: bias size of text-driven segmentor conv layers, len(bias_nums) needs to be consistent to len(weight_nums).
hidden_size: dimension of hidden features, compatible to different vision feature dimensions.
task_encoding: the text embedding features passed. TODO: make optional
feature_size: the backbone output feature size before segmentation heads.
head_in_channels: number of dynamic segmentor input channels.
head_layers: number of conv layers of the dynamic segmentor.
head_hidden_size: hidden feature size of the intermediate dynamic segmentor conv layers .
hidden_size: dimension of backbone's bottleneck features.
text_encoding: the text embedding features passed.
"""
super().__init__()
self.weight_nums = weight_nums
self.bias_nums = bias_nums



self.head_hidden_size = head_hidden_size
self.bias_nums = [head_hidden_size] * (head_layers - 1) + [1] # defined by segmentor head's hidden size, last element of 1.
self.weight_nums = [head_in_channels*head_hidden_size] + [head_hidden_size*head_hidden_size]*(head_layers-2) + [head_hidden_size] #first+intermediate+last layer

self.class_num = out_channels
self.task_encoding = task_encoding
if task_encoding:
self.task_encoding = task_encoding
self.controller = nn.Conv3d(2*hidden_size, sum(weight_nums+bias_nums), kernel_size=1, stride=1, padding=0)
self.text_encoding = text_encoding
# text-driven controller: connection of bottleneck feature to segmentor features, e.g., from 256(*2) to weights and bias nums
if self.text_encoding:
self.controller = nn.Conv3d(2*hidden_size, sum(self.weight_nums+self.bias_nums), kernel_size=1, stride=1, padding=0)
else:
self.controller = nn.Conv3d(hidden_size, sum(weight_nums+bias_nums), kernel_size=1, stride=1, padding=0)

self.controller = nn.Conv3d(hidden_size, sum(self.weight_nums+self.bias_nums), kernel_size=1, stride=1, padding=0)
# convolution layer of backbone output to segmentor head input size, e.g., 48 to 8
self.precls_conv = nn.Sequential(
nn.GroupNorm(16, 48),
nn.GroupNorm(16, feature_size),
nn.ReLU(inplace=True),
nn.Conv3d(48, 8, kernel_size=1)
nn.Conv3d(feature_size, head_in_channels, kernel_size=1)
)

def parse_dynamic_params(self, params, channels, weight_nums, bias_nums):
def parse_dynamic_params(self, params, head_hidden_size, weight_nums, bias_nums):
"""
Text-driven segmentor with layers of conv for dynamic outputs
"""
Expand All @@ -80,8 +84,8 @@ def parse_dynamic_params(self, params, channels, weight_nums, bias_nums):

for l in range(num_layers):
if l < num_layers - 1:
weight_splits[l] = weight_splits[l].reshape(num_insts * channels, -1, 1, 1, 1)
bias_splits[l] = bias_splits[l].reshape(num_insts * channels)
weight_splits[l] = weight_splits[l].reshape(num_insts * head_hidden_size, -1, 1, 1, 1)
bias_splits[l] = bias_splits[l].reshape(num_insts * head_hidden_size)
else:
weight_splits[l] = weight_splits[l].reshape(num_insts * 1, -1, 1, 1, 1)
bias_splits[l] = bias_splits[l].reshape(num_insts * 1)
Expand All @@ -102,29 +106,30 @@ def heads_forward(self, features, weights, biases, num_insts):
x = nn.functional.relu(x)
return x

def forward(self, x, out, logits_options=None):
logits_options = range(x.shape[0]) if not isinstance(logits_options, list) else logits_options
def forward(self, x, out, text_encoding=None, logits_options=None):
logits_options = range(self.class_num) if not isinstance(logits_options, list) else logits_options
b = x.shape[0]
logits_array = []
for i in logits_options:
if self.task_encoding:
x_cond = torch.cat([x[i].unsqueeze(0).repeat(self.class_num,1,1,1,1), self.task_encoding], 1)
for i in range(b): ## loop in batch size
# extract the corresponding text encoding and concate with x
if self.text_encoding:
x_cond = torch.cat([x[i].unsqueeze(0).repeat(len(logits_options),1,1,1,1), text_encoding[logits_options]], 1)
else:
x_cond = x[i].unsqueeze(0).repeat(self.class_num,1,1,1,1)

x_cond = x[i].unsqueeze(0).repeat(len(logits_options),1,1,1,1)
# generate param for segmentor
params = self.controller(x_cond)
params.squeeze_(-1).squeeze_(-1).squeeze_(-1)

## dynamic segmentor
head_inputs = self.precls_conv(out[i].unsqueeze(0))
head_inputs = head_inputs.repeat(self.class_num,1,1,1,1)
head_inputs = head_inputs.repeat(len(logits_options),1,1,1,1)
N, _, D, H, W = head_inputs.size()
head_inputs = head_inputs.reshape(1, -1, D, H, W)
weights, biases = self.parse_dynamic_params(params, 8, self.weight_nums, self.bias_nums)

# conv operation
weights, biases = self.parse_dynamic_params(params, self.head_hidden_size, self.weight_nums, self.bias_nums)
logits = self.heads_forward(head_inputs, weights, biases, N)
logits_array.append(logits.reshape(1, -1, D, H, W))

out = torch.cat(logits_array,dim=0)
return out



132 changes: 64 additions & 68 deletions monai/networks/nets/universal_model.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,14 @@
from typing import Sequence, Tuple, Type, Union
from typing import Sequence, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from torch.nn import LayerNorm
from monai.networks.blocks.text_enbedding import TextEncoder
from monai.networks.blocks.head_controller import HeadController
from monai.networks.nets import SwinUNETR
from model.text_embedding import TextEncoder
from model.head_controller import HeadController

class SwinUNETR_backbone(SwinUNETR):
"""
Universal Model uses SwinUNETR as backbone without the segmentation head based on:
"Hatamizadeh et al.,
Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images
<https://arxiv.org/abs/2201.01266>" and
"Tang et al.,
Self-Supervised Pre-Training of Swin Transformers for 3D Medical Image Analysis
<https://arxiv.org/abs/2111.14791>"
"""
def __init__(
self,
img_size: Union[Sequence[int], int],
in_channels: int,
out_channels: int,
depths: Sequence[int] = (2, 2, 2, 2),
num_heads: Sequence[int] = (3, 6, 12, 24),
feature_size: int = 48,
norm_name: Union[Tuple, str] = "instance",
drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
dropout_path_rate: float = 0.0,
normalize: bool = True,
use_checkpoint: bool = False,
spatial_dims: int = 3,
):
super(SwinUNETR_backbone, self).__init__(img_size,in_channels,out_channels,feature_size=48)

def forward(self, x_in):
hidden_states_out = self.swinViT(x_in, self.normalize)
enc0 = self.encoder1(x_in)
enc1 = self.encoder2(hidden_states_out[0])
enc2 = self.encoder3(hidden_states_out[1])
enc3 = self.encoder4(hidden_states_out[2])
dec4 = self.encoder10(hidden_states_out[4])
from monai.networks.blocks.text_embedding import TextEncoder
from monai.networks.blocks.head_controller import HeadController

dec3 = self.decoder5(dec4, hidden_states_out[3])
dec2 = self.decoder4(dec3, enc3)
dec1 = self.decoder3(dec2, enc2)
dec0 = self.decoder2(dec1, enc1)
out = self.decoder1(dec0, enc0)

return dec4, out
from monai.networks.nets import SwinUNETR

class Universal_model(nn.Module):
"""
Expand All @@ -66,11 +20,12 @@ def __init__(
img_size,
in_channels,
out_channels,
bottleneck_size=768,
text_dim=512,
hidden_size=256,
backbone = 'swinunetr',
encoding = 'clip_embedding'
bottleneck_size: int = 768,
text_dim: int = 512,
hidden_size: int = 256,
backbone: str = 'swinunetr',
encoding: str = 'clip_embedding',
logits_options: list = None,
):
super().__init__()
self.backbone_name = backbone
Expand All @@ -88,7 +43,7 @@ def __init__(
else:
raise Exception('{} backbone is not implemented, please add your own'.format(backbone))
self.class_num = out_channels

self.logits_options = logits_options
# text encoder
self.text_encoder = TextEncoder(
out_channels=self.class_num,
Expand All @@ -99,6 +54,7 @@ def __init__(

self.head_controller = HeadController(
out_channels=out_channels,
text_encoding=True
)

self.GAP = nn.Sequential(
Expand All @@ -118,22 +74,62 @@ def load_params(self, model_dict):
self.backbone.load_state_dict(store_dict)
print('Use swin unetr pretrained weights')
else:
raise Exception('{} backbone is not implemented, please add your own'.format(backbone))

def encoding_task(self, task_id):
N = task_id.shape[0]
task_encoding = torch.zeros(size=(N, 7))
for i in range(N):
task_encoding[i, task_id[i]]=1
return task_encoding.cuda()
raise Exception('{} backbone is not implemented, please add your own'.format(self.backbone_name))

def forward(self, x_in):
# get backbone feature
dec4, out = self.backbone(x_in)
# get task text encoding
task_encoding = self.text_encoder()
text_encoding = self.text_encoder()
# text controlled outputs
x_feat = self.GAP(dec4)
out = self.head_controller(x_feat, out, task_encoding)
out = self.head_controller(x_feat, out, text_encoding, self.logits_options)

return out

class SwinUNETR_backbone(SwinUNETR):
"""
Universal Model uses SwinUNETR as backbone without the segmentation head based on:
"Hatamizadeh et al.,
Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images
<https://arxiv.org/abs/2201.01266>" and
"Tang et al.,
Self-Supervised Pre-Training of Swin Transformers for 3D Medical Image Analysis
<https://arxiv.org/abs/2111.14791>"
"""
def __init__(
self,
img_size: Union[Sequence[int], int],
in_channels: int,
out_channels: int,
depths: Sequence[int] = (2, 2, 2, 2),
num_heads: Sequence[int] = (3, 6, 12, 24),
feature_size: int = 48,
norm_name: Union[Tuple, str] = "instance",
drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
dropout_path_rate: float = 0.0,
normalize: bool = True,
use_checkpoint: bool = False,
spatial_dims: int = 3,
):
super(SwinUNETR_backbone, self).__init__(img_size,in_channels,out_channels,feature_size=48)

def forward(self, x_in):
hidden_states_out = self.swinViT(x_in, self.normalize)
enc0 = self.encoder1(x_in)
enc1 = self.encoder2(hidden_states_out[0])
enc2 = self.encoder3(hidden_states_out[1])
enc3 = self.encoder4(hidden_states_out[2])
dec4 = self.encoder10(hidden_states_out[4])

dec3 = self.decoder5(dec4, hidden_states_out[3])
dec2 = self.decoder4(dec3, enc3)
dec1 = self.decoder3(dec2, enc2)
dec0 = self.decoder2(dec1, enc1)
out = self.decoder1(dec0, enc0)

return dec4, out

0 comments on commit 6bb3d12

Please sign in to comment.