Skip to content

Instantly share code, notes, and snippets.

@wolfecameron
Created March 6, 2025 18:31
Show Gist options
  • Save wolfecameron/12219c5293853610fc46785d8518cb45 to your computer and use it in GitHub Desktop.
Save wolfecameron/12219c5293853610fc46785d8518cb45 to your computer and use it in GitHub Desktop.
An implementation of the MoE load balancing loss in PyTorch.
"""
Computes Switch Transformer auxiliary loss (https://arxiv.org/abs/2101.03961)
See equations (4)-(6) on page 7
"""
import torch
import torch.nn.functional as F
# constants
B = 16 # batch size
C = 256 # sequence length
n_exp = 8 # number of experts
K = 2 # number of active expert
# define tensors needed to compute load balancing loss
indices = torch.randint(1, n_exp + 1, (B, C, K)) # top-K indices ([B, C, K])
expert_probs = F.softmax(torch.rand(B, C, n_exp), dim=2) # expert probabilities ([B, C, n_exp])
# equation (5): compute ratio of tokens allocated to each expert
# total number of tokens is defined as total tokens in batch * K
with torch.no_grad():
one_hot_indices = F.one_hot(indices, num_classes=n_exp) # [B, C, K, n_exp]
one_hot_indices = torch.sum(one_hot_indices.float(), dim=2) # [B, C, n_exp] (sum over K dimension)
tokens_per_expert = torch.mean(one_hot_indices.float(), dim=(0, 1))
# equation (6): compute ratio of router probability allocated to each expert
prob_per_expert = torch.mean(expert_probs.float(), dim=(0, 1))
# equation (4): take a scaled dot product between prob / token allocation vectors
# multiply the result by the number of experts
load_balance_loss = n_exp * torch.sum(prob_per_expert * tokens_per_expert)
@ChanduTadanki
Copy link

I get an error as:

RuntimeError                              Traceback (most recent call last)
Cell In[91], [line 22](vscode-notebook-cell:?execution_count=91&line=22)
     [19](vscode-notebook-cell:?execution_count=91&line=19) # equation (5): compute ratio of tokens allocated to each expert
     [20](vscode-notebook-cell:?execution_count=91&line=20) # total number of tokens is defined as total tokens in batch * K
     [21](vscode-notebook-cell:?execution_count=91&line=21) with torch.no_grad():
---> [22](vscode-notebook-cell:?execution_count=91&line=22)     one_hot_indices = F.one_hot(indices, num_classes=n_exp)  # [B, C, K, n_exp]
     [23](vscode-notebook-cell:?execution_count=91&line=23)     one_hot_indices = torch.sum(one_hot_indices.float(), dim=2)  # [B, C, n_exp] (sum over K dimension)
     [24](vscode-notebook-cell:?execution_count=91&line=24)     tokens_per_expert = torch.mean(one_hot_indices.float(), dim=(0, 1))

RuntimeError: Class values must be smaller than num_classes.

What am I missing?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment