Skip to content

Commit

Permalink
internal
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 728826659
  • Loading branch information
qdhack authored and Orbax Authors committed Feb 20, 2025
1 parent 95595be commit d2bb793
Show file tree
Hide file tree
Showing 20 changed files with 43 additions and 34 deletions.
11 changes: 10 additions & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,14 @@ jobs:
run: |
sudo apt-get update
sudo apt-get install -y protobuf-compiler
pip install tensorflow
pip install .
protoc -I=. --python_out=. $(find orbax/experimental/model/core/ -name "*.proto")
echo "!!! After protoc"
ls -l orbax/experimental/model/core/protos/
pip install -e .
pip install .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
if [[ "${{ matrix.jax-version }}" == "newest" ]]; then
pip install -U jax jaxlib
Expand All @@ -181,3 +187,6 @@ jobs:
else
pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}"
fi
- name: Test with pytest
run: |
pytest orbax/experimental/model/core/python/
2 changes: 1 addition & 1 deletion model/orbax/experimental/model/core/python/module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,4 +157,4 @@ def fn_2(inputs):


if __name__ == "__main__":
googletest.main()
absltest.main()
8 changes: 4 additions & 4 deletions model/orbax/experimental/model/core/python/save_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os
from typing import Tuple

from absl.testing import absltest
import jax
from jax import export as jax_export
import jax.numpy as jnp
Expand All @@ -33,8 +34,6 @@
from orbax.experimental.model.core.python.shlo_function import ShloFunction
import tensorflow as tf

from absl.testing import absltest

save = save_lib.save
Tensor = concrete_function.Tensor
Function = concrete_function.ConcreteFunction
Expand Down Expand Up @@ -70,10 +69,11 @@ def jax_spec_to_tensor_spec(x: jax.ShapeDtypeStruct) -> TensorSpec:
return TensorSpec(shape=x.shape, dtype=dtype_from_np_dtype(x.dtype))


class SaveTest(googletest.TestCase):
class SaveTest(absltest.TestCase):
# TODO(qidichen): We can move relevant parts of test from orbax/experimental/model/integration_tests/orbax_model_test.py here.
def test_save(self):
pass


if __name__ == '__main__':
googletest.main()
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
DataType = types_pb2.DataType


class NodeBuilderTest(googletest.TestCase):
class NodeBuilderTest(absltest.TestCase):

def test_add_nodes(self):
builder = node_builder.NodeBuilder()
Expand Down Expand Up @@ -61,4 +61,4 @@ def test_add_nodes(self):


if __name__ == '__main__':
googletest.main()
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
OpType = nodes.OpType


class NodesTest(googletest.TestCase):
class NodesTest(absltest.TestCase):

def assertProtoEqual(self, a, b):
compare.assertProto2Equal(self, a, b)
Expand Down Expand Up @@ -102,4 +102,4 @@ class FakeNode(Node):


if __name__ == '__main__':
googletest.main()
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def jax_spec_to_tensor_spec(x: jax.ShapeDtypeStruct) -> TensorSpec:
)


class SavedModelBuilderTest(googletest.TestCase):
class SavedModelBuilderTest(absltest.TestCase):

def test_build_with_function_aliases(self):
@jax.jit
Expand Down Expand Up @@ -110,4 +110,4 @@ def f(xy, captured):


if __name__ == "__main__":
googletest.main()
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,4 @@ def test_string_list(self):
self.assertEqual(s, back_tensor.np_array[n])

if __name__ == '__main__':
googletest.main()
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -146,4 +146,4 @@ def f():


if __name__ == "__main__":
googletest.main()
absltest.main()
4 changes: 2 additions & 2 deletions model/orbax/experimental/model/core/python/tree_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from absl.testing import absltest


class TreeUtilTest(googletest.TestCase):
class TreeUtilTest(absltest.TestCase):

def setUp(self):
super().setUp()
Expand Down Expand Up @@ -79,4 +79,4 @@ def assert_int(x: Any) -> None:


if __name__ == "__main__":
googletest.main()
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -238,4 +238,4 @@ def test_manifest_type_to_shlo_tensor_spec_tree(self):


if __name__ == "__main__":
googletest.main()
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from absl.testing import absltest


class UnstructuredDataTest(googletest.TestCase):
class UnstructuredDataTest(absltest.TestCase):

def test_write_inlined_string_to_file(self):
proto = unstructured_data.UnstructuredData()
Expand Down Expand Up @@ -119,4 +119,4 @@ def test_maybe_write_location_pointer_to_file(self):


if __name__ == "__main__":
googletest.main()
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from absl.testing import absltest


class CompatTest(googletest.TestCase):
class CompatTest(absltest.TestCase):

def testCompatValidEncoding(self):
self.assertEqual(compat.as_bytes("hello", "utf8"), b"hello")
Expand All @@ -48,4 +48,4 @@ def testCompatInvalidEncoding(self):


if __name__ == "__main__":
googletest.main()
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from absl.testing import absltest


class NamingTest(googletest.TestCase):
class NamingTest(absltest.TestCase):

def test_validate_node_name(self):
self.assertTrue(naming.is_valid_node_name('correct_name_1534'))
Expand All @@ -30,4 +30,4 @@ def test_validate_node_name(self):


if __name__ == '__main__':
googletest.main()
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from absl.testing import absltest


class ObjectIdentityWrapperTest(googletest.TestCase):
class ObjectIdentityWrapperTest(absltest.TestCase):

def testWrapperNotEqualToWrapped(self):
class SettableHash(object):
Expand Down Expand Up @@ -63,7 +63,7 @@ def __hash__(self):
bool(o in set([wrap1]))


class ObjectIdentitySetTest(googletest.TestCase):
class ObjectIdentitySetTest(absltest.TestCase):

def testDifference(self):
class Element(object):
Expand Down Expand Up @@ -96,4 +96,4 @@ def testClear(self):


if __name__ == '__main__':
googletest.main()
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,4 @@ def test_to_shlo_dtype_and_refinement_wrong_type(self, jax_dtype):


if __name__ == '__main__':
googletest.main()
absltest.main()
2 changes: 1 addition & 1 deletion model/orbax/experimental/model/jax2obm/main_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,4 +1205,4 @@ def get_mesh():


if __name__ == '__main__':
googletest.main()
absltest.main()
2 changes: 1 addition & 1 deletion model/orbax/experimental/model/jax2obm/obm_to_jax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,4 +538,4 @@ def get_mesh():


if __name__ == '__main__':
googletest.main()
absltest.main()
4 changes: 2 additions & 2 deletions model/orbax/experimental/model/jax2obm/sharding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from absl.testing import absltest


class ShardingTest(googletest.TestCase):
class ShardingTest(absltest.TestCase):

def get_mesh(self):
def _create_mesh(devices) -> jax.sharding.Mesh:
Expand Down Expand Up @@ -182,4 +182,4 @@ def test_jax_mesh_to_obm_device_assignment_by_coords(self):


if __name__ == '__main__':
googletest.main()
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -422,4 +422,4 @@ def tf_fn(a):


if __name__ == "__main__":
googletest.main()
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from absl.testing import absltest


class TfFunctionDefToObmTest(googletest.TestCase):
class TfFunctionDefToObmTest(absltest.TestCase):

def test_tf_concrete_function_to_unstructured_data_success(self):

Expand Down Expand Up @@ -56,4 +56,4 @@ def tf_pow(a):


if __name__ == "__main__":
googletest.main()
absltest.main()

0 comments on commit d2bb793

Please sign in to comment.