Skip to content

Commit

Permalink
Stop writing msgpack file for new checkpoints and update empty nodes …
Browse files Browse the repository at this point in the history
…handling so that it no longer depends on this file.

PiperOrigin-RevId: 636665054
  • Loading branch information
cpgaffney1 authored and copybara-github committed Jun 28, 2024
1 parent d94550f commit 91b219a
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions init2winit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,18 @@ def testAppendPytree(self):
latest = checkpoint.load_latest_checkpoint(pytree_path, prefix='')
saved_pytrees = latest['pytree'] if latest else []
self.assertEqual(
pytrees, [saved_pytrees[str(i)] for i in range(len(saved_pytrees))])
pytrees, [saved_pytrees[str(i)] for i in range(len(saved_pytrees))]
)

def testArrayAppend(self):
"""Test appending to an array."""
np.testing.assert_allclose(
utils.array_append(jnp.array([1, 2, 3]), 4), jnp.array([1, 2, 3, 4]))
utils.array_append(jnp.array([1, 2, 3]), 4), jnp.array([1, 2, 3, 4])
)
np.testing.assert_allclose(
utils.array_append(jnp.array([[1, 2], [3, 4]]), jnp.array([5, 6])),
jnp.array([[1, 2], [3, 4], [5, 6]]))
jnp.array([[1, 2], [3, 4], [5, 6]]),
)

def testTreeNormSqL2(self):
"""Test computing the squared L2 norm of a pytree."""
Expand All @@ -115,9 +118,9 @@ def testTreeNormSqL2(self):

def testTreeSum(self):
"""Test computing the sum of a pytree."""
pytree = {'foo': 2*jnp.ones(10), 'baz': jnp.ones(20)}
pytree = {'foo': 2 * jnp.ones(10), 'baz': jnp.ones(20)}
self.assertEqual(utils.total_tree_sum(pytree), 40)


if __name__ == '__main__':
absltest.main()

0 comments on commit 91b219a

Please sign in to comment.