Skip to content

Instantly share code, notes, and snippets.

@vadimkantorov
Created July 10, 2025 18:03
Show Gist options
  • Save vadimkantorov/5969b8308e5cb09e3421e225baaa2dbc to your computer and use it in GitHub Desktop.
Save vadimkantorov/5969b8308e5cb09e3421e225baaa2dbc to your computer and use it in GitHub Desktop.
Install a OOM hook in PyTorch
# PYTHONPATH=. python ...
import os
import torch
def cuda_oom_hook(device, alloc, device_alloc, device_free, info = dict(counter = 0), snapshot_dump_file_pattern = './memory_snapshot_{pid}_{oom_counter}.pt'):
memory_summary = torch.cuda.memory_summary(device = device)
memory_snapshot = torch.cuda.memory._snapshot(device = device)
pid = os.getpid()
print('device:', device, 'oom#:', info['oom_counter'], 'pid:', pid, 'alloc:', alloc, 'device_alloc:', device_alloc, 'device_free:', device_free)
for line in memory_summary.splitlines(): print('device:', device, 'oom#:', info['oom_counter'], 'pid:', pid, line)
torch.save(memory_snapshot, snapshot_dump_file_pattern.format(pid = pid, oom_counter = info['oom_counter']))
info['oom_counter'] += 1
torch._C._cuda_attach_out_of_memory_observer(cuda_oom_hook)
#torch._C._cuda_attach_out_of_memory_observer(lambda device, alloc, device_alloc, device_free, info = dict(oom_counter = 0), os = __import__('os'): print(os.getpid(), device, alloc, device_alloc, device_free, '\n', torch.cuda.memory_summary(device = device)) or torch.save(torch.cuda.memory._snapshot(device = device), 'memory_snapshot_{pid}_{oom_counter}.pt'.format(pid = os.getpid(), **info)) or info.update(dict(oom_counter = 1 + info['oom_counter'])))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment