Last active
November 23, 2025 17:16
-
-
Save ezyang/9ef3300cd516150637592af1e8b76d7d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import unittest | |
| from torch import Tensor | |
| from torch.distributed.tensor import ( | |
| DTensor, | |
| DeviceMesh, | |
| distribute_tensor, | |
| init_device_mesh, | |
| Partial, | |
| Replicate, | |
| Shard, | |
| ) | |
| from torch.distributed.tensor.placement_types import _StridedShard | |
| from torch.distributed._local_tensor import ( | |
| local_tensor_mode, | |
| LocalTensor, | |
| LocalTensorMode, | |
| ) | |
| import traceback | |
| from torch.distributed.tensor._sharding_prop import ShardingPropagator | |
| S = Shard | |
| R = Replicate() | |
| _SS = _StridedShard | |
| def product(it): | |
| x = 1 | |
| for i in it: | |
| x *= i | |
| return x | |
| def arange_nd(*sizes): | |
| if len(sizes) == 1 and isinstance(sizes[0], (list, tuple)): | |
| sizes = sizes[0] | |
| return torch.arange(product(sizes)).view(sizes) | |
| def reconcile(l: Tensor): | |
| """Asserts that a LocalTensor is the same on all ranks, and returns the single Tensor.""" | |
| if isinstance(l, LocalTensor): | |
| return l.reconcile() | |
| return l | |
| def exit_local_tensor_mode(): | |
| from torch.distributed import _local_tensor | |
| if getattr(_local_tensor, "_LOCAL_TENSOR_MODE", None): | |
| for lm in list(reversed(_local_tensor._LOCAL_TENSOR_MODE)): | |
| lm.__exit__(None, None, None) | |
| elif getattr(_local_tensor, "_GLOBAL_LOCAL_TENSOR_MODE", None): | |
| for lm in list(reversed(_local_tensor._GLOBAL_TENSOR_MODE)): | |
| lm.__exit__(None, None, None) | |
| def init_local_tensor_mode(world_size): | |
| exit_local_tensor_mode() | |
| try: | |
| torch.distributed.destroy_process_group() | |
| except AssertionError: | |
| pass | |
| torch.distributed.init_process_group( | |
| "fake", | |
| rank=0, | |
| world_size=world_size, | |
| ) | |
| lm = LocalTensorMode(world_size) | |
| lm.__enter__() | |
| return world_size | |
| def init_fake_tensor_mode(world_size): | |
| exit_local_tensor_mode() | |
| try: | |
| torch.distributed.destroy_process_group() | |
| except AssertionError: | |
| pass | |
| torch.distributed.init_process_group( | |
| "fake", | |
| rank=0, | |
| world_size=world_size, | |
| ) | |
| return world_size | |
| world_size = init_local_tensor_mode(4) | |
| mesh = init_device_mesh("cpu", (4,), mesh_dim_names=("x",)) | |
| a = DTensor.from_local(arange_nd(4).float(), mesh, [R]) | |
| b = DTensor.from_local(torch.ones(4), mesh, [Partial()]) | |
| a += b | |
| print(a) |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This didn't work on tippy top nightly, it's now fixed.