Last active
February 15, 2025 21:47
-
-
Save a-r-r-o-w/45048d729d62c38de641a1eeb58e20bc to your computer and use it in GitHub Desktop.
TP on simple MLP. Applies TP in 4 different ways
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import copy | |
import torch | |
import torch.nn as nn | |
import torch.distributed as dist | |
from torch.distributed.tensor.device_mesh import DeviceMesh | |
from torch.distributed.tensor import Replicate, Shard, DTensor | |
from torch.distributed.tensor.parallel.style import RowwiseParallel, ColwiseParallel, SequenceParallel | |
from torch.distributed.tensor.parallel.api import parallelize_module | |
from torch._utils import _get_device_module | |
DEVICE_TYPE = "cuda" | |
PG_BACKEND = "nccl" | |
DEVICE_COUNT = _get_device_module(DEVICE_TYPE).device_count() | |
class MLPModule(nn.Module): | |
def __init__(self, device, bias: bool = True): | |
super().__init__() | |
torch.manual_seed(42) | |
self.net1 = nn.Linear(10, 32, bias=bias, device=device) | |
self.relu = nn.ReLU() | |
self.net2 = nn.Linear(32, 10, bias=bias, device=device) | |
def forward(self, x): | |
x = self.net1(x) | |
x = self.relu(x) | |
x = self.net2(x) | |
return x | |
def compare_params( | |
local_module, | |
dist_module, | |
rank, | |
rank0_only=True, | |
skip_rowwise_bias=False, | |
compare_grad=False, | |
): | |
replicate = [Replicate()] | |
for name, param in local_module.named_parameters(): | |
dist_param = dist_module.get_parameter(name) | |
param = param.grad if compare_grad else param | |
dist_param = dist_param.grad if compare_grad else dist_param | |
if ( | |
(not rank0_only) | |
or (rank == 0) | |
or ( | |
name not in ["net2.bias"] | |
and not skip_rowwise_bias | |
or name not in ["bias", "net2.bias"] | |
) | |
): | |
print("Comparing param:", rank, name, torch.allclose(param, dist_param.redistribute(device_mesh=dist_param.device_mesh, placements=replicate).to_local())) | |
@torch.no_grad() | |
def main(world_size: int, rank: int): | |
# Normal model | |
torch.manual_seed(0) | |
model = MLPModule(DEVICE_TYPE) | |
# TP model | |
model_tp = copy.deepcopy(model) | |
device_mesh = DeviceMesh(DEVICE_TYPE, torch.arange(world_size)) | |
rowwise_on_net1 = False | |
tp_plan = { | |
# Type 1 - net1_INPUT_R---OUTPUT_S(-1) converted to net1_OUTPUT_R (net2_INPUT_R). net2_INPUT_R---OUTPUT_S(-1) converted to net2_OUTPUT_R | |
# "net1": ColwiseParallel(output_layouts=Replicate()), | |
# "net2": ColwiseParallel(output_layouts=Replicate()), | |
# Type 2 - net1_INPUT_R---OUTPUT_S(-1) (net2_INPUT_S(-1)). net2_INPUT_S(-1)---OUTPUT_R | |
# "net1": ColwiseParallel(), | |
# "net2": RowwiseParallel(), | |
# Type 3 - net1_INPUT_S(-1)---OUTPUT_R (net2_INPUT_R). net2_INPUT_R---OUTPUT_S(-1) converted to OUTPUT_R | |
# "net1": RowwiseParallel(), | |
# "net2": ColwiseParallel(output_layouts=Replicate()), | |
# Type 4 - net1_INPUT_S(-1)---OUTPUT_R (net2_INPUT_R). net2_INPUT_R---OUTPUT_R | |
"net1": RowwiseParallel(), | |
"net2": RowwiseParallel(input_layouts=Replicate()), | |
} | |
if isinstance(tp_plan["net1"], RowwiseParallel): | |
rowwise_on_net1 = True | |
model_tp = parallelize_module(model_tp, device_mesh, tp_plan) | |
compare_params(model, model_tp, rank, rank0_only=False) | |
input_shape = (16, 10) | |
input_tensor = torch.randn(input_shape, device=DEVICE_TYPE) | |
print("Forward pass on normal model:", input_tensor.shape) | |
output = model(input_tensor) | |
if rowwise_on_net1: | |
input_tensor_tp = input_tensor.chunk(world_size, dim=-1)[rank] | |
else: | |
input_tensor_tp = input_tensor | |
print("Forward pass on TP model:", input_tensor_tp.shape) | |
output_tp = model_tp(input_tensor_tp) | |
print("Before redistributed:", rank, output.shape, output_tp.shape) | |
output_tp = output_tp.redistribute(output_tp.device_mesh, [Replicate()]).to_local() if isinstance(output_tp, DTensor) else output_tp | |
print("Output shapes:", output.shape, output_tp.shape) | |
print("Comparing output:", rank, torch.allclose(output, output_tp)) | |
dist.init_process_group(PG_BACKEND) | |
WORLD_SIZE = dist.get_world_size() | |
RANK = dist.get_rank() | |
torch.cuda.set_device(RANK) | |
print(f"World size: {WORLD_SIZE}") | |
print(f"Rank: {RANK}") | |
try: | |
main(WORLD_SIZE, RANK) | |
finally: | |
dist.destroy_process_group() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment