Skip to content

Commit

Permalink
NONDET_TOL (for CUDA tests)
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede committed Aug 25, 2024
1 parent d087c30 commit 9847814
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 26 deletions.
3 changes: 3 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
DEVICE: torch.device | None = None
"""Name of Device."""

NONDET_TOL = 1e-7
"""Tolerance for non-deterministic tests."""


# A bug in PyTorch 2.3.0 and 2.3.1 somehow requires manual import of
# `torch._dynamo` to avoid errors with functorch in custom backward
Expand Down
6 changes: 3 additions & 3 deletions test/test_coulomb/test_es2_atom.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from dxtb._src.param.utils import get_elem_param
from dxtb._src.typing import DD, Tensor

from ..conftest import DEVICE
from ..conftest import DEVICE, NONDET_TOL
from .samples import samples

sample_list = ["MB16_43_01", "MB16_43_02", "SiH4_atom"]
Expand Down Expand Up @@ -127,7 +127,7 @@ def func(p: Tensor):
cache = es.get_cache(numbers, p, ihelp)
return es.get_atom_energy(qat, cache)

assert dgradcheck(func, pos)
assert dgradcheck(func, pos, nondet_tol=NONDET_TOL)


@pytest.mark.grad
Expand Down Expand Up @@ -157,4 +157,4 @@ def func(gexp: Tensor, hubbard: Tensor):
cache = es.get_cache(numbers, positions, ihelp)
return es.get_atom_energy(qat, cache)

assert dgradcheck(func, (gexp, hubbard))
assert dgradcheck(func, (gexp, hubbard), nondet_tol=NONDET_TOL)
14 changes: 7 additions & 7 deletions test/test_coulomb/test_grad_shell_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from dxtb._src.param import get_elem_param
from dxtb._src.typing import DD, Callable, Tensor

from ..conftest import DEVICE
from ..conftest import DEVICE, NONDET_TOL
from .samples import samples

sample_list = ["LiH", "SiH4"] # "MB16_43_01" requires a lot of RAM
Expand Down Expand Up @@ -90,7 +90,7 @@ def test_grad_param(dtype: torch.dtype, name: str) -> None:
gradient from `torch.autograd.gradcheck`.
"""
func, diffvars = gradcheck_param(dtype, name)
assert dgradcheck(func, diffvars, atol=tol)
assert dgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL)


@pytest.mark.grad
Expand All @@ -103,7 +103,7 @@ def test_grad_param_large(dtype: torch.dtype, name: str) -> None:
gradient from `torch.autograd.gradcheck`.
"""
func, diffvars = gradcheck_param(dtype, name)
assert dgradcheck(func, diffvars, atol=tol)
assert dgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL)


@pytest.mark.grad
Expand All @@ -115,7 +115,7 @@ def test_gradgrad_param(dtype: torch.dtype, name: str) -> None:
gradient from `torch.autograd.gradgradcheck`.
"""
func, diffvars = gradcheck_param(dtype, name)
assert dgradgradcheck(func, diffvars, atol=tol)
assert dgradgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL)


@pytest.mark.grad
Expand All @@ -128,7 +128,7 @@ def test_gradgrad_param_large(dtype: torch.dtype, name: str) -> None:
gradient from `torch.autograd.gradcheck`.
"""
func, diffvars = gradcheck_param(dtype, name)
assert dgradgradcheck(func, diffvars, atol=tol)
assert dgradgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL)


def gradcheck_param_batch(dtype: torch.dtype, name1: str, name2: str) -> tuple[
Expand Down Expand Up @@ -194,7 +194,7 @@ def test_grad_param_batch(dtype: torch.dtype, name1: str, name2: str) -> None:
gradient from `torch.autograd.gradcheck`.
"""
func, diffvars = gradcheck_param_batch(dtype, name1, name2)
assert dgradcheck(func, diffvars, atol=tol)
assert dgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL)


@pytest.mark.grad
Expand All @@ -207,4 +207,4 @@ def test_gradgrad_param_batch(dtype: torch.dtype, name1: str, name2: str) -> Non
gradient from `torch.autograd.gradgradcheck`.
"""
func, diffvars = gradcheck_param_batch(dtype, name1, name2)
assert dgradgradcheck(func, diffvars, atol=tol)
assert dgradgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL)
16 changes: 9 additions & 7 deletions test/test_coulomb/test_grad_shell_pos.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from dxtb._src.param import get_elem_param
from dxtb._src.typing import DD, Callable, Tensor

from ..conftest import DEVICE
from ..conftest import DEVICE, NONDET_TOL
from .samples import samples

sample_list = ["LiH", "SiH4"]
Expand Down Expand Up @@ -92,7 +92,7 @@ def test_grad(dtype: torch.dtype, name: str) -> None:
gradient from `torch.autograd.gradcheck`.
"""
func, diffvars = gradchecker(dtype, name)
assert dgradcheck(func, diffvars, atol=tol)
assert dgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL)


@pytest.mark.grad
Expand All @@ -105,7 +105,7 @@ def test_grad_large(dtype: torch.dtype, name: str) -> None:
gradient from `torch.autograd.gradcheck`.
"""
func, diffvars = gradchecker(dtype, name)
assert dgradcheck(func, diffvars, atol=tol, fast_mode=True)
assert dgradcheck(func, diffvars, atol=tol, fast_mode=True, nondet_tol=NONDET_TOL)


@pytest.mark.grad
Expand All @@ -118,7 +118,7 @@ def test_gradgrad(dtype: torch.dtype, name: str) -> None:
gradient from `torch.autograd.gradgradcheck`.
"""
func, diffvars = gradchecker(dtype, name)
assert dgradgradcheck(func, diffvars, atol=tol)
assert dgradgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL)


@pytest.mark.grad
Expand All @@ -130,7 +130,9 @@ def test_gradgrad_large(dtype: torch.dtype, name: str) -> None:
gradient from `torch.autograd.gradgradcheck`.
"""
func, diffvars = gradchecker(dtype, name)
assert dgradgradcheck(func, diffvars, atol=tol, fast_mode=True)
assert dgradgradcheck(
func, diffvars, atol=tol, fast_mode=True, nondet_tol=NONDET_TOL
)


def gradchecker_batch(
Expand Down Expand Up @@ -196,7 +198,7 @@ def test_grad_batch(dtype: torch.dtype, name1: str, name2: str) -> None:
gradient from `torch.autograd.gradcheck`.
"""
func, diffvars = gradchecker_batch(dtype, name1, name2)
assert dgradcheck(func, diffvars, atol=tol, nondet_tol=1e-7)
assert dgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL)


@pytest.mark.grad
Expand All @@ -209,4 +211,4 @@ def test_gradgrad_batch(dtype: torch.dtype, name1: str, name2: str) -> None:
gradient from `torch.autograd.gradgradcheck`.
"""
func, diffvars = gradchecker_batch(dtype, name1, name2)
assert dgradgradcheck(func, diffvars, atol=tol, nondet_tol=1e-7)
assert dgradgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL)
18 changes: 9 additions & 9 deletions test/test_overlap/test_grad_pos.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from dxtb._src.integral.driver.pytorch import OverlapPytorch as Overlap
from dxtb._src.typing import DD, Callable, Tensor

from ..conftest import DEVICE
from ..conftest import DEVICE, NONDET_TOL
from .samples import samples

slist = ["LiH", "H2O"]
Expand Down Expand Up @@ -73,7 +73,7 @@ def test_grad(dtype: torch.dtype, name: str) -> None:
gradient from `torch.autograd.gradcheck`.
"""
func, diffvars = gradchecker(dtype, name)
assert dgradcheck(func, diffvars, atol=tol)
assert dgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL)


@pytest.mark.grad
Expand All @@ -85,7 +85,7 @@ def test_grad_large(dtype: torch.dtype, name: str) -> None:
gradient from `torch.autograd.gradcheck`.
"""
func, diffvars = gradchecker(dtype, name)
assert dgradcheck(func, diffvars, atol=tol)
assert dgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL)


@pytest.mark.grad
Expand All @@ -97,7 +97,7 @@ def test_gradgrad(dtype: torch.dtype, name: str) -> None:
gradient from `torch.autograd.gradgradcheck`.
"""
func, diffvars = gradchecker(dtype, name)
assert dgradgradcheck(func, diffvars, atol=tol)
assert dgradgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL)


@pytest.mark.grad
Expand All @@ -109,7 +109,7 @@ def test_gradgrad_large(dtype: torch.dtype, name: str) -> None:
gradient from `torch.autograd.gradgradcheck`.
"""
func, diffvars = gradchecker(dtype, name)
assert dgradgradcheck(func, diffvars, atol=tol)
assert dgradgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL)


def gradchecker_batch(
Expand Down Expand Up @@ -161,7 +161,7 @@ def test_grad_batch(dtype: torch.dtype, name1: str, name2: str) -> None:
gradient from `torch.autograd.gradcheck`.
"""
func, diffvars = gradchecker_batch(dtype, name1, name2)
assert dgradcheck(func, diffvars, atol=tol)
assert dgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL)


@pytest.mark.grad
Expand All @@ -175,7 +175,7 @@ def test_grad_batch_large(dtype: torch.dtype, name1: str, name2: str) -> None:
gradient from `torch.autograd.gradcheck`.
"""
func, diffvars = gradchecker_batch(dtype, name1, name2)
assert dgradcheck(func, diffvars, atol=tol)
assert dgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL)


@pytest.mark.grad
Expand All @@ -189,7 +189,7 @@ def test_gradgrad_batch(dtype: torch.dtype, name1: str, name2: str) -> None:
gradient from `torch.autograd.gradgradcheck`.
"""
func, diffvars = gradchecker_batch(dtype, name1, name2)
assert dgradgradcheck(func, diffvars, atol=tol)
assert dgradgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL)


@pytest.mark.grad
Expand All @@ -203,4 +203,4 @@ def test_gradgrad_batch_large(dtype: torch.dtype, name1: str, name2: str) -> Non
gradient from `torch.autograd.gradgradcheck`.
"""
func, diffvars = gradchecker_batch(dtype, name1, name2)
assert dgradgradcheck(func, diffvars, atol=tol)
assert dgradgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL)

0 comments on commit 9847814

Please sign in to comment.