Created
March 11, 2025 21:26
-
-
Save liangfu/f07bed699662fee29ff594dbc6aa208b to your computer and use it in GitHub Desktop.
Evaluate torch.split and slice operator support on openxla backend (with torch_neuronx)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import pytest | |
import torch | |
import torch_neuronx | |
import torch_xla.core.xla_model as xm | |
@pytest.mark.parametrize("batch_size,seq_len,q_size,kv_size,use_torch_compile,disable_functionalization", [ | |
(2, 128, 32, 32, False, True), | |
(2, 128, 32, 32, True, True), | |
(2, 128, 32, 32, False, False), | |
(2, 128, 32, 32, True, False), | |
]) | |
def test_split_consistency(batch_size, seq_len, q_size, kv_size, use_torch_compile, disable_functionalization): | |
if not disable_functionalization: | |
os.environ["XLA_DISABLE_FUNCTIONALIZATION"] = "0" | |
# Get XLA device | |
device = xm.xla_device() | |
# Generate random input on CPU first for reference | |
qkv_cpu = torch.randn(batch_size, seq_len, q_size + 2 * kv_size) | |
# Create XLA device tensor directly | |
qkv_xla = qkv_cpu.to(device=device) | |
# Generate reference outputs on CPU using both implementations | |
with torch.inference_mode(): | |
# First implementation (slice) | |
q_ref = qkv_cpu[:, :, :q_size] | |
k_ref = qkv_cpu[:, :, q_size:q_size + kv_size] | |
v_ref = qkv_cpu[:, :, q_size + kv_size:q_size + (2 * kv_size)] | |
# Second implementation (split) | |
q_split_ref, k_split_ref, v_split_ref = qkv_cpu.split([q_size, kv_size, kv_size], dim=-1) | |
# Verify CPU implementations match | |
assert torch.allclose(q_ref, q_split_ref, rtol=1e-5, atol=1e-5) | |
assert torch.allclose(k_ref, k_split_ref, rtol=1e-5, atol=1e-5) | |
assert torch.allclose(v_ref, v_split_ref, rtol=1e-5, atol=1e-5) | |
# Define the two implementations as functions | |
def slice_impl(qkv): | |
q = qkv[:, :, :q_size] | |
k = qkv[:, :, q_size:q_size + kv_size] | |
v = qkv[:, :, q_size + kv_size:q_size + (2 * kv_size)] | |
return torch.concat([q, k, v]) | |
def split_impl(qkv): | |
q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1) | |
return torch.concat([q, k, v]) | |
# Compile both implementations with openxla backend | |
if use_torch_compile: | |
compiled_slice = torch.compile(slice_impl, backend='openxla', fullgraph=True, dynamic=False) | |
compiled_split = torch.compile(split_impl, backend='openxla', fullgraph=True, dynamic=False) | |
# Test both compiled implementations | |
qkv_slice_xla = compiled_slice(qkv_xla) | |
qkv_split_xla = compiled_split(qkv_xla) | |
else: | |
compiled_slice = torch_neuronx.trace(slice_impl, (qkv_cpu,)) | |
compiled_split = torch_neuronx.trace(split_impl, (qkv_cpu,)) | |
# Test both compiled implementations | |
qkv_slice_xla = compiled_slice(qkv_cpu) | |
qkv_split_xla = compiled_split(qkv_cpu) | |
# Compare compiled results with CPU reference (need to move to CPU for comparison) | |
q_xla_cpu, k_xla_cpu, v_xla_cpu = qkv_slice_xla.cpu().split(batch_size) | |
q_split_xla_cpu, k_split_xla_cpu, v_split_xla_cpu = qkv_split_xla.cpu().split(batch_size) | |
# Compare compiled results with CPU reference | |
assert torch.allclose(q_xla_cpu, q_ref, rtol=1e-5, atol=1e-5) | |
assert torch.allclose(k_xla_cpu, k_ref, rtol=1e-5, atol=1e-5) | |
assert torch.allclose(v_xla_cpu, v_ref, rtol=1e-5, atol=1e-5) | |
# Verify compiled implementations match each other | |
assert torch.allclose(q_ref, q_split_xla_cpu, rtol=1e-5, atol=1e-5) | |
assert torch.allclose(k_ref, k_split_xla_cpu, rtol=1e-5, atol=1e-5) | |
assert torch.allclose(v_ref, v_split_xla_cpu, rtol=1e-5, atol=1e-5) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment