Credit/source: here
how to use unsloth grad checkpointing
To integrate the provided monkey patch for offloading gradient checkpointing into the Hugging Face transformers library, you need to follow these steps:
-
Understand the provided code: The code defines a custom gradient checkpointing function,
Unsloth_Offloaded_Gradient_Checkpointer, that offloads tensors to CPU to save VRAM. This function is then used in a new methodnew_gradient_checkpointing_enableto enable gradient checkpointing with this custom functionality. -
Apply the monkey patch: The provided function
apply_unsloth_offloaded_gradient_checkpoint_monkey_patchmodifies thegradient_checkpointing_enablemethod oftransformersmodels to use the custom offloaded gradient checkpointing. -
Use the patched method in your model: After applying the monkey patch, you can enable gradient checkpointing in your model, and it will use the custom offloading method.
Here's a step-by-step guide:
Save the provided code in a Python script, for example, unsloth_offload_gc.py.
In your training script, import the necessary components and apply the monkey patch:
import torch
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments
# Import the custom gradient checkpointing patch
from unsloth_offload_gc import apply_unsloth_offloaded_gradient_checkpoint_monkey_patch
# Apply the monkey patch
apply_unsloth_offloaded_gradient_checkpoint_monkey_patch()
# Load your model
model_name = "bert-base-uncased"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Rest of your training code
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
warmup_steps=500,
weight_decay=0.01,
logging_dir="./logs",
logging_steps=10,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()-
Import the Patch: Ensure that you import the
apply_unsloth_offloaded_gradient_checkpoint_monkey_patchfunction from the script where you saved the provided code. -
Apply the Patch: Call the
apply_unsloth_offloaded_gradient_checkpoint_monkey_patch()function. This modifies thegradient_checkpointing_enablemethod of thetransformersmodels to use the custom gradient checkpointing function. -
Load and Configure the Model: Load your desired model using
AutoModelForSequenceClassificationor any other relevant class. Then, enable gradient checkpointing by callingmodel.gradient_checkpointing_enable(). -
Training Script: Continue with your usual training script, setting up
TrainingArgumentsandTraineras needed.
By following these steps, you integrate the custom gradient checkpointing functionality into your model training process, potentially saving VRAM by offloading tensors to RAM during the forward and backward passes.