Last active
October 29, 2018 17:03
-
-
Save lukmanr/d3efc692da61a5f48f856bbde644ce1e to your computer and use it in GitHub Desktop.
TF Model Optimization code 2
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def run_experiment(hparams, train_data, train_labels, run_config, create_estimator_fn=create_estimator): | |
train_spec = tf.estimator.TrainSpec( | |
input_fn = tf.estimator.inputs.numpy_input_fn( | |
x={'input_image': train_data}, | |
y=train_labels, | |
batch_size=hparams.batch_size, | |
num_epochs=None, | |
shuffle=True), | |
max_steps=hparams.max_training_steps | |
) | |
eval_spec = tf.estimator.EvalSpec( | |
input_fn = tf.estimator.inputs.numpy_input_fn( | |
x={'input_image': train_data}, | |
y=train_labels, | |
batch_size=hparams.batch_size, | |
num_epochs=1, | |
shuffle=False), | |
steps=None, | |
throttle_secs=hparams.eval_throttle_secs | |
) | |
tf.logging.set_verbosity(tf.logging.INFO) | |
time_start = datetime.utcnow() | |
print('Experiment started at {}'.format(time_start.strftime('%H:%M:%S'))) | |
print('.......................................') | |
estimator = create_estimator_fn(hparams, run_config) | |
tf.estimator.train_and_evaluate( | |
estimator=estimator, | |
train_spec=train_spec, | |
eval_spec=eval_spec | |
) | |
time_end = datetime.utcnow() | |
print('.......................................') | |
print('Experiment finished at {}'.format(time_end.strftime('%H:%M:%S'))) | |
print('') | |
time_elapsed = time_end - time_start | |
print('Experiment elapsed time: {} seconds'.format(time_elapsed.total_seconds())) | |
return estimator | |
def train_and_export_model(train_data, train_labels): | |
model_dir = os.path.join(MODELS_LOCATION, MODEL_NAME) | |
hparams = tf.contrib.training.HParams( | |
batch_size=100, | |
hidden_units=[1024], | |
num_conv_layers=2, | |
init_filters=64, | |
dropout=0.85, | |
max_training_steps=50, | |
eval_throttle_secs=10, | |
learning_rate=1e-3, | |
debug=True | |
) | |
run_config = tf.estimator.RunConfig( | |
tf_random_seed=19830610, | |
save_checkpoints_steps=1000, | |
keep_checkpoint_max=3, | |
model_dir=model_dir | |
) | |
if tf.gfile.Exists(model_dir): | |
print('Removing previous artifacts...') | |
tf.gfile.DeleteRecursively(model_dir) | |
os.makedirs(model_dir) | |
estimator = run_experiment(hparams, train_data, train_labels, run_config, create_estimator_keras) | |
def make_serving_input_receiver_fn(): | |
inputs = {'input_image': tf.placeholder( | |
shape=[None,28,28], dtype=tf.float32, name='serving_input_image')} | |
return tf.estimator.export.build_raw_serving_input_receiver_fn(inputs) | |
export_dir = os.path.join(model_dir, 'export') | |
if tf.gfile.Exists(export_dir): | |
tf.gfile.DeleteRecursively(export_dir) | |
estimator.export_savedmodel( | |
export_dir_base=export_dir, | |
serving_input_receiver_fn=make_serving_input_receiver_fn() | |
) | |
return export_dir |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment