Skip to content

Commit

Permalink
Make embed_dim required and fix size mismatch bug when key_dim and va…
Browse files Browse the repository at this point in the history
…l_dim are different
  • Loading branch information
wouterkool committed Nov 20, 2020
1 parent b8c9b8a commit 5fa0b17
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions nets/graph_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@ def __init__(
self,
n_heads,
input_dim,
embed_dim=None,
embed_dim,
val_dim=None,
key_dim=None
):
super(MultiHeadAttention, self).__init__()

if val_dim is None:
assert embed_dim is not None, "Provide either embed_dim or val_dim"
val_dim = embed_dim // n_heads
if key_dim is None:
key_dim = val_dim
Expand All @@ -43,8 +42,7 @@ def __init__(
self.W_key = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim))
self.W_val = nn.Parameter(torch.Tensor(n_heads, input_dim, val_dim))

if embed_dim is not None:
self.W_out = nn.Parameter(torch.Tensor(n_heads, key_dim, embed_dim))
self.W_out = nn.Parameter(torch.Tensor(n_heads, val_dim, embed_dim))

self.init_parameters()

Expand Down

0 comments on commit 5fa0b17

Please sign in to comment.