Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EMLE base model #29

Merged
merged 39 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
7797084
Import logger in models/_emle.py
kzinovjev Oct 20, 2024
885f99c
Refactor EMLE._get_mu_ind to receive A_thole as argument
kzinovjev Oct 20, 2024
0bd8b8c
Move core EMLE model functionality to EMLEBase class
kzinovjev Oct 20, 2024
870bf9e
Cleanup to/cuda/cpu/double/float methods in EMLE and EMLEBase
kzinovjev Oct 20, 2024
293cf51
Pass q_total as argument to EMLEBase.forward
kzinovjev Oct 20, 2024
cb19458
Remove _gpr method from EMLE class
kzinovjev Oct 20, 2024
a1fc1ff
Refactor EMLEBase to work with batches
kzinovjev Oct 20, 2024
fc55b1e
Refactor calculation of GPR coefficients (EMLEBase._get_c method)
kzinovjev Oct 20, 2024
3e7ed58
Explicit k_Z/sqrtk variables for species/reference models
kzinovjev Oct 20, 2024
61eaedc
Remove ref_values tensors from buffers (never used)
kzinovjev Oct 20, 2024
62c6d08
Register buffer for Kinv (will be needed during training)
kzinovjev Oct 20, 2024
d3d3862
Typo
kzinovjev Oct 20, 2024
2605d86
Move all parameters together, ensure same parameter shapes for specie…
kzinovjev Oct 20, 2024
c52c910
Register trainable model parameters with nn.Parameter
kzinovjev Oct 20, 2024
1cc242b
Update EMLEBase.__init__ docstring
kzinovjev Oct 20, 2024
2e23f43
Fix nans in padded A_thole calculation
kzinovjev Oct 20, 2024
d1e673d
Fix species mapping for highest supported element
kzinovjev Oct 20, 2024
9aae570
Fix padding in batched A_QEq matrix
kzinovjev Oct 21, 2024
40f4c0a
Remove model parameters from buffers and to/cuda etc. methods
kzinovjev Oct 21, 2024
2582b70
Typo
kzinovjev Oct 21, 2024
af0bfde
Fix mean GPR reference calculation
kzinovjev Oct 21, 2024
4e54769
Cleanup
kzinovjev Oct 21, 2024
4e685b4
Redefine reference model to work as a correction to the species one
kzinovjev Oct 21, 2024
1b81f8a
Fix auxiliary tensors created on wrong device
kzinovjev Oct 21, 2024
fa75176
Blacken.
lohedges Oct 22, 2024
3991d8d
Float variable needs to be local.
lohedges Oct 22, 2024
fd46988
Improve docstrings.
lohedges Oct 22, 2024
1fa05e3
Fix padded coordinates mask.
lohedges Oct 22, 2024
de48b29
Fix docstrings.
lohedges Oct 22, 2024
756c08c
Fixes to allow TorchScript plus refactoring and input validation.
lohedges Oct 22, 2024
56da067
Remove unnecessary logging.
lohedges Oct 22, 2024
185cfd2
Add module attributes. [ci skip]
lohedges Oct 22, 2024
874ddd8
Fix overloaded module methods and k_Z buffer.
lohedges Oct 22, 2024
c2f12f9
Fix types.
lohedges Oct 22, 2024
085f33d
Docstring tweaks. [ci skip]
lohedges Oct 22, 2024
8374a70
Clarify call to base EMLE model. [ci skip]
lohedges Oct 22, 2024
3b4d104
Clarify AEV member attribute that can be set by parent model. [ci skip]
lohedges Oct 22, 2024
1bc96bf
Default to combined AEV model.
lohedges Oct 22, 2024
85040b2
Make k_Z a torch parameter and fix module methods.
lohedges Oct 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 37 additions & 20 deletions emle/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,27 +1163,36 @@ def run(self, path=None):
E_vac += delta_E
grad_vac += delta_grad

# Store a copy of the QM coordinates as a NumPy array.
# Store a copy of the atomic numbers and QM coordinates as NumPy arrays.
atomic_numbers_np = atomic_numbers
xyz_qm_np = xyz_qm

# Convert inputs to Torch tensors.
atomic_numbers = _torch.tensor(
atomic_numbers, dtype=_torch.int64, device=self._device
)
charges_mm = _torch.tensor(
charges_mm, dtype=_torch.float32, device=self._device
)
xyz_qm = _torch.tensor(
xyz_qm, dtype=_torch.float32, device=self._device, requires_grad=True
)
xyz_mm = _torch.tensor(
xyz_mm, dtype=_torch.float32, device=self._device, requires_grad=True
)
charges_mm = _torch.tensor(
charges_mm, dtype=_torch.float32, device=self._device
)

# Compute energy and gradients.
E = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm)
dE_dxyz_qm_bohr, dE_dxyz_mm_bohr = _torch.autograd.grad(
E.sum(), (xyz_qm, xyz_mm)
)
dE_dxyz_qm_bohr = dE_dxyz_qm_bohr.cpu().numpy()
dE_dxyz_mm_bohr = dE_dxyz_mm_bohr.cpu().numpy()
try:
E = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm)
dE_dxyz_qm_bohr, dE_dxyz_mm_bohr = _torch.autograd.grad(
E.sum(), (xyz_qm, xyz_mm)
)
dE_dxyz_qm_bohr = dE_dxyz_qm_bohr.cpu().numpy()
dE_dxyz_mm_bohr = dE_dxyz_mm_bohr.cpu().numpy()
except Exception as e:
msg = f"Failed to compute EMLE energies and gradients: {e}"
_logger.error(msg)
raise RuntimeError(msg)

# Compute the total energy and gradients.
E_tot = E_vac + E.sum().detach().cpu().numpy()
Expand Down Expand Up @@ -1283,7 +1292,7 @@ def run(self, path=None):

# Write out the QM region to the xyz trajectory file.
if self._qm_xyz_frequency > 0 and self._step % self._qm_xyz_frequency == 0:
atoms = _ase.Atoms(positions=xyz_qm_np, numbers=atomic_numbers)
atoms = _ase.Atoms(positions=xyz_qm_np, numbers=atomic_numbers_np)
if hasattr(self, "_max_f_std"):
atoms.info = {"max_f_std": self._max_f_std}
_ase_io.write(self._qm_xyz_file, atoms, append=True)
Expand Down Expand Up @@ -1553,23 +1562,31 @@ def _sire_callback(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm, idx_mm=None
)

# Convert inputs to Torch tensors.
atomic_numbers = _torch.tensor(
atomic_numbers, dtype=_torch.int64, device=self._device
)
charges_mm = _torch.tensor(
charges_mm, dtype=_torch.float32, device=self._device
)
xyz_qm = _torch.tensor(
xyz_qm, dtype=_torch.float32, device=self._device, requires_grad=True
)
xyz_mm = _torch.tensor(
xyz_mm, dtype=_torch.float32, device=self._device, requires_grad=True
)
charges_mm = _torch.tensor(
charges_mm, dtype=_torch.float32, device=self._device
)

# Compute energy and gradients.
E = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm)
dE_dxyz_qm_bohr, dE_dxyz_mm_bohr = _torch.autograd.grad(
E.sum(), (xyz_qm, xyz_mm)
)
dE_dxyz_qm_bohr = dE_dxyz_qm_bohr.cpu().numpy()
dE_dxyz_mm_bohr = dE_dxyz_mm_bohr.cpu().numpy()
try:
E = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm)
dE_dxyz_qm_bohr, dE_dxyz_mm_bohr = _torch.autograd.grad(
E.sum(), (xyz_qm, xyz_mm)
)
dE_dxyz_qm_bohr = dE_dxyz_qm_bohr.cpu().numpy()
dE_dxyz_mm_bohr = dE_dxyz_mm_bohr.cpu().numpy()
except Exception as e:
msg = f"Failed to compute EMLE energies and gradients: {e}"
_logger.error(msg)
raise RuntimeError(msg)

# Compute the total energy and gradients.
E_tot = E_vac + E.sum().detach().cpu().numpy()
Expand Down
1 change: 1 addition & 0 deletions emle/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
# avoid severe module import overheads when running the client code,
# which requires no EMLE functionality.

from ._emle_base import EMLEBase
from ._emle import EMLE
from ._ani import ANI2xEMLE
from ._mace import MACEEMLE
28 changes: 24 additions & 4 deletions emle/models/_ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def __init__(
raise TypeError("'device' must be of type 'torch.device'")
else:
device = _torch.get_default_device()
self._device = device

if dtype is not None:
if not isinstance(dtype, _torch.dtype):
Expand Down Expand Up @@ -229,8 +230,15 @@ def __init__(
except:
pass

# Add a hook to the ANI2x model to capture the AEV features.
self._add_hook()

def _add_hook(self):
"""
Add a hook to the ANI2x model to capture the AEV features.
"""
# Assign a tensor attribute that can be used for assigning the AEVs.
self._ani2x.aev_computer._aev = _torch.empty(0, device=device)
self._ani2x.aev_computer._aev = _torch.empty(0, device=self._device)

# Hook the forward pass of the ANI2x model to get the AEV features.
# Note that this currently requires a patched versions of TorchANI and NNPOps.
Expand All @@ -241,7 +249,7 @@ def hook(
input: Tuple[Tuple[Tensor, Tensor], Optional[Tensor], Optional[Tensor]],
output: Tuple[Tensor, Tensor],
):
module._aev = output[1][0]
module._aev = output[1]

else:

Expand All @@ -250,7 +258,7 @@ def hook(
input: Tuple[Tuple[Tensor, Tensor], Optional[Tensor], Optional[Tensor]],
output: _torchani.aev.SpeciesAEV,
):
module._aev = output[1][0]
module._aev = output[1]

# Register the hook.
self._aev_hook = self._ani2x.aev_computer.register_forward_hook(hook)
Expand All @@ -261,6 +269,13 @@ def to(self, *args, **kwargs):
"""
self._emle = self._emle.to(*args, **kwargs)
self._ani2x = self._ani2x.to(*args, **kwargs)

# Check for a device type in args and update the device attribute.
for arg in args:
if isinstance(arg, _torch.device):
self._device = arg
break

return self

def cpu(self, **kwargs):
Expand All @@ -269,6 +284,7 @@ def cpu(self, **kwargs):
"""
self._emle = self._emle.cpu(**kwargs)
self._ani2x = self._ani2x.cpu(**kwargs)
self._device = _torch.device("cpu")
return self

def cuda(self, **kwargs):
Expand All @@ -277,6 +293,7 @@ def cuda(self, **kwargs):
"""
self._emle = self._emle.cuda(**kwargs)
self._ani2x = self._ani2x.cuda(**kwargs)
self._device = _torch.device("cuda")
return self

def double(self):
Expand Down Expand Up @@ -306,6 +323,9 @@ def float(self):
except:
pass

# Re-append the hook.
self._add_hook()

return self

def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm):
Expand Down Expand Up @@ -351,7 +371,7 @@ def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm):

# Set the AEVs captured by the forward hook as an attribute of the
# EMLE model.
self._emle._aev = self._ani2x.aev_computer._aev
self._emle._emle_base._aev = self._ani2x.aev_computer._aev

# Get the EMLE energy components.
E_emle = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm)
Expand Down
Loading
Loading