diff --git a/emle/calculator.py b/emle/calculator.py index a5fbb27..b5bdec9 100644 --- a/emle/calculator.py +++ b/emle/calculator.py @@ -942,6 +942,7 @@ def __init__( self._species = self._emle._species self._method = self._emle._method self._alpha_mode = self._emle._alpha_mode + self._atomic_numbers = self._emle._atomic_numbers if isinstance(atomic_numbers, _np.ndarray): atomic_numbers = atomic_numbers.tolist() @@ -1754,7 +1755,12 @@ def _sire_callback_optimised( # Create the model. ani2x_emle = _ANI2xEMLE( emle_model=self._model, + emle_species=self._species, + alpha_model=self._alpha_model, + mm_charges=self._mm_charges, + model_index=self._ani2x_model_index, ani2x_model=self._torchani_model, + atomic_numbers=atomic_numbers, device=self._device, ) diff --git a/emle/models/_emle.py b/emle/models/_emle.py index ccb7c75..c2db9fc 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -196,6 +196,7 @@ def __init__( raise ValueError( "All elements of 'atomic_numbers' must be greater than zero" ) + self._atomic_numbers = atomic_numbers if method == "mm": if mm_charges is None: