Skip to content

Instantly share code, notes, and snippets.

@pors
Created May 31, 2025 08:42
Show Gist options
  • Save pors/79ea28bbce5d25e6b3bdba128aef0533 to your computer and use it in GitHub Desktop.
Save pors/79ea28bbce5d25e6b3bdba128aef0533 to your computer and use it in GitHub Desktop.
Fast.ai compatible Training Analysis Dashboard
def plot_training_dashboard(learn):
"""Create a comprehensive dashboard of training metrics"""
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
fig.suptitle('Training Analysis Dashboard', fontsize=16)
# Get the values more carefully
values = learn.recorder.values
# For training loss, we want the FINAL batch loss of each epoch
if hasattr(learn.recorder, 'losses'):
all_train_losses = learn.recorder.losses
n_batches = len(all_train_losses) // len(values)
# Take the LAST training loss of each epoch (not the first)
train_losses_per_epoch = [all_train_losses[(i+1)*n_batches-1] for i in range(len(values))]
else:
# Fallback: use the recorded training loss
train_losses_per_epoch = [v[0] for v in values]
valid_losses = [v[1] for v in values]
# 1. Traditional loss plot
ax1 = axes[0, 0]
ax1.plot(train_losses_per_epoch, label='Train', color='blue', marker='o')
ax1.plot(valid_losses, label='Valid', color='orange', marker='o')
ax1.set_title('Training vs Validation Loss')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)
# 2. Loss gap (now should match the simple version)
ax2 = axes[0, 1]
gaps = [v - t for t, v in zip(train_losses_per_epoch, valid_losses)]
ax2.plot(gaps, color='red', linewidth=2, marker='o')
ax2.axhline(y=0, color='black', linestyle='--', alpha=0.3)
ax2.set_title('Loss Gap (Valid - Train)')
ax2.set_ylabel('Gap')
ax2.fill_between(range(len(gaps)), 0, gaps,
where=[g > 0 for g in gaps],
alpha=0.2, color='red')
ax2.grid(True, alpha=0.3)
# 3. Error rate or first metric
ax3 = axes[1, 0]
if len(learn.recorder.values[0]) > 2: # Has metrics
error_rates = [v[2] for v in learn.recorder.values] # First metric after train/valid loss
ax3.plot(error_rates, color='green', linewidth=2)
ax3.set_title('Error Rate Evolution')
ax3.set_ylabel('Error Rate')
else:
ax3.text(0.5, 0.5, 'No metrics recorded', ha='center', va='center')
ax3.set_title('Metrics')
ax3.set_xlabel('Epoch')
ax3.grid(True, alpha=0.3)
# 4. Loss ratio
ax4 = axes[1, 1]
ratios = [v/t if t > 0 else 0 for t, v in zip(train_losses_per_epoch, valid_losses)]
ax4.plot(ratios, color='purple', linewidth=2)
ax4.axhline(y=1, color='black', linestyle='--', alpha=0.3)
ax4.axhline(y=2, color='orange', linestyle='--', alpha=0.3)
ax4.set_title('Loss Ratio (Valid/Train)')
ax4.set_ylabel('Ratio')
ax4.set_xlabel('Epoch')
ax4.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# Usage
plot_training_dashboard(learn)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment