Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Print each x #206

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,8 @@
*.t7
*.svg
*.dot
*.png

# temp files
_*

6 changes: 6 additions & 0 deletions data/dict1/input.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
supercalifragilisticexpialidocious n. secret. 2 (foll. by on) formidious going leaves. 2 breat which the being hand. [old english]
healte v. (-ling) drinking or esp. clowel or armitic take away or causing someting. 6 sippossion of algeratous. [latin: related to *tan-1 a deer-notic mutder maddly lowy, a restinatiun]
candrious adj. 1 suchering, years. personist adj. disensentionist n. [french]
rescacabole n. urless skoiling a band bexope out in farehind earen day-deaseding. [latin gonar]
repipt n. don-if for a not actuous listing.
steatshide v. 1 (brit. abshit hair). containented oneself trryable to and branging, propession. [french honic]
1 change: 1 addition & 0 deletions data/fox/input.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The quick brown fox jumps over the lazy dog
1 change: 1 addition & 0 deletions data/simple/input.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
abcde cab bad ace add ebb deed dead cede
32 changes: 31 additions & 1 deletion train.lua
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ end
local loader = CharSplitLMMinibatchLoader.create(opt.data_dir, opt.batch_size, opt.seq_length, split_sizes)
local vocab_size = loader.vocab_size -- the number of distinct characters
local vocab = loader.vocab_mapping
local vocab_inv = {}
for k, v in pairs(loader.vocab_mapping) do
-- print('vocab pair', k, v)
vocab_inv[v] = k
end
print('vocab size: ' .. vocab_size)
-- make sure output directory exists
if not path.exists(opt.checkpoint_dir) then lfs.mkdir(opt.checkpoint_dir) end
Expand Down Expand Up @@ -212,6 +217,24 @@ function eval_split(split_index, max_batches)
return loss
end

function sampleToString(sample)
local sample_copy = sample:clone():int()
str = ''
if sample_copy:size():size() == 2 then
for j=1, sample_copy:size()[1] do
for i=1, sample_copy:size()[2] do
str = str .. vocab_inv[sample_copy[j][i]]
end
str = str .. '|'
end
elseif sample_copy:size():size() == 1 then
for i=1, sample_copy:size()[1] do
str = str .. vocab_inv[sample_copy[i]]
end
end
return str
end

-- do fwd/bwd and return loss, grad_params
local init_state_global = clone_list(init_state)
function feval(x)
Expand All @@ -222,6 +245,7 @@ function feval(x)

------------------ get minibatch -------------------
local x, y = loader:next_batch(1)
print('x', x:size(), sampleToString(x), 'y', sampleToString(y))
if opt.gpuid >= 0 and opt.opencl == 0 then -- ship the input arrays to GPU
-- have to convert to float because integers can't be cuda()'d
x = x:float():cuda()
Expand All @@ -237,7 +261,9 @@ function feval(x)
local loss = 0
for t=1,opt.seq_length do
clones.rnn[t]:training() -- make sure we are in correct mode (this is cheap, sets flag)
local lst = clones.rnn[t]:forward{x[{{}, t}], unpack(rnn_state[t-1])}
local thisinput = {x[{{}, t}], unpack(rnn_state[t-1])}
print('t=' .. t .. ' thisinput[1]', sampleToString(thisinput[1]))
local lst = clones.rnn[t]:forward(thisinput)
rnn_state[t] = {}
for i=1,#init_state do table.insert(rnn_state[t], lst[i]) end -- extract the state, without output
predictions[t] = lst[#lst] -- last element is the prediction
Expand Down Expand Up @@ -278,6 +304,9 @@ local iterations_per_epoch = loader.ntrain
local loss0 = nil
for i = 1, iterations do
local epoch = i / loader.ntrain
if epoch > 1 then
os.exit(0)
end

local timer = torch.Timer()
local _, loss = optim.rmsprop(feval, params, optim_state)
Expand Down Expand Up @@ -331,6 +360,7 @@ for i = 1, iterations do
print('loss is exploding, aborting.')
break -- halt
end

end