Skip to content

Instantly share code, notes, and snippets.

@msaroufim
Last active April 16, 2025 21:31
Show Gist options
  • Save msaroufim/305bec86209f61b24f327821e1d6aa18 to your computer and use it in GitHub Desktop.
Save msaroufim/305bec86209f61b24f327821e1d6aa18 to your computer and use it in GitHub Desktop.
"""
Limitations
1. Cannot do heavy templating, cannot use thrust for reductions
2. Cannot import any host includes
Thank you @malfet!
"""
import ctypes
import torch
import time
# TODO: Instead initialize cuda context instead of doing this hack
torch.randn(1, device="cuda")
# Load CUDA driver and NVRTC
# TODO: Need a better solution for this problem when we ship the right nvrtc
libnvrtc = ctypes.CDLL('/home/marksaroufim/.conda/envs/nv/lib/python3.10/site-packages/nvidia/cuda_nvrtc/lib/libnvrtc.so.12')
libcuda = ctypes.CDLL('libcuda.so')
# CUDA constants
NVRTC_SUCCESS = 0
CUDA_SUCCESS = 0
# Helper: check NVRTC errors
def check_nvrtc(result):
if result != NVRTC_SUCCESS:
err_str = ctypes.c_char_p()
libnvrtc.nvrtcGetErrorString(result, ctypes.byref(err_str))
raise RuntimeError(f'NVRTC error: {err_str.value.decode()}')
# Helper: check CUDA errors
def check_cuda(result):
if result != CUDA_SUCCESS:
err_str = ctypes.c_char_p()
libcuda.cuGetErrorString(result, ctypes.byref(err_str))
raise RuntimeError(f'CUDA error: {err_str.value.decode()}')
# CUDA kernel as string
kernel = b"""
extern "C"
__global__ void vector_add(const float* a, const float* b, float* c, int n) {
int i = threadIdx.x + blockIdx.x * blockDim.x;
if (i < n)
c[i] = a[i] + b[i];
}
"""
# 1. Compile with NVRTC
start_time = time.time()
prog = ctypes.c_void_p()
check_nvrtc(libnvrtc.nvrtcCreateProgram(ctypes.byref(prog),
kernel,
b'vector_add.cu',
0, None, None))
# TODO: We should figure out how to pass in the version of CUDA we want to pass in exactly to compile to the right arch
res = libnvrtc.nvrtcCompileProgram(prog, 0, None)
if res != NVRTC_SUCCESS:
# Get log
log_size = ctypes.c_size_t()
libnvrtc.nvrtcGetProgramLogSize(prog, ctypes.byref(log_size))
log = ctypes.create_string_buffer(log_size.value)
libnvrtc.nvrtcGetProgramLog(prog, log)
raise RuntimeError(f"Compilation failed:\n{log.value.decode()}")
# Get PTX
ptx_size = ctypes.c_size_t()
check_nvrtc(libnvrtc.nvrtcGetPTXSize(prog, ctypes.byref(ptx_size)))
ptx = ctypes.create_string_buffer(ptx_size.value)
check_nvrtc(libnvrtc.nvrtcGetPTX(prog, ptx))
libnvrtc.nvrtcDestroyProgram(ctypes.byref(prog))
# 3. Load PTX module
module = ctypes.c_void_p()
with torch.cuda.default_stream():
check_cuda(libcuda.cuModuleLoadData(ctypes.byref(module), ptx))
func = ctypes.c_void_p()
check_cuda(libcuda.cuModuleGetFunction(ctypes.byref(func), module, b"vector_add"))
compile_time = time.time() - start_time
print(f"Kernel compiled in {compile_time:.2f} seconds")
# 4. Prepare data
N = 512
a = torch.rand(N, device="cuda")
b = torch.rand(N, device="cuda")
c = torch.empty_like(a)
d_a = ctypes.c_void_p(a.data_ptr())
d_b = ctypes.c_void_p(b.data_ptr())
d_c = ctypes.c_void_p(c.data_ptr())
# 5. Launch kernel
threads = 256
blocks = (N + threads - 1) // threads
int_arg = ctypes.c_int(N)
arg1 = ctypes.byref(d_a)
arg2 = ctypes.byref(d_b)
arg3 = ctypes.byref(d_c)
arg4 = ctypes.byref(int_arg)
args = (ctypes.c_void_p * 4)(
ctypes.cast(arg1, ctypes.c_void_p),
ctypes.cast(arg2, ctypes.c_void_p),
ctypes.cast(arg3, ctypes.c_void_p),
ctypes.cast(arg4, ctypes.c_void_p)
)
check_cuda(libcuda.cuLaunchKernel(func,
blocks, 1, 1,
threads, 1, 1,
0, None, args, None))
# Print a few results
print("Result (first 5):", c[:5])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment