Skip to content

Commit

Permalink
More tests
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede committed Jan 8, 2025
1 parent 304b762 commit 284a8e5
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 4 deletions.
3 changes: 2 additions & 1 deletion src/dxtb/_src/components/interactions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,8 @@ def get_energy(
self, charges: Charges, cache: InteractionCache, ihelp: IndexHelper
) -> Tensor:
"""
Compute the energy from the charges, all quantities are orbital-resolved.
Compute the energy from the charges, all quantities are
orbital-resolved.
Parameters
----------
Expand Down
16 changes: 16 additions & 0 deletions test/test_classical/test_dispersion/test_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,19 @@ def test_fail_too_many_parameters() -> None:

with pytest.raises(ValueError):
new_dispersion(torch.tensor(0.0), _par)


def test_d4_cache() -> None:
numbers = torch.tensor([3, 1])

_par2 = GFN2_XTB.model_copy(deep=True)
_par2.dispersion.d4.sc = False # type: ignore

disp = new_dispersion(numbers, _par2, torch.tensor(0.0))
assert disp is not None

_ = disp.get_cache(numbers=numbers)
assert disp.cache_is_latest((numbers.detach().clone(),))

_ = disp.get_cache(numbers=numbers)
assert disp.cache_is_latest((numbers.detach().clone(),))
9 changes: 9 additions & 0 deletions test/test_classical/test_halogen/test_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,12 @@ def test_change_device_fail() -> None:
# trying to use setter
with pytest.raises(AttributeError):
cls.device = "cpu"


def test_fail_requires_ihelp() -> None:
numbers = torch.tensor([3, 1])
cls = new_halogen(numbers, par)
assert cls is not None

with pytest.raises(ValueError):
cls.get_cache(numbers=numbers, ihelp=None)
9 changes: 9 additions & 0 deletions test/test_classical/test_repulsion/test_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,12 @@ def test_change_device_fail() -> None:
# trying to use setter
with pytest.raises(AttributeError):
cls.device = "cpu"


def test_fail_requires_ihelp() -> None:
numbers = torch.tensor([3, 1])
cls = new_repulsion(numbers, par)
assert cls is not None

with pytest.raises(ValueError):
cls.get_cache(numbers=numbers, ihelp=None)
2 changes: 1 addition & 1 deletion test/test_cli/test_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,5 +67,5 @@ def test_entrypoint(

out, err = capsys.readouterr()
assert err == ""
assert out == ""
assert out == "", "No output should be printed. Leftover debug prints?"
assert len(caplog.text) == 0
82 changes: 82 additions & 0 deletions test/test_components/test_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# This file is part of dxtb.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Test caches.
"""

from __future__ import annotations

import pytest
import torch

from dxtb import GFN1_XTB, GFN2_XTB, IndexHelper, Param
from dxtb._src.typing import Callable, Tensor
from dxtb.components.base import Classical, ComponentCache
from dxtb.components.dispersion import new_dispersion
from dxtb.components.halogen import new_halogen
from dxtb.components.repulsion import new_repulsion

from ..conftest import DEVICE


@pytest.mark.parametrize(
"comp_factory_par",
[
(new_repulsion, GFN1_XTB),
(new_halogen, GFN1_XTB),
(new_dispersion, GFN1_XTB),
],
)
def test_fail_overwritten_cache(
comp_factory_par: tuple[Callable[[Tensor, Param], Classical], Param]
) -> None:
numbers = torch.tensor([3, 1], device=DEVICE)
ihelp = IndexHelper.from_numbers(numbers, GFN1_XTB)

comp_factory, par = comp_factory_par
comp = comp_factory(numbers, par)
assert comp is not None

# create cache
comp.cache_enable()
_ = comp.get_cache(numbers=numbers, ihelp=ihelp)

# manually overwrite cache
comp.cache = ComponentCache()

with pytest.raises(TypeError):
comp.get_cache(numbers=numbers, ihelp=ihelp)


def test_fail_overwritten_cache_d4() -> None:
numbers = torch.tensor([3, 1], device=DEVICE)

par = GFN2_XTB.model_copy(deep=True)
par.dispersion.d4.sc = False # type: ignore

d4 = new_dispersion(numbers, par, charge=torch.tensor(0.0, device=DEVICE))
assert d4 is not None

# create cache
d4.cache_enable()
_ = d4.get_cache(numbers=numbers)

# manually overwrite cache
d4.cache = ComponentCache()

with pytest.raises(TypeError):
d4.get_cache(numbers=numbers)
2 changes: 0 additions & 2 deletions test/test_external/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ def test_single(dtype: torch.dtype, name: str) -> None:
calc = Calculator(numbers, GFN1_XTB, interaction=[efield], opts=opts, **dd)

result = calc.singlepoint(positions, charges)
print(result.total)

res = result.total.sum(-1)
assert pytest.approx(ref.cpu(), abs=tol, rel=tol) == res.cpu()

Expand Down

0 comments on commit 284a8e5

Please sign in to comment.