diff --git a/tests/test_modifications.py b/tests/test_modifications.py index 0e628f8..d8960aa 100644 --- a/tests/test_modifications.py +++ b/tests/test_modifications.py @@ -70,8 +70,9 @@ def _test_modification_consistency( our_logits = out.prediction_logits.detach().clone() our_grads = [param.grad.detach().clone() for param in model.albert.transformer.parameters()] model.zero_grad(set_to_none=True) - assert torch.allclose(ref_logits, our_logits, rtol=0, atol=1e-5), abs(ref_logits-our_logits).max() - with pytest.raises(AssertionError) if grad_fails else nullcontext(): - for g_ref, g_our in zip(ref_grads, our_grads): - assert torch.allclose(g_ref, g_our, rtol=0, atol=1e-5), abs(g_ref-g_our).max() + assert torch.allclose(ref_logits, our_logits), abs(ref_logits-our_logits).max() + + with pytest.raises(AssertionError) if grad_fails else nullcontext(): + for g_ref, g_our in zip(ref_grads, our_grads): + assert torch.allclose(g_ref, g_our), abs(g_ref-g_our).max()