Created
January 10, 2023 05:56
-
-
Save hotbaby/4a8f637be3b262b4ac9fdceee5e66a0d to your computer and use it in GitHub Desktop.
PyTorch集合通信collective communication
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# encoding: utf8 | |
import os | |
import torch | |
import torch.distributed as dist | |
import torch.multiprocessing as mp | |
def cm_broadcast_object_demo(rank: int, world_size: int): | |
dist.init_process_group("gloo", world_size=world_size, rank=rank) | |
if rank == 0: | |
objects = ["foo", 12, {"key": "value"}] | |
else: | |
objects = [None, None, None] | |
dist.broadcast_object_list(objects, src=0, device=torch.device("cpu")) | |
print(f"rank: {rank}, objects: {objects}") | |
def cm_broadcast_demo(rank: int, world_size: int): | |
"""Broadcast""" | |
dist.init_process_group("gloo", world_size=world_size, rank=rank) | |
if dist.get_rank() == 0: | |
tensor = torch.arange(10) | |
else: | |
tensor = torch.zeros(10, dtype=torch.int64) | |
dist.broadcast(tensor, src=0) | |
print(tensor) | |
def cm_allreduce_demo(rank: int, world_size: int): | |
"""ring allreduce""" | |
dist.init_process_group("gloo", world_size=world_size, rank=rank) | |
rank = dist.get_rank() | |
tensor = torch.arange(2) + 2 * rank | |
print(f"before allreduce, rank: {rank}, tensor: {tensor}") | |
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) | |
print(f"after allreduce, rank: {rank}, tensor: {tensor}") | |
def cm_scatter_demo(rank: int, world_size: int): | |
dist.init_process_group("gloo", world_size=world_size, rank=rank) | |
scatter_list = [torch.ones(2), torch.ones(2) * 3] | |
output_tensor = torch.zeros_like(scatter_list[0]) | |
dist.scatter( | |
output_tensor, | |
scatter_list if rank == 0 else None, | |
src=0 | |
) | |
print(f"scatter_list: {scatter_list}, rank: {rank}, output: {output_tensor}") | |
def cm_gather_demo(rank: int, world_size: int): | |
dist.init_process_group("gloo", world_size=world_size, rank=rank) | |
rank = dist.get_rank() | |
tensor = torch.tensor([rank]) | |
print(f"before gather, rank: {rank}, tensor: {tensor}") | |
output = [torch.zeros_like(tensor) for _ in range(world_size)] | |
dist.gather( | |
tensor, | |
output if rank == 0 else None, # Argument ``gather_list`` must NOT be specified on non-destination ranks. | |
dst=0) | |
if rank == 0: | |
concat_output = torch.concat(output) | |
print(f"after gather, rank: {rank}, output: {output}, concat output: {concat_output}") | |
collection_methods = [ | |
cm_broadcast_object_demo, | |
cm_broadcast_demo, | |
cm_allreduce_demo, | |
cm_gather_demo, | |
cm_scatter_demo, | |
] | |
def main(): | |
world_size = 2 | |
os.environ["MASTER_ADDR"] = "localhost" | |
os.environ["MASTER_PORT"] = "12345" | |
for method in collection_methods: | |
print(f"collective communication {method.__name__}") | |
mp.spawn(method, (world_size,), nprocs=world_size, join=True) | |
print("") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
allreduce
broadcast
gather
scatter