From 7797084abaee464e4dbf7eab2b139cc90f0e4376 Mon Sep 17 00:00:00 2001 From: Kirill Zinovjev Date: Sun, 20 Oct 2024 12:38:57 +0200 Subject: [PATCH 01/39] Import logger in models/_emle.py --- emle/models/_emle.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/emle/models/_emle.py b/emle/models/_emle.py index 5dc8a37..3169c2f 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -27,6 +27,8 @@ __all__ = ["EMLE"] +from loguru import logger as _logger + import numpy as _np import os as _os import scipy.io as _scipy_io From 885f99c257262fcd2eeea4c8105ecadee9008ab9 Mon Sep 17 00:00:00 2001 From: Kirill Zinovjev Date: Sun, 20 Oct 2024 12:39:58 +0200 Subject: [PATCH 02/39] Refactor EMLE._get_mu_ind to receive A_thole as argument --- emle/models/_emle.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/emle/models/_emle.py b/emle/models/_emle.py index 3169c2f..9c71abb 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -642,7 +642,8 @@ def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm): # Compute the induced energy. if self._method == "electrostatic": - mu_ind = self._get_mu_ind(r_data, mesh_data, charges_mm, s, q_val, k) + A_thole = self._get_A_thole(r_data, s, q_val, k) + mu_ind = self._get_mu_ind(A_thole, mesh_data, charges_mm, s) vpot_ind = self._get_vpot_mu(mu_ind, mesh_data[2]) E_ind = _torch.sum(vpot_ind @ charges_mm) * 0.5 else: @@ -795,12 +796,10 @@ def _get_A_QEq(self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s): def _get_mu_ind( self, - r_data: Tuple[Tensor, Tensor, Tensor, Tensor], + A, mesh_data: Tuple[Tensor, Tensor, Tensor], q, s, - q_val, - k, ): """ Internal method, calculates induced atomic dipoles @@ -809,7 +808,8 @@ def _get_mu_ind( Parameters ---------- - r_data: r_data object (output of self._get_r_data) + A: torch.Tensor (N_ATOMS * 3, N_ATOMS * 3) + The A matrix for induced dipoles prediction. mesh_data: mesh_data object (output of self._get_mesh_data) @@ -822,16 +822,12 @@ def _get_mu_ind( q_val: torch.Tensor (N_QM_ATOMS,) MBIS valence charges. - k: torch.Tensor (N_Z) - Scaling factors for polarizabilities. - Returns ------- result: torch.Tensor (N_ATOMS, 3) Array of induced dipoles """ - A = self._get_A_thole(r_data, s, q_val, k) r = 1.0 / mesh_data[0] f1 = self._get_f1_slater(r, s[:, None] * 2.0) From 0bd8b8c1540f3ee3604154f0fb245728b1a30520 Mon Sep 17 00:00:00 2001 From: Kirill Zinovjev Date: Sun, 20 Oct 2024 15:29:24 +0200 Subject: [PATCH 03/39] Move core EMLE model functionality to EMLEBase class --- emle/models/__init__.py | 1 + emle/models/_emle.py | 499 +-------------------------------- emle/models/_emle_base.py | 570 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 585 insertions(+), 485 deletions(-) create mode 100644 emle/models/_emle_base.py diff --git a/emle/models/__init__.py b/emle/models/__init__.py index e3cf193..9bb6ea3 100644 --- a/emle/models/__init__.py +++ b/emle/models/__init__.py @@ -25,6 +25,7 @@ # avoid severe module import overheads when running the client code, # which requires no EMLE functionality. +from ._emle_base import EMLEBase from ._emle import EMLE from ._ani import ANI2xEMLE from ._mace import MACEEMLE diff --git a/emle/models/_emle.py b/emle/models/_emle.py index 9c71abb..b808c1c 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -39,6 +39,7 @@ from typing import Optional, Tuple, List from . import _patches +from . import EMLEBase # Monkey-patch the TorchANI BuiltInModel and BuiltinEnsemble classes so that # they call self.aev_computer using args only to allow forward hooks to work @@ -306,100 +307,19 @@ def __init__( except: raise IOError(f"Unable to load model parameters from: '{model}'") - # Create a map between species and their indices. - species_map = _np.full(max(species) + 1, fill_value=-1, dtype=_np.int64) - for i, s in enumerate(species): - species_map[s] = i + self._emle_base = EMLEBase(params, self._aev_computer, species, + alpha_mode, device, dtype) - # Convert to a tensor. - species_map = _torch.tensor(species_map, dtype=_torch.int64, device=device) - - # Store model parameters as tensors. - aev_mask = _torch.tensor(params["aev_mask"], dtype=_torch.bool, device=device) - q_core = _torch.tensor(params["q_core"], dtype=dtype, device=device) if method == "mm": q_core_mm = _torch.tensor(mm_charges, dtype=dtype, device=device) else: q_core_mm = _torch.empty(0, dtype=dtype, device=device) - a_QEq = _torch.tensor(params["a_QEq"], dtype=dtype, device=device) - a_Thole = _torch.tensor(params["a_Thole"], dtype=dtype, device=device) - if self._alpha_mode == "species": - try: - k = _torch.tensor(params["k_Z"], dtype=dtype, device=device) - except: - msg = ( - "Missing 'k_Z' key in model. This is required when " - "using 'species' alpha mode." - ) - raise ValueError(msg) - else: - try: - k = _torch.tensor(params["sqrtk_ref"], dtype=dtype, device=device) - except: - msg = ( - "Missing 'sqrtk_ref' key in model. This is required when " - "using 'reference' alpha mode." - ) - raise ValueError(msg) - - q_total = _torch.tensor( - params.get("total_charge", 0), dtype=dtype, device=device - ) - - # Extract the reference features. - ref_features = _torch.tensor(params["ref_aev"], dtype=dtype, device=device) - - # Extract the reference values for the MBIS valence shell widths. - ref_values_s = _torch.tensor(params["s_ref"], dtype=dtype, device=device) - - # Compute the inverse of the K matrix. - Kinv = self._get_Kinv(ref_features, 1e-3) - - # Store additional attributes for the MBIS GPR model. - n_ref = _torch.tensor(params["n_ref"], dtype=_torch.int64, device=device) - ref_mean_s = _torch.sum(ref_values_s, dim=1) / n_ref - ref_shifted = ref_values_s - ref_mean_s[:, None] - c_s = (Kinv @ ref_shifted[:, :, None]).squeeze() - - # Extract the reference values for the electronegativities. - ref_values_chi = _torch.tensor(params["chi_ref"], dtype=dtype, device=device) - - # Store additional attributes for the electronegativity GPR model. - ref_mean_chi = _torch.sum(ref_values_chi, dim=1) / n_ref - ref_shifted = ref_values_chi - ref_mean_chi[:, None] - c_chi = (Kinv @ ref_shifted[:, :, None]).squeeze() - - # Extract the reference values for the polarizabilities. - if self._alpha_mode == "reference": - ref_mean_k = _torch.sum(k, dim=1) / n_ref - ref_shifted = k - ref_mean_k[:, None] - c_k = (Kinv @ ref_shifted[:, :, None]).squeeze() - else: - ref_mean_k = _torch.empty(0, dtype=dtype, device=device) - c_k = _torch.empty(0, dtype=dtype, device=device) # Store the current device. self._device = device # Register constants as buffers. - self.register_buffer("_species_map", species_map) - self.register_buffer("_aev_mask", aev_mask) - self.register_buffer("_q_core", q_core) self.register_buffer("_q_core_mm", q_core_mm) - self.register_buffer("_a_QEq", a_QEq) - self.register_buffer("_a_Thole", a_Thole) - self.register_buffer("_k", k) - self.register_buffer("_q_total", q_total) - self.register_buffer("_ref_features", ref_features) - self.register_buffer("_n_ref", n_ref) - self.register_buffer("_ref_values_s", ref_values_s) - self.register_buffer("_ref_values_chi", ref_values_chi) - self.register_buffer("_ref_mean_s", ref_mean_s) - self.register_buffer("_ref_mean_chi", ref_mean_chi) - self.register_buffer("_c_s", c_s) - self.register_buffer("_c_chi", c_chi) - self.register_buffer("_ref_mean_k", ref_mean_k) - self.register_buffer("_c_k", c_k) # Initalise an empty AEV tensor to use to store the AEVs in derived classes. self._aev = _torch.empty(0, dtype=dtype, device=device) @@ -421,24 +341,7 @@ def to(self, *args, **kwargs): """ if self._aev_computer is not None: self._aev_computer = self._aev_computer.to(*args, **kwargs) - self._species_map = self._species_map.to(*args, **kwargs) - self._aev_mask = self._aev_mask.to(*args, **kwargs) - self._q_core = self._q_core.to(*args, **kwargs) - self._q_core_mm = self._q_core_mm.to(*args, **kwargs) - self._a_QEq = self._a_QEq.to(*args, **kwargs) - self._a_Thole = self._a_Thole.to(*args, **kwargs) - self._k = self._k.to(*args, **kwargs) - self._q_total = self._q_total.to(*args, **kwargs) - self._ref_features = self._ref_features.to(*args, **kwargs) - self._n_ref = self._n_ref.to(*args, **kwargs) - self._ref_values_s = self._ref_values_s.to(*args, **kwargs) - self._ref_values_chi = self._ref_values_chi.to(*args, **kwargs) - self._ref_mean_s = self._ref_mean_s.to(*args, **kwargs) - self._ref_mean_chi = self._ref_mean_chi.to(*args, **kwargs) - self._c_s = self._c_s.to(*args, **kwargs) - self._c_chi = self._c_chi.to(*args, **kwargs) - self._ref_mean_k = self._ref_mean_k.to(*args, **kwargs) - self._c_k = self._c_k.to(*args, **kwargs) + self._emle_base = self._emle_base.to(*args, **kwargs) # Check for a device type in args and update the device attribute. for arg in args: @@ -454,24 +357,7 @@ def cuda(self, **kwargs): """ if self._aev_computer is not None: self._aev_computer = self._aev_computer.cuda(**kwargs) - self._species_map = self._species_map.cuda(**kwargs) - self._aev_mask = self._aev_mask.cuda(**kwargs) - self._q_core = self._q_core.cuda(**kwargs) - self._q_core_mm = self._q_core_mm.cuda(**kwargs) - self._a_QEq = self._a_QEq.cuda(**kwargs) - self._a_Thole = self._a_Thole.cuda(**kwargs) - self._k = self._k.cuda(**kwargs) - self._q_total = self._q_total.cuda(**kwargs) - self._ref_features = self._ref_features.cuda(**kwargs) - self._n_ref = self._n_ref.cuda(**kwargs) - self._ref_values_s = self._ref_values_s.cuda(**kwargs) - self._ref_values_chi = self._ref_values_chi.cuda(**kwargs) - self._ref_mean_s = self._ref_mean_s.cuda(**kwargs) - self._ref_mean_chi = self._ref_mean_chi.cuda(**kwargs) - self._c_s = self._c_s.cuda(**kwargs) - self._c_chi = self._c_chi.cuda(**kwargs) - self._ref_mean_k = self._ref_mean_k.cuda(**kwargs) - self._c_k = self._c_k.cuda(**kwargs) + self._emle_base = self._emle_base.cuda(**kwargs) # Update the device attribute. self._device = self._species_map.device @@ -484,24 +370,7 @@ def cpu(self, **kwargs): """ if self._aev_computer is not None: self._aev_computer = self._aev_computer.cpu(**kwargs) - self._species_map = self._species_map.cpu(**kwargs) - self._aev_mask = self._aev_mask.cpu(**kwargs) - self._q_core = self._q_core.cpu(**kwargs) - self._q_core_mm = self._q_core_mm.cpu(**kwargs) - self._a_QEq = self._a_QEq.cpu(**kwargs) - self._a_Thole = self._a_Thole.cpu(**kwargs) - self._k = self._k.cpu(**kwargs) - self._q_total = self._q_total.cpu(**kwargs) - self._ref_features = self._ref_features.cpu(**kwargs) - self._n_ref = self._n_ref.cpu(**kwargs) - self._ref_values_s = self._ref_values_s.cpu(**kwargs) - self._ref_values_chi = self._ref_values_chi.cpu(**kwargs) - self._ref_mean_s = self._ref_mean_s.cpu(**kwargs) - self._ref_mean_chi = self._ref_mean_chi.cpu(**kwargs) - self._c_s = self._c_s.cpu(**kwargs) - self._c_chi = self._c_chi.cpu(**kwargs) - self._ref_mean_k = self._ref_mean_k.cpu(**kwargs) - self._c_k = self._c_k.cpu(**kwargs) + self._emle_base = self._emle_base.cpu() # Update the device attribute. self._device = self._species_map.device @@ -514,21 +383,7 @@ def double(self): """ if self._aev_computer is not None: self._aev_computer = self._aev_computer.double() - self._q_core = self._q_core.double() - self._q_core_mm = self._q_core_mm.double() - self._a_QEq = self._a_QEq.double() - self._a_Thole = self._a_Thole.double() - self._k = self._k.double() - self._q_total = self._q_total.double() - self._ref_features = self._ref_features.double() - self._ref_values_s = self._ref_values_s.double() - self._ref_values_chi = self._ref_values_chi.double() - self._ref_mean_s = self._ref_mean_s.double() - self._ref_mean_chi = self._ref_mean_chi.double() - self._c_s = self._c_s.double() - self._c_chi = self._c_chi.double() - self._ref_mean_k = self._ref_mean_k.double() - self._c_k = self._c_k.double() + self._emle_base = self._emle_base.double() return self def float(self): @@ -537,21 +392,7 @@ def float(self): """ if self._aev_computer is not None: self._aev_computer = self._aev_computer.float() - self._q_core = self._q_core.float() - self._q_core_mm = self._q_core_mm.float() - self._a_QEq = self._a_QEq.float() - self._a_Thole = self._a_Thole.float() - self._k = self._k.float() - self._q_total = self._q_total.float() - self._ref_features = self._ref_features.float() - self._ref_values_s = self._ref_values_s.float() - self._ref_values_chi = self._ref_values_chi.float() - self._ref_mean_s = self._ref_mean_s.float() - self._ref_mean_chi = self._ref_mean_chi.float() - self._c_s = self._c_s.float() - self._c_chi = self._c_chi.float() - self._ref_mean_k = self._ref_mean_k.float() - self._c_k = self._c_k.float() + self._emle_base = self._emle_base.float() return self def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm): @@ -584,28 +425,7 @@ def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm): if len(xyz_mm) == 0: return _torch.zeros(2, dtype=xyz_qm.dtype, device=xyz_qm.device) - # Convert the atomic numbers to species IDs. - species_id = self._species_map[atomic_numbers] - - # Reshape the IDs. - zid = species_id.unsqueeze(0) - - # Reshape the atomic positions. - xyz = xyz_qm.unsqueeze(0) - - # Compute the AEVs. - if self._aev_computer is not None: - aev = self._aev_computer((zid, xyz))[1][0][:, self._aev_mask] - # The AEVs have been pre-computed by a derived class. - else: - aev = self._aev[:, self._aev_mask] - aev = aev / _torch.linalg.norm(aev, ord=2, dim=1, keepdim=True) - - # Compute the MBIS valence shell widths. - s = self._gpr(aev, self._ref_mean_s, self._c_s, species_id) - - # Compute the electronegativities. - chi = self._gpr(aev, self._ref_mean_chi, self._c_chi, species_id) + s, q_core, q_val, A_thole = self._emle_base(atomic_numbers, xyz_qm) # Convert coordinates to Bohr. ANGSTROM_TO_BOHR = 1.8897261258369282 @@ -613,25 +433,15 @@ def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm): xyz_mm_bohr = xyz_mm * ANGSTROM_TO_BOHR # Compute the static energy. - if self._method != "mm": - q_core = self._q_core[species_id] - else: + if self._method == "mm": q_core = self._q_core_mm - if self._alpha_mode == "species": - k = self._k[species_id] - else: - k = self._gpr(aev, self._ref_mean_k, self._c_k, species_id) ** 2 - r_data = self._get_r_data(xyz_qm_bohr) - mesh_data = self._get_mesh_data(xyz_qm_bohr, xyz_mm_bohr, s) - if self._method in ["electrostatic", "nonpol"]: - q = self._get_q(r_data, s, chi) - q_val = q - q_core - elif self._method == "mechanical": - q_core = self._get_q(r_data, s, chi) q_val = _torch.zeros_like( q_core, dtype=charges_mm.dtype, device=self._device ) - else: + + mesh_data = self._get_mesh_data(xyz_qm_bohr, xyz_mm_bohr, s) + if self._method == "mechanical": + q_core = q_core + q_val q_val = _torch.zeros_like( q_core, dtype=charges_mm.dtype, device=self._device ) @@ -642,7 +452,6 @@ def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm): # Compute the induced energy. if self._method == "electrostatic": - A_thole = self._get_A_thole(r_data, s, q_val, k) mu_ind = self._get_mu_ind(A_thole, mesh_data, charges_mm, s) vpot_ind = self._get_vpot_mu(mu_ind, mesh_data[2]) E_ind = _torch.sum(vpot_ind @ charges_mm) * 0.5 @@ -651,32 +460,6 @@ def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm): return _torch.stack([E_static, E_ind]) - @classmethod - def _get_Kinv(cls, ref_features, sigma): - """ - Internal function to compute the inverse of the K matrix for GPR. - - Parameters - ---------- - - ref_features: torch.Tensor (N_Z, MAX_N_REF, N_FEAT) - The basis feature vectors for each species. - - sigma: float - The uncertainty of the observations (regularizer). - - Returns - ------- - - result: torch.Tensor (MAX_N_REF, MAX_N_REF) - The inverse of the K matrix. - """ - n = ref_features.shape[1] - K = (ref_features @ ref_features.swapaxes(1, 2)) ** 2 - return _torch.linalg.inv( - K + sigma**2 * _torch.eye(n, dtype=ref_features.dtype, device=K.device) - ) - def _gpr(self, mol_features, ref_mean, c, zid): """ Internal method to predict a property using Gaussian Process Regression. @@ -717,83 +500,6 @@ def _gpr(self, mol_features, ref_mean, c, zid): return result - def _get_q(self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s, chi): - """ - Internal method that predicts MBIS charges - (Eq. 16 in 10.1021/acs.jctc.2c00914) - - Parameters - ---------- - - r_data: r_data object (output of self._get_r_data) - - s: torch.Tensor (N_ATOMS,) - MBIS valence shell widths. - - chi: torch.Tensor (N_ATOMS,) - Electronegativities. - - Returns - ------- - - result: torch.Tensor (N_ATOMS,) - Predicted MBIS charges. - """ - A = self._get_A_QEq(r_data, s) - b = _torch.hstack([-chi, self._q_total]) - return _torch.linalg.solve(A, b)[:-1] - - def _get_A_QEq(self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s): - """ - Internal method, generates A matrix for charge prediction - (Eq. 16 in 10.1021/acs.jctc.2c00914) - - Parameters - ---------- - - r_data: r_data object (output of self._get_r_data) - - s: torch.Tensor (N_ATOMS,) - MBIS valence shell widths. - - Returns - ------- - - result: torch.Tensor (N_ATOMS + 1, N_ATOMS + 1) - """ - s_gauss = s * self._a_QEq - s2 = s_gauss**2 - s_mat = _torch.sqrt(s2[:, None] + s2[None, :]) - - device = r_data[0].device - dtype = r_data[0].dtype - - A = self._get_T0_gaussian(r_data[1], r_data[0], s_mat) - - new_diag = _torch.ones_like(A.diagonal(), dtype=dtype, device=device) * ( - 1.0 - / ( - s_gauss - * _torch.sqrt(_torch.tensor([_torch.pi], dtype=dtype, device=device)) - ) - ) - mask = _torch.diag(_torch.ones_like(new_diag, dtype=dtype, device=device)) - A = mask * _torch.diag(new_diag) + (1.0 - mask) * A - - # Store the dimensions of A. - x, y = A.shape - - # Create an tensor of ones with one more row and column than A. - B = _torch.ones(x + 1, y + 1, dtype=dtype, device=device) - - # Copy A into B. - B[:x, :y] = A - - # Set the final entry on the diagonal to zero. - B[-1, -1] = 0.0 - - return B - def _get_mu_ind( self, A, @@ -836,48 +542,6 @@ def _get_mu_ind( mu_ind = _torch.linalg.solve(A, fields) return mu_ind.reshape((-1, 3)) - def _get_A_thole(self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s, q_val, k): - """ - Internal method, generates A matrix for induced dipoles prediction - (Eq. 20 in 10.1021/acs.jctc.2c00914) - - Parameters - ---------- - - r_data: r_data object (output of self._get_r_data) - - s: torch.Tensor (N_ATOMS,) - MBIS valence shell widths. - - q_val: torch.Tensor (N_ATOMS,) - MBIS charges. - - k: torch.Tensor (N_Z) - Scaling factors for polarizabilities. - - Returns - ------- - - result: torch.Tensor (N_ATOMS * 3, N_ATOMS * 3) - The A matrix for induced dipoles prediction. - """ - v = -60 * q_val * s**3 - alpha = v * k - - alphap = alpha * self._a_Thole - alphap_mat = alphap[:, None] * alphap[None, :] - - au3 = r_data[0] ** 3 / _torch.sqrt(alphap_mat) - au31 = au3.repeat_interleave(3, dim=1) - au32 = au31.repeat_interleave(3, dim=0) - - A = -self._get_T2_thole(r_data[2], r_data[3], au32) - - new_diag = 1.0 / alpha.repeat_interleave(3) - mask = _torch.diag(_torch.ones_like(new_diag, dtype=A.dtype, device=A.device)) - A = mask * _torch.diag(new_diag) + (1.0 - mask) * A - - return A @staticmethod def _get_vpot_q(q, T0): @@ -924,49 +588,6 @@ def _get_vpot_mu(mu, T1): """ return -_torch.tensordot(T1, mu, ((0, 2), (0, 1))) - @classmethod - def _get_r_data(cls, xyz): - """ - Internal method to calculate r_data object. - - Parameters - ---------- - - xyz: torch.Tensor (N_ATOMS, 3) - Atomic positions. - - Returns - ------- - - result: r_data object - """ - n_atoms = len(xyz) - - rr_mat = xyz[:, None, :] - xyz[None, :, :] - r_mat = _torch.cdist(xyz, xyz) - r_inv = _torch.where(r_mat == 0.0, 0.0, 1.0 / r_mat) - - r_inv1 = r_inv.repeat_interleave(3, dim=1) - r_inv2 = r_inv1.repeat_interleave(3, dim=0) - - # Get a stacked matrix of outer products over the rr_mat tensors. - outer = _torch.einsum("bik,bij->bjik", rr_mat, rr_mat).reshape( - (n_atoms * 3, n_atoms * 3) - ) - - id2 = _torch.tile( - _torch.tile( - _torch.eye(3, dtype=xyz.dtype, device=xyz.device).T, (1, n_atoms) - ).T, - (1, n_atoms), - ) - - t01 = r_inv - t21 = -id2 * r_inv2**3 - t22 = 3 * outer * r_inv2**5 - - return (r_mat, t01, t21, t22) - @classmethod def _get_mesh_data(cls, xyz, xyz_mesh, s): """ @@ -1033,95 +654,3 @@ def _get_T0_slater(r, s): results: torch.Tensor (N_ATOMS, max_mm_atoms) """ return (1 - (1 + r / (s * 2)) * _torch.exp(-r / s)) / r - - @staticmethod - def _get_T0_gaussian(t01, r, s_mat): - """ - Internal method, calculates T0 tensor for Gaussian densities (for QEq). - - Parameters - ---------- - - t01: torch.Tensor (N_ATOMS, N_ATOMS) - T0 tensor for QM atoms. - - r: torch.Tensor (N_ATOMS, N_ATOMS) - Distance matrix for QM atoms. - - s_mat: torch.Tensor (N_ATOMS, N_ATOMS) - Matrix of Gaussian sigmas for QM atoms. - - Returns - ------- - - results: torch.Tensor (N_ATOMS, N_ATOMS) - """ - return t01 * _torch.erf( - r - / ( - s_mat - * _torch.sqrt(_torch.tensor([2.0], dtype=r.dtype, device=r.device)) - ) - ) - - @classmethod - def _get_T2_thole(cls, tr21, tr22, au3): - """ - Internal method, calculates T2 tensor with Thole damping. - - Parameters - ---------- - - tr21: torch.Tensor (N_ATOMS * 3, N_ATOMS * 3) - r_data[2] - - tr21: torch.Tensor (N_ATOMS * 3, N_ATOMS * 3) - r_data[3] - - au3: torch.Tensor (N_ATOMS * 3, N_ATOMS * 3) - Scaled distance matrix (see _get_A_thole). - - Returns - ------- - - result: torch.Tensor (N_ATOMS * 3, N_ATOMS * 3) - """ - return cls._lambda3(au3) * tr21 + cls._lambda5(au3) * tr22 - - @staticmethod - def _lambda3(au3): - """ - Internal method, calculates r^3 component of T2 tensor with Thole - damping. - - Parameters - ---------- - - au3: torch.Tensor (N_ATOMS * 3, N_ATOMS * 3) - Scaled distance matrix (see _get_A_thole). - - Returns - ------- - - result: torch.Tensor (N_ATOMS * 3, N_ATOMS * 3) - """ - return 1 - _torch.exp(-au3) - - @staticmethod - def _lambda5(au3): - """ - Internal method, calculates r^5 component of T2 tensor with Thole - damping. - - Parameters - ---------- - - au3: torch.Tensor (N_ATOMS * 3, N_ATOMS * 3) - Scaled distance matrix (see _get_A_thole). - - Returns - ------- - - result: torch.Tensor (N_ATOMS * 3, N_ATOMS * 3) - """ - return 1 - (1 + au3) * _torch.exp(-au3) diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py new file mode 100644 index 0000000..2fd82ca --- /dev/null +++ b/emle/models/_emle_base.py @@ -0,0 +1,570 @@ +import numpy as _np + +import torch as _torch + +from torch import Tensor +from typing import Tuple + +ANGSTROM_TO_BOHR = 1.8897261258369282 + + +class EMLEBase(_torch.nn.Module): + + def __init__( + self, + params, + aev_computer, + # method="electrostatic", # Not used here, always electrostatic + species=None, + alpha_mode="species", + # atomic_numbers=None, # Not used here, since aev_computer is provided + device=None, + dtype=None, + ): + """ + Constructor. + + Parameters + ---------- + + params: dict + EMLE model parameters + + aev_computer: AEVComputer instance (torchani/NNPOps) + + method: str + The desired embedding method. Options are: + "electrostatic": + Full ML electrostatic embedding. + "mechanical": + ML predicted charges for the core, but zero valence charge. + "nonpol": + Non-polarisable ML embedding. Here the induced component of + the potential is zeroed. + "mm": + MM charges are used for the core charge and valence charges + are set to zero. If this option is specified then the user + should also specify the MM charges for atoms in the QM + region. + + species: List[int], Tuple[int], numpy.ndarray, torch.Tensor + List of species (atomic numbers) supported by the EMLE model. If + None, then the default species list will be used. + + alpha_mode: str + How atomic polarizabilities are calculated. + "species": + one volume scaling factor is used for each species + "reference": + scaling factors are obtained with GPR using the values learned + for each reference environment + + mm_charges: List[float], Tuple[Float], numpy.ndarray, torch.Tensor + List of MM charges for atoms in the QM region in units of mod + electron charge. This is required if the 'mm' method is specified. + + device: torch.device + The device on which to run the model. + + dtype: torch.dtype + The data type to use for the models floating point tensors. + """ + + # Call the base class constructor. + super().__init__() + + if alpha_mode is None: + alpha_mode = "species" + if not isinstance(alpha_mode, str): + raise TypeError("'alpha_mode' must be of type 'str'") + alpha_mode = alpha_mode.lower().replace(" ", "") + if alpha_mode not in ["species", "reference"]: + raise ValueError("'alpha_mode' must be 'species' or 'reference'") + self._alpha_mode = alpha_mode + + if device is not None: + if not isinstance(device, _torch.device): + raise TypeError("'device' must be of type 'torch.device'") + else: + device = _torch.get_default_device() + + if dtype is not None: + if not isinstance(dtype, _torch.dtype): + raise TypeError("'dtype' must be of type 'torch.dtype'") + else: + dtype = _torch.get_default_dtype() + + self._aev_computer = aev_computer + + # Create a map between species and their indices. + species_map = _np.full(max(species) + 1, fill_value=-1, dtype=_np.int64) + for i, s in enumerate(species): + species_map[s] = i + + # Convert to a tensor. + species_map = _torch.tensor(species_map, dtype=_torch.int64, device=device) + + # Store model parameters as tensors. + aev_mask = _torch.tensor(params["aev_mask"], dtype=_torch.bool, device=device) + q_core = _torch.tensor(params["q_core"], dtype=dtype, device=device) + a_QEq = _torch.tensor(params["a_QEq"], dtype=dtype, device=device) + a_Thole = _torch.tensor(params["a_Thole"], dtype=dtype, device=device) + if self._alpha_mode == "species": + try: + k = _torch.tensor(params["k_Z"], dtype=dtype, device=device) + except: + msg = ( + "Missing 'k_Z' key in params. This is required when " + "using 'species' alpha mode." + ) + raise ValueError(msg) + else: + try: + k = _torch.tensor(params["sqrtk_ref"], dtype=dtype, device=device) + except: + msg = ( + "Missing 'sqrtk_ref' key in params. This is required when " + "using 'reference' alpha mode." + ) + raise ValueError(msg) + + q_total = _torch.tensor( + params.get("total_charge", 0), dtype=dtype, device=device + ) + + # Extract the reference features. + ref_features = _torch.tensor(params["ref_aev"], dtype=dtype, device=device) + + # Extract the reference values for the MBIS valence shell widths. + ref_values_s = _torch.tensor(params["s_ref"], dtype=dtype, device=device) + + # Compute the inverse of the K matrix. + Kinv = self._get_Kinv(ref_features, 1e-3) + + # Store additional attributes for the MBIS GPR model. + n_ref = _torch.tensor(params["n_ref"], dtype=_torch.int64, device=device) + ref_mean_s = _torch.sum(ref_values_s, dim=1) / n_ref + ref_shifted = ref_values_s - ref_mean_s[:, None] + c_s = (Kinv @ ref_shifted[:, :, None]).squeeze() + + # Extract the reference values for the electronegativities. + ref_values_chi = _torch.tensor(params["chi_ref"], dtype=dtype, device=device) + + # Store additional attributes for the electronegativity GPR model. + ref_mean_chi = _torch.sum(ref_values_chi, dim=1) / n_ref + ref_shifted = ref_values_chi - ref_mean_chi[:, None] + c_chi = (Kinv @ ref_shifted[:, :, None]).squeeze() + + # Extract the reference values for the polarizabilities. + if self._alpha_mode == "reference": + ref_mean_k = _torch.sum(k, dim=1) / n_ref + ref_shifted = k - ref_mean_k[:, None] + c_k = (Kinv @ ref_shifted[:, :, None]).squeeze() + else: + ref_mean_k = _torch.empty(0, dtype=dtype, device=device) + c_k = _torch.empty(0, dtype=dtype, device=device) + + # Store the current device. + self._device = device + + # Register constants as buffers. + self.register_buffer("_species_map", species_map) + self.register_buffer("_aev_mask", aev_mask) + self.register_buffer("_q_core", q_core) + self.register_buffer("_a_QEq", a_QEq) + self.register_buffer("_a_Thole", a_Thole) + self.register_buffer("_k", k) + self.register_buffer("_q_total", q_total) + self.register_buffer("_ref_features", ref_features) + self.register_buffer("_n_ref", n_ref) + self.register_buffer("_ref_values_s", ref_values_s) + self.register_buffer("_ref_values_chi", ref_values_chi) + self.register_buffer("_ref_mean_s", ref_mean_s) + self.register_buffer("_ref_mean_chi", ref_mean_chi) + self.register_buffer("_c_s", c_s) + self.register_buffer("_c_chi", c_chi) + self.register_buffer("_ref_mean_k", ref_mean_k) + self.register_buffer("_c_k", c_k) + + # Initalise an empty AEV tensor to use to store the AEVs in derived classes. + self._aev = _torch.empty(0, dtype=dtype, device=device) + + def forward(self, atomic_numbers, xyz_qm): + """ + Computes the static and induced EMLE energy components. + + Parameters + ---------- + + atomic_numbers: torch.Tensor (N_QM_ATOMS,) + Atomic numbers of QM atoms. + + xyz_qm: torch.Tensor (N_QM_ATOMS, 3) + Positions of QM atoms in Angstrom. + + Returns + ------- + + result: (torch.Tensor (N_QM_ATOMS,), + torch.Tensor (N_QM_ATOMS,), + torch.Tensor (N_QM_ATOMS,), + torch.Tensor (N_QM_ATOMS * 3, N_QM_ATOMS * 3,)) + Valence widths, core charges, valence charges, A_thole tensor + """ + + # Convert the atomic numbers to species IDs. + species_id = self._species_map[atomic_numbers] + + # Reshape the IDs. + zid = species_id.unsqueeze(0) + + # Reshape the atomic positions. + xyz = xyz_qm.unsqueeze(0) + + # Compute the AEVs. + aev = self._aev_computer((zid, xyz))[1][0][:, self._aev_mask] + aev = aev / _torch.linalg.norm(aev, ord=2, dim=1, keepdim=True) + + # Compute the MBIS valence shell widths. + s = self._gpr(aev, self._ref_mean_s, self._c_s, species_id) + + # Compute the electronegativities. + chi = self._gpr(aev, self._ref_mean_chi, self._c_chi, species_id) + + xyz_qm_bohr = xyz_qm * ANGSTROM_TO_BOHR + + r_data = self._get_r_data(xyz_qm_bohr) + + q_core = self._q_core[species_id] + q = self._get_q(r_data, s, chi) + q_val = q - q_core + + if self._alpha_mode == "species": + k = self._k[species_id] + else: + k = self._gpr(aev, self._ref_mean_k, self._c_k, species_id) ** 2 + + A_thole = self._get_A_thole(r_data, s, q_val, k) + + return s, q_core, q_val, A_thole + + @classmethod + def _get_Kinv(cls, ref_features, sigma): + """ + Internal function to compute the inverse of the K matrix for GPR. + + Parameters + ---------- + + ref_features: torch.Tensor (N_Z, MAX_N_REF, N_FEAT) + The basis feature vectors for each species. + + sigma: float + The uncertainty of the observations (regularizer). + + Returns + ------- + + result: torch.Tensor (MAX_N_REF, MAX_N_REF) + The inverse of the K matrix. + """ + n = ref_features.shape[1] + K = (ref_features @ ref_features.swapaxes(1, 2)) ** 2 + return _torch.linalg.inv( + K + sigma**2 * _torch.eye(n, dtype=ref_features.dtype, device=K.device) + ) + + def _gpr(self, mol_features, ref_mean, c, zid): + """ + Internal method to predict a property using Gaussian Process Regression. + + Parameters + ---------- + + mol_features: torch.Tensor (N_ATOMS, N_FEAT) + The feature vectors for each atom. + + ref_mean: torch.Tensor (N_Z,) + The mean of the reference values for each species. + + c: torch.Tensor (N_Z, MAX_N_REF) + The coefficients of the GPR model. + + zid: torch.Tensor (N_ATOMS,) + The species identity value of each atom. + + Returns + ------- + + result: torch.Tensor (N_ATOMS) + The values of the predicted property for each atom. + """ + + result = _torch.zeros( + len(zid), dtype=mol_features.dtype, device=mol_features.device + ) + for i in range(len(self._n_ref)): + n_ref = self._n_ref[i] + ref_features_z = self._ref_features[i, :n_ref] + mol_features_z = mol_features[zid == i, :, None] + + K_mol_ref2 = (ref_features_z @ mol_features_z) ** 2 + K_mol_ref2 = K_mol_ref2.reshape(K_mol_ref2.shape[:-1]) + result[zid == i] = K_mol_ref2 @ c[i, :n_ref] + ref_mean[i] + + return result + + @classmethod + def _get_r_data(cls, xyz): + """ + Internal method to calculate r_data object. + + Parameters + ---------- + + xyz: torch.Tensor (N_ATOMS, 3) + Atomic positions. + + Returns + ------- + + result: r_data object + """ + n_atoms = len(xyz) + + rr_mat = xyz[:, None, :] - xyz[None, :, :] + r_mat = _torch.cdist(xyz, xyz) + r_inv = _torch.where(r_mat == 0.0, 0.0, 1.0 / r_mat) + + r_inv1 = r_inv.repeat_interleave(3, dim=1) + r_inv2 = r_inv1.repeat_interleave(3, dim=0) + + # Get a stacked matrix of outer products over the rr_mat tensors. + outer = _torch.einsum("bik,bij->bjik", rr_mat, rr_mat).reshape( + (n_atoms * 3, n_atoms * 3) + ) + + id2 = _torch.tile( + _torch.tile( + _torch.eye(3, dtype=xyz.dtype, device=xyz.device).T, (1, n_atoms) + ).T, + (1, n_atoms), + ) + + t01 = r_inv + t21 = -id2 * r_inv2 ** 3 + t22 = 3 * outer * r_inv2 ** 5 + + return (r_mat, t01, t21, t22) + + def _get_q(self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s, chi): + """ + Internal method that predicts MBIS charges + (Eq. 16 in 10.1021/acs.jctc.2c00914) + + Parameters + ---------- + + r_data: r_data object (output of self._get_r_data) + + s: torch.Tensor (N_ATOMS,) + MBIS valence shell widths. + + chi: torch.Tensor (N_ATOMS,) + Electronegativities. + + Returns + ------- + + result: torch.Tensor (N_ATOMS,) + Predicted MBIS charges. + """ + A = self._get_A_QEq(r_data, s) + b = _torch.hstack([-chi, self._q_total]) + return _torch.linalg.solve(A, b)[:-1] + + def _get_A_QEq(self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s): + """ + Internal method, generates A matrix for charge prediction + (Eq. 16 in 10.1021/acs.jctc.2c00914) + + Parameters + ---------- + + r_data: r_data object (output of self._get_r_data) + + s: torch.Tensor (N_ATOMS,) + MBIS valence shell widths. + + Returns + ------- + + result: torch.Tensor (N_ATOMS + 1, N_ATOMS + 1) + """ + s_gauss = s * self._a_QEq + s2 = s_gauss**2 + s_mat = _torch.sqrt(s2[:, None] + s2[None, :]) + + device = r_data[0].device + dtype = r_data[0].dtype + + A = self._get_T0_gaussian(r_data[1], r_data[0], s_mat) + + new_diag = _torch.ones_like(A.diagonal(), dtype=dtype, device=device) * ( + 1.0 + / ( + s_gauss + * _torch.sqrt(_torch.tensor([_torch.pi], dtype=dtype, device=device)) + ) + ) + mask = _torch.diag(_torch.ones_like(new_diag, dtype=dtype, device=device)) + A = mask * _torch.diag(new_diag) + (1.0 - mask) * A + + # Store the dimensions of A. + x, y = A.shape + + # Create an tensor of ones with one more row and column than A. + B = _torch.ones(x + 1, y + 1, dtype=dtype, device=device) + + # Copy A into B. + B[:x, :y] = A + + # Set the final entry on the diagonal to zero. + B[-1, -1] = 0.0 + + return B + + @staticmethod + def _get_T0_gaussian(t01, r, s_mat): + """ + Internal method, calculates T0 tensor for Gaussian densities (for QEq). + + Parameters + ---------- + + t01: torch.Tensor (N_ATOMS, N_ATOMS) + T0 tensor for QM atoms. + + r: torch.Tensor (N_ATOMS, N_ATOMS) + Distance matrix for QM atoms. + + s_mat: torch.Tensor (N_ATOMS, N_ATOMS) + Matrix of Gaussian sigmas for QM atoms. + + Returns + ------- + + results: torch.Tensor (N_ATOMS, N_ATOMS) + """ + return t01 * _torch.erf( + r + / ( + s_mat + * _torch.sqrt(_torch.tensor([2.0], dtype=r.dtype, device=r.device)) + ) + ) + + def _get_A_thole(self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s, q_val, k): + """ + Internal method, generates A matrix for induced dipoles prediction + (Eq. 20 in 10.1021/acs.jctc.2c00914) + + Parameters + ---------- + + r_data: r_data object (output of self._get_r_data) + + s: torch.Tensor (N_ATOMS,) + MBIS valence shell widths. + + q_val: torch.Tensor (N_ATOMS,) + MBIS charges. + + k: torch.Tensor (N_Z) + Scaling factors for polarizabilities. + + Returns + ------- + + result: torch.Tensor (N_ATOMS * 3, N_ATOMS * 3) + The A matrix for induced dipoles prediction. + """ + v = -60 * q_val * s**3 + alpha = v * k + + alphap = alpha * self._a_Thole + alphap_mat = alphap[:, None] * alphap[None, :] + + au3 = r_data[0] ** 3 / _torch.sqrt(alphap_mat) + au31 = au3.repeat_interleave(3, dim=1) + au32 = au31.repeat_interleave(3, dim=0) + + A = -self._get_T2_thole(r_data[2], r_data[3], au32) + + new_diag = 1.0 / alpha.repeat_interleave(3) + mask = _torch.diag(_torch.ones_like(new_diag, dtype=A.dtype, device=A.device)) + A = mask * _torch.diag(new_diag) + (1.0 - mask) * A + + return A + + @classmethod + def _get_T2_thole(cls, tr21, tr22, au3): + """ + Internal method, calculates T2 tensor with Thole damping. + + Parameters + ---------- + + tr21: torch.Tensor (N_ATOMS * 3, N_ATOMS * 3) + r_data[2] + + tr21: torch.Tensor (N_ATOMS * 3, N_ATOMS * 3) + r_data[3] + + au3: torch.Tensor (N_ATOMS * 3, N_ATOMS * 3) + Scaled distance matrix (see _get_A_thole). + + Returns + ------- + + result: torch.Tensor (N_ATOMS * 3, N_ATOMS * 3) + """ + return cls._lambda3(au3) * tr21 + cls._lambda5(au3) * tr22 + + @staticmethod + def _lambda3(au3): + """ + Internal method, calculates r^3 component of T2 tensor with Thole + damping. + + Parameters + ---------- + + au3: torch.Tensor (N_ATOMS * 3, N_ATOMS * 3) + Scaled distance matrix (see _get_A_thole). + + Returns + ------- + + result: torch.Tensor (N_ATOMS * 3, N_ATOMS * 3) + """ + return 1 - _torch.exp(-au3) + + @staticmethod + def _lambda5(au3): + """ + Internal method, calculates r^5 component of T2 tensor with Thole + damping. + + Parameters + ---------- + + au3: torch.Tensor (N_ATOMS * 3, N_ATOMS * 3) + Scaled distance matrix (see _get_A_thole). + + Returns + ------- + + result: torch.Tensor (N_ATOMS * 3, N_ATOMS * 3) + """ + return 1 - (1 + au3) * _torch.exp(-au3) From 870bf9e0e38aa6bcec210edd00b2cd317a1a7895 Mon Sep 17 00:00:00 2001 From: Kirill Zinovjev Date: Sun, 20 Oct 2024 15:33:22 +0200 Subject: [PATCH 04/39] Cleanup to/cuda/cpu/double/float methods in EMLE and EMLEBase --- emle/models/_emle.py | 5 ++ emle/models/_emle_base.py | 103 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 108 insertions(+) diff --git a/emle/models/_emle.py b/emle/models/_emle.py index b808c1c..9eea015 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -341,6 +341,7 @@ def to(self, *args, **kwargs): """ if self._aev_computer is not None: self._aev_computer = self._aev_computer.to(*args, **kwargs) + self._q_core_mm = self._q_core_mm.to(*args, **kwargs) self._emle_base = self._emle_base.to(*args, **kwargs) # Check for a device type in args and update the device attribute. @@ -357,6 +358,7 @@ def cuda(self, **kwargs): """ if self._aev_computer is not None: self._aev_computer = self._aev_computer.cuda(**kwargs) + self._q_core_mm = self._q_core_mm.cuda(**kwargs) self._emle_base = self._emle_base.cuda(**kwargs) # Update the device attribute. @@ -370,6 +372,7 @@ def cpu(self, **kwargs): """ if self._aev_computer is not None: self._aev_computer = self._aev_computer.cpu(**kwargs) + self._q_core_mm = self._q_core_mm.cpu(**kwargs) self._emle_base = self._emle_base.cpu() # Update the device attribute. @@ -383,6 +386,7 @@ def double(self): """ if self._aev_computer is not None: self._aev_computer = self._aev_computer.double() + self._q_core_mm = self._q_core_mm.double() self._emle_base = self._emle_base.double() return self @@ -392,6 +396,7 @@ def float(self): """ if self._aev_computer is not None: self._aev_computer = self._aev_computer.float() + self._q_core_mm = self._q_core_mm.float() self._emle_base = self._emle_base.float() return self diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index 2fd82ca..86f9cf4 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -189,6 +189,109 @@ def __init__( # Initalise an empty AEV tensor to use to store the AEVs in derived classes. self._aev = _torch.empty(0, dtype=dtype, device=device) + def to(self, *args, **kwargs): + self._species_map = self._species_map.to(*args, **kwargs) + self._aev_mask = self._aev_mask.to(*args, **kwargs) + self._q_core = self._q_core.to(*args, **kwargs) + self._a_QEq = self._a_QEq.to(*args, **kwargs) + self._a_Thole = self._a_Thole.to(*args, **kwargs) + self._k = self._k.to(*args, **kwargs) + self._q_total = self._q_total.to(*args, **kwargs) + self._ref_features = self._ref_features.to(*args, **kwargs) + self._n_ref = self._n_ref.to(*args, **kwargs) + self._ref_values_s = self._ref_values_s.to(*args, **kwargs) + self._ref_values_chi = self._ref_values_chi.to(*args, **kwargs) + self._ref_mean_s = self._ref_mean_s.to(*args, **kwargs) + self._ref_mean_chi = self._ref_mean_chi.to(*args, **kwargs) + self._c_s = self._c_s.to(*args, **kwargs) + self._c_chi = self._c_chi.to(*args, **kwargs) + self._ref_mean_k = self._ref_mean_k.to(*args, **kwargs) + self._c_k = self._c_k.to(*args, **kwargs) + + def cuda(self, **kwargs): + """ + Move all model parameters and buffers to CUDA memory. + """ + self._species_map = self._species_map.cuda(**kwargs) + self._aev_mask = self._aev_mask.cuda(**kwargs) + self._q_core = self._q_core.cuda(**kwargs) + self._a_QEq = self._a_QEq.cuda(**kwargs) + self._a_Thole = self._a_Thole.cuda(**kwargs) + self._k = self._k.cuda(**kwargs) + self._q_total = self._q_total.cuda(**kwargs) + self._ref_features = self._ref_features.cuda(**kwargs) + self._n_ref = self._n_ref.cuda(**kwargs) + self._ref_values_s = self._ref_values_s.cuda(**kwargs) + self._ref_values_chi = self._ref_values_chi.cuda(**kwargs) + self._ref_mean_s = self._ref_mean_s.cuda(**kwargs) + self._ref_mean_chi = self._ref_mean_chi.cuda(**kwargs) + self._c_s = self._c_s.cuda(**kwargs) + self._c_chi = self._c_chi.cuda(**kwargs) + self._ref_mean_k = self._ref_mean_k.cuda(**kwargs) + self._c_k = self._c_k.cuda(**kwargs) + + def cpu(self, **kwargs): + """ + Move all model parameters and buffers to CPU memory. + """ + self._species_map = self._species_map.cpu(**kwargs) + self._aev_mask = self._aev_mask.cpu(**kwargs) + self._q_core = self._q_core.cpu(**kwargs) + self._a_QEq = self._a_QEq.cpu(**kwargs) + self._a_Thole = self._a_Thole.cpu(**kwargs) + self._k = self._k.cpu(**kwargs) + self._q_total = self._q_total.cpu(**kwargs) + self._ref_features = self._ref_features.cpu(**kwargs) + self._n_ref = self._n_ref.cpu(**kwargs) + self._ref_values_s = self._ref_values_s.cpu(**kwargs) + self._ref_values_chi = self._ref_values_chi.cpu(**kwargs) + self._ref_mean_s = self._ref_mean_s.cpu(**kwargs) + self._ref_mean_chi = self._ref_mean_chi.cpu(**kwargs) + self._c_s = self._c_s.cpu(**kwargs) + self._c_chi = self._c_chi.cpu(**kwargs) + self._ref_mean_k = self._ref_mean_k.cpu(**kwargs) + self._c_k = self._c_k.cpu(**kwargs) + + def double(self): + """ + Casts all floating point model parameters and buffers to float64 precision. + """ + self._q_core = self._q_core.double() + self._a_QEq = self._a_QEq.double() + self._a_Thole = self._a_Thole.double() + self._k = self._k.double() + self._q_total = self._q_total.double() + self._ref_features = self._ref_features.double() + self._ref_values_s = self._ref_values_s.double() + self._ref_values_chi = self._ref_values_chi.double() + self._ref_mean_s = self._ref_mean_s.double() + self._ref_mean_chi = self._ref_mean_chi.double() + self._c_s = self._c_s.double() + self._c_chi = self._c_chi.double() + self._ref_mean_k = self._ref_mean_k.double() + self._c_k = self._c_k.double() + return self + + def float(self): + """ + Casts all floating point model parameters and buffers to float32 precision. + """ + self._q_core = self._q_core.float() + self._a_QEq = self._a_QEq.float() + self._a_Thole = self._a_Thole.float() + self._k = self._k.float() + self._q_total = self._q_total.float() + self._ref_features = self._ref_features.float() + self._ref_values_s = self._ref_values_s.float() + self._ref_values_chi = self._ref_values_chi.float() + self._ref_mean_s = self._ref_mean_s.float() + self._ref_mean_chi = self._ref_mean_chi.float() + self._c_s = self._c_s.float() + self._c_chi = self._c_chi.float() + self._ref_mean_k = self._ref_mean_k.float() + self._c_k = self._c_k.float() + return self + def forward(self, atomic_numbers, xyz_qm): """ Computes the static and induced EMLE energy components. From 293cf51f419834645d64d39e1df41eef797dfbdc Mon Sep 17 00:00:00 2001 From: Kirill Zinovjev Date: Sun, 20 Oct 2024 15:50:05 +0200 Subject: [PATCH 05/39] Pass q_total as argument to EMLEBase.forward --- emle/models/_emle.py | 12 +++++++++++- emle/models/_emle_base.py | 24 ++++++++++-------------- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/emle/models/_emle.py b/emle/models/_emle.py index 9eea015..e2e395f 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -310,6 +310,10 @@ def __init__( self._emle_base = EMLEBase(params, self._aev_computer, species, alpha_mode, device, dtype) + q_total = _torch.tensor( + params.get("total_charge", 0), dtype=dtype, device=device + ) + if method == "mm": q_core_mm = _torch.tensor(mm_charges, dtype=dtype, device=device) else: @@ -319,6 +323,7 @@ def __init__( self._device = device # Register constants as buffers. + self.register_buffer("_q_total", q_total) self.register_buffer("_q_core_mm", q_core_mm) # Initalise an empty AEV tensor to use to store the AEVs in derived classes. @@ -341,6 +346,7 @@ def to(self, *args, **kwargs): """ if self._aev_computer is not None: self._aev_computer = self._aev_computer.to(*args, **kwargs) + self._q_total = self._q_total.to(*args, **kwargs) self._q_core_mm = self._q_core_mm.to(*args, **kwargs) self._emle_base = self._emle_base.to(*args, **kwargs) @@ -358,6 +364,7 @@ def cuda(self, **kwargs): """ if self._aev_computer is not None: self._aev_computer = self._aev_computer.cuda(**kwargs) + self._q_total = self._q_total.cuda(**kwargs) self._q_core_mm = self._q_core_mm.cuda(**kwargs) self._emle_base = self._emle_base.cuda(**kwargs) @@ -372,6 +379,7 @@ def cpu(self, **kwargs): """ if self._aev_computer is not None: self._aev_computer = self._aev_computer.cpu(**kwargs) + self._q_total = self._q_total.cpu(**kwargs) self._q_core_mm = self._q_core_mm.cpu(**kwargs) self._emle_base = self._emle_base.cpu() @@ -386,6 +394,7 @@ def double(self): """ if self._aev_computer is not None: self._aev_computer = self._aev_computer.double() + self._q_total = self._q_total.double() self._q_core_mm = self._q_core_mm.double() self._emle_base = self._emle_base.double() return self @@ -396,6 +405,7 @@ def float(self): """ if self._aev_computer is not None: self._aev_computer = self._aev_computer.float() + self._q_total = self._q_total.float() self._q_core_mm = self._q_core_mm.float() self._emle_base = self._emle_base.float() return self @@ -430,7 +440,7 @@ def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm): if len(xyz_mm) == 0: return _torch.zeros(2, dtype=xyz_qm.dtype, device=xyz_qm.device) - s, q_core, q_val, A_thole = self._emle_base(atomic_numbers, xyz_qm) + s, q_core, q_val, A_thole = self._emle_base(atomic_numbers, xyz_qm, self._q_total) # Convert coordinates to Bohr. ANGSTROM_TO_BOHR = 1.8897261258369282 diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index 86f9cf4..c52226f 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -128,10 +128,6 @@ def __init__( ) raise ValueError(msg) - q_total = _torch.tensor( - params.get("total_charge", 0), dtype=dtype, device=device - ) - # Extract the reference features. ref_features = _torch.tensor(params["ref_aev"], dtype=dtype, device=device) @@ -174,7 +170,6 @@ def __init__( self.register_buffer("_a_QEq", a_QEq) self.register_buffer("_a_Thole", a_Thole) self.register_buffer("_k", k) - self.register_buffer("_q_total", q_total) self.register_buffer("_ref_features", ref_features) self.register_buffer("_n_ref", n_ref) self.register_buffer("_ref_values_s", ref_values_s) @@ -196,7 +191,6 @@ def to(self, *args, **kwargs): self._a_QEq = self._a_QEq.to(*args, **kwargs) self._a_Thole = self._a_Thole.to(*args, **kwargs) self._k = self._k.to(*args, **kwargs) - self._q_total = self._q_total.to(*args, **kwargs) self._ref_features = self._ref_features.to(*args, **kwargs) self._n_ref = self._n_ref.to(*args, **kwargs) self._ref_values_s = self._ref_values_s.to(*args, **kwargs) @@ -218,7 +212,6 @@ def cuda(self, **kwargs): self._a_QEq = self._a_QEq.cuda(**kwargs) self._a_Thole = self._a_Thole.cuda(**kwargs) self._k = self._k.cuda(**kwargs) - self._q_total = self._q_total.cuda(**kwargs) self._ref_features = self._ref_features.cuda(**kwargs) self._n_ref = self._n_ref.cuda(**kwargs) self._ref_values_s = self._ref_values_s.cuda(**kwargs) @@ -240,7 +233,6 @@ def cpu(self, **kwargs): self._a_QEq = self._a_QEq.cpu(**kwargs) self._a_Thole = self._a_Thole.cpu(**kwargs) self._k = self._k.cpu(**kwargs) - self._q_total = self._q_total.cpu(**kwargs) self._ref_features = self._ref_features.cpu(**kwargs) self._n_ref = self._n_ref.cpu(**kwargs) self._ref_values_s = self._ref_values_s.cpu(**kwargs) @@ -260,7 +252,6 @@ def double(self): self._a_QEq = self._a_QEq.double() self._a_Thole = self._a_Thole.double() self._k = self._k.double() - self._q_total = self._q_total.double() self._ref_features = self._ref_features.double() self._ref_values_s = self._ref_values_s.double() self._ref_values_chi = self._ref_values_chi.double() @@ -280,7 +271,6 @@ def float(self): self._a_QEq = self._a_QEq.float() self._a_Thole = self._a_Thole.float() self._k = self._k.float() - self._q_total = self._q_total.float() self._ref_features = self._ref_features.float() self._ref_values_s = self._ref_values_s.float() self._ref_values_chi = self._ref_values_chi.float() @@ -292,7 +282,7 @@ def float(self): self._c_k = self._c_k.float() return self - def forward(self, atomic_numbers, xyz_qm): + def forward(self, atomic_numbers, xyz_qm, q_total): """ Computes the static and induced EMLE energy components. @@ -305,6 +295,9 @@ def forward(self, atomic_numbers, xyz_qm): xyz_qm: torch.Tensor (N_QM_ATOMS, 3) Positions of QM atoms in Angstrom. + q_total: torch.Tensor (1,) + Total charge + Returns ------- @@ -339,7 +332,7 @@ def forward(self, atomic_numbers, xyz_qm): r_data = self._get_r_data(xyz_qm_bohr) q_core = self._q_core[species_id] - q = self._get_q(r_data, s, chi) + q = self._get_q(r_data, s, chi, q_total) q_val = q - q_core if self._alpha_mode == "species": @@ -460,7 +453,7 @@ def _get_r_data(cls, xyz): return (r_mat, t01, t21, t22) - def _get_q(self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s, chi): + def _get_q(self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s, chi, q_total): """ Internal method that predicts MBIS charges (Eq. 16 in 10.1021/acs.jctc.2c00914) @@ -476,6 +469,9 @@ def _get_q(self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s, chi): chi: torch.Tensor (N_ATOMS,) Electronegativities. + q_total: torch.Tensor (1,) + Total charge + Returns ------- @@ -483,7 +479,7 @@ def _get_q(self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s, chi): Predicted MBIS charges. """ A = self._get_A_QEq(r_data, s) - b = _torch.hstack([-chi, self._q_total]) + b = _torch.hstack([-chi, q_total]) return _torch.linalg.solve(A, b)[:-1] def _get_A_QEq(self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s): From cb1945817fc467f2becec8095b2158eb470e58ad Mon Sep 17 00:00:00 2001 From: Kirill Zinovjev Date: Sun, 20 Oct 2024 18:48:27 +0200 Subject: [PATCH 06/39] Remove _gpr method from EMLE class --- emle/models/_emle.py | 40 ---------------------------------------- 1 file changed, 40 deletions(-) diff --git a/emle/models/_emle.py b/emle/models/_emle.py index e2e395f..ecb3acc 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -475,46 +475,6 @@ def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm): return _torch.stack([E_static, E_ind]) - def _gpr(self, mol_features, ref_mean, c, zid): - """ - Internal method to predict a property using Gaussian Process Regression. - - Parameters - ---------- - - mol_features: torch.Tensor (N_ATOMS, N_FEAT) - The feature vectors for each atom. - - ref_mean: torch.Tensor (N_Z,) - The mean of the reference values for each species. - - c: torch.Tensor (N_Z, MAX_N_REF) - The coefficients of the GPR model. - - zid: torch.Tensor (N_ATOMS,) - The species identity value of each atom. - - Returns - ------- - - result: torch.Tensor (N_ATOMS) - The values of the predicted property for each atom. - """ - - result = _torch.zeros( - len(zid), dtype=mol_features.dtype, device=mol_features.device - ) - for i in range(len(self._n_ref)): - n_ref = self._n_ref[i] - ref_features_z = self._ref_features[i, :n_ref] - mol_features_z = mol_features[zid == i, :, None] - - K_mol_ref2 = (ref_features_z @ mol_features_z) ** 2 - K_mol_ref2 = K_mol_ref2.reshape(K_mol_ref2.shape[:-1]) - result[zid == i] = K_mol_ref2 @ c[i, :n_ref] + ref_mean[i] - - return result - def _get_mu_ind( self, A, From a1fc1ff47b101e0ba8db27377422de7560fdc53f Mon Sep 17 00:00:00 2001 From: Kirill Zinovjev Date: Sun, 20 Oct 2024 19:05:59 +0200 Subject: [PATCH 07/39] Refactor EMLEBase to work with batches --- emle/models/_emle.py | 5 +- emle/models/_emle_base.py | 134 ++++++++++++++++++-------------------- 2 files changed, 68 insertions(+), 71 deletions(-) diff --git a/emle/models/_emle.py b/emle/models/_emle.py index ecb3acc..e8ca8eb 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -440,7 +440,10 @@ def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm): if len(xyz_mm) == 0: return _torch.zeros(2, dtype=xyz_qm.dtype, device=xyz_qm.device) - s, q_core, q_val, A_thole = self._emle_base(atomic_numbers, xyz_qm, self._q_total) + s, q_core, q_val, A_thole = self._emle_base(atomic_numbers[None, :], + xyz_qm[None, :, :], + self._q_total[None]) + s, q_core, q_val, A_thole = s[0], q_core[0], q_val[0], A_thole[0] # Convert coordinates to Bohr. ANGSTROM_TO_BOHR = 1.8897261258369282 diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index c52226f..4c1f687 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -289,10 +289,10 @@ def forward(self, atomic_numbers, xyz_qm, q_total): Parameters ---------- - atomic_numbers: torch.Tensor (N_QM_ATOMS,) + atomic_numbers: torch.Tensor (N_BATCH, N_QM_ATOMS,) Atomic numbers of QM atoms. - xyz_qm: torch.Tensor (N_QM_ATOMS, 3) + xyz_qm: torch.Tensor (N_BATCH, N_QM_ATOMS, 3) Positions of QM atoms in Angstrom. q_total: torch.Tensor (1,) @@ -301,25 +301,21 @@ def forward(self, atomic_numbers, xyz_qm, q_total): Returns ------- - result: (torch.Tensor (N_QM_ATOMS,), - torch.Tensor (N_QM_ATOMS,), - torch.Tensor (N_QM_ATOMS,), - torch.Tensor (N_QM_ATOMS * 3, N_QM_ATOMS * 3,)) + result: (torch.Tensor (N_BATCH, N_QM_ATOMS,), + torch.Tensor (N_BATCH, N_QM_ATOMS,), + torch.Tensor (N_BATCH, N_QM_ATOMS,), + torch.Tensor (N_BATCH, N_QM_ATOMS * 3, N_QM_ATOMS * 3,)) Valence widths, core charges, valence charges, A_thole tensor """ + mask = atomic_numbers > 0 + # Convert the atomic numbers to species IDs. species_id = self._species_map[atomic_numbers] - # Reshape the IDs. - zid = species_id.unsqueeze(0) - - # Reshape the atomic positions. - xyz = xyz_qm.unsqueeze(0) - # Compute the AEVs. - aev = self._aev_computer((zid, xyz))[1][0][:, self._aev_mask] - aev = aev / _torch.linalg.norm(aev, ord=2, dim=1, keepdim=True) + aev = self._aev_computer((species_id, xyz_qm))[1][:, :, self._aev_mask] + aev = aev / _torch.linalg.norm(aev, ord=2, dim=2, keepdim=True) # Compute the MBIS valence shell widths. s = self._gpr(aev, self._ref_mean_s, self._c_s, species_id) @@ -331,7 +327,7 @@ def forward(self, atomic_numbers, xyz_qm, q_total): r_data = self._get_r_data(xyz_qm_bohr) - q_core = self._q_core[species_id] + q_core = self._q_core[species_id] * mask q = self._get_q(r_data, s, chi, q_total) q_val = q - q_core @@ -377,7 +373,7 @@ def _gpr(self, mol_features, ref_mean, c, zid): Parameters ---------- - mol_features: torch.Tensor (N_ATOMS, N_FEAT) + mol_features: torch.Tensor (N_BATCH, N_ATOMS, N_FEAT) The feature vectors for each atom. ref_mean: torch.Tensor (N_Z,) @@ -386,26 +382,26 @@ def _gpr(self, mol_features, ref_mean, c, zid): c: torch.Tensor (N_Z, MAX_N_REF) The coefficients of the GPR model. - zid: torch.Tensor (N_ATOMS,) + zid: torch.Tensor (N_BATCH, N_ATOMS,) The species identity value of each atom. Returns ------- - result: torch.Tensor (N_ATOMS) + result: torch.Tensor (N_BATCH, N_ATOMS) The values of the predicted property for each atom. """ result = _torch.zeros( - len(zid), dtype=mol_features.dtype, device=mol_features.device + zid.shape, dtype=mol_features.dtype, device=mol_features.device ) for i in range(len(self._n_ref)): n_ref = self._n_ref[i] ref_features_z = self._ref_features[i, :n_ref] - mol_features_z = mol_features[zid == i, :, None] + mol_features_z = mol_features[zid == i] - K_mol_ref2 = (ref_features_z @ mol_features_z) ** 2 - K_mol_ref2 = K_mol_ref2.reshape(K_mol_ref2.shape[:-1]) + K_mol_ref2 = (mol_features_z @ ref_features_z.T) ** 2 + # K_mol_ref2 = K_mol_ref2.reshape(K_mol_ref2.shape[:-1]) result[zid == i] = K_mol_ref2 @ c[i, :n_ref] + ref_mean[i] return result @@ -418,7 +414,7 @@ def _get_r_data(cls, xyz): Parameters ---------- - xyz: torch.Tensor (N_ATOMS, 3) + xyz: torch.Tensor (N_BATCH, N_ATOMS, 3) Atomic positions. Returns @@ -426,32 +422,30 @@ def _get_r_data(cls, xyz): result: r_data object """ - n_atoms = len(xyz) + n_batch, n_atoms_max = xyz.shape[:2] - rr_mat = xyz[:, None, :] - xyz[None, :, :] + rr_mat = xyz[:, :, None, :] - xyz[:, None, :, :] r_mat = _torch.cdist(xyz, xyz) r_inv = _torch.where(r_mat == 0.0, 0.0, 1.0 / r_mat) - r_inv1 = r_inv.repeat_interleave(3, dim=1) - r_inv2 = r_inv1.repeat_interleave(3, dim=0) + r_inv1 = r_inv.repeat_interleave(3, dim=2) + r_inv2 = r_inv1.repeat_interleave(3, dim=1) # Get a stacked matrix of outer products over the rr_mat tensors. - outer = _torch.einsum("bik,bij->bjik", rr_mat, rr_mat).reshape( - (n_atoms * 3, n_atoms * 3) + outer = _torch.einsum("bnik,bnij->bnjik", rr_mat, rr_mat).reshape( + (n_batch, n_atoms_max * 3, n_atoms_max * 3) ) id2 = _torch.tile( - _torch.tile( - _torch.eye(3, dtype=xyz.dtype, device=xyz.device).T, (1, n_atoms) - ).T, - (1, n_atoms), + _torch.eye(3, dtype=xyz.dtype, device=xyz.device).T, + (1, n_atoms_max, n_atoms_max) ) t01 = r_inv t21 = -id2 * r_inv2 ** 3 t22 = 3 * outer * r_inv2 ** 5 - return (r_mat, t01, t21, t22) + return r_mat, t01, t21, t22 def _get_q(self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s, chi, q_total): """ @@ -463,24 +457,24 @@ def _get_q(self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s, chi, q_total) r_data: r_data object (output of self._get_r_data) - s: torch.Tensor (N_ATOMS,) + s: torch.Tensor (N_BATCH, N_ATOMS,) MBIS valence shell widths. - chi: torch.Tensor (N_ATOMS,) + chi: torch.Tensor (N_BATCH, N_ATOMS,) Electronegativities. - q_total: torch.Tensor (1,) + q_total: torch.Tensor (N_BATCH,) Total charge Returns ------- - result: torch.Tensor (N_ATOMS,) + result: torch.Tensor (N_BATCH, N_ATOMS,) Predicted MBIS charges. """ A = self._get_A_QEq(r_data, s) - b = _torch.hstack([-chi, q_total]) - return _torch.linalg.solve(A, b)[:-1] + b = _torch.hstack([-chi, q_total[:, None]]) + return _torch.linalg.solve(A, b)[:, :-1] def _get_A_QEq(self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s): """ @@ -492,44 +486,42 @@ def _get_A_QEq(self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s): r_data: r_data object (output of self._get_r_data) - s: torch.Tensor (N_ATOMS,) + s: torch.Tensor (N_BATCH, N_ATOMS,) MBIS valence shell widths. Returns ------- - result: torch.Tensor (N_ATOMS + 1, N_ATOMS + 1) + result: torch.Tensor (N_BATCH, N_ATOMS + 1, N_ATOMS + 1) """ s_gauss = s * self._a_QEq s2 = s_gauss**2 - s_mat = _torch.sqrt(s2[:, None] + s2[None, :]) + s_mat = _torch.sqrt(s2[:, :, None] + s2[:, None, :]) device = r_data[0].device dtype = r_data[0].dtype A = self._get_T0_gaussian(r_data[1], r_data[0], s_mat) - new_diag = _torch.ones_like(A.diagonal(), dtype=dtype, device=device) * ( - 1.0 - / ( - s_gauss - * _torch.sqrt(_torch.tensor([_torch.pi], dtype=dtype, device=device)) - ) - ) - mask = _torch.diag(_torch.ones_like(new_diag, dtype=dtype, device=device)) - A = mask * _torch.diag(new_diag) + (1.0 - mask) * A + diag_ones = _torch.ones_like(A.diagonal(dim1=-2, dim2=-1), + dtype=dtype, device=device) + pi = _torch.sqrt(_torch.tensor([_torch.pi], dtype=dtype, device=device)) + new_diag = diag_ones * _torch.where(s2 > 0, 1.0 / (s_gauss * pi), 0) + + mask = _torch.diag_embed(diag_ones) + A = mask * _torch.diag_embed(new_diag) + (1.0 - mask) * A # Store the dimensions of A. - x, y = A.shape + x, y = A.shape[1:] # Create an tensor of ones with one more row and column than A. - B = _torch.ones(x + 1, y + 1, dtype=dtype, device=device) + B = _torch.ones(len(A), x + 1, y + 1, dtype=dtype, device=device) # Copy A into B. - B[:x, :y] = A + B[:, :x, :y] = A # Set the final entry on the diagonal to zero. - B[-1, -1] = 0.0 + B[:, -1, -1] = 0.0 return B @@ -541,19 +533,19 @@ def _get_T0_gaussian(t01, r, s_mat): Parameters ---------- - t01: torch.Tensor (N_ATOMS, N_ATOMS) + t01: torch.Tensor (N_BATCH, N_ATOMS, N_ATOMS) T0 tensor for QM atoms. - r: torch.Tensor (N_ATOMS, N_ATOMS) + r: torch.Tensor (N_BATCH, N_ATOMS, N_ATOMS) Distance matrix for QM atoms. - s_mat: torch.Tensor (N_ATOMS, N_ATOMS) + s_mat: torch.Tensor (N_BATCH, N_ATOMS, N_ATOMS) Matrix of Gaussian sigmas for QM atoms. Returns ------- - results: torch.Tensor (N_ATOMS, N_ATOMS) + results: torch.Tensor (N_BATCH, N_ATOMS, N_ATOMS) """ return t01 * _torch.erf( r @@ -573,36 +565,38 @@ def _get_A_thole(self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s, q_val, r_data: r_data object (output of self._get_r_data) - s: torch.Tensor (N_ATOMS,) + s: torch.Tensor (N_BATCH, N_ATOMS,) MBIS valence shell widths. - q_val: torch.Tensor (N_ATOMS,) + q_val: torch.Tensor (N_BATCH, N_ATOMS,) MBIS charges. - k: torch.Tensor (N_Z) + k: torch.Tensor (N_BATCH, N_ATOMS,) Scaling factors for polarizabilities. Returns ------- - result: torch.Tensor (N_ATOMS * 3, N_ATOMS * 3) + result: torch.Tensor (N_BATCH, N_ATOMS * 3, N_ATOMS * 3) The A matrix for induced dipoles prediction. """ v = -60 * q_val * s**3 alpha = v * k alphap = alpha * self._a_Thole - alphap_mat = alphap[:, None] * alphap[None, :] + alphap_mat = alphap[:, :, None] * alphap[:, None, :] au3 = r_data[0] ** 3 / _torch.sqrt(alphap_mat) - au31 = au3.repeat_interleave(3, dim=1) - au32 = au31.repeat_interleave(3, dim=0) + au31 = au3.repeat_interleave(3, dim=2) + au32 = au31.repeat_interleave(3, dim=1) A = -self._get_T2_thole(r_data[2], r_data[3], au32) - new_diag = 1.0 / alpha.repeat_interleave(3) - mask = _torch.diag(_torch.ones_like(new_diag, dtype=A.dtype, device=A.device)) - A = mask * _torch.diag(new_diag) + (1.0 - mask) * A + alpha3 = alpha.repeat_interleave(3, dim=1) + new_diag = _torch.where(alpha3 > 0, 1.0 / alpha3, 1.) + diag_ones = _torch.ones_like(new_diag, dtype=A.dtype, device=A.device) + mask = _torch.diag_embed(diag_ones) + A = mask * _torch.diag_embed(new_diag) + (1.0 - mask) * A return A From fc55b1e81ba4b9cedfcbd51d160474e6429084bd Mon Sep 17 00:00:00 2001 From: Kirill Zinovjev Date: Sun, 20 Oct 2024 19:45:26 +0200 Subject: [PATCH 08/39] Refactor calculation of GPR coefficients (EMLEBase._get_c method) --- emle/models/_emle_base.py | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index 4c1f687..60bcf2c 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -131,31 +131,23 @@ def __init__( # Extract the reference features. ref_features = _torch.tensor(params["ref_aev"], dtype=dtype, device=device) - # Extract the reference values for the MBIS valence shell widths. - ref_values_s = _torch.tensor(params["s_ref"], dtype=dtype, device=device) - # Compute the inverse of the K matrix. Kinv = self._get_Kinv(ref_features, 1e-3) - # Store additional attributes for the MBIS GPR model. + # Extract number of references per element n_ref = _torch.tensor(params["n_ref"], dtype=_torch.int64, device=device) - ref_mean_s = _torch.sum(ref_values_s, dim=1) / n_ref - ref_shifted = ref_values_s - ref_mean_s[:, None] - c_s = (Kinv @ ref_shifted[:, :, None]).squeeze() - # Extract the reference values for the electronegativities. - ref_values_chi = _torch.tensor(params["chi_ref"], dtype=dtype, device=device) + # Extract the reference values and GPR coefficients for the valence shell widths. + ref_values_s = _torch.tensor(params["s_ref"], dtype=dtype, device=device) + ref_mean_s, c_s = self._get_c(n_ref, ref_values_s, Kinv) - # Store additional attributes for the electronegativity GPR model. - ref_mean_chi = _torch.sum(ref_values_chi, dim=1) / n_ref - ref_shifted = ref_values_chi - ref_mean_chi[:, None] - c_chi = (Kinv @ ref_shifted[:, :, None]).squeeze() + # Extract the reference values and GPR coefficients for the electronegativities. + ref_values_chi = _torch.tensor(params["chi_ref"], dtype=dtype, device=device) + ref_mean_chi, c_chi = self._get_c(n_ref, ref_values_chi, Kinv) - # Extract the reference values for the polarizabilities. + # Extract the reference values and GPR coefficients for the polarizabilities. if self._alpha_mode == "reference": - ref_mean_k = _torch.sum(k, dim=1) / n_ref - ref_shifted = k - ref_mean_k[:, None] - c_k = (Kinv @ ref_shifted[:, :, None]).squeeze() + ref_mean_k, c_k = self._get_c(n_ref, k, Kinv) else: ref_mean_k = _torch.empty(0, dtype=dtype, device=device) c_k = _torch.empty(0, dtype=dtype, device=device) @@ -366,6 +358,12 @@ def _get_Kinv(cls, ref_features, sigma): K + sigma**2 * _torch.eye(n, dtype=ref_features.dtype, device=K.device) ) + @classmethod + def _get_c(cls, n_ref, ref, Kinv): + ref_mean = _torch.sum(ref, dim=1) / n_ref + ref_shifted = ref - ref_mean[:, None] + return ref_mean, (Kinv @ ref_shifted[:, :, None]).squeeze() + def _gpr(self, mol_features, ref_mean, c, zid): """ Internal method to predict a property using Gaussian Process Regression. From 3e7ed58b90d938d250ace33f628007c6c6c57477 Mon Sep 17 00:00:00 2001 From: Kirill Zinovjev Date: Sun, 20 Oct 2024 20:06:15 +0200 Subject: [PATCH 09/39] Explicit k_Z/sqrtk variables for species/reference models --- emle/models/_emle_base.py | 83 ++++++++++++++++++++------------------- 1 file changed, 42 insertions(+), 41 deletions(-) diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index 60bcf2c..521b169 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -109,24 +109,6 @@ def __init__( q_core = _torch.tensor(params["q_core"], dtype=dtype, device=device) a_QEq = _torch.tensor(params["a_QEq"], dtype=dtype, device=device) a_Thole = _torch.tensor(params["a_Thole"], dtype=dtype, device=device) - if self._alpha_mode == "species": - try: - k = _torch.tensor(params["k_Z"], dtype=dtype, device=device) - except: - msg = ( - "Missing 'k_Z' key in params. This is required when " - "using 'species' alpha mode." - ) - raise ValueError(msg) - else: - try: - k = _torch.tensor(params["sqrtk_ref"], dtype=dtype, device=device) - except: - msg = ( - "Missing 'sqrtk_ref' key in params. This is required when " - "using 'reference' alpha mode." - ) - raise ValueError(msg) # Extract the reference features. ref_features = _torch.tensor(params["ref_aev"], dtype=dtype, device=device) @@ -145,12 +127,30 @@ def __init__( ref_values_chi = _torch.tensor(params["chi_ref"], dtype=dtype, device=device) ref_mean_chi, c_chi = self._get_c(n_ref, ref_values_chi, Kinv) - # Extract the reference values and GPR coefficients for the polarizabilities. - if self._alpha_mode == "reference": - ref_mean_k, c_k = self._get_c(n_ref, k, Kinv) + if self._alpha_mode == "species": + try: + k_Z = _torch.tensor(params["k_Z"], dtype=dtype, device=device) + ref_values_sqrtk = _torch.empty(0, dtype=dtype, device=device) + ref_mean_sqrtk = _torch.empty(0, dtype=dtype, device=device) + c_sqrtk = _torch.empty(0, dtype=dtype, device=device) + except: + msg = ( + "Missing 'k_Z' key in params. This is required when " + "using 'species' alpha mode." + ) + raise ValueError(msg) else: - ref_mean_k = _torch.empty(0, dtype=dtype, device=device) - c_k = _torch.empty(0, dtype=dtype, device=device) + try: + k_Z = torch.empty(0, dtype=dtype, device=device) + ref_values_sqrtk = _torch.tensor(params["sqrtk_ref"], + dtype=dtype, device=device) + ref_mean_sqrtk, c_sqrtk = self._get_c(n_ref, ref_values_sqrtk, Kinv) + except: + msg = ( + "Missing 'sqrtk_ref' key in params. This is required when " + "using 'reference' alpha mode." + ) + raise ValueError(msg) # Store the current device. self._device = device @@ -161,17 +161,18 @@ def __init__( self.register_buffer("_q_core", q_core) self.register_buffer("_a_QEq", a_QEq) self.register_buffer("_a_Thole", a_Thole) - self.register_buffer("_k", k) + self.register_buffer("_k_Z", k_Z) self.register_buffer("_ref_features", ref_features) self.register_buffer("_n_ref", n_ref) self.register_buffer("_ref_values_s", ref_values_s) self.register_buffer("_ref_values_chi", ref_values_chi) + self.register_buffer("_ref_values_sqrtk", ref_values_sqrtk) self.register_buffer("_ref_mean_s", ref_mean_s) self.register_buffer("_ref_mean_chi", ref_mean_chi) + self.register_buffer("_ref_mean_sqrtk", ref_mean_sqrtk) self.register_buffer("_c_s", c_s) self.register_buffer("_c_chi", c_chi) - self.register_buffer("_ref_mean_k", ref_mean_k) - self.register_buffer("_c_k", c_k) + self.register_buffer("_c_sqrtk", c_sqrtk) # Initalise an empty AEV tensor to use to store the AEVs in derived classes. self._aev = _torch.empty(0, dtype=dtype, device=device) @@ -182,16 +183,16 @@ def to(self, *args, **kwargs): self._q_core = self._q_core.to(*args, **kwargs) self._a_QEq = self._a_QEq.to(*args, **kwargs) self._a_Thole = self._a_Thole.to(*args, **kwargs) - self._k = self._k.to(*args, **kwargs) + self._k_Z = self._k_Z.to(*args, **kwargs) self._ref_features = self._ref_features.to(*args, **kwargs) self._n_ref = self._n_ref.to(*args, **kwargs) self._ref_values_s = self._ref_values_s.to(*args, **kwargs) self._ref_values_chi = self._ref_values_chi.to(*args, **kwargs) self._ref_mean_s = self._ref_mean_s.to(*args, **kwargs) self._ref_mean_chi = self._ref_mean_chi.to(*args, **kwargs) + self._ref_mean_sqrtk = self._ref_mean_sqrtk.to(*args, **kwargs) self._c_s = self._c_s.to(*args, **kwargs) self._c_chi = self._c_chi.to(*args, **kwargs) - self._ref_mean_k = self._ref_mean_k.to(*args, **kwargs) self._c_k = self._c_k.to(*args, **kwargs) def cuda(self, **kwargs): @@ -203,17 +204,17 @@ def cuda(self, **kwargs): self._q_core = self._q_core.cuda(**kwargs) self._a_QEq = self._a_QEq.cuda(**kwargs) self._a_Thole = self._a_Thole.cuda(**kwargs) - self._k = self._k.cuda(**kwargs) + self._k_Z = self._k_Z.cuda(**kwargs) self._ref_features = self._ref_features.cuda(**kwargs) self._n_ref = self._n_ref.cuda(**kwargs) self._ref_values_s = self._ref_values_s.cuda(**kwargs) self._ref_values_chi = self._ref_values_chi.cuda(**kwargs) self._ref_mean_s = self._ref_mean_s.cuda(**kwargs) self._ref_mean_chi = self._ref_mean_chi.cuda(**kwargs) + self._ref_mean_sqrtk = self._ref_mean_sqrtk.cuda(**kwargs) self._c_s = self._c_s.cuda(**kwargs) self._c_chi = self._c_chi.cuda(**kwargs) - self._ref_mean_k = self._ref_mean_k.cuda(**kwargs) - self._c_k = self._c_k.cuda(**kwargs) + self._c_sqrtk = self._c_sqrtk.cuda(**kwargs) def cpu(self, **kwargs): """ @@ -224,17 +225,17 @@ def cpu(self, **kwargs): self._q_core = self._q_core.cpu(**kwargs) self._a_QEq = self._a_QEq.cpu(**kwargs) self._a_Thole = self._a_Thole.cpu(**kwargs) - self._k = self._k.cpu(**kwargs) + self._k_Z = self._k_Z.cpu(**kwargs) self._ref_features = self._ref_features.cpu(**kwargs) self._n_ref = self._n_ref.cpu(**kwargs) self._ref_values_s = self._ref_values_s.cpu(**kwargs) self._ref_values_chi = self._ref_values_chi.cpu(**kwargs) self._ref_mean_s = self._ref_mean_s.cpu(**kwargs) self._ref_mean_chi = self._ref_mean_chi.cpu(**kwargs) + self._ref_mean_sqrtk = self._ref_mean_sqrtk.to(**kwargs) self._c_s = self._c_s.cpu(**kwargs) self._c_chi = self._c_chi.cpu(**kwargs) - self._ref_mean_k = self._ref_mean_k.cpu(**kwargs) - self._c_k = self._c_k.cpu(**kwargs) + self._c_sqrtk = self._c_sqrtk.cpu(**kwargs) def double(self): """ @@ -243,16 +244,16 @@ def double(self): self._q_core = self._q_core.double() self._a_QEq = self._a_QEq.double() self._a_Thole = self._a_Thole.double() - self._k = self._k.double() + self._k_Z = self._k_Z.double() self._ref_features = self._ref_features.double() self._ref_values_s = self._ref_values_s.double() self._ref_values_chi = self._ref_values_chi.double() self._ref_mean_s = self._ref_mean_s.double() self._ref_mean_chi = self._ref_mean_chi.double() + self._ref_mean_sqrtk = self._ref_mean_sqrtk.double() self._c_s = self._c_s.double() self._c_chi = self._c_chi.double() - self._ref_mean_k = self._ref_mean_k.double() - self._c_k = self._c_k.double() + self._c_sqrtk = self._c_sqrtk.double() return self def float(self): @@ -268,10 +269,10 @@ def float(self): self._ref_values_chi = self._ref_values_chi.float() self._ref_mean_s = self._ref_mean_s.float() self._ref_mean_chi = self._ref_mean_chi.float() + self._ref_mean_sqrtk = self._ref_mean_sqrtk.float() self._c_s = self._c_s.float() self._c_chi = self._c_chi.float() - self._ref_mean_k = self._ref_mean_k.float() - self._c_k = self._c_k.float() + self._c_sqrtk = self._c_sqrtk.float() return self def forward(self, atomic_numbers, xyz_qm, q_total): @@ -324,9 +325,9 @@ def forward(self, atomic_numbers, xyz_qm, q_total): q_val = q - q_core if self._alpha_mode == "species": - k = self._k[species_id] + k = self._k_Z[species_id] else: - k = self._gpr(aev, self._ref_mean_k, self._c_k, species_id) ** 2 + k = self._gpr(aev, self._ref_mean_sqrtk, self._c_sqrtk, species_id) ** 2 A_thole = self._get_A_thole(r_data, s, q_val, k) From 61eaedcc441ca026c90dea7f709714440df0ff56 Mon Sep 17 00:00:00 2001 From: Kirill Zinovjev Date: Sun, 20 Oct 2024 20:10:13 +0200 Subject: [PATCH 10/39] Remove ref_values tensors from buffers (never used) --- emle/models/_emle_base.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index 521b169..a8c486f 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -164,9 +164,6 @@ def __init__( self.register_buffer("_k_Z", k_Z) self.register_buffer("_ref_features", ref_features) self.register_buffer("_n_ref", n_ref) - self.register_buffer("_ref_values_s", ref_values_s) - self.register_buffer("_ref_values_chi", ref_values_chi) - self.register_buffer("_ref_values_sqrtk", ref_values_sqrtk) self.register_buffer("_ref_mean_s", ref_mean_s) self.register_buffer("_ref_mean_chi", ref_mean_chi) self.register_buffer("_ref_mean_sqrtk", ref_mean_sqrtk) @@ -186,8 +183,6 @@ def to(self, *args, **kwargs): self._k_Z = self._k_Z.to(*args, **kwargs) self._ref_features = self._ref_features.to(*args, **kwargs) self._n_ref = self._n_ref.to(*args, **kwargs) - self._ref_values_s = self._ref_values_s.to(*args, **kwargs) - self._ref_values_chi = self._ref_values_chi.to(*args, **kwargs) self._ref_mean_s = self._ref_mean_s.to(*args, **kwargs) self._ref_mean_chi = self._ref_mean_chi.to(*args, **kwargs) self._ref_mean_sqrtk = self._ref_mean_sqrtk.to(*args, **kwargs) @@ -207,8 +202,6 @@ def cuda(self, **kwargs): self._k_Z = self._k_Z.cuda(**kwargs) self._ref_features = self._ref_features.cuda(**kwargs) self._n_ref = self._n_ref.cuda(**kwargs) - self._ref_values_s = self._ref_values_s.cuda(**kwargs) - self._ref_values_chi = self._ref_values_chi.cuda(**kwargs) self._ref_mean_s = self._ref_mean_s.cuda(**kwargs) self._ref_mean_chi = self._ref_mean_chi.cuda(**kwargs) self._ref_mean_sqrtk = self._ref_mean_sqrtk.cuda(**kwargs) @@ -228,8 +221,6 @@ def cpu(self, **kwargs): self._k_Z = self._k_Z.cpu(**kwargs) self._ref_features = self._ref_features.cpu(**kwargs) self._n_ref = self._n_ref.cpu(**kwargs) - self._ref_values_s = self._ref_values_s.cpu(**kwargs) - self._ref_values_chi = self._ref_values_chi.cpu(**kwargs) self._ref_mean_s = self._ref_mean_s.cpu(**kwargs) self._ref_mean_chi = self._ref_mean_chi.cpu(**kwargs) self._ref_mean_sqrtk = self._ref_mean_sqrtk.to(**kwargs) @@ -246,8 +237,6 @@ def double(self): self._a_Thole = self._a_Thole.double() self._k_Z = self._k_Z.double() self._ref_features = self._ref_features.double() - self._ref_values_s = self._ref_values_s.double() - self._ref_values_chi = self._ref_values_chi.double() self._ref_mean_s = self._ref_mean_s.double() self._ref_mean_chi = self._ref_mean_chi.double() self._ref_mean_sqrtk = self._ref_mean_sqrtk.double() @@ -265,8 +254,6 @@ def float(self): self._a_Thole = self._a_Thole.float() self._k = self._k.float() self._ref_features = self._ref_features.float() - self._ref_values_s = self._ref_values_s.float() - self._ref_values_chi = self._ref_values_chi.float() self._ref_mean_s = self._ref_mean_s.float() self._ref_mean_chi = self._ref_mean_chi.float() self._ref_mean_sqrtk = self._ref_mean_sqrtk.float() From 62c6d082e489f69ad9ba86805d7cfb2973bf4a87 Mon Sep 17 00:00:00 2001 From: Kirill Zinovjev Date: Sun, 20 Oct 2024 20:22:25 +0200 Subject: [PATCH 11/39] Register buffer for Kinv (will be needed during training) --- emle/models/_emle_base.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index a8c486f..171ed11 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -158,6 +158,7 @@ def __init__( # Register constants as buffers. self.register_buffer("_species_map", species_map) self.register_buffer("_aev_mask", aev_mask) + self.register_buffer("_Kinv", Kinv) self.register_buffer("_q_core", q_core) self.register_buffer("_a_QEq", a_QEq) self.register_buffer("_a_Thole", a_Thole) @@ -176,6 +177,7 @@ def __init__( def to(self, *args, **kwargs): self._species_map = self._species_map.to(*args, **kwargs) + self._Kinv = self._Kinv.to(*args, **kwargs) self._aev_mask = self._aev_mask.to(*args, **kwargs) self._q_core = self._q_core.to(*args, **kwargs) self._a_QEq = self._a_QEq.to(*args, **kwargs) @@ -195,6 +197,7 @@ def cuda(self, **kwargs): Move all model parameters and buffers to CUDA memory. """ self._species_map = self._species_map.cuda(**kwargs) + self._Kinv = self._Kinv.cuda(**kwargs) self._aev_mask = self._aev_mask.cuda(**kwargs) self._q_core = self._q_core.cuda(**kwargs) self._a_QEq = self._a_QEq.cuda(**kwargs) @@ -214,6 +217,7 @@ def cpu(self, **kwargs): Move all model parameters and buffers to CPU memory. """ self._species_map = self._species_map.cpu(**kwargs) + self._Kinv = self._Kinv.cpu(**kwargs) self._aev_mask = self._aev_mask.cpu(**kwargs) self._q_core = self._q_core.cpu(**kwargs) self._a_QEq = self._a_QEq.cpu(**kwargs) @@ -232,6 +236,7 @@ def double(self): """ Casts all floating point model parameters and buffers to float64 precision. """ + self._Kinv = self._Kinv.double() self._q_core = self._q_core.double() self._a_QEq = self._a_QEq.double() self._a_Thole = self._a_Thole.double() @@ -249,6 +254,7 @@ def float(self): """ Casts all floating point model parameters and buffers to float32 precision. """ + self._Kinv = self._Kinv.float() self._q_core = self._q_core.float() self._a_QEq = self._a_QEq.float() self._a_Thole = self._a_Thole.float() From d3d3862acd7c4b42c2c583b18623dd2fa95f4612 Mon Sep 17 00:00:00 2001 From: Kirill Zinovjev Date: Sun, 20 Oct 2024 20:22:54 +0200 Subject: [PATCH 12/39] Typo --- emle/models/_emle_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index 171ed11..abb69f6 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -141,7 +141,7 @@ def __init__( raise ValueError(msg) else: try: - k_Z = torch.empty(0, dtype=dtype, device=device) + k_Z = _torch.empty(0, dtype=dtype, device=device) ref_values_sqrtk = _torch.tensor(params["sqrtk_ref"], dtype=dtype, device=device) ref_mean_sqrtk, c_sqrtk = self._get_c(n_ref, ref_values_sqrtk, Kinv) From 2605d86883d6582d6df88c89bed3a8f4d10ebacf Mon Sep 17 00:00:00 2001 From: Kirill Zinovjev Date: Sun, 20 Oct 2024 20:41:05 +0200 Subject: [PATCH 13/39] Move all parameters together, ensure same parameter shapes for species/reference models --- emle/models/_emle_base.py | 64 +++++++++++++++++++++------------------ 1 file changed, 35 insertions(+), 29 deletions(-) diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index abb69f6..c6ef0d6 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -96,43 +96,20 @@ def __init__( self._aev_computer = aev_computer - # Create a map between species and their indices. - species_map = _np.full(max(species) + 1, fill_value=-1, dtype=_np.int64) - for i, s in enumerate(species): - species_map[s] = i - - # Convert to a tensor. - species_map = _torch.tensor(species_map, dtype=_torch.int64, device=device) + # q_core is not trained, so not a parameter + q_core = _torch.tensor(params["q_core"], dtype=dtype, device=device) # Store model parameters as tensors. - aev_mask = _torch.tensor(params["aev_mask"], dtype=_torch.bool, device=device) - q_core = _torch.tensor(params["q_core"], dtype=dtype, device=device) a_QEq = _torch.tensor(params["a_QEq"], dtype=dtype, device=device) a_Thole = _torch.tensor(params["a_Thole"], dtype=dtype, device=device) - - # Extract the reference features. - ref_features = _torch.tensor(params["ref_aev"], dtype=dtype, device=device) - - # Compute the inverse of the K matrix. - Kinv = self._get_Kinv(ref_features, 1e-3) - - # Extract number of references per element - n_ref = _torch.tensor(params["n_ref"], dtype=_torch.int64, device=device) - - # Extract the reference values and GPR coefficients for the valence shell widths. ref_values_s = _torch.tensor(params["s_ref"], dtype=dtype, device=device) - ref_mean_s, c_s = self._get_c(n_ref, ref_values_s, Kinv) - - # Extract the reference values and GPR coefficients for the electronegativities. ref_values_chi = _torch.tensor(params["chi_ref"], dtype=dtype, device=device) - ref_mean_chi, c_chi = self._get_c(n_ref, ref_values_chi, Kinv) if self._alpha_mode == "species": try: k_Z = _torch.tensor(params["k_Z"], dtype=dtype, device=device) - ref_values_sqrtk = _torch.empty(0, dtype=dtype, device=device) - ref_mean_sqrtk = _torch.empty(0, dtype=dtype, device=device) - c_sqrtk = _torch.empty(0, dtype=dtype, device=device) + ref_values_sqrtk = _torch.zeros_like(ref_values_s, + dtype=dtype, device=device) except: msg = ( "Missing 'k_Z' key in params. This is required when " @@ -141,10 +118,9 @@ def __init__( raise ValueError(msg) else: try: - k_Z = _torch.empty(0, dtype=dtype, device=device) + k_Z = _torch.zeros_like(q_core, dtype=dtype, device=device) ref_values_sqrtk = _torch.tensor(params["sqrtk_ref"], dtype=dtype, device=device) - ref_mean_sqrtk, c_sqrtk = self._get_c(n_ref, ref_values_sqrtk, Kinv) except: msg = ( "Missing 'sqrtk_ref' key in params. This is required when " @@ -152,6 +128,36 @@ def __init__( ) raise ValueError(msg) + # Create a map between species (1, 6, 8) + # and their indices in the model (0, 1, 2). + species_map = _np.full(max(species) + 1, fill_value=-1, dtype=_np.int64) + for i, s in enumerate(species): + species_map[s] = i + species_map = _torch.tensor(species_map, dtype=_torch.int64, device=device) + + aev_mask = _torch.tensor(params["aev_mask"], dtype=_torch.bool, device=device) + + # Extract number of references per element + n_ref = _torch.tensor(params["n_ref"], dtype=_torch.int64, device=device) + + # Extract the reference features. + ref_features = _torch.tensor(params["ref_aev"], dtype=dtype, device=device) + + # Compute the inverse of the K matrix. + Kinv = self._get_Kinv(ref_features, 1e-3) + + # Calculate GPR coefficients for the valence shell widths (s) + # and electronegativities (chi). + ref_mean_s, c_s = self._get_c(n_ref, ref_values_s, Kinv) + ref_mean_chi, c_chi = self._get_c(n_ref, ref_values_chi, Kinv) + + if self._alpha_mode == "species": + ref_mean_sqrtk = _torch.zeros_like(ref_mean_s, dtype=dtype, + device=device) + c_sqrtk = _torch.zeros_like(c_s, dtype=dtype, device=device) + else: + ref_mean_sqrtk, c_sqrtk = self._get_c(n_ref, ref_values_sqrtk, Kinv) + # Store the current device. self._device = device From c52c91018f72834ba1b781085077775850af6711 Mon Sep 17 00:00:00 2001 From: Kirill Zinovjev Date: Sun, 20 Oct 2024 22:56:02 +0200 Subject: [PATCH 14/39] Register trainable model parameters with nn.Parameter --- emle/models/_emle.py | 19 ++++++++++++++++-- emle/models/_emle_base.py | 41 ++++++++++++++++----------------------- 2 files changed, 34 insertions(+), 26 deletions(-) diff --git a/emle/models/_emle.py b/emle/models/_emle.py index e8ca8eb..655595a 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -307,8 +307,23 @@ def __init__( except: raise IOError(f"Unable to load model parameters from: '{model}'") - self._emle_base = EMLEBase(params, self._aev_computer, species, - alpha_mode, device, dtype) + q_core = _torch.tensor(params["q_core"], dtype=dtype, device=device) + aev_mask = _torch.tensor(params["aev_mask"], dtype=_torch.bool, device=device) + n_ref = _torch.tensor(params["n_ref"], dtype=_torch.int64, device=device) + ref_features = _torch.tensor(params["ref_aev"], dtype=dtype, device=device) + + emle_params = { + 'a_QEq': _torch.tensor(params["a_QEq"], dtype=dtype, device=device), + 'a_Thole': _torch.tensor(params["a_Thole"], dtype=dtype, device=device), + 'ref_values_s': _torch.tensor(params["s_ref"], dtype=dtype, device=device), + 'ref_values_chi': _torch.tensor(params["chi_ref"], dtype=dtype, device=device), + 'k_Z': _torch.tensor(params["k_Z"], dtype=dtype, device=device) + if 'k_Z' in params else None, + 'sqrtk_ref': _torch.tensor(params["sqrtk_ref"], dtype=dtype, device=device) + if 'sqrtk_ref' in params else None + } + self._emle_base = EMLEBase(emle_params, self._aev_computer, aev_mask, species, + n_ref, ref_features, q_core, alpha_mode, device, dtype) q_total = _torch.tensor( params.get("total_charge", 0), dtype=dtype, device=device diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index c6ef0d6..d620c9b 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -14,10 +14,12 @@ def __init__( self, params, aev_computer, - # method="electrostatic", # Not used here, always electrostatic - species=None, + aev_mask, + species, + n_ref, + ref_features, + q_core, alpha_mode="species", - # atomic_numbers=None, # Not used here, since aev_computer is provided device=None, dtype=None, ): @@ -96,20 +98,18 @@ def __init__( self._aev_computer = aev_computer - # q_core is not trained, so not a parameter - q_core = _torch.tensor(params["q_core"], dtype=dtype, device=device) - # Store model parameters as tensors. - a_QEq = _torch.tensor(params["a_QEq"], dtype=dtype, device=device) - a_Thole = _torch.tensor(params["a_Thole"], dtype=dtype, device=device) - ref_values_s = _torch.tensor(params["s_ref"], dtype=dtype, device=device) - ref_values_chi = _torch.tensor(params["chi_ref"], dtype=dtype, device=device) + a_QEq = _torch.nn.Parameter(params["a_QEq"]) + a_Thole = _torch.nn.Parameter(params["a_Thole"]) + ref_values_s = _torch.nn.Parameter(params["ref_values_s"]) + ref_values_chi = _torch.nn.Parameter(params["ref_values_chi"]) if self._alpha_mode == "species": try: - k_Z = _torch.tensor(params["k_Z"], dtype=dtype, device=device) - ref_values_sqrtk = _torch.zeros_like(ref_values_s, - dtype=dtype, device=device) + k_Z = _torch.nn.Parameter(params["k_Z"]) + ref_values_sqrtk = _torch.nn.Parameter( + _torch.zeros_like(ref_values_s) + ) except: msg = ( "Missing 'k_Z' key in params. This is required when " @@ -118,9 +118,10 @@ def __init__( raise ValueError(msg) else: try: - k_Z = _torch.zeros_like(q_core, dtype=dtype, device=device) - ref_values_sqrtk = _torch.tensor(params["sqrtk_ref"], - dtype=dtype, device=device) + k_Z = _torch.nn.Parameter( + _torch.zeros_like(q_core, dtype=dtype, device=device) + ) + ref_values_sqrtk = _torch.nn.Parameter(params["sqrtk_ref"]) except: msg = ( "Missing 'sqrtk_ref' key in params. This is required when " @@ -135,14 +136,6 @@ def __init__( species_map[s] = i species_map = _torch.tensor(species_map, dtype=_torch.int64, device=device) - aev_mask = _torch.tensor(params["aev_mask"], dtype=_torch.bool, device=device) - - # Extract number of references per element - n_ref = _torch.tensor(params["n_ref"], dtype=_torch.int64, device=device) - - # Extract the reference features. - ref_features = _torch.tensor(params["ref_aev"], dtype=dtype, device=device) - # Compute the inverse of the K matrix. Kinv = self._get_Kinv(ref_features, 1e-3) From 1cc242b9f8b743cfe6847f702c4244700782cceb Mon Sep 17 00:00:00 2001 From: Kirill Zinovjev Date: Sun, 20 Oct 2024 22:56:37 +0200 Subject: [PATCH 15/39] Update EMLEBase.__init__ docstring --- emle/models/_emle_base.py | 29 +++++++++++------------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index d620c9b..15f6abb 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -34,25 +34,22 @@ def __init__( aev_computer: AEVComputer instance (torchani/NNPOps) - method: str - The desired embedding method. Options are: - "electrostatic": - Full ML electrostatic embedding. - "mechanical": - ML predicted charges for the core, but zero valence charge. - "nonpol": - Non-polarisable ML embedding. Here the induced component of - the potential is zeroed. - "mm": - MM charges are used for the core charge and valence charges - are set to zero. If this option is specified then the user - should also specify the MM charges for atoms in the QM - region. + aev_mask: torch.Tensor + mask for features coming from aev_computer species: List[int], Tuple[int], numpy.ndarray, torch.Tensor List of species (atomic numbers) supported by the EMLE model. If None, then the default species list will be used. + n_ref: torch.Tensor + number of GPR references for each element in species list + + ref_features: torch.Tensor + Feature vectors for GPR references + + q_core: torch.Tensor + Core charges for each element in species list + alpha_mode: str How atomic polarizabilities are calculated. "species": @@ -61,10 +58,6 @@ def __init__( scaling factors are obtained with GPR using the values learned for each reference environment - mm_charges: List[float], Tuple[Float], numpy.ndarray, torch.Tensor - List of MM charges for atoms in the QM region in units of mod - electron charge. This is required if the 'mm' method is specified. - device: torch.device The device on which to run the model. From 2e23f432f9b0f323f2080efd15fa1817cc23447e Mon Sep 17 00:00:00 2001 From: Kirill Zinovjev Date: Sun, 20 Oct 2024 23:01:17 +0200 Subject: [PATCH 16/39] Fix nans in padded A_thole calculation --- emle/models/_emle_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index 15f6abb..c6a0888 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -570,7 +570,7 @@ def _get_A_thole(self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s, q_val, alphap = alpha * self._a_Thole alphap_mat = alphap[:, :, None] * alphap[:, None, :] - au3 = r_data[0] ** 3 / _torch.sqrt(alphap_mat) + au3 = _torch.where(alphap_mat > 0, r_data[0] ** 3 / _torch.sqrt(alphap_mat), 0) au31 = au3.repeat_interleave(3, dim=2) au32 = au31.repeat_interleave(3, dim=1) From d1e673df170e9e8091ae393ad7c4eea234618c57 Mon Sep 17 00:00:00 2001 From: Kirill Zinovjev Date: Sun, 20 Oct 2024 23:35:17 +0200 Subject: [PATCH 17/39] Fix species mapping for highest supported element --- emle/models/_emle_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index c6a0888..1e778b2 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -124,7 +124,7 @@ def __init__( # Create a map between species (1, 6, 8) # and their indices in the model (0, 1, 2). - species_map = _np.full(max(species) + 1, fill_value=-1, dtype=_np.int64) + species_map = _np.full(max(species) + 2, fill_value=-1, dtype=_np.int64) for i, s in enumerate(species): species_map[s] = i species_map = _torch.tensor(species_map, dtype=_torch.int64, device=device) From 9aae5700f9c58bcc8dd8949743b8a2054b1483a7 Mon Sep 17 00:00:00 2001 From: Kirill Zinovjev Date: Mon, 21 Oct 2024 11:44:57 +0200 Subject: [PATCH 18/39] Fix padding in batched A_QEq matrix --- emle/models/_emle_base.py | 52 ++++++++++++++++++++++++--------------- 1 file changed, 32 insertions(+), 20 deletions(-) diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index 1e778b2..c5e00db 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -303,10 +303,10 @@ def forward(self, atomic_numbers, xyz_qm, q_total): xyz_qm_bohr = xyz_qm * ANGSTROM_TO_BOHR - r_data = self._get_r_data(xyz_qm_bohr) + r_data = self._get_r_data(xyz_qm_bohr, mask) q_core = self._q_core[species_id] * mask - q = self._get_q(r_data, s, chi, q_total) + q = self._get_q(r_data, s, chi, q_total, mask) q_val = q - q_core if self._alpha_mode == "species": @@ -391,7 +391,7 @@ def _gpr(self, mol_features, ref_mean, c, zid): return result @classmethod - def _get_r_data(cls, xyz): + def _get_r_data(cls, xyz, mask): """ Internal method to calculate r_data object. @@ -401,15 +401,19 @@ def _get_r_data(cls, xyz): xyz: torch.Tensor (N_BATCH, N_ATOMS, 3) Atomic positions. + mask: torch.Tensor (N_BATCH, N_ATOMS) + Mask for padded coordinates + Returns ------- result: r_data object """ n_batch, n_atoms_max = xyz.shape[:2] + mask_mat = mask[:, :, None] * mask[:, None, :] rr_mat = xyz[:, :, None, :] - xyz[:, None, :, :] - r_mat = _torch.cdist(xyz, xyz) + r_mat = _torch.where(mask_mat, _torch.cdist(xyz, xyz), 0.) r_inv = _torch.where(r_mat == 0.0, 0.0, 1.0 / r_mat) r_inv1 = r_inv.repeat_interleave(3, dim=2) @@ -431,7 +435,8 @@ def _get_r_data(cls, xyz): return r_mat, t01, t21, t22 - def _get_q(self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s, chi, q_total): + def _get_q(self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], + s, chi, q_total, mask): """ Internal method that predicts MBIS charges (Eq. 16 in 10.1021/acs.jctc.2c00914) @@ -450,17 +455,20 @@ def _get_q(self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s, chi, q_total) q_total: torch.Tensor (N_BATCH,) Total charge + mask: torch.Tensor (N_BATCH, N_ATOMS) + Mask for padded coordinates + Returns ------- result: torch.Tensor (N_BATCH, N_ATOMS,) Predicted MBIS charges. """ - A = self._get_A_QEq(r_data, s) + A = self._get_A_QEq(r_data, s, mask) b = _torch.hstack([-chi, q_total[:, None]]) return _torch.linalg.solve(A, b)[:, :-1] - def _get_A_QEq(self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s): + def _get_A_QEq(self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s, mask): """ Internal method, generates A matrix for charge prediction (Eq. 16 in 10.1021/acs.jctc.2c00914) @@ -473,6 +481,9 @@ def _get_A_QEq(self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s): s: torch.Tensor (N_BATCH, N_ATOMS,) MBIS valence shell widths. + mask: torch.Tensor (N_BATCH, N_ATOMS) + Mask for padded coordinates + Returns ------- @@ -490,19 +501,25 @@ def _get_A_QEq(self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s): diag_ones = _torch.ones_like(A.diagonal(dim1=-2, dim2=-1), dtype=dtype, device=device) pi = _torch.sqrt(_torch.tensor([_torch.pi], dtype=dtype, device=device)) - new_diag = diag_ones * _torch.where(s2 > 0, 1.0 / (s_gauss * pi), 0) + new_diag = diag_ones * _torch.where(s > 0, 1.0 / (s_gauss * pi), 0) - mask = _torch.diag_embed(diag_ones) - A = mask * _torch.diag_embed(new_diag) + (1.0 - mask) * A + diag_mask = _torch.diag_embed(diag_ones) + A = diag_mask * _torch.diag_embed(new_diag) + (1.0 - diag_mask) * A # Store the dimensions of A. - x, y = A.shape[1:] + n_batch, x, y = A.shape # Create an tensor of ones with one more row and column than A. - B = _torch.ones(len(A), x + 1, y + 1, dtype=dtype, device=device) + B_diag = _torch.ones((n_batch, x + 1), dtype=dtype, device=device) + B = _torch.diag_embed(B_diag) # Copy A into B. - B[:, :x, :y] = A + mask_mat = mask[:, :, None] * mask[:, None, :] + B[:, :x, :y] = _torch.where(mask_mat, A, B[:, :x, :y]) + + # Set last row and column to 1 (masked) + B[:, -1, :-1] = mask.float() + B[:, :-1, -1] = mask.float() # Set the final entry on the diagonal to zero. B[:, -1, -1] = 0.0 @@ -531,13 +548,8 @@ def _get_T0_gaussian(t01, r, s_mat): results: torch.Tensor (N_BATCH, N_ATOMS, N_ATOMS) """ - return t01 * _torch.erf( - r - / ( - s_mat - * _torch.sqrt(_torch.tensor([2.0], dtype=r.dtype, device=r.device)) - ) - ) + sqrt2 = _torch.sqrt(_torch.tensor([2.0], dtype=r.dtype, device=r.device)) + return t01 * _torch.where(s_mat > 0, _torch.erf(r / (s_mat * sqrt2)), 0.) def _get_A_thole(self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s, q_val, k): """ From 40f4c0a6cecd264d73079108ee991e9756823133 Mon Sep 17 00:00:00 2001 From: Kirill Zinovjev Date: Mon, 21 Oct 2024 12:03:43 +0200 Subject: [PATCH 19/39] Remove model parameters from buffers and to/cuda etc. methods --- emle/models/_emle_base.py | 46 ++++++++++++--------------------------- 1 file changed, 14 insertions(+), 32 deletions(-) diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index c5e00db..5f84b8a 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -92,16 +92,16 @@ def __init__( self._aev_computer = aev_computer # Store model parameters as tensors. - a_QEq = _torch.nn.Parameter(params["a_QEq"]) - a_Thole = _torch.nn.Parameter(params["a_Thole"]) - ref_values_s = _torch.nn.Parameter(params["ref_values_s"]) - ref_values_chi = _torch.nn.Parameter(params["ref_values_chi"]) + self.a_QEq = _torch.nn.Parameter(params["a_QEq"]) + self.a_Thole = _torch.nn.Parameter(params["a_Thole"]) + self.ref_values_s = _torch.nn.Parameter(params["ref_values_s"]) + self.ref_values_chi = _torch.nn.Parameter(params["ref_values_chi"]) if self._alpha_mode == "species": try: - k_Z = _torch.nn.Parameter(params["k_Z"]) - ref_values_sqrtk = _torch.nn.Parameter( - _torch.zeros_like(ref_values_s) + self.k_Z = _torch.nn.Parameter(params["k_Z"]) + self.ref_values_sqrtk = _torch.nn.Parameter( + _torch.zeros_like(self.ref_values_s) ) except: msg = ( @@ -111,10 +111,10 @@ def __init__( raise ValueError(msg) else: try: - k_Z = _torch.nn.Parameter( + self.k_Z = _torch.nn.Parameter( _torch.zeros_like(q_core, dtype=dtype, device=device) ) - ref_values_sqrtk = _torch.nn.Parameter(params["sqrtk_ref"]) + self.ref_values_sqrtk = _torch.nn.Parameter(params["sqrtk_ref"]) except: msg = ( "Missing 'sqrtk_ref' key in params. This is required when " @@ -134,15 +134,15 @@ def __init__( # Calculate GPR coefficients for the valence shell widths (s) # and electronegativities (chi). - ref_mean_s, c_s = self._get_c(n_ref, ref_values_s, Kinv) - ref_mean_chi, c_chi = self._get_c(n_ref, ref_values_chi, Kinv) + ref_mean_s, c_s = self._get_c(n_ref, self.ref_values_s, Kinv) + ref_mean_chi, c_chi = self._get_c(n_ref, self.ref_values_chi, Kinv) if self._alpha_mode == "species": ref_mean_sqrtk = _torch.zeros_like(ref_mean_s, dtype=dtype, device=device) c_sqrtk = _torch.zeros_like(c_s, dtype=dtype, device=device) else: - ref_mean_sqrtk, c_sqrtk = self._get_c(n_ref, ref_values_sqrtk, Kinv) + ref_mean_sqrtk, c_sqrtk = self._get_c(n_ref, self.ref_values_sqrtk, Kinv) # Store the current device. self._device = device @@ -152,9 +152,6 @@ def __init__( self.register_buffer("_aev_mask", aev_mask) self.register_buffer("_Kinv", Kinv) self.register_buffer("_q_core", q_core) - self.register_buffer("_a_QEq", a_QEq) - self.register_buffer("_a_Thole", a_Thole) - self.register_buffer("_k_Z", k_Z) self.register_buffer("_ref_features", ref_features) self.register_buffer("_n_ref", n_ref) self.register_buffer("_ref_mean_s", ref_mean_s) @@ -172,9 +169,6 @@ def to(self, *args, **kwargs): self._Kinv = self._Kinv.to(*args, **kwargs) self._aev_mask = self._aev_mask.to(*args, **kwargs) self._q_core = self._q_core.to(*args, **kwargs) - self._a_QEq = self._a_QEq.to(*args, **kwargs) - self._a_Thole = self._a_Thole.to(*args, **kwargs) - self._k_Z = self._k_Z.to(*args, **kwargs) self._ref_features = self._ref_features.to(*args, **kwargs) self._n_ref = self._n_ref.to(*args, **kwargs) self._ref_mean_s = self._ref_mean_s.to(*args, **kwargs) @@ -192,9 +186,6 @@ def cuda(self, **kwargs): self._Kinv = self._Kinv.cuda(**kwargs) self._aev_mask = self._aev_mask.cuda(**kwargs) self._q_core = self._q_core.cuda(**kwargs) - self._a_QEq = self._a_QEq.cuda(**kwargs) - self._a_Thole = self._a_Thole.cuda(**kwargs) - self._k_Z = self._k_Z.cuda(**kwargs) self._ref_features = self._ref_features.cuda(**kwargs) self._n_ref = self._n_ref.cuda(**kwargs) self._ref_mean_s = self._ref_mean_s.cuda(**kwargs) @@ -212,9 +203,6 @@ def cpu(self, **kwargs): self._Kinv = self._Kinv.cpu(**kwargs) self._aev_mask = self._aev_mask.cpu(**kwargs) self._q_core = self._q_core.cpu(**kwargs) - self._a_QEq = self._a_QEq.cpu(**kwargs) - self._a_Thole = self._a_Thole.cpu(**kwargs) - self._k_Z = self._k_Z.cpu(**kwargs) self._ref_features = self._ref_features.cpu(**kwargs) self._n_ref = self._n_ref.cpu(**kwargs) self._ref_mean_s = self._ref_mean_s.cpu(**kwargs) @@ -230,9 +218,6 @@ def double(self): """ self._Kinv = self._Kinv.double() self._q_core = self._q_core.double() - self._a_QEq = self._a_QEq.double() - self._a_Thole = self._a_Thole.double() - self._k_Z = self._k_Z.double() self._ref_features = self._ref_features.double() self._ref_mean_s = self._ref_mean_s.double() self._ref_mean_chi = self._ref_mean_chi.double() @@ -248,9 +233,6 @@ def float(self): """ self._Kinv = self._Kinv.float() self._q_core = self._q_core.float() - self._a_QEq = self._a_QEq.float() - self._a_Thole = self._a_Thole.float() - self._k = self._k.float() self._ref_features = self._ref_features.float() self._ref_mean_s = self._ref_mean_s.float() self._ref_mean_chi = self._ref_mean_chi.float() @@ -489,7 +471,7 @@ def _get_A_QEq(self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s, mask): result: torch.Tensor (N_BATCH, N_ATOMS + 1, N_ATOMS + 1) """ - s_gauss = s * self._a_QEq + s_gauss = s * self.a_QEq s2 = s_gauss**2 s_mat = _torch.sqrt(s2[:, :, None] + s2[:, None, :]) @@ -579,7 +561,7 @@ def _get_A_thole(self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s, q_val, v = -60 * q_val * s**3 alpha = v * k - alphap = alpha * self._a_Thole + alphap = alpha * self.a_Thole alphap_mat = alphap[:, :, None] * alphap[:, None, :] au3 = _torch.where(alphap_mat > 0, r_data[0] ** 3 / _torch.sqrt(alphap_mat), 0) From 2582b701076a4ffbb34100823c8e613aca7c5865 Mon Sep 17 00:00:00 2001 From: Kirill Zinovjev Date: Mon, 21 Oct 2024 19:02:01 +0200 Subject: [PATCH 20/39] Typo --- emle/models/_emle_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index 5f84b8a..891b784 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -176,7 +176,7 @@ def to(self, *args, **kwargs): self._ref_mean_sqrtk = self._ref_mean_sqrtk.to(*args, **kwargs) self._c_s = self._c_s.to(*args, **kwargs) self._c_chi = self._c_chi.to(*args, **kwargs) - self._c_k = self._c_k.to(*args, **kwargs) + self._c_sqrtk = self._c_sqrtk.to(*args, **kwargs) def cuda(self, **kwargs): """ From af0bfde4c4713bc30401e9e451cda78cffaee7b1 Mon Sep 17 00:00:00 2001 From: Kirill Zinovjev Date: Mon, 21 Oct 2024 19:02:45 +0200 Subject: [PATCH 21/39] Fix mean GPR reference calculation --- emle/models/_emle_base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index 891b784..3e96f42 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -328,7 +328,8 @@ def _get_Kinv(cls, ref_features, sigma): @classmethod def _get_c(cls, n_ref, ref, Kinv): - ref_mean = _torch.sum(ref, dim=1) / n_ref + mask = _torch.arange(ref.shape[1]) < n_ref[:, None] + ref_mean = _torch.sum(ref * mask, dim=1) / n_ref ref_shifted = ref - ref_mean[:, None] return ref_mean, (Kinv @ ref_shifted[:, :, None]).squeeze() From 4e547695c6f6101ed50c029f211c0831ec7a7bf9 Mon Sep 17 00:00:00 2001 From: Kirill Zinovjev Date: Mon, 21 Oct 2024 19:56:54 +0200 Subject: [PATCH 22/39] Cleanup --- emle/models/_emle_base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index 3e96f42..2ef35e9 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -368,7 +368,6 @@ def _gpr(self, mol_features, ref_mean, c, zid): mol_features_z = mol_features[zid == i] K_mol_ref2 = (mol_features_z @ ref_features_z.T) ** 2 - # K_mol_ref2 = K_mol_ref2.reshape(K_mol_ref2.shape[:-1]) result[zid == i] = K_mol_ref2 @ c[i, :n_ref] + ref_mean[i] return result From 4e685b457d56f9cdb1a5af71a9ce2ae2ada5651d Mon Sep 17 00:00:00 2001 From: Kirill Zinovjev Date: Mon, 21 Oct 2024 20:07:44 +0200 Subject: [PATCH 23/39] Redefine reference model to work as a correction to the species one --- emle/models/_emle_base.py | 27 +++++++-------------------- 1 file changed, 7 insertions(+), 20 deletions(-) diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index 2ef35e9..196aff4 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -96,24 +96,10 @@ def __init__( self.a_Thole = _torch.nn.Parameter(params["a_Thole"]) self.ref_values_s = _torch.nn.Parameter(params["ref_values_s"]) self.ref_values_chi = _torch.nn.Parameter(params["ref_values_chi"]) + self.k_Z = _torch.nn.Parameter(params["k_Z"]) - if self._alpha_mode == "species": - try: - self.k_Z = _torch.nn.Parameter(params["k_Z"]) - self.ref_values_sqrtk = _torch.nn.Parameter( - _torch.zeros_like(self.ref_values_s) - ) - except: - msg = ( - "Missing 'k_Z' key in params. This is required when " - "using 'species' alpha mode." - ) - raise ValueError(msg) - else: + if self._alpha_mode == "reference": try: - self.k_Z = _torch.nn.Parameter( - _torch.zeros_like(q_core, dtype=dtype, device=device) - ) self.ref_values_sqrtk = _torch.nn.Parameter(params["sqrtk_ref"]) except: msg = ( @@ -291,10 +277,11 @@ def forward(self, atomic_numbers, xyz_qm, q_total): q = self._get_q(r_data, s, chi, q_total, mask) q_val = q - q_core - if self._alpha_mode == "species": - k = self._k_Z[species_id] - else: - k = self._gpr(aev, self._ref_mean_sqrtk, self._c_sqrtk, species_id) ** 2 + k = self.k_Z[species_id] + + if self._alpha_mode == "reference": + k_scale = self._gpr(aev, self._ref_mean_sqrtk, self._c_sqrtk, species_id) ** 2 + k = k_scale * k A_thole = self._get_A_thole(r_data, s, q_val, k) From 1b81f8a843f03757697338b064ebb8072ad5202c Mon Sep 17 00:00:00 2001 From: Kirill Zinovjev Date: Mon, 21 Oct 2024 20:23:25 +0200 Subject: [PATCH 24/39] Fix auxiliary tensors created on wrong device --- emle/models/_emle_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index 196aff4..afbfb39 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -254,7 +254,7 @@ def forward(self, atomic_numbers, xyz_qm, q_total): Valence widths, core charges, valence charges, A_thole tensor """ - mask = atomic_numbers > 0 + mask = _torch.tensor(atomic_numbers > 0, device=self._ref_mean_s.device) # Convert the atomic numbers to species IDs. species_id = self._species_map[atomic_numbers] @@ -315,7 +315,7 @@ def _get_Kinv(cls, ref_features, sigma): @classmethod def _get_c(cls, n_ref, ref, Kinv): - mask = _torch.arange(ref.shape[1]) < n_ref[:, None] + mask = _torch.arange(ref.shape[1], device=n_ref.device) < n_ref[:, None] ref_mean = _torch.sum(ref * mask, dim=1) / n_ref ref_shifted = ref - ref_mean[:, None] return ref_mean, (Kinv @ ref_shifted[:, :, None]).squeeze() From fa7517689a1f79e193f953107edae25033ae8999 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Tue, 22 Oct 2024 09:55:47 +0100 Subject: [PATCH 25/39] Blacken. --- emle/models/_emle.py | 45 +++++++++++++++++++++++++++------------ emle/models/_emle_base.py | 29 ++++++++++++++----------- 2 files changed, 47 insertions(+), 27 deletions(-) diff --git a/emle/models/_emle.py b/emle/models/_emle.py index 655595a..2bbc1d9 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -313,17 +313,35 @@ def __init__( ref_features = _torch.tensor(params["ref_aev"], dtype=dtype, device=device) emle_params = { - 'a_QEq': _torch.tensor(params["a_QEq"], dtype=dtype, device=device), - 'a_Thole': _torch.tensor(params["a_Thole"], dtype=dtype, device=device), - 'ref_values_s': _torch.tensor(params["s_ref"], dtype=dtype, device=device), - 'ref_values_chi': _torch.tensor(params["chi_ref"], dtype=dtype, device=device), - 'k_Z': _torch.tensor(params["k_Z"], dtype=dtype, device=device) - if 'k_Z' in params else None, - 'sqrtk_ref': _torch.tensor(params["sqrtk_ref"], dtype=dtype, device=device) - if 'sqrtk_ref' in params else None + "a_QEq": _torch.tensor(params["a_QEq"], dtype=dtype, device=device), + "a_Thole": _torch.tensor(params["a_Thole"], dtype=dtype, device=device), + "ref_values_s": _torch.tensor(params["s_ref"], dtype=dtype, device=device), + "ref_values_chi": _torch.tensor( + params["chi_ref"], dtype=dtype, device=device + ), + "k_Z": ( + _torch.tensor(params["k_Z"], dtype=dtype, device=device) + if "k_Z" in params + else None + ), + "sqrtk_ref": ( + _torch.tensor(params["sqrtk_ref"], dtype=dtype, device=device) + if "sqrtk_ref" in params + else None + ), } - self._emle_base = EMLEBase(emle_params, self._aev_computer, aev_mask, species, - n_ref, ref_features, q_core, alpha_mode, device, dtype) + self._emle_base = EMLEBase( + emle_params, + self._aev_computer, + aev_mask, + species, + n_ref, + ref_features, + q_core, + alpha_mode, + device, + dtype, + ) q_total = _torch.tensor( params.get("total_charge", 0), dtype=dtype, device=device @@ -455,9 +473,9 @@ def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm): if len(xyz_mm) == 0: return _torch.zeros(2, dtype=xyz_qm.dtype, device=xyz_qm.device) - s, q_core, q_val, A_thole = self._emle_base(atomic_numbers[None, :], - xyz_qm[None, :, :], - self._q_total[None]) + s, q_core, q_val, A_thole = self._emle_base( + atomic_numbers[None, :], xyz_qm[None, :, :], self._q_total[None] + ) s, q_core, q_val, A_thole = s[0], q_core[0], q_val[0], A_thole[0] # Convert coordinates to Bohr. @@ -535,7 +553,6 @@ def _get_mu_ind( mu_ind = _torch.linalg.solve(A, fields) return mu_ind.reshape((-1, 3)) - @staticmethod def _get_vpot_q(q, T0): """ diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index afbfb39..ccf55da 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -124,8 +124,7 @@ def __init__( ref_mean_chi, c_chi = self._get_c(n_ref, self.ref_values_chi, Kinv) if self._alpha_mode == "species": - ref_mean_sqrtk = _torch.zeros_like(ref_mean_s, dtype=dtype, - device=device) + ref_mean_sqrtk = _torch.zeros_like(ref_mean_s, dtype=dtype, device=device) c_sqrtk = _torch.zeros_like(c_s, dtype=dtype, device=device) else: ref_mean_sqrtk, c_sqrtk = self._get_c(n_ref, self.ref_values_sqrtk, Kinv) @@ -280,7 +279,9 @@ def forward(self, atomic_numbers, xyz_qm, q_total): k = self.k_Z[species_id] if self._alpha_mode == "reference": - k_scale = self._gpr(aev, self._ref_mean_sqrtk, self._c_sqrtk, species_id) ** 2 + k_scale = ( + self._gpr(aev, self._ref_mean_sqrtk, self._c_sqrtk, species_id) ** 2 + ) k = k_scale * k A_thole = self._get_A_thole(r_data, s, q_val, k) @@ -382,7 +383,7 @@ def _get_r_data(cls, xyz, mask): mask_mat = mask[:, :, None] * mask[:, None, :] rr_mat = xyz[:, :, None, :] - xyz[:, None, :, :] - r_mat = _torch.where(mask_mat, _torch.cdist(xyz, xyz), 0.) + r_mat = _torch.where(mask_mat, _torch.cdist(xyz, xyz), 0.0) r_inv = _torch.where(r_mat == 0.0, 0.0, 1.0 / r_mat) r_inv1 = r_inv.repeat_interleave(3, dim=2) @@ -395,17 +396,18 @@ def _get_r_data(cls, xyz, mask): id2 = _torch.tile( _torch.eye(3, dtype=xyz.dtype, device=xyz.device).T, - (1, n_atoms_max, n_atoms_max) + (1, n_atoms_max, n_atoms_max), ) t01 = r_inv - t21 = -id2 * r_inv2 ** 3 - t22 = 3 * outer * r_inv2 ** 5 + t21 = -id2 * r_inv2**3 + t22 = 3 * outer * r_inv2**5 return r_mat, t01, t21, t22 - def _get_q(self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], - s, chi, q_total, mask): + def _get_q( + self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s, chi, q_total, mask + ): """ Internal method that predicts MBIS charges (Eq. 16 in 10.1021/acs.jctc.2c00914) @@ -467,8 +469,9 @@ def _get_A_QEq(self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s, mask): A = self._get_T0_gaussian(r_data[1], r_data[0], s_mat) - diag_ones = _torch.ones_like(A.diagonal(dim1=-2, dim2=-1), - dtype=dtype, device=device) + diag_ones = _torch.ones_like( + A.diagonal(dim1=-2, dim2=-1), dtype=dtype, device=device + ) pi = _torch.sqrt(_torch.tensor([_torch.pi], dtype=dtype, device=device)) new_diag = diag_ones * _torch.where(s > 0, 1.0 / (s_gauss * pi), 0) @@ -518,7 +521,7 @@ def _get_T0_gaussian(t01, r, s_mat): results: torch.Tensor (N_BATCH, N_ATOMS, N_ATOMS) """ sqrt2 = _torch.sqrt(_torch.tensor([2.0], dtype=r.dtype, device=r.device)) - return t01 * _torch.where(s_mat > 0, _torch.erf(r / (s_mat * sqrt2)), 0.) + return t01 * _torch.where(s_mat > 0, _torch.erf(r / (s_mat * sqrt2)), 0.0) def _get_A_thole(self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s, q_val, k): """ @@ -558,7 +561,7 @@ def _get_A_thole(self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s, q_val, A = -self._get_T2_thole(r_data[2], r_data[3], au32) alpha3 = alpha.repeat_interleave(3, dim=1) - new_diag = _torch.where(alpha3 > 0, 1.0 / alpha3, 1.) + new_diag = _torch.where(alpha3 > 0, 1.0 / alpha3, 1.0) diag_ones = _torch.ones_like(new_diag, dtype=A.dtype, device=A.device) mask = _torch.diag_embed(diag_ones) A = mask * _torch.diag_embed(new_diag) + (1.0 - mask) * A From 3991d8d3dc4930a69431f207cf42f4bf5bb4687d Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Tue, 22 Oct 2024 10:31:47 +0100 Subject: [PATCH 26/39] Float variable needs to be local. --- emle/models/_emle_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index ccf55da..6760102 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -5,8 +5,6 @@ from torch import Tensor from typing import Tuple -ANGSTROM_TO_BOHR = 1.8897261258369282 - class EMLEBase(_torch.nn.Module): @@ -268,6 +266,8 @@ def forward(self, atomic_numbers, xyz_qm, q_total): # Compute the electronegativities. chi = self._gpr(aev, self._ref_mean_chi, self._c_chi, species_id) + # Convert coordinates to Bohr. + ANGSTROM_TO_BOHR = 1.8897261258369282 xyz_qm_bohr = xyz_qm * ANGSTROM_TO_BOHR r_data = self._get_r_data(xyz_qm_bohr, mask) From fd46988920c7ac9f3238f8651df84725943bae97 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Tue, 22 Oct 2024 10:32:02 +0100 Subject: [PATCH 27/39] Improve docstrings. --- emle/models/_emle_base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index 6760102..fc2a12f 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -30,10 +30,11 @@ def __init__( params: dict EMLE model parameters - aev_computer: AEVComputer instance (torchani/NNPOps) + aev_computer: torchani.AEVComputer + AEV computer instance used to compute AEVs. aev_mask: torch.Tensor - mask for features coming from aev_computer + Mask for features coming from aev_computer. species: List[int], Tuple[int], numpy.ndarray, torch.Tensor List of species (atomic numbers) supported by the EMLE model. If From 1fa05e306746f861368eba5aa4d6487ab17e412e Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Tue, 22 Oct 2024 10:34:45 +0100 Subject: [PATCH 28/39] Fix padded coordinates mask. --- emle/models/_emle_base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index fc2a12f..cdfa5fe 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -252,7 +252,8 @@ def forward(self, atomic_numbers, xyz_qm, q_total): Valence widths, core charges, valence charges, A_thole tensor """ - mask = _torch.tensor(atomic_numbers > 0, device=self._ref_mean_s.device) + # Mask for padded coordinates. + mask = atomic_numbers > 0 # Convert the atomic numbers to species IDs. species_id = self._species_map[atomic_numbers] From de48b290f0268f358b7841edb81a1f79c9d9394f Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Tue, 22 Oct 2024 10:37:44 +0100 Subject: [PATCH 29/39] Fix docstrings. --- emle/models/_emle_base.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index cdfa5fe..4725898 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -7,6 +7,12 @@ class EMLEBase(_torch.nn.Module): + """ + Base class for the EMLE model. This is used to compute valence shell + widths, core charges, valence charges, and the A_thole tensor for a batch + of QM systems, which in turn can be used to compute static and induced + electrostating embedding energies using the EMLE model. + """ def __init__( self, @@ -228,7 +234,8 @@ def float(self): def forward(self, atomic_numbers, xyz_qm, q_total): """ - Computes the static and induced EMLE energy components. + Compute the valence widths, core charges, valence charges, and + A_thole tensor for a batch of QM systems. Parameters ---------- From 756c08c22e956e5f5a1ea7cbed1683eb7d4b5afe Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Tue, 22 Oct 2024 11:44:52 +0100 Subject: [PATCH 30/39] Fixes to allow TorchScript plus refactoring and input validation. --- emle/models/_ani.py | 6 +- emle/models/_emle.py | 30 +++---- emle/models/_emle_base.py | 177 +++++++++++++++++++++++++++++++++----- 3 files changed, 173 insertions(+), 40 deletions(-) diff --git a/emle/models/_ani.py b/emle/models/_ani.py index 60ba708..619e316 100644 --- a/emle/models/_ani.py +++ b/emle/models/_ani.py @@ -241,7 +241,7 @@ def hook( input: Tuple[Tuple[Tensor, Tensor], Optional[Tensor], Optional[Tensor]], output: Tuple[Tensor, Tensor], ): - module._aev = output[1][0] + module._aev = output[1] else: @@ -250,7 +250,7 @@ def hook( input: Tuple[Tuple[Tensor, Tensor], Optional[Tensor], Optional[Tensor]], output: _torchani.aev.SpeciesAEV, ): - module._aev = output[1][0] + module._aev = output[1] # Register the hook. self._aev_hook = self._ani2x.aev_computer.register_forward_hook(hook) @@ -351,7 +351,7 @@ def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm): # Set the AEVs captured by the forward hook as an attribute of the # EMLE model. - self._emle._aev = self._ani2x.aev_computer._aev + self._emle._emle_base._aev = self._ani2x.aev_computer._aev # Get the EMLE energy components. E_emle = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm) diff --git a/emle/models/_emle.py b/emle/models/_emle.py index 2bbc1d9..fd6c786 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -39,7 +39,7 @@ from typing import Optional, Tuple, List from . import _patches -from . import EMLEBase +from . import EMLEBase as _EMLEBase # Monkey-patch the TorchANI BuiltInModel and BuiltinEnsemble classes so that # they call self.aev_computer using args only to allow forward hooks to work @@ -330,19 +330,8 @@ def __init__( else None ), } - self._emle_base = EMLEBase( - emle_params, - self._aev_computer, - aev_mask, - species, - n_ref, - ref_features, - q_core, - alpha_mode, - device, - dtype, - ) + # Store the total charge. q_total = _torch.tensor( params.get("total_charge", 0), dtype=dtype, device=device ) @@ -359,8 +348,19 @@ def __init__( self.register_buffer("_q_total", q_total) self.register_buffer("_q_core_mm", q_core_mm) - # Initalise an empty AEV tensor to use to store the AEVs in derived classes. - self._aev = _torch.empty(0, dtype=dtype, device=device) + # Create the base EMLE model. + self._emle_base = _EMLEBase( + emle_params, + n_ref, + ref_features, + q_core, + aev_computer=self._aev_computer, + aev_mask=aev_mask, + alpha_mode=self._alpha_mode, + species=self._species, + device=device, + dtype=dtype, + ) def _to_dict(self): """ diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index 4725898..5b95f9c 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -1,3 +1,25 @@ +####################################################################### +# EMLE-Engine: https://github.com/chemle/emle-engine +# +# Copyright: 2023-2024 +# +# Authors: Lester Hedges +# Kirill Zinovjev +# +# EMLE-Engine is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# EMLE-Engine is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with EMLE-Engine. If not, see . +##################################################################### + import numpy as _np import torch as _torch @@ -5,6 +27,17 @@ from torch import Tensor from typing import Tuple +import torchani as _torchani + +try: + import NNPOps as _NNPOps + + _NNPOps.OptimizedTorchANI = _patches.OptimizedTorchANI + + _has_nnpops = True +except: + _has_nnpops = False + class EMLEBase(_torch.nn.Module): """ @@ -14,15 +47,18 @@ class EMLEBase(_torch.nn.Module): electrostating embedding energies using the EMLE model. """ + # Store the list of supported species. + _species = [1, 6, 7, 8, 16] + def __init__( self, params, - aev_computer, - aev_mask, - species, n_ref, ref_features, q_core, + aev_computer=None, + aev_mask=None, + species=None, alpha_mode="species", device=None, dtype=None, @@ -34,26 +70,16 @@ def __init__( ---------- params: dict - EMLE model parameters - - aev_computer: torchani.AEVComputer - AEV computer instance used to compute AEVs. - - aev_mask: torch.Tensor - Mask for features coming from aev_computer. - - species: List[int], Tuple[int], numpy.ndarray, torch.Tensor - List of species (atomic numbers) supported by the EMLE model. If - None, then the default species list will be used. + EMLE model parameters. n_ref: torch.Tensor number of GPR references for each element in species list ref_features: torch.Tensor - Feature vectors for GPR references + Feature vectors for GPR references. q_core: torch.Tensor - Core charges for each element in species list + Core charges for each element in species list. alpha_mode: str How atomic polarizabilities are calculated. @@ -63,6 +89,15 @@ def __init__( scaling factors are obtained with GPR using the values learned for each reference environment + aev_computer: torchani.AEVComputer + AEV computer instance used to compute AEVs. + + aev_mask: torch.Tensor + Mask for features coming from aev_computer. + + species: List[int], Tuple[int], numpy.ndarray, torch.Tensor + List of species (atomic numbers) supported by the EMLE model. + device: torch.device The device on which to run the model. @@ -73,6 +108,46 @@ def __init__( # Call the base class constructor. super().__init__() + # Validate the parameters. + if not isinstance(params, dict): + raise TypeError("'params' must be of type 'dict'") + if not all( + k in params + for k in ["a_QEq", "a_Thole", "ref_values_s", "ref_values_chi", "k_Z"] + ): + raise ValueError( + "'params' must contain keys 'a_QEq', 'a_Thole', 'ref_values_s', 'ref_values_chi', and 'k_Z'" + ) + + # Validate the number of references. + if not isinstance(n_ref, _torch.Tensor): + raise TypeError("'n_ref' must be of type 'torch.Tensor'") + if len(n_ref.shape) != 1: + raise ValueError("'n_ref' must be a 1D tensor") + if not n_ref.dtype == _torch.int64: + raise ValueError("'n_ref' must have dtype 'torch.int64'") + + # Validate the reference features. + if not isinstance(ref_features, _torch.Tensor): + raise TypeError("'ref_features' must be of type 'torch.Tensor'") + if len(ref_features.shape) != 3: + raise ValueError("'ref_features' must be a 3D tensor") + if not ref_features.dtype in (_torch.float64, _torch.float32): + raise ValueError( + "'ref_features' must have dtype 'torch.float64' or 'torch.float32'" + ) + + # Validate the core charges. + if not isinstance(q_core, _torch.Tensor): + raise TypeError("'q_core' must be of type 'torch.Tensor'") + if len(q_core.shape) != 1: + raise ValueError("'q_core' must be a 1D tensor") + if not q_core.dtype in (_torch.float64, _torch.float32): + raise ValueError( + "'q_core' must have dtype 'torch.float64' or 'torch.float32'" + ) + + # Validate the alpha mode. if alpha_mode is None: alpha_mode = "species" if not isinstance(alpha_mode, str): @@ -82,6 +157,32 @@ def __init__( raise ValueError("'alpha_mode' must be 'species' or 'reference'") self._alpha_mode = alpha_mode + # Validate the AEV computer. + if aev_computer is not None: + allowed_types = [_torchani.AEVComputer] + if _has_nnpops: + allowed_types.append( + _NNPOps.SymmetryFunctions.TorchANISymmetryFunctions + ) + if not isinstance(aev_computer, tuple(allowed_types)): + raise TypeError( + "'aev_computer' must be of type 'torchani.AEVComputer' or 'NNPOps.SymmetryFunctions.TorchANISymmetryFunctions'" + ) + self._aev_computer = aev_computer + else: + self._aev_computer = None + + # Validate the AEV mask. + if aev_mask is not None: + if not isinstance(aev_mask, _torch.Tensor): + raise TypeError("'aev_mask' must be of type 'torch.Tensor'") + if len(aev_mask.shape) != 1: + raise ValueError("'aev_mask' must be a 1D tensor") + if not aev_mask.dtype == _torch.bool: + raise ValueError("'aev_mask' must have dtype 'torch.bool'") + else: + aev_mask = _torch.ones(ref_features.shape[2], dtype=_torch.bool) + if device is not None: if not isinstance(device, _torch.device): raise TypeError("'device' must be of type 'torch.device'") @@ -94,8 +195,6 @@ def __init__( else: dtype = _torch.get_default_dtype() - self._aev_computer = aev_computer - # Store model parameters as tensors. self.a_QEq = _torch.nn.Parameter(params["a_QEq"]) self.a_Thole = _torch.nn.Parameter(params["a_Thole"]) @@ -113,8 +212,22 @@ def __init__( ) raise ValueError(msg) - # Create a map between species (1, 6, 8) - # and their indices in the model (0, 1, 2). + # Validate the species. + if species is None: + # Use the default species. + species = self._species + if isinstance(species, (_np.ndarray, _torch.Tensor)): + species = species.tolist() + if not isinstance(species, (tuple, list)): + raise TypeError( + "'species' must be of type 'list', 'tuple', or 'numpy.ndarray'" + ) + if not all(isinstance(s, int) for s in species): + raise TypeError("All elements of 'species' must be of type 'int'") + if not all(s > 0 for s in species): + raise ValueError("All elements of 'species' must be greater than zero") + + # Create a map between species and their indices in the model. species_map = _np.full(max(species) + 2, fill_value=-1, dtype=_np.int64) for i, s in enumerate(species): species_map[s] = i @@ -151,10 +264,12 @@ def __init__( self.register_buffer("_c_chi", c_chi) self.register_buffer("_c_sqrtk", c_sqrtk) - # Initalise an empty AEV tensor to use to store the AEVs in derived classes. + # Initalise an empty AEV tensor to use to store the AEVs in parent models. self._aev = _torch.empty(0, dtype=dtype, device=device) def to(self, *args, **kwargs): + if self._aev_computer is not None: + self._aev_computer = self._aev_computer.to(*args, **kwargs) self._species_map = self._species_map.to(*args, **kwargs) self._Kinv = self._Kinv.to(*args, **kwargs) self._aev_mask = self._aev_mask.to(*args, **kwargs) @@ -168,10 +283,18 @@ def to(self, *args, **kwargs): self._c_chi = self._c_chi.to(*args, **kwargs) self._c_sqrtk = self._c_sqrtk.to(*args, **kwargs) + # Check for a device type in args and update the device attribute. + for arg in args: + if isinstance(arg, _torch.device): + self._device = arg + break + def cuda(self, **kwargs): """ Move all model parameters and buffers to CUDA memory. """ + if self._aev_computer is not None: + self._aev_computer = self._aev_computer.cuda(**kwargs) self._species_map = self._species_map.cuda(**kwargs) self._Kinv = self._Kinv.cuda(**kwargs) self._aev_mask = self._aev_mask.cuda(**kwargs) @@ -189,6 +312,8 @@ def cpu(self, **kwargs): """ Move all model parameters and buffers to CPU memory. """ + if self._aev_computer is not None: + self._aev_computer = self._aev_computer.cpu(**kwargs) self._species_map = self._species_map.cpu(**kwargs) self._Kinv = self._Kinv.cpu(**kwargs) self._aev_mask = self._aev_mask.cpu(**kwargs) @@ -206,6 +331,8 @@ def double(self): """ Casts all floating point model parameters and buffers to float64 precision. """ + if self._aev_computer is not None: + self._aev_computer = self._aev_computer.double() self._Kinv = self._Kinv.double() self._q_core = self._q_core.double() self._ref_features = self._ref_features.double() @@ -221,6 +348,8 @@ def float(self): """ Casts all floating point model parameters and buffers to float32 precision. """ + if self._aev_computer is not None: + self._aev_computer = self._aev_computer.float() self._Kinv = self._Kinv.float() self._q_core = self._q_core.float() self._ref_features = self._ref_features.float() @@ -266,7 +395,11 @@ def forward(self, atomic_numbers, xyz_qm, q_total): species_id = self._species_map[atomic_numbers] # Compute the AEVs. - aev = self._aev_computer((species_id, xyz_qm))[1][:, :, self._aev_mask] + if self._aev_computer is not None: + aev = self._aev_computer((species_id, xyz_qm))[1][:, :, self._aev_mask] + # The AEVs have been pre-computed by a parent model. + else: + aev = self._aev[:, :, self._aev_mask] aev = aev / _torch.linalg.norm(aev, ord=2, dim=2, keepdim=True) # Compute the MBIS valence shell widths. From 56da0671b198bc64c5dccbd5d73e6da89405771a Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Tue, 22 Oct 2024 11:55:41 +0100 Subject: [PATCH 31/39] Remove unnecessary logging. --- emle/models/_emle.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/emle/models/_emle.py b/emle/models/_emle.py index fd6c786..f23330d 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -27,8 +27,6 @@ __all__ = ["EMLE"] -from loguru import logger as _logger - import numpy as _np import os as _os import scipy.io as _scipy_io @@ -217,17 +215,13 @@ def __init__( if model is not None: if not isinstance(model, str): - msg = "'model' must be of type 'str'" - _logger.error(msg) - raise TypeError(msg) + raise TypeError("'model' must be of type 'str'") # Convert to an absolute path. abs_model = _os.path.abspath(model) if not _os.path.isfile(abs_model): - msg = f"Unable to locate EMLE embedding model file: '{model}'" - _logger.error(msg) - raise IOError(msg) + raise IOError(f"Unable to locate EMLE embedding model file: '{model}'") self._model = abs_model # Validate the species for the custom model. From 185cfd2342a98153eeb88b6ab0064a186ae113dd Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Tue, 22 Oct 2024 12:01:04 +0100 Subject: [PATCH 32/39] Add module attributes. [ci skip] --- emle/models/_emle_base.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index 5b95f9c..9c45569 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -20,6 +20,13 @@ # along with EMLE-Engine. If not, see . ##################################################################### +"""EMLE base model implementation.""" + +__author__ = "Kirill Zinovjev" +__email__ = "kzinovjev@gmail.com" + +__all__ = ["EMLEBase"] + import numpy as _np import torch as _torch From 874ddd8e388224dc232eb59896667b1f3579ce57 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Tue, 22 Oct 2024 12:27:11 +0100 Subject: [PATCH 33/39] Fix overloaded module methods and k_Z buffer. --- emle/models/_emle_base.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index 9c45569..772738d 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -207,7 +207,7 @@ def __init__( self.a_Thole = _torch.nn.Parameter(params["a_Thole"]) self.ref_values_s = _torch.nn.Parameter(params["ref_values_s"]) self.ref_values_chi = _torch.nn.Parameter(params["ref_values_chi"]) - self.k_Z = _torch.nn.Parameter(params["k_Z"]) + k_Z = _torch.nn.Parameter(params["k_Z"]) if self._alpha_mode == "reference": try: @@ -270,6 +270,7 @@ def __init__( self.register_buffer("_c_s", c_s) self.register_buffer("_c_chi", c_chi) self.register_buffer("_c_sqrtk", c_sqrtk) + self.register_buffer("_k_Z", k_Z) # Initalise an empty AEV tensor to use to store the AEVs in parent models. self._aev = _torch.empty(0, dtype=dtype, device=device) @@ -289,6 +290,7 @@ def to(self, *args, **kwargs): self._c_s = self._c_s.to(*args, **kwargs) self._c_chi = self._c_chi.to(*args, **kwargs) self._c_sqrtk = self._c_sqrtk.to(*args, **kwargs) + self._k_Z = self._k_Z.to(*args, **kwargs) # Check for a device type in args and update the device attribute. for arg in args: @@ -296,6 +298,8 @@ def to(self, *args, **kwargs): self._device = arg break + return self + def cuda(self, **kwargs): """ Move all model parameters and buffers to CUDA memory. @@ -314,6 +318,12 @@ def cuda(self, **kwargs): self._c_s = self._c_s.cuda(**kwargs) self._c_chi = self._c_chi.cuda(**kwargs) self._c_sqrtk = self._c_sqrtk.cuda(**kwargs) + self._k_Z = self._k_Z.cuda(**kwargs) + + # Update the device attribute. + self._device = self._species_map.device + + return self def cpu(self, **kwargs): """ @@ -333,6 +343,12 @@ def cpu(self, **kwargs): self._c_s = self._c_s.cpu(**kwargs) self._c_chi = self._c_chi.cpu(**kwargs) self._c_sqrtk = self._c_sqrtk.cpu(**kwargs) + self._k_Z = self._k_Z.cpu(**kwargs) + + # Update the device attribute. + self._device = self._species_map.device + + return self def double(self): """ @@ -349,6 +365,7 @@ def double(self): self._c_s = self._c_s.double() self._c_chi = self._c_chi.double() self._c_sqrtk = self._c_sqrtk.double() + self._k_Z = self._k_Z.double() return self def float(self): @@ -366,6 +383,7 @@ def float(self): self._c_s = self._c_s.float() self._c_chi = self._c_chi.float() self._c_sqrtk = self._c_sqrtk.float() + self._k_Z = self._k_Z.float() return self def forward(self, atomic_numbers, xyz_qm, q_total): @@ -425,7 +443,7 @@ def forward(self, atomic_numbers, xyz_qm, q_total): q = self._get_q(r_data, s, chi, q_total, mask) q_val = q - q_core - k = self.k_Z[species_id] + k = self._k_Z[species_id] if self._alpha_mode == "reference": k_scale = ( From c2f12f97c7ae37286a936d43a62046219c2d7e6e Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Tue, 22 Oct 2024 12:54:19 +0100 Subject: [PATCH 34/39] Fix types. --- emle/calculator.py | 57 ++++++++++++++++++++++++++++++---------------- 1 file changed, 37 insertions(+), 20 deletions(-) diff --git a/emle/calculator.py b/emle/calculator.py index c20044a..4aad64b 100644 --- a/emle/calculator.py +++ b/emle/calculator.py @@ -1163,27 +1163,36 @@ def run(self, path=None): E_vac += delta_E grad_vac += delta_grad - # Store a copy of the QM coordinates as a NumPy array. + # Store a copy of the atomic numbers and QM coordinates as NumPy arrays. + atomic_numbers_np = atomic_numbers xyz_qm_np = xyz_qm # Convert inputs to Torch tensors. + atomic_numbers = _torch.tensor( + atomic_numbers, dtype=_torch.int64, device=self._device + ) + charges_mm = _torch.tensor( + charges_mm, dtype=_torch.float32, device=self._device + ) xyz_qm = _torch.tensor( xyz_qm, dtype=_torch.float32, device=self._device, requires_grad=True ) xyz_mm = _torch.tensor( xyz_mm, dtype=_torch.float32, device=self._device, requires_grad=True ) - charges_mm = _torch.tensor( - charges_mm, dtype=_torch.float32, device=self._device - ) # Compute energy and gradients. - E = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm) - dE_dxyz_qm_bohr, dE_dxyz_mm_bohr = _torch.autograd.grad( - E.sum(), (xyz_qm, xyz_mm) - ) - dE_dxyz_qm_bohr = dE_dxyz_qm_bohr.cpu().numpy() - dE_dxyz_mm_bohr = dE_dxyz_mm_bohr.cpu().numpy() + try: + E = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm) + dE_dxyz_qm_bohr, dE_dxyz_mm_bohr = _torch.autograd.grad( + E.sum(), (xyz_qm, xyz_mm) + ) + dE_dxyz_qm_bohr = dE_dxyz_qm_bohr.cpu().numpy() + dE_dxyz_mm_bohr = dE_dxyz_mm_bohr.cpu().numpy() + except Exception as e: + msg = f"Failed to compute EMLE energies and gradients: {e}" + _logger.error(msg) + raise RuntimeError(msg) # Compute the total energy and gradients. E_tot = E_vac + E.sum().detach().cpu().numpy() @@ -1283,7 +1292,7 @@ def run(self, path=None): # Write out the QM region to the xyz trajectory file. if self._qm_xyz_frequency > 0 and self._step % self._qm_xyz_frequency == 0: - atoms = _ase.Atoms(positions=xyz_qm_np, numbers=atomic_numbers) + atoms = _ase.Atoms(positions=xyz_qm_np, numbers=atomic_numbers_np) if hasattr(self, "_max_f_std"): atoms.info = {"max_f_std": self._max_f_std} _ase_io.write(self._qm_xyz_file, atoms, append=True) @@ -1553,23 +1562,31 @@ def _sire_callback(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm, idx_mm=None ) # Convert inputs to Torch tensors. + atomic_numbers = _torch.tensor( + atomic_numbers, dtype=_torch.int64, device=self._device + ) + charges_mm = _torch.tensor( + charges_mm, dtype=_torch.float32, device=self._device + ) xyz_qm = _torch.tensor( xyz_qm, dtype=_torch.float32, device=self._device, requires_grad=True ) xyz_mm = _torch.tensor( xyz_mm, dtype=_torch.float32, device=self._device, requires_grad=True ) - charges_mm = _torch.tensor( - charges_mm, dtype=_torch.float32, device=self._device - ) # Compute energy and gradients. - E = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm) - dE_dxyz_qm_bohr, dE_dxyz_mm_bohr = _torch.autograd.grad( - E.sum(), (xyz_qm, xyz_mm) - ) - dE_dxyz_qm_bohr = dE_dxyz_qm_bohr.cpu().numpy() - dE_dxyz_mm_bohr = dE_dxyz_mm_bohr.cpu().numpy() + try: + E = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm) + dE_dxyz_qm_bohr, dE_dxyz_mm_bohr = _torch.autograd.grad( + E.sum(), (xyz_qm, xyz_mm) + ) + dE_dxyz_qm_bohr = dE_dxyz_qm_bohr.cpu().numpy() + dE_dxyz_mm_bohr = dE_dxyz_mm_bohr.cpu().numpy() + except Exception as e: + msg = f"Failed to compute EMLE energies and gradients: {e}" + _logger.error(msg) + raise RuntimeError(msg) # Compute the total energy and gradients. E_tot = E_vac + E.sum().detach().cpu().numpy() From 085f33d9b90d8819af790bf683fa60b253989eaf Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Tue, 22 Oct 2024 13:07:53 +0100 Subject: [PATCH 35/39] Docstring tweaks. [ci skip] --- emle/models/_emle_base.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index 772738d..24e0663 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -401,7 +401,7 @@ def forward(self, atomic_numbers, xyz_qm, q_total): Positions of QM atoms in Angstrom. q_total: torch.Tensor (1,) - Total charge + Total charge. Returns ------- @@ -483,6 +483,10 @@ def _get_Kinv(cls, ref_features, sigma): @classmethod def _get_c(cls, n_ref, ref, Kinv): + """ + Internal method to compute the coefficients of the GPR model. + """ + mask = _torch.arange(ref.shape[1], device=n_ref.device) < n_ref[:, None] ref_mean = _torch.sum(ref * mask, dim=1) / n_ref ref_shifted = ref - ref_mean[:, None] From 8374a704d4df1e10308cffedc36012bcc949ad8b Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Tue, 22 Oct 2024 13:51:42 +0100 Subject: [PATCH 36/39] Clarify call to base EMLE model. [ci skip] --- emle/models/_emle.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/emle/models/_emle.py b/emle/models/_emle.py index f23330d..7422383 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -467,6 +467,10 @@ def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm): if len(xyz_mm) == 0: return _torch.zeros(2, dtype=xyz_qm.dtype, device=xyz_qm.device) + # Get the parameters from the base model: + # valence widths, core charges, valence charges, A_thole tensor + # These are returned as batched tensors, so we need to extract the + # first element of each. s, q_core, q_val, A_thole = self._emle_base( atomic_numbers[None, :], xyz_qm[None, :, :], self._q_total[None] ) From 3b4d10477541d329c8b9e9d2cd54f09079c24b14 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Tue, 22 Oct 2024 14:00:11 +0100 Subject: [PATCH 37/39] Clarify AEV member attribute that can be set by parent model. [ci skip] --- emle/models/_emle_base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index 24e0663..bff8015 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -273,6 +273,8 @@ def __init__( self.register_buffer("_k_Z", k_Z) # Initalise an empty AEV tensor to use to store the AEVs in parent models. + # If AEVs are computed externally, then this tensor will be set by the + # parent. self._aev = _torch.empty(0, dtype=dtype, device=device) def to(self, *args, **kwargs): From 1bc96bf07bd9697565b2ff8899b328bb5f8a4fb3 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Tue, 22 Oct 2024 14:02:23 +0100 Subject: [PATCH 38/39] Default to combined AEV model. --- emle/models/_emle.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/emle/models/_emle.py b/emle/models/_emle.py index 7422383..9f17ac7 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -73,11 +73,8 @@ class EMLE(_torch.nn.Module): _os.path.dirname(_os.path.abspath(__file__)), "..", "resources" ) - # Create the name of the default model file for each alpha mode. - _default_models = { - "species": _os.path.join(_resource_dir, "emle_qm7_aev_species.mat"), - "reference": _os.path.join(_resource_dir, "emle_qm7_aev_reference.mat"), - } + # Create the name of the default model file. + _default_model = _os.path.join(_resource_dir, "emle_qm7_aev.mat") # Store the list of supported species. _species = [1, 6, 7, 8, 16] @@ -247,8 +244,8 @@ def __init__( # Set to None as this will be used in any calculator configuration. self._model = None - # Choose the model based on the alpha_mode. - model = self._default_models[alpha_mode] + # Use the default model. + model = self._default_model # Use the default species. species = self._species From 85040b2793cc6f7aae0f115e7df9b94414f4ff07 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Tue, 22 Oct 2024 16:35:37 +0100 Subject: [PATCH 39/39] Make k_Z a torch parameter and fix module methods. --- emle/models/_ani.py | 22 +++++++++++++++++++++- emle/models/_emle.py | 15 ++------------- emle/models/_emle_base.py | 13 +++++-------- 3 files changed, 28 insertions(+), 22 deletions(-) diff --git a/emle/models/_ani.py b/emle/models/_ani.py index 619e316..b17d99d 100644 --- a/emle/models/_ani.py +++ b/emle/models/_ani.py @@ -144,6 +144,7 @@ def __init__( raise TypeError("'device' must be of type 'torch.device'") else: device = _torch.get_default_device() + self._device = device if dtype is not None: if not isinstance(dtype, _torch.dtype): @@ -229,8 +230,15 @@ def __init__( except: pass + # Add a hook to the ANI2x model to capture the AEV features. + self._add_hook() + + def _add_hook(self): + """ + Add a hook to the ANI2x model to capture the AEV features. + """ # Assign a tensor attribute that can be used for assigning the AEVs. - self._ani2x.aev_computer._aev = _torch.empty(0, device=device) + self._ani2x.aev_computer._aev = _torch.empty(0, device=self._device) # Hook the forward pass of the ANI2x model to get the AEV features. # Note that this currently requires a patched versions of TorchANI and NNPOps. @@ -261,6 +269,13 @@ def to(self, *args, **kwargs): """ self._emle = self._emle.to(*args, **kwargs) self._ani2x = self._ani2x.to(*args, **kwargs) + + # Check for a device type in args and update the device attribute. + for arg in args: + if isinstance(arg, _torch.device): + self._device = arg + break + return self def cpu(self, **kwargs): @@ -269,6 +284,7 @@ def cpu(self, **kwargs): """ self._emle = self._emle.cpu(**kwargs) self._ani2x = self._ani2x.cpu(**kwargs) + self._device = _torch.device("cpu") return self def cuda(self, **kwargs): @@ -277,6 +293,7 @@ def cuda(self, **kwargs): """ self._emle = self._emle.cuda(**kwargs) self._ani2x = self._ani2x.cuda(**kwargs) + self._device = _torch.device("cuda") return self def double(self): @@ -306,6 +323,9 @@ def float(self): except: pass + # Re-append the hook. + self._add_hook() + return self def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm): diff --git a/emle/models/_emle.py b/emle/models/_emle.py index 9f17ac7..344faf1 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -353,17 +353,6 @@ def __init__( dtype=dtype, ) - def _to_dict(self): - """ - Return the configuration of the module as a dictionary. - """ - return { - "model": self._model, - "method": self._method, - "species": self._species_map.tolist(), - "alpha_mode": self._alpha_mode, - } - def to(self, *args, **kwargs): """ Performs Tensor dtype and/or device conversion on the model. @@ -393,7 +382,7 @@ def cuda(self, **kwargs): self._emle_base = self._emle_base.cuda(**kwargs) # Update the device attribute. - self._device = self._species_map.device + self._device = self._q_total.device return self @@ -408,7 +397,7 @@ def cpu(self, **kwargs): self._emle_base = self._emle_base.cpu() # Update the device attribute. - self._device = self._species_map.device + self._device = self._q_total.device return self diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index bff8015..5b9b3b1 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -207,7 +207,7 @@ def __init__( self.a_Thole = _torch.nn.Parameter(params["a_Thole"]) self.ref_values_s = _torch.nn.Parameter(params["ref_values_s"]) self.ref_values_chi = _torch.nn.Parameter(params["ref_values_chi"]) - k_Z = _torch.nn.Parameter(params["k_Z"]) + self.k_Z = _torch.nn.Parameter(params["k_Z"]) if self._alpha_mode == "reference": try: @@ -270,7 +270,6 @@ def __init__( self.register_buffer("_c_s", c_s) self.register_buffer("_c_chi", c_chi) self.register_buffer("_c_sqrtk", c_sqrtk) - self.register_buffer("_k_Z", k_Z) # Initalise an empty AEV tensor to use to store the AEVs in parent models. # If AEVs are computed externally, then this tensor will be set by the @@ -292,7 +291,6 @@ def to(self, *args, **kwargs): self._c_s = self._c_s.to(*args, **kwargs) self._c_chi = self._c_chi.to(*args, **kwargs) self._c_sqrtk = self._c_sqrtk.to(*args, **kwargs) - self._k_Z = self._k_Z.to(*args, **kwargs) # Check for a device type in args and update the device attribute. for arg in args: @@ -320,7 +318,7 @@ def cuda(self, **kwargs): self._c_s = self._c_s.cuda(**kwargs) self._c_chi = self._c_chi.cuda(**kwargs) self._c_sqrtk = self._c_sqrtk.cuda(**kwargs) - self._k_Z = self._k_Z.cuda(**kwargs) + self.k_Z = self.k_Z.cuda(**kwargs) # Update the device attribute. self._device = self._species_map.device @@ -345,7 +343,6 @@ def cpu(self, **kwargs): self._c_s = self._c_s.cpu(**kwargs) self._c_chi = self._c_chi.cpu(**kwargs) self._c_sqrtk = self._c_sqrtk.cpu(**kwargs) - self._k_Z = self._k_Z.cpu(**kwargs) # Update the device attribute. self._device = self._species_map.device @@ -367,7 +364,7 @@ def double(self): self._c_s = self._c_s.double() self._c_chi = self._c_chi.double() self._c_sqrtk = self._c_sqrtk.double() - self._k_Z = self._k_Z.double() + self.k_Z = _torch.nn.Parameter(self.k_Z.double()) return self def float(self): @@ -385,7 +382,7 @@ def float(self): self._c_s = self._c_s.float() self._c_chi = self._c_chi.float() self._c_sqrtk = self._c_sqrtk.float() - self._k_Z = self._k_Z.float() + self.k_Z = _torch.nn.Parameter(self.k_Z.float()) return self def forward(self, atomic_numbers, xyz_qm, q_total): @@ -445,7 +442,7 @@ def forward(self, atomic_numbers, xyz_qm, q_total): q = self._get_q(r_data, s, chi, q_total, mask) q_val = q - q_core - k = self._k_Z[species_id] + k = self.k_Z[species_id] if self._alpha_mode == "reference": k_scale = (