Skip to content

Commit

Permalink
Minor refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede committed Aug 24, 2024
1 parent 6f805ac commit 054fa2c
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 103 deletions.
15 changes: 11 additions & 4 deletions src/dxtb/_src/calculators/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,14 @@ class **inherits from all types**, i.e., it provides the energy and properties
if we need to differentiate for multiple properties at once (e.g., Hessian and
dipole moment for IR spectra). Hence, the default is ``use_functorch=False``.
"""
from .analytical import *
from .autograd import *
from .energy import *
from .numerical import *
from .analytical import AnalyticalCalculator
from .autograd import AutogradCalculator
from .energy import EnergyCalculator
from .numerical import NumericalCalculator

__all__ = [
"AnalyticalCalculator",
"AutogradCalculator",
"EnergyCalculator",
"NumericalCalculator",
]
1 change: 0 additions & 1 deletion src/dxtb/_src/components/classicals/dispersion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
Dispersion models in the extended tight-binding model.
"""

from .base import Dispersion
from .d3 import DispersionD3
from .d4 import DispersionD4
from .factory import new_dispersion
3 changes: 2 additions & 1 deletion src/dxtb/_src/components/classicals/dispersion/d3.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@

from dxtb._src.typing import Any, CountingFunction, Tensor

from .base import ClassicalCache, Dispersion
from ..base import ClassicalCache
from .base import Dispersion

__all__ = ["DispersionD3", "DispersionD3Cache"]

Expand Down
3 changes: 2 additions & 1 deletion src/dxtb/_src/components/classicals/dispersion/d4.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@

from dxtb._src.typing import Any, Tensor

from .base import ClassicalCache, Dispersion
from ..base import ClassicalCache
from .base import Dispersion

__all__ = ["DispersionD4", "DispersionD4Cache"]

Expand Down
85 changes: 85 additions & 0 deletions src/dxtb/_src/integral/abc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# This file is part of dxtb.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Integrals: Abstract Base Classes
================================
Abstract case class for integrals.
"""

from __future__ import annotations

from abc import ABC, abstractmethod

from dxtb._src.typing import TYPE_CHECKING, Any, Tensor

if TYPE_CHECKING:
from dxtb._src.integral.base import IntDriver


__all__ = ["IntegralABC"]


class IntegralABC(ABC):
"""
Abstract base class for integral implementations.
All integral calculations are executed by this class.
"""

@abstractmethod
def build(self, driver: IntDriver, **kwargs: Any) -> Tensor:
"""
Create the integral matrix.
Parameters
----------
driver : IntDriver
Integral driver for the calculation.
Returns
-------
Tensor
Integral matrix.
"""

@abstractmethod
def get_gradient(self, driver: IntDriver, **kwargs: Any) -> Tensor:
"""
Calculate the full nuclear gradient matrix of the integral.
Parameters
----------
driver : IntDriver
Integral driver for the calculation.
Returns
-------
Tensor
Nuclear integral derivative matrix.
"""

@abstractmethod
def normalize(self, norm: Tensor | None = None, **kwargs: Any) -> None:
"""
Normalize the integral (changes ``self.matrix``).
Parameters
----------
norm : Tensor, optional
Overlap norm to normalize the integral.
"""
99 changes: 5 additions & 94 deletions src/dxtb/_src/integral/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,25 @@
Integrals: Base Classes
=======================
Base class for Integrals classes and their actual implementations.
Base class for integral classes and their actual implementations.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from abc import abstractmethod

import torch
from tad_mctc.math import einsum

from dxtb import IndexHelper
from dxtb._src.basis.bas import Basis
from dxtb._src.param import Param
from dxtb._src.typing import Any, Literal, PathLike, Tensor, TensorLike
from dxtb._src.typing import Literal, PathLike, Tensor, TensorLike

from .abc import IntegralABC
from .utils import snorm

__all__ = [
"BaseIntegral",
"IntDriver",
"IntegralContainer",
]
__all__ = ["BaseIntegral", "IntDriver"]


class IntDriver(TensorLike):
Expand Down Expand Up @@ -201,57 +198,6 @@ def __repr__(self) -> str:
#########################################################


class IntegralABC(ABC):
"""
Abstract base class for integral implementations.
All integral calculations are executed by this class.
"""

@abstractmethod
def build(self, driver: IntDriver, **kwargs: Any) -> Tensor:
"""
Create the integral matrix.
Parameters
----------
driver : IntDriver
Integral driver for the calculation.
Returns
-------
Tensor
Integral matrix.
"""

@abstractmethod
def get_gradient(self, driver: IntDriver, **kwargs: Any) -> Tensor:
"""
Calculate the full nuclear gradient matrix of the integral.
Parameters
----------
driver : IntDriver
Integral driver for the calculation.
Returns
-------
Tensor
Nuclear integral derivative matrix.
"""

@abstractmethod
def normalize(self, norm: Tensor | None = None, **kwargs: Any) -> None:
"""
Normalize the integral (changes ``self.matrix``).
Parameters
----------
norm : Tensor, optional
Overlap norm to normalize the integral.
"""


class BaseIntegral(IntegralABC, TensorLike):
"""
Base class for integral implementations.
Expand Down Expand Up @@ -427,38 +373,3 @@ def __str__(self) -> str:

def __repr__(self) -> str:
return str(self)


#########################################################


class IntegralContainer(TensorLike):
"""
Base class for integral container.
"""

def __init__(
self,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
_run_checks: bool = True,
):
super().__init__(device, dtype)
self._run_checks = _run_checks

@property
def run_checks(self) -> bool:
return self._run_checks

@run_checks.setter
def run_checks(self, run_checks: bool) -> None:
current = self.run_checks
self._run_checks = run_checks

# switching from False to True should automatically run checks
if current is False and run_checks is True:
self.checks()

@abstractmethod
def checks(self) -> None:
"""Run checks for integrals."""
37 changes: 35 additions & 2 deletions src/dxtb/_src/integral/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,16 @@
from __future__ import annotations

import logging
from abc import abstractmethod

import torch

from dxtb import labels
from dxtb._src.constants import defaults, labels
from dxtb._src.typing import Any, Tensor
from dxtb._src.typing import Any, Tensor, TensorLike
from dxtb._src.xtb.base import BaseHamiltonian

from .base import BaseIntegral, IntegralContainer
from .base import BaseIntegral
from .driver import DriverManager
from .types import DipoleIntegral, OverlapIntegral, QuadrupoleIntegral

Expand All @@ -41,6 +42,38 @@
logger = logging.getLogger(__name__)


class IntegralContainer(TensorLike):
"""
Base class for integral container.
"""

def __init__(
self,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
_run_checks: bool = True,
):
super().__init__(device, dtype)
self._run_checks = _run_checks

@property
def run_checks(self) -> bool:
return self._run_checks

@run_checks.setter
def run_checks(self, run_checks: bool) -> None:
current = self.run_checks
self._run_checks = run_checks

# switching from False to True should automatically run checks
if current is False and run_checks is True:
self.checks()

@abstractmethod
def checks(self) -> None:
"""Run checks for integrals."""


class Integrals(IntegralContainer):
"""
Integral container.
Expand Down

0 comments on commit 054fa2c

Please sign in to comment.