Skip to content

Instantly share code, notes, and snippets.

@liangfu
Created March 11, 2025 21:26
Show Gist options
  • Save liangfu/f07bed699662fee29ff594dbc6aa208b to your computer and use it in GitHub Desktop.
Save liangfu/f07bed699662fee29ff594dbc6aa208b to your computer and use it in GitHub Desktop.
Evaluate torch.split and slice operator support on openxla backend (with torch_neuronx)
import os
import pytest
import torch
import torch_neuronx
import torch_xla.core.xla_model as xm
@pytest.mark.parametrize("batch_size,seq_len,q_size,kv_size,use_torch_compile,disable_functionalization", [
(2, 128, 32, 32, False, True),
(2, 128, 32, 32, True, True),
(2, 128, 32, 32, False, False),
(2, 128, 32, 32, True, False),
])
def test_split_consistency(batch_size, seq_len, q_size, kv_size, use_torch_compile, disable_functionalization):
if not disable_functionalization:
os.environ["XLA_DISABLE_FUNCTIONALIZATION"] = "0"
# Get XLA device
device = xm.xla_device()
# Generate random input on CPU first for reference
qkv_cpu = torch.randn(batch_size, seq_len, q_size + 2 * kv_size)
# Create XLA device tensor directly
qkv_xla = qkv_cpu.to(device=device)
# Generate reference outputs on CPU using both implementations
with torch.inference_mode():
# First implementation (slice)
q_ref = qkv_cpu[:, :, :q_size]
k_ref = qkv_cpu[:, :, q_size:q_size + kv_size]
v_ref = qkv_cpu[:, :, q_size + kv_size:q_size + (2 * kv_size)]
# Second implementation (split)
q_split_ref, k_split_ref, v_split_ref = qkv_cpu.split([q_size, kv_size, kv_size], dim=-1)
# Verify CPU implementations match
assert torch.allclose(q_ref, q_split_ref, rtol=1e-5, atol=1e-5)
assert torch.allclose(k_ref, k_split_ref, rtol=1e-5, atol=1e-5)
assert torch.allclose(v_ref, v_split_ref, rtol=1e-5, atol=1e-5)
# Define the two implementations as functions
def slice_impl(qkv):
q = qkv[:, :, :q_size]
k = qkv[:, :, q_size:q_size + kv_size]
v = qkv[:, :, q_size + kv_size:q_size + (2 * kv_size)]
return torch.concat([q, k, v])
def split_impl(qkv):
q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
return torch.concat([q, k, v])
# Compile both implementations with openxla backend
if use_torch_compile:
compiled_slice = torch.compile(slice_impl, backend='openxla', fullgraph=True, dynamic=False)
compiled_split = torch.compile(split_impl, backend='openxla', fullgraph=True, dynamic=False)
# Test both compiled implementations
qkv_slice_xla = compiled_slice(qkv_xla)
qkv_split_xla = compiled_split(qkv_xla)
else:
compiled_slice = torch_neuronx.trace(slice_impl, (qkv_cpu,))
compiled_split = torch_neuronx.trace(split_impl, (qkv_cpu,))
# Test both compiled implementations
qkv_slice_xla = compiled_slice(qkv_cpu)
qkv_split_xla = compiled_split(qkv_cpu)
# Compare compiled results with CPU reference (need to move to CPU for comparison)
q_xla_cpu, k_xla_cpu, v_xla_cpu = qkv_slice_xla.cpu().split(batch_size)
q_split_xla_cpu, k_split_xla_cpu, v_split_xla_cpu = qkv_split_xla.cpu().split(batch_size)
# Compare compiled results with CPU reference
assert torch.allclose(q_xla_cpu, q_ref, rtol=1e-5, atol=1e-5)
assert torch.allclose(k_xla_cpu, k_ref, rtol=1e-5, atol=1e-5)
assert torch.allclose(v_xla_cpu, v_ref, rtol=1e-5, atol=1e-5)
# Verify compiled implementations match each other
assert torch.allclose(q_ref, q_split_xla_cpu, rtol=1e-5, atol=1e-5)
assert torch.allclose(k_ref, k_split_xla_cpu, rtol=1e-5, atol=1e-5)
assert torch.allclose(v_ref, v_split_xla_cpu, rtol=1e-5, atol=1e-5)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment