-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainer_cifar10_gep_public_no_gp.py
120 lines (97 loc) · 5.9 KB
/
trainer_cifar10_gep_public_no_gp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import argparse
from pathlib import Path
import torch
import wandb
from dataset import gen_random_loaders
from trainer_gep_public_no_gp import train
from utils import set_logger, set_seed, str2bool
def get_dataloaders(args):
train_loaders, val_loaders, test_loaders = gen_random_loaders(
args.data_name,
args.data_path,
args.num_clients,
args.batch_size,
args.classes_per_client)
return train_loaders, val_loaders, test_loaders
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="GEP Public CIFAR10/100 Federated Learning")
data_name = 'cifar10'
##################################
# Network args #
##################################
parser.add_argument("--num-blocks", type=int, default=3)
parser.add_argument("--block-size", type=int, default=3)
parser.add_argument("--model-name", type=str, choices=['CNNTarget', 'ResNet'], default='ResNet')
parser.add_argument("--n-kernels", type=int, default=16, help="number of kernels")
parser.add_argument('--embed-dim', type=int, default=64)
parser.add_argument('--use-gp', type=str2bool, default=False)
##################################
# Optimization args #
##################################
parser.add_argument("--num-steps", type=int, default=100)
parser.add_argument("--optimizer", type=str, default='adam',
choices=['adam', 'sgd'], help="optimizer type")
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--inner-steps", type=int, default=1, help="number of inner steps")
parser.add_argument("--num-client-agg", type=int, default=10, help="number of clients per step")
parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
parser.add_argument("--global_lr", type=float, default=0.9, help="server learning rate")
parser.add_argument("--wd", type=float, default=1e-4, help="weight decay")
parser.add_argument("--clip", type=float, default=0.1, help="gradient clip")
parser.add_argument("--clip_residual", type=float, default=0.1, help="residual clip")
parser.add_argument("--noise-multiplier", type=float, default=1.0, help="gradient dp noise factor"
" to be multiplied by clip")
parser.add_argument("--noise-multiplier-residual", type=float, default=1.0, help="residual part "
"dp noise factor"
" to be multiplied by clip")
##################################
# GEP args #
##################################
parser.add_argument("--gradients-history-size", type=int,
default=500, help="amount of past gradients participating in embedding subspace computation")
parser.add_argument("--basis-size", type=int, default=50, help="number of basis vectors")
#############################
# General args #
#############################
parser.add_argument("--num-workers", type=int, default=0, help="number of workers")
parser.add_argument("--gpus", type=str, default='0', help="gpu device ID")
parser.add_argument("--exp-name", type=str, default='', help="suffix for exp name")
parser.add_argument("--save-path", type=str, default=(Path.home() / 'saved_models').as_posix(),
help="dir path for saved models")
parser.add_argument("--seed", type=int, default=42, help="seed value")
parser.add_argument('--wandb', type=str2bool, default=False)
parser.add_argument("--gpu", type=int, default=0, help="gpu device ID")
parser.add_argument("--eval-every", type=int, default=5, help="eval every X selected epochs")
parser.add_argument("--eval-after", type=int, default=25, help="eval only after X selected epochs")
parser.add_argument("--log-every", type=int, default=1, help="log every X selected epochs")
parser.add_argument("--log-dir", type=str, default="./log", help="dir path for logger file")
parser.add_argument("--log-name", type=str, default="gep_private", help="dir path for logger file")
parser.add_argument("--csv-path", type=str, default="./csv", help="dir path for csv file")
parser.add_argument("--csv-name", type=str, default=f"{data_name}_sgd_dp.csv", help="dir path for csv file")
#############################
# Dataset Args #
#############################
parser.add_argument(
"--data-name", type=str, default=data_name,
choices=['cifar10', 'cifar100', 'putEMG', 'mnist'], help="dataset"
)
parser.add_argument("--data-path", type=str, default="data", help="dir path for dataset")
#############################
# Clients Args #
#############################
parser.add_argument("--num-clients", type=int, default=500, help="total number of clients")
parser.add_argument("--num-private-clients", type=int, default=490, help="number of private clients")
parser.add_argument("--num-public-clients", type=int, default=10, help="number of public clients")
parser.add_argument("--classes-per-client", type=int, default=2, help="number of simulated clients")
parser.add_argument("--num-client-agg", type=int, default=100, help="number of clients per step")
args = parser.parse_args()
assert args.gpu <= torch.cuda.device_count(), f"--gpu flag should be in range [0,{torch.cuda.device_count() - 1}]"
logger = set_logger(args)
logger.info(f"Args: {args}")
set_seed(args.seed)
exp_name = f'GEP_PUBLIC_{args.data_name}_lr_{args.lr}_clip_{args.clip}_noise_{args.noise_multiplier}'
# Weights & Biases
if args.wandb:
wandb.init(project="key_press_emg_toronto", name=exp_name)
wandb.config.update(args)
train(args, get_dataloaders(args))