Skip to content

Instantly share code, notes, and snippets.

@liangfu
Created January 9, 2025 04:34
Show Gist options
  • Save liangfu/83af1873c652f26431bc9026eaa6a616 to your computer and use it in GitHub Desktop.
Save liangfu/83af1873c652f26431bc9026eaa6a616 to your computer and use it in GitHub Desktop.
Benchmark xla scatter with torch-xla
import time
import torch
import torch_xla.core.xla_model as xm
N = 128
n_iters = 100
def main():
device = xm.xla_device()
src = torch.arange(1, 2*N+1).reshape((2, N)).to(device=device)
print(src)
# V1
index = torch.tensor([[0,1]]).to(device=device).reshape(2,1).expand(2,N)
input = torch.zeros(3, N, dtype=src.dtype).to(device=device)
input = torch.scatter(input, 0, index, src) # warmup
print(f"V1 {index=}")
print(f"V1 {input=}")
with TimedRegion("V1"):
for _ in range(n_iters):
input = torch.scatter(input, 0, index, src)
# V2
index = torch.tensor([0,1]).to(device=device)
input = torch.zeros(3, N, dtype=src.dtype).to(device=device)
input[index,:] = src # warmup
print(f"V2 {index=}")
print(f"V2 {input=}")
with TimedRegion("V2"):
for _ in range(n_iters):
input[index,:] = src
# V3
index = torch.tensor([0,1]).to(device=device)
input = torch.zeros(3, N, dtype=src.dtype).to(device=device)
input.index_copy_(0, index, src) # warmup
print(f"V3 {index=}")
print(f"V3 {input=}")
with TimedRegion("V3"):
for _ in range(n_iters):
input.index_copy_(0, index, src)
class TimedRegion:
def __init__(self, name):
self.name = name
self.start_time = None
def __enter__(self):
self.start_time = time.time()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
end_time = time.time()
elapsed_time = end_time - self.start_time
print(f"{self.name}: {elapsed_time*1000.0:.1f} ms")
if __name__=="__main__":
main()
@liangfu
Copy link
Author

liangfu commented Jan 9, 2025

V1: 0.4 ms
V2: 3.5 ms
V3: 2.4 ms

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment