Skip to content

Commit

Permalink
Compatibility with Keras 2.2.3 and Tensorflow 1.11.0
Browse files Browse the repository at this point in the history
  • Loading branch information
krasserm committed Oct 6, 2018
1 parent 7d5ad0e commit fd42f2c
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 74 deletions.
151 changes: 96 additions & 55 deletions keras_2/cifar10_cnn.py
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,53 +1,56 @@
'''
CIFAR-10 example from https://github.com/fchollet/keras/blob/master/examples/cifar10_cnn.py
Now with weight normalization. Lines 64 and 69 contain the changes w.r.t. original.
CIFAR-10 example from https://github.com/keras-team/keras/blob/master/examples/cifar10_cnn.py
Now with weight normalization. Lines 64-65 and 78-79 contain the changes w.r.t. original.
'''

from __future__ import print_function
import keras
from keras.datasets import cifar10
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Convolution2D, MaxPooling2D
from keras.utils import np_utils
from keras.layers import Conv2D, MaxPooling2D
import os

'''Train a simple deep CNN on the CIFAR10 small images dataset.
It gets to 75% validation accuracy in 25 epochs, and 79% after 50 epochs.
(it's still underfitting at that point, though).
With weight normalization, a validation accuracy of 75% is already reached
after 10 epochs.
'''

batch_size = 32
nb_classes = 10
nb_epoch = 200
num_classes = 10
epochs = 100
data_augmentation = True
num_predictions = 20
save_dir = os.path.join(os.getcwd(), 'saved_models')
model_name = 'keras_cifar10_trained_model.h5'

# input image dimensions
img_rows, img_cols = 32, 32
# the CIFAR10 images are RGB
img_channels = 3

# the data, shuffled and split between train and test sets
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255

# convert class vectors to binary class matrices
Y_train = np_utils.to_categorical(y_train, nb_classes)
Y_test = np_utils.to_categorical(y_test, nb_classes)
# The data, split between train and test sets:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

model = Sequential()
# Convert class vectors to binary class matrices.
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

model.add(Convolution2D(32, 3, 3, border_mode='same',
input_shape=X_train.shape[1:]))
model = Sequential()
model.add(Conv2D(32, (3, 3), padding='same',
input_shape=x_train.shape[1:]))
model.add(Activation('relu'))
model.add(Convolution2D(32, 3, 3))
model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Convolution2D(64, 3, 3, border_mode='same'))
model.add(Conv2D(64, (3, 3), padding='same'))
model.add(Activation('relu'))
model.add(Convolution2D(64, 3, 3))
model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
Expand All @@ -56,49 +59,87 @@
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(nb_classes))
model.add(Dense(num_classes))
model.add(Activation('softmax'))

# let's train the model using SGD + momentum (how original). EDIT: now with weight normalization, so slightly more original ;-)
from weightnorm import SGDWithWeightnorm
sgd_wn = SGDWithWeightnorm(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy',optimizer=sgd_wn,metrics=['accuracy'])
from weightnorm import SGDWithWeightnorm, AdamWithWeightnorm
opt_wn = SGDWithWeightnorm(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
#opt_wn = AdamWithWeightnorm(lr=0.001, decay=1e-6)

# Let's train the model using RMSprop
model.compile(loss='categorical_crossentropy',
optimizer=opt_wn,
metrics=['accuracy'])

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255

# data based initialization of parameters
from weightnorm import data_based_init
data_based_init(model, X_train[:100])

data_based_init(model, x_train[:100])

if not data_augmentation:
print('Not using data augmentation.')
model.fit(X_train, Y_train,
model.fit(x_train, y_train,
batch_size=batch_size,
nb_epoch=nb_epoch,
validation_data=(X_test, Y_test),
epochs=epochs,
validation_data=(x_test, y_test),
shuffle=True)
else:
print('Using real-time data augmentation.')

# this will do preprocessing and realtime data augmentation
# This will do preprocessing and realtime data augmentation:
datagen = ImageDataGenerator(
featurewise_center=False, # set input mean to 0 over the dataset
samplewise_center=False, # set each sample mean to 0
featurewise_std_normalization=False, # divide inputs by std of the dataset
samplewise_std_normalization=False, # divide each input by its std
zca_whitening=False, # apply ZCA whitening
zca_epsilon=1e-06, # epsilon for ZCA whitening
rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180)
width_shift_range=0.1, # randomly shift images horizontally (fraction of total width)
height_shift_range=0.1, # randomly shift images vertically (fraction of total height)
# randomly shift images horizontally (fraction of total width)
width_shift_range=0.1,
# randomly shift images vertically (fraction of total height)
height_shift_range=0.1,
shear_range=0., # set range for random shear
zoom_range=0., # set range for random zoom
channel_shift_range=0., # set range for random channel shifts
# set mode for filling points outside the input boundaries
fill_mode='nearest',
cval=0., # value used for fill_mode = "constant"
horizontal_flip=True, # randomly flip images
vertical_flip=False) # randomly flip images

# compute quantities required for featurewise normalization
# (std, mean, and principal components if ZCA whitening is applied)
datagen.fit(X_train)

# fit the model on the batches generated by datagen.flow()
model.fit_generator(datagen.flow(X_train, Y_train,
batch_size=batch_size),
samples_per_epoch=X_train.shape[0],
nb_epoch=nb_epoch,
validation_data=(X_test, Y_test))
vertical_flip=False, # randomly flip images
# set rescaling factor (applied before any other transformation)
rescale=None,
# set function that will be applied on each input
preprocessing_function=None,
# image data format, either "channels_first" or "channels_last"
data_format=None,
# fraction of images reserved for validation (strictly between 0 and 1)
validation_split=0.0)

# Compute quantities required for feature-wise normalization
# (std, mean, and principal components if ZCA whitening is applied).
datagen.fit(x_train)

# Fit the model on the batches generated by datagen.flow().
model.fit_generator(datagen.flow(x_train, y_train,
batch_size=batch_size),
epochs=epochs,
steps_per_epoch=len(x_train)/batch_size,
validation_data=(x_test, y_test),
workers=4)

# Save model and weights
if not os.path.isdir(save_dir):
os.makedirs(save_dir)
model_path = os.path.join(save_dir, model_name)
model.save(model_path)
print('Saved trained model at %s ' % model_path)

# Score trained model.
scores = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])
36 changes: 17 additions & 19 deletions keras_2/weightnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

# adapted from keras.optimizers.SGD
class SGDWithWeightnorm(SGD):
def get_updates(self, params, constraints, loss):
def get_updates(self, loss, params):
grads = self.get_gradients(loss, params)
self.updates = []

lr = self.lr
if self.initial_decay > 0:
lr *= (1. / (1. + self.decay * self.iterations))
lr *= (1. / (1. + self.decay * K.cast(self.iterations, K.floatx())))
self.updates .append(K.update_add(self.iterations, 1))

# momentum
Expand Down Expand Up @@ -47,9 +47,8 @@ def get_updates(self, params, constraints, loss):
new_V_param = V + v_v

# if there are constraints we apply them to V, not W
if p in constraints:
c = constraints[p]
new_V_param = c(new_V_param)
if getattr(p, 'constraint', None) is not None:
new_V_param = p.constraint(new_V_param)

# wn param updates --> W updates
add_weightnorm_param_updates(self.updates, new_V_param, new_g_param, p, V_scaler)
Expand All @@ -64,24 +63,23 @@ def get_updates(self, params, constraints, loss):
new_p = p + v

# apply constraints
if p in constraints:
c = constraints[p]
new_p = c(new_p)
if getattr(p, 'constraint', None) is not None:
new_p = p.constraint(new_p)

self.updates.append(K.update(p, new_p))
return self.updates

# adapted from keras.optimizers.Adam
class AdamWithWeightnorm(Adam):
def get_updates(self, params, constraints, loss):
def get_updates(self, loss, params):
grads = self.get_gradients(loss, params)
self.updates = [K.update_add(self.iterations, 1)]

lr = self.lr
if self.initial_decay > 0:
lr *= (1. / (1. + self.decay * self.iterations))
lr *= (1. / (1. + self.decay * K.cast(self.iterations, K.floatx())))

t = self.iterations + 1
t = K.cast(self.iterations + 1, K.floatx())
lr_t = lr * K.sqrt(1. - K.pow(self.beta_2, t)) / (1. - K.pow(self.beta_1, t))

shapes = [K.get_variable_shape(p) for p in params]
Expand Down Expand Up @@ -119,9 +117,8 @@ def get_updates(self, params, constraints, loss):
self.updates.append(K.update(v, v_t))

# if there are constraints we apply them to V, not W
if p in constraints:
c = constraints[p]
new_V_param = c(new_V_param)
if getattr(p, 'constraint', None) is not None:
new_V_param = p.constraint(new_V_param)

# wn param updates --> W updates
add_weightnorm_param_updates(self.updates, new_V_param, new_g_param, p, V_scaler)
Expand All @@ -136,9 +133,8 @@ def get_updates(self, params, constraints, loss):

new_p = p_t
# apply constraints
if p in constraints:
c = constraints[p]
new_p = c(new_p)
if getattr(p, 'constraint', None) is not None:
new_p = p.constraint(new_p)
self.updates.append(K.update(p, new_p))
return self.updates

Expand Down Expand Up @@ -196,9 +192,11 @@ def data_based_init(model, input):
# get all layer name, output, weight, bias tuples
layer_output_weight_bias = []
for l in model.layers:
if hasattr(l, 'W') and hasattr(l, 'b'):
trainable_weights = l.trainable_weights
if len(trainable_weights) == 2:
W,b = trainable_weights
assert(l.built)
layer_output_weight_bias.append( (l.name,l.get_output_at(0),l.W,l.b) ) # if more than one node, only use the first
layer_output_weight_bias.append((l.name,l.get_output_at(0),W,b)) # if more than one node, only use the first

# iterate over our list and do data dependent init
sess = K.get_session()
Expand Down

0 comments on commit fd42f2c

Please sign in to comment.