Skip to content

Instantly share code, notes, and snippets.

@msaroufim
Last active March 21, 2025 00:02
Show Gist options
  • Save msaroufim/995d58845f29f76a4479d6c0fad30432 to your computer and use it in GitHub Desktop.
Save msaroufim/995d58845f29f76a4479d6c0fad30432 to your computer and use it in GitHub Desktop.
"""
Example showing how to use the no_implicit_headers mode with a TensorBase CUDA extension
This example creates a CUDA extension that directly includes ATen/core/TensorBase.h
instead of torch/extension.h or types.h, resulting in faster compilation
"""
from datetime import datetime
import torch
import torch.utils.cpp_extension
import shutil
import os
cuda_include_dir = os.path.join(os.environ['HOME'], '.conda/envs/pt/targets/x86_64-linux/include')
BUILD_DIR = os.path.join(os.getcwd(), "custom_extension_build")
# Clear the build directory if it exists
if os.path.exists(BUILD_DIR):
print(f"Clearing existing build directory: {BUILD_DIR}")
shutil.rmtree(BUILD_DIR)
# Create the build directory
os.makedirs(BUILD_DIR, exist_ok=True)
print(f"Created build directory: {BUILD_DIR}")
# C++ code that directly includes TensorBase.h without using torch/extension.h
cpp_source = """
#include <ATen/core/TensorBase.h>
#include <ATen/cuda/EmptyTensor.h>
#include <c10/cuda/CUDAGuard.h>
#include <pybind11/pybind11.h>
// Forward declaration of the CUDA kernel function
void launch_add_kernel(const float *, const float *, float *, int64_t);
at::TensorBase tensor_base_add(const at::TensorBase& x, const at::TensorBase& y) {
// Validate inputs
TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor");
TORCH_CHECK(y.is_cuda(), "y must be a CUDA tensor");
TORCH_CHECK(x.scalar_type() == at::ScalarType::Float, "x must be a float tensor");
TORCH_CHECK(y.scalar_type() == at::ScalarType::Float, "y must be a float tensor");
TORCH_CHECK(x.sizes() == y.sizes(), "x and y must have the same shape");
TORCH_CHECK(x.is_contiguous() && y.is_contiguous(), "x and y must be contiguous tensors");
auto output = at::detail::empty_cuda(x.sizes(), x.scalar_type());
// Set CUDA device and launch kernel
const at::cuda::CUDAGuard device_guard(x.device());
launch_add_kernel(x.const_data_ptr<float>(), y.const_data_ptr<float>(), output.data_ptr<float>(), x.numel());
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("tensor_base_add", &tensor_base_add, "Add two tensors using TensorBase directly");
}
"""
# CUDA source with direct TensorBase usage
cuda_source = """
#include <cuda.h>
#include <cuda_runtime.h>
__global__ void tensor_base_add_kernel(const float* __restrict__ x,
const float* __restrict__ y,
float* __restrict__ out,
const int size) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < size) {
out[idx] = x[idx] + y[idx] + 1.0f; // Add 1 to make it distinguishable
}
}
void launch_add_kernel(const float *x_data, const float *y_data, float *output_data, int64_t num_elements) {
const int threads = 1024;
const int blocks = (num_elements + threads - 1) / threads;
tensor_base_add_kernel<<<blocks, threads>>>(
x_data, y_data, output_data, num_elements);
}
"""
def main():
start_time = datetime.now()
module = torch.utils.cpp_extension.load_inline(
name="tensor_base_example",
cpp_sources=cpp_source,
cuda_sources=cuda_source,
verbose=True,
no_implicit_headers=True, # Skip including torch/extension.h and types.h
extra_include_paths=[cuda_include_dir],
build_directory=BUILD_DIR, # Specify the custom build directory
)
end_time = datetime.now()-start_time
print(f"Extension compiled successfully, end_time={end_time}!")
# Test the functionality
print("Testing on CUDA tensors...")
x = torch.randn(100, device="cuda", dtype=torch.float32)
y = torch.randn(100, device="cuda", dtype=torch.float32)
# Call our custom kernel
result = module.tensor_base_add(x, y)
# Verify result (our kernel adds 1.0 to distinguish it from a regular add)
expected = x + y + 1.0
# Check if results match
# if torch.allclose(result, expected):
# print("Test PASSED! ✓")
# else:
# print("Test FAILED!")
# max_diff = torch.max(torch.abs(result - expected))
# print(f"Maximum difference: {max_diff}")
if __name__ == "__main__":
# Check if CUDA is available
if not torch.cuda.is_available():
print("CUDA is not available, this example requires CUDA")
else:
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment