Last active
April 16, 2025 21:47
-
-
Save malfet/2c9a25976dd7396430c38af603f791da to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ctypes | |
import torch | |
import time | |
def nvrtc_compile(source: str) -> str: | |
from ctypes import CDLL, c_void_p, c_char_p, c_size_t, byref, create_string_buffer | |
libnvrtc = CDLL('libnvrtc.so') | |
def get_error_string() -> str: | |
err_p = c_char_p() | |
libnvrtc.nvrtcGetErrorString(result, byref(err_str)) | |
return err_p.value.decode() | |
prog = ctypes.c_void_p() | |
res = libnvrtc.nvrtcCreateProgram(byref(prog), source, "nvrtc.cu", 0, None, None) | |
if res != 0: | |
raise RuntimeError(f"Can't create program: {get_error_string()}") | |
if libnvrtc.nvrtcCompileProgram(prog, 0, None) != 0: | |
raise RuntimeError(f"Can't compile: {get_error_string()}") | |
ptx_size = c_size_t() | |
if libnvrtc.nvrtcGetPTXSize(prog, byref(ptx_size)) != 0: | |
raise RuntimeError("Can't get PTX size: {get_error_string()}") | |
ptx = create_string_buffer(ptx_size.value) | |
if libnvrtc.nvrtcGetPTX(prog, ptx) != 0: | |
raise RuntimeError(f"Can't get PTX: {get_error_string()}") | |
libnvrtc.nvrtcDestroyProgram(ctypes.byref(prog)) | |
return ptx.value.decode() | |
# Load CUDA driver and NVRTC | |
libcuda = ctypes.CDLL('libcuda.so') | |
CUDA_SUCCESS = 0 | |
# Helper: check CUDA errors | |
def check_cuda(result): | |
if result != CUDA_SUCCESS: | |
err_str = ctypes.c_char_p() | |
libcuda.cuGetErrorString(result, ctypes.byref(err_str)) | |
raise RuntimeError(f'CUDA error: {err_str.value.decode()}') | |
# CUDA kernel as string | |
kernel = b""" | |
extern "C" | |
__global__ void vector_add(const float* a, const float* b, float* c, int n) { | |
int i = threadIdx.x + blockIdx.x * blockDim.x; | |
if (i < n) | |
c[i] = a[i] + b[i]; | |
} | |
""" | |
# 1. Compile with NVRTC | |
start_time = time.time() | |
ptx = nvrtc_compile(kernel) | |
# 3. Load PTX module | |
module = ctypes.c_void_p() | |
with torch.cuda.default_stream(): | |
check_cuda(libcuda.cuModuleLoadData(ctypes.byref(module), ptx.encode("latin-1"))) | |
func = ctypes.c_void_p() | |
check_cuda(libcuda.cuModuleGetFunction(ctypes.byref(func), module, b"vector_add")) | |
compile_time = time.time() - start_time | |
print(f"Kernel compiled in {compile_time:.2f} seconds") | |
# 4. Prepare data | |
N = 512 | |
a = torch.rand(N, device="cuda") | |
b = torch.rand(N, device="cuda") | |
c = torch.empty_like(a) | |
d_a = ctypes.c_void_p(a.data_ptr()) | |
d_b = ctypes.c_void_p(b.data_ptr()) | |
d_c = ctypes.c_void_p(c.data_ptr()) | |
# 5. Launch kernel | |
threads = 256 | |
blocks = (N + threads - 1) // threads | |
int_arg = ctypes.c_int(N) | |
arg1 = ctypes.byref(d_a) | |
arg2 = ctypes.byref(d_b) | |
arg3 = ctypes.byref(d_c) | |
arg4 = ctypes.byref(int_arg) | |
args = (ctypes.c_void_p * 4)( | |
ctypes.cast(arg1, ctypes.c_void_p), | |
ctypes.cast(arg2, ctypes.c_void_p), | |
ctypes.cast(arg3, ctypes.c_void_p), | |
ctypes.cast(arg4, ctypes.c_void_p) | |
) | |
check_cuda(libcuda.cuLaunchKernel(func, | |
blocks, 1, 1, | |
threads, 1, 1, | |
0, None, args, None)) | |
# Print a few results | |
print("Result (first 5):", c[:5]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment