Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Recurrent Batch Normalization #163

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Readme.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@

# char-rnn

This code implements **multi-layer Recurrent Neural Network** (RNN, LSTM, and GRU) for training/sampling from character-level language models. In other words the model takes one text file as input and trains a Recurrent Neural Network that learns to predict the next character in a sequence. The RNN can then be used to generate text character by character that will look like the original training data. The context of this code base is described in detail in my [blog post](http://karpathy.github.io/2015/05/21/rnn-effectiveness/).
This code implements **multi-layer Recurrent Neural Network** (RNN, LSTM, BNLSTM, and GRU) for training/sampling from character-level language models. In other words the model takes one text file as input and trains a Recurrent Neural Network that learns to predict the next character in a sequence. The RNN can then be used to generate text character by character that will look like the original training data. The context of this code base is described in detail in my [blog post](http://karpathy.github.io/2015/05/21/rnn-effectiveness/).

If you are new to Torch/Lua/Neural Nets, it might be helpful to know that this code is really just a slightly more fancy version of this [100-line gist](https://gist.github.com/karpathy/d4dee566867f8291f086) that I wrote in Python/numpy. The code in this repo additionally: allows for multiple layers, uses an LSTM instead of a vanilla RNN, has more supporting code for model checkpointing, and is of course much more efficient since it uses mini-batches and can run on a GPU.
If you are new to Torch/Lua/Neural Nets, it might be helpful to know that this code is really just a slightly more fancy version of this [100-line gist](https://gist.github.com/karpathy/d4dee566867f8291f086) that I wrote in Python/numpy. The code in this repo additionally: allows for multiple layers, uses an BNLSTM instead of a vanilla RNN, has more supporting code for model checkpointing, and is of course much more efficient since it uses mini-batches and can run on a GPU.

## Update: torch-rnn

[Justin Johnson](http://cs.stanford.edu/people/jcjohns/) (@jcjohnson) recently re-implemented char-rnn from scratch with a much nicer/smaller/cleaner/faster Torch code base. It's under the name [torch-rnn](https://github.com/jcjohnson/torch-rnn). It uses Adam for optimization and hard-codes the RNN/LSTM forward/backward passes for space/time efficiency. This also avoids headaches with cloning models in this repo. In other words, torch-rnn should be the default char-rnn implemention to use now instead of the one in this code base.
[Justin Johnson](http://cs.stanford.edu/people/jcjohns/) (@jcjohnson) recently re-implemented char-rnn from scratch with a much nicer/smaller/cleaner/faster Torch code base. It's under the name [torch-rnn](https://github.com/jcjohnson/torch-rnn). It uses Adam for optimization and hard-codes the RNN/BNLSTM forward/backward passes for space/time efficiency. This also avoids headaches with cloning models in this repo. In other words, torch-rnn should be the default char-rnn implemention to use now instead of the one in this code base.

## Requirements

Expand Down
30 changes: 26 additions & 4 deletions model/LSTM.lua
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

local LSTM = {}
function LSTM.lstm(input_size, rnn_size, n, dropout)
function LSTM.lstm(input_size, rnn_size, n, dropout, bn)
dropout = dropout or 0

-- there will be 2*n+1 inputs
Expand All @@ -26,9 +26,31 @@ function LSTM.lstm(input_size, rnn_size, n, dropout)
if dropout > 0 then x = nn.Dropout(dropout)(x) end -- apply dropout, if any
input_size_L = rnn_size
end
-- recurrent batch normalization
-- http://arxiv.org/abs/1603.09025
local bn_wx, bn_wh, bn_c
if bn then
bn_wx = nn.BatchNormalization(4 * rnn_size, 1e-5, 0.1, true)
bn_wh = nn.BatchNormalization(4 * rnn_size, 1e-5, 0.1, true)
bn_c = nn.BatchNormalization(rnn_size, 1e-5, 0.1, true)

-- initialise beta=0, gamma=0.1
bn_wx.weight:fill(0.1)
bn_wx.bias:zero()
bn_wh.weight:fill(0.1)
bn_wh.bias:zero()
bn_c.weight:fill(0.1)
bn_c.bias:zero()
else
bn_wx = nn.Identity()
bn_wh = nn.Identity()
bn_c = nn.Identity()
end
-- evaluate the input sums at once for efficiency
local i2h = nn.Linear(input_size_L, 4 * rnn_size)(x):annotate{name='i2h_'..L}
local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h):annotate{name='h2h_'..L}
local i2h = bn_wx(nn.Linear(input_size_L, 4 * rnn_size)(x):annotate{name='i2h_'..L}
):annotate{name='bn_wx_'..L}
local h2h = bn_wh(nn.Linear(rnn_size, 4 * rnn_size, false)(prev_h):annotate{name='h2h_'..L}
):annotate{name='bn_wh_'..L}
local all_input_sums = nn.CAddTable()({i2h, h2h})

local reshaped = nn.Reshape(4, rnn_size)(all_input_sums)
Expand All @@ -45,7 +67,7 @@ function LSTM.lstm(input_size, rnn_size, n, dropout)
nn.CMulTable()({in_gate, in_transform})
})
-- gated cells form the output
local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)})
local next_h = nn.CMulTable()({out_gate, nn.Tanh()(bn_c(next_c):annotate{name='bn_c_'..L})})

table.insert(outputs, next_c)
table.insert(outputs, next_h)
Expand Down
12 changes: 9 additions & 3 deletions sample.lua
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,12 @@ if not lfs.attributes(opt.model, 'mode') then
gprint('Error: File ' .. opt.model .. ' does not exist. Are you sure you didn\'t forget to prepend cv/ ?')
end
checkpoint = torch.load(opt.model)
clones = checkpoint.clones
protos = checkpoint.protos
protos.rnn:evaluate() -- put in eval mode so that dropout works properly
for t = 1, #clones.rnn do
clones.rnn[t]:evaluate()
end

-- initialize the vocabulary (and its inverted version)
local vocab = checkpoint.vocab
Expand All @@ -102,7 +106,7 @@ for L = 1,checkpoint.opt.num_layers do
if opt.gpuid >= 0 and opt.opencl == 0 then h_init = h_init:cuda() end
if opt.gpuid >= 0 and opt.opencl == 1 then h_init = h_init:cl() end
table.insert(current_state, h_init:clone())
if checkpoint.opt.model == 'lstm' then
if checkpoint.opt.model == 'lstm' or checkpoint.opt.model == 'bnlstm' then
table.insert(current_state, h_init:clone())
end
end
Expand All @@ -113,16 +117,18 @@ local seed_text = opt.primetext
if string.len(seed_text) > 0 then
gprint('seeding with ' .. seed_text)
gprint('--------------------------')
local t = 1
for c in seed_text:gmatch'.' do
prev_char = torch.Tensor{vocab[c]}
io.write(ivocab[prev_char[1]])
if opt.gpuid >= 0 and opt.opencl == 0 then prev_char = prev_char:cuda() end
if opt.gpuid >= 0 and opt.opencl == 1 then prev_char = prev_char:cl() end
local lst = protos.rnn:forward{prev_char, unpack(current_state)}
local lst = clones.rnn[t]:forward{prev_char, unpack(current_state)}
-- lst is a list of [state1,state2,..stateN,output]. We want everything but last piece
current_state = {}
for i=1,state_size do table.insert(current_state, lst[i]) end
prediction = lst[#lst] -- last element holds the log probabilities
t = t + 1
end
else
-- fill with uniform probabilities over characters (? hmm)
Expand Down Expand Up @@ -150,7 +156,7 @@ for i=1, opt.length do
end

-- forward the rnn for next character
local lst = protos.rnn:forward{prev_char, unpack(current_state)}
local lst = clones.rnn[(i - 1) % #clones.rnn + 1]:forward{prev_char, unpack(current_state)}
current_state = {}
for i=1,state_size do table.insert(current_state, lst[i]) end
prediction = lst[#lst] -- last element holds the log probabilities
Expand Down
40 changes: 30 additions & 10 deletions train.lua
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ cmd:option('-data_dir','data/tinyshakespeare','data directory. Should contain th
-- model params
cmd:option('-rnn_size', 128, 'size of LSTM internal state')
cmd:option('-num_layers', 2, 'number of layers in the LSTM')
cmd:option('-model', 'lstm', 'lstm,gru or rnn')
cmd:option('-model', 'bnlstm', 'lstm,bnlstm,gru or rnn')
-- optimization
cmd:option('-learning_rate',2e-3,'learning rate')
cmd:option('-learning_rate_decay',0.97,'learning rate decay')
Expand All @@ -57,7 +57,7 @@ cmd:option('-seed',123,'torch manual random number generator seed')
cmd:option('-print_every',1,'how many steps/minibatches between printing out the loss')
cmd:option('-eval_val_every',1000,'every how many iterations should we evaluate on validation data?')
cmd:option('-checkpoint_dir', 'cv', 'output directory where checkpoints get written')
cmd:option('-savefile','lstm','filename to autosave the checkpont to. Will be inside checkpoint_dir/')
cmd:option('-savefile','bnlstm','filename to autosave the checkpont to. Will be inside checkpoint_dir/')
cmd:option('-accurate_gpu_timing',0,'set this flag to 1 to get precise timings when using GPU. Might make code bit slower but reports accurate timings.')
-- GPU/CPU
cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU')
Expand Down Expand Up @@ -120,6 +120,7 @@ local do_random_init = true
if string.len(opt.init_from) > 0 then
print('loading a model from checkpoint ' .. opt.init_from)
local checkpoint = torch.load(opt.init_from)
clones = checkpoint.clones
protos = checkpoint.protos
-- make sure the vocabs are the same
local vocab_compatible = true
Expand All @@ -146,6 +147,8 @@ else
protos = {}
if opt.model == 'lstm' then
protos.rnn = LSTM.lstm(vocab_size, opt.rnn_size, opt.num_layers, opt.dropout)
elseif opt.model == 'bnlstm' then
protos.rnn = LSTM.lstm(vocab_size, opt.rnn_size, opt.num_layers, opt.dropout, true)
elseif opt.model == 'gru' then
protos.rnn = GRU.gru(vocab_size, opt.rnn_size, opt.num_layers, opt.dropout)
elseif opt.model == 'rnn' then
Expand All @@ -161,7 +164,7 @@ for L=1,opt.num_layers do
if opt.gpuid >=0 and opt.opencl == 0 then h_init = h_init:cuda() end
if opt.gpuid >=0 and opt.opencl == 1 then h_init = h_init:cl() end
table.insert(init_state, h_init:clone())
if opt.model == 'lstm' then
if opt.model == 'lstm' or opt.model == 'bnlstm' then
table.insert(init_state, h_init:clone())
end
end
Expand All @@ -178,28 +181,44 @@ end
params, grad_params = model_utils.combine_all_parameters(protos.rnn)

-- initialization
if do_random_init then
if do_random_init and opt.model ~= 'bnlstm' then
params:uniform(-0.08, 0.08) -- small uniform numbers
end
-- initialize the LSTM forget gates with slightly higher biases to encourage remembering in the beginning
if opt.model == 'lstm' then
if opt.model == 'lstm' or opt.model == 'bnlstm' then
for layer_idx = 1, opt.num_layers do
for _,node in ipairs(protos.rnn.forwardnodes) do
if node.data.annotations.name == "i2h_" .. layer_idx then
print('setting forget gate biases to 1 in LSTM layer ' .. layer_idx)
print('setting forget gate biases to 1 in ' .. opt.model:upper() .. ' layer ' .. layer_idx)
-- the gates are, in order, i,f,o,g, so f is the 2nd block of weights
node.data.module.bias[{{opt.rnn_size+1, 2*opt.rnn_size}}]:fill(1.0)
end
end
end
end
-- initialize the BNLSTM gamma and beta parameters with 0.1 and 0 respectively
if opt.model == 'bnlstm' then
for layer_idx = 1, opt.num_layers do
for _,node in ipairs(protos.rnn.forwardnodes) do
if node.data.annotations.name == "bn_wx_" .. layer_idx or
node.data.annotations.name == "bn_wh_" .. layer_idx or
node.data.annotations.name == "bn_c_" .. layer_idx then
print('setting gamma to 0.1 and beta to 0 in ' .. node.data.annotations.name .. ' BNLSTM layer ' .. layer_idx)
node.data.module.weight:fill(0.1)
node.data.module.bias:zero()
end
end
end
end

print('number of parameters in the model: ' .. params:nElement())
-- make a bunch of clones after flattening, as that reallocates memory
clones = {}
for name,proto in pairs(protos) do
print('cloning ' .. name)
clones[name] = model_utils.clone_many_times(proto, opt.seq_length, not proto.parameters)
if clones == nil then
clones = {}
for name,proto in pairs(protos) do
print('cloning ' .. name)
clones[name] = model_utils.clone_many_times(proto, opt.seq_length, not proto.parameters)
end
end

-- preprocessing helper function
Expand Down Expand Up @@ -343,6 +362,7 @@ for i = 1, iterations do
local savefile = string.format('%s/lm_%s_epoch%.2f_%.4f.t7', opt.checkpoint_dir, opt.savefile, epoch, val_loss)
print('saving checkpoint to ' .. savefile)
local checkpoint = {}
checkpoint.clones = clones
checkpoint.protos = protos
checkpoint.opt = opt
checkpoint.train_losses = train_losses
Expand Down