Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible mistakes in d_k, d_v of MultiheadAttention #211

Open
SARIHUST opened this issue Jul 4, 2023 · 0 comments
Open

Possible mistakes in d_k, d_v of MultiheadAttention #211

SARIHUST opened this issue Jul 4, 2023 · 0 comments

Comments

@SARIHUST
Copy link

SARIHUST commented Jul 4, 2023

In the __init__ function of the MultiheadAttention class you use d_k and d_v to denote the dimensions of keys and values. You also define the projections below:

self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)

However, when d_v is not the same as d_q (it should be d_model // n_head), this will cause the shape of the queries to change after the attention operation and will cause problems in multiple layer structures.
After going through the official MultiheadAttention implementation of PyTorch, I believe that you used a similar presentation with:

self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))
self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs))
self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs))

However, in the official PyTorch implementation, it used weights, rather than a nn.Linear class, which means that the weights are actually used to transformer the dimension of keys from self.kdim to embed_dim, which is the very opposite to what your implementation is doing. So I believe that there might be some errors with your code. But overall, thank you for your work, it helped me a lot.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant