forked from ivanzzh/admm_uav_regression
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathunit.py
73 lines (50 loc) · 2.26 KB
/
unit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import torch.nn as nn
class Unit(nn.Module):
def __init__(self, channel=10):
super(Unit, self).__init__()
self.out1 = 64
self.out2 = 128
self.out3 = 256
self.relu = nn.ReLU(inplace=True)
# self.dropout = nn.Dropout(p=0.2)
# 3D Conv Operation
# self.sub_conv1 = nn.Conv3d(in_channels=1, out_channels=self.out1, kernel_size=(4, 4, 2), stride=(2, 2, 2))
# self.sub_conv2 = nn.Conv3d(in_channels=self.out1, out_channels=self.out2, kernel_size=(2,2,2), stride=(1, 1, 1))
# self.sub_conv3 = nn.Conv3d(in_channels=self.out2, out_channels=self.out3, kernel_size=(2,2,2), stride=(1, 1, 1))
# 2D Conv Operation
self.sub_conv1 = nn.Conv2d(in_channels=channel, out_channels=self.out1, kernel_size=(4, 4), stride=(2, 2))
self.sub_conv2 = nn.Conv2d(in_channels=self.out1, out_channels=self.out2, kernel_size=(2,2), stride=(1, 1))
self.sub_conv3 = nn.Conv2d(in_channels=self.out2, out_channels=self.out3, kernel_size=(2,2), stride=(1, 1))
# 3D Conv Operation
# self.bn3d_sub_1 = nn.BatchNorm3d(self.out1)
# self.bn3d_sub_2 = nn.BatchNorm3d(self.out2)
# self.bn3d_sub_3 = nn.BatchNorm3d(self.out3)
# 2D Conv Operation
self.bn3d_sub_1 = nn.BatchNorm2d(self.out1)
self.bn3d_sub_2 = nn.BatchNorm2d(self.out2)
self.bn3d_sub_3 = nn.BatchNorm2d(self.out3)
# 3D Conv Operation
# self.max_pool_3d1 = nn.MaxPool3d(kernel_size=2, stride=2)
# 2D Conv Operation
self.max_pool_3d1 = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, x):
subx = x
subx = self.sub_conv1(subx)
subx = self.bn3d_sub_1(subx)
subx = self.relu(subx)
subx = self.sub_conv2(subx)
subx = self.bn3d_sub_2(subx)
subx = self.relu(subx)
# x = self.dropout(x)
subx = self.max_pool_3d1(subx)
subx = self.sub_conv3(subx)
subx = self.bn3d_sub_3(subx)
subx = self.relu(subx)
# x = self.dropout(x)
return subx
if __name__ == "__main__":
x = torch.rand(4, 64, 100, 100, 6)
net = Unit()
output = net(x)
# print(output.shape)