diff --git a/sample.lua b/sample.lua index e22ece74..97dc2256 100644 --- a/sample.lua +++ b/sample.lua @@ -91,6 +91,7 @@ protos.rnn:evaluate() -- put in eval mode so that dropout works properly local vocab = checkpoint.vocab local ivocab = {} for c,i in pairs(vocab) do ivocab[i] = c end +local char_log_prob = checkpoint.char_log_prob -- initialize the rnn state to all zeros gprint('creating an ' .. checkpoint.opt.model .. '...') @@ -108,6 +109,9 @@ for L = 1,checkpoint.opt.num_layers do end state_size = #current_state +-- keep the probability of the sampled text +local sample_log_prob = nil + -- do a few seeded timesteps local seed_text = opt.primetext if string.len(seed_text) > 0 then @@ -115,6 +119,14 @@ if string.len(seed_text) > 0 then gprint('--------------------------') for c in seed_text:gmatch'.' do prev_char = torch.Tensor{vocab[c]} + -- initialize the sample probability to the empirical log probability of the first character + if not sample_log_prob then + sample_log_prob = char_log_prob[c] + else + -- use the previous prediction to find the probability of this character + sample_log_prob = sample_log_prob + prediction[1][vocab[c]] + end + io.write(ivocab[prev_char[1]]) if opt.gpuid >= 0 and opt.opencl == 0 then prev_char = prev_char:cuda() end if opt.gpuid >= 0 and opt.opencl == 1 then prev_char = prev_char:cl() end @@ -128,7 +140,7 @@ else -- fill with uniform probabilities over characters (? hmm) gprint('missing seed text, using uniform probability over first character') gprint('--------------------------') - prediction = torch.Tensor(1, #ivocab):fill(1)/(#ivocab) + prediction = torch.log(torch.Tensor(1, #ivocab):fill(1)/(#ivocab)) if opt.gpuid >= 0 and opt.opencl == 0 then prediction = prediction:cuda() end if opt.gpuid >= 0 and opt.opencl == 1 then prediction = prediction:cl() end end @@ -141,12 +153,21 @@ for i=1, opt.length do -- use argmax local _, prev_char_ = prediction:max(2) prev_char = prev_char_:resize(1) + prev_char_log_prob = prediction[1][prev_char[1]] else -- use sampling prediction:div(opt.temperature) -- scale by temperature local probs = torch.exp(prediction):squeeze() probs:div(torch.sum(probs)) -- renormalize so probs sum to one prev_char = torch.multinomial(probs:float(), 1):resize(1):float() + prev_char_log_prob = torch.log(probs[prev_char[1]]) -- use log of probs to account for temperature effect + end + + -- initialize the sample_log_prob (in case there was no seed text) to the emprical log prob of the selected char + if not sample_log_prob then + sample_log_prob = char_log_prob[ivocab[prev_char[1]]] + else + sample_log_prob = sample_log_prob + prev_char_log_prob end -- forward the rnn for next character @@ -159,3 +180,4 @@ for i=1, opt.length do end io.write('\n') io.flush() +gprint('\nSample log probability: ' .. sample_log_prob) diff --git a/train.lua b/train.lua index b6a576a9..1f5b3825 100644 --- a/train.lua +++ b/train.lua @@ -351,6 +351,7 @@ for i = 1, iterations do checkpoint.i = i checkpoint.epoch = epoch checkpoint.vocab = loader.vocab_mapping + checkpoint.char_log_prob = loader.char_log_prob torch.save(savefile, checkpoint) end diff --git a/util/CharSplitLMMinibatchLoader.lua b/util/CharSplitLMMinibatchLoader.lua index f3718f3d..8567eae0 100644 --- a/util/CharSplitLMMinibatchLoader.lua +++ b/util/CharSplitLMMinibatchLoader.lua @@ -14,12 +14,13 @@ function CharSplitLMMinibatchLoader.create(data_dir, batch_size, seq_length, spl local input_file = path.join(data_dir, 'input.txt') local vocab_file = path.join(data_dir, 'vocab.t7') local tensor_file = path.join(data_dir, 'data.t7') + local char_log_prob_file = path.join(data_dir, 'char_log_prob.t7') -- fetch file attributes to determine if we need to rerun preprocessing local run_prepro = false - if not (path.exists(vocab_file) or path.exists(tensor_file)) then + if not (path.exists(vocab_file) or path.exists(tensor_file) or path.exists(char_log_prob_file)) then -- prepro files do not exist, generate them - print('vocab.t7 and data.t7 do not exist. Running preprocessing...') + print('vocab.t7, data.t7, and/or char_log_prob.t7 do not exist. Running preprocessing...') run_prepro = true else -- check if the input file was modified since last time we @@ -27,20 +28,22 @@ function CharSplitLMMinibatchLoader.create(data_dir, batch_size, seq_length, spl local input_attr = lfs.attributes(input_file) local vocab_attr = lfs.attributes(vocab_file) local tensor_attr = lfs.attributes(tensor_file) - if input_attr.modification > vocab_attr.modification or input_attr.modification > tensor_attr.modification then - print('vocab.t7 or data.t7 detected as stale. Re-running preprocessing...') + local char_log_prob_attr = lfs.attributes(char_log_prob_file) + if input_attr.modification > vocab_attr.modification or input_attr.modification > tensor_attr.modification or input_attr.modification > char_log_prob_attr.modification then + print('vocab.t7, data.t7, or char_log_prob.t7 detected as stale. Re-running preprocessing...') run_prepro = true end end if run_prepro then -- construct a tensor with all the data, and vocab file print('one-time setup: preprocessing input text file ' .. input_file .. '...') - CharSplitLMMinibatchLoader.text_to_tensor(input_file, vocab_file, tensor_file) + CharSplitLMMinibatchLoader.text_to_tensor(input_file, vocab_file, tensor_file, char_log_prob_file) end print('loading data files...') local data = torch.load(tensor_file) self.vocab_mapping = torch.load(vocab_file) + self.char_log_prob = torch.load(char_log_prob_file) -- cut off the end so that it divides evenly local len = data:size(1) @@ -123,7 +126,7 @@ function CharSplitLMMinibatchLoader:next_batch(split_index) end -- *** STATIC method *** -function CharSplitLMMinibatchLoader.text_to_tensor(in_textfile, out_vocabfile, out_tensorfile) +function CharSplitLMMinibatchLoader.text_to_tensor(in_textfile, out_vocabfile, out_tensorfile, out_char_log_probfile) local timer = torch.Timer() print('loading text file...') @@ -139,12 +142,20 @@ function CharSplitLMMinibatchLoader.text_to_tensor(in_textfile, out_vocabfile, o rawdata = f:read(cache_len) repeat for char in rawdata:gmatch'.' do - if not unordered[char] then unordered[char] = true end + if not unordered[char] then + unordered[char] = 1 + else + unordered[char] = unordered[char] + 1 + end end tot_len = tot_len + #rawdata rawdata = f:read(cache_len) until not rawdata f:close() + -- construct character log probabilities + local log_tot_len = torch.log(tot_len) + for char, count in pairs(unordered) do unordered[char] = torch.log(count) - log_tot_len end + -- sort into a table (i.e. keys become 1..N) local ordered = {} for char in pairs(unordered) do ordered[#ordered + 1] = char end @@ -174,6 +185,8 @@ function CharSplitLMMinibatchLoader.text_to_tensor(in_textfile, out_vocabfile, o torch.save(out_vocabfile, vocab_mapping) print('saving ' .. out_tensorfile) torch.save(out_tensorfile, data) + print('saving ' .. out_char_log_probfile) + torch.save(out_char_log_probfile, unordered) end return CharSplitLMMinibatchLoader