diff --git a/python/chemiscope/explore.py b/python/chemiscope/explore.py index b279444d4..8f253542f 100644 --- a/python/chemiscope/explore.py +++ b/python/chemiscope/explore.py @@ -110,10 +110,9 @@ def soap_kpca_featurize(frames, _environments): # Use default featurizer else: - centers = None if environments is not None: - centers, frames = _pick_env_frames(environments, frames) - X_reduced = soap_pca_featurize(frames, centers) + frames = [frames[env_index] for env_index, _, _ in environments] + X_reduced = soap_pca_featurize(frames, environments) # Add dimensionality reduction results to properties properties["features"] = X_reduced @@ -122,7 +121,7 @@ def soap_kpca_featurize(frames, _environments): return show(frames=frames, properties=properties, mode=mode) -def soap_pca_featurize(frames, centers=None): +def soap_pca_featurize(frames, environments=None): """ Computes SOAP features for a given set of atomic structures and performs dimensionality reduction using PCA. Custom featurize functions should @@ -142,6 +141,12 @@ def soap_pca_featurize(frames, centers=None): f"Required package not found: {str(e)}. Please install dependency " + "using 'pip install chemiscope[explore]'." ) + centers = None + + # Get the atom indexes from the environments and pick related frames + if environments is not None: + centers = _extract_environment_indices(environments) + # Get global species species = set() for frame in frames: @@ -174,18 +179,16 @@ def soap_pca_featurize(frames, centers=None): return pca.fit_transform(feats) -def _pick_env_frames(envs, frames): +def _extract_environment_indices(envs): """ - Get environment indices par structures and pick corresponding frames + Extract environment indices per structure :param: list envs: each element is a list of [env_index, atom_index, cutoff] - :param: list frames: list of frames + :return: dict of structure indices mapping to lists of atom indices """ grouped_envs = {} - picked_frames = [] for [env_index, atom_index, _cutoff] in envs: if env_index not in grouped_envs: grouped_envs[env_index] = [] grouped_envs[env_index].append(atom_index) - picked_frames.append(frames[env_index]) - return list(grouped_envs.values()), picked_frames + return list(grouped_envs.values())