diff --git a/transformer/SubLayers.py b/transformer/SubLayers.py index 0298a19..b023a21 100644 --- a/transformer/SubLayers.py +++ b/transformer/SubLayers.py @@ -50,7 +50,8 @@ def forward(self, q, k, v, mask=None): k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv - mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. + if mask is not None: + mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. output, attn = self.attention(q, k, v, mask=mask) output = output.view(n_head, sz_b, len_q, d_v)