-
Notifications
You must be signed in to change notification settings - Fork 313
/
utils.py
94 lines (78 loc) · 3.39 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file implement a class for model configure.
"""
import datetime
import os
import yaml
import random
import shutil
class Config(object):
"""Implementation of Config class for model configure.
Args:
config_file(str): configure filename, which is a yaml file.
isCreate(bool): if true, create some neccessary directories to save models, log file and other outputs.
isSave(bool): if true, save config_file in order to record the configure message.
"""
def __init__(self, config_file, isCreate=False, isSave=False):
self.config_file = config_file
self.config = self.get_config_from_yaml(config_file)
if isCreate:
self.create_necessary_dirs()
if isSave:
self.save_config_file()
def get_config_from_yaml(self, yaml_file):
"""Get the configure hyperparameters from yaml file.
"""
try:
with open(yaml_file, 'r') as f:
config = yaml.load(f)
except Exception:
raise IOError("Error in parsing config file '%s'" % yaml_file)
return config
def create_necessary_dirs(self):
"""Create some necessary directories to save some important files.
"""
time_stamp = datetime.datetime.now().strftime('%m%d_%H%M')
self.config['trainer']['args']['log_dir'] = ''.join(
(self.config['trainer']['args']['log_dir'],
self.config['task_name'], '/')) # , '.%s/' % (time_stamp)))
self.config['trainer']['args']['save_dir'] = ''.join(
(self.config['trainer']['args']['save_dir'],
self.config['task_name'], '/')) # , '.%s/' % (time_stamp)))
self.config['trainer']['args']['output_dir'] = ''.join(
(self.config['trainer']['args']['output_dir'],
self.config['task_name'], '/')) # , '.%s/' % (time_stamp)))
# if os.path.exists(self.config['trainer']['args']['save_dir']):
# input('save_dir is existed, do you really want to continue?')
self.make_dir(self.config['trainer']['args']['log_dir'])
self.make_dir(self.config['trainer']['args']['save_dir'])
self.make_dir(self.config['trainer']['args']['output_dir'])
def save_config_file(self):
"""Save config file so that we can know the config when we look back
"""
filename = self.config_file.split('/')[-1]
targetpath = self.config['trainer']['args']['save_dir']
shutil.copyfile(self.config_file, targetpath + filename)
def make_dir(self, path):
"""Build directory"""
if not os.path.exists(path):
os.makedirs(path)
def __getitem__(self, key):
"""Return the configure dict"""
return self.config[key]
def __call__(self):
"""__call__"""
return self.config