Skip to content

Instantly share code, notes, and snippets.

@youkaichao
Created November 5, 2024 00:06
Show Gist options
  • Save youkaichao/8f87555bdeaaf68f4492b0dc96fbd206 to your computer and use it in GitHub Desktop.
Save youkaichao/8f87555bdeaaf68f4492b0dc96fbd206 to your computer and use it in GitHub Desktop.
cuda ipc
import os
from typing import List
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
import torch.distributed as dist
dist.init_process_group(backend="gloo")
rank = local_rank = dist.get_rank()
world_size = dist.get_world_size()
torch.cuda.set_device(local_rank)
def share_tensor(A: torch.Tensor, group=None) -> List[torch.Tensor]:
from torch.multiprocessing.reductions import reduce_tensor
A_meta = reduce_tensor(A)
tensor_metas = [None] * world_size
dist.all_gather_object(tensor_metas, A_meta, group=group)
rank = dist.get_rank(group)
all_tensors = []
for i, obj in enumerate(tensor_metas):
func = obj[0]
args = list(obj[1])
args[6] = A.device.index
if i != rank:
all_tensors.append(func(*args))
else:
all_tensors.append(A)
return all_tensors
A = torch.ones((10,), device=local_rank) * rank
all_tensors = share_tensor(A)
dist.barrier()
torch.cuda.synchronize()
if rank == 0:
for x in all_tensors:
x.zero_()
dist.barrier()
torch.cuda.synchronize()
for i, x in enumerate(all_tensors):
print(f"{rank=}, {i=}, {x=}")
@youkaichao
Copy link
Author

youkaichao commented Jun 15, 2025

to use the fabric handle, we should follow https://docs.nvidia.com/multi-node-nvlink-systems/imex-guide/imexchannels.html to create imex channels (read access is required), and then here is a working example of ipc through the fabric handle:

// sender.cpp
#include <cuda.h>
#include <cuda_runtime.h>
#include <iostream>
#include <fstream>
#include <vector>
#include <cstring>
#include <unistd.h>
#include <sys/syscall.h>
#include <sys/prctl.h>

// Define syscall numbers if not available
#ifndef SYS_pidfd_open
#define SYS_pidfd_open 434
#endif

// Helper function to get CUDA error string
const char* getCudaErrorString(CUresult error) {
    const char* errorString;
    cuGetErrorString(error, &errorString);
    return errorString;
}

int main() {
    prctl(PR_SET_PTRACER, PR_SET_PTRACER_ANY);
    // Initialize CUDA
    CUresult result = cuInit(0);
    if (result != CUDA_SUCCESS) {
        std::cerr << "Failed to initialize CUDA: " << getCudaErrorString(result) << std::endl;
        return 1;
    }

    // Get CUDA device
    CUdevice device;
    result = cuDeviceGet(&device, 0);
    if (result != CUDA_SUCCESS) {
        std::cerr << "Failed to get CUDA device: " << getCudaErrorString(result) << std::endl;
        return 1;
    }

    // Create CUDA context
    CUcontext context;
    result = cuCtxCreate(&context, 0, device);
    if (result != CUDA_SUCCESS) {
        std::cerr << "Failed to create CUDA context: " << getCudaErrorString(result) << std::endl;
        return 1;
    }

    // Allocate memory using VMM API
    const size_t size = 20 * 1024 * 1024; // 20MB
    CUmemGenericAllocationHandle handle;

    // Set up memory allocation properties
    CUmemAllocationProp prop = {};
    prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
    prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
    prop.location.id = 0;  // Use device 0
    prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_FABRIC;  // Use fabric handle type for IPC
    prop.win32HandleMetaData = nullptr;

    // Get the minimum granularity supported for allocation
    size_t granularity = 0;
    result = cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM);
    if (result != CUDA_SUCCESS) {
        std::cerr << "Failed to get allocation granularity: " << getCudaErrorString(result) << std::endl;
        return 1;
    }

    // Ensure size is a multiple of granularity
    if (size % granularity) {
        std::cerr << "Allocation size is not a multiple of minimum supported granularity" << std::endl;
        return 1;
    }

    std::cout << "Creating memory handle with size: " << size << " bytes" << std::endl;
    result = cuMemCreate(&handle, size, &prop, 0);
    if (result != CUDA_SUCCESS) {
        std::cerr << "Failed to create memory handle: " << getCudaErrorString(result) << std::endl;
        return 1;
    }
    std::cout << "Successfully created memory handle" << std::endl;

    // Reserve address range
    CUdeviceptr ptr;
    std::cout << "Reserving address range" << std::endl;
    result = cuMemAddressReserve(&ptr, size, 0, 0, 0);
    if (result != CUDA_SUCCESS) {
        std::cerr << "Failed to reserve address range: " << getCudaErrorString(result) << std::endl;
        cuMemRelease(handle);
        return 1;
    }
    std::cout << "Successfully reserved address range at: " << ptr << std::endl;

    // Map the memory
    std::cout << "Mapping memory" << std::endl;
    result = cuMemMap(ptr, size, 0, handle, 0);
    if (result != CUDA_SUCCESS) {
        std::cerr << "Failed to map memory: " << getCudaErrorString(result) << std::endl;
        cuMemAddressFree(ptr, size);
        cuMemRelease(handle);
        return 1;
    }
    std::cout << "Successfully mapped memory" << std::endl;

    // Set access properties
    CUmemAccessDesc accessDesc = {};
    accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
    accessDesc.location.id = 0;  // Use device 0
    accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;

    std::cout << "Setting memory access properties" << std::endl;
    result = cuMemSetAccess(ptr, size, &accessDesc, 1);
    if (result != CUDA_SUCCESS) {
        std::cerr << "Failed to set memory access: " << getCudaErrorString(result) << std::endl;
        cuMemUnmap(ptr, size);
        cuMemAddressFree(ptr, size);
        cuMemRelease(handle);
        return 1;
    }
    std::cout << "Successfully set memory access properties" << std::endl;

    // Export handle to fabric handle
    CUmemFabricHandle_v1 fabricHandle;
    std::cout << "Exporting handle to fabric handle" << std::endl;
    std::cout << "Original handle value: " << handle << std::endl;
    std::cout << "Allocation size: " << size << " bytes" << std::endl;
    result = cuMemExportToShareableHandle(&fabricHandle, handle, CU_MEM_HANDLE_TYPE_FABRIC, 0);
    if (result != CUDA_SUCCESS) {
        std::cerr << "Failed to export handle: " << getCudaErrorString(result) << std::endl;
        std::cerr << "Handle value: " << handle << std::endl;
        std::cerr << "Handle type: CU_MEM_HANDLE_TYPE_FABRIC" << std::endl;
        cuMemUnmap(ptr, size);
        cuMemAddressFree(ptr, size);
        cuMemRelease(handle);
        return 1;
    }
    std::cout << "Successfully exported handle to fabric handle" << std::endl;
    std::cout << "Fabric handle value: " << reinterpret_cast<uintptr_t>(&fabricHandle) << std::endl;

    // Write to file
    std::ofstream outfile("data.bin", std::ios::binary);
    if (!outfile) {
        std::cerr << "Failed to open output file: " << strerror(errno) << std::endl;
        cuMemUnmap(ptr, size);
        cuMemAddressFree(ptr, size);
        cuMemRelease(handle);
        return 1;
    }

    // Write 8-byte size header
    outfile.write(reinterpret_cast<const char*>(&size), 8);
    // Write fabric handle
    outfile.write(reinterpret_cast<const char*>(&fabricHandle), sizeof(CUmemFabricHandle_v1));
    outfile.close();

    std::cout << "Data written to data.bin. Press Enter to continue..." << std::endl;
    std::cin.get();

    // Cleanup
    result = cuMemUnmap(ptr, size);
    if (result != CUDA_SUCCESS) {
        std::cerr << "Failed to unmap memory: " << getCudaErrorString(result) << std::endl;
    }

    result = cuMemAddressFree(ptr, size);
    if (result != CUDA_SUCCESS) {
        std::cerr << "Failed to free address range: " << getCudaErrorString(result) << std::endl;
    }

    result = cuMemRelease(handle);
    if (result != CUDA_SUCCESS) {
        std::cerr << "Failed to release memory handle: " << getCudaErrorString(result) << std::endl;
    }
    
    result = cuCtxDestroy(context);
    if (result != CUDA_SUCCESS) {
        std::cerr << "Failed to destroy CUDA context: " << getCudaErrorString(result) << std::endl;
    }

    return 0;
} 
// receiver.cpp
#include <cuda.h>
#include <cuda_runtime.h>
#include <iostream>
#include <fstream>
#include <vector>
#include <cstring>

// Helper function to get CUDA error string
const char* getCudaErrorString(CUresult error) {
    const char* errorString;
    cuGetErrorString(error, &errorString);
    return errorString;
}

int main() {
    // Initialize CUDA
    CUresult result = cuInit(0);
    if (result != CUDA_SUCCESS) {
        std::cerr << "Failed to initialize CUDA: " << getCudaErrorString(result) << std::endl;
        return 1;
    }

    // Get CUDA device
    CUdevice device;
    result = cuDeviceGet(&device, 0);
    if (result != CUDA_SUCCESS) {
        std::cerr << "Failed to get CUDA device: " << getCudaErrorString(result) << std::endl;
        return 1;
    }

    // Create CUDA context
    CUcontext context;
    result = cuCtxCreate(&context, 0, device);
    if (result != CUDA_SUCCESS) {
        std::cerr << "Failed to create CUDA context: " << getCudaErrorString(result) << std::endl;
        return 1;
    }

    // Read from file
    std::ifstream infile("data.bin", std::ios::binary);
    if (!infile) {
        std::cerr << "Failed to open input file: " << strerror(errno) << std::endl;
        return 1;
    }

    // Read 8-byte size header
    size_t size;
    infile.read(reinterpret_cast<char*>(&size), 8);
    std::cout << "Read allocation size: " << size << " bytes" << std::endl;

    // Read fabric handle
    CUmemFabricHandle_v1 fabricHandle;
    infile.read(reinterpret_cast<char*>(&fabricHandle), sizeof(CUmemFabricHandle_v1));
    std::cout << "Read fabric handle value: " << reinterpret_cast<uintptr_t>(&fabricHandle) << std::endl;
    infile.close();

    // Import handle
    CUmemGenericAllocationHandle handle;
    std::cout << "Importing handle..." << std::endl;
    result = cuMemImportFromShareableHandle(
        &handle,
        reinterpret_cast<void*>(&fabricHandle),
        CU_MEM_HANDLE_TYPE_FABRIC
    );
    if (result != CUDA_SUCCESS) {
        std::cerr << "Failed to import handle: " << getCudaErrorString(result) << std::endl;
        return 1;
    }
    std::cout << "Successfully imported handle: " << handle << std::endl;

    // Reserve address range
    CUdeviceptr ptr;
    std::cout << "Reserving address range of size: " << size << std::endl;
    result = cuMemAddressReserve(&ptr, size, 0, 0, 0);
    if (result != CUDA_SUCCESS) {
        std::cerr << "Failed to reserve address range: " << getCudaErrorString(result) << std::endl;
        return 1;
    }
    std::cout << "Successfully reserved address range at: " << ptr << std::endl;

    // Map the memory
    std::cout << "Mapping memory..." << std::endl;
    result = cuMemMap(ptr, size, 0, handle, 0);
    if (result != CUDA_SUCCESS) {
        std::cerr << "Failed to map memory: " << getCudaErrorString(result) << std::endl;
        cuMemAddressFree(ptr, size);
        return 1;
    }
    std::cout << "Successfully mapped memory" << std::endl;

    // Set access properties
    CUmemAccessDesc accessDesc = {};
    accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
    accessDesc.location.id = 0;  // Use device 0
    accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;

    std::cout << "Setting memory access properties..." << std::endl;
    result = cuMemSetAccess(ptr, size, &accessDesc, 1);
    if (result != CUDA_SUCCESS) {
        std::cerr << "Failed to set memory access: " << getCudaErrorString(result) << std::endl;
        cuMemUnmap(ptr, size);
        cuMemAddressFree(ptr, size);
        return 1;
    }
    std::cout << "Successfully set memory access properties" << std::endl;

    std::cout << "Successfully imported and mapped memory at address: " << ptr << std::endl;
    std::cout << "Press Enter to continue..." << std::endl;
    std::cin.get();

    // Cleanup
    result = cuMemUnmap(ptr, size);
    if (result != CUDA_SUCCESS) {
        std::cerr << "Failed to unmap memory: " << getCudaErrorString(result) << std::endl;
    }

    result = cuMemAddressFree(ptr, size);
    if (result != CUDA_SUCCESS) {
        std::cerr << "Failed to free address range: " << getCudaErrorString(result) << std::endl;
    }

    result = cuMemRelease(handle);
    if (result != CUDA_SUCCESS) {
        std::cerr << "Failed to release memory handle: " << getCudaErrorString(result) << std::endl;
    }

    result = cuCtxDestroy(context);
    if (result != CUDA_SUCCESS) {
        std::cerr << "Failed to destroy CUDA context: " << getCudaErrorString(result) << std::endl;
    }

    return 0;
} 

Compile with

$ nvcc receiver.cpp -o receiver -lcuda
$ nvcc sender.cpp -o sender -lcuda

In one shell, execute ./sender , and in another shell, execute ./receiver .

We can see:

(py310) ➜  test_pidfd ./sender
Creating memory handle with size: 20971520 bytes
Successfully created memory handle
Reserving address range
Successfully reserved address range at: 139849545809920
Mapping memory
Successfully mapped memory
Setting memory access properties
Successfully set memory access properties
Exporting handle to fabric handle
Original handle value: 94556404496384
Allocation size: 20971520 bytes
Successfully exported handle to fabric handle
Fabric handle value: 140735970184032
Data written to data.bin. Press Enter to continue...

(py310) ➜  test_pidfd ./receiver
Read allocation size: 20971520 bytes
Read fabric handle value: 140734345409808
Importing handle...
Successfully imported handle: 94798849087488
Reserving address range of size: 20971520
Successfully reserved address range at: 140063992184832
Mapping memory...
Successfully mapped memory
Setting memory access properties...
Successfully set memory access properties
Successfully imported and mapped memory at address: 140063992184832
Press Enter to continue...

The data is 72 bytes, 8 byte header (for size) and 64 byte for the fabric handle.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment