Created
July 13, 2020 23:40
-
-
Save shashankprasanna/42fe4bf3903a4ca735dc260df6efca18 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
for trial_hyp in trial_hyperparameter_set: | |
# Combine static hyperparameters and trial specific hyperparameters | |
hyperparams = {**static_hyperparams, **trial_hyp} | |
# Create unique job name with hyperparameter and time | |
time_append = int(time.time()) | |
hyp_append = "-".join([str(elm) for elm in trial_hyp.values()]) | |
job_name = f'cifar10-training-{hyp_append}-{time_append}' | |
# Create a Tracker to track Trial specific hyperparameters | |
with Tracker.create(display_name=f"trial-metadata-{time_append}", | |
artifact_bucket=bucket_name, | |
artifact_prefix=f"{training_experiment.experiment_name}/{job_name}", | |
sagemaker_boto_client=sm) as trial_tracker: | |
trial_tracker.log_parameters(hyperparams) | |
# Create a new Trial and associate Tracker to it | |
tf_trial = Trial.create( | |
trial_name = f'trial-{hyp_append}-{time_append}', | |
experiment_name = training_experiment.experiment_name, | |
sagemaker_boto_client = sm) | |
tf_trial.add_trial_component(exp_tracker.trial_component) | |
time.sleep(2) #To prevent ThrottlingException | |
tf_trial.add_trial_component(trial_tracker.trial_component) | |
# Create an experiment config that associates training job to the Trial | |
experiment_config = {"ExperimentName" : training_experiment.experiment_name, | |
"TrialName" : tf_trial.trial_name, | |
"TrialComponentDisplayName": job_name} | |
metric_definitions = [{'Name': 'loss', 'Regex': 'loss: ([0-9\\.]+)'}, | |
{'Name': 'acc', 'Regex': 'acc: ([0-9\\.]+)'}, | |
{'Name': 'val_loss', 'Regex': 'val_loss: ([0-9\\.]+)'}, | |
{'Name': 'val_acc', 'Regex': 'val_acc: ([0-9\\.]+)'}, | |
{'Name': 'test_acc', 'Regex': 'test_acc: ([0-9\\.]+)'}, | |
{'Name': 'test_loss', 'Regex': 'test_loss: ([0-9\\.]+)'}] | |
# Create a TensorFlow Estimator with the Trial specific hyperparameters | |
tf_estimator = TensorFlow(entry_point = 'cifar10-training-sagemaker.py', | |
source_dir = 'code', | |
output_path = f's3://{bucket_name}/{training_experiment.experiment_name}/', | |
code_location = f's3://{bucket_name}/{training_experiment.experiment_name}', | |
role = role, | |
train_instance_count = 1, | |
train_instance_type = 'ml.p3.2xlarge', | |
framework_version = '1.15', | |
py_version = 'py3', | |
script_mode = True, | |
metric_definitions = metric_definitions, | |
sagemaker_session = sagemaker_session, | |
hyperparameters = hyperparams, | |
enable_sagemaker_metrics = True) | |
# Launch a training job | |
tf_estimator.fit({'training' : datasets, | |
'validation': datasets, | |
'eval' : datasets}, | |
job_name = job_name, | |
wait = False, | |
experiment_config = experiment_config) | |
time.sleep(3) #To prevent ThrottlingException |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment