Skip to content

Instantly share code, notes, and snippets.

@wolfecameron
Last active March 6, 2025 18:47
Show Gist options
  • Save wolfecameron/46f03d50617f256f4560f299422f7ceb to your computer and use it in GitHub Desktop.
Save wolfecameron/46f03d50617f256f4560f299422f7ceb to your computer and use it in GitHub Desktop.
Implementation of a basic softmax routing mechanism for an MoE.
import torch
from torch import nn
from torch.nn import functional as F
class BasicSoftmaxRouter(nn.Module):
def __init__(
self,
d,
n_exp = 8,
top_k = 2,
use_noisy_top_k = True,
):
"""
Arguments:
d: size of embedding dimension
n_exp: the number of experts to create in the expert layer
top_k: the number of active experts for each token
use_noisy_top_k: whether to add noise when computing expert output
"""
super().__init__()
# router settings
self.top_k = top_k
assert self.top_k >= 1 and self.top_k <= n_exp
self.use_noisy_top_k = use_noisy_top_k
# linear projection for (noisy) softmax routing
# no bias used, see page 4 eq (4) in https://arxiv.org/abs/1701.06538
self.w_g = nn.Linear(d, n_exp, bias=False)
self.w_noise = nn.Linear(d, n_exp, bias=False) if self.use_noisy_top_k else None
def forward(self, x):
# eq (4) in https://arxiv.org/abs/1701.06538
logits = self.w_g(x) # [B, C, d] -> [B, C, n_exp]
if self.use_noisy_top_k:
# (optionally) add noise into the router
noise = F.softplus(self.w_noise(x))
noise *= torch.randn_like(noise)
logits += noise
top_k_logits, top_k_indices = logits.topk(self.top_k, dim=-1) # [B, C, k]
return top_k_logits, top_k_indices
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment