-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
126 lines (91 loc) · 3.52 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import json
import os, argparse
__all__ = ["ConfLoader", "directory_setter", "config_overwriter", "str2bool"]
class ConfLoader:
"""
Load json config file using DictWithAttributeAccess object_hook.
ConfLoader(conf_name).opt attribute is the result of loading json config file.
"""
class DictWithAttributeAccess(dict):
"""
This inner class makes dict to be accessed same as class attribute.
For example, you can use opt.key instead of the opt['key']
"""
def __getattr__(self, key):
return self[key]
def __setattr__(self, key, value):
self[key] = value
def __init__(self, conf_name):
self.conf_name = conf_name
self.opt = self.__get_opt()
def __load_conf(self):
with open(self.conf_name, "r") as conf:
opt = json.load(
conf, object_hook=lambda dict: self.DictWithAttributeAccess(dict)
)
return opt
def __get_opt(self):
opt = self.__load_conf()
opt = self.DictWithAttributeAccess(opt)
return opt
def directory_setter(path="./results", make_dir=False):
"""
Make dictionary if not exists.
"""
if not os.path.exists(path) and make_dir:
os.makedirs(path) # make dir if not exist
print("directory %s is created" % path)
if not os.path.isdir(path):
raise NotADirectoryError(
"%s is not valid. set make_dir=True to make dir." % path
)
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
def config_overwriter(opt, args):
"""
Overwrite loaded configuration by parsing arguments.
"""
if args.dataset_name is not None:
opt.data_setups.dataset_name = args.dataset_name
if args.n_clients is not None:
opt.data_setups.n_clients = args.n_clients
if args.model_name is not None:
opt.train_setups.model.name = args.model_name
if args.partition_method is not None:
opt.data_setups.partition.method = args.partition_method
if args.partition_s is not None:
opt.data_setups.partition.shard_per_client = args.partition_s
if args.partition_alpha is not None:
opt.data_setups.partition.alpha = args.partition_alpha
if args.n_rounds is not None:
opt.train_setups.scenario.n_rounds = args.n_rounds
if args.sample_ratio is not None:
opt.train_setups.scenario.sample_ratio = args.sample_ratio
if args.local_epochs is not None:
opt.train_setups.scenario.local_epochs = args.local_epochs
if args.device is not None:
opt.train_setups.scenario.device = args.device
if args.lr is not None:
opt.train_setups.optimizer.params.lr = args.lr
if args.rho is not None:
opt.train_setups.algo.params.rho = args.rho
if args.perturb_body is not None:
opt.train_setups.algo.params.perturb_body = args.perturb_body
if args.perturb_head is not None:
opt.train_setups.algo.params.perturb_head = args.perturb_head
if args.algo_name is not None:
opt.train_setups.algo.name = args.algo_name
if args.seed is not None:
opt.train_setups.seed = args.seed
if args.group is not None:
opt.wandb_setups.group = args.group
if args.exp_name is not None:
opt.wandb_setups.name = args.exp_name
return opt