Skip to content

Commit

Permalink
Fix moving devices
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede committed Aug 25, 2024
1 parent e76a4c5 commit d087c30
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 85 deletions.
49 changes: 48 additions & 1 deletion src/dxtb/_src/integral/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from dxtb import IndexHelper
from dxtb._src.basis.bas import Basis
from dxtb._src.param import Param
from dxtb._src.typing import Literal, PathLike, Tensor, TensorLike
from dxtb._src.typing import Literal, PathLike, Self, Tensor, TensorLike

from .abc import IntegralABC
from .utils import snorm
Expand Down Expand Up @@ -311,6 +311,24 @@ def normalize(self, norm: Tensor | None = None) -> None:

self.matrix = einsum(einsum_str, self.matrix, norm, norm)

def normalize_gradient(self, norm: Tensor | None = None) -> None:
"""
Normalize the gradient (changes ``self.gradient``).
Parameters
----------
norm : Tensor
Overlap norm to normalize the integral.
"""
if norm is None:
if self.norm is not None:
norm = self.norm
else:
norm = snorm(self.matrix)

einsum_str = "...ijx,...i,...j->...ijx"
self.gradient = einsum(einsum_str, self.gradient, norm, norm)

def to_pt(self, path: PathLike | None = None) -> None:
"""
Save the integral matrix to a file.
Expand All @@ -326,6 +344,35 @@ def to_pt(self, path: PathLike | None = None) -> None:

torch.save(self.matrix, path)

def to(self, device: torch.device) -> Self:
"""
Returns a copy of the integral on the specified device "``device``".
This is essentially a wrapper around the :meth:`to` method of the
:class:`TensorLike` class, but explicitly also moves the integral
matrix.
Parameters
----------
device : torch.device
Device to which all associated tensors should be moved.
Returns
-------
Self
A copy of the integral placed on the specified device.
"""
if self._gradient is not None:
self._gradient = self._gradient.to(device=device)

if self._norm is not None:
self._norm = self._norm.to(device=device)

if self._matrix is not None:
self._matrix = self._matrix.to(device=device)

return super().to(device=device)

@property
def matrix(self) -> Tensor:
if self._matrix is None:
Expand Down
54 changes: 43 additions & 11 deletions src/dxtb/_src/integral/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,29 @@ def build_overlap(self, positions: Tensor, **kwargs: Any) -> Tensor:

# move integral to the correct device...
if self.mgr.force_cpu_for_libcint is True:
# ... but only if no other multipole integrals are required
# ...but only if no other multipole integrals are required
if self.intlevel <= labels.INTLEVEL_HCORE:
self.overlap = self.overlap.to(device=self.device)

# DEVNOTE: This is a sanity check to avoid the following
# scenario: When the overlap is built on CPU (forced by
# libcint), it will be moved to the correct device after
# the last integral is built. Now, in case of a second
# call with an invalid cache, the overlap class already
# is on the correct device, but the matrix is not. Hence,
# the `to` method must be called on the matrix as well,
# which is handled in the custom `to` method of all
# integrals.
# Also make sure to pass the `force_cpu_for_libcint`
# flag when instantiating the integral classes.
assert self.overlap is not None
if self.overlap.device != self.overlap.matrix.device:
raise RuntimeError(

Check warning on line 218 in src/dxtb/_src/integral/container.py

View check run for this annotation

Codecov / codecov/patch

src/dxtb/_src/integral/container.py#L218

Added line #L218 was not covered by tests
f"Device of '{self.overlap.label}' integral class "
f"({self.overlap.device}) and its matrix "
f"({self.overlap.matrix.device}) do not match."
)

logger.debug("Overlap integral: All finished.")

return self.overlap.matrix
Expand All @@ -215,13 +234,22 @@ def grad_overlap(self, positions: Tensor, **kwargs) -> Tensor:
self.mgr.setup_driver(positions, **kwargs)

if self.overlap is None:
raise RuntimeError("No overlap integral provided.")
# pylint: disable=import-outside-toplevel
from .factory import new_overlap

Check warning on line 238 in src/dxtb/_src/integral/container.py

View check run for this annotation

Codecov / codecov/patch

src/dxtb/_src/integral/container.py#L238

Added line #L238 was not covered by tests

self.overlap = new_overlap(

Check warning on line 240 in src/dxtb/_src/integral/container.py

View check run for this annotation

Codecov / codecov/patch

src/dxtb/_src/integral/container.py#L240

Added line #L240 was not covered by tests
self.mgr.driver_type,
**self.dd,
**kwargs,
)

logger.debug("Overlap gradient: Start.")
grad = self.overlap.get_gradient(self.mgr.driver, **kwargs)
self.overlap.get_gradient(self.mgr.driver, **kwargs)
self.overlap.gradient = self.overlap.gradient.to(self.device)
self.overlap.normalize_gradient()
logger.debug("Overlap gradient: All finished.")

return grad.to(self.device)
return self.overlap.gradient.to(self.device)

# dipole

Expand Down Expand Up @@ -282,7 +310,10 @@ def build_dipole(self, positions: Tensor, shift: bool = True, **kwargs: Any):

# move integral to the correct device, but only if no other multipole
# integrals are required
if self.mgr.force_cpu_for_libcint and self.intlevel <= labels.INTLEVEL_DIPOLE:
if (
self.mgr.force_cpu_for_libcint is True
and self.intlevel <= labels.INTLEVEL_DIPOLE
):
self.dipole = self.dipole.to(device=self.device)
self.overlap = self.overlap.to(device=self.device)

Expand Down Expand Up @@ -378,7 +409,7 @@ def build_quadrupole(
# Finally, we move the integral to the correct device, but only if
# no other multipole integrals are required.
if (
self.mgr.force_cpu_for_libcint
self.mgr.force_cpu_for_libcint is True
and self.intlevel <= labels.INTLEVEL_QUADRUPOLE
):
self.overlap = self.overlap.to(self.device)
Expand Down Expand Up @@ -413,11 +444,12 @@ def checks(self) -> None:
f"Data type of '{cls.label}' integral ({cls.dtype}) and "
f"integral container ({self.dtype}) do not match."
)
if cls.device != self.device:
raise RuntimeError(
f"Device of '{cls.label}' integral ({cls.device}) and "
f"integral container ({self.device}) do not match."
)
if self.mgr.force_cpu_for_libcint is False:
if cls.device != self.device:
raise RuntimeError(

Check warning on line 449 in src/dxtb/_src/integral/container.py

View check run for this annotation

Codecov / codecov/patch

src/dxtb/_src/integral/container.py#L449

Added line #L449 was not covered by tests
f"Device of '{cls.label}' integral ({cls.device}) and "
f"integral container ({self.device}) do not match."
)

if name != "hcore":
assert not isinstance(cls, BaseHamiltonian)
Expand Down
74 changes: 23 additions & 51 deletions src/dxtb/_src/integral/driver/libcint/overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from __future__ import annotations

import torch
from tad_mctc.batch import pack
from tad_mctc.math import einsum

Expand All @@ -41,11 +42,11 @@ class OverlapLibcint(OverlapIntegral, IntegralLibcint):
"""
Overlap integral from atomic orbitals.
Use the :meth:`.build` method to calculate the overlap integral. The
Use the :meth:`build` method to calculate the overlap integral. The
returned matrix uses a custom autograd function to calculate the
backward pass with the analytical gradient.
For the full gradient, i.e., a matrix of shape ``(..., norb, norb, 3)``,
the :meth:`.get_gradient` method should be used.
the :meth:`get_gradient` method should be used.
"""

def build(self, driver: IntDriverLibcint) -> Tensor:
Expand All @@ -64,23 +65,12 @@ def build(self, driver: IntDriverLibcint) -> Tensor:
"""
super().checks(driver)

def fcn(driver: libcint.LibcintWrapper) -> tuple[Tensor, Tensor]:
s = libcint.overlap(driver)
norm = snorm(s)

return s, norm

# batched mode
if driver.ihelp.batch_mode > 0:
assert isinstance(driver.drv, list)

slist = []
nlist = []

for d in driver.drv:
mat, norm = fcn(d)
slist.append(mat)
nlist.append(norm)
slist = [libcint.overlap(d) for d in driver.drv]
nlist = [snorm(s) for s in slist]

self.norm = pack(nlist)
self.matrix = pack(slist)
Expand All @@ -89,7 +79,8 @@ def fcn(driver: libcint.LibcintWrapper) -> tuple[Tensor, Tensor]:
# single mode
assert isinstance(driver.drv, libcint.LibcintWrapper)

self.matrix, self.norm = fcn(driver.drv)
self.matrix = libcint.overlap(driver.drv)
self.norm = snorm(self.matrix)
return self.matrix

def get_gradient(self, driver: IntDriverLibcint) -> Tensor:
Expand All @@ -108,25 +99,17 @@ def get_gradient(self, driver: IntDriverLibcint) -> Tensor:
"""
super().checks(driver)

def fcn(driver: libcint.LibcintWrapper, norm: Tensor) -> Tensor:
# build norm if not already available
if self.norm is None:
self.build(driver)

Check warning on line 104 in src/dxtb/_src/integral/driver/libcint/overlap.py

View check run for this annotation

Codecov / codecov/patch

src/dxtb/_src/integral/driver/libcint/overlap.py#L104

Added line #L104 was not covered by tests

def fcn(driver: libcint.LibcintWrapper) -> Tensor:
# (3, norb, norb)
grad = libcint.int1e("ipovlp", driver)

if self.normalize is False:
return -einsum("...xij->...ijx", grad)

# normalize and move xyz dimension to last, which is required for
# the reduction (only works with extra dimension in last)
return -einsum("...xij,...i,...j->...ijx", grad, norm, norm)

# build norm if not already available
if self.norm is None:
if driver.ihelp.batch_mode > 0:
assert isinstance(driver.drv, list)
self.norm = pack([snorm(libcint.overlap(d)) for d in driver.drv])
else:
assert isinstance(driver.drv, libcint.LibcintWrapper)
self.norm = snorm(libcint.overlap(driver.drv))
# Move xyz dimension to last, which is required for the
# reduction (only works with extra dimension in last)
return -einsum("...xij->...ijx", grad)

# batched mode
if driver.ihelp.batch_mode > 0:
Expand All @@ -137,24 +120,12 @@ def fcn(driver: libcint.LibcintWrapper, norm: Tensor) -> Tensor:
)

if driver.ihelp.batch_mode == 1:
# pylint: disable=import-outside-toplevel
from tad_mctc.batch import deflate

self.grad = pack(
[
fcn(driver, deflate(norm))
for driver, norm in zip(driver.drv, self.norm)
]
)
return self.grad
self.gradient = pack([fcn(d) for d in driver.drv])
return self.gradient

Check warning on line 124 in src/dxtb/_src/integral/driver/libcint/overlap.py

View check run for this annotation

Codecov / codecov/patch

src/dxtb/_src/integral/driver/libcint/overlap.py#L124

Added line #L124 was not covered by tests

elif driver.ihelp.batch_mode == 2:
self.grad = pack(
[
fcn(driver, norm) # no deflating here
for driver, norm in zip(driver.drv, self.norm)
]
)
return self.grad
self.gradient = torch.stack([fcn(d) for d in driver.drv])
return self.gradient

Check warning on line 128 in src/dxtb/_src/integral/driver/libcint/overlap.py

View check run for this annotation

Codecov / codecov/patch

src/dxtb/_src/integral/driver/libcint/overlap.py#L128

Added line #L128 was not covered by tests

raise ValueError(f"Unknown batch mode '{driver.ihelp.batch_mode}'.")

Expand All @@ -165,5 +136,6 @@ def fcn(driver: libcint.LibcintWrapper, norm: Tensor) -> Tensor:
"driver instance itself seems to be batched."
)

self.grad = fcn(driver.drv, self.norm)
return self.grad
print("aksdjkasd")
self.gradient = fcn(driver.drv)
return self.gradient
10 changes: 5 additions & 5 deletions src/dxtb/_src/integral/driver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import torch

from dxtb import IndexHelper, labels
from dxtb._src.constants import labels
from dxtb._src.param import Param
from dxtb._src.typing import TYPE_CHECKING, Any, Tensor, TensorLike

Expand Down Expand Up @@ -56,8 +55,8 @@ def __init__(
_driver: IntDriver | None = None,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
**kwargs,
):
**kwargs: Any,
) -> None:
super().__init__(device=device, dtype=dtype)

# per default, libcint is run on the CPU
Expand Down Expand Up @@ -87,7 +86,7 @@ def create_driver(self, numbers: Tensor, par: Param, ihelp: IndexHelper) -> None
# pylint: disable=import-outside-toplevel
from .libcint import IntDriverLibcint as _IntDriver

if self.force_cpu_for_libcint:
if self.force_cpu_for_libcint is True:
device = torch.device("cpu")
numbers = numbers.to(device=device)
ihelp = ihelp.to(device=device)
Expand All @@ -99,11 +98,12 @@ def create_driver(self, numbers: Tensor, par: Param, ihelp: IndexHelper) -> None
elif self.driver_type == labels.INTDRIVER_AUTOGRAD:
# pylint: disable=import-outside-toplevel
from .pytorch import IntDriverPytorchNoAnalytical as _IntDriver

else:
raise ValueError(f"Unknown integral driver '{self.driver_type}'.")

Check warning on line 103 in src/dxtb/_src/integral/driver/manager.py

View check run for this annotation

Codecov / codecov/patch

src/dxtb/_src/integral/driver/manager.py#L103

Added line #L103 was not covered by tests

self.driver = _IntDriver(
numbers, par, ihelp, device=self.device, dtype=self.dtype
numbers, par, ihelp, device=ihelp.device, dtype=self.dtype
)

def setup_driver(self, positions: Tensor, **kwargs: Any) -> None:
Expand Down
6 changes: 3 additions & 3 deletions src/dxtb/_src/integral/driver/pytorch/overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,11 @@ def get_gradient(self, driver: BaseIntDriverPytorch) -> Tensor:
super().checks(driver)

if driver.ihelp.batch_mode > 0:
self.grad = self._batch(driver.eval_ovlp_grad, driver)
self.gradient = self._batch(driver.eval_ovlp_grad, driver)

Check warning on line 124 in src/dxtb/_src/integral/driver/pytorch/overlap.py

View check run for this annotation

Codecov / codecov/patch

src/dxtb/_src/integral/driver/pytorch/overlap.py#L124

Added line #L124 was not covered by tests
else:
self.grad = self._single(driver.eval_ovlp_grad, driver)
self.gradient = self._single(driver.eval_ovlp_grad, driver)

return self.grad
return self.gradient

def _single(self, fcn: OverlapFunction, driver: BaseIntDriverPytorch) -> Tensor:
if not isinstance(driver, BaseIntDriverPytorch):
Expand Down
Loading

0 comments on commit d087c30

Please sign in to comment.