Efficient PyTorch Implementation of MoE with Aux loss and Token drop

本文最后更新于:August 4, 2025 am

1 Preliminaries

Mixture-of-Experts is an essential architecture choice when building LLMs. Since the prevalence of DeepSeekV3, companies will consider whether to use MoE structure before LLM pretraining.

The concept of MoE is simple, suppose we have EE experts, for a token input XX, we use a router RR to decide which K experts it goes through, typically we have KEK \ll E. This creates sparsely activated computation since not all experts are used during forward.

Typically, experts are FFNs that output tokens with the same shape as the input tokens. Router is a linear layer that transforms input dimension to logits that represents the probability that a token will be assigned to an expert. We can formulate the computation as follows. Suppose XRDX\in \mathbb{R}^{D}, we first calculate the logits

logits=R(X)RE\mathrm{logits} = R(X)\in \mathbb{R}^{E}

Then, we choose the top-K logits

logits[e]=TopK(logits[e])={logits[e]eTopK0otherwise\mathrm{logits}'[e] = \mathrm{TopK}(\mathrm{logits[e]})= \begin{cases} \mathrm{logits}[e]\quad e\in \mathrm{TopK}\\ 0\quad \text{otherwise} \end{cases}

After that, the logits will be multiplied to the expert output

X=e=1Elogits[e]FFNe(X)X' = \sum_{e=1}^E \mathrm{logits}[e]\mathrm{FFN}_e(X)

Since logits\mathrm{logits} contains zero element, so not all experts contributes to the output.

Our goal is to implement the above computation efficiently, and enable two standard tricks: (1) Load balancing loss. (2) Token drop

One major issue in MoE training is that the router may collapse and assign tokens to only a few experts, sometimes only one expert. If such case happens, other experts will never be updated during optimization, and we are not fully utilizing the parameters.

Load balancing loss introduces an aux loss that is combined with the next token prediction loss

Laux=EeFePe\mathcal{L}_{\text{aux}} = E\sum_e F_eP_e

where FeF_e is the average frequency an expert is chosen, and PeP_e is the average probability an expert is assigned by the router.

Except for the load balancing loss, people may use token drop. They will set a capacity factor CC for each expert, then one expert can receive at most

Tm=max(K,KN/E)T_m = \max (K, K*N / E)

where K is the number of experts chosen, N is the total number of tokens and EE is the number of experts. If an expert is assigned with tokens more than TmT_m, these tokens will be dropped (set to zero as the MoE layer output). However, they are not really dropped, after setting to zero, residual connection will bring back the tokens before MoE layer.

2 Code Walk through

Let’s use the code from HunYuan MoE models, released by Tencent. The core is the topkgating function.

Suppose we have an input

1
2
3
4
5
6
logits = Tensor(
[[1, 0, 0],
[1, 0, 0],
...
[1, 0, 0]]
) # 10, 3

and topk=2

Which means we have 10 tokens and will be assigned to 3 experts, among which two experts will be chosen. Note that the router is severely biased, it chooses expert 0 only and do not assign scores to other experts, this is used to demonstrate the token drop mechanism.

The first two lines of code is

1
2
logits = logits.float()
gates = F.softmax(logits, dim=1)

For stability, we always convert input to FP32 before softmax. Then, the actual gate logits is calculated by the softmax funcion on dimension 1. The result is

1
2
3
4
5
6
7
8
gates
>>> Tensor(
[[0.5761, 0.2119, 0.2119],
[0.5761, 0.2119, 0.2119],
...
[0.5761, 0.2119, 0.2119]
]
)

The next line of code calculates the expert capacity

1
expert_capacity = max(topk, topk * gates.shape[0] // gates.shape[1])

In our case, max(2,210//3)=6\max(2, 2*10 // 3)=6, which means that one expert will receive at most 6 tokens.

Next, we calculate the topk

1
2
3
num_experts = int(gates.shape[1])  # 3
expert_gate, expert_index = torch.topk(gates, topk)
expert_mask = F.one_hot(expert_index, num_experts)

The expert_gate represents the selected logits, expert_index represents the index of the selected experts, both with shape 10x2, where 10 is the number of tokens and 2 means we choose experts with top-2 scores.

However, in our extremely biased case, the expert_index is

1
2
3
4
5
6
7
expert_index
>>> Tensor(
[[0, 1],
[0, 1],
...
[0, 1]]
) # 10, 2

which means for all 10 tokens, only the first and the second expert are chosen.

Next, we calculate the load balancing loss

1
2
3
4
5
expert_mask = F.one_hot(expert_index, num_experts)
expert_mask_aux = expert_mask.max(dim=-2)[0]
tokens_per_group_and_expert = torch.mean(expert_mask_aux.float(), dim=-2)
router_prob_per_group_and_expert = torch.mean(gates.float(), dim=-2)
l_aux = num_experts**2 * torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert)

The first line uses one_hot function and the index will be encoded to

1
2
3
4
5
6
7
8
9
10
expert_max
>>> Tensor(
[[[1, 0, 0],
[0, 1, 0]],
[[1, 0, 0],
[0, 1, 0]],
...
[[1, 0, 0],
[0, 1, 0]]]
) # 10, 2, 3

This mask creates three slots, it means for each token, each choice, which expert is chosen, and the chosen one is assigned with value 1.

But as you can see, the value doesn’t overlap on dimension 1, this is because the top-k function won’t choose the same expert twice, so for each choice, the chosen expert index is different.

We can remove the redundant information by

1
2
3
4
5
6
7
8
expert_mask_aux = expert_mask.max(dim=-2)[0]
expert_mask_aux
>>> Tensor(
[[1, 1, 0],
[1, 1, 0],
...
[1, 1, 0]]
) # 10, 3

Now, expert_mask_aux represents: for each token, which two experts are chosen. We can then easily calculate FiF_i by

1
tokens_per_group_and_expert = torch.mean(expert_mask_aux.float(), dim=-2)

This line of code averages the first dimension, which means we are calculating the “average chosen time” or the “frequency” for an expert.

1
2
tokens_per_group_and_expert
>>> Tensor([1, 1, 0]) # 3

The next line of code calculates PiP_i

1
router_prob_per_group_and_expert = torch.mean(gates.float(), dim=-2)

Remember that gates is just the softmax version of logits, it represents for each token, the probability of signing this token to an expert, the shape is 10x3, so averaging on the first dimension calculates the average probability of assigning tokens to an expert.

1
2
router_prob_per_group_and_expert
>>> Tensor([0.5761, 0.2119, 0.2119]) # 3

The last line of code applies the formula EeFePeE\sum_e F_eP_e

1
l_aux = num_experts**2 * torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert)

Here, since we calculate torch.mean, then we are calculating 1EeFePe\frac{1}{E}\sum_eF_eP_e, thus we need to multiply E2E^2 to get the original result.

When ideal balance achieved, we have Fi=1/EF_i=1/E and Pe=1/EP_e=1/E, then the minimum of aux loss is 1.

The next two lines of code is

1
2
3
4
5
6
7
8
gates_s = torch.clamp(
torch.matmul(
expert_mask.float(),
gates.unsqueeze(-1)
).sum(dim=1),
min=torch.finfo(gates.dtype).eps
)
router_probs = gates / gates_s

Let’s break down the first line. Let’s consider one token, the matmul result is

[100010]×[0.57610.21190.2119]=[0.57610.2119]\begin{bmatrix} 1 &0 &0\\ 0 & 1 & 0 \end{bmatrix}\times \begin{bmatrix} 0.5761 \\ 0.2119 \\ 0.2119 \end{bmatrix}= \begin{bmatrix} 0.5761\\ 0.2119 \end{bmatrix}

Note that here we just used one token, not the whole batch. The clamp here is just used to avoid precision flow when training in low precision. This operation calculates the sum of the top-k logits.

The result of matmul is summed together, and used to normalize the probability. Hence, the code here is used to normalize the top-k logits so the top-k logits sum to 1.

Next, we are going to apply token drop.

1
2
3
4
5
6
7
expert_index = torch.transpose(expert_index, 0, 1)
>>> Tensor(
[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
)
expert_index = expert_index.reshape(-1)
>>> Tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

The first line of code transposes the expert_index, 10, 2 -> 2, 10, the first row represents the top-1 choices for 10 tokens, and the second row is the top-2 choices. In our case, the top-1 choices are just the first expert and the top-2 choices are just the second expert. Next, the code flattens the choices, then the first 10 elements are the top-1 choices, the second 10 elements are the top-2 choices.

The next line of code is

1
2
3
4
5
6
7
8
9
10
11
expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32)
>>> Tensor(
[[1, 0, 0],
[1, 0, 0],
...
[1, 0, 0],
[0, 1, 0],
[0, 1, 0],
...
[0, 1, 0]]
) # 20, 3

This mask indicates which expert is chosen for each choice. The next line of code calculates how many times an expert is chosen.

1
2
exp_counts = torch.sum(expert_mask, dim=0).detach()
>>> Tensor([10, 10, 0])

The next line of code is

1
token_priority = torch.cumsum(expert_mask, dim=0) * expert_mask - 1

The core idea here is to calculate priority for each token, and top-1 choices always have priority higher than top-2 choices. That’s why we flatten the index by expert_index = expert_index.reshape(-1), and the top-1 choices are placed before the top-2 choices.

Also, cumsum means that tokens at early positions have priority higher than the tokens at later positions. The result of cumsum is

1
2
3
4
5
6
7
8
9
10
>>> Tensor(
[[1, 0, 0],
[2, 0, 0],
...
[10, 0, 0],
[10, 1, 0],
[10, 2, 0],
...
[10, 10, 0]]
) # 20, 3

Then the result is multiplied by expert_mask and reduced by 1, the result is

1
2
3
4
5
6
7
8
9
10
Tensor(
[[0, -1, -1],
[1, -1, -1],
...
[9, -1, -1],
[-1, 0, -1],
[-1, 1, -1],
...
[-1, 9, -1]]
) # 20, 3

The next line of code is

1
2
3
4
5
6
7
8
9
10
11
12
token_priority = token_priority.reshape((topk, -1, num_experts))
>>> Tensor(
[[[0, -1, -1],
[1, -1, -1],
...
[9, -1, -1],],

[[-1, 0, -1],
[-1, 1, -1],
...
[-1, 9, -1]]]
) # 2, 10, 3

and subsequently

1
2
3
4
5
6
7
8
9
10
11
12
token_priority = torch.transpose(token_priority, 0, 1)
>>> Tensor(
[[[0, -1, -1],
[-1, 0, -1]],

[[1, -1, -1],
[-1, 1, -1]],
...
[[9, -1, -1],
[-1, 9, -1]]]

) # 10, 2, 3

This transpose operation puts top-1 and top-2 choices back together, the final token_priority means, for each choice, the priority of this token in this expert. Again, since the top-k function won’t choose the same expert twice, we can remove the redundant information

1
2
3
4
5
6
7
token_priority = torch.max(token_priority, dim=1)[0]
>>> Tensor(
[[0, 0, -1],
[1, 1, -1],
...
[9, 9, -1]]
) # 10, 3

This means the priority of this token in each expert. For instance, token_priority[0, 0]=0 means the 0’th token has priority 0 in 0’th expert. Also, 0’th token has priority 0 in the first expert (token_priority[0, 1]=0).

Next, we drop the token that exceeds the capacity and also zero out the position with value -1, we will see later why we zero out them later.

1
valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity)

This line of code calculates logical and between token_priority >= 0, which is used to zero out the negative values, and token_priority < expert_capacity, which is used to filter those who exceeds the expert capacity.

1
2
3
4
5
6
7
valid_mask
>>> Tensor(
[[True, True, False],
[True, True, False],
...
[False, False, False],]
) # here, the last 4 rows are set to False because the capacity is 6.

Then, this mask is applied to token_priority

1
2
3
4
5
6
7
8
9
10
token_priority = torch.masked_fill(token_priority, ~valid_mask, 0)
>>> Tensor(
[[0, 0, 0],
[1, 1, 0],
...
[5, 5, 0],
[0, 0, 0],
...
[0, 0, 0]]
) # 10, 3

Now, the most important part.

1
dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool)

Here, we one-hot each priority, creating “slots” for each expert. Imagine that each expert has several slots (6 in our case), and each slot allows one token to be inserted, the one hot result exactly did this:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
dispatch_mask
>>> Tensor(
[[[True, False, False, False, False, False],
[True, False, False, False, False, False],
[True, False, False, False, False, False]],

[[False, True, False, False, False, False],
[False, True, False, False, False, False],
[True, False, False, False, False, False]],

...

[[False, False, False, False, False, True],
[False, False, False, False, False, True],
[True, False, False, False, False, False]],

...

[[True, False, False, False, False, False],
[True, False, False, False, False, False],
[True, False, False, False, False, False]]]
)

Now we also know that zero out the negative values allows the use of one_hot function, because one_hot doesn’t allow negative values. However, -1 represents experts that are not chosen, now it means priority 0. Next, we use the valid_mask to mask out the original positions with -1.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, expert_capacity)
dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0)
dispatch_mask
>>> Tensor(
[[[True, False, False, False, False, False],
[True, False, False, False, False, False],
[False, False, False, False, False, False]],

[[False, True, False, False, False, False],
[False, True, False, False, False, False],
[False, False, False, False, False, False]],

...

[[False, False, False, False, False, False],
[False, False, False, False, False, False],
[False, False, False, False, False, False]]]
) # 10, 3, 6

Now if we observe the dispatch_mask, we actually created 6 slots for each expert. And dispatch_mask[n, e, c] means whether token n is assigned to expert e’s slot c. For example, we can see that token 0 is assigned to slot 0 in expert 0 and expert 1, token 1 is assigned to slot 1 in expert 0 and expert 1 etc.

This effectively creates expert_capacity slots that allows only one token to be inserted in, then each expert can only accept expert_capacity tokens.

This dispatch_mask will be used to mask the token input, we will see later.

The next line calculates weights

1
combine_weights = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask)

This is just a dot product with router_probs unsqueezed and broadcasted. The result is

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
combine_weights
>>> Tensor(
[[[0.7311, 0, 0, 0, 0, 0],
[0.2689, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]],

[[0, 0.7311, 0, 0, 0, 0],
[0, 0.2689, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]],

...

[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]]]
) # 10, 3, 6

Now, the tensor combine_weights contains router logits (normalized) that is assigned to each expert, because the last dimension is one-hot, then the logit can be directly multiplied with the mask.

The last two lines compute the capacity rate

1
2
exp_counts_capacity = torch.sum(dispatch_mask)
exp_capacity_rate = exp_counts_capacity / (logits.shape[0]*topk)

This effectively calculates the ratio of the dropped token, when no token is dropped, torch.sum(dispatch_mask) is the total times that experts are chosen, which is the number of tokens times topk. Thus, when no token is dropped, exp_capacity_rate is 1, values smaller than one indicates some tokens are dropped.

Next, let’s see how we calculate the expert output, assume our hidden_states is

1
2
3
4
5
6
7
8
9
10
11
12
13
hidden_states = torch.arange(20, dtype=torch.float32).reshape(10, 2)
>>> Tensor(
[[0, 1],
[2, 3],
[4, 5],
[6, 7],
[8, 9],
[10, 11],
[12, 13],
[14, 15],
[16, 17],
[18, 19]]
)

The input for experts is calculated by

1
dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), hidden_states)

The einsum multiplies the first dimension and take summation, the last dimension of hidden_states is attached after the dimension c. This is the insertion operation, tokens (vector embedding) are inserted into each slot

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
dispatched_input
>>> Tensor(
[[[0, 1],
[2, 3],
[4, 5],
[6, 7],
[8, 9],
[10, 11]],
[[0, 1],
[2, 3],
[4, 5],
[6, 7],
[8, 9],
[10, 11]],
[[0, 0],
[0, 0],
[0, 0],
[0, 0],
[0, 0],
[0, 0]]]
) # 3, 6, 2

which means for each expert, each slot, we have inserted the input token. Note that token [12, 13], [14, 15], [16, 17], [18, 19] are dropped.

Next are expert inference

1
2
3
4
5
6
chunks = dispatched_input.chunk(self.num_experts, dim=0)
expert_outputs = []
for chunk, expert in zip(chunks, self.experts):
expert_outputs.append(expert(chunk)) # 6, 2

expert_output = torch.cat(expert_outputs, dim=0) # 3, 6, 2

the expert calculates its own chunk, within each chunk, we provide 6 tokens.

After that, the router logits will be multiplied.

1
2
3
# for convenience, we use input as expert output
expert_output = dispatched_input
combined_output = torch.einsum("sec,ecm->sm", combine_weights.type_as(hidden_states), expert_output)

einsum multiplies and sum the dimension e and c, the result is

1
2
3
4
5
6
7
8
9
10
11
12
13
combined_output
>>> Tensor(
[[0, 1],
[2, 3],
[4, 5],
[6, 7],
[8, 9],
[10, 11],
[0, 0],
[0, 0],
[0, 0],
[0, 0]]
)

You will notice two things. (1) There four zero vectors, as described before, these tokens are dropped, and residual connection will bring them back. (2) Although we used dispatched_input as expert output, we multiplied router logits, why output exactly equals input? This is because we normalized the top-2 logits so top-2 logits sum to 1, hence each token calculates

tig1+tig2=(g1+g2)ti=tit_i*g_1 + t_i*g_2 = (g_1+g_2)*t_i=t_i

where tit_i is the i’th token and g1g_1 and g2g_2 are top-2 logits.


Efficient PyTorch Implementation of MoE with Aux loss and Token drop
https://jesseprince.github.io/2025/08/03/ai_infra/moe_implementation/
Author
林正
Posted on
August 3, 2025
Licensed under