Skip to content

Commit

Permalink
check exact match
Browse files Browse the repository at this point in the history
  • Loading branch information
justheuristic committed Mar 17, 2022
1 parent 4945cac commit e737a8f
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions tests/test_modifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ def _test_modification_consistency(
our_logits = out.prediction_logits.detach().clone()
our_grads = [param.grad.detach().clone() for param in model.albert.transformer.parameters()]
model.zero_grad(set_to_none=True)
assert torch.allclose(ref_logits, our_logits, rtol=0, atol=1e-5), abs(ref_logits-our_logits).max()

with pytest.raises(AssertionError) if grad_fails else nullcontext():
for g_ref, g_our in zip(ref_grads, our_grads):
assert torch.allclose(g_ref, g_our, rtol=0, atol=1e-5), abs(g_ref-g_our).max()
assert torch.allclose(ref_logits, our_logits), abs(ref_logits-our_logits).max()

with pytest.raises(AssertionError) if grad_fails else nullcontext():
for g_ref, g_our in zip(ref_grads, our_grads):
assert torch.allclose(g_ref, g_our), abs(g_ref-g_our).max()

0 comments on commit e737a8f

Please sign in to comment.