Skip to content

Instantly share code, notes, and snippets.

@a-r-r-o-w
Last active February 15, 2025 21:47
Show Gist options
  • Save a-r-r-o-w/45048d729d62c38de641a1eeb58e20bc to your computer and use it in GitHub Desktop.
Save a-r-r-o-w/45048d729d62c38de641a1eeb58e20bc to your computer and use it in GitHub Desktop.
TP on simple MLP. Applies TP in 4 different ways
import copy
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.distributed.tensor.device_mesh import DeviceMesh
from torch.distributed.tensor import Replicate, Shard, DTensor
from torch.distributed.tensor.parallel.style import RowwiseParallel, ColwiseParallel, SequenceParallel
from torch.distributed.tensor.parallel.api import parallelize_module
from torch._utils import _get_device_module
DEVICE_TYPE = "cuda"
PG_BACKEND = "nccl"
DEVICE_COUNT = _get_device_module(DEVICE_TYPE).device_count()
class MLPModule(nn.Module):
def __init__(self, device, bias: bool = True):
super().__init__()
torch.manual_seed(42)
self.net1 = nn.Linear(10, 32, bias=bias, device=device)
self.relu = nn.ReLU()
self.net2 = nn.Linear(32, 10, bias=bias, device=device)
def forward(self, x):
x = self.net1(x)
x = self.relu(x)
x = self.net2(x)
return x
def compare_params(
local_module,
dist_module,
rank,
rank0_only=True,
skip_rowwise_bias=False,
compare_grad=False,
):
replicate = [Replicate()]
for name, param in local_module.named_parameters():
dist_param = dist_module.get_parameter(name)
param = param.grad if compare_grad else param
dist_param = dist_param.grad if compare_grad else dist_param
if (
(not rank0_only)
or (rank == 0)
or (
name not in ["net2.bias"]
and not skip_rowwise_bias
or name not in ["bias", "net2.bias"]
)
):
print("Comparing param:", rank, name, torch.allclose(param, dist_param.redistribute(device_mesh=dist_param.device_mesh, placements=replicate).to_local()))
@torch.no_grad()
def main(world_size: int, rank: int):
# Normal model
torch.manual_seed(0)
model = MLPModule(DEVICE_TYPE)
# TP model
model_tp = copy.deepcopy(model)
device_mesh = DeviceMesh(DEVICE_TYPE, torch.arange(world_size))
rowwise_on_net1 = False
tp_plan = {
# Type 1 - net1_INPUT_R---OUTPUT_S(-1) converted to net1_OUTPUT_R (net2_INPUT_R). net2_INPUT_R---OUTPUT_S(-1) converted to net2_OUTPUT_R
# "net1": ColwiseParallel(output_layouts=Replicate()),
# "net2": ColwiseParallel(output_layouts=Replicate()),
# Type 2 - net1_INPUT_R---OUTPUT_S(-1) (net2_INPUT_S(-1)). net2_INPUT_S(-1)---OUTPUT_R
# "net1": ColwiseParallel(),
# "net2": RowwiseParallel(),
# Type 3 - net1_INPUT_S(-1)---OUTPUT_R (net2_INPUT_R). net2_INPUT_R---OUTPUT_S(-1) converted to OUTPUT_R
# "net1": RowwiseParallel(),
# "net2": ColwiseParallel(output_layouts=Replicate()),
# Type 4 - net1_INPUT_S(-1)---OUTPUT_R (net2_INPUT_R). net2_INPUT_R---OUTPUT_R
"net1": RowwiseParallel(),
"net2": RowwiseParallel(input_layouts=Replicate()),
}
if isinstance(tp_plan["net1"], RowwiseParallel):
rowwise_on_net1 = True
model_tp = parallelize_module(model_tp, device_mesh, tp_plan)
compare_params(model, model_tp, rank, rank0_only=False)
input_shape = (16, 10)
input_tensor = torch.randn(input_shape, device=DEVICE_TYPE)
print("Forward pass on normal model:", input_tensor.shape)
output = model(input_tensor)
if rowwise_on_net1:
input_tensor_tp = input_tensor.chunk(world_size, dim=-1)[rank]
else:
input_tensor_tp = input_tensor
print("Forward pass on TP model:", input_tensor_tp.shape)
output_tp = model_tp(input_tensor_tp)
print("Before redistributed:", rank, output.shape, output_tp.shape)
output_tp = output_tp.redistribute(output_tp.device_mesh, [Replicate()]).to_local() if isinstance(output_tp, DTensor) else output_tp
print("Output shapes:", output.shape, output_tp.shape)
print("Comparing output:", rank, torch.allclose(output, output_tp))
dist.init_process_group(PG_BACKEND)
WORLD_SIZE = dist.get_world_size()
RANK = dist.get_rank()
torch.cuda.set_device(RANK)
print(f"World size: {WORLD_SIZE}")
print(f"Rank: {RANK}")
try:
main(WORLD_SIZE, RANK)
finally:
dist.destroy_process_group()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment