-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathevaluate_inference_time.py
53 lines (43 loc) · 1.42 KB
/
evaluate_inference_time.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import time
import json
import numpy as np
import torch
import argparse
import yaml
import sys
import nomenclature
from utils import load_args
parser = argparse.ArgumentParser(description='Do stuff.')
parser.add_argument('--config_file', type = str, required = True)
parser.add_argument('--model', type = str, required = True)
parser.add_argument('--name', type = str, required = True)
args = parser.parse_args()
args, cfg = load_args(args)
print(args)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
output = {
'period_length': [],
'time-mean': [],
'time-std': []
}
for period_length in [12, 24, 36, 48, 60, 72, 84, 96]:
mock_input = torch.rand((512, 3, period_length, 18, 1)).cuda()
times = []
args.period_length = period_length
model = nomenclature.MODELS[args.model](args).cuda()
model.eval()
model.train(False)
num_params = count_parameters(model)
print(args.name, 'params:', num_params)
for _ in range(100):
with torch.no_grad():
start_time = time.time()
model(mock_input)
end_time = time.time()
times.append(end_time - start_time)
output['period_length'].append(period_length)
output['time-mean'].append(float(np.mean(times)))
output['time-std'].append(float(np.std(times)))
with open(f'experiments/inference/{args.name}.json', 'wt') as f:
json.dump(output, f)