-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
47 lines (34 loc) · 942 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# -*- coding: utf-8 -*-
import sys
reload(sys)
sys.setdefaultencoding('utf8')
from rnn import RNN
from lstm import LSTM
hidden_size = 300 # Size of hidden layer of neurons (H)
seq_length = 50 # Number of steps to unroll the RNN
learning_rate = 2e-3
with open('data/input.txt') as f:
data = f.read().replace('\n', ' ').encode('ascii', 'ignore')
args = {
'hidden_size': hidden_size,
'seq_length': seq_length,
'learning_rate': learning_rate,
'data': data
}
# Initialized the RNN and run the first epoch
rnn = RNN(args)
inputs, hidden, loss = rnn.step()
i = 0
while True:
inputs, hidden, loss = rnn.step(hidden)
if i % 100 == 0:
print "Iteration {}:".format(i)
print "Loss: {}".format(loss)
print ''.join(rnn.generate(hidden, inputs[0], 140))
print ""
if i % 10000 == 0:
rnn.save_model()
print "Checkpoint saved!"
i += 1
# args = {
# }