diff --git a/src/dxtb/_src/components/interactions/dispersion/d4sc.py b/src/dxtb/_src/components/interactions/dispersion/d4sc.py index e0fc825c..e75861dd 100644 --- a/src/dxtb/_src/components/interactions/dispersion/d4sc.py +++ b/src/dxtb/_src/components/interactions/dispersion/d4sc.py @@ -42,16 +42,16 @@ from ..base import Interaction, InteractionCache -__all__ = ["D4SC", "LABEL_D4SC", "new_d4sc"] +__all__ = ["DispersionD4SC", "LABEL_DispersionD4SC", "new_d4sc"] -LABEL_D4SC = "D4SC" -"""Label for the :class:`.D4SC` interaction, coinciding with the class name.""" +LABEL_DispersionD4SC = "DispersionD4SC" +"""Label for the :class:`.DispersionD4SC` interaction, coinciding with the class name.""" -class D4SCCache(InteractionCache, TensorLike): +class DispersionDispersionD4SCCache(InteractionCache, TensorLike): """ - Restart data for the :class:`.D4SC` interaction. + Restart data for the :class:`.DispersionD4SC` interaction. Note ---- @@ -127,9 +127,9 @@ def restore(self) -> None: self.dispmat = self.__store.dispmat -class D4SC(Interaction): +class DispersionD4SC(Interaction): """ - Self-consistent D4 dispersion correction (:class:`.D4SC`). + Self-consistent D4 dispersion correction (:class:`.DispersionD4SC`). """ param: dict[str, Tensor] @@ -175,7 +175,7 @@ def get_cache( positions: Tensor | None = None, ihelp: IndexHelper | None = None, **_, - ) -> D4SCCache: + ) -> DispersionDispersionD4SCCache: """ Create restart data for individual interactions. @@ -188,26 +188,28 @@ def get_cache( Returns ------- - D4SCCache + DispersionDispersionD4SCCache Restart data for the interaction. Note ---- - If the :class:`.D4SC` interaction is evaluated within the + If the :class:`.DispersionD4SC` interaction is evaluated within the :class:`dxtb.components.InteractionList`, ``positions`` will be passed as an argument, too. Hence, it is necessary to absorb the ``positions`` in the signature of the function (also see :meth:`dxtb.components.Interaction.get_cache`). """ if numbers is None: - raise ValueError("Atomic numbers are required for D4SC cache.") + raise ValueError( + "Atomic numbers are required for DispersionD4SC cache." + ) if positions is None: raise ValueError("Positions are required for ES2 cache.") cachvars = (numbers.detach().clone(),) if self.cache_is_latest(cachvars) is True: - if not isinstance(self.cache, D4SCCache): + if not isinstance(self.cache, DispersionDispersionD4SCCache): raise TypeError( f"Cache in {self.label} is not of type '{self.label}." "Cache'. This can only happen if you manually manipulate " @@ -241,11 +243,13 @@ def get_cache( ) dispmat = edisp.unsqueeze(-1).unsqueeze(-1) * self.model.rc6 - self.cache = D4SCCache(cn, dispmat) + self.cache = DispersionDispersionD4SCCache(cn, dispmat) return self.cache - def get_atom_energy(self, charges: Tensor, cache: D4SCCache) -> Tensor: + def get_atom_energy( + self, charges: Tensor, cache: DispersionDispersionD4SCCache + ) -> Tensor: """ Calculate the D4 dispersion correction energy. @@ -253,7 +257,7 @@ def get_atom_energy(self, charges: Tensor, cache: D4SCCache) -> Tensor: ---------- charges : Tensor Atomic charges of all atoms. - cache : D4SCCache + cache : DispersionDispersionD4SCCache Restart data for the interaction. Returns @@ -269,7 +273,9 @@ def get_atom_energy(self, charges: Tensor, cache: D4SCCache) -> Tensor: optimize=[(0, 1), (0, 1)], ) - def get_atom_potential(self, charges: Tensor, cache: D4SCCache) -> Tensor: + def get_atom_potential( + self, charges: Tensor, cache: DispersionDispersionD4SCCache + ) -> Tensor: """ Calculate the D4 dispersion correction potential. @@ -277,7 +283,7 @@ def get_atom_potential(self, charges: Tensor, cache: D4SCCache) -> Tensor: ---------- charges : Tensor Atomic charges of all atoms. - cache : D4SCCache + cache : DispersionDispersionD4SCCache Restart data for the interaction. Returns @@ -299,9 +305,9 @@ def new_d4sc( par: Param, device: torch.device | None = None, dtype: torch.dtype | None = None, -) -> D4SC | None: +) -> DispersionD4SC | None: """ - Create new instance of :class:`.D4SC`. + Create new instance of :class:`.DispersionD4SC`. Parameters ---------- @@ -312,8 +318,8 @@ def new_d4sc( Returns ------- - D4SC | None - Instance of the :class:`.D4SC` class or ``None`` if no :class:`.D4SC` is + DispersionD4SC | None + Instance of the :class:`.DispersionD4SC` class or ``None`` if no :class:`.DispersionD4SC` is used. """ @@ -355,4 +361,6 @@ def new_d4sc( model = d4.model.D4Model(numbers, ref_charges="gfn2", **dd) cutoff = d4.cutoff.Cutoff(disp2=50.0, disp3=25.0, **dd) - return D4SC(param, model=model, rcov=rcov, r4r2=r4r2, cutoff=cutoff, **dd) + return DispersionD4SC( + param, model=model, rcov=rcov, r4r2=r4r2, cutoff=cutoff, **dd + ) diff --git a/src/dxtb/_src/components/interactions/list.py b/src/dxtb/_src/components/interactions/list.py index 13918001..7bca82b2 100644 --- a/src/dxtb/_src/components/interactions/list.py +++ b/src/dxtb/_src/components/interactions/list.py @@ -39,7 +39,7 @@ from .container import Charges, Potential from .coulomb.secondorder import ES2, LABEL_ES2 from .coulomb.thirdorder import ES3, LABEL_ES3 -from .dispersion.d4sc import D4SC, LABEL_D4SC +from .dispersion.d4sc import DispersionD4SC, LABEL_DispersionD4SC from .field.efield import LABEL_EFIELD, ElectricField from .field.efieldgrad import LABEL_EFIELD_GRAD, ElectricFieldGrad @@ -273,7 +273,9 @@ def get_potential( ########################################################################### @overload - def get_interaction(self, name: Literal["D4SC"]) -> D4SC: ... + def get_interaction( + self, name: Literal["DispersionD4SC"] + ) -> DispersionD4SC: ... @overload def get_interaction( @@ -300,7 +302,7 @@ def get_interaction(self, name: str) -> Interaction: @_docstring_reset def reset_d4sc(self) -> Interaction: """Reset tensor attributes to a detached clone of the current state.""" - return self.reset(LABEL_D4SC) + return self.reset(LABEL_DispersionD4SC) @_docstring_reset def reset_efield(self) -> Interaction: @@ -326,7 +328,7 @@ def reset_es3(self) -> Interaction: @_docstring_update def update_d4sc(self, **kwargs: Any) -> Interaction: - return self.update(LABEL_D4SC, **kwargs) + return self.update(LABEL_DispersionD4SC, **kwargs) @_docstring_update def update_efield( diff --git a/src/dxtb/_src/components/interactions/solvation/alpb.py b/src/dxtb/_src/components/interactions/solvation/alpb.py index 33cdf68e..c66f1e05 100644 --- a/src/dxtb/_src/components/interactions/solvation/alpb.py +++ b/src/dxtb/_src/components/interactions/solvation/alpb.py @@ -60,6 +60,7 @@ import torch from tad_mctc import storch from tad_mctc.batch import real_pairs +from tad_mctc.convert import any_to_tensor from tad_mctc.data import VDW_D3 from tad_mctc.math import einsum @@ -82,6 +83,8 @@ DEFAULT_KERNEL = "p16" DEFAULT_ALPB = True +DEFAULT_BORN_SCALE = 1.0 +DEFAULT_BORN_OFFSET = 0.0 __all__ = ["GeneralizedBorn", "new_solvation"] @@ -338,6 +341,7 @@ def get_atom_gradient( def new_solvation( numbers: Tensor, par: Param, + dielectric_constant: Tensor | float | int = 80.3, device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> GeneralizedBorn | None: @@ -360,6 +364,9 @@ def new_solvation( if hasattr(par, "solvation") is False or par.solvation is None: return None + if hasattr(par.solvation, "alpb") is False or par.solvation.alpb is None: + return None + if device is not None: if device != numbers.device: raise DeviceError( @@ -372,11 +379,22 @@ def new_solvation( "dtype": dtype if dtype is not None else get_default_dtype(), } - s = par.solvation # type: ignore - epsilon = torch.tensor(s.epsilon.gexp, **dd) + s = par.solvation.alpb # type: ignore alpb = s.alpb if hasattr(s, "alpb") else DEFAULT_ALPB kernel = s.kernel if hasattr(s, "kernel") else DEFAULT_KERNEL + born_scale = ( + s.born_scale if hasattr(s, "born_scale") else DEFAULT_BORN_SCALE + ) + born_offset = ( + s.born_offset if hasattr(s, "born_offset") else DEFAULT_BORN_OFFSET + ) return GeneralizedBorn( - numbers, dielectric_constant=epsilon, alpb=alpb, kernel=kernel, **dd + numbers, + dielectric_constant=any_to_tensor(dielectric_constant), + alpb=alpb, + kernel=kernel, + born_scale=born_scale, + born_offset=born_offset, + **dd, ) diff --git a/src/dxtb/_src/param/base.py b/src/dxtb/_src/param/base.py index a24620e7..aeebf787 100644 --- a/src/dxtb/_src/param/base.py +++ b/src/dxtb/_src/param/base.py @@ -46,6 +46,7 @@ from .meta import Meta from .multipole import Multipole from .repulsion import Repulsion +from .solvation import Solvation from .thirdorder import ThirdOrder __all__ = ["Param"] @@ -91,6 +92,8 @@ class Param(BaseModel): thirdorder: Optional[ThirdOrder] = None """Definition of the isotropic third-order charge interactions.""" + solvation: Optional[Solvation] = None + def clean_model_dump(self) -> dict[str, Any]: """ Clean the model from any `None` values. diff --git a/src/dxtb/_src/param/solvation.py b/src/dxtb/_src/param/solvation.py new file mode 100644 index 00000000..ac484fc8 --- /dev/null +++ b/src/dxtb/_src/param/solvation.py @@ -0,0 +1,60 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Parametrization: Electrostatics (3rd order) +=========================================== + +Definition of the isotropic third-order onsite correction. +""" + +from __future__ import annotations + +from typing import Optional + +from pydantic import BaseModel + +__all__ = ["ALPB", "Solvation"] + + +class ALPB(BaseModel): + """Representation of shell-resolved third-order electrostatics.""" + + alpb: bool + """Use analytical linearized Poisson-Boltzmann model.""" + + kernel: str + """ + Born interaction kernels. Either classical Still kernel or P16 kernel + by Lange (JCTC 2012, 8, 1999-2011). + """ + + born_scale: float + """Scaling factor for Born radii.""" + + born_offset: float + """Offset parameter for Born radii integration.""" + + +class Solvation(BaseModel): + """ + Representation of the isotropic third-order onsite correction. + """ + + alpb: Optional[ALPB] = None + """ + Whether the third order contribution is shell-dependent or only atomwise. + """ diff --git a/src/dxtb/components/dispersion.py b/src/dxtb/components/dispersion.py index b5162683..a0e755c8 100644 --- a/src/dxtb/components/dispersion.py +++ b/src/dxtb/components/dispersion.py @@ -30,5 +30,15 @@ from dxtb._src.components.classicals.dispersion import ( new_dispersion as new_dispersion, ) +from dxtb._src.components.interactions.dispersion import ( + DispersionD4SC as DispersionD4SC, +) +from dxtb._src.components.interactions.dispersion import new_d4sc as new_d4sc -__all__ = ["DispersionD3", "DispersionD4", "new_dispersion"] +__all__ = [ + "DispersionD3", + "DispersionD4", + "new_dispersion", + "new_d4sc", + "DispersionD4SC", +] diff --git a/test/test_classical/test_dispersion/test_d4sc.py b/test/test_classical/test_dispersion/test_d4sc.py index dce77452..570a296c 100644 --- a/test/test_classical/test_dispersion/test_d4sc.py +++ b/test/test_classical/test_dispersion/test_d4sc.py @@ -61,7 +61,7 @@ def test_single(dtype: torch.dtype, name: str): calc = Calculator(numbers, GFN2_XTB, opts=opts, **dd) result = calc.singlepoint(positions, charges) - d4sc = calc.interactions.get_interaction("D4SC") + d4sc = calc.interactions.get_interaction("DispersionD4SC") cache = d4sc.get_cache(numbers=numbers, positions=positions) edisp = d4sc.get_energy(result.charges, cache, calc.ihelp) @@ -99,7 +99,7 @@ def test_batch(dtype: torch.dtype, name1: str, name2: str): calc = Calculator(numbers, GFN2_XTB, opts=opts, **dd) result = calc.singlepoint(positions) - d4sc = calc.interactions.get_interaction("D4SC") + d4sc = calc.interactions.get_interaction("DispersionD4SC") cache = d4sc.get_cache(numbers=numbers, positions=positions) edisp = d4sc.get_energy(result.charges, cache, calc.ihelp) diff --git a/test/test_coulomb/test_es3_general.py b/test/test_coulomb/test_es3_general.py index bddeac01..5aedf98d 100644 --- a/test/test_coulomb/test_es3_general.py +++ b/test/test_coulomb/test_es3_general.py @@ -24,6 +24,8 @@ import pytest import torch from tad_mctc.convert import str_to_device +from tad_mctc.exceptions import DeviceError +from tad_mctc.typing import MockTensor from dxtb import GFN1_XTB, IndexHelper from dxtb._src.components.interactions.coulomb import thirdorder as es3 @@ -117,3 +119,16 @@ def test_change_device_fail() -> None: # trying to use setter with pytest.raises(AttributeError): cls.device = "cpu" + + +def test_device_fail_numbers() -> None: + n = torch.tensor([3, 1], device="cpu") + numbers = MockTensor(n) + numbers.device = "cuda" + + # works + _ = es3.new_es3(n, GFN1_XTB, device=torch.device("cpu")) + + # fails + with pytest.raises(DeviceError): + es3.new_es3(numbers, GFN1_XTB, device=torch.device("cpu")) diff --git a/test/test_interaction/test_base.py b/test/test_interaction/test_base.py index 554f0f90..bb991619 100644 --- a/test/test_interaction/test_base.py +++ b/test/test_interaction/test_base.py @@ -20,6 +20,7 @@ from __future__ import annotations +import pytest import torch from dxtb import IndexHelper @@ -60,3 +61,12 @@ def test_empty() -> None: sg = i.get_shell_gradient(numbers, numbers) assert (sg == torch.zeros(sg.shape, device=DEVICE)).all() + + +def test_energy_fail() -> None: + """Monopolar charges are always required.""" + i = Interaction() + c = Charges() + + with pytest.raises(RuntimeError): + i.get_energy(c, None, None) # type: ignore diff --git a/test/test_interaction/test_cache.py b/test/test_interaction/test_cache.py new file mode 100644 index 00000000..19afc228 --- /dev/null +++ b/test/test_interaction/test_cache.py @@ -0,0 +1,126 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Test caches. +""" + +from __future__ import annotations + +import pytest +import torch + +from dxtb import GFN1_XTB, GFN2_XTB, IndexHelper, Param +from dxtb._src.typing import Callable, Tensor +from dxtb.components.base import Interaction, InteractionCache +from dxtb.components.coulomb import new_es2, new_es3 +from dxtb.components.dispersion import new_d4sc +from dxtb.components.field import new_efield, new_efield_grad +from dxtb.components.solvation import new_solvation + +from ..conftest import DEVICE + + +@pytest.mark.parametrize( + "comp_factory_par", + [ + (new_d4sc, GFN2_XTB), + (new_es2, GFN1_XTB), + (new_es3, GFN1_XTB), + ], +) +def test_fail_overwritten_cache( + comp_factory_par: tuple[Callable[[Tensor, Param], Interaction], Param] +) -> None: + numbers = torch.tensor([3, 1], device=DEVICE) + positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], device=DEVICE) + ihelp = IndexHelper.from_numbers(numbers, GFN1_XTB) + + comp_factory, par = comp_factory_par + comp = comp_factory(numbers, par) + assert comp is not None + + # create cache + comp.cache_enable() + _ = comp.get_cache(numbers=numbers, positions=positions, ihelp=ihelp) + + # manually overwrite cache + comp.cache = InteractionCache() + + with pytest.raises(TypeError): + comp.get_cache(numbers=numbers, positions=positions, ihelp=ihelp) + + +def test_fail_overwritten_cache_ef() -> None: + positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], device=DEVICE) + + ef = new_efield(torch.tensor([0.0, 0.0, 0.0]), device=DEVICE) + assert ef is not None + + # create cache + ef.cache_enable() + _ = ef.get_cache(positions=positions) + + # manually overwrite cache + ef.cache = InteractionCache() + + with pytest.raises(TypeError): + ef.get_cache(positions=positions) + + +def test_fail_overwritten_cache_efg() -> None: + positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], device=DEVICE) + + efg = new_efield_grad(torch.zeros((3, 3)), device=DEVICE) + assert efg is not None + + # create cache + efg.cache_enable() + _ = efg.get_cache(positions=positions) + + # manually overwrite cache + efg.cache = InteractionCache() + + with pytest.raises(TypeError): + efg.get_cache(positions=positions) + + +def test_fail_overwritten_cache_solvation() -> None: + numbers = torch.tensor([3, 1], device=DEVICE) + positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], device=DEVICE) + ihelp = IndexHelper.from_numbers(numbers, GFN1_XTB) + + # manually create solvation + from dxtb._src.param.solvation import ALPB, Solvation + + par = GFN1_XTB.model_copy(deep=True) + par.solvation = Solvation( + alpb=ALPB(alpb=True, kernel="p16", born_scale=1.0, born_offset=0.0) + ) + + solv = new_solvation(numbers, par) + assert solv is not None + + # create cache + solv.cache_enable() + _ = solv.get_cache(numbers=numbers, positions=positions, ihelp=ihelp) + + # manually overwrite cache + + solv.cache = InteractionCache() + + with pytest.raises(TypeError): + solv.get_cache(numbers=numbers, positions=positions, ihelp=ihelp) diff --git a/test/test_scf/test_fenergy.py b/test/test_scf/test_fenergy.py index c8a9bebe..acdb2a2c 100644 --- a/test/test_scf/test_fenergy.py +++ b/test/test_scf/test_fenergy.py @@ -34,11 +34,6 @@ opts = { "fermi_etemp": 300, "fermi_maxiter": 500, - "fermi_thresh": { - # instead of 1e-5 - torch.float32: torch.tensor(1e-4, dtype=torch.float32), - torch.float64: torch.tensor(1e-10, dtype=torch.float64), - }, "scf_mode": labels.SCF_MODE_IMPLICIT_NON_PURE, "scp_mode": "potential", # important for atoms (better convergence) "verbosity": 0, @@ -65,6 +60,7 @@ def test_element(dtype: torch.dtype, partition: str, number: int) -> None: "f_atol": 1e-5 if dtype == torch.float32 else 1e-6, "x_atol": 1e-5 if dtype == torch.float32 else 1e-6, "fermi_partition": partition, + "fermi_thresh": 1e-4 if dtype == torch.float32 else 1e-10, "maxiter": 100, }, ) @@ -105,6 +101,7 @@ def fcn(number): **{ "f_atol": 1e-5 if dtype == torch.float32 else 1e-6, "x_atol": 1e-5 if dtype == torch.float32 else 1e-6, + "fermi_thresh": 1e-4 if dtype == torch.float32 else 1e-10, }, ) calc = Calculator(numbers, par, opts=options, **dd) @@ -133,6 +130,7 @@ def fcn(number): **{ "f_atol": 1e-5, # avoids Jacobian inversion error "x_atol": 1e-5, # avoids Jacobian inversion error + "fermi_thresh": 1e-4 if dtype == torch.float32 else 1e-10, }, ) calc = Calculator(numbers, par, opts=options, **dd) @@ -165,6 +163,7 @@ def fcn(number): **{ "f_atol": 1e-5, # avoid Jacobian inversion error "x_atol": 1e-5, # avoid Jacobian inversion error + "fermi_thresh": 1e-4 if dtype == torch.float32 else 1e-10, }, ) calc = Calculator(numbers, par, opts=options, **dd)