forked from Syyabb/PUD
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_sig.py
147 lines (119 loc) · 4.28 KB
/
train_sig.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
from torch.utils.tensorboard import SummaryWriter
import config
import torch
from classifier_models import PreActResNet18, ResNet18
from networks.models import NetC_MNIST
from utils.dataloader import PostTensorTransform, get_dataloader
from utils.utils import progress_bar
from create_bd import patch, sig
from aft_train import eval, train_eval
from model import SequentialImageNetwork
import torch.nn as nn
from torchvision import models
def get_model(opt):
netC = None
optimizerC = None
schedulerC = None
if opt.dataset == "cifar10" or opt.dataset == "gtsrb":
from pytorch_cifar.models import resnet
netC = PreActResNet18(num_classes=opt.num_classes).to(opt.device)
#model = resnet.ResNet18().cuda()
#netC = SequentialImageNetwork(netC).cuda()
if opt.dataset == "celeba":
netC = ResNet18().to(opt.device)
if opt.dataset == "imagenet":
netC = models.resnet18(num_classes=1000, pretrained="imagenet")
# netC = models.resnet18(weights = ResNet18_Weights.IMAGENET1K_V1)
num_ftrs = netC.fc.in_features
netC.fc = nn.Linear(num_ftrs, 10)
netC.cuda()
# Optimizer
optimizerC = torch.optim.SGD(netC.parameters(), opt.lr_C, momentum=0.9, weight_decay=5e-4)
# Scheduler
schedulerC = torch.optim.lr_scheduler.MultiStepLR(optimizerC, opt.schedulerC_milestones, opt.schedulerC_lambda)
return netC, optimizerC, schedulerC
def train(netC, optimizerC, schedulerC, train_dl, opt, tf_writer):
print(" Train:")
netC.train()
rate_bd = opt.pc
criterion_CE = torch.nn.CrossEntropyLoss()
transforms = PostTensorTransform(opt).to(opt.device)
for batch_idx, (inputs, targets) in enumerate(train_dl):
optimizerC.zero_grad()
#print(inputs.shape)
#print(targets.shape)
inputs, targets = inputs.to(opt.device), targets.to(opt.device)
inputs = transforms(inputs)
with torch.no_grad():
bs = inputs.shape[0]
num_bd = int(opt.pc * bs)
inputs_bd, targets_bd = sig(inputs[:num_bd], targets[:num_bd], opt)
total_inputs = torch.cat((inputs_bd, inputs[num_bd :]), 0)
total_targets = torch.cat((targets_bd, targets[num_bd:]), 0)
total_preds = netC(total_inputs)
loss_ce = criterion_CE(total_preds, total_targets)
loss = loss_ce
loss.backward()
optimizerC.step()
schedulerC.step()
def main():
opt = config.get_arguments().parse_args()
if opt.dataset in ["mnist", "cifar10"]:
opt.num_classes = 10
elif opt.dataset == "gtsrb":
opt.num_classes = 43
elif opt.dataset == "imagenet":
opt.num_classes = 10
else:
raise Exception("Invalid Dataset")
if opt.dataset == "cifar10":
opt.input_height = 32
opt.input_width = 32
opt.input_channel = 3
elif opt.dataset == "gtsrb":
opt.input_height = 32
opt.input_width = 32
opt.input_channel = 3
elif opt.dataset == "mnist":
opt.input_height = 28
opt.input_width = 28
opt.input_channel = 1
elif opt.dataset == "imagenet":
opt.input_height = 224
opt.input_width = 224
opt.input_channel = 3
else:
raise Exception("Invalid Dataset")
# Dataset
train_dl = get_dataloader(opt, True)
test_dl = get_dataloader(opt, False)
# prepare model
opt.target_label = 0
netC, optimizerC, schedulerC = get_model(opt)
log_path = os.path.join('./log', 'sig_train')
writer = SummaryWriter(log_path)
pt_name = opt.attack_mode + '_' + 'sig' + '_' + opt.dataset + str(opt.target_label) +"N" + '.pt'
pt_path = os.path.join('./pt', pt_name)
for epoch in range(60):
print("Epoch {}:".format(epoch + 1))
#netC.load_state_dict(torch.load("./pt/all2all_sig_gtsrb2.pt"))
train(netC, optimizerC, schedulerC, train_dl, opt, writer)
if (epoch+1) % 5 == 0:
eval(
netC,
test_dl,
opt,
writer,
epoch
)
eval(
netC,
train_dl,
opt,
writer,
epoch
)
torch.save(netC.state_dict(), pt_path)
if __name__ == "__main__":
main()