-
Notifications
You must be signed in to change notification settings - Fork 176
/
Copy pathgaitgl.py
199 lines (170 loc) · 7.39 KB
/
gaitgl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..base_model import BaseModel
from ..modules import SeparateFCs, BasicConv3d, PackSequenceWrapper, SeparateBNNecks
class GLConv(nn.Module):
def __init__(self, in_channels, out_channels, halving, fm_sign=False, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False, **kwargs):
super(GLConv, self).__init__()
self.halving = halving
self.fm_sign = fm_sign
self.global_conv3d = BasicConv3d(
in_channels, out_channels, kernel_size, stride, padding, bias, **kwargs)
self.local_conv3d = BasicConv3d(
in_channels, out_channels, kernel_size, stride, padding, bias, **kwargs)
def forward(self, x):
'''
x: [n, c, s, h, w]
'''
gob_feat = self.global_conv3d(x)
if self.halving == 0:
lcl_feat = self.local_conv3d(x)
else:
h = x.size(3)
split_size = int(h // 2**self.halving)
lcl_feat = x.split(split_size, 3)
lcl_feat = torch.cat([self.local_conv3d(_) for _ in lcl_feat], 3)
if not self.fm_sign:
feat = F.leaky_relu(gob_feat) + F.leaky_relu(lcl_feat)
else:
feat = F.leaky_relu(torch.cat([gob_feat, lcl_feat], dim=3))
return feat
class GeMHPP(nn.Module):
def __init__(self, bin_num=[64], p=6.5, eps=1.0e-6):
super(GeMHPP, self).__init__()
self.bin_num = bin_num
self.p = nn.Parameter(
torch.ones(1)*p)
self.eps = eps
def gem(self, ipts):
return F.avg_pool2d(ipts.clamp(min=self.eps).pow(self.p), (1, ipts.size(-1))).pow(1. / self.p)
def forward(self, x):
"""
x : [n, c, h, w]
ret: [n, c, p]
"""
n, c = x.size()[:2]
features = []
for b in self.bin_num:
z = x.view(n, c, b, -1)
z = self.gem(z).squeeze(-1)
features.append(z)
return torch.cat(features, -1)
class GaitGL(BaseModel):
"""
GaitGL: Gait Recognition via Effective Global-Local Feature Representation and Local Temporal Aggregation
Arxiv : https://arxiv.org/pdf/2011.01461.pdf
"""
def __init__(self, *args, **kargs):
super(GaitGL, self).__init__(*args, **kargs)
def build_network(self, model_cfg):
in_c = model_cfg['channels']
class_num = model_cfg['class_num']
dataset_name = self.cfgs['data_cfg']['dataset_name']
if dataset_name in ['OUMVLP', 'GREW']:
# For OUMVLP and GREW
self.conv3d = nn.Sequential(
BasicConv3d(1, in_c[0], kernel_size=(3, 3, 3),
stride=(1, 1, 1), padding=(1, 1, 1)),
nn.LeakyReLU(inplace=True),
BasicConv3d(in_c[0], in_c[0], kernel_size=(3, 3, 3),
stride=(1, 1, 1), padding=(1, 1, 1)),
nn.LeakyReLU(inplace=True),
)
self.LTA = nn.Sequential(
BasicConv3d(in_c[0], in_c[0], kernel_size=(
3, 1, 1), stride=(3, 1, 1), padding=(0, 0, 0)),
nn.LeakyReLU(inplace=True)
)
self.GLConvA0 = nn.Sequential(
GLConv(in_c[0], in_c[1], halving=1, fm_sign=False, kernel_size=(
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)),
GLConv(in_c[1], in_c[1], halving=1, fm_sign=False, kernel_size=(
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)),
)
self.MaxPool0 = nn.MaxPool3d(
kernel_size=(1, 2, 2), stride=(1, 2, 2))
self.GLConvA1 = nn.Sequential(
GLConv(in_c[1], in_c[2], halving=1, fm_sign=False, kernel_size=(
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)),
GLConv(in_c[2], in_c[2], halving=1, fm_sign=False, kernel_size=(
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)),
)
self.GLConvB2 = nn.Sequential(
GLConv(in_c[2], in_c[3], halving=1, fm_sign=False, kernel_size=(
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)),
GLConv(in_c[3], in_c[3], halving=1, fm_sign=True, kernel_size=(
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)),
)
else:
# For CASIA-B or other unstated datasets.
self.conv3d = nn.Sequential(
BasicConv3d(1, in_c[0], kernel_size=(3, 3, 3),
stride=(1, 1, 1), padding=(1, 1, 1)),
nn.LeakyReLU(inplace=True)
)
self.LTA = nn.Sequential(
BasicConv3d(in_c[0], in_c[0], kernel_size=(
3, 1, 1), stride=(3, 1, 1), padding=(0, 0, 0)),
nn.LeakyReLU(inplace=True)
)
self.GLConvA0 = GLConv(in_c[0], in_c[1], halving=3, fm_sign=False, kernel_size=(
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
self.MaxPool0 = nn.MaxPool3d(
kernel_size=(1, 2, 2), stride=(1, 2, 2))
self.GLConvA1 = GLConv(in_c[1], in_c[2], halving=3, fm_sign=False, kernel_size=(
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
self.GLConvB2 = GLConv(in_c[2], in_c[2], halving=3, fm_sign=True, kernel_size=(
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
self.TP = PackSequenceWrapper(torch.max)
self.HPP = GeMHPP()
self.Head0 = SeparateFCs(64, in_c[-1], in_c[-1])
if 'SeparateBNNecks' in model_cfg.keys():
self.BNNecks = SeparateBNNecks(**model_cfg['SeparateBNNecks'])
self.Bn_head = False
else:
self.Bn = nn.BatchNorm1d(in_c[-1])
self.Head1 = SeparateFCs(64, in_c[-1], class_num)
self.Bn_head = True
def forward(self, inputs):
ipts, labs, _, _, seqL = inputs
seqL = None if not self.training else seqL
if not self.training and len(labs) != 1:
raise ValueError(
'The input size of each GPU must be 1 in testing mode, but got {}!'.format(len(labs)))
sils = ipts[0].unsqueeze(1)
del ipts
n, _, s, h, w = sils.size()
if s < 3:
repeat = 3 if s == 1 else 2
sils = sils.repeat(1, 1, repeat, 1, 1)
outs = self.conv3d(sils)
outs = self.LTA(outs)
outs = self.GLConvA0(outs)
outs = self.MaxPool0(outs)
outs = self.GLConvA1(outs)
outs = self.GLConvB2(outs) # [n, c, s, h, w]
outs = self.TP(outs, seqL=seqL, options={"dim": 2})[0] # [n, c, h, w]
outs = self.HPP(outs) # [n, c, p]
gait = self.Head0(outs) # [n, c, p]
if self.Bn_head: # Original GaitGL Head
bnft = self.Bn(gait) # [n, c, p]
logi = self.Head1(bnft) # [n, c, p]
embed = bnft
else: # BNNechk as Head
bnft, logi = self.BNNecks(gait) # [n, c, p]
embed = gait
n, _, s, h, w = sils.size()
retval = {
'training_feat': {
'triplet': {'embeddings': embed, 'labels': labs},
'softmax': {'logits': logi, 'labels': labs}
},
'visual_summary': {
'image/sils': sils.view(n*s, 1, h, w)
},
'inference_feat': {
'embeddings': embed
}
}
return retval