-
Notifications
You must be signed in to change notification settings - Fork 278
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add ParametricAttention.v2 This layer is an extension of the existing `ParametricAttention` layer, adding support for transformations (such as a non-linear layer) of the key representation. This brings the model closer to the paper that suggested it (Yang et al, 2016) and gave slightly better results in experiments. * Use `noop` for when `key_transform` is `None` * Remove stray import * Add constant for key transform ref * Check that we correctly set the key transform * isooooooort * Update citation to ACL link Co-authored-by: Adriane Boyd <[email protected]> --------- Co-authored-by: Sofie Van Landeghem <[email protected]> Co-authored-by: Adriane Boyd <[email protected]>
- Loading branch information
1 parent
c16f552
commit 88dc49d
Showing
6 changed files
with
154 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
from typing import Callable, Optional, Tuple, cast | ||
|
||
from ..config import registry | ||
from ..model import Model | ||
from ..types import Floats2d, Ragged | ||
from ..util import get_width | ||
from .noop import noop | ||
|
||
InT = Ragged | ||
OutT = Ragged | ||
|
||
KEY_TRANSFORM_REF: str = "key_transform" | ||
|
||
|
||
@registry.layers("ParametricAttention.v2") | ||
def ParametricAttention_v2( | ||
*, | ||
key_transform: Optional[Model[Floats2d, Floats2d]] = None, | ||
nO: Optional[int] = None | ||
) -> Model[InT, OutT]: | ||
if key_transform is None: | ||
key_transform = noop() | ||
|
||
"""Weight inputs by similarity to a learned vector""" | ||
return Model( | ||
"para-attn", | ||
forward, | ||
init=init, | ||
params={"Q": None}, | ||
dims={"nO": nO}, | ||
refs={KEY_TRANSFORM_REF: key_transform}, | ||
layers=[key_transform], | ||
) | ||
|
||
|
||
def forward(model: Model[InT, OutT], Xr: InT, is_train: bool) -> Tuple[OutT, Callable]: | ||
Q = model.get_param("Q") | ||
key_transform = model.get_ref(KEY_TRANSFORM_REF) | ||
|
||
attention, bp_attention = _get_attention( | ||
model.ops, Q, key_transform, Xr.dataXd, Xr.lengths, is_train | ||
) | ||
output, bp_output = _apply_attention(model.ops, attention, Xr.dataXd, Xr.lengths) | ||
|
||
def backprop(dYr: OutT) -> InT: | ||
dX, d_attention = bp_output(dYr.dataXd) | ||
dQ, dX2 = bp_attention(d_attention) | ||
model.inc_grad("Q", dQ.ravel()) | ||
dX += dX2 | ||
return Ragged(dX, dYr.lengths) | ||
|
||
return Ragged(output, Xr.lengths), backprop | ||
|
||
|
||
def init( | ||
model: Model[InT, OutT], X: Optional[InT] = None, Y: Optional[OutT] = None | ||
) -> None: | ||
key_transform = model.get_ref(KEY_TRANSFORM_REF) | ||
width = get_width(X) if X is not None else None | ||
if width: | ||
model.set_dim("nO", width) | ||
if key_transform.has_dim("nO"): | ||
key_transform.set_dim("nO", width) | ||
|
||
# Randomly initialize the parameter, as though it were an embedding. | ||
Q = model.ops.alloc1f(model.get_dim("nO")) | ||
Q += model.ops.xp.random.uniform(-0.1, 0.1, Q.shape) | ||
model.set_param("Q", Q) | ||
|
||
X_array = X.dataXd if X is not None else None | ||
Y_array = Y.dataXd if Y is not None else None | ||
|
||
key_transform.initialize(X_array, Y_array) | ||
|
||
|
||
def _get_attention(ops, Q, key_transform, X, lengths, is_train): | ||
K, K_bp = key_transform(X, is_train=is_train) | ||
|
||
attention = ops.gemm(K, ops.reshape2f(Q, -1, 1)) | ||
attention = ops.softmax_sequences(attention, lengths) | ||
|
||
def get_attention_bwd(d_attention): | ||
d_attention = ops.backprop_softmax_sequences(d_attention, attention, lengths) | ||
dQ = ops.gemm(K, d_attention, trans1=True) | ||
dY = ops.xp.outer(d_attention, Q) | ||
dX = K_bp(dY) | ||
return dQ, dX | ||
|
||
return attention, get_attention_bwd | ||
|
||
|
||
def _apply_attention(ops, attention, X, lengths): | ||
output = X * attention | ||
|
||
def apply_attention_bwd(d_output): | ||
d_attention = (X * d_output).sum(axis=1, keepdims=True) | ||
dX = d_output * attention | ||
return dX, d_attention | ||
|
||
return output, apply_attention_bwd |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from thinc.layers.gelu import Gelu | ||
from thinc.layers.parametricattention_v2 import ( | ||
KEY_TRANSFORM_REF, | ||
ParametricAttention_v2, | ||
) | ||
|
||
|
||
def test_key_transform_used(): | ||
attn = ParametricAttention_v2(key_transform=Gelu()) | ||
assert attn.get_ref(KEY_TRANSFORM_REF).name == "gelu" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters