Skip to content

Commit

Permalink
Add Cumprod to tf2jax ops.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 554941388
  • Loading branch information
TF2JAXDev authored and TF2JAXDev committed Aug 8, 2023
1 parent c4063d1 commit 7adfe09
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 4 deletions.
4 changes: 4 additions & 0 deletions tf2jax/_src/numpy_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,10 @@ def cumsum(arr, axis: int):
return _get_np(arr).cumsum(arr, axis=axis)


def cumprod(arr, axis: int):
return _get_np(arr).cumprod(arr, axis=axis)


def max_(arr, axis: Union[int, Sequence[int]], keepdims: bool):
axis = tuple(axis) if isinstance(axis, (list, tuple)) else (axis,)
return _get_np(arr).max(arr, axis=axis, keepdims=keepdims)
Expand Down
25 changes: 25 additions & 0 deletions tf2jax/_src/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,31 @@ def _func(x: jnp.ndarray, axis: jnp.ndarray) -> jnp.ndarray:
return _func


@register_operation("Cumprod")
def _cumprod(proto):
"""Parse a Cumprod Op."""
_check_attrs(proto, {"T", "Tidx", "exclusive", "reverse"})

exclusive = proto.attr["exclusive"].b
reverse = proto.attr["reverse"].b

def _func(x: jnp.ndarray, axis: jnp.ndarray) -> jnp.ndarray:
axis = axis.item()
if reverse:
x = anp.flip(x, axis=axis)
if exclusive:
pad_shape = list(x.shape)
pad_shape[axis] = 1
x = anp.concatenate([np.ones(pad_shape, dtype=x.dtype), x], axis=axis)
x = x[(slice(None),) * axis + (slice(0, -1), Ellipsis)]
res = anp.cumprod(x, axis=axis)
if reverse:
res = anp.flip(res, axis=axis)
return res

return _func


@register_operation("DepthwiseConv2dNative")
def _depthwise_conv2d(proto):
"""Parse a DepthwiseConv2d Op."""
Expand Down
41 changes: 37 additions & 4 deletions tf2jax/_src/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,11 +697,44 @@ def cumsum_static():
@chex.variants(with_jit=True, without_jit=True)
@parameterized.named_parameters(
chex.params_product(
(("without_explicit_paddings", False),
("with_explicit_paddings", True)),
(("NHWC", "NHWC"), ("NCHW", "NCHW"),),
(
("exclusive", True),
("not_exclusive", False),
),
(("reverse", True), ("forward", False)),
named=True,
))
)
)
def test_cumprod(self, exclusive, reverse):
inputs = np.array(np.reshape(range(24), (4, 3, 2)), dtype=np.int32)

def cumprod_fn(xs):
return tf.raw_ops.Cumprod(
x=xs, axis=1, exclusive=exclusive, reverse=reverse
)

self._test_convert(cumprod_fn, inputs)

# Check static inputs result in static outputs.
def cumprod_static():
return tf.zeros(cumprod_fn(inputs)[0, -1])

self._test_convert(cumprod_static, [])

@chex.variants(with_jit=True, without_jit=True)
@parameterized.named_parameters(
chex.params_product(
(
("without_explicit_paddings", False),
("with_explicit_paddings", True),
),
(
("NHWC", "NHWC"),
("NCHW", "NCHW"),
),
named=True,
)
)
def test_depthwise_conv2d(self, use_explicit_paddings, data_format):
np.random.seed(42)

Expand Down

0 comments on commit 7adfe09

Please sign in to comment.