-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create metatensor_featurizer function
This allows taking a metatensor model that compute "features" and use it in chemiscope.explore
- Loading branch information
1 parent
01f8ee4
commit d423e68
Showing
11 changed files
with
488 additions
and
95 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -75,3 +75,8 @@ explore = [ | |
"dscribe", | ||
"scikit-learn", | ||
] | ||
|
||
metatensor = [ | ||
"metatensor", | ||
"metatensor[torch]" | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
import numpy as np | ||
|
||
|
||
def metatensor_featurizer( | ||
model, | ||
extensions_directory=None, | ||
check_consistency=False, | ||
device=None, | ||
): | ||
""" | ||
Create a featurizer function using a `metatensor`_ model to obtain the features from | ||
structures. The model must be able to create a ``"feature"`` output. | ||
:param model: model to use for the calculation. It can be a file path, a Python | ||
instance of :py:class:`metatensor.torch.atomistic.MetatensorAtomisticModel`, or | ||
the output of :py:func:`torch.jit.script` on | ||
:py:class:`metatensor.torch.atomistic.MetatensorAtomisticModel`. | ||
:param extensions_directory: a directory where model extensions are located | ||
:param check_consistency: should we check the model for consistency when running, | ||
defaults to False. | ||
:param device: a torch device to use for the calculation. If ``None``, the function | ||
will use the options in model's ``supported_device`` attribute. | ||
:returns: a function that takes a list of frames and returns the features. | ||
To use this function, additional dependencies are required. They can be installed | ||
with the following command: | ||
.. code:: bash | ||
pip install chemiscope[metatensor] | ||
Here is an example using a pre-trained `metatensor`_ model, stored as a ``model.pt`` | ||
file with the compiled extensions stored in the ``extensions/`` directory. To obtain | ||
the details on how to get it, see metatensor `tutorial | ||
<https://lab-cosmo.github.io/metatrain/latest/getting-started/usage.html>`_. The | ||
frames are obtained by reading structures from a file that `ase <ase-io_>`_ can | ||
read. | ||
.. code-block:: python | ||
import chemiscope | ||
import ase.io | ||
# Read the structures from the dataset frames = | ||
ase.io.read("data/explore_c-gap-20u.xyz", ":") | ||
# Provide model file ("model.pt") to `metatensor_featurizer` | ||
featurizer = chemiscope.metatensor_featurizer( | ||
"model.pt", extensions_directory="extensions" | ||
) | ||
chemiscope.explore(frames, featurize=featurizer) | ||
For more examples, see the related :ref:`documentation | ||
<chemiscope-explore-metatensor>`. | ||
.. _metatensor: https://docs.metatensor.org/latest/index.html | ||
.. _chemiscope-explore-metatensor: | ||
https://chemiscope.org/docs/examples/7-explore-advanced.html#example-with-metatensor-model | ||
""" | ||
|
||
# Check if dependencies were installed | ||
try: | ||
from metatensor.torch.atomistic import ModelOutput | ||
from metatensor.torch.atomistic.ase_calculator import MetatensorCalculator | ||
except ImportError as e: | ||
raise ImportError( | ||
f"Required package not found: {e}. Please install the dependency using " | ||
"'pip install chemiscope[metatensor]'." | ||
) | ||
|
||
# Initialize metatensor calculator | ||
CALCULATOR = MetatensorCalculator( | ||
model=model, | ||
extensions_directory=extensions_directory, | ||
check_consistency=check_consistency, | ||
device=device, | ||
) | ||
|
||
def get_features(atoms, environments): | ||
"""Run the model on a single atomic structure and extract the features""" | ||
outputs = {"features": ModelOutput(per_atom=environments is not None)} | ||
selected_atoms = _create_selected_atoms(environments) | ||
output = CALCULATOR.run_model(atoms, outputs, selected_atoms) | ||
|
||
return output["features"].block().values.detach().cpu().numpy() | ||
|
||
def featurize(frames, environments): | ||
if isinstance(frames, list): | ||
envs_per_frame = _get_environments_per_frame(environments, len(frames)) | ||
|
||
outputs = [ | ||
get_features(frame, envs) for frame, envs in zip(frames, envs_per_frame) | ||
] | ||
return np.vstack(outputs) | ||
else: | ||
return get_features(frames, environments) | ||
|
||
return featurize | ||
|
||
|
||
def _get_environments_per_frame(environments, num_frames): | ||
""" | ||
Organize the environments for each frame | ||
:param list environments: a list of atomic environments | ||
:param int num_frames: total number of frames | ||
""" | ||
envs_per_frame = [None] * num_frames | ||
|
||
if environments is None: | ||
return envs_per_frame | ||
|
||
frames_dict = {} | ||
|
||
# Group environments by structure_id | ||
for env in environments: | ||
structure_id = env[0] | ||
if structure_id not in frames_dict: | ||
frames_dict[structure_id] = [] | ||
frames_dict[structure_id].append(env) | ||
|
||
# Insert environments to the frame indices | ||
for structure_id, envs in frames_dict.items(): | ||
if structure_id < num_frames: | ||
envs_per_frame[structure_id] = envs | ||
|
||
return envs_per_frame | ||
|
||
|
||
def _create_selected_atoms(environments): | ||
""" | ||
Convert environments into ``Labels`` object, to be used as ``selected_atoms`` | ||
:param environments: a list of atom-centered environments | ||
""" | ||
import torch | ||
from metatensor.torch import Labels | ||
|
||
if environments is None: | ||
return None | ||
|
||
# Extract system and atom indices from environments, overriding the structure id to | ||
# be 0 (since we will only give a single frame to the calculator at the same time). | ||
values = torch.tensor([(0, atom_id) for _, atom_id, _ in environments]) | ||
|
||
return Labels(names=["system", "atom"], values=values) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import os | ||
|
||
|
||
def soap_pca_featurize(frames, environments=None): | ||
""" | ||
Computes SOAP features for a given set of atomic structures and performs | ||
dimensionality reduction using PCA. Custom featurize functions should | ||
have the same signature. | ||
Note: | ||
- The SOAP descriptor parameters are pre-defined. | ||
- We use all available CPU cores for parallel computation of SOAP descriptors. | ||
""" | ||
|
||
# Check if dependencies were installed | ||
try: | ||
from dscribe.descriptors import SOAP | ||
from sklearn.decomposition import PCA | ||
except ImportError as e: | ||
raise ImportError( | ||
f"Required package not found: {str(e)}. Please install dependency " | ||
+ "using 'pip install chemiscope[explore]'." | ||
) | ||
centers = None | ||
|
||
# Get the atom indexes from the environments and pick related frames | ||
if environments is not None: | ||
centers = _extract_environment_indices(environments) | ||
|
||
# Pick frames and properties related to the environments if provided | ||
if environments is not None: | ||
# Sort environments by structure id and atom id | ||
environments = sorted(environments, key=lambda x: (x[0], x[1])) | ||
|
||
# Check structure indexes | ||
unique_structures = list({env[0] for env in environments}) | ||
if any(index >= len(frames) for index in unique_structures): | ||
raise IndexError( | ||
"Some structure indices in 'environments' are larger than the number of" | ||
"frames" | ||
) | ||
|
||
if len(unique_structures) != len(frames): | ||
# only include frames that are present in the user-provided environments | ||
frames = [frames[index] for index in unique_structures] | ||
|
||
# Get global species | ||
species = set() | ||
for frame in frames: | ||
species.update(frame.get_chemical_symbols()) | ||
species = list(species) | ||
|
||
# Check if periodic | ||
is_periodic = all(all(frame.get_pbc()) for frame in frames) | ||
|
||
# Initialize calculator | ||
soap = SOAP( | ||
species=species, | ||
r_cut=4.5, | ||
n_max=8, | ||
l_max=6, | ||
sigma=0.2, | ||
rbf="gto", | ||
average="outer", | ||
periodic=is_periodic, | ||
weighting={"function": "pow", "c": 1, "m": 5, "d": 1, "r0": 3.5}, | ||
compression={"mode": "mu1nu1"}, | ||
) | ||
|
||
# Calculate descriptors | ||
n_jobs = min(len(frames), os.cpu_count()) | ||
feats = soap.create(frames, centers=centers, n_jobs=n_jobs) | ||
|
||
# Compute pca | ||
pca = PCA(n_components=2) | ||
return pca.fit_transform(feats) | ||
|
||
|
||
def _extract_environment_indices(environments): | ||
""" | ||
Convert from chemiscope's environments to DScribe's centers selection | ||
:param: list environments: each element is a list of [env_index, atom_index, cutoff] | ||
""" | ||
grouped_envs = {} | ||
for [env_index, atom_index, _cutoff] in environments: | ||
if env_index not in grouped_envs: | ||
grouped_envs[env_index] = [] | ||
grouped_envs[env_index].append(atom_index) | ||
return list(grouped_envs.values()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.