Created
July 13, 2020 23:53
-
-
Save shashankprasanna/dfdcea9164399eeaf723f03ee91a610d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from smdebug.trials import create_trial | |
def tensor_df(tname): | |
tval = trial.tensor(tname).values() | |
df = pd.DataFrame.from_dict(tval,orient='index',columns=[tname]) | |
df_tval = df.reset_index().rename(columns={'index':'steps'}) | |
return df_tval | |
def trial_perf_curves(job_name, tname, experiment_name): | |
debug_data = f's3://{bucket_name}/{experiment_name}/{job_name}/debug-output' | |
trial = create_trial(debug_data) | |
tval = trial.tensor(tname).values() | |
df = pd.DataFrame.from_dict(tval,orient='index',columns=[tname]) | |
return df | |
def get_metric_dataframe(metric, trial_comp_ds, experiment_name): | |
df = pd.DataFrame() | |
for tc_name in trial_comp_ds['DisplayName']: | |
print(f'\nLoading training job: {tc_name}') | |
print(f'--------------------------------\n') | |
trial_perf = trial_perf_curves(tc_name, metric, experiment_name) | |
trial_perf.columns = [tc_name] | |
df = pd.concat([df, trial_perf],axis=1) | |
return df | |
val_acc_df = get_metric_dataframe('val_acc', trial_comp_ds_jobs, experiment_name) | |
fig = plt.figure() | |
fig.set_size_inches([15, 10]) | |
# Replace the Trial names with the ones you want to plot, or remove indexing to plot all jobs | |
val_acc_df[['cifar10-training-adam-custom-120-1594536575','cifar10-training-adam-custom-60-1594536571','cifar10-training-rmsprop-custom-30-1594536622']].plot(style='-',ax=plt.gca()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment