Skip to content

Instantly share code, notes, and snippets.

@vahbuna
Last active June 3, 2025 18:16
Show Gist options
  • Save vahbuna/53b7b5dc4aadf2a43daf3ebc5ac93c0e to your computer and use it in GitHub Desktop.
Save vahbuna/53b7b5dc4aadf2a43daf3ebc5ac93c0e to your computer and use it in GitHub Desktop.
Debugging PyTorch memory use with snapshots - Zach's Blog

https://zdevito.github.io/2022/08/16/memory-snapshots.html

# enable the recording of stack frame information for each allocation
import torch
torch.cuda.memory._record_memory_history(True)

from torchvision.models import resnet18
from pprint import pprint

model = resnet18().cuda()
input = torch.rand(1, 3, 224, 224).cuda()
model.train()
output = model(input)
snapshot = torch.cuda.memory._snapshot()
pprint(snapshot['segments'])
# The snapshot is a list of Segment dictionaries with this structure:
from typing import TypedDict, List

class Segment(TypedDict):
    address: int
    total_size: int #  cudaMalloc'd size of segment
    stream: int
    segment_type: str # 'large' (>1MB) or 'small'
    allocated_size: int # size of memory in use
    active_size: int # size of memory in use or in active_awaiting_free state
    blocks : List[Block]

class Block(TypedDict):
    size: int
    state: str # 'active_allocated', used by a tensor
               # 'active_awaiting_free', we are waiting for another stream to finish using
               #                         this, then it will become free
               # 'inactive', free for reuse
    history: List[History]

class History(TypedDict):
    addr: int
    frames : List[Frame] # stack trace when address was last allocated
                         # most recent frame first
    real_size: int # unrounded size requested from the allocator

class Frame(TypedDict):
    filename: str
    line: int
    name: str

class Snapshot(TypedDict):
    segments : List[Segment]

snapshot : Snapshot = torch.cuda.memory._snapshot()
from pickle import dump
dump(snapshot, open('snapshot.pickle', 'wb'))

https://github.com/pytorch/pytorch/blob/master/torch/cuda/_memory_viz.py

python _memory_viz.py stats snapshot.pickle

Visualizing snapshots

flame graph

python _memory_viz.py memory snapshot.pickle -o memory.svg

The memory view gives a good overview of how the memory is being used. For debugging allocator issues in particular, though, it is useful to first categorized memory into individual Segment objects, which are the invidual cudaMalloc segments that allocated tracks:

python _memory_viz.py segments snapshot.pickle -o segments.svg

Comparing snapshots

input8 = torch.rand(8, 3, 224, 224, device='cuda')
output = model(input8)
snapshot = torch.cuda.memory._snapshot()
dump(snapshot, open('snapshot2.pickle', 'wb'))

python _memory_viz.py compare snapshot.pickle snapshot2.pickle -o segments2.svg

python _memory_viz.py compare snapshot.pickle snapshot2.pickle -o compare.svg

Generating Snapshots when Out of Memory

def oom_observer(device, alloc, device_alloc, device_free):
    # snapshot right after an OOM happened
    print('saving allocated state during OOM')
    snapshot = torch.cuda.memory._snapshot()
    dump(snapshot, open('oom_snapshot.pickle', 'wb'))

torch._C._cuda_attach_out_of_memory_observer(oom_observer)
@tmm1
Copy link

tmm1 commented Aug 21, 2023

Thanks for putting this together! Small typo: s/snapshot.py/snapshot.pickle/g

@vahbuna
Copy link
Author

vahbuna commented Jun 28, 2024

fixed :) Thanks @tmm1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment