-
Notifications
You must be signed in to change notification settings - Fork 0
/
ConvNeXt.py
149 lines (122 loc) · 4.83 KB
/
ConvNeXt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import torch
import torchvision
class ConvNeXtBlock(torch.nn.Module):
def __init__(self, in_channels, out_channels, expansion_ratio, dropout=0.1):
super().__init__()
hidden_dim = in_channels * expansion_ratio
# Depthwise Convolution
self.spatial_mixing = torch.nn.Sequential(
torch.nn.Conv2d(in_channels, in_channels, kernel_size=7, padding=3, groups=in_channels),
torch.nn.BatchNorm2d(in_channels)
)
# Pointwise Convolution, Upsampling Channels
self.feature_mixing = torch.nn.Sequential(
torch.nn.Conv2d(in_channels, hidden_dim, kernel_size=1, stride=1, padding=0),
torch.nn.GELU(),
)
# Pointwise convolution, Downsampling Channels
self.bottleneck = torch.nn.Sequential(
torch.nn.Conv2d(hidden_dim, out_channels, kernel_size=1, stride=1, padding=0)
)
# Drop Path / Stochastic Depth
self.stochastic_depth = torchvision.ops.StochasticDepth(p=dropout, mode="batch")
def forward(self, x):
out = self.spatial_mixing(x)
out = self.feature_mixing(out)
out = self.bottleneck(out)
# Residual Connection within Block
# out_channels of each ConvNeXtBlock is the same as in_channels, so we have a residual connection added without any upsampling
return x + self.stochastic_depth(out)
class ConvNeXt(torch.nn.Module):
def __init__(self, num_classes, in_channels, stage_config):
super().__init__()
self.num_classes = num_classes
self.in_channels = in_channels
# Define Stage Configurations
self.stage_config = stage_config
self.layers = []
# Define Stem (Kernel Size 4, Stride 4)
stem = torch.nn.Sequential(
torch.nn.Conv2d(in_channels, self.stage_config[0][1], kernel_size=4, stride=4),
torch.nn.BatchNorm2d(96)
)
self.layers.append(stem)
# Define Each Stage
for i in range(0, len(self.stage_config)):
# Create stage
expansion_ratio, in_channels, num_blocks, dropout = self.stage_config[i]
# Append Blocks Depth Number of Times
for j in range(0, num_blocks):
self.layers.append(
ConvNeXtBlock(in_channels=in_channels, out_channels=in_channels,
expansion_ratio=expansion_ratio, dropout=dropout)
)
# Append downsampling layers at the end of each stage
if i < len(self.stage_config)-1:
# Append downsampling layer
next_stage_in_channels = self.stage_config[i+1][1]
dsl_layer = torch.nn.Sequential(
torch.nn.Conv2d(in_channels, next_stage_in_channels, kernel_size=2, stride=2),
torch.nn.BatchNorm2d(next_stage_in_channels),
)
self.layers.append(dsl_layer)
self.layers = torch.nn.Sequential(*self.layers)
# Embeddings
self.embeddings = torch.nn.Sequential(
torch.nn.BatchNorm2d(self.stage_config[-1][1]),
torch.nn.AdaptiveAvgPool2d(1),
torch.nn.Flatten(),
# Optional - Activation (GELU?)
)
# Classification Layer
self.head = torch.nn.Linear(self.stage_config[-1][1], self.num_classes)
def forward(self, x):
out = self.layers(x)
out = self.embeddings(out)
out = self.head(out)
return out
def ConvNeXt_T(in_channels, num_classes):
stage_config = [
# Expansion Ratio, Channels, No. of Blocks/Depth, Dropout
[4, 96, 3, 0.0],
[4, 192, 3, 0.0],
[4, 384, 9, 0.0],
[4, 768, 3, 0.0],
]
return ConvNeXt(in_channels=in_channels, num_classes=num_classes, stage_config=stage_config)
def ConvNeXt_S(in_channels, num_classes):
stage_config = [
# Expansion Ratio, Channels, No. of Blocks/Depth, Dropout
[4, 96, 3, 0.1],
[4, 192, 3, 0.1],
[4, 384, 27, 0.1],
[4, 768, 3, 0.1],
]
return ConvNeXt(in_channels=in_channels, num_classes=num_classes, stage_config=stage_config)
def ConvNeXt_B(in_channels, num_classes):
stage_config = [
# Expansion Ratio, Channels, No. of Blocks/Depth, Dropout
[4, 128, 3, 0.2],
[4, 256, 3, 0.2],
[4, 512, 27, 0.2],
[4, 1024, 3, 0.2],
]
return ConvNeXt(in_channels=in_channels, num_classes=num_classes, stage_config=stage_config)
def ConvNeXt_L(in_channels, num_classes):
stage_config = [
# Expansion Ratio, Channels, No. of Blocks/Depth, Dropout
[4, 192, 3, 0.3],
[4, 384, 3, 0.3],
[4, 768, 27, 0.3],
[4, 1536, 3, 0.3],
]
return ConvNeXt(in_channels=in_channels, num_classes=num_classes, stage_config=stage_config)
def ConvNeXt_XL(in_channels, num_classes):
stage_config = [
# Expansion Ratio, Channels, No. of Blocks/Depth, Dropout
[4, 256, 3, 0.4],
[4, 512, 3, 0.4],
[4, 1024, 27, 0.4],
[4, 2048, 3, 0.4],
]
return ConvNeXt(in_channels=in_channels, num_classes=num_classes, stage_config=stage_config)