Created
May 31, 2025 08:42
-
-
Save pors/79ea28bbce5d25e6b3bdba128aef0533 to your computer and use it in GitHub Desktop.
Fast.ai compatible Training Analysis Dashboard
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def plot_training_dashboard(learn): | |
"""Create a comprehensive dashboard of training metrics""" | |
fig, axes = plt.subplots(2, 2, figsize=(12, 10)) | |
fig.suptitle('Training Analysis Dashboard', fontsize=16) | |
# Get the values more carefully | |
values = learn.recorder.values | |
# For training loss, we want the FINAL batch loss of each epoch | |
if hasattr(learn.recorder, 'losses'): | |
all_train_losses = learn.recorder.losses | |
n_batches = len(all_train_losses) // len(values) | |
# Take the LAST training loss of each epoch (not the first) | |
train_losses_per_epoch = [all_train_losses[(i+1)*n_batches-1] for i in range(len(values))] | |
else: | |
# Fallback: use the recorded training loss | |
train_losses_per_epoch = [v[0] for v in values] | |
valid_losses = [v[1] for v in values] | |
# 1. Traditional loss plot | |
ax1 = axes[0, 0] | |
ax1.plot(train_losses_per_epoch, label='Train', color='blue', marker='o') | |
ax1.plot(valid_losses, label='Valid', color='orange', marker='o') | |
ax1.set_title('Training vs Validation Loss') | |
ax1.set_ylabel('Loss') | |
ax1.legend() | |
ax1.grid(True, alpha=0.3) | |
# 2. Loss gap (now should match the simple version) | |
ax2 = axes[0, 1] | |
gaps = [v - t for t, v in zip(train_losses_per_epoch, valid_losses)] | |
ax2.plot(gaps, color='red', linewidth=2, marker='o') | |
ax2.axhline(y=0, color='black', linestyle='--', alpha=0.3) | |
ax2.set_title('Loss Gap (Valid - Train)') | |
ax2.set_ylabel('Gap') | |
ax2.fill_between(range(len(gaps)), 0, gaps, | |
where=[g > 0 for g in gaps], | |
alpha=0.2, color='red') | |
ax2.grid(True, alpha=0.3) | |
# 3. Error rate or first metric | |
ax3 = axes[1, 0] | |
if len(learn.recorder.values[0]) > 2: # Has metrics | |
error_rates = [v[2] for v in learn.recorder.values] # First metric after train/valid loss | |
ax3.plot(error_rates, color='green', linewidth=2) | |
ax3.set_title('Error Rate Evolution') | |
ax3.set_ylabel('Error Rate') | |
else: | |
ax3.text(0.5, 0.5, 'No metrics recorded', ha='center', va='center') | |
ax3.set_title('Metrics') | |
ax3.set_xlabel('Epoch') | |
ax3.grid(True, alpha=0.3) | |
# 4. Loss ratio | |
ax4 = axes[1, 1] | |
ratios = [v/t if t > 0 else 0 for t, v in zip(train_losses_per_epoch, valid_losses)] | |
ax4.plot(ratios, color='purple', linewidth=2) | |
ax4.axhline(y=1, color='black', linestyle='--', alpha=0.3) | |
ax4.axhline(y=2, color='orange', linestyle='--', alpha=0.3) | |
ax4.set_title('Loss Ratio (Valid/Train)') | |
ax4.set_ylabel('Ratio') | |
ax4.set_xlabel('Epoch') | |
ax4.grid(True, alpha=0.3) | |
plt.tight_layout() | |
plt.show() | |
# Usage | |
plot_training_dashboard(learn) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment