Last active
April 19, 2026 07:30
-
-
Save thomwolf/ac7a7da6b1888c2eeac8ac8b9b05d3d3 to your computer and use it in GitHub Desktop.
PyTorch gradient accumulation training loop
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| model.zero_grad() # Reset gradients tensors | |
| for i, (inputs, labels) in enumerate(training_set): | |
| predictions = model(inputs) # Forward pass | |
| loss = loss_function(predictions, labels) # Compute loss function | |
| loss = loss / accumulation_steps # Normalize our loss (if averaged) | |
| loss.backward() # Backward pass | |
| if (i+1) % accumulation_steps == 0: # Wait for several backward steps | |
| optimizer.step() # Now we can do an optimizer step | |
| model.zero_grad() # Reset gradients tensors | |
| if (i+1) % evaluation_steps == 0: # Evaluate the model when we... | |
| evaluate_model() # ...have no gradients accumulated |
oregonoparm-sketch
commented
Apr 19, 2026
<script src="https://gist.github.com/thomwolf/ac7a7da6b1888c2eeac8ac8b9b05d3d3.js"></script>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment