Skip to content

Commit

Permalink
Add unravel_index op (keras-team#20559)
Browse files Browse the repository at this point in the history
* Add unravel_index

* Fix Tensorflow

* Fix Tensorflow impl

* fix default np.int64

* fix

* Fix torch

* fix numpy and torch

* api

* fix

* Fix tensorflow impl and docstring

* fix

* shape None case

* shape None case

* fix
  • Loading branch information
IMvision12 authored Dec 2, 2024
1 parent 0b8c6a4 commit ab53ed2
Show file tree
Hide file tree
Showing 10 changed files with 171 additions and 0 deletions.
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@
from keras.src.ops.numpy import triu
from keras.src.ops.numpy import true_divide
from keras.src.ops.numpy import trunc
from keras.src.ops.numpy import unravel_index
from keras.src.ops.numpy import var
from keras.src.ops.numpy import vdot
from keras.src.ops.numpy import vectorize
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@
from keras.src.ops.numpy import triu
from keras.src.ops.numpy import true_divide
from keras.src.ops.numpy import trunc
from keras.src.ops.numpy import unravel_index
from keras.src.ops.numpy import var
from keras.src.ops.numpy import vdot
from keras.src.ops.numpy import vectorize
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@
from keras.src.ops.numpy import triu
from keras.src.ops.numpy import true_divide
from keras.src.ops.numpy import trunc
from keras.src.ops.numpy import unravel_index
from keras.src.ops.numpy import var
from keras.src.ops.numpy import vdot
from keras.src.ops.numpy import vectorize
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@
from keras.src.ops.numpy import triu
from keras.src.ops.numpy import true_divide
from keras.src.ops.numpy import trunc
from keras.src.ops.numpy import unravel_index
from keras.src.ops.numpy import var
from keras.src.ops.numpy import vdot
from keras.src.ops.numpy import vectorize
Expand Down
5 changes: 5 additions & 0 deletions keras/src/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,6 +904,11 @@ def ravel(x):
return jnp.ravel(x)


def unravel_index(x, shape):
x = convert_to_tensor(x)
return jnp.unravel_index(x, shape)


@sparse.elementwise_unary(linear=True)
def real(x):
x = convert_to_tensor(x)
Expand Down
7 changes: 7 additions & 0 deletions keras/src/backend/numpy/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,13 @@ def ravel(x):
return np.ravel(x)


def unravel_index(x, shape):
dtype = dtypes.result_type(x.dtype)
return tuple(
indices.astype(dtype) for indices in np.unravel_index(x, shape)
)


def real(x):
return np.real(x)

Expand Down
25 changes: 25 additions & 0 deletions keras/src/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1855,6 +1855,31 @@ def ravel(x):
return tf.reshape(x, [-1])


def unravel_index(x, shape):
x = tf.convert_to_tensor(x)
input_dtype = x.dtype

if None in shape:
raise ValueError(
"`shape` argument cannot contain `None`. Received: shape={shape}"
)

if x.ndim == 1:
coords = []
for dim in reversed(shape):
coords.append(tf.cast(x % dim, input_dtype))
x = x // dim
return tuple(reversed(coords))

x_shape = x.shape
coords = []
for dim in shape:
coords.append(tf.reshape(tf.cast(x % dim, input_dtype), x_shape))
x = x // dim

return tuple(reversed(coords))


@sparse.elementwise_unary
def real(x):
x = convert_to_tensor(x)
Expand Down
6 changes: 6 additions & 0 deletions keras/src/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1214,6 +1214,12 @@ def ravel(x):
return torch.ravel(x)


def unravel_index(x, shape):
x = convert_to_tensor(x)
dtype = dtypes.result_type(x.dtype)
return tuple(cast(idx, dtype) for idx in torch.unravel_index(x, shape))


def real(x):
if not isinstance(x, torch.Tensor):
x = torch.from_numpy(x) # needed for complex type conversion
Expand Down
51 changes: 51 additions & 0 deletions keras/src/ops/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4659,6 +4659,57 @@ def ravel(x):
return backend.numpy.ravel(x)


class UnravelIndex(Operation):
def __init__(self, shape):
self.shape = shape
self._inbound_nodes = []

def call(self, indices):
return backend.numpy.unravel_index(indices, self.shape)

def compute_output_spec(self, indices):
if None in self.shape:
output_shapes = [[None] for _ in self.shape]
else:
if isinstance(indices, int):
output_shapes = [[1] for _ in self.shape]
elif hasattr(indices, "shape"):
output_shapes = [list(indices.shape) for _ in self.shape]
else:
try:
indices_shape = np.shape(indices)
output_shapes = [list(indices_shape) for _ in self.shape]
except Exception:
output_shapes = [[None] for _ in self.shape]

return [
KerasTensor(shape, dtype=indices.dtype) for shape in output_shapes
]


@keras_export(["keras.ops.unravel_index", "keras.ops.numpy.unravel_index"])
def unravel_index(indices, shape):
"""Convert flat indices to coordinate arrays in a given array shape.
Args:
indices: An integer or array of integers representing flat indices.
shape: The shape of the array to unravel into.
Returns:
Tuple of arrays for each dimension with unraveled indices.
Example:
>>> indices = 5
>>> shape = (3, 3)
>>> unravel_index(indices, shape)
(1, 2) # 5 is at row 1, column 2 in a 3x3 array
"""
if any_symbolic_tensors((indices,)):
return UnravelIndex(shape).symbolic_call(indices)

return backend.numpy.unravel_index(indices, shape)


class Real(Operation):
def call(self, x):
return backend.numpy.real(x)
Expand Down
73 changes: 73 additions & 0 deletions keras/src/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1455,6 +1455,26 @@ def test_ravel(self):
x = KerasTensor((None, 3))
self.assertEqual(knp.ravel(x).shape, (None,))

def test_unravel_index(self):
x = KerasTensor((None,))
indices = knp.unravel_index(x, (2, 3))
self.assertEqual(len(indices), 2)
self.assertEqual(indices[0].shape, (None,))
self.assertEqual(indices[1].shape, (None,))

x = KerasTensor((None, 4))
indices = knp.unravel_index(x, (3, 4))
self.assertEqual(len(indices), 2)
self.assertEqual(indices[0].shape, (None, 4))
self.assertEqual(indices[1].shape, (None, 4))

x = KerasTensor((None, 3, 2))
indices = knp.unravel_index(x, (5, 6, 4))
self.assertEqual(len(indices), 3)
self.assertEqual(indices[0].shape, (None, 3, 2))
self.assertEqual(indices[1].shape, (None, 3, 2))
self.assertEqual(indices[2].shape, (None, 3, 2))

def test_real(self):
x = KerasTensor((None, 3))
self.assertEqual(knp.real(x).shape, (None, 3))
Expand Down Expand Up @@ -1999,6 +2019,19 @@ def test_ravel(self):
x = KerasTensor((2, 3))
self.assertEqual(knp.ravel(x).shape, (6,))

def test_unravel_index(self):
x = KerasTensor((6,))
indices = knp.unravel_index(x, (2, 3))
self.assertEqual(len(indices), 2)
self.assertEqual(indices[0].shape, (6,))
self.assertEqual(indices[1].shape, (6,))

x = KerasTensor((2, 3))
indices = knp.unravel_index(x, (3, 4))
self.assertEqual(len(indices), 2)
self.assertEqual(indices[0].shape, (2, 3))
self.assertEqual(indices[1].shape, (2, 3))

def test_real(self):
x = KerasTensor((2, 3))
self.assertEqual(knp.real(x).shape, (2, 3))
Expand Down Expand Up @@ -4114,6 +4147,19 @@ def test_ravel(self):
self.assertAllClose(knp.ravel(x), np.ravel(x))
self.assertAllClose(knp.Ravel()(x), np.ravel(x))

def test_unravel_index(self):
x = np.array([0, 1, 2, 3])
shape = (2, 2)
self.assertAllClose(
knp.unravel_index(x, shape), np.unravel_index(x, shape)
)

x = np.array([[0, 1], [2, 3]])
shape = (2, 2)
self.assertAllClose(
knp.unravel_index(x, shape), np.unravel_index(x, shape)
)

def test_real(self):
x = np.array([[1, 2, 3 - 3j], [3, 2, 1 + 5j]])
self.assertAllClose(knp.real(x), np.real(x))
Expand Down Expand Up @@ -7789,6 +7835,33 @@ def test_ravel(self, dtype):
expected_dtype,
)

@parameterized.named_parameters(named_product(dtype=INT_DTYPES))
def test_unravel_index(self, dtype):
import jax.numpy as jnp

x = knp.ones((3,), dtype=dtype)
x_jax = jnp.ones((3,), dtype=dtype)

indices = knp.array([2, 0], dtype=dtype)
indices_jax = jnp.array([2, 0], dtype=dtype)

unravel_result_knp = knp.unravel_index(indices, x.shape)
unravel_result_jax = jnp.unravel_index(indices_jax, x_jax.shape)

expected_dtype_knp = standardize_dtype(unravel_result_knp[0].dtype)
expected_dtype_jax = standardize_dtype(unravel_result_jax[0].dtype)

self.assertEqual(expected_dtype_knp, expected_dtype_jax)

unravel_result_knp_symbolic = knp.UnravelIndex(x.shape).symbolic_call(
indices
)
expected_dtype_symbolic = standardize_dtype(
unravel_result_knp_symbolic[0].dtype
)

self.assertEqual(expected_dtype_symbolic, expected_dtype_jax)

@parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
def test_repeat(self, dtype):
import jax.numpy as jnp
Expand Down

0 comments on commit ab53ed2

Please sign in to comment.