Skip to content

Commit

Permalink
Add sorting of the environments
Browse files Browse the repository at this point in the history
  • Loading branch information
sofiia-chorna committed Sep 4, 2024
1 parent d5e3873 commit 313233d
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions python/chemiscope/explore.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,19 +111,22 @@ def soap_kpca_featurize(frames, environments):

# Pick frames and properties related to the environments if provided
if environments is not None:
unique_structure_indices = list({env[0] for env in environments})
if any(index >= len(frames) for index in unique_structure_indices):
unique_structures = list({env[0] for env in environments})
if any(index >= len(frames) for index in unique_structures):
raise IndexError(
"Some or more indices are greater than the length of the frames"
)

if len(unique_structure_indices) != len(frames):
if len(unique_structures) != len(frames):
# Sort environments by structure id and atom id
environments = sorted(environments, key=lambda x: (x[0], x[1]))

# Pick frames corresponding to the environments
frames = [frames[index] for index in unique_structure_indices]
frames = [frames[index] for index in unique_structures]

# Pick properties corresponding to the environments
properties = _extract_properties_by_environments(
properties, unique_structure_indices, len(environments)
properties, unique_structures
)

# Apply dimensionality reduction from the provided featurizer
Expand Down Expand Up @@ -212,29 +215,27 @@ def _extract_environment_indices(envs):
return list(grouped_envs.values())


def _extract_properties_by_environments(properties, structures, n_envs):
def _extract_properties_by_environments(properties, structures):
"""
Filter properties based on the structure indexes
:param: dict properties: properties where each property can be either a list or a
dictionary with list of values
:param: list structures: structure indexes taken from the environments
:param: int n_envs: total number of environments
"""

for prop_name, prop_val in properties.items():
if isinstance(prop_val, list):
# Pick property values corresponding to the structure indexes
if len(prop_val) != n_envs:
if len(prop_val) != len(structures):
properties[prop_name] = [
prop_val[struct_index] for struct_index in structures
]
else:
# Nested list with property values
values = prop_val.get("values", [])
if len(values) != n_envs:
if len(values) != len(structures):
properties[prop_name]["values"] = [
values[struct_index] for struct_index in structures
]
Expand Down

0 comments on commit 313233d

Please sign in to comment.