Skip to content

Instantly share code, notes, and snippets.

@msaroufim
Created August 25, 2025 21:49
Show Gist options
  • Save msaroufim/51bec037cd95cddb29c3b4c2087d885c to your computer and use it in GitHub Desktop.
Save msaroufim/51bec037cd95cddb29c3b4c2087d885c to your computer and use it in GitHub Desktop.
import torch
# CUDA kernel with inline PTX
kernel_source = """
__global__ void vector_add(const float* a, const float* b, float* c, int n) {
int idx;
asm("mov.u32 %0, %%ctaid.x;" : "=r"(idx));
int tid;
asm("mov.u32 %0, %%tid.x;" : "=r"(tid));
int ntid;
asm("mov.u32 %0, %%ntid.x;" : "=r"(ntid));
asm("mad.lo.s32 %0, %1, %2, %3;" : "=r"(idx) : "r"(idx), "r"(ntid), "r"(tid));
if (idx >= n) return;
float val_a, val_b, result;
// Load values using PTX
asm("ld.global.f32 %0, [%1];" : "=f"(val_a) : "l"(&a[idx]));
asm("ld.global.f32 %0, [%1];" : "=f"(val_b) : "l"(&b[idx]));
// Add using PTX
asm("add.f32 %0, %1, %2;" : "=f"(result) : "f"(val_a), "f"(val_b));
// Store result using PTX
asm("st.global.f32 [%0], %1;" : : "l"(&c[idx]), "f"(result));
}
"""
# Create tensors
n = 1000
a = torch.ones(n, device='cuda')
b = torch.ones(n, device='cuda') * 2
c = torch.zeros(n, device='cuda')
# Compile kernel
add_kernel = torch.cuda._compile_kernel(kernel_source, "vector_add")
# Launch kernel
blocks = (n + 255) // 256
add_kernel(
grid=(blocks, 1, 1),
block=(256, 1, 1),
args=[a, b, c, n]
)
# Check result
print(f"a[0] = {a[0].item()}")
print(f"b[0] = {b[0].item()}")
print(f"c[0] = {c[0].item()}") # Should be 3.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment