Skip to content

Commit

Permalink
More tests
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede committed Jan 6, 2025
1 parent 9ff869b commit d1351e3
Show file tree
Hide file tree
Showing 11 changed files with 288 additions and 37 deletions.
52 changes: 30 additions & 22 deletions src/dxtb/_src/components/interactions/dispersion/d4sc.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,16 @@

from ..base import Interaction, InteractionCache

__all__ = ["D4SC", "LABEL_D4SC", "new_d4sc"]
__all__ = ["DispersionD4SC", "LABEL_DispersionD4SC", "new_d4sc"]


LABEL_D4SC = "D4SC"
"""Label for the :class:`.D4SC` interaction, coinciding with the class name."""
LABEL_DispersionD4SC = "DispersionD4SC"
"""Label for the :class:`.DispersionD4SC` interaction, coinciding with the class name."""


class D4SCCache(InteractionCache, TensorLike):
class DispersionDispersionD4SCCache(InteractionCache, TensorLike):
"""
Restart data for the :class:`.D4SC` interaction.
Restart data for the :class:`.DispersionD4SC` interaction.
Note
----
Expand Down Expand Up @@ -127,9 +127,9 @@ def restore(self) -> None:
self.dispmat = self.__store.dispmat


class D4SC(Interaction):
class DispersionD4SC(Interaction):
"""
Self-consistent D4 dispersion correction (:class:`.D4SC`).
Self-consistent D4 dispersion correction (:class:`.DispersionD4SC`).
"""

param: dict[str, Tensor]
Expand Down Expand Up @@ -175,7 +175,7 @@ def get_cache(
positions: Tensor | None = None,
ihelp: IndexHelper | None = None,
**_,
) -> D4SCCache:
) -> DispersionDispersionD4SCCache:
"""
Create restart data for individual interactions.
Expand All @@ -188,26 +188,28 @@ def get_cache(
Returns
-------
D4SCCache
DispersionDispersionD4SCCache
Restart data for the interaction.
Note
----
If the :class:`.D4SC` interaction is evaluated within the
If the :class:`.DispersionD4SC` interaction is evaluated within the
:class:`dxtb.components.InteractionList`, ``positions`` will be passed
as an argument, too. Hence, it is necessary to absorb the ``positions``
in the signature of the function (also see
:meth:`dxtb.components.Interaction.get_cache`).
"""
if numbers is None:
raise ValueError("Atomic numbers are required for D4SC cache.")
raise ValueError(
"Atomic numbers are required for DispersionD4SC cache."
)
if positions is None:
raise ValueError("Positions are required for ES2 cache.")

cachvars = (numbers.detach().clone(),)

if self.cache_is_latest(cachvars) is True:
if not isinstance(self.cache, D4SCCache):
if not isinstance(self.cache, DispersionDispersionD4SCCache):
raise TypeError(
f"Cache in {self.label} is not of type '{self.label}."
"Cache'. This can only happen if you manually manipulate "
Expand Down Expand Up @@ -241,19 +243,21 @@ def get_cache(
)
dispmat = edisp.unsqueeze(-1).unsqueeze(-1) * self.model.rc6

self.cache = D4SCCache(cn, dispmat)
self.cache = DispersionDispersionD4SCCache(cn, dispmat)

return self.cache

def get_atom_energy(self, charges: Tensor, cache: D4SCCache) -> Tensor:
def get_atom_energy(
self, charges: Tensor, cache: DispersionDispersionD4SCCache
) -> Tensor:
"""
Calculate the D4 dispersion correction energy.
Parameters
----------
charges : Tensor
Atomic charges of all atoms.
cache : D4SCCache
cache : DispersionDispersionD4SCCache
Restart data for the interaction.
Returns
Expand All @@ -269,15 +273,17 @@ def get_atom_energy(self, charges: Tensor, cache: D4SCCache) -> Tensor:
optimize=[(0, 1), (0, 1)],
)

def get_atom_potential(self, charges: Tensor, cache: D4SCCache) -> Tensor:
def get_atom_potential(
self, charges: Tensor, cache: DispersionDispersionD4SCCache
) -> Tensor:
"""
Calculate the D4 dispersion correction potential.
Parameters
----------
charges : Tensor
Atomic charges of all atoms.
cache : D4SCCache
cache : DispersionDispersionD4SCCache
Restart data for the interaction.
Returns
Expand All @@ -299,9 +305,9 @@ def new_d4sc(
par: Param,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> D4SC | None:
) -> DispersionD4SC | None:
"""
Create new instance of :class:`.D4SC`.
Create new instance of :class:`.DispersionD4SC`.
Parameters
----------
Expand All @@ -312,8 +318,8 @@ def new_d4sc(
Returns
-------
D4SC | None
Instance of the :class:`.D4SC` class or ``None`` if no :class:`.D4SC` is
DispersionD4SC | None
Instance of the :class:`.DispersionD4SC` class or ``None`` if no :class:`.DispersionD4SC` is
used.
"""

Expand Down Expand Up @@ -355,4 +361,6 @@ def new_d4sc(
model = d4.model.D4Model(numbers, ref_charges="gfn2", **dd)
cutoff = d4.cutoff.Cutoff(disp2=50.0, disp3=25.0, **dd)

return D4SC(param, model=model, rcov=rcov, r4r2=r4r2, cutoff=cutoff, **dd)
return DispersionD4SC(
param, model=model, rcov=rcov, r4r2=r4r2, cutoff=cutoff, **dd
)
10 changes: 6 additions & 4 deletions src/dxtb/_src/components/interactions/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from .container import Charges, Potential
from .coulomb.secondorder import ES2, LABEL_ES2
from .coulomb.thirdorder import ES3, LABEL_ES3
from .dispersion.d4sc import D4SC, LABEL_D4SC
from .dispersion.d4sc import DispersionD4SC, LABEL_DispersionD4SC
from .field.efield import LABEL_EFIELD, ElectricField
from .field.efieldgrad import LABEL_EFIELD_GRAD, ElectricFieldGrad

Expand Down Expand Up @@ -273,7 +273,9 @@ def get_potential(
###########################################################################

@overload
def get_interaction(self, name: Literal["D4SC"]) -> D4SC: ...
def get_interaction(
self, name: Literal["DispersionD4SC"]
) -> DispersionD4SC: ...

@overload
def get_interaction(
Expand All @@ -300,7 +302,7 @@ def get_interaction(self, name: str) -> Interaction:
@_docstring_reset
def reset_d4sc(self) -> Interaction:
"""Reset tensor attributes to a detached clone of the current state."""
return self.reset(LABEL_D4SC)
return self.reset(LABEL_DispersionD4SC)

@_docstring_reset
def reset_efield(self) -> Interaction:
Expand All @@ -326,7 +328,7 @@ def reset_es3(self) -> Interaction:

@_docstring_update
def update_d4sc(self, **kwargs: Any) -> Interaction:
return self.update(LABEL_D4SC, **kwargs)
return self.update(LABEL_DispersionD4SC, **kwargs)

@_docstring_update
def update_efield(
Expand Down
24 changes: 21 additions & 3 deletions src/dxtb/_src/components/interactions/solvation/alpb.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import torch
from tad_mctc import storch
from tad_mctc.batch import real_pairs
from tad_mctc.convert import any_to_tensor
from tad_mctc.data import VDW_D3
from tad_mctc.math import einsum

Expand All @@ -82,6 +83,8 @@

DEFAULT_KERNEL = "p16"
DEFAULT_ALPB = True
DEFAULT_BORN_SCALE = 1.0
DEFAULT_BORN_OFFSET = 0.0

__all__ = ["GeneralizedBorn", "new_solvation"]

Expand Down Expand Up @@ -338,6 +341,7 @@ def get_atom_gradient(
def new_solvation(
numbers: Tensor,
par: Param,
dielectric_constant: Tensor | float | int = 80.3,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> GeneralizedBorn | None:
Expand All @@ -360,6 +364,9 @@ def new_solvation(
if hasattr(par, "solvation") is False or par.solvation is None:
return None

if hasattr(par.solvation, "alpb") is False or par.solvation.alpb is None:
return None

if device is not None:
if device != numbers.device:
raise DeviceError(
Expand All @@ -372,11 +379,22 @@ def new_solvation(
"dtype": dtype if dtype is not None else get_default_dtype(),
}

s = par.solvation # type: ignore
epsilon = torch.tensor(s.epsilon.gexp, **dd)
s = par.solvation.alpb # type: ignore
alpb = s.alpb if hasattr(s, "alpb") else DEFAULT_ALPB
kernel = s.kernel if hasattr(s, "kernel") else DEFAULT_KERNEL
born_scale = (
s.born_scale if hasattr(s, "born_scale") else DEFAULT_BORN_SCALE
)
born_offset = (
s.born_offset if hasattr(s, "born_offset") else DEFAULT_BORN_OFFSET
)

return GeneralizedBorn(
numbers, dielectric_constant=epsilon, alpb=alpb, kernel=kernel, **dd
numbers,
dielectric_constant=any_to_tensor(dielectric_constant),
alpb=alpb,
kernel=kernel,
born_scale=born_scale,
born_offset=born_offset,
**dd,
)
3 changes: 3 additions & 0 deletions src/dxtb/_src/param/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from .meta import Meta
from .multipole import Multipole
from .repulsion import Repulsion
from .solvation import Solvation
from .thirdorder import ThirdOrder

__all__ = ["Param"]
Expand Down Expand Up @@ -91,6 +92,8 @@ class Param(BaseModel):
thirdorder: Optional[ThirdOrder] = None
"""Definition of the isotropic third-order charge interactions."""

solvation: Optional[Solvation] = None

def clean_model_dump(self) -> dict[str, Any]:
"""
Clean the model from any `None` values.
Expand Down
60 changes: 60 additions & 0 deletions src/dxtb/_src/param/solvation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# This file is part of dxtb.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Parametrization: Electrostatics (3rd order)
===========================================
Definition of the isotropic third-order onsite correction.
"""

from __future__ import annotations

from typing import Optional

from pydantic import BaseModel

__all__ = ["ALPB", "Solvation"]


class ALPB(BaseModel):
"""Representation of shell-resolved third-order electrostatics."""

alpb: bool
"""Use analytical linearized Poisson-Boltzmann model."""

kernel: str
"""
Born interaction kernels. Either classical Still kernel or P16 kernel
by Lange (JCTC 2012, 8, 1999-2011).
"""

born_scale: float
"""Scaling factor for Born radii."""

born_offset: float
"""Offset parameter for Born radii integration."""


class Solvation(BaseModel):
"""
Representation of the isotropic third-order onsite correction.
"""

alpb: Optional[ALPB] = None
"""
Whether the third order contribution is shell-dependent or only atomwise.
"""
12 changes: 11 additions & 1 deletion src/dxtb/components/dispersion.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,15 @@
from dxtb._src.components.classicals.dispersion import (
new_dispersion as new_dispersion,
)
from dxtb._src.components.interactions.dispersion import (
DispersionD4SC as DispersionD4SC,
)
from dxtb._src.components.interactions.dispersion import new_d4sc as new_d4sc

__all__ = ["DispersionD3", "DispersionD4", "new_dispersion"]
__all__ = [
"DispersionD3",
"DispersionD4",
"new_dispersion",
"new_d4sc",
"DispersionD4SC",
]
4 changes: 2 additions & 2 deletions test/test_classical/test_dispersion/test_d4sc.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_single(dtype: torch.dtype, name: str):
calc = Calculator(numbers, GFN2_XTB, opts=opts, **dd)

result = calc.singlepoint(positions, charges)
d4sc = calc.interactions.get_interaction("D4SC")
d4sc = calc.interactions.get_interaction("DispersionD4SC")
cache = d4sc.get_cache(numbers=numbers, positions=positions)

edisp = d4sc.get_energy(result.charges, cache, calc.ihelp)
Expand Down Expand Up @@ -99,7 +99,7 @@ def test_batch(dtype: torch.dtype, name1: str, name2: str):
calc = Calculator(numbers, GFN2_XTB, opts=opts, **dd)

result = calc.singlepoint(positions)
d4sc = calc.interactions.get_interaction("D4SC")
d4sc = calc.interactions.get_interaction("DispersionD4SC")
cache = d4sc.get_cache(numbers=numbers, positions=positions)

edisp = d4sc.get_energy(result.charges, cache, calc.ihelp)
Expand Down
15 changes: 15 additions & 0 deletions test/test_coulomb/test_es3_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import pytest
import torch
from tad_mctc.convert import str_to_device
from tad_mctc.exceptions import DeviceError
from tad_mctc.typing import MockTensor

from dxtb import GFN1_XTB, IndexHelper
from dxtb._src.components.interactions.coulomb import thirdorder as es3
Expand Down Expand Up @@ -117,3 +119,16 @@ def test_change_device_fail() -> None:
# trying to use setter
with pytest.raises(AttributeError):
cls.device = "cpu"


def test_device_fail_numbers() -> None:
n = torch.tensor([3, 1], device="cpu")
numbers = MockTensor(n)
numbers.device = "cuda"

# works
_ = es3.new_es3(n, GFN1_XTB, device=torch.device("cpu"))

# fails
with pytest.raises(DeviceError):
es3.new_es3(numbers, GFN1_XTB, device=torch.device("cpu"))
Loading

0 comments on commit d1351e3

Please sign in to comment.