-
-
Save stas00/b7e8da20ff999de0eb74f022ccf4fd00 to your computer and use it in GitHub Desktop.
Measure performance difference of `torch.mm` vs `torch.bmm`
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Benchmark relative performance of torch.mm and torch.bmm with single batch | |
import torch | |
import time | |
def benchmark_fn(fn, args, warmup=5, cycles=300, use_kineto=False) -> float: | |
if use_kineto: | |
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as p: | |
fn(*args) | |
return sum([e.cuda_time for e in p.key_averages()]) | |
for _ in range(warmup): | |
fn(*args) | |
torch.cuda.synchronize() | |
begin = time.time() | |
for _ in range(cycles): | |
fn(*args) | |
torch.cuda.synchronize() | |
dt = (time.time() - begin) | |
dt_us = int(dt * 1000000) / cycles | |
return dt_us | |
if __name__ == "__main__": | |
print("torch: ", torch.__version__, " device: ", torch.cuda.get_device_name(0)) | |
msizes = [(1, 1, 4096), (1, 1, 65536), (129, 129, 129), (257, 257, 257), (128, 257, 512), (16385, 5, 16385)] | |
msizes = [(1, 1, 2**x) for x in range(12, 18)] | |
msizes += [(2**x, 2**x, 2**x) for x in range(7, 12)] | |
msizes += [(2**x+1, 2**x-1, 2**x+1) for x in range(7, 12)] | |
msizes += [(2**x+1, 3, 2**x+1) for x in range(12, 17)] | |
msizes += [(2**x+1, 5, 2**x+1) for x in range(12, 17)] | |
msizes += [(2**x+1, 7, 2**x+1) for x in range(12, 17)] | |
print("| Shape | bmm_time | mm_time | slow down (%) |") | |
print("| -------------- | --------- | --------- | ------------- |") | |
for (m, n, k) in msizes: | |
a = torch.rand((m, k), device='cuda') | |
b = torch.rand((k, n), device='cuda') | |
bmm_time = benchmark_fn(torch.bmm, (a.unsqueeze(0), b.unsqueeze(0))) | |
mm_time = benchmark_fn(torch.mm, (a, b)) | |
shape_str=f"{m}x{n}x{k}" | |
print(f"| {shape_str :^14} | {bmm_time :^9.2f} | {mm_time :^9.2f} | {100.0*(bmm_time-mm_time)/mm_time :^13.2f} |") | |
assert torch.allclose(torch.bmm(a.unsqueeze(0), b.unsqueeze(0)).squeeze(0), torch.mm(a, b)) | |
# Running above script on A100 with torch-2.1.1+cu118 following output is produced | |
# torch: 2.1.1+cu118 device: NVIDIA A100-SXM4-40GB | |
# | Shape | bmm_time | mm_time | slow down (%) | | |
# | -------------- | --------- | --------- | ------------- | | |
# | 1x1x4096 | 12.38 | 11.96 | 3.48 | | |
# | 1x1x8192 | 12.26 | 11.84 | 3.55 | | |
# | 1x1x16384 | 11.81 | 11.66 | 1.29 | | |
# | 1x1x32768 | 12.00 | 11.81 | 1.61 | | |
# | 1x1x65536 | 14.82 | 15.05 | -1.48 | | |
# | 1x1x131072 | 12.02 | 11.77 | 2.15 | | |
# | 128x128x128 | 9.47 | 9.69 | -2.24 | | |
# | 256x256x256 | 12.66 | 12.60 | 0.50 | | |
# | 512x512x512 | 27.34 | 27.31 | 0.10 | | |
# | 1024x1024x1024 | 129.59 | 129.48 | 0.08 | | |
# | 2048x2048x2048 | 973.63 | 973.04 | 0.06 | | |
# | 129x127x129 | 9.56 | 8.97 | 6.62 | | |
# | 257x255x257 | 12.85 | 12.78 | 0.52 | | |
# | 513x511x513 | 28.99 | 28.98 | 0.05 | | |
# | 1025x1023x1025 | 137.92 | 137.76 | 0.11 | | |
# | 2049x2047x2049 | 982.34 | 982.32 | 0.00 | | |
# | 4097x3x4097 | 86.94 | 86.91 | 0.03 | | |
# | 8193x3x8193 | 384.38 | 384.54 | -0.04 | | |
# | 16385x3x16385 | 1106.25 | 1107.35 | -0.10 | | |
# | 32769x3x32769 | 4736.79 | 4737.19 | -0.01 | | |
# | 65537x3x65537 | 17368.65 | 17371.21 | -0.01 | | |
# | 4097x5x4097 | 87.50 | 87.49 | 0.01 | | |
# | 8193x5x8193 | 302.27 | 302.29 | -0.00 | | |
# | 16385x5x16385 | 1107.69 | 1107.65 | 0.00 | | |
# | 32769x5x32769 | 4743.02 | 4743.13 | -0.00 | | |
# | 65537x5x65537 | 17393.08 | 17392.32 | 0.00 | | |
# | 4097x7x4097 | 87.58 | 87.60 | -0.02 | | |
# | 8193x7x8193 | 302.42 | 302.45 | -0.01 | | |
# | 16385x7x16385 | 1106.55 | 1107.34 | -0.07 | | |
# | 32769x7x32769 | 4746.99 | 4746.58 | 0.01 | | |
# | 65537x7x65537 | 17406.08 | 17424.31 | -0.10 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment