-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtrain.py
172 lines (131 loc) · 4.27 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Script for learning opitimal wavelet bases using a neural network approach.
"""
# Basic import(s)
import time
import json
import numpy as np
# Pytorch import(s)
import torch
from torch.autograd import Variable
# Project import(s)
from wavenet.loss import *
from wavenet.utils import *
from wavenet.modules import *
from wavenet.generators import *
# Type definitions
dtype = torch.FloatTensor
# General definitions
seed = 22
num_params = 16
batch_size = 1
input_shape = (64,)
generator_name = 'spikes'
gen_opts = dict(input_shape=input_shape)
# Main function definition.
def main ():
# Generator
if generator_name == 'sine':
generator = generate_sine
elif generator_name == 'spikes':
generator = generate_spikes
else:
raise "Generator {} not supported.".format(generator_name)
# Reproducibility
np.random.seed(seed)
torch.manual_seed(seed)
# Learnable, universal filter coefficients
params = np.random.randn(num_params)
params = transform_params(params)
# Pytorch variables
params = Variable(torch.FloatTensor(params), requires_grad=True)
indices = Variable(torch.arange(num_params).type(dtype) / np.float(num_params - 1), requires_grad=False)
x = Variable(torch.randn(input_shape).type(dtype), requires_grad=False)
# Wavenet instance
w = Wavenet(params, input_shape)
# Optimiser
optimiser = torch.optim.Adam([params], lr=1e-02)
lambda_reg = 1.0E+02
num_steps = 5000
# Regularisation
reg = Regularisation(params)
# Training loop
print "Initial parameters:", params.data.numpy()
loss_dict = lambda : {'sparsity': 0, 'regularisation': 0, 'combined': 0, 'compactness': 0}
losses = {'sparsity': [0], 'regularisation': [0], 'combined': [0], 'compactness': [0]}
print "=" * 80
print "START RUNNING"
print "-" * 80
start = time.time()
for step, x_ in enumerate(generator(**gen_opts)):
# Stop condition
if step >= num_steps:
break
# Set input
x.data = torch.from_numpy(x_).type(dtype)
# Get wavelet coefficients
c = w.forward(x)
# Sparsity loss
sparsity = 1. - gini(c)
# Regularisation loss
regularisation = reg.forward()
# Compactness loss
compactness = torch.sum(torch.dot(indices - 0.5, params.abs() / params.abs().sum()))
# Combined loss
combined = sparsity + lambda_reg * (regularisation) + compactness
# Perform backpropagation
combined.backward()
# Parameter update
if step % batch_size == 0:
optimiser.step()
optimiser.zero_grad()
pass
# Non-essential stuff below
# -------------------------------------------------------------------------
# Log
if step % 1000 == 0:
print "Step {}/{}".format(step, num_steps)
pass
# Logging loss history
losses['sparsity'][-1] += np.float(sparsity)
losses['regularisation'][-1] += np.float(regularisation)
losses['compactness'][-1] += np.float(compactness)
losses['combined'][-1] += np.float(combined)
if step % batch_size == 0:
for key in losses:
losses[key][-1] /= float(batch_size)
losses[key].append(0.)
pass
pass
# Draw model diagram
if step == 0:
from torchviz import make_dot
dot = make_dot(sparsity, params={'params': params, 'input': x})
dot.format = 'pdf'
dot.render('output/model')
pass
pass
end = time.time()
print "-" * 80
print "Took {:.1f} sec.".format(end - start)
print "=" * 80
# Clean-up
for key in losses:
losses[key].pop(-1)
pass
print "Final parameters:", params.data.numpy()
# Save to file
tag = '{}__N{}__{}'.format('x'.join(map(str, input_shape)), num_params, generator_name)
# -- Model
torch.save(w, 'output/model__{}.pt'.format(tag))
# -- Loss
with open('output/loss__{}.json'.format(tag), 'w') as f:
json.dump(losses, f)
pass
return
# Main function call.
if __name__ == '__main__':
main()
pass