Skip to content

Commit

Permalink
Add ParametricAttention.v2 (#913)
Browse files Browse the repository at this point in the history
* Add ParametricAttention.v2

This layer is an extension of the existing `ParametricAttention` layer,
adding support for transformations (such as a non-linear layer) of the
key representation. This brings the model closer to the paper that
suggested it (Yang et al, 2016) and gave slightly better results in
experiments.

* Use `noop` for when `key_transform` is `None`

* Remove stray import

* Add constant for key transform ref

* Check that we correctly set the key transform

* isooooooort

* Update citation to ACL link

Co-authored-by: Adriane Boyd <[email protected]>

---------

Co-authored-by: Sofie Van Landeghem <[email protected]>
Co-authored-by: Adriane Boyd <[email protected]>
  • Loading branch information
3 people authored Dec 14, 2023
1 parent c16f552 commit 88dc49d
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 1 deletion.
3 changes: 2 additions & 1 deletion thinc/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
MultiSoftmax,
MXNetWrapper,
ParametricAttention,
ParametricAttention_v2,
PyTorchLSTM,
PyTorchRNNWrapper,
PyTorchWrapper,
Expand Down Expand Up @@ -207,7 +208,7 @@
"PyTorchWrapper", "PyTorchRNNWrapper", "PyTorchLSTM",
"TensorFlowWrapper", "keras_subclass", "MXNetWrapper",
"PyTorchWrapper_v2", "Softmax_v2", "PyTorchWrapper_v3",
"SparseLinear_v2", "TorchScriptWrapper_v1",
"SparseLinear_v2", "TorchScriptWrapper_v1", "ParametricAttention_v2",

"add", "bidirectional", "chain", "clone", "concatenate", "noop",
"residual", "uniqued", "siamese", "list2ragged", "ragged2list",
Expand Down
2 changes: 2 additions & 0 deletions thinc/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from .noop import noop
from .padded2list import padded2list
from .parametricattention import ParametricAttention
from .parametricattention_v2 import ParametricAttention_v2
from .premap_ids import premap_ids
from .pytorchwrapper import (
PyTorchRNNWrapper,
Expand Down Expand Up @@ -94,6 +95,7 @@
"Mish",
"MultiSoftmax",
"ParametricAttention",
"ParametricAttention_v2",
"PyTorchLSTM",
"PyTorchWrapper",
"PyTorchWrapper_v2",
Expand Down
100 changes: 100 additions & 0 deletions thinc/layers/parametricattention_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from typing import Callable, Optional, Tuple, cast

from ..config import registry
from ..model import Model
from ..types import Floats2d, Ragged
from ..util import get_width
from .noop import noop

InT = Ragged
OutT = Ragged

KEY_TRANSFORM_REF: str = "key_transform"


@registry.layers("ParametricAttention.v2")
def ParametricAttention_v2(
*,
key_transform: Optional[Model[Floats2d, Floats2d]] = None,
nO: Optional[int] = None
) -> Model[InT, OutT]:
if key_transform is None:
key_transform = noop()

"""Weight inputs by similarity to a learned vector"""
return Model(
"para-attn",
forward,
init=init,
params={"Q": None},
dims={"nO": nO},
refs={KEY_TRANSFORM_REF: key_transform},
layers=[key_transform],
)


def forward(model: Model[InT, OutT], Xr: InT, is_train: bool) -> Tuple[OutT, Callable]:
Q = model.get_param("Q")
key_transform = model.get_ref(KEY_TRANSFORM_REF)

attention, bp_attention = _get_attention(
model.ops, Q, key_transform, Xr.dataXd, Xr.lengths, is_train
)
output, bp_output = _apply_attention(model.ops, attention, Xr.dataXd, Xr.lengths)

def backprop(dYr: OutT) -> InT:
dX, d_attention = bp_output(dYr.dataXd)
dQ, dX2 = bp_attention(d_attention)
model.inc_grad("Q", dQ.ravel())
dX += dX2
return Ragged(dX, dYr.lengths)

return Ragged(output, Xr.lengths), backprop


def init(
model: Model[InT, OutT], X: Optional[InT] = None, Y: Optional[OutT] = None
) -> None:
key_transform = model.get_ref(KEY_TRANSFORM_REF)
width = get_width(X) if X is not None else None
if width:
model.set_dim("nO", width)
if key_transform.has_dim("nO"):
key_transform.set_dim("nO", width)

# Randomly initialize the parameter, as though it were an embedding.
Q = model.ops.alloc1f(model.get_dim("nO"))
Q += model.ops.xp.random.uniform(-0.1, 0.1, Q.shape)
model.set_param("Q", Q)

X_array = X.dataXd if X is not None else None
Y_array = Y.dataXd if Y is not None else None

key_transform.initialize(X_array, Y_array)


def _get_attention(ops, Q, key_transform, X, lengths, is_train):
K, K_bp = key_transform(X, is_train=is_train)

attention = ops.gemm(K, ops.reshape2f(Q, -1, 1))
attention = ops.softmax_sequences(attention, lengths)

def get_attention_bwd(d_attention):
d_attention = ops.backprop_softmax_sequences(d_attention, attention, lengths)
dQ = ops.gemm(K, d_attention, trans1=True)
dY = ops.xp.outer(d_attention, Q)
dX = K_bp(dY)
return dQ, dX

return attention, get_attention_bwd


def _apply_attention(ops, attention, X, lengths):
output = X * attention

def apply_attention_bwd(d_output):
d_attention = (X * d_output).sum(axis=1, keepdims=True)
dX = d_output * attention
return dX, d_attention

return output, apply_attention_bwd
2 changes: 2 additions & 0 deletions thinc/tests/layers/test_layers_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ def assert_data_match(Y, out_data):
("MultiSoftmax.v1", {"nOs": (1, 3)}, array2d, array2d),
# ("CauchySimilarity.v1", {}, (array2d, array2d), array1d),
("ParametricAttention.v1", {}, ragged, ragged),
("ParametricAttention.v2", {}, ragged, ragged),
("ParametricAttention.v2", {"key_transform": {"@layers": "Gelu.v1"}}, ragged, ragged),
("SparseLinear.v1", {}, (numpy.asarray([1, 2, 3], dtype="uint64"), array1d, numpy.asarray([1, 1], dtype="i")), array2d),
("SparseLinear.v2", {}, (numpy.asarray([1, 2, 3], dtype="uint64"), array1d, numpy.asarray([1, 1], dtype="i")), array2d),
("remap_ids.v1", {"dtype": "f"}, ["a", 1, 5.0], array2dint),
Expand Down
10 changes: 10 additions & 0 deletions thinc/tests/layers/test_parametric_attention_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from thinc.layers.gelu import Gelu
from thinc.layers.parametricattention_v2 import (
KEY_TRANSFORM_REF,
ParametricAttention_v2,
)


def test_key_transform_used():
attn = ParametricAttention_v2(key_transform=Gelu())
assert attn.get_ref(KEY_TRANSFORM_REF).name == "gelu"
38 changes: 38 additions & 0 deletions website/docs/api-layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,44 @@ attention mechanism.
https://github.com/explosion/thinc/blob/master/thinc/layers/parametricattention.py
```

### ParametricAttention_v2 {#parametricattention_v2 tag="function"}

<inline-list>

- **Input:** <ndarray>Ragged</ndarray>
- **Output:** <ndarray>Ragged</ndarray>
- **Parameters:** <ndarray shape="nO,">Q</ndarray>

</inline-list>

A layer that uses the parametric attention scheme described by
[Yang et al. (2016)](https://aclanthology.org/N16-1174).
The layer learns a parameter vector that is used as the keys in a single-headed
attention mechanism.

<infobox variant="warning">

The original `ParametricAttention` layer uses the hidden representation as-is
for the keys in the attention. This differs from the paper that introduces
parametric attention (Equation 5). `ParametricAttention_v2` adds the option to
transform the key representation in line with the paper by passing such a
transformation through the `key_transform` parameter.

</infobox>


| Argument | Type | Description |
|-----------------|----------------------------------------------|------------------------------------------------------------------------|
| `key_transform` | <tt>Optional[Model[Floats2d, Floats2d]]</tt> | Transformation to apply to the key representations. Defaults to `None` |
| `nO` | <tt>Optional[int]</tt> | The size of the output vectors. |
| **RETURNS** | <tt>Model[Ragged, Ragged]</tt> | The created attention layer. |

```python
https://github.com/explosion/thinc/blob/master/thinc/layers/parametricattention_v2.py
```



### Relu {#relu tag="function"}

<inline-list>
Expand Down

0 comments on commit 88dc49d

Please sign in to comment.