Last active
March 12, 2025 14:27
-
-
Save Flova/8bed128b41a74142a661883af9e51490 to your computer and use it in GitHub Desktop.
Plot the gradient flow (PyTorch)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Based on https://discuss.pytorch.org/t/check-gradient-flow-in-network/15063/10 | |
def plot_grad_flow(named_parameters): | |
'''Plots the gradients flowing through different layers in the net during training. | |
Can be used for checking for possible gradient vanishing / exploding problems. | |
Usage: Plug this function in Trainer class after loss.backwards() as | |
"plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow''' | |
ave_grads = [] | |
max_grads= [] | |
layers = [] | |
for n, p in named_parameters: | |
if(p.requires_grad) and ("bias" not in n): | |
layers.append(n) | |
ave_grads.append(p.grad.abs().mean().item()) | |
max_grads.append(p.grad.abs().max().item()) | |
plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c") | |
plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b") | |
plt.hlines(0, 0, len(ave_grads)+1, lw=2, color="k" ) | |
plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical") | |
plt.xlim(left=0, right=len(ave_grads)) | |
plt.ylim(bottom = -0.001, top=0.02) # zoom in on the lower gradient regions | |
plt.xlabel("Layers") | |
plt.ylabel("average gradient") | |
plt.title("Gradient flow") | |
plt.grid(True) | |
plt.legend([Line2D([0], [0], color="c", lw=4), | |
Line2D([0], [0], color="b", lw=4), | |
Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient']) |
Thanks, I changed it.
What an amazing little piece of code, many thanks! One suggestion to make handling varying gradient magnitudes better: Why not make the y scale logarithmic?
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thank you for this gist!
A little side note: when trying this function I got a conversion error from Tensor to numpy, this was due because my model was running on CUDA and doing
p.grad.abs().mean()
yields a Tensor (samething goes withmax()
). To fix this, it's enough to add.item()
afterwards, likep.grad.abs().mean().item()
.