Created
May 19, 2022 18:06
-
-
Save KohakuBlueleaf/83b36ba1ba50dd4c0f44e10e6af5b81e to your computer and use it in GitHub Desktop.
PyTorch basic benchmark
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import numpy as np | |
import timeit | |
#could be "cpu" "cuda" "opencl" "mps"... | |
DEVICE = "mps" | |
ROUNDS = 500 | |
NUM = (2048, 2048) | |
# Input for benchmarking | |
data = np.random.randn(*NUM).astype(np.float32) | |
x = torch.from_numpy(data).to(device=DEVICE) | |
''' | |
Benchmark functions | |
''' | |
def batched_dot_mul_sum(a, b): | |
'''Computes batched dot by multiplying and summing''' | |
return a.mul(b).sum(-1).cpu().detach().numpy() | |
def batched_dot_bmm(a, b): | |
'''Computes batched dot by reducing to bmm''' | |
a = a.reshape(-1, 1, a.shape[-1]) | |
b = b.reshape(-1, b.shape[-1], 1) | |
return torch.bmm(a, b).flatten(-3).cpu().detach().numpy() | |
# Ensure that both functions compute the same output | |
assert torch.allclose(torch.from_numpy(batched_dot_mul_sum(x, x)),torch.from_numpy(batched_dot_bmm(x, x))) | |
def main(): | |
t0 = timeit.Timer( | |
stmt='batched_dot_mul_sum(x, x)', | |
setup='from __main__ import batched_dot_mul_sum', | |
globals={'x': x} | |
) | |
t1 = timeit.Timer( | |
stmt='batched_dot_bmm(x, x)', | |
setup='from __main__ import batched_dot_bmm', | |
globals={'x': x} | |
) | |
print(f'mul_sum(x, x): {t0.timeit(ROUNDS) / ROUNDS * 1e6:>5.1f} us') | |
print(f'bmm(x, x) : {t1.timeit(ROUNDS) / ROUNDS * 1e6:>5.1f} us') | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment