-
Notifications
You must be signed in to change notification settings - Fork 157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to visualise the attention weight of the inputs #106
Comments
The general way to do this is to log the intermediates of the network computations, and then recompute the attention mask. Here's a brief sketch of how you would do this: @jax.jit
def get_attention_mask(model, observation, task):
_, intermediates = model.module.apply(
{'params': model.params},
observation,
task,
observation['timestep_pad_mask'],
train=False,
method="octo_transformer",
mutable=['intermediates'],
capture_intermediates=True
)
# Intermediates holds literally the output of every submodule run in the NN
# As an example, let's get out the last Transformer MHA
outs = intermediates['intermediates']['octo_transformer']['BlockTransformer_0']['Transformer_0']['encoderblock_11']['MultiHeadDotProductAttention_0']
key = outs['key']['__call__']
query = outs['query']['__call__']
attention_weights = nn.dot_product_attention_weights(query, key)
# get the attention weights corresponding to the readout token
return attention_weights[..., -1, :] # Shape (batch_size, # attention heads, # tokens) Some notes:
|
Get it! Thanks for your kind response! And I will try with this right now. |
I am trying to do the same thing you do. @oym1994, did you manage to solve the problem? One thing I have a problem with is figuring out which tokens correspond to images/languages / etc. If you solved the problem, could you share the code? |
Hello,
Thanks for your great job. We want know more explanation of the output and so how can we visualise the attention weight of the inputs(including image and language)
Thanks for your attention and keep waiting for your kind response!
The text was updated successfully, but these errors were encountered: