Skip to content

Instantly share code, notes, and snippets.

@ezyang
Last active November 23, 2025 17:16
Show Gist options
  • Select an option

  • Save ezyang/9ef3300cd516150637592af1e8b76d7d to your computer and use it in GitHub Desktop.

Select an option

Save ezyang/9ef3300cd516150637592af1e8b76d7d to your computer and use it in GitHub Desktop.
import torch
import unittest
from torch import Tensor
from torch.distributed.tensor import (
DTensor,
DeviceMesh,
distribute_tensor,
init_device_mesh,
Partial,
Replicate,
Shard,
)
from torch.distributed.tensor.placement_types import _StridedShard
from torch.distributed._local_tensor import (
local_tensor_mode,
LocalTensor,
LocalTensorMode,
)
import traceback
from torch.distributed.tensor._sharding_prop import ShardingPropagator
S = Shard
R = Replicate()
_SS = _StridedShard
def product(it):
x = 1
for i in it:
x *= i
return x
def arange_nd(*sizes):
if len(sizes) == 1 and isinstance(sizes[0], (list, tuple)):
sizes = sizes[0]
return torch.arange(product(sizes)).view(sizes)
def reconcile(l: Tensor):
"""Asserts that a LocalTensor is the same on all ranks, and returns the single Tensor."""
if isinstance(l, LocalTensor):
return l.reconcile()
return l
def exit_local_tensor_mode():
from torch.distributed import _local_tensor
if getattr(_local_tensor, "_LOCAL_TENSOR_MODE", None):
for lm in list(reversed(_local_tensor._LOCAL_TENSOR_MODE)):
lm.__exit__(None, None, None)
elif getattr(_local_tensor, "_GLOBAL_LOCAL_TENSOR_MODE", None):
for lm in list(reversed(_local_tensor._GLOBAL_TENSOR_MODE)):
lm.__exit__(None, None, None)
def init_local_tensor_mode(world_size):
exit_local_tensor_mode()
try:
torch.distributed.destroy_process_group()
except AssertionError:
pass
torch.distributed.init_process_group(
"fake",
rank=0,
world_size=world_size,
)
lm = LocalTensorMode(world_size)
lm.__enter__()
return world_size
def init_fake_tensor_mode(world_size):
exit_local_tensor_mode()
try:
torch.distributed.destroy_process_group()
except AssertionError:
pass
torch.distributed.init_process_group(
"fake",
rank=0,
world_size=world_size,
)
return world_size
world_size = init_local_tensor_mode(4)
mesh = init_device_mesh("cpu", (4,), mesh_dim_names=("x",))
a = DTensor.from_local(arange_nd(4).float(), mesh, [R])
b = DTensor.from_local(torch.ones(4), mesh, [Partial()])
a += b
print(a)
@ezyang
Copy link
Author

ezyang commented Nov 20, 2025

This didn't work on tippy top nightly, it's now fixed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment