-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathfile_utils.py
201 lines (183 loc) · 7.64 KB
/
file_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import re
import time
import types
import os
from copy import deepcopy
import importlib
import numpy as np
import json as js
import torch
class Logger(object):
def __init__(self, filename=None, overwrite=True):
self.filename = filename
if(filename is None):
self.log_to_file = False
else:
self.log_to_file = True
if(overwrite):
self.file_obj = open(filename, 'w', buffering=1)
else:
self.file_obj = open(filename, 'r+', buffering=1)
self.file_obj.seek(0)
def js_dumpstring(self, obj):
"""Dump json string with special CustomEncoder"""
return js.dumps(obj, sort_keys=True, indent=2, cls=CustomEncoder)
def log_trainable_variables(self, name_list):
"""
Use logging to write names of trainable variables in model
Inputs:
name_list: list containing variable names
"""
js_str = self.js_dumpstring(name_list)
self.log_info('<train_vars>'+js_str+'</train_vars>')
def log_params(self, params):
"""
Use logging to write model params
Inputs:
params: [dict] containing parameters values
"""
out_params = deepcopy(params)
if('ensemble_params' in out_params.keys()):
for sub_idx, sub_params in enumerate(out_params['ensemble_params']):
sub_params.set_params()
for key, value in sub_params.__dict__.items():
if(key != 'rand_state'):
new_dict_key = f'{sub_idx}_{key}'
out_params[new_dict_key] = value
del out_params['ensemble_params']
if('rand_state' in out_params.keys()):
del out_params['rand_state']
js_str = self.js_dumpstring(out_params)
self.log_info('<params>'+js_str+'</params>')
def log_info(self, string):
"""Log input string"""
now = time.localtime(time.time())
time_str = time.strftime('%m/%d/%y %H:%M:%S', now)
out_str = '\n' + time_str + ' -- ' + str(string)
if(self.log_to_file):
self.file_obj.write(out_str)
else:
print(out_str)
def load_file(self, filename=None):
"""
Load log file into memory
Outputs:
log_text: [str] containing log file text
"""
if(filename is None):
self.file_obj.seek(0)
else:
self.file_obj = open(filename, 'r', buffering=1)
text = self.file_obj.read()
return text
def read_js(self, tokens, text):
"""
Read js string encased by tokens and convert to python object
Outpus:
output: converted python object
Inputs:
tokens: [list] of length 2 with [0] entry indicating start token and [1]
entry indicating end token
text: [str] containing text to parse, can be obtained by calling load_file()
TODO: Verify that js_matches is the same type for both conditionals at the end
js_matches should be a list at all times. That way when e.g. read_params
is called the output is a list no matter how many params specifications there are
in the logfile.
"""
assert type(tokens) == list, ('Input variable tokens must be a list')
assert len(tokens) == 2, ('Input variable tokens must be a list of length 2')
matches = re.findall(re.escape(tokens[0])+r'([\s\S]*?)'+re.escape(tokens[1]), text)
if(len(matches) > 1):
js_matches = [js.loads(match) for match in matches]
else:
js_matches = [js.loads(matches[0])]
return js_matches
def read_params(self, text):
"""
Read params from text file and return as a params object or list of params objects
Outpus:
params: converted python object
Inputs:
text: [str] containing text to parse, can be obtained by calling load_file()
"""
tokens = ['<params>', '</params>']
params = self.read_js(tokens, text)
param_list = []
for param_dict in params:
param_obj = type('param_obj', (), {})()
if(param_dict['model_type'] == 'ensemble'):
param_obj.ensemble_params = []
ensemble_nums = set()
for key, value in param_dict.items():
if(param_dict['model_type'] == 'ensemble'):
key_split = key.split('_')
if(key_split[0].isdigit()): # ensemble params are prefaced with ensemble index
ens_num = int(key_split[0])
if(ens_num not in ensemble_nums):
ensemble_nums.add(ens_num)
param_obj.ensemble_params.append(types.SimpleNamespace())
setattr(param_obj.ensemble_params[ens_num], '_'.join(key_split[1:]), value)
else: # if it is not a digit then it is a general param
setattr(param_obj, key, value)
else:
setattr(param_obj, key, value)
def optimizer_dict_to_obj(param_obj):
if(hasattr(param_obj, 'optimizer')): # convert optimizer dict to class
optimizer_dict = deepcopy(param_obj.optimizer)
param_obj.optimizer = types.SimpleNamespace()
for key, value in optimizer_dict.items():
setattr(param_obj.optimizer, key, value)
if(param_obj.model_type == 'ensemble'): # each model in ensembles could have optimizers
for model_param_obj in param_obj.ensemble_params:
optimizer_dict_to_obj(model_param_obj)
else:
optimizer_dict_to_obj(param_obj)
param_list.append(param_obj)
return param_list
def read_stats(self, text):
"""
Generate dictionary of lists that contain stats from log text
Outpus:
stats: [dict] containing run statistics
Inputs:
text: [str] containing text to parse, can be obtained by calling load_file()
"""
tokens = ['<stats>', '</stats>']
js_matches = self.read_js(tokens, text)
stats = {}
for js_match in js_matches:
if(type(js_match) is str):
js_match = {js_match:js_match}
for key in js_match.keys():
if(key in stats):
stats[key].append(js_match[key])
else:
stats[key] = [js_match[key]]
return stats
def __del__(self):
if(self.log_to_file and hasattr(self, 'file_obj')):
self.file_obj.close()
class CustomEncoder(js.JSONEncoder):
def default(self, obj):
if(callable(obj)):
return None
elif(isinstance(obj, np.integer)):
return int(obj)
elif(isinstance(obj, np.floating)):
return float(obj)
elif(isinstance(obj, np.ndarray)):
return obj.tolist()
elif(isinstance(obj, torch.device)):
return obj.type
elif(isinstance(obj, torch.dtype)):
return str(obj)
elif(isinstance(obj, types.SimpleNamespace)):
return obj.__dict__
else:
return super(CustomEncoder, self).default(obj)
def python_module_from_file(py_module_name, file_name):
assert os.path.isfile(file_name), (f'Error: {file_name} does not exist!')
spec = importlib.util.spec_from_file_location(py_module_name, file_name)
py_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(py_module)
return py_module