-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodel.py
92 lines (72 loc) · 3.02 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
class DoubleConv(torch.nn.Module):
"""
Helper Class which implements the intermediate Convolutions
"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.step = torch.nn.Sequential(torch.nn.Conv2d(in_channels, out_channels, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(out_channels, out_channels, 3, padding=1),
torch.nn.ReLU())
def forward(self, X):
return self.step(X)
class UNet(torch.nn.Module):
"""
This class implements a UNet for the Segmentation
We use 3 down- and 3 UpConvolutions and two Convolutions in each step
"""
def __init__(self):
"""Sets up the U-Net Structure
"""
super().__init__()
############# DOWN #####################
self.layer1 = DoubleConv(1, 64)
self.layer2 = DoubleConv(64, 128)
self.layer3 = DoubleConv(128, 256)
self.layer4 = DoubleConv(256, 512)
#########################################
############## UP #######################
self.layer5 = DoubleConv(512 + 256, 256)
self.layer6 = DoubleConv(256+128, 128)
self.layer7 = DoubleConv(128+64, 64)
self.layer8 = torch.nn.Conv2d(64, 1, 1)
#########################################
self.maxpool = torch.nn.MaxPool2d(2)
def forward(self, x):
####### DownConv 1#########
x1 = self.layer1(x)
x1m = self.maxpool(x1)
###########################
####### DownConv 2#########
x2 = self.layer2(x1m)
x2m = self.maxpool(x2)
###########################
####### DownConv 3#########
x3 = self.layer3(x2m)
x3m = self.maxpool(x3)
###########################
##### Intermediate Layer ##
x4 = self.layer4(x3m)
###########################
####### UpCONV 1#########
x5 = torch.nn.Upsample(scale_factor=2, mode="bilinear")(x4) # Upsample with a factor of 2
#x5 = torch.nn.ConvTranspose2d(512, 512, 2, 2)(x4)
x5 = torch.cat([x5, x3], dim=1) # Skip-Connection
x5 = self.layer5(x5)
###########################
####### UpCONV 2#########
x6 = torch.nn.Upsample(scale_factor=2, mode="bilinear")(x5)
#x6 = torch.nn.ConvTranspose2d(256, 256, 2, 2)(x5)
x6 = torch.cat([x6, x2], dim=1) # Skip-Connection
x6 = self.layer6(x6)
###########################
####### UpCONV 3#########
x7 = torch.nn.Upsample(scale_factor=2, mode="bilinear")(x6)
#x7 = torch.nn.ConvTranspose2d(128, 128, 2, 2)(x6)
x7 = torch.cat([x7, x1], dim=1)
x7 = self.layer7(x7)
###########################
####### Predicted segmentation#########
ret = self.layer8(x7)
return ret