Created
March 6, 2025 21:17
-
-
Save wolfecameron/67851367036bf1cb4e0524607bc90c91 to your computer and use it in GitHub Desktop.
PyTorch implementation of a feed-forward expert layer within an MoE.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Based upon ColossalAI OpenMoE | |
""" | |
from torch import nn | |
class MOELayer(nn.Module): | |
def __init__( | |
self, | |
d, | |
n_exp = 8, | |
top_k = 2, | |
use_noisy_top_k = True, | |
capacity_factor = 1.25, | |
bias=False, | |
dropout=0.2, | |
): | |
""" | |
Arguments: | |
d: size of embedding dimension | |
n_exp: the number of experts to create in the expert layer | |
top_k: the number of active experts for each token | |
use_noisy_top_k: whether to add noise when computing expert output | |
capacity_factor: used to compute expert capacity | |
bias: whether or not to use bias in linear layers | |
dropout: probability of dropout | |
""" | |
super().__init__() | |
self.router = Router( # (noisy) top k router | |
d=d, | |
n_exp=n_exp, | |
top_k=top_k, | |
use_noisy_top_k=use_noisy_top_k, | |
capacity_factor=capacity_factor, | |
) | |
self.experts = MLPExperts( # group of MLPs (experts) | |
d=d, | |
n_exp=n_exp, | |
bias=bias, | |
dropout=dropout, | |
) | |
def forward(self, x: torch.Tensor): | |
B, C, d = x.size() # track original shape of input | |
num_tokens = (B * C) | |
# pass each token through the router | |
exp_weight, exp_mask, exp_batches = self.router(x) | |
# compute expert output | |
exp_out = self.experts(exp_batches) # [n_exp, exp_capacity, d] | |
# aggregate expert outputs based on router weights | |
# eq (2) on page 4 of ST-MoE (https://arxiv.org/abs/2202.08906) | |
exp_weight = exp_weight.view(num_tokens, -1) # [B * C, n_exp * exp_capacity] | |
exp_out = exp_out.view(-1, d) # [n_exp * exp_capacity, d] | |
output = exp_weight @ exp_out # [B * C, d] | |
# resize output before return | |
return output.view(B, T, d) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment