Skip to content

Instantly share code, notes, and snippets.

@ezyang
Created July 8, 2025 20:50
Show Gist options
  • Save ezyang/dbb48a060630143634eb1e07cb92da16 to your computer and use it in GitHub Desktop.
Save ezyang/dbb48a060630143634eb1e07cb92da16 to your computer and use it in GitHub Desktop.
import torch
from torch import nn
from torch.distributed.tensor.placement_types import Replicate, Shard
from torch.testing._internal.distributed.fake_pg import FakeStore
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import DTensor, Replicate
world_size = 4
fake_store = FakeStore()
torch.distributed.init_process_group(
"fake", store=fake_store, rank=0, world_size=world_size
)
mesh = torch.distributed.device_mesh.init_device_mesh(
"cuda",
(world_size // 2, 2),
mesh_dim_names=(
"dp",
"tp",
),
)
local_tensor = torch.arange(4*4, device='cuda').view(4, 4)
dt = DTensor.from_local(local_tensor, mesh, [Shard(0)])
print(dt)
full = dt.full_tensor()
print(full)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment