diff --git a/Readme.md b/Readme.md index a14aa2de..ea88e8ec 100644 --- a/Readme.md +++ b/Readme.md @@ -1,13 +1,13 @@ # char-rnn -This code implements **multi-layer Recurrent Neural Network** (RNN, LSTM, and GRU) for training/sampling from character-level language models. In other words the model takes one text file as input and trains a Recurrent Neural Network that learns to predict the next character in a sequence. The RNN can then be used to generate text character by character that will look like the original training data. The context of this code base is described in detail in my [blog post](http://karpathy.github.io/2015/05/21/rnn-effectiveness/). +This code implements **multi-layer Recurrent Neural Network** (RNN, LSTM, BNLSTM, and GRU) for training/sampling from character-level language models. In other words the model takes one text file as input and trains a Recurrent Neural Network that learns to predict the next character in a sequence. The RNN can then be used to generate text character by character that will look like the original training data. The context of this code base is described in detail in my [blog post](http://karpathy.github.io/2015/05/21/rnn-effectiveness/). -If you are new to Torch/Lua/Neural Nets, it might be helpful to know that this code is really just a slightly more fancy version of this [100-line gist](https://gist.github.com/karpathy/d4dee566867f8291f086) that I wrote in Python/numpy. The code in this repo additionally: allows for multiple layers, uses an LSTM instead of a vanilla RNN, has more supporting code for model checkpointing, and is of course much more efficient since it uses mini-batches and can run on a GPU. +If you are new to Torch/Lua/Neural Nets, it might be helpful to know that this code is really just a slightly more fancy version of this [100-line gist](https://gist.github.com/karpathy/d4dee566867f8291f086) that I wrote in Python/numpy. The code in this repo additionally: allows for multiple layers, uses an BNLSTM instead of a vanilla RNN, has more supporting code for model checkpointing, and is of course much more efficient since it uses mini-batches and can run on a GPU. ## Update: torch-rnn -[Justin Johnson](http://cs.stanford.edu/people/jcjohns/) (@jcjohnson) recently re-implemented char-rnn from scratch with a much nicer/smaller/cleaner/faster Torch code base. It's under the name [torch-rnn](https://github.com/jcjohnson/torch-rnn). It uses Adam for optimization and hard-codes the RNN/LSTM forward/backward passes for space/time efficiency. This also avoids headaches with cloning models in this repo. In other words, torch-rnn should be the default char-rnn implemention to use now instead of the one in this code base. +[Justin Johnson](http://cs.stanford.edu/people/jcjohns/) (@jcjohnson) recently re-implemented char-rnn from scratch with a much nicer/smaller/cleaner/faster Torch code base. It's under the name [torch-rnn](https://github.com/jcjohnson/torch-rnn). It uses Adam for optimization and hard-codes the RNN/BNLSTM forward/backward passes for space/time efficiency. This also avoids headaches with cloning models in this repo. In other words, torch-rnn should be the default char-rnn implemention to use now instead of the one in this code base. ## Requirements diff --git a/model/LSTM.lua b/model/LSTM.lua index c9a738bd..bae8b7c6 100644 --- a/model/LSTM.lua +++ b/model/LSTM.lua @@ -1,6 +1,6 @@ local LSTM = {} -function LSTM.lstm(input_size, rnn_size, n, dropout) +function LSTM.lstm(input_size, rnn_size, n, dropout, bn) dropout = dropout or 0 -- there will be 2*n+1 inputs @@ -26,9 +26,31 @@ function LSTM.lstm(input_size, rnn_size, n, dropout) if dropout > 0 then x = nn.Dropout(dropout)(x) end -- apply dropout, if any input_size_L = rnn_size end + -- recurrent batch normalization + -- http://arxiv.org/abs/1603.09025 + local bn_wx, bn_wh, bn_c + if bn then + bn_wx = nn.BatchNormalization(4 * rnn_size, 1e-5, 0.1, true) + bn_wh = nn.BatchNormalization(4 * rnn_size, 1e-5, 0.1, true) + bn_c = nn.BatchNormalization(rnn_size, 1e-5, 0.1, true) + + -- initialise beta=0, gamma=0.1 + bn_wx.weight:fill(0.1) + bn_wx.bias:zero() + bn_wh.weight:fill(0.1) + bn_wh.bias:zero() + bn_c.weight:fill(0.1) + bn_c.bias:zero() + else + bn_wx = nn.Identity() + bn_wh = nn.Identity() + bn_c = nn.Identity() + end -- evaluate the input sums at once for efficiency - local i2h = nn.Linear(input_size_L, 4 * rnn_size)(x):annotate{name='i2h_'..L} - local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h):annotate{name='h2h_'..L} + local i2h = bn_wx(nn.Linear(input_size_L, 4 * rnn_size)(x):annotate{name='i2h_'..L} + ):annotate{name='bn_wx_'..L} + local h2h = bn_wh(nn.Linear(rnn_size, 4 * rnn_size, false)(prev_h):annotate{name='h2h_'..L} + ):annotate{name='bn_wh_'..L} local all_input_sums = nn.CAddTable()({i2h, h2h}) local reshaped = nn.Reshape(4, rnn_size)(all_input_sums) @@ -45,7 +67,7 @@ function LSTM.lstm(input_size, rnn_size, n, dropout) nn.CMulTable()({in_gate, in_transform}) }) -- gated cells form the output - local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)}) + local next_h = nn.CMulTable()({out_gate, nn.Tanh()(bn_c(next_c):annotate{name='bn_c_'..L})}) table.insert(outputs, next_c) table.insert(outputs, next_h) diff --git a/sample.lua b/sample.lua index e22ece74..bcb11982 100644 --- a/sample.lua +++ b/sample.lua @@ -84,8 +84,12 @@ if not lfs.attributes(opt.model, 'mode') then gprint('Error: File ' .. opt.model .. ' does not exist. Are you sure you didn\'t forget to prepend cv/ ?') end checkpoint = torch.load(opt.model) +clones = checkpoint.clones protos = checkpoint.protos protos.rnn:evaluate() -- put in eval mode so that dropout works properly +for t = 1, #clones.rnn do + clones.rnn[t]:evaluate() +end -- initialize the vocabulary (and its inverted version) local vocab = checkpoint.vocab @@ -102,7 +106,7 @@ for L = 1,checkpoint.opt.num_layers do if opt.gpuid >= 0 and opt.opencl == 0 then h_init = h_init:cuda() end if opt.gpuid >= 0 and opt.opencl == 1 then h_init = h_init:cl() end table.insert(current_state, h_init:clone()) - if checkpoint.opt.model == 'lstm' then + if checkpoint.opt.model == 'lstm' or checkpoint.opt.model == 'bnlstm' then table.insert(current_state, h_init:clone()) end end @@ -113,16 +117,18 @@ local seed_text = opt.primetext if string.len(seed_text) > 0 then gprint('seeding with ' .. seed_text) gprint('--------------------------') + local t = 1 for c in seed_text:gmatch'.' do prev_char = torch.Tensor{vocab[c]} io.write(ivocab[prev_char[1]]) if opt.gpuid >= 0 and opt.opencl == 0 then prev_char = prev_char:cuda() end if opt.gpuid >= 0 and opt.opencl == 1 then prev_char = prev_char:cl() end - local lst = protos.rnn:forward{prev_char, unpack(current_state)} + local lst = clones.rnn[t]:forward{prev_char, unpack(current_state)} -- lst is a list of [state1,state2,..stateN,output]. We want everything but last piece current_state = {} for i=1,state_size do table.insert(current_state, lst[i]) end prediction = lst[#lst] -- last element holds the log probabilities + t = t + 1 end else -- fill with uniform probabilities over characters (? hmm) @@ -150,7 +156,7 @@ for i=1, opt.length do end -- forward the rnn for next character - local lst = protos.rnn:forward{prev_char, unpack(current_state)} + local lst = clones.rnn[(i - 1) % #clones.rnn + 1]:forward{prev_char, unpack(current_state)} current_state = {} for i=1,state_size do table.insert(current_state, lst[i]) end prediction = lst[#lst] -- last element holds the log probabilities diff --git a/train.lua b/train.lua index b6a576a9..26446250 100644 --- a/train.lua +++ b/train.lua @@ -37,7 +37,7 @@ cmd:option('-data_dir','data/tinyshakespeare','data directory. Should contain th -- model params cmd:option('-rnn_size', 128, 'size of LSTM internal state') cmd:option('-num_layers', 2, 'number of layers in the LSTM') -cmd:option('-model', 'lstm', 'lstm,gru or rnn') +cmd:option('-model', 'bnlstm', 'lstm,bnlstm,gru or rnn') -- optimization cmd:option('-learning_rate',2e-3,'learning rate') cmd:option('-learning_rate_decay',0.97,'learning rate decay') @@ -57,7 +57,7 @@ cmd:option('-seed',123,'torch manual random number generator seed') cmd:option('-print_every',1,'how many steps/minibatches between printing out the loss') cmd:option('-eval_val_every',1000,'every how many iterations should we evaluate on validation data?') cmd:option('-checkpoint_dir', 'cv', 'output directory where checkpoints get written') -cmd:option('-savefile','lstm','filename to autosave the checkpont to. Will be inside checkpoint_dir/') +cmd:option('-savefile','bnlstm','filename to autosave the checkpont to. Will be inside checkpoint_dir/') cmd:option('-accurate_gpu_timing',0,'set this flag to 1 to get precise timings when using GPU. Might make code bit slower but reports accurate timings.') -- GPU/CPU cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU') @@ -120,6 +120,7 @@ local do_random_init = true if string.len(opt.init_from) > 0 then print('loading a model from checkpoint ' .. opt.init_from) local checkpoint = torch.load(opt.init_from) + clones = checkpoint.clones protos = checkpoint.protos -- make sure the vocabs are the same local vocab_compatible = true @@ -146,6 +147,8 @@ else protos = {} if opt.model == 'lstm' then protos.rnn = LSTM.lstm(vocab_size, opt.rnn_size, opt.num_layers, opt.dropout) + elseif opt.model == 'bnlstm' then + protos.rnn = LSTM.lstm(vocab_size, opt.rnn_size, opt.num_layers, opt.dropout, true) elseif opt.model == 'gru' then protos.rnn = GRU.gru(vocab_size, opt.rnn_size, opt.num_layers, opt.dropout) elseif opt.model == 'rnn' then @@ -161,7 +164,7 @@ for L=1,opt.num_layers do if opt.gpuid >=0 and opt.opencl == 0 then h_init = h_init:cuda() end if opt.gpuid >=0 and opt.opencl == 1 then h_init = h_init:cl() end table.insert(init_state, h_init:clone()) - if opt.model == 'lstm' then + if opt.model == 'lstm' or opt.model == 'bnlstm' then table.insert(init_state, h_init:clone()) end end @@ -178,28 +181,44 @@ end params, grad_params = model_utils.combine_all_parameters(protos.rnn) -- initialization -if do_random_init then +if do_random_init and opt.model ~= 'bnlstm' then params:uniform(-0.08, 0.08) -- small uniform numbers end -- initialize the LSTM forget gates with slightly higher biases to encourage remembering in the beginning -if opt.model == 'lstm' then +if opt.model == 'lstm' or opt.model == 'bnlstm' then for layer_idx = 1, opt.num_layers do for _,node in ipairs(protos.rnn.forwardnodes) do if node.data.annotations.name == "i2h_" .. layer_idx then - print('setting forget gate biases to 1 in LSTM layer ' .. layer_idx) + print('setting forget gate biases to 1 in ' .. opt.model:upper() .. ' layer ' .. layer_idx) -- the gates are, in order, i,f,o,g, so f is the 2nd block of weights node.data.module.bias[{{opt.rnn_size+1, 2*opt.rnn_size}}]:fill(1.0) end end end end +-- initialize the BNLSTM gamma and beta parameters with 0.1 and 0 respectively +if opt.model == 'bnlstm' then + for layer_idx = 1, opt.num_layers do + for _,node in ipairs(protos.rnn.forwardnodes) do + if node.data.annotations.name == "bn_wx_" .. layer_idx or + node.data.annotations.name == "bn_wh_" .. layer_idx or + node.data.annotations.name == "bn_c_" .. layer_idx then + print('setting gamma to 0.1 and beta to 0 in ' .. node.data.annotations.name .. ' BNLSTM layer ' .. layer_idx) + node.data.module.weight:fill(0.1) + node.data.module.bias:zero() + end + end + end +end print('number of parameters in the model: ' .. params:nElement()) -- make a bunch of clones after flattening, as that reallocates memory -clones = {} -for name,proto in pairs(protos) do - print('cloning ' .. name) - clones[name] = model_utils.clone_many_times(proto, opt.seq_length, not proto.parameters) +if clones == nil then + clones = {} + for name,proto in pairs(protos) do + print('cloning ' .. name) + clones[name] = model_utils.clone_many_times(proto, opt.seq_length, not proto.parameters) + end end -- preprocessing helper function @@ -343,6 +362,7 @@ for i = 1, iterations do local savefile = string.format('%s/lm_%s_epoch%.2f_%.4f.t7', opt.checkpoint_dir, opt.savefile, epoch, val_loss) print('saving checkpoint to ' .. savefile) local checkpoint = {} + checkpoint.clones = clones checkpoint.protos = protos checkpoint.opt = opt checkpoint.train_losses = train_losses