Skip to content

Instantly share code, notes, and snippets.

@jsta
Forked from vahbuna/pymem.md
Created June 3, 2025 18:16
Show Gist options
  • Save jsta/0378cbcef3bd49729d5086e93d22649d to your computer and use it in GitHub Desktop.
Save jsta/0378cbcef3bd49729d5086e93d22649d to your computer and use it in GitHub Desktop.
Debugging PyTorch memory use with snapshots - Zach's Blog

https://zdevito.github.io/2022/08/16/memory-snapshots.html

# enable the recording of stack frame information for each allocation
import torch
torch.cuda.memory._record_memory_history(True)

from torchvision.models import resnet18
from pprint import pprint

model = resnet18().cuda()
input = torch.rand(1, 3, 224, 224).cuda()
model.train()
output = model(input)
snapshot = torch.cuda.memory._snapshot()
pprint(snapshot['segments'])
# The snapshot is a list of Segment dictionaries with this structure:
from typing import TypedDict, List

class Segment(TypedDict):
    address: int
    total_size: int #  cudaMalloc'd size of segment
    stream: int
    segment_type: str # 'large' (>1MB) or 'small'
    allocated_size: int # size of memory in use
    active_size: int # size of memory in use or in active_awaiting_free state
    blocks : List[Block]

class Block(TypedDict):
    size: int
    state: str # 'active_allocated', used by a tensor
               # 'active_awaiting_free', we are waiting for another stream to finish using
               #                         this, then it will become free
               # 'inactive', free for reuse
    history: List[History]

class History(TypedDict):
    addr: int
    frames : List[Frame] # stack trace when address was last allocated
                         # most recent frame first
    real_size: int # unrounded size requested from the allocator

class Frame(TypedDict):
    filename: str
    line: int
    name: str

class Snapshot(TypedDict):
    segments : List[Segment]

snapshot : Snapshot = torch.cuda.memory._snapshot()
from pickle import dump
dump(snapshot, open('snapshot.pickle', 'wb'))

https://github.com/pytorch/pytorch/blob/master/torch/cuda/_memory_viz.py

python _memory_viz.py stats snapshot.pickle

Visualizing snapshots

flame graph

python _memory_viz.py memory snapshot.pickle -o memory.svg

The memory view gives a good overview of how the memory is being used. For debugging allocator issues in particular, though, it is useful to first categorized memory into individual Segment objects, which are the invidual cudaMalloc segments that allocated tracks:

python _memory_viz.py segments snapshot.pickle -o segments.svg

Comparing snapshots

input8 = torch.rand(8, 3, 224, 224, device='cuda')
output = model(input8)
snapshot = torch.cuda.memory._snapshot()
dump(snapshot, open('snapshot2.pickle', 'wb'))

python _memory_viz.py compare snapshot.pickle snapshot2.pickle -o segments2.svg

python _memory_viz.py compare snapshot.pickle snapshot2.pickle -o compare.svg

Generating Snapshots when Out of Memory

def oom_observer(device, alloc, device_alloc, device_free):
    # snapshot right after an OOM happened
    print('saving allocated state during OOM')
    snapshot = torch.cuda.memory._snapshot()
    dump(snapshot, open('oom_snapshot.pickle', 'wb'))

torch._C._cuda_attach_out_of_memory_observer(oom_observer)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment