Skip to content

Instantly share code, notes, and snippets.

@liangfu
Created January 14, 2025 07:19
Show Gist options
  • Save liangfu/d49a1b2511c0faec660a1a8ea691b072 to your computer and use it in GitHub Desktop.
Save liangfu/d49a1b2511c0faec660a1a8ea691b072 to your computer and use it in GitHub Desktop.
import torch
import os
import depyf
import torch_xla.core.xla_model as xm
os.environ["NEURON_CC_FLAGS"]= " --model-type=transformer -O1 "
os.environ["NEURON_COMPILE_CACHE_URL"] = os.path.join(os.getcwd(), "_compile_cache")
@torch.compiler.allow_in_graph
def write_to_kv_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)
key = key.flatten(0, 2)
value = value.flatten(0, 2)
key_cache = key_cache.flatten(0, 2)
value_cache = value_cache.flatten(0, 2)
key_cache.index_copy_(0, slot_mapping, key)
value_cache.index_copy_(0, slot_mapping, value)
if __name__ == '__main__':
device = xm.xla_device()
num_blocks = 16
block_size = 8
num_kv_heads = 4
head_size = 4
kv_cache_shape = (1, num_blocks, block_size, num_kv_heads, head_size)
key_cache = torch.zeros(kv_cache_shape, dtype=torch.float, device=device)
value_cache = torch.zeros(kv_cache_shape, dtype=torch.float, device=device)
num_heads = 4
key = torch.empty(1, 12, 1, num_kv_heads, head_size, dtype=torch.float, device=device)
value = torch.empty(1, 12, 1, num_kv_heads, head_size, dtype=torch.float, device=device)
key.uniform_(-1,1)
value.uniform_(-1,1)
slot_mapping = torch.tensor([0,1,2,3,4,5,6,7,8,9,10,11], dtype=torch.int32,device=device).long()
print(f"{key=}, {value=}, {key_cache=}, {value_cache=}, {slot_mapping=}")
print("before compilation")
with depyf.prepare_debug("./tmp"):
compiled_callable = torch.compile(write_to_kv_cache,
backend="openxla",
fullgraph=True)
print("done compilation")
compiled_callable(key, value, key_cache, value_cache, slot_mapping)
print(f"k/v cache use torch compile {key_cache[0][:5]}")
compiled_callable(key.to(device=device), value.to(device=device), key_cache.to(device=device), value_cache.to(device=device), slot_mapping.to(device=device))
print(f"k/v cache use torch compile again func {key_cache[0][:5]}")
write_to_kv_cache(key.to(device=device), value.to(device=device), key_cache.to(device=device), value_cache.to(device=device), slot_mapping.to(device=device))
print(f"k/v cache use original func {key_cache[0][:5]}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment