SNABSuite  0.x
Spiking Neural Architecture Benchmark Suite
convert_weights.py
Go to the documentation of this file.
1 import argparse
2 
3 parser = argparse.ArgumentParser(description='Converts a Keras model to json interpreted by the respective SNAB')
4 parser.add_argument('-a', help="JSON", required=True)
5 parser.add_argument('-w', help="Weights in HDF5", required=True)
6 parser.add_argument('-o', help="Ouput file name", required=True)
7 args = parser.parse_args()
8 
9 from keras.models import Sequential, model_from_json
10 import json
11 import numpy as np
12 
13 from keras.constraints import Constraint
14 from keras import backend as K
15 class WeightClip(Constraint):
16  '''Clips the weights incident to each hidden unit to be inside a range
17  '''
18  def __init__(self, c=2):
19  self.c = c
20 
21  def __call__(self, p):
22  return K.clip(p, -self.c, self.c)
23 
24  def get_config(self):
25  return {'name': self.__class__.__name__,
26  'c': self.c}
27 
28 json_file = open(args.a).read()
29 model = model_from_json(json_file)
30 model.load_weights(args.w)
31 netw = json.loads(json_file)
32 
33 data = {}
34 input_shape = []
35 if netw["keras_version"][0] != "2":
36  print("Warning: script was written for Keras 2.3.0")
37 
38 data["netw"] = []
39 for ind, layer in enumerate(netw["config"]["layers"]):
40  layer_dict = {}
41  if(layer["class_name"] == "Dropout"):
42  print("Ignoring dropout layer")
43  continue
44  elif(layer["class_name"] == "Flatten"):
45  print("Ignoring Flatten layer")
46  continue
47  elif(layer["class_name"] == "AveragePooling2D"):
48  print("Ignoring AveragePooling2D layer")
49  continue
50  elif(layer["class_name"] == "InputLayer"):
51  print("Ignoring InputLayer layer")
52  continue
53  elif(layer["class_name"] == "Dense"):
54  layer_dict["class_name"] = "Dense"
55  layer_dict["size"] = layer["config"]["units"]
56  try:
57  layer_name = layer["config"]["name"]
58  for layer_2 in model.layers:
59  if layer_2.name == layer_name:
60  layer_dict["weights"] = layer_2.get_weights()[0].tolist()
61  break
62  except:
63  layer_dict["weights"] = model.layers[ind].get_weights()[0].tolist()
64  # Weight[i][j]: i input, j output
65  elif(layer["class_name"] == "Conv2D"):
66  layer_dict["class_name"] = "Conv2D"
67  layer_dict["size"] = layer["config"]["filters"]
68  layer_dict["stride"] = layer["config"]["strides"][0]
69  layer_dict["padding"] = layer["config"]["padding"]
70  if "batch_input_shape" in layer["config"]:
71  layer_dict["input_shape_x"] = layer["config"]["batch_input_shape"][1]
72  layer_dict["input_shape_y"] = layer["config"]["batch_input_shape"][2]
73  layer_dict["input_shape_z"] = layer["config"]["batch_input_shape"][3]
74  else:
75  layer_dict["input_shape_x"] = None
76  layer_dict["input_shape_y"] = None
77  layer_dict["input_shape_z"] = None
78  try:
79  layer_name = layer["config"]["name"]
80  for layer_2 in model.layers:
81  if layer_2.name == layer_name:
82  layer_dict["weights"] = layer_2.get_weights()[0].tolist()
83  break
84  except:
85  layer_dict["weights"] = model.layers[ind].get_weights()[0].tolist()
86  elif(layer["class_name"] == "MaxPooling2D"):
87  layer_dict["class_name"] = "MaxPooling2D"
88  layer_dict["size"] = layer["config"]["pool_size"]
89  layer_dict["stride"] = layer["config"]["strides"][0]
90  else:
91  raise RuntimeError("Unknown layer type " + layer["class_name"] + "!")
92  data["netw"].append(layer_dict)
93 
94 
95 if(args.o.endswith(".json")):
96  with open(args.o, 'w') as file:
97  json.dump(data, file)
98 
99 elif(args.o.endswith(".msgpack")):
100  import msgpack
101  with open(args.o, 'wb') as file:
102  msgpack.dump(data, file, use_single_float=True)
103 else:
104  raise RuntimeError("Wrong file name! File must end with either .json or .msgpack!")