Efficient PyTorch Implementation of MoE with Aux loss and Token drop
本文最后更新于:August 4, 2025 am
1 Preliminaries
Mixture-of-Experts is an essential architecture choice when building LLMs. Since the prevalence of DeepSeekV3, companies will consider whether to use MoE structure before LLM pretraining.
The concept of MoE is simple, suppose we have experts, for a token input , we use a router to decide which K experts it goes through, typically we have . This creates sparsely activated computation since not all experts are used during forward.
Typically, experts are FFNs that output tokens with the same shape as the input tokens. Router is a linear layer that transforms input dimension to logits that represents the probability that a token will be assigned to an expert. We can formulate the computation as follows. Suppose , we first calculate the logits
Then, we choose the top-K logits
After that, the logits will be multiplied to the expert output
Since contains zero element, so not all experts contributes to the output.
Our goal is to implement the above computation efficiently, and enable two standard tricks: (1) Load balancing loss. (2) Token drop
One major issue in MoE training is that the router may collapse and assign tokens to only a few experts, sometimes only one expert. If such case happens, other experts will never be updated during optimization, and we are not fully utilizing the parameters.
Load balancing loss introduces an aux loss that is combined with the next token prediction loss
where is the average frequency an expert is chosen, and is the average probability an expert is assigned by the router.
Except for the load balancing loss, people may use token drop. They will set a capacity factor for each expert, then one expert can receive at most
where K is the number of experts chosen, N is the total number of tokens and is the number of experts. If an expert is assigned with tokens more than , these tokens will be dropped (set to zero as the MoE layer output). However, they are not really dropped, after setting to zero, residual connection will bring back the tokens before MoE layer.
2 Code Walk through
Let’s use the code from HunYuan MoE models, released by Tencent. The core is the topkgating
function.
Suppose we have an input
1 |
|
and topk=2
Which means we have 10 tokens and will be assigned to 3 experts, among which two experts will be chosen. Note that the router is severely biased, it chooses expert 0 only and do not assign scores to other experts, this is used to demonstrate the token drop mechanism.
The first two lines of code is
1 |
|
For stability, we always convert input to FP32 before softmax. Then, the actual gate logits is calculated by the softmax funcion on dimension 1. The result is
1 |
|
The next line of code calculates the expert capacity
1 |
|
In our case, , which means that one expert will receive at most 6 tokens.
Next, we calculate the topk
1 |
|
The expert_gate
represents the selected logits, expert_index
represents the index of the selected experts, both with shape 10x2
, where 10 is the number of tokens and 2 means we choose experts with top-2 scores.
However, in our extremely biased case, the expert_index
is
1 |
|
which means for all 10 tokens, only the first and the second expert are chosen.
Next, we calculate the load balancing loss
1 |
|
The first line uses one_hot
function and the index will be encoded to
1 |
|
This mask creates three slots, it means for each token, each choice, which expert is chosen, and the chosen one is assigned with value 1.
But as you can see, the value doesn’t overlap on dimension 1, this is because the top-k function won’t choose the same expert twice, so for each choice, the chosen expert index is different.
We can remove the redundant information by
1 |
|
Now, expert_mask_aux
represents: for each token, which two experts are chosen. We can then easily calculate by
1 |
|
This line of code averages the first dimension, which means we are calculating the “average chosen time” or the “frequency” for an expert.
1 |
|
The next line of code calculates
1 |
|
Remember that gates
is just the softmax version of logits, it represents for each token, the probability of signing this token to an expert, the shape is 10x3
, so averaging on the first dimension calculates the average probability of assigning tokens to an expert.
1 |
|
The last line of code applies the formula
1 |
|
Here, since we calculate torch.mean
, then we are calculating , thus we need to multiply to get the original result.
When ideal balance achieved, we have and , then the minimum of aux loss is 1.
The next two lines of code is
1 |
|
Let’s break down the first line. Let’s consider one token, the matmul
result is
Note that here we just used one token, not the whole batch. The clamp
here is just used to avoid precision flow when training in low precision. This operation calculates the sum of the top-k logits.
The result of matmul
is summed together, and used to normalize the probability. Hence, the code here is used to normalize the top-k logits so the top-k logits sum to 1.
Next, we are going to apply token drop.
1 |
|
The first line of code transposes the expert_index
, 10, 2 -> 2, 10
, the first row represents the top-1 choices for 10 tokens, and the second row is the top-2 choices. In our case, the top-1 choices are just the first expert and the top-2 choices are just the second expert. Next, the code flattens the choices, then the first 10 elements are the top-1 choices, the second 10 elements are the top-2 choices.
The next line of code is
1 |
|
This mask indicates which expert is chosen for each choice. The next line of code calculates how many times an expert is chosen.
1 |
|
The next line of code is
1 |
|
The core idea here is to calculate priority for each token, and top-1 choices always have priority higher than top-2 choices. That’s why we flatten the index by expert_index = expert_index.reshape(-1)
, and the top-1 choices are placed before the top-2 choices.
Also, cumsum
means that tokens at early positions have priority higher than the tokens at later positions. The result of cumsum
is
1 |
|
Then the result is multiplied by expert_mask
and reduced by 1, the result is
1 |
|
The next line of code is
1 |
|
and subsequently
1 |
|
This transpose operation puts top-1 and top-2 choices back together, the final token_priority
means, for each choice, the priority of this token in this expert. Again, since the top-k function won’t choose the same expert twice, we can remove the redundant information
1 |
|
This means the priority of this token in each expert. For instance, token_priority[0, 0]=0
means the 0’th token has priority 0 in 0’th expert. Also, 0’th token has priority 0 in the first expert (token_priority[0, 1]=0
).
Next, we drop the token that exceeds the capacity and also zero out the position with value -1, we will see later why we zero out them later.
1 |
|
This line of code calculates logical and between token_priority >= 0
, which is used to zero out the negative values, and token_priority < expert_capacity
, which is used to filter those who exceeds the expert capacity.
1 |
|
Then, this mask is applied to token_priority
1 |
|
Now, the most important part.
1 |
|
Here, we one-hot each priority, creating “slots” for each expert. Imagine that each expert has several slots (6 in our case), and each slot allows one token to be inserted, the one hot result exactly did this:
1 |
|
Now we also know that zero out the negative values allows the use of one_hot
function, because one_hot
doesn’t allow negative values. However, -1 represents experts that are not chosen, now it means priority 0. Next, we use the valid_mask
to mask out the original positions with -1.
1 |
|
Now if we observe the dispatch_mask
, we actually created 6 slots for each expert. And dispatch_mask[n, e, c]
means whether token n
is assigned to expert e
’s slot c
. For example, we can see that token 0 is assigned to slot 0 in expert 0 and expert 1, token 1 is assigned to slot 1 in expert 0 and expert 1 etc.
This effectively creates expert_capacity
slots that allows only one token to be inserted in, then each expert can only accept expert_capacity
tokens.
This dispatch_mask
will be used to mask the token input, we will see later.
The next line calculates weights
1 |
|
This is just a dot product with router_probs
unsqueezed and broadcasted. The result is
1 |
|
Now, the tensor combine_weights
contains router logits (normalized) that is assigned to each expert, because the last dimension is one-hot, then the logit can be directly multiplied with the mask.
The last two lines compute the capacity rate
1 |
|
This effectively calculates the ratio of the dropped token, when no token is dropped, torch.sum(dispatch_mask)
is the total times that experts are chosen, which is the number of tokens times topk. Thus, when no token is dropped, exp_capacity_rate
is 1, values smaller than one indicates some tokens are dropped.
Next, let’s see how we calculate the expert output, assume our hidden_states
is
1 |
|
The input for experts is calculated by
1 |
|
The einsum
multiplies the first dimension and take summation, the last dimension of hidden_states
is attached after the dimension c. This is the insertion operation, tokens (vector embedding) are inserted into each slot
1 |
|
which means for each expert, each slot, we have inserted the input token. Note that token [12, 13]
, [14, 15]
, [16, 17]
, [18, 19]
are dropped.
Next are expert inference
1 |
|
the expert calculates its own chunk, within each chunk, we provide 6 tokens.
After that, the router logits will be multiplied.
1 |
|
einsum
multiplies and sum the dimension e
and c
, the result is
1 |
|
You will notice two things. (1) There four zero vectors, as described before, these tokens are dropped, and residual connection will bring them back. (2) Although we used dispatched_input
as expert output, we multiplied router logits, why output exactly equals input? This is because we normalized the top-2
logits so top-2
logits sum to 1, hence each token calculates
where is the i’th token and and are top-2 logits.