Skip to content

Instantly share code, notes, and snippets.

@KohakuBlueleaf
Created May 19, 2022 18:06
Show Gist options
  • Save KohakuBlueleaf/83b36ba1ba50dd4c0f44e10e6af5b81e to your computer and use it in GitHub Desktop.
Save KohakuBlueleaf/83b36ba1ba50dd4c0f44e10e6af5b81e to your computer and use it in GitHub Desktop.
PyTorch basic benchmark
import torch
import numpy as np
import timeit
#could be "cpu" "cuda" "opencl" "mps"...
DEVICE = "mps"
ROUNDS = 500
NUM = (2048, 2048)
# Input for benchmarking
data = np.random.randn(*NUM).astype(np.float32)
x = torch.from_numpy(data).to(device=DEVICE)
'''
Benchmark functions
'''
def batched_dot_mul_sum(a, b):
'''Computes batched dot by multiplying and summing'''
return a.mul(b).sum(-1).cpu().detach().numpy()
def batched_dot_bmm(a, b):
'''Computes batched dot by reducing to bmm'''
a = a.reshape(-1, 1, a.shape[-1])
b = b.reshape(-1, b.shape[-1], 1)
return torch.bmm(a, b).flatten(-3).cpu().detach().numpy()
# Ensure that both functions compute the same output
assert torch.allclose(torch.from_numpy(batched_dot_mul_sum(x, x)),torch.from_numpy(batched_dot_bmm(x, x)))
def main():
t0 = timeit.Timer(
stmt='batched_dot_mul_sum(x, x)',
setup='from __main__ import batched_dot_mul_sum',
globals={'x': x}
)
t1 = timeit.Timer(
stmt='batched_dot_bmm(x, x)',
setup='from __main__ import batched_dot_bmm',
globals={'x': x}
)
print(f'mul_sum(x, x): {t0.timeit(ROUNDS) / ROUNDS * 1e6:>5.1f} us')
print(f'bmm(x, x) : {t1.timeit(ROUNDS) / ROUNDS * 1e6:>5.1f} us')
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment