Skip to content

Instantly share code, notes, and snippets.

@liangfu
Created March 13, 2025 16:34
Show Gist options
  • Save liangfu/e9c072b21b853975d9dc82647eb05d3d to your computer and use it in GitHub Desktop.
Save liangfu/e9c072b21b853975d9dc82647eb05d3d to your computer and use it in GitHub Desktop.
Evaluate consistency when mixing eager execution with torch.compile()
import torch
import os
import torch_xla.core.xla_model as xm
def write_to_kv_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)
key = key.flatten(0, 2)
value = value.flatten(0, 2)
key_cache = key_cache.flatten(0, 2)
value_cache = value_cache.flatten(0, 2)
key_cache.index_copy_(0, slot_mapping, key)
value_cache.index_copy_(0, slot_mapping, value)
if __name__ == '__main__':
device = xm.xla_device()
num_blocks = 128
block_size = 128
num_kv_heads = 4
head_size = 64
kv_cache_shape = (2, num_blocks, block_size, num_kv_heads, head_size)
kv_cache = torch.zeros(kv_cache_shape,
dtype=torch.float,
device=device)
key_cache, value_cache = kv_cache
num_heads = 64
kv = torch.empty(1, 3, 2, num_kv_heads, head_size, dtype=torch.float, device=device)
kv.uniform_(-1,1)
key, value = kv.unbind(dim=2)
slot_mapping = torch.tensor([0,1,2,3,4,5,6,7,8,9,10,11], dtype=torch.int32,device=device).long()
compiled_callable = torch.compile(write_to_kv_cache,
backend="openxla",
fullgraph=False,
dynamic=False)
compiled_callable(key, value, key_cache, value_cache, slot_mapping)
print(f"k/v cache use torch compile {key_cache[0][:5]}")
compiled_callable(key, value, key_cache, value_cache, slot_mapping)
print(f"k/v cache use torch compile again func {key_cache[0][:5]}")
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
print(f"k/v cache use original func {key_cache[0][:5]}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment