-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
469af9a
commit 698c747
Showing
3 changed files
with
147 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from numpy import ndarray | ||
|
||
from mygrad import multiply, Tensor | ||
from .sigmoid import sigmoid | ||
|
||
|
||
def glu(x, axis=-1, constant=False): | ||
""" Returns the Gated Linear Unit A * σ(B), where A and B are split from `x`. | ||
Parameters | ||
---------- | ||
x : mygrad.Tensor | ||
The input. | ||
axis : int, optional (default=-1) | ||
The axis along which to split the input in half and apply the GLU. | ||
constant : boolean, optional (default=False) | ||
If ``True``, the returned tensor is a constant (it | ||
does not back-propagate a gradient). | ||
Returns | ||
------- | ||
mygrad.Tensor | ||
The result of applying the Gated Linear Unit elementwise to the input. | ||
Extended Description | ||
-------------------- | ||
The Gated Linear Unit was proposed in the paper | ||
"Language Modeling with Gated Convolutional Networks" | ||
Yann Dauphin, Angela Fan, Michael Auli, David Grangier | ||
available at https://arxiv.org/abs/1612.08083 | ||
The GLU operation splits the input `x` in half along `axis`, storing the first half in A and the | ||
second in B. The return value is then A ⊙ σ(B), where ⊙ is elementwise multiplication and σ is | ||
the sigmoid function. | ||
Examples | ||
-------- | ||
>>> import mygrad as mg | ||
>>> from mygrad.nnet.activations import glu | ||
>>> x = mg.arange(-5, 5) | ||
>>> x | ||
Tensor([-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]) | ||
>>> y = glu(x); y | ||
Tensor([-2.5 , -2.92423431, -2.64239123, -1.90514825, -0.98201379]) | ||
>>> y.backward() | ||
>>> x.grad | ||
array([ 0, 0, 0, 0, 0, -1, 0, 0, 0, 0]) | ||
""" | ||
if isinstance(axis, (ndarray, Tensor)): | ||
axis = axis.item() | ||
|
||
if not isinstance(axis, int): | ||
raise TypeError( | ||
f"`axis` must be an integer-valued scalar, got {axis} (type {type(axis)})" | ||
) | ||
|
||
first_idx = list(slice(None) for _ in x.shape) | ||
second_idx = list(slice(None) for _ in x.shape) | ||
first_idx[axis] = slice(0, x.shape[axis] // 2) | ||
second_idx[axis] = slice(x.shape[axis] // 2, None) | ||
|
||
first_half = x[tuple(first_idx)] | ||
second_half = x[tuple(second_idx)] | ||
|
||
if first_half.shape != second_half.shape: | ||
raise ValueError( | ||
f"The shapes after splitting must be the same but got {first_half.shape} " | ||
"and {second_half.shape}" | ||
) | ||
return multiply(first_half, sigmoid(second_half), constant=constant) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import sys | ||
|
||
from hypothesis import assume, given | ||
import hypothesis.strategies as st | ||
import hypothesis.extra.numpy as hnp | ||
import numpy as np | ||
import pytest | ||
|
||
import mygrad as mg | ||
from mygrad.nnet.activations import glu | ||
from tests.wrappers.uber import backprop_test_factory, fwdprop_test_factory | ||
|
||
|
||
@pytest.mark.parametrize("axis", (None, 1j)) | ||
def test_input_validation(axis): | ||
with pytest.raises(TypeError): | ||
glu(2, axis=axis) | ||
|
||
|
||
@given(arr=hnp.arrays(dtype=np.float32, shape=hnp.array_shapes())) | ||
def test_bad_shape_dimension(arr): | ||
assume(any(x % 2 for x in arr.shape)) | ||
idx = np.random.choice([i for i, axis in enumerate(arr.shape) if axis % 2]).item() | ||
with pytest.raises(ValueError): | ||
glu(arr, idx) | ||
|
||
|
||
def _np_glu(x, axis): | ||
if isinstance(axis, (np.ndarray, mg.Tensor)): | ||
axis = axis.item() | ||
|
||
first_idx = list(slice(None) for _ in x.shape) | ||
second_idx = list(slice(None) for _ in x.shape) | ||
first_idx[axis] = slice(0, x.shape[axis] // 2) | ||
second_idx[axis] = slice(x.shape[axis] // 2, None) | ||
|
||
first_half = x[tuple(first_idx)] | ||
second_half = x[tuple(second_idx)] | ||
|
||
return first_half * (1 / (1 + np.exp(-second_half))) | ||
|
||
|
||
@st.composite | ||
def _axis_strategy(draw, arr): | ||
assume(any(not x % 2 for x in arr.shape)) | ||
val = draw(st.sampled_from([i for i, axis in enumerate(arr.shape) if not axis % 2])) | ||
dtype = draw(st.sampled_from((np.array, mg.Tensor, int))) | ||
return dtype(val) | ||
|
||
|
||
@fwdprop_test_factory( | ||
mygrad_func=glu, | ||
true_func=_np_glu, | ||
num_arrays=1, | ||
index_to_bnds={0: (-np.log(sys.float_info.max), np.log(sys.float_info.max))}, | ||
kwargs={"axis": lambda x: _axis_strategy(x),}, | ||
assumptions=lambda arr, axis: any(not x % 2 for x in arr.shape), | ||
) | ||
def test_glu_fwd(): | ||
pass | ||
|
||
|
||
@backprop_test_factory( | ||
mygrad_func=glu, | ||
true_func=_np_glu, | ||
num_arrays=1, | ||
index_to_bnds={0: (-np.log(sys.float_info.max), np.log(sys.float_info.max))}, | ||
kwargs={"axis": lambda x: _axis_strategy(x)}, | ||
assumptions=lambda arr, axis: any(not x % 2 for x in arr.shape), | ||
vary_each_element=True, | ||
) | ||
def test_glu_bkwd(): | ||
pass |