Skip to content

Commit

Permalink
fix length
Browse files Browse the repository at this point in the history
  • Loading branch information
llleohk committed May 2, 2024
1 parent dafe8c0 commit f5eeb5c
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion wenet/transformer/positionwise_feed_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ def forward(self, xs: torch.Tensor) -> torch.Tensor:
router = self.gate(xs) # (B*L, n_expert)
if self.gate_type == 'noisy':
noisy_router = self.noisy_gate(xs)
noisy_router = torch.randn_like(router) * F.softplus(noisy_router) * self.training
noisy_router = (
torch.randn_like(router) * F.softplus(noisy_router)
) * self.training
router = router + noisy_router
logits, selected_experts = torch.topk(
router, self.n_expert_activated
Expand Down

0 comments on commit f5eeb5c

Please sign in to comment.