Created
September 18, 2019 09:12
-
-
Save jiankaiwang/ff63e786162225121040b4090bc015d6 to your computer and use it in GitHub Desktop.
Example of model specifications to inputs and outputs.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
@author: jiankaiwang (https://jiankaiwang.no-ip.biz/) | |
@version: | |
Tensorflow: 1.x (developed >= 1.13.2) | |
@description: | |
Example of model specifications to inputs and outputs. | |
@dependency: | |
OperateFrozenModel (TF1_FrozenModel.py, https://gist.github.com/jiankaiwang/24cc1bc8b38ce72bba73f7fb326f7b9e) | |
@changelog (main): | |
2019-04: initial commit | |
2019-09: released on gist.github.com | |
""" | |
import tensorflow as tf | |
import numpy as np | |
import OperateFrozenModel | |
# In[] | |
pb_path = "/Users/jiankaiwang/Desktop/output_graph.pb" | |
merged_pb_path = "/Users/jiankaiwang/Desktop/merged_graph.pb" | |
# In[] | |
tf.reset_default_graph() | |
# merged_graph | |
merged_graph = tf.Graph() | |
with merged_graph.as_default(): | |
# defined a specific input | |
# shape [None, 224, 224, 3] depends on your model's input | |
inputs = tf.placeholder(tf.float32, [None, 224, 224, 3], "input") | |
# import a trained model graph | |
graph_def = tf.GraphDef() | |
with tf.gfile.GFile(pb_path, "rb") as fin: | |
graph_def.ParseFromString(fin.read()) | |
graph_outputs, = tf.import_graph_def( | |
graph_def, | |
input_map={"Placeholder:0": inputs}, | |
return_elements=["final_result:0"], | |
name="") | |
# defined a specific output | |
outputs = tf.identity(graph_outputs, name="output") | |
# exported as a frozen model | |
with tf.Session() as sess: | |
state, graph = OperateFrozenModel.save_sess_into_frozen_model( | |
sess, ["output"], merged_pb_path) | |
print(state, graph) | |
# In[] | |
# sample input | |
sampled = np.random.randn(1, 224, 224, 3) | |
_, merged_graph = OperateFrozenModel.load_frozen_model(merged_pb_path) | |
with merged_graph.as_default(): | |
inputs = merged_graph.get_tensor_by_name("input:0") | |
outputs = merged_graph.get_tensor_by_name("output:0") | |
with tf.Session() as sess: | |
print(sess.run(outputs, feed_dict={inputs: sampled})) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment