diff --git a/python/chemiscope/explore.py b/python/chemiscope/explore.py index 9df944d32..ff6fc48db 100644 --- a/python/chemiscope/explore.py +++ b/python/chemiscope/explore.py @@ -111,19 +111,22 @@ def soap_kpca_featurize(frames, environments): # Pick frames and properties related to the environments if provided if environments is not None: - unique_structure_indices = list({env[0] for env in environments}) - if any(index >= len(frames) for index in unique_structure_indices): + unique_structures = list({env[0] for env in environments}) + if any(index >= len(frames) for index in unique_structures): raise IndexError( "Some or more indices are greater than the length of the frames" ) - if len(unique_structure_indices) != len(frames): + if len(unique_structures) != len(frames): + # Sort environments by structure id and atom id + environments = sorted(environments, key=lambda x: (x[0], x[1])) + # Pick frames corresponding to the environments - frames = [frames[index] for index in unique_structure_indices] + frames = [frames[index] for index in unique_structures] # Pick properties corresponding to the environments properties = _extract_properties_by_environments( - properties, unique_structure_indices, len(environments) + properties, unique_structures ) # Apply dimensionality reduction from the provided featurizer @@ -212,7 +215,7 @@ def _extract_environment_indices(envs): return list(grouped_envs.values()) -def _extract_properties_by_environments(properties, structures, n_envs): +def _extract_properties_by_environments(properties, structures): """ Filter properties based on the structure indexes @@ -220,21 +223,19 @@ def _extract_properties_by_environments(properties, structures, n_envs): dictionary with list of values :param: list structures: structure indexes taken from the environments - - :param: int n_envs: total number of environments """ for prop_name, prop_val in properties.items(): if isinstance(prop_val, list): # Pick property values corresponding to the structure indexes - if len(prop_val) != n_envs: + if len(prop_val) != len(structures): properties[prop_name] = [ prop_val[struct_index] for struct_index in structures ] else: # Nested list with property values values = prop_val.get("values", []) - if len(values) != n_envs: + if len(values) != len(structures): properties[prop_name]["values"] = [ values[struct_index] for struct_index in structures ]