From ba8b273bc80ea34af28e94d10ec53e091cc9bb84 Mon Sep 17 00:00:00 2001 From: Sofiia Chorna Date: Mon, 2 Sep 2024 11:18:46 +0200 Subject: [PATCH] Add environments aurgument --- python/chemiscope/explore.py | 30 ++++++++++++++++++++++----- python/examples/6-explore.py | 8 ++++--- python/examples/7-explore-advanced.py | 4 ++-- 3 files changed, 32 insertions(+), 10 deletions(-) diff --git a/python/chemiscope/explore.py b/python/chemiscope/explore.py index 9a583bc8c..7e187e6e2 100644 --- a/python/chemiscope/explore.py +++ b/python/chemiscope/explore.py @@ -3,7 +3,7 @@ from .jupyter import show -def explore(frames, featurize=None, properties=None, mode="default"): +def explore(frames, featurize=None, properties=None, environments=None, mode="default"): """ Automatically explore a dataset containing all structures in ``frames``. @@ -30,6 +30,12 @@ def explore(frames, featurize=None, properties=None, mode="default"): with the atomic structures. Properties can be extracted from frames with :py:func:`extract_properties` or manually defined by the user. + :param environments: optional. List of environments (described as + ``(structure id, center id, cutoff)``) to include when extracting the + atomic properties. Can be extracted from frames with + :py:func:`all_atomic_environments` (or :py:func:`librascal_atomic_environments`) + or manually defined. + :param str mode: optional. Visualization mode for the chemiscope widget. Can be one of "default", "structure", or "map". The default mode is "default". @@ -100,11 +106,14 @@ def soap_kpca_featurize(frames): # Apply dimensionality reduction from the provided featurizer if featurize is not None: - X_reduced = featurize(frames) + X_reduced = featurize(frames, environments) # Use default featurizer else: - X_reduced = soap_pca_featurize(frames) + centers = None + if environments is not None: + centers, frames = _pick_env_frames(environments, frames) + X_reduced = soap_pca_featurize(frames, centers) # Add dimensionality reduction results to properties properties["features"] = X_reduced @@ -113,7 +122,7 @@ def soap_kpca_featurize(frames): return show(frames=frames, properties=properties, mode=mode) -def soap_pca_featurize(frames): +def soap_pca_featurize(frames, centers=None): """ Computes SOAP features for a given set of atomic structures and performs dimensionality reduction using PCA. Custom featurize functions should @@ -158,8 +167,19 @@ def soap_pca_featurize(frames): # Calculate descriptors n_jobs = min(len(frames), os.cpu_count()) - feats = soap.create(frames, n_jobs=n_jobs) + feats = soap.create(frames, centers=centers, n_jobs=n_jobs) # Compute pca pca = PCA(n_components=2) return pca.fit_transform(feats) + + +def _pick_env_frames(envs, frames): + grouped_envs = {} + picked_frames = [] + for [env_index, atom_index, _cutoff] in envs: + if env_index not in grouped_envs: + grouped_envs[env_index] = [] + grouped_envs[env_index].append(atom_index) + picked_frames.append(frames[env_index]) + return list(grouped_envs.values()), picked_frames diff --git a/python/examples/6-explore.py b/python/examples/6-explore.py index 05d5d108e..062e11705 100644 --- a/python/examples/6-explore.py +++ b/python/examples/6-explore.py @@ -69,8 +69,10 @@ def fetch_dataset(filename, base_url="https://zenodo.org/records/12748925/files/ # # Provide the frames to the :py:func:`chemiscope.explore`. It will generate a Chemiscope # interactive widget with the reduced dimensionality of data. - -chemiscope.explore(frames) +envs = [(0, 0, 3.5), + (1, 1, 3.5), + (2, 2, 3.5)] +chemiscope.explore(frames, environments=envs) # %% # @@ -109,7 +111,7 @@ def fetch_dataset(filename, base_url="https://zenodo.org/records/12748925/files/ # passed to the ``featurize`` function. -def soap_kpca_featurize(frames): +def soap_kpca_featurize(frames, _environments): # Initialise soap calculator. The detailed explanation of the provided # hyperparameters can be checked in the documentation of the library (``dscribe``). soap = SOAP( diff --git a/python/examples/7-explore-advanced.py b/python/examples/7-explore-advanced.py index 619d713b0..e9e238cf6 100644 --- a/python/examples/7-explore-advanced.py +++ b/python/examples/7-explore-advanced.py @@ -62,7 +62,7 @@ def fetch_dataset(filename, base_url="https://zenodo.org/records/12748925/files/ # return the reduced data. -def mace_off_tsne(frames): +def mace_off_tsne(frames, _environments): # At first, we initialize a mace_off calculator: descriptor_opt = {"model": "small", "device": "cpu", "default_dtype": "float64"} calculator = mace_off(**descriptor_opt) @@ -141,7 +141,7 @@ def mace_off_tsne(frames): # different mace calculator. -def mace_mp0_tsne(frames): +def mace_mp0_tsne(frames, _environments): # Initialise a mace-mp0 calculator descriptor_opt = {"model": "small", "device": "cpu", "default_dtype": "float64"} calculator = mace_mp(**descriptor_opt)