Created
June 29, 2018 06:46
-
-
Save simgt/ac5b0706932866b05a61db09f49a7787 to your computer and use it in GitHub Desktop.
A simple Cuda memory tracer for PyTorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
from pathlib import Path | |
from functools import partial | |
from collections import defaultdict | |
import logging | |
import torch | |
logger = logging.getLogger(__name__) | |
class CudaMemoryTracer: | |
"""A memory allocation tracer for Cuda devices. | |
Args: | |
root (Path, str): the root of the hierarchy of files to be recorded | |
""" | |
def __init__(self, root=None): | |
if root is None: | |
root = Path(__file__).parent.resolve() | |
self.root = Path(root) | |
logger.info(f"Tracing all files bellow {root}") | |
self.register = defaultdict(int) | |
sys.settrace(self.trace) | |
def make_state(self, frame): | |
"""Associate the current frame code position with the memory usage. | |
Returns: | |
tuple(str, int): instruction position and current cuda memory usage | |
""" | |
file = frame.f_code.co_filename | |
line = frame.f_lineno | |
spot = f"{file}:{line}" | |
mem = torch.cuda.memory_allocated() | |
return spot, mem | |
def trace(self, frame, event, __): | |
"""Tracing entry point. Filter out the files outside of the root | |
folder. | |
""" | |
func = frame.f_code.co_name | |
file = frame.f_code.co_filename | |
if event == 'call': | |
try: | |
Path(file).resolve().relative_to(self.root) | |
except (ValueError, TypeError): | |
# This scope does not belong to the traced folder, | |
# ignore it | |
return None | |
# Retrieve the current memory status and start tracing the scope | |
return partial(self.trace_scope, self.make_state(frame)) | |
return trace | |
def trace_scope(self, prev_state, frame, event, __): | |
# This method is called on 'line' and 'return' events | |
# It is called *before* line execution, as we are interested in memory | |
# usage, it must be given the previous state of memory usage and the | |
# associated file name / line number. | |
prev_spot, prev_mem = prev_state | |
spot, mem = self.make_state(frame) | |
mem_diff = mem - prev_mem | |
if mem_diff > 0: | |
logger.debug(f"{prev_spot:<30} {mem_diff/1024:>7.1f} Kb") | |
self.register[prev_spot] += mem_diff | |
# Continue the tracing | |
return partial(self.trace_scope, (spot, mem)) | |
def summary(self, n=10): | |
top = sorted(self.register.items(), key=lambda e: e[1], reverse=True) | |
logger.info("Top GPU memory consumers:") | |
for spot, usage in top[:n]: | |
logger.info(f"{spot:<30} {usage/1024:>7.1f} Kb") | |
if __name__ == '__main__': | |
logging.basicConfig(level=logging.DEBUG) | |
tracer = CudaMemoryTracer() | |
device = torch.device('cuda') | |
torch.cuda.device(device) | |
def foo(): | |
tensors = [] | |
for __ in range(1000): | |
x = torch.Tensor(100, 100) | |
x = x.to(device) | |
tensors += [x] | |
return tensors | |
def bar(): | |
x = torch.Tensor(100, 50) | |
x = x.to(device) | |
print(f"I'm bar and I made {x.shape}") | |
return x | |
x = foo() | |
y = bar() | |
tracer.summary() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment