Skip to content

Commit

Permalink
Temporary comment since I have tf2jax library changes on the diffbase.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 555222733
  • Loading branch information
TF2JAXDev authored and TF2JAXDev committed Aug 9, 2023
1 parent 7adfe09 commit 79c2115
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
12 changes: 11 additions & 1 deletion tf2jax/_src/numpy_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,12 +227,22 @@ def broadcast_to(arr, shape):
flip = lambda arr, axis: _get_np(arr).flip(arr, axis=axis)
roll = lambda arr, shift, axis: _get_np(arr).roll(arr, shift=shift, axis=axis)
split = lambda arr, sections, axis: _get_np(arr).split(arr, sections, axis=axis)
squeeze = lambda arr, axis: _get_np(arr).squeeze(arr, axis=axis)
stack = lambda arrs, axis: _get_np(*arrs).stack(arrs, axis=axis)
tile = lambda arr, reps: _get_np(arr, reps).tile(arr, reps=reps)
where = lambda cond, x, y: _get_np(cond, x, y).where(cond, x, y)


def squeeze(arr, axis):
# tf.squeeze and np/jnp.squeeze have different behaviors when axis=().
# - tf.squeeze will squeeze all dimensions.
# - np/jnp.squeeze will not squeeze any dimensions.
# Here we change () to None to ensure that squeeze has the same behavior
# when converted from tf to np/jnp.
if axis == tuple():
axis = None
return _get_np(arr).squeeze(arr, axis=axis)


def moveaxis(
arr,
source: Union[int, Sequence[int]],
Expand Down
5 changes: 3 additions & 2 deletions tf2jax/_src/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1458,8 +1458,9 @@ def roll_static():
self._test_convert(roll_static, [])

@chex.variants(with_jit=True, without_jit=True)
def test_squeeze(self):
inputs, dims = np.array([[[42], [47]]]), (0, 2)
@parameterized.parameters(((0, 2),), (tuple(),), (None,))
def test_squeeze(self, dims):
inputs = np.array([[[42], [47]]])

def squeeze(x):
return tf.raw_ops.Squeeze(input=x, axis=dims)
Expand Down

0 comments on commit 79c2115

Please sign in to comment.