Skip to content

Commit

Permalink
only training noisy
Browse files Browse the repository at this point in the history
  • Loading branch information
llleohk committed May 2, 2024
1 parent d24a9f4 commit dafe8c0
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion wenet/transformer/positionwise_feed_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class MoEFFNLayer(torch.nn.Module):
Modified from https://github.com/Lightning-AI/lit-gpt/pull/823
https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
Noisy-gate reference from https://arxiv.org/pdf/1701.06538.pdf
Args:
n_expert: number of expert.
n_expert_activated: The actual number of experts used for each frame
Expand Down Expand Up @@ -112,7 +114,7 @@ def forward(self, xs: torch.Tensor) -> torch.Tensor:
router = self.gate(xs) # (B*L, n_expert)
if self.gate_type == 'noisy':
noisy_router = self.noisy_gate(xs)
noisy_router = torch.randn_like(router) * F.softplus(noisy_router)
noisy_router = torch.randn_like(router) * F.softplus(noisy_router) * self.training
router = router + noisy_router
logits, selected_experts = torch.topk(
router, self.n_expert_activated
Expand Down

0 comments on commit dafe8c0

Please sign in to comment.