Skip to content

Commit

Permalink
Disable TF tests on macos
Browse files Browse the repository at this point in the history
  • Loading branch information
romanngg committed Sep 12, 2024
1 parent 471e3a9 commit ac477d8
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/macos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
python-version: ['3.10', 3.11, 3.12]
JAX_ENABLE_X64: [0]

runs-on: macos-14
runs-on: macos-latest

steps:

Expand Down
9 changes: 9 additions & 0 deletions tests/experimental/empirical_tf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Tests for `experimental/empirical_tf/empirical.py`."""

import platform
from absl.testing import absltest
from absl.testing import parameterized
import jax
Expand Down Expand Up @@ -285,6 +286,8 @@ def test_keras_functional(
diagonal_axes,
vmap_axes,
):
if platform.system() == 'Darwin':
self.skipTest('TF <-> JAX fails on MacOS.')
with jax.numpy_rank_promotion("warn"):
f = f(classes=1, input_shape=input_shape, weights=None)
f.build((None, *input_shape))
Expand Down Expand Up @@ -314,6 +317,8 @@ def test_keras_sequential(
diagonal_axes,
vmap_axes,
):
if platform.system() == 'Darwin':
self.skipTest('TF <-> JAX fails on MacOS.')
with jax.numpy_rank_promotion("warn"):
f = keras.Sequential()
f.add(keras.layers.Conv2D(4, (3, 3), activation='relu'))
Expand Down Expand Up @@ -355,6 +360,8 @@ def test_tf_function(
diagonal_axes,
vmap_axes,
):
if platform.system() == 'Darwin':
self.skipTest('TF <-> JAX fails on MacOS.')
f, f_jax = f_f_jax
f = tf.function(f, input_signature=_input_signature)
params = tf.random.normal(params_shape, seed=4)
Expand All @@ -379,6 +386,8 @@ def test_tf_module(
diagonal_axes,
vmap_axes,
):
if platform.system() == 'Darwin':
self.skipTest('TF <-> JAX fails on MacOS.')
f = _MLP(input_size=5, sizes=[4, 6, 3], name='MLP')
f_jax, params = experimental.get_apply_fn_and_params(f)
self._compare_ntks(f, f_jax, params, trace_axes, diagonal_axes, vmap_axes)
Expand Down

0 comments on commit ac477d8

Please sign in to comment.