-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
112 lines (83 loc) · 3.41 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from datetime import datetime
import torch
from torch import nn
from torch.autograd import Variable
from midi_converter.importer import get_data
from logger import log_info, log_ok
from models.autoencoder import Autoencoder, build_model
NUM_EPOCHS = 2000
SAVE_MODEL_EACH = 500
SHOW_EXACT_LOSS_EACH = 9999999
DATASET = 'ninsheetmusic_trans'
DATA_LENGTH = 4016 # 4016 = 251 * 2^4 is the max
BATCH_SIZE = 251
MODEL = 'autoencoder_batch'
CUDA = True
# optimizer
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-5
assert (DATA_LENGTH % BATCH_SIZE == 0)
def train():
log_info('Loading data...')
dataset = get_data('ninsheetmusic')[:DATA_LENGTH]
dataset_length = dataset.size()[0]
log_info('Data loaded! Dataset of size {}'.format(dataset_length))
log_info('Building model...')
model: Autoencoder = build_model(CUDA)
if CUDA:
model = model.cuda()
criterion = nn.BCELoss()
optimizer = torch.optim.RMSprop(model.parameters(), lr=LEARNING_RATE)
log_info('Model built! Starting training...')
start_time = datetime.now()
loss_global_history = []
# train loop
for epoch in range(NUM_EPOCHS):
loss_local_history = []
# prepare batches
perm = torch.randperm(DATA_LENGTH)
shuffled_data = dataset[perm]
batched_data = torch.chunk(shuffled_data, DATA_LENGTH // BATCH_SIZE)
for i, data in enumerate(batched_data):
data = Variable(data)
if CUDA:
data = data.cuda()
# forward
output = model(data)
loss = criterion(output, data)
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_local_history.append(loss.data)
# logging
# if not CUDA:
# if i in [10, 25, 50] or i % 100 == 0:
# log_ok('Processed {}/{} songs.'.format(i + 1, dataset_length))
# else:
# if i == 100 or i % 1000 == 0:
# log_ok('Processed {}/{} songs.'.format(i + 1, dataset_length))
# after each epoch
mean_loss = torch.tensor(loss_local_history).mean()
loss_global_history.append(mean_loss)
if epoch != 0 and epoch % SAVE_MODEL_EACH == 0:
torch.save(model.state_dict(),
'export/models/{}/{}_d{}_e{}_l{:.16f}.pt'.format(MODEL, start_time, dataset_length, epoch,
mean_loss))
if epoch != 0 and epoch % SHOW_EXACT_LOSS_EACH == 0:
log_info('loss: {}'.format(loss_local_history))
# ===================log========================
log_info('Epoch [{}/{}], loss:{:.16f}. Time elapsed: {}'
.format(epoch + 1, NUM_EPOCHS, mean_loss, datetime.now() - start_time))
log_info('Finished training in {}'.format(datetime.now() - start_time))
torch.save(model.state_dict(),
'export/models/{}/{}_fin_d{}_e{}_l{:.16f}.pt'.format(MODEL, start_time, dataset_length, NUM_EPOCHS,
mean_loss))
log_ok('Model saved.')
torch.save(loss_global_history, "loss.pt")
if __name__ == '__main__':
log_info('{} CUDA device(s) available.'.format(
torch.cuda.device_count()) if torch.cuda.is_available() else 'No CUDA available.')
if CUDA:
torch.cuda.empty_cache()
train()