Skip to content

Instantly share code, notes, and snippets.

@malfet
Last active April 16, 2025 21:47
Show Gist options
  • Save malfet/2c9a25976dd7396430c38af603f791da to your computer and use it in GitHub Desktop.
Save malfet/2c9a25976dd7396430c38af603f791da to your computer and use it in GitHub Desktop.
import ctypes
import torch
import time
def nvrtc_compile(source: str) -> str:
from ctypes import CDLL, c_void_p, c_char_p, c_size_t, byref, create_string_buffer
libnvrtc = CDLL('libnvrtc.so')
def get_error_string() -> str:
err_p = c_char_p()
libnvrtc.nvrtcGetErrorString(result, byref(err_str))
return err_p.value.decode()
prog = ctypes.c_void_p()
res = libnvrtc.nvrtcCreateProgram(byref(prog), source, "nvrtc.cu", 0, None, None)
if res != 0:
raise RuntimeError(f"Can't create program: {get_error_string()}")
if libnvrtc.nvrtcCompileProgram(prog, 0, None) != 0:
raise RuntimeError(f"Can't compile: {get_error_string()}")
ptx_size = c_size_t()
if libnvrtc.nvrtcGetPTXSize(prog, byref(ptx_size)) != 0:
raise RuntimeError("Can't get PTX size: {get_error_string()}")
ptx = create_string_buffer(ptx_size.value)
if libnvrtc.nvrtcGetPTX(prog, ptx) != 0:
raise RuntimeError(f"Can't get PTX: {get_error_string()}")
libnvrtc.nvrtcDestroyProgram(ctypes.byref(prog))
return ptx.value.decode()
# Load CUDA driver and NVRTC
libcuda = ctypes.CDLL('libcuda.so')
CUDA_SUCCESS = 0
# Helper: check CUDA errors
def check_cuda(result):
if result != CUDA_SUCCESS:
err_str = ctypes.c_char_p()
libcuda.cuGetErrorString(result, ctypes.byref(err_str))
raise RuntimeError(f'CUDA error: {err_str.value.decode()}')
# CUDA kernel as string
kernel = b"""
extern "C"
__global__ void vector_add(const float* a, const float* b, float* c, int n) {
int i = threadIdx.x + blockIdx.x * blockDim.x;
if (i < n)
c[i] = a[i] + b[i];
}
"""
# 1. Compile with NVRTC
start_time = time.time()
ptx = nvrtc_compile(kernel)
# 3. Load PTX module
module = ctypes.c_void_p()
with torch.cuda.default_stream():
check_cuda(libcuda.cuModuleLoadData(ctypes.byref(module), ptx.encode("latin-1")))
func = ctypes.c_void_p()
check_cuda(libcuda.cuModuleGetFunction(ctypes.byref(func), module, b"vector_add"))
compile_time = time.time() - start_time
print(f"Kernel compiled in {compile_time:.2f} seconds")
# 4. Prepare data
N = 512
a = torch.rand(N, device="cuda")
b = torch.rand(N, device="cuda")
c = torch.empty_like(a)
d_a = ctypes.c_void_p(a.data_ptr())
d_b = ctypes.c_void_p(b.data_ptr())
d_c = ctypes.c_void_p(c.data_ptr())
# 5. Launch kernel
threads = 256
blocks = (N + threads - 1) // threads
int_arg = ctypes.c_int(N)
arg1 = ctypes.byref(d_a)
arg2 = ctypes.byref(d_b)
arg3 = ctypes.byref(d_c)
arg4 = ctypes.byref(int_arg)
args = (ctypes.c_void_p * 4)(
ctypes.cast(arg1, ctypes.c_void_p),
ctypes.cast(arg2, ctypes.c_void_p),
ctypes.cast(arg3, ctypes.c_void_p),
ctypes.cast(arg4, ctypes.c_void_p)
)
check_cuda(libcuda.cuLaunchKernel(func,
blocks, 1, 1,
threads, 1, 1,
0, None, args, None))
# Print a few results
print("Result (first 5):", c[:5])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment