Skip to content

Commit

Permalink
Add environments aurgument
Browse files Browse the repository at this point in the history
  • Loading branch information
sofiia-chorna committed Sep 2, 2024
1 parent d162f87 commit ba8b273
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 10 deletions.
30 changes: 25 additions & 5 deletions python/chemiscope/explore.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .jupyter import show


def explore(frames, featurize=None, properties=None, mode="default"):
def explore(frames, featurize=None, properties=None, environments=None, mode="default"):
"""
Automatically explore a dataset containing all structures in ``frames``.
Expand All @@ -30,6 +30,12 @@ def explore(frames, featurize=None, properties=None, mode="default"):
with the atomic structures. Properties can be extracted from frames with
:py:func:`extract_properties` or manually defined by the user.
:param environments: optional. List of environments (described as
``(structure id, center id, cutoff)``) to include when extracting the
atomic properties. Can be extracted from frames with
:py:func:`all_atomic_environments` (or :py:func:`librascal_atomic_environments`)
or manually defined.
:param str mode: optional. Visualization mode for the chemiscope widget. Can be one
of "default", "structure", or "map". The default mode is "default".
Expand Down Expand Up @@ -100,11 +106,14 @@ def soap_kpca_featurize(frames):

# Apply dimensionality reduction from the provided featurizer
if featurize is not None:
X_reduced = featurize(frames)
X_reduced = featurize(frames, environments)

# Use default featurizer
else:
X_reduced = soap_pca_featurize(frames)
centers = None
if environments is not None:
centers, frames = _pick_env_frames(environments, frames)
X_reduced = soap_pca_featurize(frames, centers)

# Add dimensionality reduction results to properties
properties["features"] = X_reduced
Expand All @@ -113,7 +122,7 @@ def soap_kpca_featurize(frames):
return show(frames=frames, properties=properties, mode=mode)


def soap_pca_featurize(frames):
def soap_pca_featurize(frames, centers=None):
"""
Computes SOAP features for a given set of atomic structures and performs
dimensionality reduction using PCA. Custom featurize functions should
Expand Down Expand Up @@ -158,8 +167,19 @@ def soap_pca_featurize(frames):

# Calculate descriptors
n_jobs = min(len(frames), os.cpu_count())
feats = soap.create(frames, n_jobs=n_jobs)
feats = soap.create(frames, centers=centers, n_jobs=n_jobs)

# Compute pca
pca = PCA(n_components=2)
return pca.fit_transform(feats)


def _pick_env_frames(envs, frames):
grouped_envs = {}
picked_frames = []
for [env_index, atom_index, _cutoff] in envs:
if env_index not in grouped_envs:
grouped_envs[env_index] = []
grouped_envs[env_index].append(atom_index)
picked_frames.append(frames[env_index])
return list(grouped_envs.values()), picked_frames
8 changes: 5 additions & 3 deletions python/examples/6-explore.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,10 @@ def fetch_dataset(filename, base_url="https://zenodo.org/records/12748925/files/
#
# Provide the frames to the :py:func:`chemiscope.explore`. It will generate a Chemiscope
# interactive widget with the reduced dimensionality of data.

chemiscope.explore(frames)
envs = [(0, 0, 3.5),
(1, 1, 3.5),
(2, 2, 3.5)]
chemiscope.explore(frames, environments=envs)

# %%
#
Expand Down Expand Up @@ -109,7 +111,7 @@ def fetch_dataset(filename, base_url="https://zenodo.org/records/12748925/files/
# passed to the ``featurize`` function.


def soap_kpca_featurize(frames):
def soap_kpca_featurize(frames, _environments):
# Initialise soap calculator. The detailed explanation of the provided
# hyperparameters can be checked in the documentation of the library (``dscribe``).
soap = SOAP(
Expand Down
4 changes: 2 additions & 2 deletions python/examples/7-explore-advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def fetch_dataset(filename, base_url="https://zenodo.org/records/12748925/files/
# return the reduced data.


def mace_off_tsne(frames):
def mace_off_tsne(frames, _environments):
# At first, we initialize a mace_off calculator:
descriptor_opt = {"model": "small", "device": "cpu", "default_dtype": "float64"}
calculator = mace_off(**descriptor_opt)
Expand Down Expand Up @@ -141,7 +141,7 @@ def mace_off_tsne(frames):
# different mace calculator.


def mace_mp0_tsne(frames):
def mace_mp0_tsne(frames, _environments):
# Initialise a mace-mp0 calculator
descriptor_opt = {"model": "small", "device": "cpu", "default_dtype": "float64"}
calculator = mace_mp(**descriptor_opt)
Expand Down

0 comments on commit ba8b273

Please sign in to comment.