-
Notifications
You must be signed in to change notification settings - Fork 176
/
Copy pathgaitpart.py
126 lines (105 loc) · 4.54 KB
/
gaitpart.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
import torch.nn as nn
from ..base_model import BaseModel
from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs
from utils import clones
class BasicConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
super(BasicConv1d, self).__init__()
self.conv = nn.Conv1d(in_channels, out_channels,
kernel_size, bias=False, **kwargs)
def forward(self, x):
ret = self.conv(x)
return ret
class TemporalFeatureAggregator(nn.Module):
def __init__(self, in_channels, squeeze=4, parts_num=16):
super(TemporalFeatureAggregator, self).__init__()
hidden_dim = int(in_channels // squeeze)
self.parts_num = parts_num
# MTB1
conv3x1 = nn.Sequential(
BasicConv1d(in_channels, hidden_dim, 3, padding=1),
nn.LeakyReLU(inplace=True),
BasicConv1d(hidden_dim, in_channels, 1))
self.conv1d3x1 = clones(conv3x1, parts_num)
self.avg_pool3x1 = nn.AvgPool1d(3, stride=1, padding=1)
self.max_pool3x1 = nn.MaxPool1d(3, stride=1, padding=1)
# MTB1
conv3x3 = nn.Sequential(
BasicConv1d(in_channels, hidden_dim, 3, padding=1),
nn.LeakyReLU(inplace=True),
BasicConv1d(hidden_dim, in_channels, 3, padding=1))
self.conv1d3x3 = clones(conv3x3, parts_num)
self.avg_pool3x3 = nn.AvgPool1d(5, stride=1, padding=2)
self.max_pool3x3 = nn.MaxPool1d(5, stride=1, padding=2)
# Temporal Pooling, TP
self.TP = torch.max
def forward(self, x):
"""
Input: x, [n, c, s, p]
Output: ret, [n, c, p]
"""
n, c, s, p = x.size()
x = x.permute(3, 0, 1, 2).contiguous() # [p, n, c, s]
feature = x.split(1, 0) # [[1, n, c, s], ...]
x = x.view(-1, c, s)
# MTB1: ConvNet1d & Sigmoid
logits3x1 = torch.cat([conv(_.squeeze(0)).unsqueeze(0)
for conv, _ in zip(self.conv1d3x1, feature)], 0)
scores3x1 = torch.sigmoid(logits3x1)
# MTB1: Template Function
feature3x1 = self.avg_pool3x1(x) + self.max_pool3x1(x)
feature3x1 = feature3x1.view(p, n, c, s)
feature3x1 = feature3x1 * scores3x1
# MTB2: ConvNet1d & Sigmoid
logits3x3 = torch.cat([conv(_.squeeze(0)).unsqueeze(0)
for conv, _ in zip(self.conv1d3x3, feature)], 0)
scores3x3 = torch.sigmoid(logits3x3)
# MTB2: Template Function
feature3x3 = self.avg_pool3x3(x) + self.max_pool3x3(x)
feature3x3 = feature3x3.view(p, n, c, s)
feature3x3 = feature3x3 * scores3x3
# Temporal Pooling
ret = self.TP(feature3x1 + feature3x3, dim=-1)[0] # [p, n, c]
ret = ret.permute(1, 2, 0).contiguous() # [n, p, c]
return ret
class GaitPart(BaseModel):
def __init__(self, *args, **kargs):
super(GaitPart, self).__init__(*args, **kargs)
"""
GaitPart: Temporal Part-based Model for Gait Recognition
Paper: https://openaccess.thecvf.com/content_CVPR_2020/papers/Fan_GaitPart_Temporal_Part-Based_Model_for_Gait_Recognition_CVPR_2020_paper.pdf
Github: https://github.com/ChaoFan96/GaitPart
"""
def build_network(self, model_cfg):
self.Backbone = self.get_backbone(model_cfg['backbone_cfg'])
head_cfg = model_cfg['SeparateFCs']
self.Head = SeparateFCs(**model_cfg['SeparateFCs'])
self.Backbone = SetBlockWrapper(self.Backbone)
self.HPP = SetBlockWrapper(
HorizontalPoolingPyramid(bin_num=model_cfg['bin_num']))
self.TFA = PackSequenceWrapper(TemporalFeatureAggregator(
in_channels=head_cfg['in_channels'], parts_num=head_cfg['parts_num']))
def forward(self, inputs):
ipts, labs, _, _, seqL = inputs
sils = ipts[0]
if len(sils.size()) == 4:
sils = sils.unsqueeze(1)
del ipts
out = self.Backbone(sils) # [n, c, s, h, w]
out = self.HPP(out) # [n, c, s, p]
out = self.TFA(out, seqL) # [n, c, p]
embs = self.Head(out) # [n, c, p]
n, _, s, h, w = sils.size()
retval = {
'training_feat': {
'triplet': {'embeddings': embs, 'labels': labs}
},
'visual_summary': {
'image/sils': sils.view(n*s, 1, h, w)
},
'inference_feat': {
'embeddings': embs
}
}
return retval