From 7d5ad0ebee834da03c83608ce75f486b8ad70cd6 Mon Sep 17 00:00:00 2001 From: Martin Krasser Date: Sat, 6 Oct 2018 09:14:55 +0200 Subject: [PATCH 1/2] Copy keras directory to keras_2 and rename keras to keras_1 --- {keras => keras_1}/README.md | 0 {keras => keras_1}/cifar10_cnn.py | 0 {keras => keras_1}/weightnorm.py | 0 keras_2/README.md | 6 + keras_2/cifar10_cnn.py | 104 +++++++++++++++ keras_2/weightnorm.py | 210 ++++++++++++++++++++++++++++++ 6 files changed, 320 insertions(+) rename {keras => keras_1}/README.md (100%) rename {keras => keras_1}/cifar10_cnn.py (100%) rename {keras => keras_1}/weightnorm.py (100%) create mode 100644 keras_2/README.md create mode 100755 keras_2/cifar10_cnn.py create mode 100644 keras_2/weightnorm.py diff --git a/keras/README.md b/keras_1/README.md similarity index 100% rename from keras/README.md rename to keras_1/README.md diff --git a/keras/cifar10_cnn.py b/keras_1/cifar10_cnn.py similarity index 100% rename from keras/cifar10_cnn.py rename to keras_1/cifar10_cnn.py diff --git a/keras/weightnorm.py b/keras_1/weightnorm.py similarity index 100% rename from keras/weightnorm.py rename to keras_1/weightnorm.py diff --git a/keras_2/README.md b/keras_2/README.md new file mode 100644 index 0000000..b59eeae --- /dev/null +++ b/keras_2/README.md @@ -0,0 +1,6 @@ + +# Weight Normalization using Keras + +Example code for using Weight Normalization using [Keras](https://keras.io). + +```cifar10_cnn.py``` contains the standard CIFAR-10 example from Keras, with lines 64 and 69 edited to include weight normalization and data dependent initialization. \ No newline at end of file diff --git a/keras_2/cifar10_cnn.py b/keras_2/cifar10_cnn.py new file mode 100755 index 0000000..393f0d6 --- /dev/null +++ b/keras_2/cifar10_cnn.py @@ -0,0 +1,104 @@ +''' +CIFAR-10 example from https://github.com/fchollet/keras/blob/master/examples/cifar10_cnn.py +Now with weight normalization. Lines 64 and 69 contain the changes w.r.t. original. +''' + +from __future__ import print_function +from keras.datasets import cifar10 +from keras.preprocessing.image import ImageDataGenerator +from keras.models import Sequential +from keras.layers import Dense, Dropout, Activation, Flatten +from keras.layers import Convolution2D, MaxPooling2D +from keras.utils import np_utils + +batch_size = 32 +nb_classes = 10 +nb_epoch = 200 +data_augmentation = True + +# input image dimensions +img_rows, img_cols = 32, 32 +# the CIFAR10 images are RGB +img_channels = 3 + +# the data, shuffled and split between train and test sets +(X_train, y_train), (X_test, y_test) = cifar10.load_data() +print('X_train shape:', X_train.shape) +print(X_train.shape[0], 'train samples') +print(X_test.shape[0], 'test samples') +X_train = X_train.astype('float32') +X_test = X_test.astype('float32') +X_train /= 255 +X_test /= 255 + +# convert class vectors to binary class matrices +Y_train = np_utils.to_categorical(y_train, nb_classes) +Y_test = np_utils.to_categorical(y_test, nb_classes) + +model = Sequential() + +model.add(Convolution2D(32, 3, 3, border_mode='same', + input_shape=X_train.shape[1:])) +model.add(Activation('relu')) +model.add(Convolution2D(32, 3, 3)) +model.add(Activation('relu')) +model.add(MaxPooling2D(pool_size=(2, 2))) +model.add(Dropout(0.25)) + +model.add(Convolution2D(64, 3, 3, border_mode='same')) +model.add(Activation('relu')) +model.add(Convolution2D(64, 3, 3)) +model.add(Activation('relu')) +model.add(MaxPooling2D(pool_size=(2, 2))) +model.add(Dropout(0.25)) + +model.add(Flatten()) +model.add(Dense(512)) +model.add(Activation('relu')) +model.add(Dropout(0.5)) +model.add(Dense(nb_classes)) +model.add(Activation('softmax')) + +# let's train the model using SGD + momentum (how original). EDIT: now with weight normalization, so slightly more original ;-) +from weightnorm import SGDWithWeightnorm +sgd_wn = SGDWithWeightnorm(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True) +model.compile(loss='categorical_crossentropy',optimizer=sgd_wn,metrics=['accuracy']) + +# data based initialization of parameters +from weightnorm import data_based_init +data_based_init(model, X_train[:100]) + + +if not data_augmentation: + print('Not using data augmentation.') + model.fit(X_train, Y_train, + batch_size=batch_size, + nb_epoch=nb_epoch, + validation_data=(X_test, Y_test), + shuffle=True) +else: + print('Using real-time data augmentation.') + + # this will do preprocessing and realtime data augmentation + datagen = ImageDataGenerator( + featurewise_center=False, # set input mean to 0 over the dataset + samplewise_center=False, # set each sample mean to 0 + featurewise_std_normalization=False, # divide inputs by std of the dataset + samplewise_std_normalization=False, # divide each input by its std + zca_whitening=False, # apply ZCA whitening + rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180) + width_shift_range=0.1, # randomly shift images horizontally (fraction of total width) + height_shift_range=0.1, # randomly shift images vertically (fraction of total height) + horizontal_flip=True, # randomly flip images + vertical_flip=False) # randomly flip images + + # compute quantities required for featurewise normalization + # (std, mean, and principal components if ZCA whitening is applied) + datagen.fit(X_train) + + # fit the model on the batches generated by datagen.flow() + model.fit_generator(datagen.flow(X_train, Y_train, + batch_size=batch_size), + samples_per_epoch=X_train.shape[0], + nb_epoch=nb_epoch, + validation_data=(X_test, Y_test)) diff --git a/keras_2/weightnorm.py b/keras_2/weightnorm.py new file mode 100644 index 0000000..23fb469 --- /dev/null +++ b/keras_2/weightnorm.py @@ -0,0 +1,210 @@ +from keras import backend as K +from keras.optimizers import SGD,Adam +import tensorflow as tf + +# adapted from keras.optimizers.SGD +class SGDWithWeightnorm(SGD): + def get_updates(self, params, constraints, loss): + grads = self.get_gradients(loss, params) + self.updates = [] + + lr = self.lr + if self.initial_decay > 0: + lr *= (1. / (1. + self.decay * self.iterations)) + self.updates .append(K.update_add(self.iterations, 1)) + + # momentum + shapes = [K.get_variable_shape(p) for p in params] + moments = [K.zeros(shape) for shape in shapes] + self.weights = [self.iterations] + moments + for p, g, m in zip(params, grads, moments): + + # if a weight tensor (len > 1) use weight normalized parameterization + ps = K.get_variable_shape(p) + if len(ps) > 1: + + # get weight normalization parameters + V, V_norm, V_scaler, g_param, grad_g, grad_V = get_weightnorm_params_and_grads(p, g) + + # momentum container for the 'g' parameter + V_scaler_shape = K.get_variable_shape(V_scaler) + m_g = K.zeros(V_scaler_shape) + + # update g parameters + v_g = self.momentum * m_g - lr * grad_g # velocity + self.updates.append(K.update(m_g, v_g)) + if self.nesterov: + new_g_param = g_param + self.momentum * v_g - lr * grad_g + else: + new_g_param = g_param + v_g + + # update V parameters + v_v = self.momentum * m - lr * grad_V # velocity + self.updates.append(K.update(m, v_v)) + if self.nesterov: + new_V_param = V + self.momentum * v_v - lr * grad_V + else: + new_V_param = V + v_v + + # if there are constraints we apply them to V, not W + if p in constraints: + c = constraints[p] + new_V_param = c(new_V_param) + + # wn param updates --> W updates + add_weightnorm_param_updates(self.updates, new_V_param, new_g_param, p, V_scaler) + + else: # normal SGD with momentum + v = self.momentum * m - lr * g # velocity + self.updates.append(K.update(m, v)) + + if self.nesterov: + new_p = p + self.momentum * v - lr * g + else: + new_p = p + v + + # apply constraints + if p in constraints: + c = constraints[p] + new_p = c(new_p) + + self.updates.append(K.update(p, new_p)) + return self.updates + +# adapted from keras.optimizers.Adam +class AdamWithWeightnorm(Adam): + def get_updates(self, params, constraints, loss): + grads = self.get_gradients(loss, params) + self.updates = [K.update_add(self.iterations, 1)] + + lr = self.lr + if self.initial_decay > 0: + lr *= (1. / (1. + self.decay * self.iterations)) + + t = self.iterations + 1 + lr_t = lr * K.sqrt(1. - K.pow(self.beta_2, t)) / (1. - K.pow(self.beta_1, t)) + + shapes = [K.get_variable_shape(p) for p in params] + ms = [K.zeros(shape) for shape in shapes] + vs = [K.zeros(shape) for shape in shapes] + self.weights = [self.iterations] + ms + vs + + for p, g, m, v in zip(params, grads, ms, vs): + + # if a weight tensor (len > 1) use weight normalized parameterization + # this is the only part changed w.r.t. keras.optimizers.Adam + ps = K.get_variable_shape(p) + if len(ps)>1: + + # get weight normalization parameters + V, V_norm, V_scaler, g_param, grad_g, grad_V = get_weightnorm_params_and_grads(p, g) + + # Adam containers for the 'g' parameter + V_scaler_shape = K.get_variable_shape(V_scaler) + m_g = K.zeros(V_scaler_shape) + v_g = K.zeros(V_scaler_shape) + + # update g parameters + m_g_t = (self.beta_1 * m_g) + (1. - self.beta_1) * grad_g + v_g_t = (self.beta_2 * v_g) + (1. - self.beta_2) * K.square(grad_g) + new_g_param = g_param - lr_t * m_g_t / (K.sqrt(v_g_t) + self.epsilon) + self.updates.append(K.update(m_g, m_g_t)) + self.updates.append(K.update(v_g, v_g_t)) + + # update V parameters + m_t = (self.beta_1 * m) + (1. - self.beta_1) * grad_V + v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(grad_V) + new_V_param = V - lr_t * m_t / (K.sqrt(v_t) + self.epsilon) + self.updates.append(K.update(m, m_t)) + self.updates.append(K.update(v, v_t)) + + # if there are constraints we apply them to V, not W + if p in constraints: + c = constraints[p] + new_V_param = c(new_V_param) + + # wn param updates --> W updates + add_weightnorm_param_updates(self.updates, new_V_param, new_g_param, p, V_scaler) + + else: # do optimization normally + m_t = (self.beta_1 * m) + (1. - self.beta_1) * g + v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g) + p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon) + + self.updates.append(K.update(m, m_t)) + self.updates.append(K.update(v, v_t)) + + new_p = p_t + # apply constraints + if p in constraints: + c = constraints[p] + new_p = c(new_p) + self.updates.append(K.update(p, new_p)) + return self.updates + + +def get_weightnorm_params_and_grads(p, g): + ps = K.get_variable_shape(p) + + # construct weight scaler: V_scaler = g/||V|| + V_scaler_shape = (ps[-1],) # assumes we're using tensorflow! + V_scaler = K.ones(V_scaler_shape) # init to ones, so effective parameters don't change + + # get V parameters = ||V||/g * W + norm_axes = [i for i in range(len(ps) - 1)] + V = p / tf.reshape(V_scaler, [1] * len(norm_axes) + [-1]) + + # split V_scaler into ||V|| and g parameters + V_norm = tf.sqrt(tf.reduce_sum(tf.square(V), norm_axes)) + g_param = V_scaler * V_norm + + # get grad in V,g parameters + grad_g = tf.reduce_sum(g * V, norm_axes) / V_norm + grad_V = tf.reshape(V_scaler, [1] * len(norm_axes) + [-1]) * \ + (g - tf.reshape(grad_g / V_norm, [1] * len(norm_axes) + [-1]) * V) + + return V, V_norm, V_scaler, g_param, grad_g, grad_V + + +def add_weightnorm_param_updates(updates, new_V_param, new_g_param, W, V_scaler): + ps = K.get_variable_shape(new_V_param) + norm_axes = [i for i in range(len(ps) - 1)] + + # update W and V_scaler + new_V_norm = tf.sqrt(tf.reduce_sum(tf.square(new_V_param), norm_axes)) + new_V_scaler = new_g_param / new_V_norm + new_W = tf.reshape(new_V_scaler, [1] * len(norm_axes) + [-1]) * new_V_param + updates.append(K.update(W, new_W)) + updates.append(K.update(V_scaler, new_V_scaler)) + + +# data based initialization for a given Keras model +def data_based_init(model, input): + + # input can be dict, numpy array, or list of numpy arrays + if type(input) is dict: + feed_dict = input + elif type(input) is list: + feed_dict = {tf_inp: np_inp for tf_inp,np_inp in zip(model.inputs,input)} + else: + feed_dict = {model.inputs[0]: input} + + # add learning phase if required + if model.uses_learning_phase and K.learning_phase() not in feed_dict: + feed_dict.update({K.learning_phase(): 1}) + + # get all layer name, output, weight, bias tuples + layer_output_weight_bias = [] + for l in model.layers: + if hasattr(l, 'W') and hasattr(l, 'b'): + assert(l.built) + layer_output_weight_bias.append( (l.name,l.get_output_at(0),l.W,l.b) ) # if more than one node, only use the first + + # iterate over our list and do data dependent init + sess = K.get_session() + for l,o,W,b in layer_output_weight_bias: + print('Performing data dependent initialization for layer ' + l) + m,v = tf.nn.moments(o, [i for i in range(len(o.get_shape())-1)]) + s = tf.sqrt(v + 1e-10) + updates = tf.group(W.assign(W/tf.reshape(s,[1]*(len(W.get_shape())-1)+[-1])), b.assign((b-m)/s)) + sess.run(updates, feed_dict) From fd42f2cfd249141bf073a4a28544b2c481de5ac6 Mon Sep 17 00:00:00 2001 From: Martin Krasser Date: Sat, 6 Oct 2018 09:16:19 +0200 Subject: [PATCH 2/2] Compatibility with Keras 2.2.3 and Tensorflow 1.11.0 --- keras_2/cifar10_cnn.py | 151 ++++++++++++++++++++++++++--------------- keras_2/weightnorm.py | 36 +++++----- 2 files changed, 113 insertions(+), 74 deletions(-) mode change 100755 => 100644 keras_2/cifar10_cnn.py diff --git a/keras_2/cifar10_cnn.py b/keras_2/cifar10_cnn.py old mode 100755 new mode 100644 index 393f0d6..8dd51bc --- a/keras_2/cifar10_cnn.py +++ b/keras_2/cifar10_cnn.py @@ -1,53 +1,56 @@ ''' -CIFAR-10 example from https://github.com/fchollet/keras/blob/master/examples/cifar10_cnn.py -Now with weight normalization. Lines 64 and 69 contain the changes w.r.t. original. +CIFAR-10 example from https://github.com/keras-team/keras/blob/master/examples/cifar10_cnn.py +Now with weight normalization. Lines 64-65 and 78-79 contain the changes w.r.t. original. ''' from __future__ import print_function +import keras from keras.datasets import cifar10 from keras.preprocessing.image import ImageDataGenerator from keras.models import Sequential from keras.layers import Dense, Dropout, Activation, Flatten -from keras.layers import Convolution2D, MaxPooling2D -from keras.utils import np_utils +from keras.layers import Conv2D, MaxPooling2D +import os + +'''Train a simple deep CNN on the CIFAR10 small images dataset. + +It gets to 75% validation accuracy in 25 epochs, and 79% after 50 epochs. +(it's still underfitting at that point, though). + +With weight normalization, a validation accuracy of 75% is already reached +after 10 epochs. +''' batch_size = 32 -nb_classes = 10 -nb_epoch = 200 +num_classes = 10 +epochs = 100 data_augmentation = True +num_predictions = 20 +save_dir = os.path.join(os.getcwd(), 'saved_models') +model_name = 'keras_cifar10_trained_model.h5' -# input image dimensions -img_rows, img_cols = 32, 32 -# the CIFAR10 images are RGB -img_channels = 3 - -# the data, shuffled and split between train and test sets -(X_train, y_train), (X_test, y_test) = cifar10.load_data() -print('X_train shape:', X_train.shape) -print(X_train.shape[0], 'train samples') -print(X_test.shape[0], 'test samples') -X_train = X_train.astype('float32') -X_test = X_test.astype('float32') -X_train /= 255 -X_test /= 255 - -# convert class vectors to binary class matrices -Y_train = np_utils.to_categorical(y_train, nb_classes) -Y_test = np_utils.to_categorical(y_test, nb_classes) +# The data, split between train and test sets: +(x_train, y_train), (x_test, y_test) = cifar10.load_data() +print('x_train shape:', x_train.shape) +print(x_train.shape[0], 'train samples') +print(x_test.shape[0], 'test samples') -model = Sequential() +# Convert class vectors to binary class matrices. +y_train = keras.utils.to_categorical(y_train, num_classes) +y_test = keras.utils.to_categorical(y_test, num_classes) -model.add(Convolution2D(32, 3, 3, border_mode='same', - input_shape=X_train.shape[1:])) +model = Sequential() +model.add(Conv2D(32, (3, 3), padding='same', + input_shape=x_train.shape[1:])) model.add(Activation('relu')) -model.add(Convolution2D(32, 3, 3)) +model.add(Conv2D(32, (3, 3))) model.add(Activation('relu')) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Dropout(0.25)) -model.add(Convolution2D(64, 3, 3, border_mode='same')) +model.add(Conv2D(64, (3, 3), padding='same')) model.add(Activation('relu')) -model.add(Convolution2D(64, 3, 3)) +model.add(Conv2D(64, (3, 3))) model.add(Activation('relu')) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Dropout(0.25)) @@ -56,49 +59,87 @@ model.add(Dense(512)) model.add(Activation('relu')) model.add(Dropout(0.5)) -model.add(Dense(nb_classes)) +model.add(Dense(num_classes)) model.add(Activation('softmax')) # let's train the model using SGD + momentum (how original). EDIT: now with weight normalization, so slightly more original ;-) -from weightnorm import SGDWithWeightnorm -sgd_wn = SGDWithWeightnorm(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True) -model.compile(loss='categorical_crossentropy',optimizer=sgd_wn,metrics=['accuracy']) +from weightnorm import SGDWithWeightnorm, AdamWithWeightnorm +opt_wn = SGDWithWeightnorm(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True) +#opt_wn = AdamWithWeightnorm(lr=0.001, decay=1e-6) + +# Let's train the model using RMSprop +model.compile(loss='categorical_crossentropy', + optimizer=opt_wn, + metrics=['accuracy']) + +x_train = x_train.astype('float32') +x_test = x_test.astype('float32') +x_train /= 255 +x_test /= 255 # data based initialization of parameters from weightnorm import data_based_init -data_based_init(model, X_train[:100]) - +data_based_init(model, x_train[:100]) if not data_augmentation: print('Not using data augmentation.') - model.fit(X_train, Y_train, + model.fit(x_train, y_train, batch_size=batch_size, - nb_epoch=nb_epoch, - validation_data=(X_test, Y_test), + epochs=epochs, + validation_data=(x_test, y_test), shuffle=True) else: print('Using real-time data augmentation.') - - # this will do preprocessing and realtime data augmentation + # This will do preprocessing and realtime data augmentation: datagen = ImageDataGenerator( featurewise_center=False, # set input mean to 0 over the dataset samplewise_center=False, # set each sample mean to 0 featurewise_std_normalization=False, # divide inputs by std of the dataset samplewise_std_normalization=False, # divide each input by its std zca_whitening=False, # apply ZCA whitening + zca_epsilon=1e-06, # epsilon for ZCA whitening rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180) - width_shift_range=0.1, # randomly shift images horizontally (fraction of total width) - height_shift_range=0.1, # randomly shift images vertically (fraction of total height) + # randomly shift images horizontally (fraction of total width) + width_shift_range=0.1, + # randomly shift images vertically (fraction of total height) + height_shift_range=0.1, + shear_range=0., # set range for random shear + zoom_range=0., # set range for random zoom + channel_shift_range=0., # set range for random channel shifts + # set mode for filling points outside the input boundaries + fill_mode='nearest', + cval=0., # value used for fill_mode = "constant" horizontal_flip=True, # randomly flip images - vertical_flip=False) # randomly flip images - - # compute quantities required for featurewise normalization - # (std, mean, and principal components if ZCA whitening is applied) - datagen.fit(X_train) - - # fit the model on the batches generated by datagen.flow() - model.fit_generator(datagen.flow(X_train, Y_train, - batch_size=batch_size), - samples_per_epoch=X_train.shape[0], - nb_epoch=nb_epoch, - validation_data=(X_test, Y_test)) + vertical_flip=False, # randomly flip images + # set rescaling factor (applied before any other transformation) + rescale=None, + # set function that will be applied on each input + preprocessing_function=None, + # image data format, either "channels_first" or "channels_last" + data_format=None, + # fraction of images reserved for validation (strictly between 0 and 1) + validation_split=0.0) + + # Compute quantities required for feature-wise normalization + # (std, mean, and principal components if ZCA whitening is applied). + datagen.fit(x_train) + + # Fit the model on the batches generated by datagen.flow(). + model.fit_generator(datagen.flow(x_train, y_train, + batch_size=batch_size), + epochs=epochs, + steps_per_epoch=len(x_train)/batch_size, + validation_data=(x_test, y_test), + workers=4) + +# Save model and weights +if not os.path.isdir(save_dir): + os.makedirs(save_dir) +model_path = os.path.join(save_dir, model_name) +model.save(model_path) +print('Saved trained model at %s ' % model_path) + +# Score trained model. +scores = model.evaluate(x_test, y_test, verbose=1) +print('Test loss:', scores[0]) +print('Test accuracy:', scores[1]) \ No newline at end of file diff --git a/keras_2/weightnorm.py b/keras_2/weightnorm.py index 23fb469..dec50b1 100644 --- a/keras_2/weightnorm.py +++ b/keras_2/weightnorm.py @@ -4,13 +4,13 @@ # adapted from keras.optimizers.SGD class SGDWithWeightnorm(SGD): - def get_updates(self, params, constraints, loss): + def get_updates(self, loss, params): grads = self.get_gradients(loss, params) self.updates = [] lr = self.lr if self.initial_decay > 0: - lr *= (1. / (1. + self.decay * self.iterations)) + lr *= (1. / (1. + self.decay * K.cast(self.iterations, K.floatx()))) self.updates .append(K.update_add(self.iterations, 1)) # momentum @@ -47,9 +47,8 @@ def get_updates(self, params, constraints, loss): new_V_param = V + v_v # if there are constraints we apply them to V, not W - if p in constraints: - c = constraints[p] - new_V_param = c(new_V_param) + if getattr(p, 'constraint', None) is not None: + new_V_param = p.constraint(new_V_param) # wn param updates --> W updates add_weightnorm_param_updates(self.updates, new_V_param, new_g_param, p, V_scaler) @@ -64,24 +63,23 @@ def get_updates(self, params, constraints, loss): new_p = p + v # apply constraints - if p in constraints: - c = constraints[p] - new_p = c(new_p) + if getattr(p, 'constraint', None) is not None: + new_p = p.constraint(new_p) self.updates.append(K.update(p, new_p)) return self.updates # adapted from keras.optimizers.Adam class AdamWithWeightnorm(Adam): - def get_updates(self, params, constraints, loss): + def get_updates(self, loss, params): grads = self.get_gradients(loss, params) self.updates = [K.update_add(self.iterations, 1)] lr = self.lr if self.initial_decay > 0: - lr *= (1. / (1. + self.decay * self.iterations)) + lr *= (1. / (1. + self.decay * K.cast(self.iterations, K.floatx()))) - t = self.iterations + 1 + t = K.cast(self.iterations + 1, K.floatx()) lr_t = lr * K.sqrt(1. - K.pow(self.beta_2, t)) / (1. - K.pow(self.beta_1, t)) shapes = [K.get_variable_shape(p) for p in params] @@ -119,9 +117,8 @@ def get_updates(self, params, constraints, loss): self.updates.append(K.update(v, v_t)) # if there are constraints we apply them to V, not W - if p in constraints: - c = constraints[p] - new_V_param = c(new_V_param) + if getattr(p, 'constraint', None) is not None: + new_V_param = p.constraint(new_V_param) # wn param updates --> W updates add_weightnorm_param_updates(self.updates, new_V_param, new_g_param, p, V_scaler) @@ -136,9 +133,8 @@ def get_updates(self, params, constraints, loss): new_p = p_t # apply constraints - if p in constraints: - c = constraints[p] - new_p = c(new_p) + if getattr(p, 'constraint', None) is not None: + new_p = p.constraint(new_p) self.updates.append(K.update(p, new_p)) return self.updates @@ -196,9 +192,11 @@ def data_based_init(model, input): # get all layer name, output, weight, bias tuples layer_output_weight_bias = [] for l in model.layers: - if hasattr(l, 'W') and hasattr(l, 'b'): + trainable_weights = l.trainable_weights + if len(trainable_weights) == 2: + W,b = trainable_weights assert(l.built) - layer_output_weight_bias.append( (l.name,l.get_output_at(0),l.W,l.b) ) # if more than one node, only use the first + layer_output_weight_bias.append((l.name,l.get_output_at(0),W,b)) # if more than one node, only use the first # iterate over our list and do data dependent init sess = K.get_session()