Created
April 5, 2019 12:44
-
-
Save gabraganca/720f70c2f5fa8857150d67325a3abb38 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Obtém os resultados do Grid Search | |
df_results = pd.DataFrame.from_dict(gs.cv_results) | |
df_results.columns = df_results.columns.str.replace('param_','') | |
# Grafica os mapas de calor | |
n_epochs = len(param_grid['n_epochs']) | |
fig, axes = plt.subplots(nrows=n_epochs, ncols=3, figsize=(22, 6*n_epochs)) | |
for ax_row, n_epoch in zip(axes, param_grid['n_epochs']): | |
for ax, metric in zip(ax_row, ['mae', 'rmse', 'time']): | |
parameter = f'mean_test_{metric}' if metric != 'time' else f'mean_fit_{metric}' | |
ax = sns.heatmap( | |
df_results.query(f'n_epochs =={n_epoch}')\ | |
.pivot_table(columns='n_factors', index='lr_all', values=parameter), | |
annot=True, | |
fmt='0.4f', | |
vmin= df_results[parameter].min(), | |
vmax= df_results[parameter].max(), | |
ax=ax, | |
cmap='viridis' | |
) | |
metric = metric.capitalize() if metric == 'time' else metric.upper() | |
ax.set_title(f'# Epochs: {n_epoch} | metric: {metric}') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment