-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain-exp.py
68 lines (58 loc) · 2.62 KB
/
train-exp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import json
import os
import torch
import torch.optim as optim
from sacred import Experiment
from sacred.observers import MongoObserver
from ingredients.attacks import attacks_ingredient, pgd_linf
from ingredients.data_loader import data_ingredient, load_cifar10
from ingredients.test import adv_test, test, test_ingredient
from ingredients.train import train, train_ingredient
from model.resnet import resnet20
ex = Experiment('bitdepth-train', ingredients=[data_ingredient,
attacks_ingredient,
train_ingredient,
test_ingredient])
ex.add_config('config.json')
ex.observers.append(MongoObserver.create())
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
@ex.automain
def run(data_dir, trained_models_dir, momentum, weight_decay, results_dir):
# check data and trained-models dir
if not os.path.exists(data_dir):
os.makedir(data_dir)
if not os.path.exists(trained_models_dir):
os.makedir(trained_models_dir)
if not os.path.exists(results_dir):
os.makedir(results_dir)
train_loader, test_loader = load_cifar10()
# train 1-8bit quantization and with dithering
train_errs, train_losses = [], []
train_dither_errs, train_dither_losses = [], []
for i in range(1, 9):
model = resnet20().to(device)
opt = optim.SGD(model.parameters(), lr=1e-1, momentum=momentum,
weight_decay=weight_decay)
train_err, train_loss = train(model=model, loader=train_loader,
bitdepth=i, opt=opt,
model_name='resnet20')
train_errs.append(train_err)
train_losses.append(train_loss)
for i in range(1, 8):
model = resnet20().to(device)
opt = optim.SGD(model.parameters(), lr=1e-1, momentum=momentum,
weight_decay=weight_decay)
train_err, train_loss = train(model=model, loader=train_loader,
bitdepth=i, dither=True,
opt=opt, model_name='resnet20')
train_dither_errs.append(train_err)
train_dither_losses.append(train_loss)
# save the results in json
results = {}
results['train_errors'] = train_errs
results['train_losses'] = train_losses
results['train_dither_errors'] = train_dither_errs
results['train_dither_losses'] = train_dither_losses
path = os.path.join(results_dir, 'train-results.json')
with open(path, 'w') as f:
json.dump(results, f, indent=4)