-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_word.py
46 lines (34 loc) · 969 Bytes
/
train_word.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# -*- coding: utf-8 -*-
import sys
reload(sys)
sys.setdefaultencoding('utf8')
from rnn import RNN
import nltk
hidden_size = 100 # Size of hidden layer of neurons (H)
seq_length = 5 # Number of steps to unroll the RNN
learning_rate = 2e-3
with open('data/input.txt') as f:
data = f.read().replace('\n', ' ').encode('ascii', 'ignore')
data = data.lower()
data = nltk.word_tokenize(data)
args = {
'hidden_size': hidden_size,
'seq_length': seq_length,
'learning_rate': learning_rate,
'data': data
}
# Initialized the RNN and run the first epoch
rnn = RNN(args)
inputs, hidden, loss = rnn.step()
i = 0
while True:
inputs, hidden, loss = rnn.step(hidden)
if i % 100 == 0:
print "Iteration {}:".format(i)
print "Loss: {}".format(loss)
print ' '.join(rnn.generate(hidden, inputs[0], 15))
print ""
# if i % 10000 == 0:
# rnn.save_model()
# print "Checkpoint saved!"
i += 1