diff --git a/sample.lua b/sample.lua index e6d0f0c8..3ddd0e5e 100644 --- a/sample.lua +++ b/sample.lua @@ -28,6 +28,7 @@ cmd:argument('-model','model checkpoint to use for sampling') cmd:option('-seed',123,'random number generator\'s seed') cmd:option('-sample',1,' 0 to use max at each timestep, 1 to sample at each timestep') cmd:option('-primetext',"",'used as a prompt to "seed" the state of the LSTM using a given sequence, before we sample.') +cmd:option('-ignore_bad_primetext',false,'ignore characters in primetext that are not in the trained vocabulary') cmd:option('-length',2000,'number of characters to sample') cmd:option('-temperature',1,'temperature of sampling') cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU') @@ -92,6 +93,9 @@ if string.len(seed_text) > 0 then gprint('seeding with ' .. seed_text) gprint('--------------------------') for c in seed_text:gmatch'.' do + if vocab[c] == nil and opt.ignore_bad_primetext then + goto skip_seed_character + end prev_char = torch.Tensor{vocab[c]} io.write(ivocab[prev_char[1]]) if opt.gpuid >= 0 then prev_char = prev_char:cuda() end @@ -100,6 +104,7 @@ if string.len(seed_text) > 0 then current_state = {} for i=1,state_size do table.insert(current_state, lst[i]) end prediction = lst[#lst] -- last element holds the log probabilities + ::skip_seed_character:: end else -- fill with uniform probabilities over characters (? hmm)