Skip to content

Instantly share code, notes, and snippets.

@Mason-McGough
Last active September 30, 2024 20:13
Show Gist options
  • Save Mason-McGough/d8209f269d9335f5b6ff66c85d33f5ce to your computer and use it in GitHub Desktop.
Save Mason-McGough/d8209f269d9335f5b6ff66c85d33f5ce to your computer and use it in GitHub Desktop.
Simple reference functions to estimate memory usage in torch.
import functools
import torch
import psutil
def process_memory() -> float:
"""
Estimate total memory usage of the current process.
"""
return psutil.Process().memory_info().rss / (1024 ** 2)
def tensor_memory(tensor: torch.Tensor) -> float:
"""
Estimate total memory usage of a tensor in MB.
"""
memory_bytes = tensor.numel() * tensor.element_size()
return memory_bytes / (1024 ** 2) # Convert to MB
def module_memory(model: torch.nn.Module) -> float:
"""
Estimate total memory usage of a torch module in MB.
"""
param_memory = sum(p.numel() * p.element_size() for p in model.parameters())
return param_memory / (1024 ** 2) # Convert to MB
def total_vram_allocation() -> float:
"""
Estimate total CUDA VRAM allocation in MB.
"""
return torch.cuda.max_memory_allocated() / (1024 ** 2)
def track_peak_memory(device=0):
"""
Function decorator to print peak VRAM usage during function call on a specific device.
"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
torch.cuda.reset_peak_memory_stats(device)
result = func(*args, **kwargs)
peak_memory_allocated = torch.cuda.max_memory_allocated(device)
print(f"Peak VRAM allocated during forward pass: {peak_memory_allocated / (1024 ** 2):.2f} MB")
return result
return wrapper
return decorator
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment