Skip to content

Commit

Permalink
Clean up tests.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 382711807
  • Loading branch information
hbq1 authored and OptaxDev committed Jul 2, 2021
1 parent 8f21d37 commit d2954c2
Showing 1 changed file with 21 additions and 20 deletions.
41 changes: 21 additions & 20 deletions optax/_src/constrain_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,42 +72,43 @@ def test_zero_nans(self):
opt_state = self.variant(opt.init)(params)
update_fn = self.variant(opt.update)

equality_comp = lambda a, b: bool(jnp.all(jnp.equal(a, b)))
chex.assert_tree_all_equal_comparator(equality_comp, opt_state,
(jnp.array(False),) * 3)
chex.assert_tree_all_close(opt_state,
constrain.ZeroNansState((jnp.array(False),) * 3))

# Check an upate with nans
grads_with_nans = (jnp.ones([3]),
jnp.array([1., float('nan'), float('nan')]),
jnp.array([float('nan'), 1., 1.]))
updates, opt_state = update_fn(grads_with_nans, opt_state)
chex.assert_tree_all_equal_comparator(
equality_comp, opt_state,
(jnp.array(False), jnp.array(True), jnp.array(True)))
chex.assert_tree_all_equal_comparator(
equality_comp, updates,
chex.assert_tree_all_close(
opt_state,
constrain.ZeroNansState(
(jnp.array(False), jnp.array(True), jnp.array(True))))
chex.assert_tree_all_close(
updates,
(jnp.ones([3]), jnp.array([1., 0., 0.]), jnp.array([0., 1., 1.])))

# Check an upate with nans and infs
grads_with_nans_infs = (jnp.ones([3]),
jnp.array([1., float('nan'), float('nan')]),
jnp.array([1., float('nan'),
float('nan')]),
jnp.array([float('inf'), 1., 1.]))
updates, opt_state = update_fn(grads_with_nans_infs, opt_state)
chex.assert_tree_all_equal_comparator(
equality_comp, opt_state,
(jnp.array(False), jnp.array(True), jnp.array(False)))
chex.assert_tree_all_equal_comparator(
equality_comp, updates,
(jnp.ones([3]), jnp.array([1., 0., 0.]),
jnp.array([float('inf'), 1., 1.])))
chex.assert_tree_all_close(
opt_state,
constrain.ZeroNansState(
(jnp.array(False), jnp.array(True), jnp.array(False))))
chex.assert_tree_all_close(updates, (jnp.ones([3]), jnp.array(
[1., 0., 0.]), jnp.array([float('inf'), 1., 1.])))

# Check an upate with only good values
grads = (jnp.ones([3]), jnp.ones([3]), jnp.ones([3]))
updates, opt_state = update_fn(grads, opt_state)
chex.assert_tree_all_equal_comparator(
equality_comp, opt_state,
(jnp.array(False), jnp.array(False), jnp.array(False)))
chex.assert_tree_all_equal_comparator(equality_comp, updates, grads)
chex.assert_tree_all_close(
opt_state,
constrain.ZeroNansState(
(jnp.array(False), jnp.array(False), jnp.array(False))))
chex.assert_tree_all_close(updates, grads)


if __name__ == '__main__':
Expand Down

0 comments on commit d2954c2

Please sign in to comment.