Skip to content

Commit

Permalink
Fix negative axes
Browse files Browse the repository at this point in the history
  • Loading branch information
michalk8 committed Oct 15, 2024
1 parent 78003d9 commit 9e1ae03
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 16 deletions.
48 changes: 40 additions & 8 deletions src/ott/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,8 @@ def _prepare_info(status: IOStatus) -> Tuple[int, int, int, np.ndarray]:
return iteration, inner_iterations, total_iter, errors


def _canonicalize_axis(axis: int, num_dims: int) -> int:
def _canonicalize_axis(axis: int, num_dims: int, *, has_extra_dim: bool) -> int:
num_dims -= has_extra_dim
if not -num_dims <= axis < num_dims:
raise ValueError(
f"axis {axis} is out of bounds for array of dimension {num_dims}"
Expand All @@ -241,11 +242,21 @@ def _canonicalize_axis(axis: int, num_dims: int) -> int:


def _prepare_axes(
name: str, leaves: Any, treedef: Any, axes: Any, *, return_flat: bool
name: str,
leaves: Any,
treedef: Any,
axes: Any,
*,
return_flat: bool,
has_extra_dim: bool,
) -> Any:
axes = jax.api_util.flatten_axes(name, treedef, axes, kws=False)
assert len(leaves) == len(axes), (len(leaves), len(axes))
# TODO(michalk8): enable negative axes
axes = [
axis if axis is None else
_canonicalize_axis(axis, jnp.ndim(leaf), has_extra_dim=has_extra_dim)
for axis, leaf in zip(axes, leaves)
]
return axes if return_flat else treedef.unflatten(axes)


Expand All @@ -255,23 +266,34 @@ def _batch_and_remainder(
batch_size: int,
in_axes: Optional[Union[int, Sequence[int], Any]],
) -> Tuple[Any, Any]:
assert batch_size > 0, f"Batch size must be positive, got {batch_size}. "
assert batch_size > 0, f"Batch size must be positive, got {batch_size}."
leaves, treedef = jax.tree.flatten(args, is_leaf=batching.is_vmappable)
in_axes = _prepare_axes(
"vmap in_axes", leaves, treedef, in_axes, return_flat=True
"vmap in_axes",
leaves,
treedef,
in_axes,
return_flat=True,
has_extra_dim=False,
)
assert not all(
axis is None for axis in in_axes
), "vmap must have at least one non-None value in in_axes"

num_splits = None
has_scan, has_remainder = False, False
scan_leaves, remainder_leaves = [], []

for leaf, axis in zip(leaves, in_axes):
if axis is None:
scan_leaf = remainder_leaf = leaf
else:
num_splits, _ = divmod(leaf.shape[axis], batch_size)
if num_splits is None:
num_splits, _ = divmod(leaf.shape[axis], batch_size)
else:
curr_num_splits, _ = divmod(leaf.shape[axis], batch_size)
# TODO(michalk8): better error message
assert num_splits == curr_num_splits, (num_splits, curr_num_splits)
num_elems = num_splits * batch_size

scan_leaf = jax.lax.slice_in_dim(leaf, None, num_elems, axis=axis)
Expand Down Expand Up @@ -321,7 +343,12 @@ def body_fn(carry: None, index: int) -> Tuple[None, R]:

leaves, treedef = jax.tree.flatten(args, is_leaf=batching.is_vmappable)
axes = _prepare_axes(
"vmap in_axes", leaves, treedef, in_axes, return_flat=True
"vmap in_axes",
leaves,
treedef,
in_axes,
return_flat=True,
has_extra_dim=True,
)
n = num_steps(axes, args)

Expand Down Expand Up @@ -370,7 +397,12 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
batched = batched_fun(*batched, **kwargs)
leaves, treedef = jax.tree.flatten(batched, is_leaf=batching.is_vmappable)
out_axes_ = _prepare_axes(
"vmap out_axes", leaves, treedef, out_axes, return_flat=False
"vmap out_axes",
leaves,
treedef,
out_axes,
return_flat=False,
has_extra_dim=True,
)
batched = jax.tree.map(unbatch, out_axes_, batched)
if has_remainder:
Expand Down
19 changes: 11 additions & 8 deletions tests/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ott import utils


@pytest.mark.fast()
class TestBatchedVmap:

@pytest.mark.parametrize("batch_size", [1, 11, 32, 33])
Expand Down Expand Up @@ -56,32 +57,34 @@ def f(x: Any) -> jnp.ndarray:

np.testing.assert_array_equal(gt_fn(x), fn(x))

@pytest.mark.parametrize("in_axes", [0, 1, [0, None]])
def test_in_axes(self, rng: jax.Array, in_axes: Any):
@pytest.mark.parametrize("batch_size", [1, 7, 67, 133])
@pytest.mark.parametrize("in_axes", [0, 1, -1, -2, [0, None], (0, -2)])
def test_in_axes(self, rng: jax.Array, in_axes: Any, batch_size: int):

def f(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
x = jnp.atleast_2d(x)
y = jnp.atleast_2d(y)
return jnp.dot(x, y.T)

rng1, rng2 = jax.random.split(rng, 2)
x = jax.random.normal(rng1, (15, 3)) + 10.0
y = jax.random.normal(rng2, (15, 3))
x = jax.random.normal(rng1, (133, 71)) + 10.0
y = jax.random.normal(rng2, (133, 71))

gt_fn = jax.jit(jax.vmap(f, in_axes=in_axes))
fn = jax.jit(utils.batched_vmap(f, batch_size=5, in_axes=in_axes))
fn = jax.jit(utils.batched_vmap(f, batch_size=batch_size, in_axes=in_axes))

np.testing.assert_array_equal(gt_fn(x, y), fn(x, y))
# TODO(michalk8): check this
np.testing.assert_allclose(gt_fn(x, y), fn(x, y), rtol=1e-4, atol=1e-4)

@pytest.mark.parametrize("out_axes", [0, 1, 2])
@pytest.mark.parametrize("out_axes", [0, 1, 2, -1, -2, -3])
def test_out_axes(self, rng: jax.Array, out_axes: int):

def f(x: jnp.ndarray, y: jnp.ndarray) -> Any:
return (x.sum() + y.sum()).reshape(1, 1)

rng1, rng2 = jax.random.split(rng, 2)
x = jax.random.normal(rng1, (31, 13))
y = jax.random.normal(rng2, (31, 3))
y = jax.random.normal(rng2, (31, 6)) - 15.0

gt_fn = jax.vmap(f, out_axes=out_axes)
fn = utils.batched_vmap(f, batch_size=5, out_axes=out_axes)
Expand Down

0 comments on commit 9e1ae03

Please sign in to comment.