Skip to content

Commit

Permalink
Fixes to allow TorchScript plus refactoring and input validation.
Browse files Browse the repository at this point in the history
  • Loading branch information
lohedges committed Oct 22, 2024
1 parent de48b29 commit 756c08c
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 40 deletions.
6 changes: 3 additions & 3 deletions emle/models/_ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def hook(
input: Tuple[Tuple[Tensor, Tensor], Optional[Tensor], Optional[Tensor]],
output: Tuple[Tensor, Tensor],
):
module._aev = output[1][0]
module._aev = output[1]

else:

Expand All @@ -250,7 +250,7 @@ def hook(
input: Tuple[Tuple[Tensor, Tensor], Optional[Tensor], Optional[Tensor]],
output: _torchani.aev.SpeciesAEV,
):
module._aev = output[1][0]
module._aev = output[1]

# Register the hook.
self._aev_hook = self._ani2x.aev_computer.register_forward_hook(hook)
Expand Down Expand Up @@ -351,7 +351,7 @@ def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm):

# Set the AEVs captured by the forward hook as an attribute of the
# EMLE model.
self._emle._aev = self._ani2x.aev_computer._aev
self._emle._emle_base._aev = self._ani2x.aev_computer._aev

# Get the EMLE energy components.
E_emle = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm)
Expand Down
30 changes: 15 additions & 15 deletions emle/models/_emle.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from typing import Optional, Tuple, List

from . import _patches
from . import EMLEBase
from . import EMLEBase as _EMLEBase

# Monkey-patch the TorchANI BuiltInModel and BuiltinEnsemble classes so that
# they call self.aev_computer using args only to allow forward hooks to work
Expand Down Expand Up @@ -330,19 +330,8 @@ def __init__(
else None
),
}
self._emle_base = EMLEBase(
emle_params,
self._aev_computer,
aev_mask,
species,
n_ref,
ref_features,
q_core,
alpha_mode,
device,
dtype,
)

# Store the total charge.
q_total = _torch.tensor(
params.get("total_charge", 0), dtype=dtype, device=device
)
Expand All @@ -359,8 +348,19 @@ def __init__(
self.register_buffer("_q_total", q_total)
self.register_buffer("_q_core_mm", q_core_mm)

# Initalise an empty AEV tensor to use to store the AEVs in derived classes.
self._aev = _torch.empty(0, dtype=dtype, device=device)
# Create the base EMLE model.
self._emle_base = _EMLEBase(
emle_params,
n_ref,
ref_features,
q_core,
aev_computer=self._aev_computer,
aev_mask=aev_mask,
alpha_mode=self._alpha_mode,
species=self._species,
device=device,
dtype=dtype,
)

def _to_dict(self):
"""
Expand Down
177 changes: 155 additions & 22 deletions emle/models/_emle_base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,43 @@
#######################################################################
# EMLE-Engine: https://github.com/chemle/emle-engine
#
# Copyright: 2023-2024
#
# Authors: Lester Hedges <[email protected]>
# Kirill Zinovjev <[email protected]>
#
# EMLE-Engine is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# EMLE-Engine is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with EMLE-Engine. If not, see <http://www.gnu.org/licenses/>.
#####################################################################

import numpy as _np

import torch as _torch

from torch import Tensor
from typing import Tuple

import torchani as _torchani

try:
import NNPOps as _NNPOps

_NNPOps.OptimizedTorchANI = _patches.OptimizedTorchANI

_has_nnpops = True
except:
_has_nnpops = False


class EMLEBase(_torch.nn.Module):
"""
Expand All @@ -14,15 +47,18 @@ class EMLEBase(_torch.nn.Module):
electrostating embedding energies using the EMLE model.
"""

# Store the list of supported species.
_species = [1, 6, 7, 8, 16]

def __init__(
self,
params,
aev_computer,
aev_mask,
species,
n_ref,
ref_features,
q_core,
aev_computer=None,
aev_mask=None,
species=None,
alpha_mode="species",
device=None,
dtype=None,
Expand All @@ -34,26 +70,16 @@ def __init__(
----------
params: dict
EMLE model parameters
aev_computer: torchani.AEVComputer
AEV computer instance used to compute AEVs.
aev_mask: torch.Tensor
Mask for features coming from aev_computer.
species: List[int], Tuple[int], numpy.ndarray, torch.Tensor
List of species (atomic numbers) supported by the EMLE model. If
None, then the default species list will be used.
EMLE model parameters.
n_ref: torch.Tensor
number of GPR references for each element in species list
ref_features: torch.Tensor
Feature vectors for GPR references
Feature vectors for GPR references.
q_core: torch.Tensor
Core charges for each element in species list
Core charges for each element in species list.
alpha_mode: str
How atomic polarizabilities are calculated.
Expand All @@ -63,6 +89,15 @@ def __init__(
scaling factors are obtained with GPR using the values learned
for each reference environment
aev_computer: torchani.AEVComputer
AEV computer instance used to compute AEVs.
aev_mask: torch.Tensor
Mask for features coming from aev_computer.
species: List[int], Tuple[int], numpy.ndarray, torch.Tensor
List of species (atomic numbers) supported by the EMLE model.
device: torch.device
The device on which to run the model.
Expand All @@ -73,6 +108,46 @@ def __init__(
# Call the base class constructor.
super().__init__()

# Validate the parameters.
if not isinstance(params, dict):
raise TypeError("'params' must be of type 'dict'")
if not all(
k in params
for k in ["a_QEq", "a_Thole", "ref_values_s", "ref_values_chi", "k_Z"]
):
raise ValueError(
"'params' must contain keys 'a_QEq', 'a_Thole', 'ref_values_s', 'ref_values_chi', and 'k_Z'"
)

# Validate the number of references.
if not isinstance(n_ref, _torch.Tensor):
raise TypeError("'n_ref' must be of type 'torch.Tensor'")
if len(n_ref.shape) != 1:
raise ValueError("'n_ref' must be a 1D tensor")
if not n_ref.dtype == _torch.int64:
raise ValueError("'n_ref' must have dtype 'torch.int64'")

# Validate the reference features.
if not isinstance(ref_features, _torch.Tensor):
raise TypeError("'ref_features' must be of type 'torch.Tensor'")
if len(ref_features.shape) != 3:
raise ValueError("'ref_features' must be a 3D tensor")
if not ref_features.dtype in (_torch.float64, _torch.float32):
raise ValueError(
"'ref_features' must have dtype 'torch.float64' or 'torch.float32'"
)

# Validate the core charges.
if not isinstance(q_core, _torch.Tensor):
raise TypeError("'q_core' must be of type 'torch.Tensor'")
if len(q_core.shape) != 1:
raise ValueError("'q_core' must be a 1D tensor")
if not q_core.dtype in (_torch.float64, _torch.float32):
raise ValueError(
"'q_core' must have dtype 'torch.float64' or 'torch.float32'"
)

# Validate the alpha mode.
if alpha_mode is None:
alpha_mode = "species"
if not isinstance(alpha_mode, str):
Expand All @@ -82,6 +157,32 @@ def __init__(
raise ValueError("'alpha_mode' must be 'species' or 'reference'")
self._alpha_mode = alpha_mode

# Validate the AEV computer.
if aev_computer is not None:
allowed_types = [_torchani.AEVComputer]
if _has_nnpops:
allowed_types.append(
_NNPOps.SymmetryFunctions.TorchANISymmetryFunctions
)
if not isinstance(aev_computer, tuple(allowed_types)):
raise TypeError(
"'aev_computer' must be of type 'torchani.AEVComputer' or 'NNPOps.SymmetryFunctions.TorchANISymmetryFunctions'"
)
self._aev_computer = aev_computer
else:
self._aev_computer = None

# Validate the AEV mask.
if aev_mask is not None:
if not isinstance(aev_mask, _torch.Tensor):
raise TypeError("'aev_mask' must be of type 'torch.Tensor'")
if len(aev_mask.shape) != 1:
raise ValueError("'aev_mask' must be a 1D tensor")
if not aev_mask.dtype == _torch.bool:
raise ValueError("'aev_mask' must have dtype 'torch.bool'")
else:
aev_mask = _torch.ones(ref_features.shape[2], dtype=_torch.bool)

if device is not None:
if not isinstance(device, _torch.device):
raise TypeError("'device' must be of type 'torch.device'")
Expand All @@ -94,8 +195,6 @@ def __init__(
else:
dtype = _torch.get_default_dtype()

self._aev_computer = aev_computer

# Store model parameters as tensors.
self.a_QEq = _torch.nn.Parameter(params["a_QEq"])
self.a_Thole = _torch.nn.Parameter(params["a_Thole"])
Expand All @@ -113,8 +212,22 @@ def __init__(
)
raise ValueError(msg)

# Create a map between species (1, 6, 8)
# and their indices in the model (0, 1, 2).
# Validate the species.
if species is None:
# Use the default species.
species = self._species
if isinstance(species, (_np.ndarray, _torch.Tensor)):
species = species.tolist()
if not isinstance(species, (tuple, list)):
raise TypeError(
"'species' must be of type 'list', 'tuple', or 'numpy.ndarray'"
)
if not all(isinstance(s, int) for s in species):
raise TypeError("All elements of 'species' must be of type 'int'")
if not all(s > 0 for s in species):
raise ValueError("All elements of 'species' must be greater than zero")

# Create a map between species and their indices in the model.
species_map = _np.full(max(species) + 2, fill_value=-1, dtype=_np.int64)
for i, s in enumerate(species):
species_map[s] = i
Expand Down Expand Up @@ -151,10 +264,12 @@ def __init__(
self.register_buffer("_c_chi", c_chi)
self.register_buffer("_c_sqrtk", c_sqrtk)

# Initalise an empty AEV tensor to use to store the AEVs in derived classes.
# Initalise an empty AEV tensor to use to store the AEVs in parent models.
self._aev = _torch.empty(0, dtype=dtype, device=device)

def to(self, *args, **kwargs):
if self._aev_computer is not None:
self._aev_computer = self._aev_computer.to(*args, **kwargs)
self._species_map = self._species_map.to(*args, **kwargs)
self._Kinv = self._Kinv.to(*args, **kwargs)
self._aev_mask = self._aev_mask.to(*args, **kwargs)
Expand All @@ -168,10 +283,18 @@ def to(self, *args, **kwargs):
self._c_chi = self._c_chi.to(*args, **kwargs)
self._c_sqrtk = self._c_sqrtk.to(*args, **kwargs)

# Check for a device type in args and update the device attribute.
for arg in args:
if isinstance(arg, _torch.device):
self._device = arg
break

def cuda(self, **kwargs):
"""
Move all model parameters and buffers to CUDA memory.
"""
if self._aev_computer is not None:
self._aev_computer = self._aev_computer.cuda(**kwargs)
self._species_map = self._species_map.cuda(**kwargs)
self._Kinv = self._Kinv.cuda(**kwargs)
self._aev_mask = self._aev_mask.cuda(**kwargs)
Expand All @@ -189,6 +312,8 @@ def cpu(self, **kwargs):
"""
Move all model parameters and buffers to CPU memory.
"""
if self._aev_computer is not None:
self._aev_computer = self._aev_computer.cpu(**kwargs)
self._species_map = self._species_map.cpu(**kwargs)
self._Kinv = self._Kinv.cpu(**kwargs)
self._aev_mask = self._aev_mask.cpu(**kwargs)
Expand All @@ -206,6 +331,8 @@ def double(self):
"""
Casts all floating point model parameters and buffers to float64 precision.
"""
if self._aev_computer is not None:
self._aev_computer = self._aev_computer.double()
self._Kinv = self._Kinv.double()
self._q_core = self._q_core.double()
self._ref_features = self._ref_features.double()
Expand All @@ -221,6 +348,8 @@ def float(self):
"""
Casts all floating point model parameters and buffers to float32 precision.
"""
if self._aev_computer is not None:
self._aev_computer = self._aev_computer.float()
self._Kinv = self._Kinv.float()
self._q_core = self._q_core.float()
self._ref_features = self._ref_features.float()
Expand Down Expand Up @@ -266,7 +395,11 @@ def forward(self, atomic_numbers, xyz_qm, q_total):
species_id = self._species_map[atomic_numbers]

# Compute the AEVs.
aev = self._aev_computer((species_id, xyz_qm))[1][:, :, self._aev_mask]
if self._aev_computer is not None:
aev = self._aev_computer((species_id, xyz_qm))[1][:, :, self._aev_mask]
# The AEVs have been pre-computed by a parent model.
else:
aev = self._aev[:, :, self._aev_mask]
aev = aev / _torch.linalg.norm(aev, ord=2, dim=2, keepdim=True)

# Compute the MBIS valence shell widths.
Expand Down

0 comments on commit 756c08c

Please sign in to comment.