Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

internal #1620

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,14 @@ jobs:
run: |
sudo apt-get update
sudo apt-get install -y protobuf-compiler

pip install tensorflow

pip install .
protoc -I=. --python_out=. $(find orbax/experimental/model/core/ -name "*.proto")
echo "!!! After protoc"
ls -l orbax/experimental/model/core/protos/

pip install -e .
pip install .[testing] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
if [[ "${{ matrix.jax-version }}" == "newest" ]]; then
pip install -U jax jaxlib
Expand All @@ -181,3 +187,6 @@ jobs:
else
pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}"
fi
- name: Test with pytest
run: |
pytest orbax/experimental/model/core/python/
2 changes: 1 addition & 1 deletion model/orbax/experimental/model/core/python/module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,4 +157,4 @@ def fn_2(inputs):


if __name__ == "__main__":
googletest.main()
absltest.main()
8 changes: 4 additions & 4 deletions model/orbax/experimental/model/core/python/save_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os
from typing import Tuple

from absl.testing import absltest
import jax
from jax import export as jax_export
import jax.numpy as jnp
Expand All @@ -33,8 +34,6 @@
from orbax.experimental.model.core.python.shlo_function import ShloFunction
import tensorflow as tf

from absl.testing import absltest

save = save_lib.save
Tensor = concrete_function.Tensor
Function = concrete_function.ConcreteFunction
Expand Down Expand Up @@ -70,10 +69,11 @@ def jax_spec_to_tensor_spec(x: jax.ShapeDtypeStruct) -> TensorSpec:
return TensorSpec(shape=x.shape, dtype=dtype_from_np_dtype(x.dtype))


class SaveTest(googletest.TestCase):
class SaveTest(absltest.TestCase):
# TODO(qidichen): We can move relevant parts of test from orbax/experimental/model/integration_tests/orbax_model_test.py here.
def test_save(self):
pass


if __name__ == '__main__':
googletest.main()
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
DataType = types_pb2.DataType


class NodeBuilderTest(googletest.TestCase):
class NodeBuilderTest(absltest.TestCase):

def test_add_nodes(self):
builder = node_builder.NodeBuilder()
Expand Down Expand Up @@ -61,4 +61,4 @@ def test_add_nodes(self):


if __name__ == '__main__':
googletest.main()
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
OpType = nodes.OpType


class NodesTest(googletest.TestCase):
class NodesTest(absltest.TestCase):

def assertProtoEqual(self, a, b):
compare.assertProto2Equal(self, a, b)
Expand Down Expand Up @@ -102,4 +102,4 @@ class FakeNode(Node):


if __name__ == '__main__':
googletest.main()
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def jax_spec_to_tensor_spec(x: jax.ShapeDtypeStruct) -> TensorSpec:
)


class SavedModelBuilderTest(googletest.TestCase):
class SavedModelBuilderTest(absltest.TestCase):

def test_build_with_function_aliases(self):
@jax.jit
Expand Down Expand Up @@ -110,4 +110,4 @@ def f(xy, captured):


if __name__ == "__main__":
googletest.main()
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,4 @@ def test_string_list(self):
self.assertEqual(s, back_tensor.np_array[n])

if __name__ == '__main__':
googletest.main()
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -146,4 +146,4 @@ def f():


if __name__ == "__main__":
googletest.main()
absltest.main()
4 changes: 2 additions & 2 deletions model/orbax/experimental/model/core/python/tree_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from absl.testing import absltest


class TreeUtilTest(googletest.TestCase):
class TreeUtilTest(absltest.TestCase):

def setUp(self):
super().setUp()
Expand Down Expand Up @@ -79,4 +79,4 @@ def assert_int(x: Any) -> None:


if __name__ == "__main__":
googletest.main()
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -238,4 +238,4 @@ def test_manifest_type_to_shlo_tensor_spec_tree(self):


if __name__ == "__main__":
googletest.main()
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from absl.testing import absltest


class UnstructuredDataTest(googletest.TestCase):
class UnstructuredDataTest(absltest.TestCase):

def test_write_inlined_string_to_file(self):
proto = unstructured_data.UnstructuredData()
Expand Down Expand Up @@ -119,4 +119,4 @@ def test_maybe_write_location_pointer_to_file(self):


if __name__ == "__main__":
googletest.main()
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from absl.testing import absltest


class CompatTest(googletest.TestCase):
class CompatTest(absltest.TestCase):

def testCompatValidEncoding(self):
self.assertEqual(compat.as_bytes("hello", "utf8"), b"hello")
Expand All @@ -48,4 +48,4 @@ def testCompatInvalidEncoding(self):


if __name__ == "__main__":
googletest.main()
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from absl.testing import absltest


class NamingTest(googletest.TestCase):
class NamingTest(absltest.TestCase):

def test_validate_node_name(self):
self.assertTrue(naming.is_valid_node_name('correct_name_1534'))
Expand All @@ -30,4 +30,4 @@ def test_validate_node_name(self):


if __name__ == '__main__':
googletest.main()
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from absl.testing import absltest


class ObjectIdentityWrapperTest(googletest.TestCase):
class ObjectIdentityWrapperTest(absltest.TestCase):

def testWrapperNotEqualToWrapped(self):
class SettableHash(object):
Expand Down Expand Up @@ -63,7 +63,7 @@ def __hash__(self):
bool(o in set([wrap1]))


class ObjectIdentitySetTest(googletest.TestCase):
class ObjectIdentitySetTest(absltest.TestCase):

def testDifference(self):
class Element(object):
Expand Down Expand Up @@ -96,4 +96,4 @@ def testClear(self):


if __name__ == '__main__':
googletest.main()
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,4 @@ def test_to_shlo_dtype_and_refinement_wrong_type(self, jax_dtype):


if __name__ == '__main__':
googletest.main()
absltest.main()
2 changes: 1 addition & 1 deletion model/orbax/experimental/model/jax2obm/main_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,4 +1205,4 @@ def get_mesh():


if __name__ == '__main__':
googletest.main()
absltest.main()
2 changes: 1 addition & 1 deletion model/orbax/experimental/model/jax2obm/obm_to_jax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,4 +538,4 @@ def get_mesh():


if __name__ == '__main__':
googletest.main()
absltest.main()
4 changes: 2 additions & 2 deletions model/orbax/experimental/model/jax2obm/sharding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from absl.testing import absltest


class ShardingTest(googletest.TestCase):
class ShardingTest(absltest.TestCase):

def get_mesh(self):
def _create_mesh(devices) -> jax.sharding.Mesh:
Expand Down Expand Up @@ -182,4 +182,4 @@ def test_jax_mesh_to_obm_device_assignment_by_coords(self):


if __name__ == '__main__':
googletest.main()
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -422,4 +422,4 @@ def tf_fn(a):


if __name__ == "__main__":
googletest.main()
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from absl.testing import absltest


class TfFunctionDefToObmTest(googletest.TestCase):
class TfFunctionDefToObmTest(absltest.TestCase):

def test_tf_concrete_function_to_unstructured_data_success(self):

Expand Down Expand Up @@ -56,4 +56,4 @@ def tf_pow(a):


if __name__ == "__main__":
googletest.main()
absltest.main()
Loading