SNABSuite  0.x
Spiking Neural Architecture Benchmark Suite
mnist_dnn_spikey.py
Go to the documentation of this file.
1 '''
2 Trains a simple deep NN on the downscaled MNIST dataset.
3 Adapted from
4 https://raw.githubusercontent.com/keras-team/keras/master/examples/mnist_mlp.py
5 '''
6 
7 
8 from __future__ import print_function
9 
10 import keras
11 from keras.datasets import mnist
12 from keras.models import Sequential
13 from keras.layers import Dense, Dropout, MaxPooling2D, Flatten, AveragePooling2D
14 from keras.optimizers import RMSprop, SGD, Adam
15 from keras import regularizers
16 
17 import numpy as np
18 batch_size = 128
19 num_classes = 10
20 epochs = 25
21 
22 # the data, split between train and test sets
23 (x_train, y_train), (x_test, y_test) = mnist.load_data()
24 
25 x_train = x_train.reshape(60000,28,28,1)
26 x_test = x_test.reshape(10000, 28,28,1)
27 x_train = x_train.astype('float32')
28 x_test = x_test.astype('float32')
29 x_train /= 255
30 x_test /= 255
31 print(x_train.shape, 'train samples')
32 print(x_test.shape, 'test samples')
33 
34 # convert class vectors to binary class matrices
35 y_train = keras.utils.to_categorical(y_train, num_classes)
36 y_test = keras.utils.to_categorical(y_test, num_classes)
37 
38 # = np.ndarray((60000,14,14))
39 #x_test_new = np.ndarray((10000,14,14))
40 
41 #for counter, image in enumerate(x_train):
42  ##new_im = np.zeros((14,14))
43  #for i in range(0,28,2):
44  #for j in range(0,28,2):
45  #x_train_new[counter][int(i/2)][int(j/2)] = (image[i][j] + image[i+1][j] + image[i][j+1] + image[i+1][j+1])/4.0
46  ##x_train_new.append(new_im)
47 
48 
49 #for counter, image in enumerate(x_test):
50  ##new_im = np.zeros((14,14))
51  #for i in range(0,28,2):
52  #for j in range(0,28,2):
53  #x_test_new[counter][int(i/2)][int(j/2)] = (image[i][j] +
54  #image[i+1][j] +
55  #image[i][j+1] +
56  #image[i+1][j+1])/4.0
57  #x_test_new.append(new_im)
58 x_train_new = x_train
59 x_test_new = x_test
60 
61 model = Sequential()
62 model.add(AveragePooling2D(pool_size=(3,3), input_shape=(28,28,1), data_format = "channels_last"))
63 model.add(Flatten())
64 model.add(Dense(100, activation='relu',
65  use_bias=False,
66  kernel_constraint=keras.constraints.NonNeg(),kernel_regularizer=regularizers.l2(0.001)
67  ))
68 model.add(Dense(num_classes, activation='relu',
69  use_bias=False,
70  kernel_constraint=keras.constraints.NonNeg(),kernel_regularizer=regularizers.l2(0.001)
71  ))
72 
73 model.summary()
74 model.compile(loss='categorical_hinge',
75  optimizer=Adam(lr=0.001),#, momentum=0.0, nesterov=False),
76  metrics=['accuracy'])
77 
78 history = model.fit(x_train_new, y_train,
79  batch_size=batch_size,
80  epochs=epochs,
81  verbose=1,
82  validation_data=(x_test_new, y_test))
83 score = model.evaluate(x_test_new, y_test, verbose=0)
84 print('Test loss:', score[0])
85 print('Test accuracy:', score[1])
86 
87 
88 model.save_weights('dnn_spikey.h5')
89 json_string = model.to_json()
90 with open('dnn_spikey.json', 'w') as file:
91  file.write(json_string)