Skip to content

Instantly share code, notes, and snippets.

@youkaichao
Created November 5, 2024 00:06
Show Gist options
  • Save youkaichao/8f87555bdeaaf68f4492b0dc96fbd206 to your computer and use it in GitHub Desktop.
Save youkaichao/8f87555bdeaaf68f4492b0dc96fbd206 to your computer and use it in GitHub Desktop.
cuda ipc
import os
from typing import List
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
import torch.distributed as dist
dist.init_process_group(backend="gloo")
rank = local_rank = dist.get_rank()
world_size = dist.get_world_size()
torch.cuda.set_device(local_rank)
def share_tensor(A: torch.Tensor, group=None) -> List[torch.Tensor]:
from torch.multiprocessing.reductions import reduce_tensor
A_meta = reduce_tensor(A)
tensor_metas = [None] * world_size
dist.all_gather_object(tensor_metas, A_meta, group=group)
rank = dist.get_rank(group)
all_tensors = []
for i, obj in enumerate(tensor_metas):
func = obj[0]
args = list(obj[1])
args[6] = A.device.index
if i != rank:
all_tensors.append(func(*args))
else:
all_tensors.append(A)
return all_tensors
A = torch.ones((10,), device=local_rank) * rank
all_tensors = share_tensor(A)
dist.barrier()
torch.cuda.synchronize()
if rank == 0:
for x in all_tensors:
x.zero_()
dist.barrier()
torch.cuda.synchronize()
for i, x in enumerate(all_tensors):
print(f"{rank=}, {i=}, {x=}")
@youkaichao
Copy link
Author

run it with torchrun --nproc-per-node=4 testg.py :

rank=0, i=0, x=tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0')
rank=1, i=0, x=tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:1')
rank=1, i=1, x=tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:1')
rank=0, i=1, x=tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0')
rank=0, i=2, x=tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0')
rank=1, i=2, x=tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:1')
rank=0, i=3, x=tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0')
rank=1, i=3, x=tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:1')
rank=3, i=0, x=tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:3')
rank=3, i=1, x=tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:3')
rank=3, i=2, x=tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:3')
rank=3, i=3, x=tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:3')
rank=2, i=0, x=tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:2')
rank=2, i=1, x=tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:2')
rank=2, i=2, x=tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:2')
rank=2, i=3, x=tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:2')

@youkaichao
Copy link
Author

running with export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True , will get an error:

RuntimeError: pidfd_getfd: Operation not permitted

might be related with https://github.com/pytorch/pytorch/blob/3f248a57353288ac4df3a445ffa3ae0f952a6d33/c10/cuda/CUDACachingAllocator.cpp#L487

@youkaichao
Copy link
Author

the following code can compile and run actually:

#define _GNU_SOURCE
#include <stdio.h>
#include <unistd.h>
#include <fcntl.h>
#include <sys/syscall.h>
#include <errno.h>

int main() {
    // Step 1: Obtain a pidfd for the current process
    int pidfd = syscall(SYS_pidfd_open, getpid(), 0);
    if (pidfd == -1) {
        perror("pidfd_open failed");
        return 1;
    }

    // Step 2: Open a valid file descriptor, e.g., /proc/self/status
    int fd = open("/proc/self/status", O_RDONLY);
    if (fd == -1) {
        perror("open failed");
        close(pidfd);
        return 1;
    }

    // Step 3: Try to duplicate fd using pidfd_getfd
    int new_fd = syscall(SYS_pidfd_getfd, pidfd, fd, 0);
    if (new_fd == -1) {
        perror("pidfd_getfd failed");
    } else {
        printf("pidfd_getfd succeeded, new_fd: %d\n", new_fd);
        close(new_fd);  // Close the duplicated fd if successful
    }

    // Clean up
    close(fd);
    close(pidfd);
    return 0;
}

compile : gcc test.c -o test

run: ./test

output: pidfd_getfd succeeded, new_fd: 5

@youkaichao
Copy link
Author

running on 2.6.0.dev20241112+cu124 , still get the same error RuntimeError: pidfd_getfd: Operation not permitted .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment