-
Notifications
You must be signed in to change notification settings - Fork 33
/
Copy pathtrain_detector.py
executable file
·163 lines (125 loc) · 5.77 KB
/
train_detector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# FootAndBall: Integrated Player and Ball Detector
# Jacek Komorowski, Grzegorz Kurzejamski, Grzegorz Sarwas
# Copyright (c) 2020 Sport Algorithmics and Gaming
#
# Train FootAndBall detector on ISSIA-CNR Soccer and SoccerPlayerDetection dataset
#
import tqdm
import argparse
import pickle
import numpy as np
import os
import time
import torch
import torch.optim as optim
from network import footandball
from data.data_reader import make_dataloaders
from network.ssd_loss import SSDLoss
from misc.config import Params
MODEL_FOLDER = 'models'
def train_model(model, optimizer, scheduler, num_epochs, dataloaders, device, model_name):
# Weight for components of the loss function.
# Ball-related loss and player-related loss are mean losses (loss per one positive example)
alpha_l_player = 0.01
alpha_c_player = 1.
alpha_c_ball = 5.
# Normalize weights
total = alpha_l_player + alpha_c_player + alpha_c_ball
alpha_l_player /= total
alpha_c_player /= total
alpha_c_ball /= total
# Loss function
criterion = SSDLoss(neg_pos_ratio=3)
is_validation_set = 'val' in dataloaders
if is_validation_set:
phases = ['train', 'val']
else:
phases = ['train']
# Training statistics
training_stats = {'train': [], 'val': []}
print('Training...')
for epoch in tqdm.tqdm(range(num_epochs)):
# Each epoch has a training and validation phase
# for phase in ['train', 'val']:
for phase in phases:
if phase == 'train':
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
batch_stats = {'loss': [], 'loss_ball_c': [], 'loss_player_c': [], 'loss_player_l': []}
count_batches = 0
# Iterate over data.
for ndx, (images, boxes, labels) in enumerate(dataloaders[phase]):
images = images.to(device)
h, w = images.shape[-2], images.shape[-1]
gt_maps = model.groundtruth_maps(boxes, labels, (h, w))
gt_maps = [e.to(device) for e in gt_maps]
count_batches += 1
with torch.set_grad_enabled(phase == 'train'):
predictions = model(images)
# Backpropagation
optimizer.zero_grad()
loss_l_player, loss_c_player, loss_c_ball = criterion(predictions, gt_maps)
loss = alpha_l_player * loss_l_player + alpha_c_player * loss_c_player + alpha_c_ball * loss_c_ball
# backward + optimize only if in training phase
if phase == 'train':
loss.backward()
optimizer.step()
# statistics
count_batches += 1
batch_stats['loss'].append(loss.item())
batch_stats['loss_ball_c'].append(loss_c_ball.item())
batch_stats['loss_player_c'].append(loss_c_player.item())
batch_stats['loss_player_l'].append(loss_l_player.item())
# Average stats per batch
avg_batch_stats = {}
for e in batch_stats:
avg_batch_stats[e] = np.mean(batch_stats[e])
training_stats[phase].append(avg_batch_stats)
s = '{} Avg. loss total / ball conf. / player conf. / player loc.: {:.4f} / {:.4f} / {:.4f} / {:.4f}'
print(s.format(phase, avg_batch_stats['loss'], avg_batch_stats['loss_ball_c'],
avg_batch_stats['loss_player_c'], avg_batch_stats['loss_player_l']))
# Scheduler step
scheduler.step()
print('')
model_filepath = os.path.join(MODEL_FOLDER, model_name + '_final' + '.pth')
torch.save(model.state_dict(), model_filepath)
with open('training_stats_{}.pickle'.format(model_name), 'wb') as handle:
pickle.dump(training_stats, handle, protocol=pickle.HIGHEST_PROTOCOL)
return training_stats
def train(params: Params):
if not os.path.exists(MODEL_FOLDER):
os.mkdir(MODEL_FOLDER)
assert os.path.exists(MODEL_FOLDER), ' Cannot create folder to save trained model: {}'.format(MODEL_FOLDER)
dataloaders = make_dataloaders(params)
print('Training set: Dataset size: {}'.format(len(dataloaders['train'].dataset)))
if 'val' in dataloaders:
print('Validation set: Dataset size: {}'.format(len(dataloaders['val'].dataset)))
# Create model
device = "cuda" if torch.cuda.is_available() else 'cpu'
model = footandball.model_factory(params.model, 'train')
model.print_summary(show_architecture=True)
model = model.to(device)
model_name = 'model_' + time.strftime("%Y%m%d_%H%M")
print('Model name: {}'.format(model_name))
optimizer = optim.Adam(model.parameters(), lr=params.lr)
scheduler_milestones = [int(params.epochs * 0.75)]
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, scheduler_milestones, gamma=0.1)
train_model(model, optimizer, scheduler, params.epochs, dataloaders, device, model_name)
if __name__ == '__main__':
print('Train FoootAndBall detector on ISSIA dataset')
parser = argparse.ArgumentParser()
parser.add_argument('--config', help='Path to the configuration file', type=str, default='config.txt')
parser.add_argument('--debug', dest='debug', help='debug mode', action='store_true')
args = parser.parse_args()
print('Config path: {}'.format(args.config))
print('Debug mode: {}'.format(args.debug))
params = Params(args.config)
params.print()
train(params)
'''
NOTES:
In camera 5 some of the player bboxes are moved by a few pixels from the true position.
When evaluating mean precision use smaller IoU ratio, otherwise detection results are poor.
Alternatively add some margin to ISSIA ground truth bboxes.
'''