forked from gwtaylor/theano-hf
-
Notifications
You must be signed in to change notification settings - Fork 0
/
hf.py
330 lines (271 loc) · 12.6 KB
/
hf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
# Author: Nicolas Boulanger-Lewandowski
# University of Montreal, 2012
import numpy, sys
import theano
import theano.tensor as T
import cPickle
import os
class hf_optimizer:
'''Black-box Theano-based Hessian-free optimizer.
See (Martens, ICML 2010) and (Martens & Sutskever, ICML 2011) for details.
Useful functions:
__init__ :
Compiles necessary Theano functions from symbolic expressions.
train :
Performs HF optimization following the above references.'''
def __init__(self, p, inputs, s, costs, h=None):
'''Constructs and compiles the necessary Theano functions.
p : list of Theano shared variables
Parameters of the model to be optimized.
inputs : list of Theano variables
Symbolic variables that are inputs to your graph (they should also
include your model 'output'). Your training examples must fit these.
s : Theano variable
Symbolic variable with respect to which the Hessian of the objective is
positive-definite, implicitly defining the Gauss-Newton matrix. Typically,
it is the activation of the output layer.
costs : list of Theano variables
Monitoring costs, the first of which will be the optimized objective.
h: Theano variable or None
Structural damping is applied to this variable (typically the hidden units
of an RNN).'''
self.p = p
self.shapes = [i.get_value().shape for i in p]
self.sizes = map(numpy.prod, self.shapes)
self.positions = numpy.cumsum([0] + self.sizes)[:-1]
g = T.grad(costs[0], p)
g = map(T.as_tensor_variable, g) # for CudaNdarray
self.f_gc = theano.function(inputs, g + costs) # during gradient computation
self.f_cost = theano.function(inputs, costs) # for quick cost evaluation
symbolic_types = T.scalar, T.vector, T.matrix, T.tensor3, T.tensor4
coefficient = T.scalar() # this is lambda*mu
if h is not None: # structural damping with cross-entropy
h_constant = symbolic_types[h.ndim]() # T.Rop does not support `consider_constant` yet, so use `givens`
structural_damping = coefficient * (-h_constant*T.log(h) - (1-h_constant)*T.log(1-h)).sum() / h.shape[0]
costs[0] += structural_damping
givens = {h_constant: h}
else:
givens = {}
# this computes the product Gv = J'HJv (G is the Gauss-Newton matrix)
v = [symbolic_types[len(i)]() for i in self.shapes]
Jv = T.Rop(s, p, v)
HJv = T.grad(T.sum(T.grad(costs[0], s)*Jv), s, consider_constant=[Jv])
Gv = T.grad(T.sum(HJv*s), p, consider_constant=[HJv, Jv])
Gv = map(T.as_tensor_variable, Gv) # for CudaNdarray
self.function_Gv = theano.function(inputs + v + [coefficient], Gv, givens=givens,
on_unused_input='ignore')
def quick_cost(self, delta=0):
# quickly evaluate objective (costs[0]) over the CG batch
# for `current params` + delta
# delta can be a flat vector or a list (else it is not used)
if isinstance(delta, numpy.ndarray):
delta = self.flat_to_list(delta)
if type(delta) in (list, tuple):
for i, d in zip(self.p, delta):
i.set_value(i.get_value() + d)
cost = numpy.mean([self.f_cost(*i)[0] for i in self.cg_dataset.iterate(update=False)])
if type(delta) in (list, tuple):
for i, d in zip(self.p, delta):
i.set_value(i.get_value() - d)
return cost
def cg(self, b):
if self.preconditioner:
M = self.lambda_ * numpy.ones_like(b)
for inputs in self.cg_dataset.iterate(update=False):
M += self.list_to_flat(self.f_gc(*inputs)[:len(self.p)])**2 #/ self.cg_dataset.number_batches**2
#print 'precond~%.3f,' % (M - self.lambda_).mean(),
M **= -0.75 # actually 1/M
sys.stdout.flush()
else:
M = 1.0
x = self.cg_last_x if hasattr(self, 'cg_last_x') else numpy.zeros_like(b) # sharing information between CG runs
r = b - self.batch_Gv(x)
d = M*r
delta_new = numpy.dot(r, d)
phi = []
backtracking = []
backspaces = 0
for i in xrange(1, 1 + self.max_cg_iterations):
# adapted from http://www.cs.cmu.edu/~quake-papers/painless-conjugate-gradient.pdf (p.51)
q = self.batch_Gv(d)
dq = numpy.dot(d, q)
#assert dq > 0, 'negative curvature'
alpha = delta_new / dq
x = x + alpha*d
r = r - alpha*q
s = M*r
delta_old = delta_new
delta_new = numpy.dot(r, s)
d = s + (delta_new / delta_old) * d
if i >= int(numpy.ceil(1.3**len(backtracking))):
backtracking.append((self.quick_cost(x), x.copy(), i))
phi_i = -0.5 * numpy.dot(x, r + b)
phi.append(phi_i)
progress = ' [CG iter %i, phi=%+.5f, cost=%.5f]' % (i, phi_i, backtracking[-1][0])
sys.stdout.write('\b'*backspaces + progress)
sys.stdout.flush()
backspaces = len(progress)
k = max(10, i/10)
if i > k and phi_i < 0 and (phi_i - phi[-k-1]) / phi_i < k*0.0005:
break
self.cg_last_x = x.copy()
if self.global_backtracking:
j = numpy.argmin([b[0] for b in backtracking])
else:
j = len(backtracking) - 1
while j > 0 and backtracking[j-1][0] < backtracking[j][0]:
j -= 1
print ' backtracked %i/%i' % (backtracking[j][2], i),
sys.stdout.flush()
return backtracking[j] + (i,)
def flat_to_list(self, vector):
return [vector[position:position + size].reshape(shape) for shape, size, position in zip(self.shapes, self.sizes, self.positions)]
def list_to_flat(self, l):
return numpy.concatenate([i.flatten() for i in l])
def batch_Gv(self, vector, lambda_=None):
v = self.flat_to_list(vector)
if lambda_ is None: lambda_ = self.lambda_
result = lambda_*vector # Tikhonov damping
for inputs in self.cg_dataset.iterate(False):
result += self.list_to_flat(self.function_Gv(*(inputs + v + [lambda_*self.mu]))) / self.cg_dataset.number_batches
return result
def train(self, gradient_dataset, cg_dataset, initial_lambda=0.1, mu=0.03, global_backtracking=False, preconditioner=False, max_cg_iterations=250, num_updates=100, validation=None, validation_frequency=1, patience=numpy.inf, save_progress=None):
'''Performs HF training.
gradient_dataset : SequenceDataset-like object
Defines batches used to compute the gradient.
The `iterate(update=True)` method should yield shuffled training examples
(tuples of variables matching your graph inputs).
The same examples MUST be returned between multiple calls to iterator(),
unless update is True, in which case the next batch should be different.
cg_dataset : SequenceDataset-like object
Defines batches used to compute CG iterations.
initial_lambda : float
Initial value of the Tikhonov damping coefficient.
mu : float
Coefficient for structural damping.
global_backtracking : Boolean
If True, backtracks as much as necessary to find the global minimum among
all CG iterates. Else, Martens' heuristic is used.
preconditioner : Boolean
Whether to use Martens' preconditioner.
max_cg_iterations : int
CG stops after this many iterations regardless of the stopping criterion.
num_updates : int
Training stops after this many parameter updates regardless of `patience`.
validation: SequenceDataset object, (lambda : tuple) callback, or None
If a SequenceDataset object is provided, the training monitoring costs
will be evaluated on that validation dataset.
If a callback is provided, it should return a list of validation costs
for monitoring, the first of which is also used for early stopping.
If None, no early stopping nor validation monitoring is performed.
validation_frequency: int
Validation is performed every `validation_frequency` updates.
patience: int
Training stops after `patience` updates without improvement in validation
cost.
save_progress: string or None
A checkpoint is automatically saved at this location after each update.
Call the `train` function again with the same parameters to resume
training.'''
self.lambda_ = initial_lambda
self.mu = mu
self.global_backtracking = global_backtracking
self.cg_dataset = cg_dataset
self.preconditioner = preconditioner
self.max_cg_iterations = max_cg_iterations
best = [0, numpy.inf, None] # iteration, cost, params
first_iteration = 1
if isinstance(save_progress, str) and os.path.isfile(save_progress):
save = cPickle.load(file(save_progress))
self.cg_last_x, best, self.lambda_, first_iteration, init_p = save
first_iteration += 1
for i, j in zip(self.p, init_p): i.set_value(j)
print '* recovered saved model'
try:
for u in xrange(first_iteration, 1 + num_updates):
print 'update %i/%i,' % (u, num_updates),
sys.stdout.flush()
gradient = numpy.zeros(sum(self.sizes), dtype=theano.config.floatX)
costs = []
for inputs in gradient_dataset.iterate(update=True):
result = self.f_gc(*inputs)
gradient += self.list_to_flat(result[:len(self.p)]) / gradient_dataset.number_batches
costs.append(result[len(self.p):])
print 'cost=', numpy.mean(costs, axis=0),
print 'lambda=%.5f,' % self.lambda_,
sys.stdout.flush()
after_cost, flat_delta, backtracking, num_cg_iterations = self.cg(-gradient)
delta_cost = numpy.dot(flat_delta, gradient + 0.5*self.batch_Gv(flat_delta, lambda_=0)) # disable damping
before_cost = self.quick_cost()
for i, delta in zip(self.p, self.flat_to_list(flat_delta)):
i.set_value(i.get_value() + delta)
cg_dataset.update()
rho = (after_cost - before_cost) / delta_cost # Levenberg-Marquardt
#print 'rho=%f' %rho,
if rho < 0.25:
self.lambda_ *= 1.5
elif rho > 0.75:
self.lambda_ /= 1.5
if validation is not None and u % validation_frequency == 0:
if validation.__class__.__name__ == 'SequenceDataset':
costs = numpy.mean([self.f_cost(*i) for i in validation.iterate()], axis=0)
elif callable(validation):
costs = validation()
print 'validation=', costs,
if costs[0] < best[1]:
best = u, costs[0], [i.get_value().copy() for i in self.p]
print '*NEW BEST',
if isinstance(save_progress, str):
# do not save dataset states
save = self.cg_last_x, best, self.lambda_, u, [i.get_value().copy() for i in self.p]
cPickle.dump(save, file(save_progress, 'wb'), cPickle.HIGHEST_PROTOCOL)
if u - best[0] > patience:
print 'PATIENCE ELAPSED, BAILING OUT'
break
print
sys.stdout.flush()
except KeyboardInterrupt:
print 'Interrupted by user.'
if best[2] is None:
best[2] = [i.get_value().copy() for i in self.p]
return best[2]
class SequenceDataset:
'''Slices, shuffles and manages a small dataset for the HF optimizer.'''
def __init__(self, data, batch_size, number_batches, minimum_size=10):
'''SequenceDataset __init__
data : list of lists of numpy arrays
Your dataset will be provided as a list (one list for each graph input) of
variable-length tensors that will be used as mini-batches. Typically, each
tensor is a sequence or a set of examples.
batch_size : int or None
If an int, the mini-batches will be further split in chunks of length
`batch_size`. This is useful for slicing subsequences or provide the full
dataset in a single tensor to be split here. All tensors in `data` must
then have the same leading dimension.
number_batches : int
Number of mini-batches over which you iterate to compute a gradient or
Gauss-Newton matrix product.
minimum_size : int
Reject all mini-batches that end up smaller than this length.'''
self.current_batch = 0
self.number_batches = number_batches
self.items = []
for i_sequence in xrange(len(data[0])):
if batch_size is None:
self.items.append([data[i][i_sequence] for i in xrange(len(data))])
else:
for i_step in xrange(0, len(data[0][i_sequence]) - minimum_size + 1, batch_size):
self.items.append([data[i][i_sequence][i_step:i_step + batch_size] for i in xrange(len(data))])
self.shuffle()
def shuffle(self):
numpy.random.shuffle(self.items)
def iterate(self, update=True):
for b in xrange(self.number_batches):
yield self.items[(self.current_batch + b) % len(self.items)]
if update: self.update()
def update(self):
if self.current_batch + self.number_batches >= len(self.items):
self.shuffle()
self.current_batch = 0
else:
self.current_batch += self.number_batches