Created
September 13, 2024 13:56
-
-
Save NTT123/76193b931a72e7ab3810f143ac97b020 to your computer and use it in GitHub Desktop.
Inplace RoPE inference kernel
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
RoPE triton kernel | |
""" | |
import triton | |
import triton.language as tl | |
@triton.jit | |
def _rope_kernel( | |
x_ptr, x_row_stride, x_head_stride, | |
r_ptr, r_row_stride, | |
N, H: tl.constexpr, D: tl.constexpr, | |
): | |
row_start = tl.program_id(0).to(tl.int64) | |
row_step = tl.num_programs(0) | |
cols = tl.arange(0, D//2) | |
for row_idx in tl.range(row_start, N, step=row_step): | |
# load r to SRAM | |
r_row = tl.load( r_ptr + row_idx * r_row_stride + cols, mask=None).to(tl.float32) | |
cos = tl.cos(r_row) | |
sin = tl.sin(r_row) | |
for head_idx in tl.range(0, H): | |
x_row_0 = tl.load(x_ptr + row_idx * x_row_stride + head_idx *x_head_stride + cols * 2 + 0, mask=None) | |
x_row_1 = tl.load(x_ptr + row_idx * x_row_stride + head_idx *x_head_stride + cols * 2 + 1, mask=None) | |
output_0 = x_row_0 * cos - x_row_1 * sin | |
output_1 = x_row_0 * sin + x_row_1 * cos | |
output = tl.interleave(output_0, output_1) | |
tl.store( | |
x_ptr + row_idx * x_row_stride + head_idx * x_head_stride + tl.arange(0, D), | |
output.to(x_row_0.dtype), | |
mask=None | |
) | |
def rope(x, r): | |
# x is the input, N X H x D | |
# r is the rotation angle, N x D/2 | |
shape = x.shape | |
x = x.view(-1, shape[-2], shape[-1]) | |
r = r.view(-1, r.shape[-1]) | |
N, H, D = x.shape | |
N1, D1 = r.shape | |
assert D1 * 2 == D | |
assert N == N1 | |
M = max(1, N//32) | |
_rope_kernel[(M,)]( | |
x, x.stride(0), x.stride(1), | |
r, r.stride(0), | |
N=N, | |
H=H, | |
D=D, | |
) | |
return x.view(*shape) | |
if __name__ == "__main__": | |
import torch | |
# Set default device to CUDA and default dtype to bfloat16 | |
torch.set_default_device('cuda') | |
torch.set_default_dtype(torch.bfloat16) | |
# Create example input tensors | |
N, H, D = 128, 32, 128 # Batch size, number of heads, embedding dimension | |
x = torch.randn(N, H, D) # Input tensor | |
print(f"First few values of input: {x[0, 0, :10]}") | |
r = torch.randn(N, D // 2) # Rotation angles | |
# Apply RoPE (Rotary Position Embedding) | |
rotated_x = rope(x, r) | |
# Synchronize CUDA operations | |
torch.cuda.synchronize() | |
print(f"Input shape: {x.shape}") | |
print(f"Rotated output shape: {rotated_x.shape}") | |
print(f"First few values of rotated output: {rotated_x[0, 0, :10]}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment