-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathtrain_corrector.py
98 lines (81 loc) · 2.78 KB
/
train_corrector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import time
import torch
from fluidestimator.config import get_cfg
from fluidestimator.modeling import PwcNet as UnPwcNet
from fluidestimator.engine import DefaultTrainer, SimpleTrainer, default_argument_parser, default_setup, launch
class SequenceTrainer(SimpleTrainer):
"""
A trainer to handle the sequence data
"""
def __init__(self, model, data_loader, optimizer, predictor):
super().__init__(model, data_loader, optimizer)
self.predictor = predictor
def run_step(self):
assert self.model.training, "[SimpleTrainer] modeling was changed to eval mode!"
start = time.perf_counter()
data = next(self._data_loader_iter)
images = torch.stack([x["image"] for x in data], dim=0)
data_time = time.perf_counter() - start
seq_len = images.shape[1]
prev = None
corrected = None
loss_dict = {}
for num in range(seq_len):
pred = self.predictor(images[:,num,...])
if prev is not None:
sub_loss_dict, corrected = self.model(prev, pred)
for name, v in sub_loss_dict.items():
if name in loss_dict:
loss_dict[name] += v
else:
loss_dict[name] = v
if corrected is None:
prev = pred
else:
prev = corrected
losses = sum(loss_dict.values())
self.optimizer.zero_grad()
losses.backward()
self._write_metrics(loss_dict, data_time)
self.optimizer.step()
class Trainer(DefaultTrainer):
"""
Replace the SimpleTrainer to SequenceTrainer
"""
def __init__(self, cfg, predictor):
super().__init__(cfg)
self._trainer = SequenceTrainer(self.model, self.data_loader, self.optimizer, predictor)
def setup(args):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg()
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
default_setup(cfg, args)
return cfg
def main(args):
cfg = setup(args)
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
# Setup predictor
un_pwcnet = UnPwcNet(cfg).to(cfg.DEVICE)
un_pwcnet.load_state_dict(torch.load(cfg.MODEL.CORRECTOR.UPSTREAM_PREDICTOR)['model'])
un_pwcnet.eval()
trainer = Trainer(cfg, un_pwcnet)
trainer.resume_or_load(resume=args.resume)
return trainer.train()
if __name__ == "__main__":
# Corrector training
args = default_argument_parser().parse_args()
args.resume = True
print("Command Line Args:", args)
launch(
main,
args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
args=(args,),
)