diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 7a0536d99..ad9de9469 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -171,8 +171,14 @@ jobs: run: | sudo apt-get update sudo apt-get install -y protobuf-compiler + + pip install tensorflow - pip install . + protoc -I=. --python_out=. $(find orbax/experimental/model/core/ -name "*.proto") + echo "!!! After protoc" + ls -l orbax/experimental/model/core/protos/ + + pip install -e . pip install .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html if [[ "${{ matrix.jax-version }}" == "newest" ]]; then pip install -U jax jaxlib @@ -181,3 +187,6 @@ jobs: else pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" fi + - name: Test with pytest + run: | + pytest orbax/experimental/model/core/python/ diff --git a/model/orbax/experimental/model/core/python/module_test.py b/model/orbax/experimental/model/core/python/module_test.py index fdea2fd45..bfcc9417d 100644 --- a/model/orbax/experimental/model/core/python/module_test.py +++ b/model/orbax/experimental/model/core/python/module_test.py @@ -157,4 +157,4 @@ def fn_2(inputs): if __name__ == "__main__": - googletest.main() + absltest.main() diff --git a/model/orbax/experimental/model/core/python/save_lib_test.py b/model/orbax/experimental/model/core/python/save_lib_test.py index c2d371236..f29c20c16 100644 --- a/model/orbax/experimental/model/core/python/save_lib_test.py +++ b/model/orbax/experimental/model/core/python/save_lib_test.py @@ -18,6 +18,7 @@ import os from typing import Tuple +from absl.testing import absltest import jax from jax import export as jax_export import jax.numpy as jnp @@ -33,8 +34,6 @@ from orbax.experimental.model.core.python.shlo_function import ShloFunction import tensorflow as tf -from absl.testing import absltest - save = save_lib.save Tensor = concrete_function.Tensor Function = concrete_function.ConcreteFunction @@ -70,10 +69,11 @@ def jax_spec_to_tensor_spec(x: jax.ShapeDtypeStruct) -> TensorSpec: return TensorSpec(shape=x.shape, dtype=dtype_from_np_dtype(x.dtype)) -class SaveTest(googletest.TestCase): +class SaveTest(absltest.TestCase): # TODO(qidichen): We can move relevant parts of test from orbax/experimental/model/integration_tests/orbax_model_test.py here. def test_save(self): pass + if __name__ == '__main__': - googletest.main() + absltest.main() diff --git a/model/orbax/experimental/model/core/python/saved_model_proto/node_builder_test.py b/model/orbax/experimental/model/core/python/saved_model_proto/node_builder_test.py index 798338fa0..95cd6212d 100644 --- a/model/orbax/experimental/model/core/python/saved_model_proto/node_builder_test.py +++ b/model/orbax/experimental/model/core/python/saved_model_proto/node_builder_test.py @@ -22,7 +22,7 @@ DataType = types_pb2.DataType -class NodeBuilderTest(googletest.TestCase): +class NodeBuilderTest(absltest.TestCase): def test_add_nodes(self): builder = node_builder.NodeBuilder() @@ -61,4 +61,4 @@ def test_add_nodes(self): if __name__ == '__main__': - googletest.main() + absltest.main() diff --git a/model/orbax/experimental/model/core/python/saved_model_proto/nodes_test.py b/model/orbax/experimental/model/core/python/saved_model_proto/nodes_test.py index ef1872b46..b2d5ce7f2 100644 --- a/model/orbax/experimental/model/core/python/saved_model_proto/nodes_test.py +++ b/model/orbax/experimental/model/core/python/saved_model_proto/nodes_test.py @@ -27,7 +27,7 @@ OpType = nodes.OpType -class NodesTest(googletest.TestCase): +class NodesTest(absltest.TestCase): def assertProtoEqual(self, a, b): compare.assertProto2Equal(self, a, b) @@ -102,4 +102,4 @@ class FakeNode(Node): if __name__ == '__main__': - googletest.main() + absltest.main() diff --git a/model/orbax/experimental/model/core/python/saved_model_proto/saved_model_builder_test.py b/model/orbax/experimental/model/core/python/saved_model_proto/saved_model_builder_test.py index 96f3d411d..89cb48cdb 100644 --- a/model/orbax/experimental/model/core/python/saved_model_proto/saved_model_builder_test.py +++ b/model/orbax/experimental/model/core/python/saved_model_proto/saved_model_builder_test.py @@ -58,7 +58,7 @@ def jax_spec_to_tensor_spec(x: jax.ShapeDtypeStruct) -> TensorSpec: ) -class SavedModelBuilderTest(googletest.TestCase): +class SavedModelBuilderTest(absltest.TestCase): def test_build_with_function_aliases(self): @jax.jit @@ -110,4 +110,4 @@ def f(xy, captured): if __name__ == "__main__": - googletest.main() + absltest.main() diff --git a/model/orbax/experimental/model/core/python/saved_model_proto/tensor_proto_test.py b/model/orbax/experimental/model/core/python/saved_model_proto/tensor_proto_test.py index 237b4aff9..c69b5f60e 100644 --- a/model/orbax/experimental/model/core/python/saved_model_proto/tensor_proto_test.py +++ b/model/orbax/experimental/model/core/python/saved_model_proto/tensor_proto_test.py @@ -72,4 +72,4 @@ def test_string_list(self): self.assertEqual(s, back_tensor.np_array[n]) if __name__ == '__main__': - googletest.main() + absltest.main() diff --git a/model/orbax/experimental/model/core/python/simple_orchestration_test.py b/model/orbax/experimental/model/core/python/simple_orchestration_test.py index fa72cb71f..ec85965ee 100644 --- a/model/orbax/experimental/model/core/python/simple_orchestration_test.py +++ b/model/orbax/experimental/model/core/python/simple_orchestration_test.py @@ -146,4 +146,4 @@ def f(): if __name__ == "__main__": - googletest.main() + absltest.main() diff --git a/model/orbax/experimental/model/core/python/tree_util_test.py b/model/orbax/experimental/model/core/python/tree_util_test.py index e97fe1806..33fe284a8 100644 --- a/model/orbax/experimental/model/core/python/tree_util_test.py +++ b/model/orbax/experimental/model/core/python/tree_util_test.py @@ -20,7 +20,7 @@ from absl.testing import absltest -class TreeUtilTest(googletest.TestCase): +class TreeUtilTest(absltest.TestCase): def setUp(self): super().setUp() @@ -79,4 +79,4 @@ def assert_int(x: Any) -> None: if __name__ == "__main__": - googletest.main() + absltest.main() diff --git a/model/orbax/experimental/model/core/python/type_proto_util_test.py b/model/orbax/experimental/model/core/python/type_proto_util_test.py index 300b80ab4..aac9833ee 100644 --- a/model/orbax/experimental/model/core/python/type_proto_util_test.py +++ b/model/orbax/experimental/model/core/python/type_proto_util_test.py @@ -238,4 +238,4 @@ def test_manifest_type_to_shlo_tensor_spec_tree(self): if __name__ == "__main__": - googletest.main() + absltest.main() diff --git a/model/orbax/experimental/model/core/python/unstructured_data_test.py b/model/orbax/experimental/model/core/python/unstructured_data_test.py index 50a3d77cb..8c9fd04f1 100644 --- a/model/orbax/experimental/model/core/python/unstructured_data_test.py +++ b/model/orbax/experimental/model/core/python/unstructured_data_test.py @@ -19,7 +19,7 @@ from absl.testing import absltest -class UnstructuredDataTest(googletest.TestCase): +class UnstructuredDataTest(absltest.TestCase): def test_write_inlined_string_to_file(self): proto = unstructured_data.UnstructuredData() @@ -119,4 +119,4 @@ def test_maybe_write_location_pointer_to_file(self): if __name__ == "__main__": - googletest.main() + absltest.main() diff --git a/model/orbax/experimental/model/core/python/util/compat_test.py b/model/orbax/experimental/model/core/python/util/compat_test.py index fa4e86718..085339cb9 100644 --- a/model/orbax/experimental/model/core/python/util/compat_test.py +++ b/model/orbax/experimental/model/core/python/util/compat_test.py @@ -33,7 +33,7 @@ from absl.testing import absltest -class CompatTest(googletest.TestCase): +class CompatTest(absltest.TestCase): def testCompatValidEncoding(self): self.assertEqual(compat.as_bytes("hello", "utf8"), b"hello") @@ -48,4 +48,4 @@ def testCompatInvalidEncoding(self): if __name__ == "__main__": - googletest.main() + absltest.main() diff --git a/model/orbax/experimental/model/core/python/util/naming_test.py b/model/orbax/experimental/model/core/python/util/naming_test.py index 3241be3c8..01d93b2da 100644 --- a/model/orbax/experimental/model/core/python/util/naming_test.py +++ b/model/orbax/experimental/model/core/python/util/naming_test.py @@ -18,7 +18,7 @@ from absl.testing import absltest -class NamingTest(googletest.TestCase): +class NamingTest(absltest.TestCase): def test_validate_node_name(self): self.assertTrue(naming.is_valid_node_name('correct_name_1534')) @@ -30,4 +30,4 @@ def test_validate_node_name(self): if __name__ == '__main__': - googletest.main() + absltest.main() diff --git a/model/orbax/experimental/model/core/python/util/object_identity_test.py b/model/orbax/experimental/model/core/python/util/object_identity_test.py index 1bf1173ed..6f9ce7b65 100644 --- a/model/orbax/experimental/model/core/python/util/object_identity_test.py +++ b/model/orbax/experimental/model/core/python/util/object_identity_test.py @@ -32,7 +32,7 @@ from absl.testing import absltest -class ObjectIdentityWrapperTest(googletest.TestCase): +class ObjectIdentityWrapperTest(absltest.TestCase): def testWrapperNotEqualToWrapped(self): class SettableHash(object): @@ -63,7 +63,7 @@ def __hash__(self): bool(o in set([wrap1])) -class ObjectIdentitySetTest(googletest.TestCase): +class ObjectIdentitySetTest(absltest.TestCase): def testDifference(self): class Element(object): @@ -96,4 +96,4 @@ def testClear(self): if __name__ == '__main__': - googletest.main() + absltest.main() diff --git a/model/orbax/experimental/model/jax2obm/jax_specific_info_test.py b/model/orbax/experimental/model/jax2obm/jax_specific_info_test.py index 29f05be41..0b7031cf1 100644 --- a/model/orbax/experimental/model/jax2obm/jax_specific_info_test.py +++ b/model/orbax/experimental/model/jax2obm/jax_specific_info_test.py @@ -172,4 +172,4 @@ def test_to_shlo_dtype_and_refinement_wrong_type(self, jax_dtype): if __name__ == '__main__': - googletest.main() + absltest.main() diff --git a/model/orbax/experimental/model/jax2obm/main_lib_test.py b/model/orbax/experimental/model/jax2obm/main_lib_test.py index dacfc7a18..9b2eae59f 100644 --- a/model/orbax/experimental/model/jax2obm/main_lib_test.py +++ b/model/orbax/experimental/model/jax2obm/main_lib_test.py @@ -1205,4 +1205,4 @@ def get_mesh(): if __name__ == '__main__': - googletest.main() + absltest.main() diff --git a/model/orbax/experimental/model/jax2obm/obm_to_jax_test.py b/model/orbax/experimental/model/jax2obm/obm_to_jax_test.py index 6a6948b42..d3a50c1ac 100644 --- a/model/orbax/experimental/model/jax2obm/obm_to_jax_test.py +++ b/model/orbax/experimental/model/jax2obm/obm_to_jax_test.py @@ -538,4 +538,4 @@ def get_mesh(): if __name__ == '__main__': - googletest.main() + absltest.main() diff --git a/model/orbax/experimental/model/jax2obm/sharding_test.py b/model/orbax/experimental/model/jax2obm/sharding_test.py index 4ca2f18d9..3670e3587 100644 --- a/model/orbax/experimental/model/jax2obm/sharding_test.py +++ b/model/orbax/experimental/model/jax2obm/sharding_test.py @@ -22,7 +22,7 @@ from absl.testing import absltest -class ShardingTest(googletest.TestCase): +class ShardingTest(absltest.TestCase): def get_mesh(self): def _create_mesh(devices) -> jax.sharding.Mesh: @@ -182,4 +182,4 @@ def test_jax_mesh_to_obm_device_assignment_by_coords(self): if __name__ == '__main__': - googletest.main() + absltest.main() diff --git a/model/orbax/experimental/model/tf2obm/tf_concrete_functions_to_obm_test.py b/model/orbax/experimental/model/tf2obm/tf_concrete_functions_to_obm_test.py index b5980851b..2daaef239 100644 --- a/model/orbax/experimental/model/tf2obm/tf_concrete_functions_to_obm_test.py +++ b/model/orbax/experimental/model/tf2obm/tf_concrete_functions_to_obm_test.py @@ -422,4 +422,4 @@ def tf_fn(a): if __name__ == "__main__": - googletest.main() + absltest.main() diff --git a/model/orbax/experimental/model/tf2obm/tf_function_def_to_obm_test.py b/model/orbax/experimental/model/tf2obm/tf_function_def_to_obm_test.py index 6218b7cbf..550c9d99a 100644 --- a/model/orbax/experimental/model/tf2obm/tf_function_def_to_obm_test.py +++ b/model/orbax/experimental/model/tf2obm/tf_function_def_to_obm_test.py @@ -19,7 +19,7 @@ from absl.testing import absltest -class TfFunctionDefToObmTest(googletest.TestCase): +class TfFunctionDefToObmTest(absltest.TestCase): def test_tf_concrete_function_to_unstructured_data_success(self): @@ -56,4 +56,4 @@ def tf_pow(a): if __name__ == "__main__": - googletest.main() + absltest.main()