Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added log probability of sample to sample output #151

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion sample.lua
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ protos.rnn:evaluate() -- put in eval mode so that dropout works properly
local vocab = checkpoint.vocab
local ivocab = {}
for c,i in pairs(vocab) do ivocab[i] = c end
local char_log_prob = checkpoint.char_log_prob

-- initialize the rnn state to all zeros
gprint('creating an ' .. checkpoint.opt.model .. '...')
Expand All @@ -108,13 +109,24 @@ for L = 1,checkpoint.opt.num_layers do
end
state_size = #current_state

-- keep the probability of the sampled text
local sample_log_prob = nil

-- do a few seeded timesteps
local seed_text = opt.primetext
if string.len(seed_text) > 0 then
gprint('seeding with ' .. seed_text)
gprint('--------------------------')
for c in seed_text:gmatch'.' do
prev_char = torch.Tensor{vocab[c]}
-- initialize the sample probability to the empirical log probability of the first character
if not sample_log_prob then
sample_log_prob = char_log_prob[c]
else
-- use the previous prediction to find the probability of this character
sample_log_prob = sample_log_prob + prediction[1][vocab[c]]
end

io.write(ivocab[prev_char[1]])
if opt.gpuid >= 0 and opt.opencl == 0 then prev_char = prev_char:cuda() end
if opt.gpuid >= 0 and opt.opencl == 1 then prev_char = prev_char:cl() end
Expand All @@ -128,7 +140,7 @@ else
-- fill with uniform probabilities over characters (? hmm)
gprint('missing seed text, using uniform probability over first character')
gprint('--------------------------')
prediction = torch.Tensor(1, #ivocab):fill(1)/(#ivocab)
prediction = torch.log(torch.Tensor(1, #ivocab):fill(1)/(#ivocab))
if opt.gpuid >= 0 and opt.opencl == 0 then prediction = prediction:cuda() end
if opt.gpuid >= 0 and opt.opencl == 1 then prediction = prediction:cl() end
end
Expand All @@ -141,12 +153,21 @@ for i=1, opt.length do
-- use argmax
local _, prev_char_ = prediction:max(2)
prev_char = prev_char_:resize(1)
prev_char_log_prob = prediction[1][prev_char[1]]
else
-- use sampling
prediction:div(opt.temperature) -- scale by temperature
local probs = torch.exp(prediction):squeeze()
probs:div(torch.sum(probs)) -- renormalize so probs sum to one
prev_char = torch.multinomial(probs:float(), 1):resize(1):float()
prev_char_log_prob = torch.log(probs[prev_char[1]]) -- use log of probs to account for temperature effect
end

-- initialize the sample_log_prob (in case there was no seed text) to the emprical log prob of the selected char
if not sample_log_prob then
sample_log_prob = char_log_prob[ivocab[prev_char[1]]]
else
sample_log_prob = sample_log_prob + prev_char_log_prob
end

-- forward the rnn for next character
Expand All @@ -159,3 +180,4 @@ for i=1, opt.length do
end
io.write('\n') io.flush()

gprint('\nSample log probability: ' .. sample_log_prob)
1 change: 1 addition & 0 deletions train.lua
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ for i = 1, iterations do
checkpoint.i = i
checkpoint.epoch = epoch
checkpoint.vocab = loader.vocab_mapping
checkpoint.char_log_prob = loader.char_log_prob
torch.save(savefile, checkpoint)
end

Expand Down
27 changes: 20 additions & 7 deletions util/CharSplitLMMinibatchLoader.lua
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,36 @@ function CharSplitLMMinibatchLoader.create(data_dir, batch_size, seq_length, spl
local input_file = path.join(data_dir, 'input.txt')
local vocab_file = path.join(data_dir, 'vocab.t7')
local tensor_file = path.join(data_dir, 'data.t7')
local char_log_prob_file = path.join(data_dir, 'char_log_prob.t7')

-- fetch file attributes to determine if we need to rerun preprocessing
local run_prepro = false
if not (path.exists(vocab_file) or path.exists(tensor_file)) then
if not (path.exists(vocab_file) or path.exists(tensor_file) or path.exists(char_log_prob_file)) then
-- prepro files do not exist, generate them
print('vocab.t7 and data.t7 do not exist. Running preprocessing...')
print('vocab.t7, data.t7, and/or char_log_prob.t7 do not exist. Running preprocessing...')
run_prepro = true
else
-- check if the input file was modified since last time we
-- ran the prepro. if so, we have to rerun the preprocessing
local input_attr = lfs.attributes(input_file)
local vocab_attr = lfs.attributes(vocab_file)
local tensor_attr = lfs.attributes(tensor_file)
if input_attr.modification > vocab_attr.modification or input_attr.modification > tensor_attr.modification then
print('vocab.t7 or data.t7 detected as stale. Re-running preprocessing...')
local char_log_prob_attr = lfs.attributes(char_log_prob_file)
if input_attr.modification > vocab_attr.modification or input_attr.modification > tensor_attr.modification or input_attr.modification > char_log_prob_attr.modification then
print('vocab.t7, data.t7, or char_log_prob.t7 detected as stale. Re-running preprocessing...')
run_prepro = true
end
end
if run_prepro then
-- construct a tensor with all the data, and vocab file
print('one-time setup: preprocessing input text file ' .. input_file .. '...')
CharSplitLMMinibatchLoader.text_to_tensor(input_file, vocab_file, tensor_file)
CharSplitLMMinibatchLoader.text_to_tensor(input_file, vocab_file, tensor_file, char_log_prob_file)
end

print('loading data files...')
local data = torch.load(tensor_file)
self.vocab_mapping = torch.load(vocab_file)
self.char_log_prob = torch.load(char_log_prob_file)

-- cut off the end so that it divides evenly
local len = data:size(1)
Expand Down Expand Up @@ -123,7 +126,7 @@ function CharSplitLMMinibatchLoader:next_batch(split_index)
end

-- *** STATIC method ***
function CharSplitLMMinibatchLoader.text_to_tensor(in_textfile, out_vocabfile, out_tensorfile)
function CharSplitLMMinibatchLoader.text_to_tensor(in_textfile, out_vocabfile, out_tensorfile, out_char_log_probfile)
local timer = torch.Timer()

print('loading text file...')
Expand All @@ -139,12 +142,20 @@ function CharSplitLMMinibatchLoader.text_to_tensor(in_textfile, out_vocabfile, o
rawdata = f:read(cache_len)
repeat
for char in rawdata:gmatch'.' do
if not unordered[char] then unordered[char] = true end
if not unordered[char] then
unordered[char] = 1
else
unordered[char] = unordered[char] + 1
end
end
tot_len = tot_len + #rawdata
rawdata = f:read(cache_len)
until not rawdata
f:close()
-- construct character log probabilities
local log_tot_len = torch.log(tot_len)
for char, count in pairs(unordered) do unordered[char] = torch.log(count) - log_tot_len end

-- sort into a table (i.e. keys become 1..N)
local ordered = {}
for char in pairs(unordered) do ordered[#ordered + 1] = char end
Expand Down Expand Up @@ -174,6 +185,8 @@ function CharSplitLMMinibatchLoader.text_to_tensor(in_textfile, out_vocabfile, o
torch.save(out_vocabfile, vocab_mapping)
print('saving ' .. out_tensorfile)
torch.save(out_tensorfile, data)
print('saving ' .. out_char_log_probfile)
torch.save(out_char_log_probfile, unordered)
end

return CharSplitLMMinibatchLoader
Expand Down