This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_graph_def_from_saved_model(saved_model_dir): | |
with tf.Session() as session: | |
meta_graph_def = tf.saved_model.loader.load( | |
session, | |
tags=[tag_constants.SERVING], | |
export_dir=saved_model_dir | |
) | |
return meta_graph_def.graph_def |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def convert_graph_def_to_saved_model(export_dir, graph_filepath): | |
if tf.gfile.Exists(export_dir): | |
tf.gfile.DeleteRecursively(export_dir) | |
graph_def = get_graph_def_from_file(graph_filepath) | |
with tf.Session(graph=tf.Graph()) as session: | |
tf.import_graph_def(graph_def, name='') | |
tf.saved_model.simple_save( | |
session, | |
export_dir, | |
inputs={ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tensorflow.tools.graph_transforms import TransformGraph | |
def get_graph_def_from_file(graph_filepath): | |
with ops.Graph().as_default(): | |
with tf.gfile.GFile(graph_filepath, 'rb') as f: | |
graph_def = tf.GraphDef() | |
graph_def.ParseFromString(f.read()) | |
return graph_def | |
def optimize_graph(model_dir, graph_filename, transforms, output_node): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def freeze_model(saved_model_dir, output_node_names, output_filename): | |
output_graph_filename = os.path.join(saved_model_dir, output_filename) | |
initializer_nodes = '' | |
freeze_graph.freeze_graph( | |
input_saved_model_dir=saved_model_dir, | |
output_graph=output_graph_filename, | |
saved_model_tags = tag_constants.SERVING, | |
output_node_names=output_node_names, | |
initializer_nodes=initializer_nodes, | |
input_graph=None, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def inference_tfserving(eval_data, batch=BATCH_SIZE, | |
repeat=1000, signature='predict'): | |
url = 'http://localhost:8501/v1/models/mnist_classifier:predict' | |
instances = [[float(i) for i in list(eval_data[img])] for img in range(batch)] | |
request_data = {'signature_name': signature, | |
'instances': instances} | |
time_start = datetime.utcnow() | |
for i in range(repeat): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_size(model_dir, model_file='saved_model.pb'): | |
model_file_path = os.path.join(model_dir, model_file) | |
print(model_file_path, '') | |
pb_size = os.path.getsize(model_file_path) | |
variables_size = 0 | |
if os.path.exists( | |
os.path.join(model_dir,'variables/variables.data-00000-of-00001')): | |
variables_size = os.path.getsize(os.path.join( | |
model_dir,'variables/variables.data-00000-of-00001')) | |
variables_size += os.path.getsize(os.path.join( |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def describe_graph(graph_def, show_nodes=False): | |
print('Input Feature Nodes: {}'.format( | |
[node.name for node in graph_def.node if node.op=='Placeholder'])) | |
print('') | |
print('Unused Nodes: {}'.format( | |
[node.name for node in graph_def.node if 'unused' in node.name])) | |
print('') | |
print('Output Nodes: {}'.format( | |
[node.name for node in graph_def.node if ( |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def run_experiment(hparams, train_data, train_labels, run_config, create_estimator_fn=create_estimator): | |
train_spec = tf.estimator.TrainSpec( | |
input_fn = tf.estimator.inputs.numpy_input_fn( | |
x={'input_image': train_data}, | |
y=train_labels, | |
batch_size=hparams.batch_size, | |
num_epochs=None, | |
shuffle=True), | |
max_steps=hparams.max_training_steps | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
import os | |
import numpy as np | |
from datetime import datetime | |
import sys | |
import tensorflow as tf | |
from tensorflow import data | |
from tensorflow.python.saved_model import tag_constants |