Skip to content

Commit

Permalink
Add back adaptive pool 2d
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Jan 2, 2025
1 parent 37164ce commit 637cf2f
Showing 1 changed file with 71 additions and 4 deletions.
75 changes: 71 additions & 4 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Torch ops implemented using jax."""

import sys
from typing import Optional, Sequence, Tuple, Union
from typing import Optional, Sequence, Tuple, Union, Callable
import functools

import math
Expand Down Expand Up @@ -1889,19 +1889,19 @@ def pool(inputs, init, reduce_fn, window_shape, strides, padding):
y = jnp.squeeze(y, axis=0)
return y


@op(torch.ops.aten._adaptive_avg_pool3d)
def _aten_adaptive_avg_pool3d(x, output_shape):
assert len(x.shape) in (4,5), f'Expected 4D or 5D input but got {len(x.shape)} dimensions'
assert len(output_shape) == 3, f'Expected 3D output but got {len(output_shape)} dimensions'

# Reference PyTorch implementation:
# https://github.com/pytorch/pytorch/blob/ef4475f9025b3c46a13bdd054b6adfbcb5f8ab8c/aten/src/ATen/native/AdaptiveAveragePooling.cpp
output_shape = x.shape[:-3] + tuple(output_shape)
output = jnp.zeros(output_shape, dtype = x.dtype)
stride_d = x.shape[-3] / output_shape[-3]
stride_h = x.shape[-2] / output_shape[-2]
stride_w = x.shape[-1] / output_shape[-1]

def avg_pool_batch(d, h, w):
start_d = int(jnp.floor(d * stride_d))
end_d = int(jnp.ceil((d+1) * stride_d))
Expand All @@ -1910,7 +1910,6 @@ def avg_pool_batch(d, h, w):
start_w = int(jnp.floor(w * stride_w))
end_w = int(jnp.ceil((w+1) * stride_w))
return jnp.mean(x[..., start_d:end_d, start_h:end_h, start_w:end_w], axis=(-3, -2, -1))

# TODO: Replace this with more performant implementation.
# Related JAX issue requiring adaptive pooling: https://github.com/jax-ml/jax/issues/20098
for d in range(output_shape[-3]):
Expand All @@ -1919,6 +1918,74 @@ def avg_pool_batch(d, h, w):
output = output.at[..., d, h, w].set(avg_pool_batch(d, h, w))
return output

@op(torch.ops.aten._adaptive_avg_pool2d)
def _aten_adaptive_avg_pool2d(x, output_shape):
# https://github.com/patrick-kidger/equinox/blob/fe4121d96e3cb246d8acc1647c9c7a5b73d753dc/equinox/nn/_pool.py#L525
assert len(x.shape) in (3,4), f'Expected 3D or 4D input but got {len(x.shape)} dimensions'
assert len(output_shape) == 2, f'Expected 2D output but got {len(output_shape)} dimensions'

if np.prod(x.shape) == 0:
return jnp.zeros((*x.shape[:-len(output_shape)], *output_shape), dtype=x.dtype)

batch_dims = None
if x.ndim == 4:
batch_dims = x.shape[:2]
x = x.reshape(-1, *x.shape[2:])
res = _adaptive_poolnd(x, output_shape, 2, jnp.mean)
if batch_dims is not None:
res = res.reshape((*batch_dims, *output_shape))
return res

def _adaptive_pool1d(
x: jax.Array, target_size: int, operation: Callable[[jax.Array], jax.Array]
) -> jax.Array:
"""**Arguments:**
- `x`: The input. Should be a JAX array of shape `(dim,)`.
- `target_size`: The shape of the output after the pooling operation
`(target_size,)`.
- `operation`: The pooling operation to be performed on the input array.
**Returns:**
A JAX array of shape `(1, target_shape)`.
"""
dims = jnp.size(x)
num_head_arrays = dims % target_size
if num_head_arrays != 0:
head_end_index = num_head_arrays * (dims // target_size + 1)
head_op = jax.vmap(operation)(x[:head_end_index].reshape(num_head_arrays, -1))
tail_op = jax.vmap(operation)(
x[head_end_index:].reshape(-1, dims // target_size)
)
outputs = jnp.concatenate([head_op, tail_op])
else:
outputs = jax.vmap(operation)(
jax.vmap(operation)(x.reshape(-1, dims // target_size))
)
return outputs


def _adaptive_poolnd(
x: jax.Array, target_size: int, num_spatial_dims: int,
operation: Callable[[jax.Array], jax.Array]):

if x.ndim - 1 != len(target_size):
raise ValueError(
f"Expected input with {len(target_size)} dimensions, "
f"received {x.ndim-1} instead."
)
for i in range(1, x.ndim):
op = jax.vmap(
_adaptive_pool1d, (0, None, None), 0
) # batching over channels by default
for j in range(1, x.ndim):
if i == j:
continue
op = jax.vmap(op, in_axes=(j, None, None), out_axes=j)
x = op(x, target_size[i - 1], operation)
return x


@op(torch.ops.aten.avg_pool1d)
@op(torch.ops.aten.avg_pool2d)
Expand Down

0 comments on commit 637cf2f

Please sign in to comment.