Last active
March 25, 2021 05:13
-
-
Save dsevero/8a885dbe0a547507a8e20ba922ffdbd6 to your computer and use it in GitHub Desktop.
time profiling with contextmanager
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from time import time | |
from contextlib import contextmanager | |
import json | |
import torch | |
import logging | |
logging.basicConfig(stream=sys.stdout, | |
level=logging.INFO, format='%(asctime)s %(name)s %(levelname)s:%(message)s') | |
logger = logging.getLogger() | |
def log(data: dict): | |
logger.info(json.dumps(data)) | |
@contextmanager | |
def log_runtime(**kwargs): | |
start = time() | |
yield | |
log({**kwargs, 'dt': time() - start, 'cuda': False}) | |
@contextmanager | |
def log_cuda_runtime(**kwargs): | |
start_event = torch.cuda.Event(enable_timing=True) | |
end_event = torch.cuda.Event(enable_timing=True) | |
start_event.record() | |
yield | |
end_event.record() | |
torch.cuda.synchronize() # Wait for the events to be recorded! | |
elapsed_time_ms = start_event.elapsed_time(end_event) | |
log({**kwargs, 'dt': elapsed_time_ms, 'cuda': True}) | |
# Examples | |
with log_runtime(foo='bar', something='else'): | |
... # your code here | |
@log_cuda_runtime(func='my_cuda_func') | |
def my_cuda_func(...): | |
... | |
model = ... # trained model | |
model = log_cuda_runtime(func='model')(model) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment