-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fixes to allow TorchScript plus refactoring and input validation.
- Loading branch information
Showing
3 changed files
with
173 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,43 @@ | ||
####################################################################### | ||
# EMLE-Engine: https://github.com/chemle/emle-engine | ||
# | ||
# Copyright: 2023-2024 | ||
# | ||
# Authors: Lester Hedges <[email protected]> | ||
# Kirill Zinovjev <[email protected]> | ||
# | ||
# EMLE-Engine is free software: you can redistribute it and/or modify | ||
# it under the terms of the GNU General Public License as published by | ||
# the Free Software Foundation, either version 2 of the License, or | ||
# (at your option) any later version. | ||
# | ||
# EMLE-Engine is distributed in the hope that it will be useful, | ||
# but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
# GNU General Public License for more details. | ||
# | ||
# You should have received a copy of the GNU General Public License | ||
# along with EMLE-Engine. If not, see <http://www.gnu.org/licenses/>. | ||
##################################################################### | ||
|
||
import numpy as _np | ||
|
||
import torch as _torch | ||
|
||
from torch import Tensor | ||
from typing import Tuple | ||
|
||
import torchani as _torchani | ||
|
||
try: | ||
import NNPOps as _NNPOps | ||
|
||
_NNPOps.OptimizedTorchANI = _patches.OptimizedTorchANI | ||
|
||
_has_nnpops = True | ||
except: | ||
_has_nnpops = False | ||
|
||
|
||
class EMLEBase(_torch.nn.Module): | ||
""" | ||
|
@@ -14,15 +47,18 @@ class EMLEBase(_torch.nn.Module): | |
electrostating embedding energies using the EMLE model. | ||
""" | ||
|
||
# Store the list of supported species. | ||
_species = [1, 6, 7, 8, 16] | ||
|
||
def __init__( | ||
self, | ||
params, | ||
aev_computer, | ||
aev_mask, | ||
species, | ||
n_ref, | ||
ref_features, | ||
q_core, | ||
aev_computer=None, | ||
aev_mask=None, | ||
species=None, | ||
alpha_mode="species", | ||
device=None, | ||
dtype=None, | ||
|
@@ -34,26 +70,16 @@ def __init__( | |
---------- | ||
params: dict | ||
EMLE model parameters | ||
aev_computer: torchani.AEVComputer | ||
AEV computer instance used to compute AEVs. | ||
aev_mask: torch.Tensor | ||
Mask for features coming from aev_computer. | ||
species: List[int], Tuple[int], numpy.ndarray, torch.Tensor | ||
List of species (atomic numbers) supported by the EMLE model. If | ||
None, then the default species list will be used. | ||
EMLE model parameters. | ||
n_ref: torch.Tensor | ||
number of GPR references for each element in species list | ||
ref_features: torch.Tensor | ||
Feature vectors for GPR references | ||
Feature vectors for GPR references. | ||
q_core: torch.Tensor | ||
Core charges for each element in species list | ||
Core charges for each element in species list. | ||
alpha_mode: str | ||
How atomic polarizabilities are calculated. | ||
|
@@ -63,6 +89,15 @@ def __init__( | |
scaling factors are obtained with GPR using the values learned | ||
for each reference environment | ||
aev_computer: torchani.AEVComputer | ||
AEV computer instance used to compute AEVs. | ||
aev_mask: torch.Tensor | ||
Mask for features coming from aev_computer. | ||
species: List[int], Tuple[int], numpy.ndarray, torch.Tensor | ||
List of species (atomic numbers) supported by the EMLE model. | ||
device: torch.device | ||
The device on which to run the model. | ||
|
@@ -73,6 +108,46 @@ def __init__( | |
# Call the base class constructor. | ||
super().__init__() | ||
|
||
# Validate the parameters. | ||
if not isinstance(params, dict): | ||
raise TypeError("'params' must be of type 'dict'") | ||
if not all( | ||
k in params | ||
for k in ["a_QEq", "a_Thole", "ref_values_s", "ref_values_chi", "k_Z"] | ||
): | ||
raise ValueError( | ||
"'params' must contain keys 'a_QEq', 'a_Thole', 'ref_values_s', 'ref_values_chi', and 'k_Z'" | ||
) | ||
|
||
# Validate the number of references. | ||
if not isinstance(n_ref, _torch.Tensor): | ||
raise TypeError("'n_ref' must be of type 'torch.Tensor'") | ||
if len(n_ref.shape) != 1: | ||
raise ValueError("'n_ref' must be a 1D tensor") | ||
if not n_ref.dtype == _torch.int64: | ||
raise ValueError("'n_ref' must have dtype 'torch.int64'") | ||
|
||
# Validate the reference features. | ||
if not isinstance(ref_features, _torch.Tensor): | ||
raise TypeError("'ref_features' must be of type 'torch.Tensor'") | ||
if len(ref_features.shape) != 3: | ||
raise ValueError("'ref_features' must be a 3D tensor") | ||
if not ref_features.dtype in (_torch.float64, _torch.float32): | ||
raise ValueError( | ||
"'ref_features' must have dtype 'torch.float64' or 'torch.float32'" | ||
) | ||
|
||
# Validate the core charges. | ||
if not isinstance(q_core, _torch.Tensor): | ||
raise TypeError("'q_core' must be of type 'torch.Tensor'") | ||
if len(q_core.shape) != 1: | ||
raise ValueError("'q_core' must be a 1D tensor") | ||
if not q_core.dtype in (_torch.float64, _torch.float32): | ||
raise ValueError( | ||
"'q_core' must have dtype 'torch.float64' or 'torch.float32'" | ||
) | ||
|
||
# Validate the alpha mode. | ||
if alpha_mode is None: | ||
alpha_mode = "species" | ||
if not isinstance(alpha_mode, str): | ||
|
@@ -82,6 +157,32 @@ def __init__( | |
raise ValueError("'alpha_mode' must be 'species' or 'reference'") | ||
self._alpha_mode = alpha_mode | ||
|
||
# Validate the AEV computer. | ||
if aev_computer is not None: | ||
allowed_types = [_torchani.AEVComputer] | ||
if _has_nnpops: | ||
allowed_types.append( | ||
_NNPOps.SymmetryFunctions.TorchANISymmetryFunctions | ||
) | ||
if not isinstance(aev_computer, tuple(allowed_types)): | ||
raise TypeError( | ||
"'aev_computer' must be of type 'torchani.AEVComputer' or 'NNPOps.SymmetryFunctions.TorchANISymmetryFunctions'" | ||
) | ||
self._aev_computer = aev_computer | ||
else: | ||
self._aev_computer = None | ||
|
||
# Validate the AEV mask. | ||
if aev_mask is not None: | ||
if not isinstance(aev_mask, _torch.Tensor): | ||
raise TypeError("'aev_mask' must be of type 'torch.Tensor'") | ||
if len(aev_mask.shape) != 1: | ||
raise ValueError("'aev_mask' must be a 1D tensor") | ||
if not aev_mask.dtype == _torch.bool: | ||
raise ValueError("'aev_mask' must have dtype 'torch.bool'") | ||
else: | ||
aev_mask = _torch.ones(ref_features.shape[2], dtype=_torch.bool) | ||
|
||
if device is not None: | ||
if not isinstance(device, _torch.device): | ||
raise TypeError("'device' must be of type 'torch.device'") | ||
|
@@ -94,8 +195,6 @@ def __init__( | |
else: | ||
dtype = _torch.get_default_dtype() | ||
|
||
self._aev_computer = aev_computer | ||
|
||
# Store model parameters as tensors. | ||
self.a_QEq = _torch.nn.Parameter(params["a_QEq"]) | ||
self.a_Thole = _torch.nn.Parameter(params["a_Thole"]) | ||
|
@@ -113,8 +212,22 @@ def __init__( | |
) | ||
raise ValueError(msg) | ||
|
||
# Create a map between species (1, 6, 8) | ||
# and their indices in the model (0, 1, 2). | ||
# Validate the species. | ||
if species is None: | ||
# Use the default species. | ||
species = self._species | ||
if isinstance(species, (_np.ndarray, _torch.Tensor)): | ||
species = species.tolist() | ||
if not isinstance(species, (tuple, list)): | ||
raise TypeError( | ||
"'species' must be of type 'list', 'tuple', or 'numpy.ndarray'" | ||
) | ||
if not all(isinstance(s, int) for s in species): | ||
raise TypeError("All elements of 'species' must be of type 'int'") | ||
if not all(s > 0 for s in species): | ||
raise ValueError("All elements of 'species' must be greater than zero") | ||
|
||
# Create a map between species and their indices in the model. | ||
species_map = _np.full(max(species) + 2, fill_value=-1, dtype=_np.int64) | ||
for i, s in enumerate(species): | ||
species_map[s] = i | ||
|
@@ -151,10 +264,12 @@ def __init__( | |
self.register_buffer("_c_chi", c_chi) | ||
self.register_buffer("_c_sqrtk", c_sqrtk) | ||
|
||
# Initalise an empty AEV tensor to use to store the AEVs in derived classes. | ||
# Initalise an empty AEV tensor to use to store the AEVs in parent models. | ||
self._aev = _torch.empty(0, dtype=dtype, device=device) | ||
|
||
def to(self, *args, **kwargs): | ||
if self._aev_computer is not None: | ||
self._aev_computer = self._aev_computer.to(*args, **kwargs) | ||
self._species_map = self._species_map.to(*args, **kwargs) | ||
self._Kinv = self._Kinv.to(*args, **kwargs) | ||
self._aev_mask = self._aev_mask.to(*args, **kwargs) | ||
|
@@ -168,10 +283,18 @@ def to(self, *args, **kwargs): | |
self._c_chi = self._c_chi.to(*args, **kwargs) | ||
self._c_sqrtk = self._c_sqrtk.to(*args, **kwargs) | ||
|
||
# Check for a device type in args and update the device attribute. | ||
for arg in args: | ||
if isinstance(arg, _torch.device): | ||
self._device = arg | ||
break | ||
|
||
def cuda(self, **kwargs): | ||
""" | ||
Move all model parameters and buffers to CUDA memory. | ||
""" | ||
if self._aev_computer is not None: | ||
self._aev_computer = self._aev_computer.cuda(**kwargs) | ||
self._species_map = self._species_map.cuda(**kwargs) | ||
self._Kinv = self._Kinv.cuda(**kwargs) | ||
self._aev_mask = self._aev_mask.cuda(**kwargs) | ||
|
@@ -189,6 +312,8 @@ def cpu(self, **kwargs): | |
""" | ||
Move all model parameters and buffers to CPU memory. | ||
""" | ||
if self._aev_computer is not None: | ||
self._aev_computer = self._aev_computer.cpu(**kwargs) | ||
self._species_map = self._species_map.cpu(**kwargs) | ||
self._Kinv = self._Kinv.cpu(**kwargs) | ||
self._aev_mask = self._aev_mask.cpu(**kwargs) | ||
|
@@ -206,6 +331,8 @@ def double(self): | |
""" | ||
Casts all floating point model parameters and buffers to float64 precision. | ||
""" | ||
if self._aev_computer is not None: | ||
self._aev_computer = self._aev_computer.double() | ||
self._Kinv = self._Kinv.double() | ||
self._q_core = self._q_core.double() | ||
self._ref_features = self._ref_features.double() | ||
|
@@ -221,6 +348,8 @@ def float(self): | |
""" | ||
Casts all floating point model parameters and buffers to float32 precision. | ||
""" | ||
if self._aev_computer is not None: | ||
self._aev_computer = self._aev_computer.float() | ||
self._Kinv = self._Kinv.float() | ||
self._q_core = self._q_core.float() | ||
self._ref_features = self._ref_features.float() | ||
|
@@ -266,7 +395,11 @@ def forward(self, atomic_numbers, xyz_qm, q_total): | |
species_id = self._species_map[atomic_numbers] | ||
|
||
# Compute the AEVs. | ||
aev = self._aev_computer((species_id, xyz_qm))[1][:, :, self._aev_mask] | ||
if self._aev_computer is not None: | ||
aev = self._aev_computer((species_id, xyz_qm))[1][:, :, self._aev_mask] | ||
# The AEVs have been pre-computed by a parent model. | ||
else: | ||
aev = self._aev[:, :, self._aev_mask] | ||
aev = aev / _torch.linalg.norm(aev, ord=2, dim=2, keepdim=True) | ||
|
||
# Compute the MBIS valence shell widths. | ||
|