https://zdevito.github.io/2022/08/16/memory-snapshots.html
# enable the recording of stack frame information for each allocation
import torch
torch.cuda.memory._record_memory_history(True)
from torchvision.models import resnet18
from pprint import pprint
model = resnet18().cuda()
input = torch.rand(1, 3, 224, 224).cuda()
model.train()
output = model(input)
snapshot = torch.cuda.memory._snapshot()
pprint(snapshot['segments'])
# The snapshot is a list of Segment dictionaries with this structure:
from typing import TypedDict, List
class Segment(TypedDict):
address: int
total_size: int # cudaMalloc'd size of segment
stream: int
segment_type: str # 'large' (>1MB) or 'small'
allocated_size: int # size of memory in use
active_size: int # size of memory in use or in active_awaiting_free state
blocks : List[Block]
class Block(TypedDict):
size: int
state: str # 'active_allocated', used by a tensor
# 'active_awaiting_free', we are waiting for another stream to finish using
# this, then it will become free
# 'inactive', free for reuse
history: List[History]
class History(TypedDict):
addr: int
frames : List[Frame] # stack trace when address was last allocated
# most recent frame first
real_size: int # unrounded size requested from the allocator
class Frame(TypedDict):
filename: str
line: int
name: str
class Snapshot(TypedDict):
segments : List[Segment]
snapshot : Snapshot = torch.cuda.memory._snapshot()
from pickle import dump
dump(snapshot, open('snapshot.pickle', 'wb'))
https://github.com/pytorch/pytorch/blob/master/torch/cuda/_memory_viz.py
python _memory_viz.py stats snapshot.pickle
flame graph
python _memory_viz.py memory snapshot.pickle -o memory.svg
The memory view gives a good overview of how the memory is being used. For debugging allocator issues in particular, though, it is useful to first categorized memory into individual Segment objects, which are the invidual cudaMalloc segments that allocated tracks:
python _memory_viz.py segments snapshot.pickle -o segments.svg
input8 = torch.rand(8, 3, 224, 224, device='cuda')
output = model(input8)
snapshot = torch.cuda.memory._snapshot()
dump(snapshot, open('snapshot2.pickle', 'wb'))
python _memory_viz.py compare snapshot.pickle snapshot2.pickle -o segments2.svg
python _memory_viz.py compare snapshot.pickle snapshot2.pickle -o compare.svg
def oom_observer(device, alloc, device_alloc, device_free):
# snapshot right after an OOM happened
print('saving allocated state during OOM')
snapshot = torch.cuda.memory._snapshot()
dump(snapshot, open('oom_snapshot.pickle', 'wb'))
torch._C._cuda_attach_out_of_memory_observer(oom_observer)
Thanks for putting this together! Small typo:
s/snapshot.py/snapshot.pickle/g