diff --git a/emle/models/_ani.py b/emle/models/_ani.py index 60ba708..619e316 100644 --- a/emle/models/_ani.py +++ b/emle/models/_ani.py @@ -241,7 +241,7 @@ def hook( input: Tuple[Tuple[Tensor, Tensor], Optional[Tensor], Optional[Tensor]], output: Tuple[Tensor, Tensor], ): - module._aev = output[1][0] + module._aev = output[1] else: @@ -250,7 +250,7 @@ def hook( input: Tuple[Tuple[Tensor, Tensor], Optional[Tensor], Optional[Tensor]], output: _torchani.aev.SpeciesAEV, ): - module._aev = output[1][0] + module._aev = output[1] # Register the hook. self._aev_hook = self._ani2x.aev_computer.register_forward_hook(hook) @@ -351,7 +351,7 @@ def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm): # Set the AEVs captured by the forward hook as an attribute of the # EMLE model. - self._emle._aev = self._ani2x.aev_computer._aev + self._emle._emle_base._aev = self._ani2x.aev_computer._aev # Get the EMLE energy components. E_emle = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm) diff --git a/emle/models/_emle.py b/emle/models/_emle.py index 2bbc1d9..fd6c786 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -39,7 +39,7 @@ from typing import Optional, Tuple, List from . import _patches -from . import EMLEBase +from . import EMLEBase as _EMLEBase # Monkey-patch the TorchANI BuiltInModel and BuiltinEnsemble classes so that # they call self.aev_computer using args only to allow forward hooks to work @@ -330,19 +330,8 @@ def __init__( else None ), } - self._emle_base = EMLEBase( - emle_params, - self._aev_computer, - aev_mask, - species, - n_ref, - ref_features, - q_core, - alpha_mode, - device, - dtype, - ) + # Store the total charge. q_total = _torch.tensor( params.get("total_charge", 0), dtype=dtype, device=device ) @@ -359,8 +348,19 @@ def __init__( self.register_buffer("_q_total", q_total) self.register_buffer("_q_core_mm", q_core_mm) - # Initalise an empty AEV tensor to use to store the AEVs in derived classes. - self._aev = _torch.empty(0, dtype=dtype, device=device) + # Create the base EMLE model. + self._emle_base = _EMLEBase( + emle_params, + n_ref, + ref_features, + q_core, + aev_computer=self._aev_computer, + aev_mask=aev_mask, + alpha_mode=self._alpha_mode, + species=self._species, + device=device, + dtype=dtype, + ) def _to_dict(self): """ diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index 4725898..5b95f9c 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -1,3 +1,25 @@ +####################################################################### +# EMLE-Engine: https://github.com/chemle/emle-engine +# +# Copyright: 2023-2024 +# +# Authors: Lester Hedges +# Kirill Zinovjev +# +# EMLE-Engine is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# EMLE-Engine is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with EMLE-Engine. If not, see . +##################################################################### + import numpy as _np import torch as _torch @@ -5,6 +27,17 @@ from torch import Tensor from typing import Tuple +import torchani as _torchani + +try: + import NNPOps as _NNPOps + + _NNPOps.OptimizedTorchANI = _patches.OptimizedTorchANI + + _has_nnpops = True +except: + _has_nnpops = False + class EMLEBase(_torch.nn.Module): """ @@ -14,15 +47,18 @@ class EMLEBase(_torch.nn.Module): electrostating embedding energies using the EMLE model. """ + # Store the list of supported species. + _species = [1, 6, 7, 8, 16] + def __init__( self, params, - aev_computer, - aev_mask, - species, n_ref, ref_features, q_core, + aev_computer=None, + aev_mask=None, + species=None, alpha_mode="species", device=None, dtype=None, @@ -34,26 +70,16 @@ def __init__( ---------- params: dict - EMLE model parameters - - aev_computer: torchani.AEVComputer - AEV computer instance used to compute AEVs. - - aev_mask: torch.Tensor - Mask for features coming from aev_computer. - - species: List[int], Tuple[int], numpy.ndarray, torch.Tensor - List of species (atomic numbers) supported by the EMLE model. If - None, then the default species list will be used. + EMLE model parameters. n_ref: torch.Tensor number of GPR references for each element in species list ref_features: torch.Tensor - Feature vectors for GPR references + Feature vectors for GPR references. q_core: torch.Tensor - Core charges for each element in species list + Core charges for each element in species list. alpha_mode: str How atomic polarizabilities are calculated. @@ -63,6 +89,15 @@ def __init__( scaling factors are obtained with GPR using the values learned for each reference environment + aev_computer: torchani.AEVComputer + AEV computer instance used to compute AEVs. + + aev_mask: torch.Tensor + Mask for features coming from aev_computer. + + species: List[int], Tuple[int], numpy.ndarray, torch.Tensor + List of species (atomic numbers) supported by the EMLE model. + device: torch.device The device on which to run the model. @@ -73,6 +108,46 @@ def __init__( # Call the base class constructor. super().__init__() + # Validate the parameters. + if not isinstance(params, dict): + raise TypeError("'params' must be of type 'dict'") + if not all( + k in params + for k in ["a_QEq", "a_Thole", "ref_values_s", "ref_values_chi", "k_Z"] + ): + raise ValueError( + "'params' must contain keys 'a_QEq', 'a_Thole', 'ref_values_s', 'ref_values_chi', and 'k_Z'" + ) + + # Validate the number of references. + if not isinstance(n_ref, _torch.Tensor): + raise TypeError("'n_ref' must be of type 'torch.Tensor'") + if len(n_ref.shape) != 1: + raise ValueError("'n_ref' must be a 1D tensor") + if not n_ref.dtype == _torch.int64: + raise ValueError("'n_ref' must have dtype 'torch.int64'") + + # Validate the reference features. + if not isinstance(ref_features, _torch.Tensor): + raise TypeError("'ref_features' must be of type 'torch.Tensor'") + if len(ref_features.shape) != 3: + raise ValueError("'ref_features' must be a 3D tensor") + if not ref_features.dtype in (_torch.float64, _torch.float32): + raise ValueError( + "'ref_features' must have dtype 'torch.float64' or 'torch.float32'" + ) + + # Validate the core charges. + if not isinstance(q_core, _torch.Tensor): + raise TypeError("'q_core' must be of type 'torch.Tensor'") + if len(q_core.shape) != 1: + raise ValueError("'q_core' must be a 1D tensor") + if not q_core.dtype in (_torch.float64, _torch.float32): + raise ValueError( + "'q_core' must have dtype 'torch.float64' or 'torch.float32'" + ) + + # Validate the alpha mode. if alpha_mode is None: alpha_mode = "species" if not isinstance(alpha_mode, str): @@ -82,6 +157,32 @@ def __init__( raise ValueError("'alpha_mode' must be 'species' or 'reference'") self._alpha_mode = alpha_mode + # Validate the AEV computer. + if aev_computer is not None: + allowed_types = [_torchani.AEVComputer] + if _has_nnpops: + allowed_types.append( + _NNPOps.SymmetryFunctions.TorchANISymmetryFunctions + ) + if not isinstance(aev_computer, tuple(allowed_types)): + raise TypeError( + "'aev_computer' must be of type 'torchani.AEVComputer' or 'NNPOps.SymmetryFunctions.TorchANISymmetryFunctions'" + ) + self._aev_computer = aev_computer + else: + self._aev_computer = None + + # Validate the AEV mask. + if aev_mask is not None: + if not isinstance(aev_mask, _torch.Tensor): + raise TypeError("'aev_mask' must be of type 'torch.Tensor'") + if len(aev_mask.shape) != 1: + raise ValueError("'aev_mask' must be a 1D tensor") + if not aev_mask.dtype == _torch.bool: + raise ValueError("'aev_mask' must have dtype 'torch.bool'") + else: + aev_mask = _torch.ones(ref_features.shape[2], dtype=_torch.bool) + if device is not None: if not isinstance(device, _torch.device): raise TypeError("'device' must be of type 'torch.device'") @@ -94,8 +195,6 @@ def __init__( else: dtype = _torch.get_default_dtype() - self._aev_computer = aev_computer - # Store model parameters as tensors. self.a_QEq = _torch.nn.Parameter(params["a_QEq"]) self.a_Thole = _torch.nn.Parameter(params["a_Thole"]) @@ -113,8 +212,22 @@ def __init__( ) raise ValueError(msg) - # Create a map between species (1, 6, 8) - # and their indices in the model (0, 1, 2). + # Validate the species. + if species is None: + # Use the default species. + species = self._species + if isinstance(species, (_np.ndarray, _torch.Tensor)): + species = species.tolist() + if not isinstance(species, (tuple, list)): + raise TypeError( + "'species' must be of type 'list', 'tuple', or 'numpy.ndarray'" + ) + if not all(isinstance(s, int) for s in species): + raise TypeError("All elements of 'species' must be of type 'int'") + if not all(s > 0 for s in species): + raise ValueError("All elements of 'species' must be greater than zero") + + # Create a map between species and their indices in the model. species_map = _np.full(max(species) + 2, fill_value=-1, dtype=_np.int64) for i, s in enumerate(species): species_map[s] = i @@ -151,10 +264,12 @@ def __init__( self.register_buffer("_c_chi", c_chi) self.register_buffer("_c_sqrtk", c_sqrtk) - # Initalise an empty AEV tensor to use to store the AEVs in derived classes. + # Initalise an empty AEV tensor to use to store the AEVs in parent models. self._aev = _torch.empty(0, dtype=dtype, device=device) def to(self, *args, **kwargs): + if self._aev_computer is not None: + self._aev_computer = self._aev_computer.to(*args, **kwargs) self._species_map = self._species_map.to(*args, **kwargs) self._Kinv = self._Kinv.to(*args, **kwargs) self._aev_mask = self._aev_mask.to(*args, **kwargs) @@ -168,10 +283,18 @@ def to(self, *args, **kwargs): self._c_chi = self._c_chi.to(*args, **kwargs) self._c_sqrtk = self._c_sqrtk.to(*args, **kwargs) + # Check for a device type in args and update the device attribute. + for arg in args: + if isinstance(arg, _torch.device): + self._device = arg + break + def cuda(self, **kwargs): """ Move all model parameters and buffers to CUDA memory. """ + if self._aev_computer is not None: + self._aev_computer = self._aev_computer.cuda(**kwargs) self._species_map = self._species_map.cuda(**kwargs) self._Kinv = self._Kinv.cuda(**kwargs) self._aev_mask = self._aev_mask.cuda(**kwargs) @@ -189,6 +312,8 @@ def cpu(self, **kwargs): """ Move all model parameters and buffers to CPU memory. """ + if self._aev_computer is not None: + self._aev_computer = self._aev_computer.cpu(**kwargs) self._species_map = self._species_map.cpu(**kwargs) self._Kinv = self._Kinv.cpu(**kwargs) self._aev_mask = self._aev_mask.cpu(**kwargs) @@ -206,6 +331,8 @@ def double(self): """ Casts all floating point model parameters and buffers to float64 precision. """ + if self._aev_computer is not None: + self._aev_computer = self._aev_computer.double() self._Kinv = self._Kinv.double() self._q_core = self._q_core.double() self._ref_features = self._ref_features.double() @@ -221,6 +348,8 @@ def float(self): """ Casts all floating point model parameters and buffers to float32 precision. """ + if self._aev_computer is not None: + self._aev_computer = self._aev_computer.float() self._Kinv = self._Kinv.float() self._q_core = self._q_core.float() self._ref_features = self._ref_features.float() @@ -266,7 +395,11 @@ def forward(self, atomic_numbers, xyz_qm, q_total): species_id = self._species_map[atomic_numbers] # Compute the AEVs. - aev = self._aev_computer((species_id, xyz_qm))[1][:, :, self._aev_mask] + if self._aev_computer is not None: + aev = self._aev_computer((species_id, xyz_qm))[1][:, :, self._aev_mask] + # The AEVs have been pre-computed by a parent model. + else: + aev = self._aev[:, :, self._aev_mask] aev = aev / _torch.linalg.norm(aev, ord=2, dim=2, keepdim=True) # Compute the MBIS valence shell widths.