Last active
February 7, 2018 21:17
-
-
Save jhihn/8bf951e11c049767a4c90393cb723635 to your computer and use it in GitHub Desktop.
Load Keras model and weights from JSON
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
def load_keras_model(filename): | |
with open(filename) as file: | |
data = file.read() | |
json_object = json.loads(data) | |
model = model_from_json(data) | |
layer_arrays = [] | |
for layer in json_object["weights"]: | |
layer_arrays.append(np.array(layer)) | |
model.set_weights(layer_arrays) | |
def save_keras_model(filename, model, stats={'epochs':-1, 'loss':-1.0, 'accuracy':-1.0}): | |
json_string = model.to_json() | |
json_object = json.loads(json_string) | |
json_object.update(stats) | |
layers = model.get_weights() | |
layer_list = [] | |
for layer in layers: | |
layer_list.append(layer.tolist()) | |
json_object["weights"] = layer_list | |
with open(filename, "w") as file: | |
json.dump(json_object, file, sort_keys=True, indent=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment