Skip to content

Commit

Permalink
Move extraction of env indices inside default featurizer
Browse files Browse the repository at this point in the history
  • Loading branch information
sofiia-chorna committed Sep 3, 2024
1 parent 89bfd9b commit 8ed4406
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions python/chemiscope/explore.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,9 @@ def soap_kpca_featurize(frames, _environments):

# Use default featurizer
else:
centers = None
if environments is not None:
centers, frames = _pick_env_frames(environments, frames)
X_reduced = soap_pca_featurize(frames, centers)
frames = [frames[env_index] for env_index, _, _ in environments]
X_reduced = soap_pca_featurize(frames, environments)

# Add dimensionality reduction results to properties
properties["features"] = X_reduced
Expand All @@ -122,7 +121,7 @@ def soap_kpca_featurize(frames, _environments):
return show(frames=frames, properties=properties, mode=mode)


def soap_pca_featurize(frames, centers=None):
def soap_pca_featurize(frames, environments=None):
"""
Computes SOAP features for a given set of atomic structures and performs
dimensionality reduction using PCA. Custom featurize functions should
Expand All @@ -142,6 +141,12 @@ def soap_pca_featurize(frames, centers=None):
f"Required package not found: {str(e)}. Please install dependency "
+ "using 'pip install chemiscope[explore]'."
)
centers = None

# Get the atom indexes from the environments and pick related frames
if environments is not None:
centers = _extract_environment_indices(environments)

# Get global species
species = set()
for frame in frames:
Expand Down Expand Up @@ -174,18 +179,16 @@ def soap_pca_featurize(frames, centers=None):
return pca.fit_transform(feats)


def _pick_env_frames(envs, frames):
def _extract_environment_indices(envs):
"""
Get environment indices par structures and pick corresponding frames
Extract environment indices per structure
:param: list envs: each element is a list of [env_index, atom_index, cutoff]
:param: list frames: list of frames
:return: dict of structure indices mapping to lists of atom indices
"""
grouped_envs = {}
picked_frames = []
for [env_index, atom_index, _cutoff] in envs:
if env_index not in grouped_envs:
grouped_envs[env_index] = []
grouped_envs[env_index].append(atom_index)
picked_frames.append(frames[env_index])
return list(grouped_envs.values()), picked_frames
return list(grouped_envs.values())

0 comments on commit 8ed4406

Please sign in to comment.