1 '''Trains a simple convnet on the MNIST dataset. 3 Gets to 99.25% test accuracy after 12 epochs 4 (there is still a lot of margin for parameter tuning). 5 16 seconds per epoch on a GRID K520 GPU. 10 from __future__
import print_function
12 from keras.datasets
import mnist
13 from keras.models
import Sequential
14 from keras.layers
import Dense, Dropout, Flatten
15 from keras.layers
import Conv2D, MaxPooling2D
16 from keras
import backend
as K
23 img_rows, img_cols = 28, 28
26 (x_train, y_train), (x_test, y_test) = mnist.load_data()
28 if K.image_data_format() ==
'channels_first':
29 x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
30 x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
31 input_shape = (1, img_rows, img_cols)
33 x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
34 x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
35 input_shape = (img_rows, img_cols, 1)
37 x_train = x_train.astype(
'float32')
38 x_test = x_test.astype(
'float32')
41 print(
'x_train shape:', x_train.shape)
42 print(x_train.shape[0],
'train samples')
43 print(x_test.shape[0],
'test samples')
46 y_train = keras.utils.to_categorical(y_train, num_classes)
47 y_test = keras.utils.to_categorical(y_test, num_classes)
50 kernel_init =
'he_uniform' 52 model.add(Conv2D(16, kernel_size=(3, 3),
54 input_shape=input_shape, use_bias=
False,
55 kernel_initializer=kernel_init))
56 model.add(MaxPooling2D(pool_size=(2, 2)))
59 model.add(
Dense(128, activation=
'relu', use_bias=
False,
60 kernel_initializer=kernel_init))
62 model.add(
Dense(num_classes, activation=
'softmax', use_bias=
False))
65 model.compile(loss=keras.losses.categorical_crossentropy,
66 optimizer=keras.optimizers.Adadelta(),
72 model.fit(x_train, y_train,
73 batch_size=batch_size,
76 validation_data=(x_test, y_test))
77 score = model.evaluate(x_test, y_test, verbose=0)
78 print(
'Test loss:', score[0])
79 print(
'Test accuracy:', score[1])
82 model.save_weights(
'cnn_pool_he_100.h5')
83 json_string = model.to_json()
84 with open(
'cnn_pool_he_100.json',
'w')
as file:
85 file.write(json_string)