diff --git a/src/dxtb/_src/integral/base.py b/src/dxtb/_src/integral/base.py index 36bc2be9..1290935c 100644 --- a/src/dxtb/_src/integral/base.py +++ b/src/dxtb/_src/integral/base.py @@ -31,7 +31,7 @@ from dxtb import IndexHelper from dxtb._src.basis.bas import Basis from dxtb._src.param import Param -from dxtb._src.typing import Literal, PathLike, Tensor, TensorLike +from dxtb._src.typing import Literal, PathLike, Self, Tensor, TensorLike from .abc import IntegralABC from .utils import snorm @@ -311,6 +311,24 @@ def normalize(self, norm: Tensor | None = None) -> None: self.matrix = einsum(einsum_str, self.matrix, norm, norm) + def normalize_gradient(self, norm: Tensor | None = None) -> None: + """ + Normalize the gradient (changes ``self.gradient``). + + Parameters + ---------- + norm : Tensor + Overlap norm to normalize the integral. + """ + if norm is None: + if self.norm is not None: + norm = self.norm + else: + norm = snorm(self.matrix) + + einsum_str = "...ijx,...i,...j->...ijx" + self.gradient = einsum(einsum_str, self.gradient, norm, norm) + def to_pt(self, path: PathLike | None = None) -> None: """ Save the integral matrix to a file. @@ -326,6 +344,35 @@ def to_pt(self, path: PathLike | None = None) -> None: torch.save(self.matrix, path) + def to(self, device: torch.device) -> Self: + """ + Returns a copy of the integral on the specified device "``device``". + + This is essentially a wrapper around the :meth:`to` method of the + :class:`TensorLike` class, but explicitly also moves the integral + matrix. + + Parameters + ---------- + device : torch.device + Device to which all associated tensors should be moved. + + Returns + ------- + Self + A copy of the integral placed on the specified device. + """ + if self._gradient is not None: + self._gradient = self._gradient.to(device=device) + + if self._norm is not None: + self._norm = self._norm.to(device=device) + + if self._matrix is not None: + self._matrix = self._matrix.to(device=device) + + return super().to(device=device) + @property def matrix(self) -> Tensor: if self._matrix is None: diff --git a/src/dxtb/_src/integral/container.py b/src/dxtb/_src/integral/container.py index 104f7bf8..6bff2aa9 100644 --- a/src/dxtb/_src/integral/container.py +++ b/src/dxtb/_src/integral/container.py @@ -198,10 +198,29 @@ def build_overlap(self, positions: Tensor, **kwargs: Any) -> Tensor: # move integral to the correct device... if self.mgr.force_cpu_for_libcint is True: - # ... but only if no other multipole integrals are required + # ...but only if no other multipole integrals are required if self.intlevel <= labels.INTLEVEL_HCORE: self.overlap = self.overlap.to(device=self.device) + # DEVNOTE: This is a sanity check to avoid the following + # scenario: When the overlap is built on CPU (forced by + # libcint), it will be moved to the correct device after + # the last integral is built. Now, in case of a second + # call with an invalid cache, the overlap class already + # is on the correct device, but the matrix is not. Hence, + # the `to` method must be called on the matrix as well, + # which is handled in the custom `to` method of all + # integrals. + # Also make sure to pass the `force_cpu_for_libcint` + # flag when instantiating the integral classes. + assert self.overlap is not None + if self.overlap.device != self.overlap.matrix.device: + raise RuntimeError( + f"Device of '{self.overlap.label}' integral class " + f"({self.overlap.device}) and its matrix " + f"({self.overlap.matrix.device}) do not match." + ) + logger.debug("Overlap integral: All finished.") return self.overlap.matrix @@ -215,13 +234,22 @@ def grad_overlap(self, positions: Tensor, **kwargs) -> Tensor: self.mgr.setup_driver(positions, **kwargs) if self.overlap is None: - raise RuntimeError("No overlap integral provided.") + # pylint: disable=import-outside-toplevel + from .factory import new_overlap + + self.overlap = new_overlap( + self.mgr.driver_type, + **self.dd, + **kwargs, + ) logger.debug("Overlap gradient: Start.") - grad = self.overlap.get_gradient(self.mgr.driver, **kwargs) + self.overlap.get_gradient(self.mgr.driver, **kwargs) + self.overlap.gradient = self.overlap.gradient.to(self.device) + self.overlap.normalize_gradient() logger.debug("Overlap gradient: All finished.") - return grad.to(self.device) + return self.overlap.gradient.to(self.device) # dipole @@ -282,7 +310,10 @@ def build_dipole(self, positions: Tensor, shift: bool = True, **kwargs: Any): # move integral to the correct device, but only if no other multipole # integrals are required - if self.mgr.force_cpu_for_libcint and self.intlevel <= labels.INTLEVEL_DIPOLE: + if ( + self.mgr.force_cpu_for_libcint is True + and self.intlevel <= labels.INTLEVEL_DIPOLE + ): self.dipole = self.dipole.to(device=self.device) self.overlap = self.overlap.to(device=self.device) @@ -378,7 +409,7 @@ def build_quadrupole( # Finally, we move the integral to the correct device, but only if # no other multipole integrals are required. if ( - self.mgr.force_cpu_for_libcint + self.mgr.force_cpu_for_libcint is True and self.intlevel <= labels.INTLEVEL_QUADRUPOLE ): self.overlap = self.overlap.to(self.device) @@ -413,11 +444,12 @@ def checks(self) -> None: f"Data type of '{cls.label}' integral ({cls.dtype}) and " f"integral container ({self.dtype}) do not match." ) - if cls.device != self.device: - raise RuntimeError( - f"Device of '{cls.label}' integral ({cls.device}) and " - f"integral container ({self.device}) do not match." - ) + if self.mgr.force_cpu_for_libcint is False: + if cls.device != self.device: + raise RuntimeError( + f"Device of '{cls.label}' integral ({cls.device}) and " + f"integral container ({self.device}) do not match." + ) if name != "hcore": assert not isinstance(cls, BaseHamiltonian) diff --git a/src/dxtb/_src/integral/driver/libcint/overlap.py b/src/dxtb/_src/integral/driver/libcint/overlap.py index 7ba09c44..28f8ec00 100644 --- a/src/dxtb/_src/integral/driver/libcint/overlap.py +++ b/src/dxtb/_src/integral/driver/libcint/overlap.py @@ -23,6 +23,7 @@ from __future__ import annotations +import torch from tad_mctc.batch import pack from tad_mctc.math import einsum @@ -41,11 +42,11 @@ class OverlapLibcint(OverlapIntegral, IntegralLibcint): """ Overlap integral from atomic orbitals. - Use the :meth:`.build` method to calculate the overlap integral. The + Use the :meth:`build` method to calculate the overlap integral. The returned matrix uses a custom autograd function to calculate the backward pass with the analytical gradient. For the full gradient, i.e., a matrix of shape ``(..., norb, norb, 3)``, - the :meth:`.get_gradient` method should be used. + the :meth:`get_gradient` method should be used. """ def build(self, driver: IntDriverLibcint) -> Tensor: @@ -64,23 +65,12 @@ def build(self, driver: IntDriverLibcint) -> Tensor: """ super().checks(driver) - def fcn(driver: libcint.LibcintWrapper) -> tuple[Tensor, Tensor]: - s = libcint.overlap(driver) - norm = snorm(s) - - return s, norm - # batched mode if driver.ihelp.batch_mode > 0: assert isinstance(driver.drv, list) - slist = [] - nlist = [] - - for d in driver.drv: - mat, norm = fcn(d) - slist.append(mat) - nlist.append(norm) + slist = [libcint.overlap(d) for d in driver.drv] + nlist = [snorm(s) for s in slist] self.norm = pack(nlist) self.matrix = pack(slist) @@ -89,7 +79,8 @@ def fcn(driver: libcint.LibcintWrapper) -> tuple[Tensor, Tensor]: # single mode assert isinstance(driver.drv, libcint.LibcintWrapper) - self.matrix, self.norm = fcn(driver.drv) + self.matrix = libcint.overlap(driver.drv) + self.norm = snorm(self.matrix) return self.matrix def get_gradient(self, driver: IntDriverLibcint) -> Tensor: @@ -108,25 +99,17 @@ def get_gradient(self, driver: IntDriverLibcint) -> Tensor: """ super().checks(driver) - def fcn(driver: libcint.LibcintWrapper, norm: Tensor) -> Tensor: + # build norm if not already available + if self.norm is None: + self.build(driver) + + def fcn(driver: libcint.LibcintWrapper) -> Tensor: # (3, norb, norb) grad = libcint.int1e("ipovlp", driver) - if self.normalize is False: - return -einsum("...xij->...ijx", grad) - - # normalize and move xyz dimension to last, which is required for - # the reduction (only works with extra dimension in last) - return -einsum("...xij,...i,...j->...ijx", grad, norm, norm) - - # build norm if not already available - if self.norm is None: - if driver.ihelp.batch_mode > 0: - assert isinstance(driver.drv, list) - self.norm = pack([snorm(libcint.overlap(d)) for d in driver.drv]) - else: - assert isinstance(driver.drv, libcint.LibcintWrapper) - self.norm = snorm(libcint.overlap(driver.drv)) + # Move xyz dimension to last, which is required for the + # reduction (only works with extra dimension in last) + return -einsum("...xij->...ijx", grad) # batched mode if driver.ihelp.batch_mode > 0: @@ -137,24 +120,12 @@ def fcn(driver: libcint.LibcintWrapper, norm: Tensor) -> Tensor: ) if driver.ihelp.batch_mode == 1: - # pylint: disable=import-outside-toplevel - from tad_mctc.batch import deflate - - self.grad = pack( - [ - fcn(driver, deflate(norm)) - for driver, norm in zip(driver.drv, self.norm) - ] - ) - return self.grad + self.gradient = pack([fcn(d) for d in driver.drv]) + return self.gradient + elif driver.ihelp.batch_mode == 2: - self.grad = pack( - [ - fcn(driver, norm) # no deflating here - for driver, norm in zip(driver.drv, self.norm) - ] - ) - return self.grad + self.gradient = torch.stack([fcn(d) for d in driver.drv]) + return self.gradient raise ValueError(f"Unknown batch mode '{driver.ihelp.batch_mode}'.") @@ -165,5 +136,6 @@ def fcn(driver: libcint.LibcintWrapper, norm: Tensor) -> Tensor: "driver instance itself seems to be batched." ) - self.grad = fcn(driver.drv, self.norm) - return self.grad + print("aksdjkasd") + self.gradient = fcn(driver.drv) + return self.gradient diff --git a/src/dxtb/_src/integral/driver/manager.py b/src/dxtb/_src/integral/driver/manager.py index 8db152fc..d7709a74 100644 --- a/src/dxtb/_src/integral/driver/manager.py +++ b/src/dxtb/_src/integral/driver/manager.py @@ -28,7 +28,6 @@ import torch from dxtb import IndexHelper, labels -from dxtb._src.constants import labels from dxtb._src.param import Param from dxtb._src.typing import TYPE_CHECKING, Any, Tensor, TensorLike @@ -56,8 +55,8 @@ def __init__( _driver: IntDriver | None = None, device: torch.device | None = None, dtype: torch.dtype | None = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: super().__init__(device=device, dtype=dtype) # per default, libcint is run on the CPU @@ -87,7 +86,7 @@ def create_driver(self, numbers: Tensor, par: Param, ihelp: IndexHelper) -> None # pylint: disable=import-outside-toplevel from .libcint import IntDriverLibcint as _IntDriver - if self.force_cpu_for_libcint: + if self.force_cpu_for_libcint is True: device = torch.device("cpu") numbers = numbers.to(device=device) ihelp = ihelp.to(device=device) @@ -99,11 +98,12 @@ def create_driver(self, numbers: Tensor, par: Param, ihelp: IndexHelper) -> None elif self.driver_type == labels.INTDRIVER_AUTOGRAD: # pylint: disable=import-outside-toplevel from .pytorch import IntDriverPytorchNoAnalytical as _IntDriver + else: raise ValueError(f"Unknown integral driver '{self.driver_type}'.") self.driver = _IntDriver( - numbers, par, ihelp, device=self.device, dtype=self.dtype + numbers, par, ihelp, device=ihelp.device, dtype=self.dtype ) def setup_driver(self, positions: Tensor, **kwargs: Any) -> None: diff --git a/src/dxtb/_src/integral/driver/pytorch/overlap.py b/src/dxtb/_src/integral/driver/pytorch/overlap.py index 7511fcad..1a4ceba3 100644 --- a/src/dxtb/_src/integral/driver/pytorch/overlap.py +++ b/src/dxtb/_src/integral/driver/pytorch/overlap.py @@ -121,11 +121,11 @@ def get_gradient(self, driver: BaseIntDriverPytorch) -> Tensor: super().checks(driver) if driver.ihelp.batch_mode > 0: - self.grad = self._batch(driver.eval_ovlp_grad, driver) + self.gradient = self._batch(driver.eval_ovlp_grad, driver) else: - self.grad = self._single(driver.eval_ovlp_grad, driver) + self.gradient = self._single(driver.eval_ovlp_grad, driver) - return self.grad + return self.gradient def _single(self, fcn: OverlapFunction, driver: BaseIntDriverPytorch) -> Tensor: if not isinstance(driver, BaseIntDriverPytorch): diff --git a/src/dxtb/_src/integral/wrappers.py b/src/dxtb/_src/integral/wrappers.py index 0f80d8eb..6b97b742 100644 --- a/src/dxtb/_src/integral/wrappers.py +++ b/src/dxtb/_src/integral/wrappers.py @@ -129,7 +129,7 @@ def hcore(numbers: Tensor, positions: Tensor, par: Param, **kwargs: Any) -> Tens raise ValueError(f"Unknown Hamiltonian type '{name}'.") ovlp = overlap(numbers, positions, par) - return h0.build(positions, ovlp) + return h0.build(positions, ovlp.to(h0.device)) def overlap(numbers: Tensor, positions: Tensor, par: Param, **kwargs: Any) -> Tensor: @@ -242,19 +242,19 @@ def _integral( driver_name = kwargs.pop("driver", labels.INTDRIVER_LIBCINT) # setup driver for integral calculation - drv_mngr = DriverManager(driver_name, **dd) - drv_mngr.create_driver(numbers, par, ihelp) - drv_mngr.driver.setup(positions) + drv_mgr = DriverManager(driver_name, **dd) + drv_mgr.create_driver(numbers, par, ihelp) + drv_mgr.driver.setup(positions) ########### # Overlap # ########### if integral_type == "_overlap": - integral = new_overlap(drv_mngr.driver_type, **dd, **kwargs) + integral = new_overlap(drv_mgr.driver_type, **dd, **kwargs) # actual integral calculation - integral.build(drv_mngr.driver) + integral.build(drv_mgr.driver) if normalize is True: integral.normalize(integral.norm) @@ -266,19 +266,19 @@ def _integral( ############# # multipole integrals require the overlap for normalization - ovlp = new_overlap(drv_mngr.driver_type, **dd, **kwargs) + ovlp = new_overlap(drv_mgr.driver_type, **dd, **kwargs) if ovlp._matrix is None or ovlp.norm is None: - ovlp.build(drv_mngr.driver) + ovlp.build(drv_mgr.driver) if integral_type == "_dipole": - integral = new_dipint(driver=drv_mngr.driver_type, **dd, **kwargs) + integral = new_dipint(driver=drv_mgr.driver_type, **dd, **kwargs) elif integral_type == "_quadrupole": - integral = new_quadint(driver=drv_mngr.driver_type, **dd, **kwargs) + integral = new_quadint(driver=drv_mgr.driver_type, **dd, **kwargs) else: raise ValueError(f"Unknown integral type '{integral_type}'.") # actual integral calculation - integral.build(drv_mngr.driver) + integral.build(drv_mgr.driver) if normalize is True: integral.normalize(ovlp.norm) diff --git a/test/test_interaction/test_grad.py b/test/test_interaction/test_grad.py index c9d7e8e3..38403a9e 100644 --- a/test/test_interaction/test_grad.py +++ b/test/test_interaction/test_grad.py @@ -60,7 +60,7 @@ def gradchecker( def func(p: Tensor) -> Tensor: icaches = ilist.get_cache(numbers=numbers, positions=p, ihelp=ihelp) - charges = get_guess(numbers, positions, chrg, ihelp) + charges = get_guess(numbers, p, chrg, ihelp) return ilist.get_energy(charges, icaches, ihelp) return func, pos @@ -120,10 +120,10 @@ def gradchecker_batch( def func(p: Tensor) -> Tensor: icaches = ilist.get_cache(numbers=numbers, positions=p, ihelp=ihelp) - charges = get_guess(numbers, positions, chrg, ihelp) + charges = get_guess(numbers, p, chrg, ihelp) return ilist.get_energy(charges, icaches, ihelp) - return func, positions + return func, pos @pytest.mark.grad