Created
March 18, 2025 11:17
-
-
Save qpwo/f5e3928d9719ed2cd2375a993e59f114 to your computer and use it in GitHub Desktop.
inline cuda kernel pytorch minimal example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.utils.cpp_extension import load_inline | |
# Define C++ source code | |
cpp_source = """ | |
#include <torch/extension.h> | |
torch::Tensor add_one(torch::Tensor input) { | |
return input + 1; | |
} | |
// Forward declaration of the CUDA function | |
torch::Tensor add_one_cuda(torch::Tensor input); | |
""" | |
# Define CUDA source code | |
cuda_source = """ | |
#include <torch/extension.h> | |
__global__ void add_one_kernel(float* input, float* output, int size) { | |
int idx = blockIdx.x * blockDim.x + threadIdx.x; | |
if (idx < size) { | |
output[idx] = input[idx] + 1.0f; | |
} | |
} | |
torch::Tensor add_one_cuda(torch::Tensor input) { | |
auto output = torch::empty_like(input); | |
const int threads = 256; | |
const int blocks = (input.numel() + threads - 1) / threads; | |
add_one_kernel<<<blocks, threads>>>( | |
input.data_ptr<float>(), | |
output.data_ptr<float>(), | |
input.numel() | |
); | |
return output; | |
} | |
""" | |
# Compile the extension | |
add_module = load_inline( | |
name="add_extension", | |
cpp_sources=cpp_source, | |
cuda_sources=cuda_source, | |
functions=["add_one", "add_one_cuda"], | |
verbose=True | |
) | |
# Test the CPU extension | |
tensor = torch.tensor([1, 2, 3]) | |
result = add_module.add_one(tensor) | |
print(f"Input: {tensor}") | |
print(f"Output (CPU): {result}") | |
# Test the CUDA extension | |
if torch.cuda.is_available(): | |
tensor_cuda = torch.tensor([1, 2, 3], device="cuda", dtype=torch.float32) | |
result_cuda = add_module.add_one_cuda(tensor_cuda) | |
print(f"Input: {tensor_cuda}") | |
print(f"Output (CUDA): {result_cuda}") | |
else: | |
print("CUDA not available, skipping CUDA test") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment