diff --git a/src/dxtb/_src/calculators/types/energy.py b/src/dxtb/_src/calculators/types/energy.py index 7eb105ce3..bce69b065 100644 --- a/src/dxtb/_src/calculators/types/energy.py +++ b/src/dxtb/_src/calculators/types/energy.py @@ -325,15 +325,31 @@ def singlepoint( if kwargs.get("store_hcore", copts.hcore): self.cache["hcore"] = self.integrals.hcore + else: + if self.integrals.hcore is not None: + if self.integrals.hcore.requires_grad is False: + self.integrals.hcore.clear() if kwargs.get("store_overlap", copts.overlap): self.cache["overlap"] = self.integrals.overlap + else: + if self.integrals.overlap is not None: + if self.integrals.overlap.requires_grad is False: + self.integrals.overlap.clear() if kwargs.get("store_dipole", copts.dipole): self.cache["dipint"] = self.integrals.dipole + else: + if self.integrals.dipole is not None: + if self.integrals.dipole.requires_grad is False: + self.integrals.dipole.clear() if kwargs.get("store_quadrupole", copts.quadrupole): self.cache["quadint"] = self.integrals.quadrupole + else: + if self.integrals.quadrupole is not None: + if self.integrals.quadrupole.requires_grad is False: + self.integrals.quadrupole.clear() self._ncalcs += 1 return result diff --git a/src/dxtb/_src/components/interactions/coulomb/secondorder.py b/src/dxtb/_src/components/interactions/coulomb/secondorder.py index d638e1b40..5f8c46df3 100644 --- a/src/dxtb/_src/components/interactions/coulomb/secondorder.py +++ b/src/dxtb/_src/components/interactions/coulomb/secondorder.py @@ -32,11 +32,11 @@ # Define atomic numbers, positions, and charges numbers = torch.tensor([14, 1, 1, 1, 1]) positions = torch.tensor([ - [0.00000000000000, -0.00000000000000, 0.00000000000000], - [1.61768389755830, 1.61768389755830, -1.61768389755830], + [+0.00000000000000, -0.00000000000000, +0.00000000000000], + [+1.61768389755830, +1.61768389755830, -1.61768389755830], [-1.61768389755830, -1.61768389755830, -1.61768389755830], - [1.61768389755830, -1.61768389755830, 1.61768389755830], - [-1.61768389755830, 1.61768389755830, 1.61768389755830], + [+1.61768389755830, -1.61768389755830, +1.61768389755830], + [-1.61768389755830, +1.61768389755830, +1.61768389755830], ]) q = torch.tensor([ diff --git a/src/dxtb/_src/components/interactions/coulomb/thirdorder.py b/src/dxtb/_src/components/interactions/coulomb/thirdorder.py index c10a3738c..483d0e8d6 100644 --- a/src/dxtb/_src/components/interactions/coulomb/thirdorder.py +++ b/src/dxtb/_src/components/interactions/coulomb/thirdorder.py @@ -33,11 +33,11 @@ # Define atomic numbers and their positions numbers = torch.tensor([14, 1, 1, 1, 1]) positions = torch.tensor([ - [0.00000000000000, -0.00000000000000, 0.00000000000000], - [1.61768389755830, 1.61768389755830, -1.61768389755830], + [+0.00000000000000, -0.00000000000000, +0.00000000000000], + [+1.61768389755830, +1.61768389755830, -1.61768389755830], [-1.61768389755830, -1.61768389755830, -1.61768389755830], - [1.61768389755830, -1.61768389755830, 1.61768389755830], - [-1.61768389755830, 1.61768389755830, 1.61768389755830], + [+1.61768389755830, -1.61768389755830, +1.61768389755830], + [-1.61768389755830, +1.61768389755830, +1.61768389755830], ]) # Atomic charges diff --git a/src/dxtb/_src/components/interactions/solvation/alpb.py b/src/dxtb/_src/components/interactions/solvation/alpb.py index f92e45899..7682f2d8b 100644 --- a/src/dxtb/_src/components/interactions/solvation/alpb.py +++ b/src/dxtb/_src/components/interactions/solvation/alpb.py @@ -29,11 +29,11 @@ numbers = torch.tensor([14, 1, 1, 1, 1]) positions = torch.tensor([ - [0.00000000000000, -0.00000000000000, 0.00000000000000], - [1.61768389755830, 1.61768389755830, -1.61768389755830], + [+0.00000000000000, -0.00000000000000, +0.00000000000000], + [+1.61768389755830, +1.61768389755830, -1.61768389755830], [-1.61768389755830, -1.61768389755830, -1.61768389755830], - [1.61768389755830, -1.61768389755830, 1.61768389755830], - [-1.61768389755830, 1.61768389755830, 1.61768389755830], + [+1.61768389755830, -1.61768389755830, +1.61768389755830], + [-1.61768389755830, +1.61768389755830, +1.61768389755830], ]) charges = torch.tensor([ -8.41282505804719e-2, diff --git a/src/dxtb/_src/components/interactions/solvation/born.py b/src/dxtb/_src/components/interactions/solvation/born.py index 044bc0ede..c5326f3db 100644 --- a/src/dxtb/_src/components/interactions/solvation/born.py +++ b/src/dxtb/_src/components/interactions/solvation/born.py @@ -31,11 +31,11 @@ # Define atomic numbers and positions of the atoms numbers = torch.tensor([14, 1, 1, 1, 1]) positions = torch.tensor([ - [0.00000000000000, -0.00000000000000, 0.00000000000000], - [1.61768389755830, 1.61768389755830, -1.61768389755830], + [+0.00000000000000, -0.00000000000000, +0.00000000000000], + [+1.61768389755830, +1.61768389755830, -1.61768389755830], [-1.61768389755830, -1.61768389755830, -1.61768389755830], - [1.61768389755830, -1.61768389755830, 1.61768389755830], - [-1.61768389755830, 1.61768389755830, 1.61768389755830], + [+1.61768389755830, -1.61768389755830, +1.61768389755830], + [-1.61768389755830, +1.61768389755830, +1.61768389755830], ]) # Calculate the Born radii for the given atomic configuration diff --git a/src/dxtb/_src/integral/base.py b/src/dxtb/_src/integral/base.py index 8e81cf10f..5a96e5c4f 100644 --- a/src/dxtb/_src/integral/base.py +++ b/src/dxtb/_src/integral/base.py @@ -187,15 +187,12 @@ def setup(self, positions: Tensor, **kwargs) -> None: """ def __str__(self) -> str: - dict_repr = [] - for key, value in self.__dict__.items(): - if isinstance(value, Tensor): - value_repr = f"{value.shape}" - else: - value_repr = repr(value) - dict_repr.append(f" {key}: {value_repr}") - dict_str = "{\n" + ",\n".join(dict_repr) + "\n}" - return f"{self.__class__.__name__}({dict_str})" + return ( + f"{self.__class__.__name__}(" + f"Family: {self.family}, " + f"Number of Atoms: {self.numbers.shape[-1]}, " + f"Setup?: {self.is_setup()})" + ) def __repr__(self) -> str: return str(self) @@ -328,6 +325,22 @@ def clear(self) -> None: self._norm = None self._gradient = None + @property + def requires_grad(self) -> bool: + """ + Check if any field of the integral class is requires gradient. + + Returns + ------- + bool + Flag for gradient requirement. + """ + for field in (self._matrix, self._gradient, self._norm): + if field is not None and field.requires_grad: + return True + + return False + def normalize(self, norm: Tensor | None = None) -> None: """ Normalize the integral (changes ``self.matrix``). @@ -401,6 +414,8 @@ def __str__(self) -> str: d["_matrix"] = self._matrix.shape if self._gradient is not None: d["_gradient"] = self._gradient.shape + if self._norm is not None: + d["_norm"] = self._norm.shape return f"{self.__class__.__name__}({d})" diff --git a/src/dxtb/_src/integral/driver/factory.py b/src/dxtb/_src/integral/driver/factory.py index 62bac47a9..5b4f0e5f8 100644 --- a/src/dxtb/_src/integral/driver/factory.py +++ b/src/dxtb/_src/integral/driver/factory.py @@ -57,7 +57,9 @@ def new_driver( return new_driver_pytorch(numbers, par, device=device, dtype=dtype) if name == labels.INTDRIVER_AUTOGRAD: - return new_driver_pytorch2(numbers, par, device=device, dtype=dtype) + return new_driver_pytorch_no_analytical( + numbers, par, device=device, dtype=dtype + ) if name == labels.INTDRIVER_LEGACY: return new_driver_legacy(numbers, par, device=device, dtype=dtype) @@ -97,7 +99,7 @@ def new_driver_pytorch( return _IntDriver(numbers, par, ihelp, device=device, dtype=dtype) -def new_driver_pytorch2( +def new_driver_pytorch_no_analytical( numbers: Tensor, par: Param, device: torch.device | None = None, diff --git a/src/dxtb/_src/integral/types/quadrupole.py b/src/dxtb/_src/integral/types/quadrupole.py index 820fb3ebf..6d4e8d0f5 100644 --- a/src/dxtb/_src/integral/types/quadrupole.py +++ b/src/dxtb/_src/integral/types/quadrupole.py @@ -68,7 +68,7 @@ def traceless(self) -> Tensor: zx zy zz 6 7 8 6 7 8 """ - if self.matrix.shape[-3] != 9: + if self.matrix.ndim not in (3, 4) or self.matrix.shape[-3] != 9: raise RuntimeError( "Quadrupole integral must be a tensor tensor of shape " f"'(9, nao, nao)' but is {self.matrix.shape}." diff --git a/src/dxtb/_src/integral/utils.py b/src/dxtb/_src/integral/utils.py index d5daa7ffb..db00c899f 100644 --- a/src/dxtb/_src/integral/utils.py +++ b/src/dxtb/_src/integral/utils.py @@ -30,4 +30,5 @@ def snorm(ovlp: Tensor) -> Tensor: d = ovlp.diagonal(dim1=-1, dim2=-2) - return torch.where(d == 0.0, 0.0, torch.pow(d, -0.5)) + zero = torch.tensor(0.0, dtype=d.dtype, device=d.device) + return torch.where(d == 0.0, zero, torch.pow(d, -0.5)) diff --git a/src/dxtb/_src/integral/wrappers.py b/src/dxtb/_src/integral/wrappers.py index c36fc1446..0f80d8ebe 100644 --- a/src/dxtb/_src/integral/wrappers.py +++ b/src/dxtb/_src/integral/wrappers.py @@ -72,10 +72,8 @@ from dxtb._src.xtb.gfn1 import GFN1Hamiltonian from dxtb._src.xtb.gfn2 import GFN2Hamiltonian -from .driver.factory import new_driver from .driver.manager import DriverManager from .factory import new_dipint, new_overlap, new_quadint -from .types import DipoleIntegral, OverlapIntegral, QuadrupoleIntegral __all__ = ["hcore", "overlap", "dipint", "quadint"] @@ -100,20 +98,20 @@ def hcore(numbers: Tensor, positions: Tensor, par: Param, **kwargs: Any) -> Tens Raises ------ - TypeError + ValueError If the Hamiltonian parametrization does not contain meta data or if the meta data does not contain a name to select the correct Hamiltonian. ValueError If the Hamiltonian name is unknown. """ if par.meta is None: - raise TypeError( + raise ValueError( "Meta data of Hamiltonian parametrization must contain a name. " "Otherwise, the correct Hamiltonian cannot be selected internally." ) if par.meta.name is None: - raise TypeError( + raise ValueError( "The name field of the meta data of the Hamiltonian " "parametrization must contain a name. Otherwise, the correct " "Hamiltonian cannot be selected internally." diff --git a/src/dxtb/_src/xtb/base.py b/src/dxtb/_src/xtb/base.py index 31947cbdc..c982c4167 100644 --- a/src/dxtb/_src/xtb/base.py +++ b/src/dxtb/_src/xtb/base.py @@ -131,6 +131,13 @@ def clear(self) -> None: """ self._matrix = None + @property + def requires_grad(self) -> bool: + if self._matrix is None: + return False + + return self._matrix.requires_grad + def get_occupation(self) -> Tensor: """ Obtain the reference occupation numbers for each orbital. diff --git a/test/test_calculator/test_cache/test_integrals.py b/test/test_calculator/test_cache/test_integrals.py new file mode 100644 index 000000000..63d0aafbe --- /dev/null +++ b/test/test_calculator/test_cache/test_integrals.py @@ -0,0 +1,88 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Test caching integrals. +""" + +from __future__ import annotations + +import pytest +import torch + +from dxtb._src.typing import DD, Tensor +from dxtb.calculators import GFN1Calculator + +from ...conftest import DEVICE + +opts = {"cache_enabled": True, "verbosity": 0} + + +@pytest.mark.parametrize("dtype", [torch.float, torch.double]) +def test_overlap_deleted(dtype: torch.dtype) -> None: + dd: DD = {"device": DEVICE, "dtype": dtype} + + numbers = torch.tensor([3, 1], device=DEVICE) + positions = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]], **dd) + + calc = GFN1Calculator(numbers, opts={"verbosity": 0}, **dd) + assert calc._ncalcs == 0 + + # overlap should not be cached + assert calc.opts.cache.store.overlap == False + + # one successful calculation + energy = calc.get_energy(positions) + assert calc._ncalcs == 1 + assert isinstance(energy, Tensor) + + # cache should be empty + assert calc.cache.overlap is None + + # ... but also the tensors in the calculator should be deleted + assert calc.integrals.overlap is not None + assert calc.integrals.overlap._matrix is None + assert calc.integrals.overlap._norm is None + assert calc.integrals.overlap._gradient is None + + +@pytest.mark.parametrize("dtype", [torch.float, torch.double]) +def test_overlap_retained_for_grad(dtype: torch.dtype) -> None: + dd: DD = {"device": DEVICE, "dtype": dtype} + + numbers = torch.tensor([3, 1], device=DEVICE) + positions = torch.tensor( + [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]], **dd, requires_grad=True + ) + + calc = GFN1Calculator(numbers, opts={"verbosity": 0}, **dd) + assert calc._ncalcs == 0 + + # overlap should not be cached + assert calc.opts.cache.store.overlap == False + + # one successful calculation + energy = calc.get_energy(positions) + assert calc._ncalcs == 1 + assert isinstance(energy, Tensor) + + # cache should still be empty ... + assert calc.cache.overlap is None + + # ... but the tensors in the calculator should still be there + assert calc.integrals.overlap is not None + assert calc.integrals.overlap._matrix is not None + assert calc.integrals.overlap._norm is not None diff --git a/test/test_indexhelper/test_extra.py b/test/test_indexhelper/test_extra.py index 3455b04de..741242f47 100644 --- a/test/test_indexhelper/test_extra.py +++ b/test/test_indexhelper/test_extra.py @@ -168,7 +168,8 @@ def test_spread_unique_batch() -> None: x = torch.randn((nbatch, nat_u, 3), device=DEVICE) # pollutes CUDA memory - assert False + if DEVICE is not None: + assert False out = ihelp.spread_uspecies_to_atom(x, dim=-2, extra=True) assert out.shape == torch.Size((nbatch, nat, 3)) diff --git a/test/test_integrals/test_driver/__init__.py b/test/test_integrals/test_driver/__init__.py new file mode 100644 index 000000000..15d042be4 --- /dev/null +++ b/test/test_integrals/test_driver/__init__.py @@ -0,0 +1,16 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/test/test_integrals/test_driver/test_factory.py b/test/test_integrals/test_driver/test_factory.py new file mode 100644 index 000000000..0460db86a --- /dev/null +++ b/test/test_integrals/test_driver/test_factory.py @@ -0,0 +1,70 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Test factories for integral drivers. +""" + +from __future__ import annotations + +import pytest +import torch + +from dxtb import GFN1_XTB, labels +from dxtb._src.integral.driver import factory +from dxtb._src.integral.driver.libcint import IntDriverLibcint +from dxtb._src.integral.driver.pytorch import ( + IntDriverPytorch, + IntDriverPytorchLegacy, + IntDriverPytorchNoAnalytical, +) + +numbers = torch.tensor([14, 1, 1, 1, 1]) + + +def test_fail() -> None: + with pytest.raises(ValueError): + factory.new_driver(-1, numbers, GFN1_XTB) + + +def test_driver() -> None: + cls = factory.new_driver(labels.INTDRIVER_LIBCINT, numbers, GFN1_XTB) + assert isinstance(cls, IntDriverLibcint) + + cls = factory.new_driver(labels.INTDRIVER_ANALYTICAL, numbers, GFN1_XTB) + assert isinstance(cls, IntDriverPytorch) + + cls = factory.new_driver(labels.INTDRIVER_AUTOGRAD, numbers, GFN1_XTB) + assert isinstance(cls, IntDriverPytorchNoAnalytical) + + cls = factory.new_driver(labels.INTDRIVER_LEGACY, numbers, GFN1_XTB) + assert isinstance(cls, IntDriverPytorchLegacy) + + +def test_libcint() -> None: + cls = factory.new_driver_libcint(numbers, GFN1_XTB) + assert isinstance(cls, IntDriverLibcint) + + +def test_pytorch() -> None: + cls = factory.new_driver_pytorch(numbers, GFN1_XTB) + assert isinstance(cls, IntDriverPytorch) + + cls = factory.new_driver_pytorch_no_analytical(numbers, GFN1_XTB) + assert isinstance(cls, IntDriverPytorchNoAnalytical) + + cls = factory.new_driver_legacy(numbers, GFN1_XTB) + assert isinstance(cls, IntDriverPytorchLegacy) diff --git a/test/test_integrals/test_driver.py b/test/test_integrals/test_driver/test_manager.py similarity index 97% rename from test/test_integrals/test_driver.py rename to test/test_integrals/test_driver/test_manager.py index d08141354..e819de65e 100644 --- a/test/test_integrals/test_driver.py +++ b/test/test_integrals/test_driver/test_manager.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Test overlap build from integral container. +Test the integral driver manager. """ from __future__ import annotations @@ -25,14 +25,13 @@ from dxtb import GFN1_XTB as par from dxtb import IndexHelper -from dxtb import integrals as ints from dxtb._src.constants.labels import INTDRIVER_ANALYTICAL, INTDRIVER_LIBCINT from dxtb._src.integral.driver.libcint import IntDriverLibcint from dxtb._src.integral.driver.manager import DriverManager from dxtb._src.integral.driver.pytorch import IntDriverPytorch from dxtb._src.typing import DD -from ..conftest import DEVICE +from ...conftest import DEVICE @pytest.mark.parametrize("dtype", [torch.float, torch.double]) diff --git a/test/test_integrals/test_driver/test_pytorch.py b/test/test_integrals/test_driver/test_pytorch.py new file mode 100644 index 000000000..d15b6443b --- /dev/null +++ b/test/test_integrals/test_driver/test_pytorch.py @@ -0,0 +1,139 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Test the PyTorch integral driver. +""" + +from __future__ import annotations + +import pytest +import torch +from tad_mctc.batch import pack + +from dxtb import GFN1_XTB, IndexHelper +from dxtb._src.integral.driver.pytorch.driver import BaseIntDriverPytorch +from dxtb._src.typing import DD + +from ...conftest import DEVICE + + +@pytest.mark.parametrize("dtype", [torch.float, torch.double]) +def test_single(dtype: torch.dtype): + dd: DD = {"dtype": dtype, "device": DEVICE} + + numbers = torch.tensor([3, 1], device=DEVICE) + positions = torch.zeros((2, 3), **dd) + ihelp = IndexHelper.from_numbers(numbers, GFN1_XTB) + + drv = BaseIntDriverPytorch(numbers, GFN1_XTB, ihelp, **dd) + drv.setup(positions) + + assert drv._basis is not None + assert drv._positions is not None + + +@pytest.mark.parametrize("dtype", [torch.float, torch.double]) +def test_batch_mode_fail(dtype: torch.dtype) -> None: + dd: DD = {"dtype": dtype, "device": DEVICE} + + numbers = torch.tensor([[3, 1], [1, 0]], device=DEVICE) + positions = torch.zeros((2, 2, 3), **dd) + ihelp = IndexHelper.from_numbers(numbers, GFN1_XTB) + + # set to invalid value + ihelp.batch_mode = -99 + + drv = BaseIntDriverPytorch(numbers, GFN1_XTB, ihelp, **dd) + + with pytest.raises(ValueError): + drv.setup(positions) + + +@pytest.mark.parametrize("dtype", [torch.float, torch.double]) +def test_batch_mode1(dtype: torch.dtype) -> None: + dd: DD = {"dtype": dtype, "device": DEVICE} + + numbers = torch.tensor([[3, 1], [1, 0]], device=DEVICE) + positions = pack( + [ + torch.tensor([[0.0, 0.0, +1.0], [0.0, 0.0, -1.0]], **dd), + torch.tensor([[0.0, 0.0, 2.0]], **dd), + ], + return_mask=False, + ) + ihelp = IndexHelper.from_numbers(numbers, GFN1_XTB, batch_mode=1) + + drv = BaseIntDriverPytorch(numbers, GFN1_XTB, ihelp, **dd) + drv.setup(positions) + + assert drv._basis_batch is not None + assert len(drv._basis_batch) == 2 + + assert drv._positions_batch is not None + assert len(drv._positions_batch) == 2 + + assert drv._positions_batch[0].shape == (2, 3) + assert (drv._positions_batch[0] == positions[0, :, :]).all() + assert drv._positions_batch[1].shape == (1, 3) + assert (drv._positions_batch[1] == positions[1, 0, :]).all() + + +@pytest.mark.parametrize("dtype", [torch.float, torch.double]) +def test_batch_mode1_mask(dtype: torch.dtype) -> None: + dd: DD = {"dtype": dtype, "device": DEVICE} + + numbers = torch.tensor([[3, 1], [1, 0]], device=DEVICE) + positions, mask = pack( + [ + torch.tensor([[0.0, 0.0, +1.0], [0.0, 0.0, -1.0]], **dd), + torch.tensor([[0.0, 0.0, 2.0]], **dd), + ], + return_mask=True, + ) + ihelp = IndexHelper.from_numbers(numbers, GFN1_XTB, batch_mode=1) + + drv = BaseIntDriverPytorch(numbers, GFN1_XTB, ihelp, **dd) + drv.setup(positions, mask=mask) + + assert drv._basis_batch is not None + assert len(drv._basis_batch) == 2 + + assert drv._positions_batch is not None + assert len(drv._positions_batch) == 2 + + assert drv._positions_batch[0].shape == (2, 3) + assert (drv._positions_batch[0] == positions[0, :, :]).all() + assert drv._positions_batch[1].shape == (1, 3) + assert (drv._positions_batch[1] == positions[1, 0, :]).all() + + +@pytest.mark.parametrize("dtype", [torch.float, torch.double]) +def test_batch_mode2(dtype: torch.dtype) -> None: + dd: DD = {"dtype": dtype, "device": DEVICE} + + numbers = torch.tensor([[3, 1], [1, 0]], device=DEVICE) + positions = torch.zeros((2, 2, 3), **dd) + ihelp = IndexHelper.from_numbers(numbers, GFN1_XTB, batch_mode=2) + + drv = BaseIntDriverPytorch(numbers, GFN1_XTB, ihelp, **dd) + drv.setup(positions) + + assert drv._basis_batch is not None + assert len(drv._basis_batch) == 2 + + assert drv._positions_batch is not None + assert len(drv._positions_batch) == 2 diff --git a/test/test_integrals/test_factory.py b/test/test_integrals/test_factory.py new file mode 100644 index 000000000..17e945a15 --- /dev/null +++ b/test/test_integrals/test_factory.py @@ -0,0 +1,191 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Test factories for integral classes. +""" + +from __future__ import annotations + +import pytest +import torch + +from dxtb import GFN1_XTB, GFN2_XTB, IndexHelper, labels +from dxtb._src.integral import factory +from dxtb._src.xtb.gfn1 import GFN1Hamiltonian +from dxtb._src.xtb.gfn2 import GFN2Hamiltonian +from dxtb.integrals import factories, types + +numbers = torch.tensor([14, 1, 1, 1, 1]) +positions = torch.tensor( + [ + [+0.00000000000000, +0.00000000000000, +0.00000000000000], + [+1.61768389755830, +1.61768389755830, -1.61768389755830], + [-1.61768389755830, -1.61768389755830, -1.61768389755830], + [+1.61768389755830, -1.61768389755830, +1.61768389755830], + [-1.61768389755830, +1.61768389755830, +1.61768389755830], + ] +) + + +def test_fail() -> None: + ihelp = IndexHelper.from_numbers(numbers, GFN1_XTB) + + with pytest.raises(ValueError): + par1 = GFN1_XTB.model_copy(deep=True) + par1.meta = None + factories.new_hcore(numbers, par1, ihelp) + + with pytest.raises(ValueError): + par1 = GFN1_XTB.model_copy(deep=True) + assert par1.meta is not None + + par1.meta.name = None + factories.new_hcore(numbers, par1, ihelp) + + with pytest.raises(ValueError): + par1 = GFN1_XTB.model_copy(deep=True) + assert par1.meta is not None + + par1.meta.name = "fail" + factories.new_hcore(numbers, par1, ihelp) + + +def test_hcore() -> None: + ihelp = IndexHelper.from_numbers(numbers, GFN1_XTB) + h0_gfn1 = factory.new_hcore(numbers, GFN1_XTB, ihelp) + assert isinstance(h0_gfn1, GFN1Hamiltonian) + + ihelp = IndexHelper.from_numbers(numbers, GFN2_XTB) + h0_gfn2 = factory.new_hcore(numbers, GFN2_XTB, ihelp) + assert isinstance(h0_gfn2, GFN2Hamiltonian) + + +def test_hcore_gfn1() -> None: + ihelp = IndexHelper.from_numbers(numbers, GFN1_XTB) + + h0 = factory.new_hcore_gfn1(numbers, ihelp) + assert isinstance(h0, GFN1Hamiltonian) + + h0 = factory.new_hcore_gfn1(numbers, ihelp, GFN1_XTB) + assert isinstance(h0, GFN1Hamiltonian) + + +def test_hcore_gfn2() -> None: + ihelp = IndexHelper.from_numbers(numbers, GFN2_XTB) + + h0 = factory.new_hcore_gfn2(numbers, ihelp) + assert isinstance(h0, GFN2Hamiltonian) + + h0 = factory.new_hcore_gfn2(numbers, ihelp, GFN2_XTB) + assert isinstance(h0, GFN2Hamiltonian) + + +################################################################################ + + +def test_overlap_fail() -> None: + with pytest.raises(ValueError): + factory.new_overlap(-1) + + +def test_overlap() -> None: + cls = factory.new_overlap(labels.INTDRIVER_LIBCINT) + assert isinstance(cls, types.OverlapIntegral) + + cls = factory.new_overlap(labels.INTDRIVER_ANALYTICAL) + assert isinstance(cls, types.OverlapIntegral) + + +def test_overlap_libcint() -> None: + cls = factory.new_overlap_libcint() + assert isinstance(cls, types.OverlapIntegral) + assert cls.device == torch.device("cpu") + + cls = factory.new_overlap_libcint(force_cpu_for_libcint=True) + assert isinstance(cls, types.OverlapIntegral) + assert cls.device == torch.device("cpu") + + +def test_overlap_pytorch() -> None: + cls = factory.new_overlap_pytorch() + assert isinstance(cls, types.OverlapIntegral) + + +################################################################################ + + +def test_dipint_fail() -> None: + with pytest.raises(ValueError): + factory.new_dipint(-1) + + +def test_dipint() -> None: + cls = factory.new_dipint(labels.INTDRIVER_LIBCINT) + assert isinstance(cls, types.DipoleIntegral) + + with pytest.raises(NotImplementedError): + cls = factory.new_dipint(labels.INTDRIVER_ANALYTICAL) + assert isinstance(cls, types.DipoleIntegral) + + +def test_dipint_libcint() -> None: + cls = factory.new_dipint_libcint() + assert isinstance(cls, types.DipoleIntegral) + assert cls.device == torch.device("cpu") + + cls = factory.new_dipint_libcint(force_cpu_for_libcint=True) + assert isinstance(cls, types.DipoleIntegral) + assert cls.device == torch.device("cpu") + + +def test_dipint_pytorch() -> None: + with pytest.raises(NotImplementedError): + cls = factory.new_dipint_pytorch() + assert isinstance(cls, types.DipoleIntegral) + + +################################################################################ + + +def test_quadint_fail() -> None: + with pytest.raises(ValueError): + factory.new_quadint(-1) + + +def test_quadint() -> None: + cls = factory.new_quadint(labels.INTDRIVER_LIBCINT) + assert isinstance(cls, types.QuadrupoleIntegral) + + with pytest.raises(NotImplementedError): + cls = factory.new_quadint(labels.INTDRIVER_ANALYTICAL) + assert isinstance(cls, types.QuadrupoleIntegral) + + +def test_quadint_libcint() -> None: + cls = factory.new_quadint_libcint() + assert isinstance(cls, types.QuadrupoleIntegral) + assert cls.device == torch.device("cpu") + + cls = factory.new_quadint_libcint(force_cpu_for_libcint=True) + assert isinstance(cls, types.QuadrupoleIntegral) + assert cls.device == torch.device("cpu") + + +def test_quadint_pytorch() -> None: + with pytest.raises(NotImplementedError): + cls = factory.new_quadint_pytorch() + assert isinstance(cls, types.QuadrupoleIntegral) diff --git a/test/test_integrals/test_types.py b/test/test_integrals/test_types.py index e7b244caa..fb7dbff4a 100644 --- a/test/test_integrals/test_types.py +++ b/test/test_integrals/test_types.py @@ -23,10 +23,19 @@ import pytest import torch -from dxtb import GFN1_XTB, IndexHelper -from dxtb.integrals.factories import new_hcore +from dxtb import GFN1_XTB, IndexHelper, labels +from dxtb.integrals.factories import new_dipint, new_hcore, new_quadint numbers = torch.tensor([14, 1, 1, 1, 1]) +positions = torch.tensor( + [ + [+0.00000000000000, -0.00000000000000, +0.00000000000000], + [+1.61768389755830, +1.61768389755830, -1.61768389755830], + [-1.61768389755830, -1.61768389755830, -1.61768389755830], + [+1.61768389755830, -1.61768389755830, +1.61768389755830], + [-1.61768389755830, +1.61768389755830, +1.61768389755830], + ] +) def test_fail() -> None: @@ -38,3 +47,24 @@ def test_fail() -> None: par1.meta.name = "fail" new_hcore(numbers, par1, ihelp) + + +def test_dipole_fail() -> None: + i = new_dipint(labels.INTDRIVER_LIBCINT) + + with pytest.raises(RuntimeError): + fake_ovlp = torch.eye(3, dtype=torch.float64) + i.shift_r0_rj(fake_ovlp, positions) + + +def test_quadrupole_fail() -> None: + i = new_quadint(labels.INTDRIVER_LIBCINT) + + with pytest.raises(RuntimeError): + fake_ovlp = torch.eye(3, dtype=torch.float64) + fake_r0 = torch.zeros(3, dtype=torch.float64) + i.shift_r0r0_rjrj(fake_r0, fake_ovlp, positions) + + with pytest.raises(RuntimeError): + i._matrix = torch.eye(3, dtype=torch.float64) + i.traceless() diff --git a/test/test_integrals/test_wrappers.py b/test/test_integrals/test_wrappers.py index b29572eeb..4bb734dbf 100644 --- a/test/test_integrals/test_wrappers.py +++ b/test/test_integrals/test_wrappers.py @@ -39,12 +39,12 @@ def test_fail() -> None: - with pytest.raises(TypeError): + with pytest.raises(ValueError): par1 = GFN1_XTB.model_copy(deep=True) par1.meta = None wrappers.hcore(numbers, positions, par1) - with pytest.raises(TypeError): + with pytest.raises(ValueError): par1 = GFN1_XTB.model_copy(deep=True) assert par1.meta is not None