-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathrecorder.py
137 lines (107 loc) · 4.29 KB
/
recorder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import sys
import argparse
sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))
import yaml
from tqdm import tqdm
from torch.utils.data import DataLoader, random_split
import pytorch_model_summary
import torchvision.transforms as transforms
import glob
from dataset.dataset import FaceDataset, FaceDatasetVal
from utils import *
from model_factory import model_build
from loss import *
COLOR_LIST = [
'r','seagreen', 'deepskyblue', 'orange', 'gold',
'greenyellow', 'royalblue', 'indigo', 'magenta',
'silver', 'gray', 'k'
]
def mse(real, pred):
assert len(real) == len(pred)
return round(float(np.sum(np.square(real - pred)) / len(pred)), 4)
def recorder(args):
'''Function for recording some models' results'''
##########################################################
# configuration #
##########################################################
# (0) : Test set
path = 'data_samples/samples'
if not os.path.exists(path):
path = os.getcwd() + path
random_pkl_path = 'data_samples/random_reference.pkl'
if not os.path.exists(random_pkl_path):
random_pkl_path = os.getcwd() + random_pkl_path
fix_pkl_path = 'data_samples/fix_reference.pkl'
if not os.path.exists(fix_pkl_path):
fix_pkl_path = os.getcwd() + fix_pkl_path
assert os.path.exists(path) == True
assert os.path.exists(random_pkl_path) == True
assert os.path.exists(fix_pkl_path) == True
with open(random_pkl_path, 'rb') as f:
real_mean_random = pickle.load(f)
with open(fix_pkl_path, 'rb') as f:
real_mean_fix = pickle.load(f)
suffix = '$\\theta$'
fix_plots = {'Real Fix ' + suffix: real_mean_fix}
random_plots = {'Real Random '+ suffix: real_mean_random}
fix_metric = {}
random_metric = {}
# (1) : Bring Model results
for checkpoint in args.checkpoints:
one_fix_pkl = list(glob.glob(checkpoint + '/*_cos_mean_fix.pkl'))[0]
one_random_pkl = list(glob.glob(checkpoint + '/*_cos_mean_fix.pkl'))[0]
exp_name = os.path.basename(one_fix_pkl).replace('_cos_mean_fix.pkl', '')
with open(one_fix_pkl, 'rb') as f:
cos_mean_fix = pickle.load(f)
with open(one_random_pkl, 'rb') as f:
cos_mean_random = pickle.load(f)
fix_plots[exp_name + ' ' + suffix] = cos_mean_fix
random_plots[exp_name + ' ' + suffix] = cos_mean_random
fix_metric[exp_name + ' ' + suffix] = mse(real_mean_fix, cos_mean_fix)
random_metric[exp_name + ' ' + suffix] = mse(real_mean_random, cos_mean_random)
# (2) : Plot Model Results
plt.figure(figsize=(24, 7))
plt.subplot(1, 2, 1)
for i, (key, val) in enumerate(fix_plots.items()):
plt.plot(val, linewidth=2, label=key, color=COLOR_LIST[i])
plt.legend(fontsize=15)
plt.subplot(1, 2, 2)
for i, (key, val) in enumerate(random_plots.items()):
plt.plot(val, linewidth=2, label=key, color=COLOR_LIST[i])
plt.legend(fontsize=15)
plt.savefig(args.save_path)
# (3) : Save Metric as Table
txt_path = args.save_path.replace('.png', '.txt')
with open(txt_path, 'w') as f:
f.write('-----FIX METRIC-----')
f.write('\n')
for i, (key, val) in enumerate(fix_metric.items()):
f.write(key + ' : ' + str(val) + '\n')
f.write('\n')
f.write('-----RANDOM METRIC -----')
f.write('\n')
for i, (key, val) in enumerate(random_metric.items()):
f.write(key + ' : ' + str(val) + '\n')
f.close()
with open(txt_path, 'r') as f:
file = f.read()
print(file)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoints', nargs="+")
parser.add_argument('--save_path', type=str, help='.png or .jpg at last')
args = parser.parse_args()
"""
args.checkpoints = [dir1, dir2, dir2, ...]
folder structure should be like ...
dir1
- {exp_name_1}_cos_mean_fix.pkl
- {exp_name_1}_cos_mean_random.pkl
dir2
- {exp_name_2}_cos_mean_fix.pkl
- {exp_name_2}_cos_mean_random.pkl
...
These *_cos_mean_fix.pkl can be optained w/ test.py, after training model w/ train.py
"""
recorder(args)