Skip to content

Commit

Permalink
Create metatensor_featurizer function
Browse files Browse the repository at this point in the history
This allows taking a metatensor model that compute "features"
and use it in chemiscope.explore
  • Loading branch information
sofiia-chorna authored and Luthaf committed Sep 26, 2024
1 parent 01f8ee4 commit d423e68
Show file tree
Hide file tree
Showing 11 changed files with 488 additions and 95 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ dependencies with:
pip install chemiscope[explore]
```

To use `chemiscope.metatensor_featurizer` for providing your trained model
to get the features for `chemiscope.explore`, install the dependencies with:
```bash
pip install chemiscope[metatensor]
```

## sphinx and sphinx-gallery integration

Chemiscope provides also extensions for `sphinx` and `sphinx-gallery` to
Expand Down
2 changes: 2 additions & 0 deletions docs/src/python/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,5 @@
.. autofunction:: chemiscope.ase_tensors_to_ellipsoids

.. autofunction:: chemiscope.explore

.. autofunction:: chemiscope.metatensor_featurizer
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,8 @@ explore = [
"dscribe",
"scikit-learn",
]

metatensor = [
"metatensor",
"metatensor[torch]"
]
2 changes: 1 addition & 1 deletion python/chemiscope/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
extract_properties,
librascal_atomic_environments,
)
from .explore import explore # noqa: F401
from .explore import explore, metatensor_featurizer # noqa: F401
from .version import __version__ # noqa: F401

from .jupyter import show, show_input # noqa
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from ..jupyter import show
from ._soap_pca import soap_pca_featurize
from ._metatensor import metatensor_featurizer

from .jupyter import show
__all__ = ["explore", "metatensor_featurizer"]


def explore(frames, featurize=None, properties=None, environments=None, mode="default"):
Expand Down Expand Up @@ -116,96 +118,9 @@ def soap_kpca_featurize(frames, environments):
# Add dimensionality reduction results to properties
properties["features"] = X_reduced

# Return chemiscope widget
return show(
frames=frames, properties=properties, mode=mode, environments=environments
frames=frames,
properties=properties,
environments=environments,
mode=mode,
)


def soap_pca_featurize(frames, environments=None):
"""
Computes SOAP features for a given set of atomic structures and performs
dimensionality reduction using PCA. Custom featurize functions should
have the same signature.
Note:
- The SOAP descriptor parameters are pre-defined.
- We use all available CPU cores for parallel computation of SOAP descriptors.
"""

# Check if dependencies were installed
try:
from dscribe.descriptors import SOAP
from sklearn.decomposition import PCA
except ImportError as e:
raise ImportError(
f"Required package not found: {str(e)}. Please install dependency "
+ "using 'pip install chemiscope[explore]'."
)
centers = None

# Get the atom indexes from the environments and pick related frames
if environments is not None:
centers = _extract_environment_indices(environments)

# Pick frames and properties related to the environments if provided
if environments is not None:
# Sort environments by structure id and atom id
environments = sorted(environments, key=lambda x: (x[0], x[1]))

# Check structure indexes
unique_structures = list({env[0] for env in environments})
if any(index >= len(frames) for index in unique_structures):
raise IndexError(
"Some structure indices in 'environments' are larger than the number of"
"frames"
)

if len(unique_structures) != len(frames):
# only include frames that are present in the user-provided environments
frames = [frames[index] for index in unique_structures]

# Get global species
species = set()
for frame in frames:
species.update(frame.get_chemical_symbols())
species = list(species)

# Check if periodic
is_periodic = all(all(frame.get_pbc()) for frame in frames)

# Initialize calculator
soap = SOAP(
species=species,
r_cut=4.5,
n_max=8,
l_max=6,
sigma=0.2,
rbf="gto",
average="outer",
periodic=is_periodic,
weighting={"function": "pow", "c": 1, "m": 5, "d": 1, "r0": 3.5},
compression={"mode": "mu1nu1"},
)

# Calculate descriptors
n_jobs = min(len(frames), os.cpu_count())
feats = soap.create(frames, centers=centers, n_jobs=n_jobs)

# Compute pca
pca = PCA(n_components=2)
return pca.fit_transform(feats)


def _extract_environment_indices(envs):
"""
Convert from chemiscope's environements to DScribe's centers selection
:param: list envs: each element is a list of [env_index, atom_index, cutoff]
"""
grouped_envs = {}
for [env_index, atom_index, _cutoff] in envs:
if env_index not in grouped_envs:
grouped_envs[env_index] = []
grouped_envs[env_index].append(atom_index)
return list(grouped_envs.values())
148 changes: 148 additions & 0 deletions python/chemiscope/explore/_metatensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import numpy as np


def metatensor_featurizer(
model,
extensions_directory=None,
check_consistency=False,
device=None,
):
"""
Create a featurizer function using a `metatensor`_ model to obtain the features from
structures. The model must be able to create a ``"feature"`` output.
:param model: model to use for the calculation. It can be a file path, a Python
instance of :py:class:`metatensor.torch.atomistic.MetatensorAtomisticModel`, or
the output of :py:func:`torch.jit.script` on
:py:class:`metatensor.torch.atomistic.MetatensorAtomisticModel`.
:param extensions_directory: a directory where model extensions are located
:param check_consistency: should we check the model for consistency when running,
defaults to False.
:param device: a torch device to use for the calculation. If ``None``, the function
will use the options in model's ``supported_device`` attribute.
:returns: a function that takes a list of frames and returns the features.
To use this function, additional dependencies are required. They can be installed
with the following command:
.. code:: bash
pip install chemiscope[metatensor]
Here is an example using a pre-trained `metatensor`_ model, stored as a ``model.pt``
file with the compiled extensions stored in the ``extensions/`` directory. To obtain
the details on how to get it, see metatensor `tutorial
<https://lab-cosmo.github.io/metatrain/latest/getting-started/usage.html>`_. The
frames are obtained by reading structures from a file that `ase <ase-io_>`_ can
read.
.. code-block:: python
import chemiscope
import ase.io
# Read the structures from the dataset frames =
ase.io.read("data/explore_c-gap-20u.xyz", ":")
# Provide model file ("model.pt") to `metatensor_featurizer`
featurizer = chemiscope.metatensor_featurizer(
"model.pt", extensions_directory="extensions"
)
chemiscope.explore(frames, featurize=featurizer)
For more examples, see the related :ref:`documentation
<chemiscope-explore-metatensor>`.
.. _metatensor: https://docs.metatensor.org/latest/index.html
.. _chemiscope-explore-metatensor:
https://chemiscope.org/docs/examples/7-explore-advanced.html#example-with-metatensor-model
"""

# Check if dependencies were installed
try:
from metatensor.torch.atomistic import ModelOutput
from metatensor.torch.atomistic.ase_calculator import MetatensorCalculator
except ImportError as e:
raise ImportError(
f"Required package not found: {e}. Please install the dependency using "
"'pip install chemiscope[metatensor]'."
)

# Initialize metatensor calculator
CALCULATOR = MetatensorCalculator(
model=model,
extensions_directory=extensions_directory,
check_consistency=check_consistency,
device=device,
)

def get_features(atoms, environments):
"""Run the model on a single atomic structure and extract the features"""
outputs = {"features": ModelOutput(per_atom=environments is not None)}
selected_atoms = _create_selected_atoms(environments)
output = CALCULATOR.run_model(atoms, outputs, selected_atoms)

return output["features"].block().values.detach().cpu().numpy()

def featurize(frames, environments):
if isinstance(frames, list):
envs_per_frame = _get_environments_per_frame(environments, len(frames))

outputs = [
get_features(frame, envs) for frame, envs in zip(frames, envs_per_frame)
]
return np.vstack(outputs)
else:
return get_features(frames, environments)

return featurize


def _get_environments_per_frame(environments, num_frames):
"""
Organize the environments for each frame
:param list environments: a list of atomic environments
:param int num_frames: total number of frames
"""
envs_per_frame = [None] * num_frames

if environments is None:
return envs_per_frame

frames_dict = {}

# Group environments by structure_id
for env in environments:
structure_id = env[0]
if structure_id not in frames_dict:
frames_dict[structure_id] = []
frames_dict[structure_id].append(env)

# Insert environments to the frame indices
for structure_id, envs in frames_dict.items():
if structure_id < num_frames:
envs_per_frame[structure_id] = envs

return envs_per_frame


def _create_selected_atoms(environments):
"""
Convert environments into ``Labels`` object, to be used as ``selected_atoms``
:param environments: a list of atom-centered environments
"""
import torch
from metatensor.torch import Labels

if environments is None:
return None

# Extract system and atom indices from environments, overriding the structure id to
# be 0 (since we will only give a single frame to the calculator at the same time).
values = torch.tensor([(0, atom_id) for _, atom_id, _ in environments])

return Labels(names=["system", "atom"], values=values)
90 changes: 90 additions & 0 deletions python/chemiscope/explore/_soap_pca.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import os


def soap_pca_featurize(frames, environments=None):
"""
Computes SOAP features for a given set of atomic structures and performs
dimensionality reduction using PCA. Custom featurize functions should
have the same signature.
Note:
- The SOAP descriptor parameters are pre-defined.
- We use all available CPU cores for parallel computation of SOAP descriptors.
"""

# Check if dependencies were installed
try:
from dscribe.descriptors import SOAP
from sklearn.decomposition import PCA
except ImportError as e:
raise ImportError(
f"Required package not found: {str(e)}. Please install dependency "
+ "using 'pip install chemiscope[explore]'."
)
centers = None

# Get the atom indexes from the environments and pick related frames
if environments is not None:
centers = _extract_environment_indices(environments)

# Pick frames and properties related to the environments if provided
if environments is not None:
# Sort environments by structure id and atom id
environments = sorted(environments, key=lambda x: (x[0], x[1]))

# Check structure indexes
unique_structures = list({env[0] for env in environments})
if any(index >= len(frames) for index in unique_structures):
raise IndexError(
"Some structure indices in 'environments' are larger than the number of"
"frames"
)

if len(unique_structures) != len(frames):
# only include frames that are present in the user-provided environments
frames = [frames[index] for index in unique_structures]

# Get global species
species = set()
for frame in frames:
species.update(frame.get_chemical_symbols())
species = list(species)

# Check if periodic
is_periodic = all(all(frame.get_pbc()) for frame in frames)

# Initialize calculator
soap = SOAP(
species=species,
r_cut=4.5,
n_max=8,
l_max=6,
sigma=0.2,
rbf="gto",
average="outer",
periodic=is_periodic,
weighting={"function": "pow", "c": 1, "m": 5, "d": 1, "r0": 3.5},
compression={"mode": "mu1nu1"},
)

# Calculate descriptors
n_jobs = min(len(frames), os.cpu_count())
feats = soap.create(frames, centers=centers, n_jobs=n_jobs)

# Compute pca
pca = PCA(n_components=2)
return pca.fit_transform(feats)


def _extract_environment_indices(environments):
"""
Convert from chemiscope's environments to DScribe's centers selection
:param: list environments: each element is a list of [env_index, atom_index, cutoff]
"""
grouped_envs = {}
for [env_index, atom_index, _cutoff] in environments:
if env_index not in grouped_envs:
grouped_envs[env_index] = []
grouped_envs[env_index].append(atom_index)
return list(grouped_envs.values())
1 change: 1 addition & 0 deletions python/examples/7-explore-advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def mace_mp0_tsne(frames, environments):
fetch_dataset("mace-mp-tsne-m3cd.json.gz")
chemiscope.show_input("data/mace-mp-tsne-m3cd.json.gz")


# %%
#
# Example with SOAP, t-SNE and environments
Expand Down
Loading

0 comments on commit d423e68

Please sign in to comment.