Skip to content

Commit

Permalink
Use Labels.select for key selection in ClebschGordanProduct
Browse files Browse the repository at this point in the history
  • Loading branch information
Luthaf committed Sep 19, 2024
1 parent a4d7247 commit 83f4ae3
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -300,10 +300,10 @@ def _cg_tensor_product(

# 3. a) Apply key selections
if selected_keys is not None:
output_keys, combinations = _utils._apply_key_selection(
output_keys,
combinations,
selected_keys,
selected_idx = output_keys.select(selected_keys)
combinations = [combinations[i] for i in selected_idx]
output_keys = Labels(
names=output_keys.names, values=output_keys.values[selected_idx]
)

# 3. b) Apply key filter
Expand Down
56 changes: 0 additions & 56 deletions python/rascaline/rascaline/utils/clebsch_gordan/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,62 +160,6 @@ def _compute_output_keys(
return output_keys, combinations


def _apply_key_selection(
output_keys: Labels,
combinations: List[Tuple[int, int]],
selected_keys: Labels,
) -> Tuple[Labels, List[Tuple[int, int]]]:
"""
Applies a selection according to ``selected_keys`` to the keys of an output
TensorMap ``output_keys`` produced by the provided ``combinations`` of blocks.
After application of the selections, returned is a reduced set of keys and set of
corresponding parents key entries.
If a selection in ``selected_keys`` is not valid based on the keys in
``output_keys``, we raise an error.
"""
# Extract the relevant columns from `selected_keys` that the selection will be
# performed on
col_idx = _dispatch.int_array_like(
[output_keys.names.index(name) for name in selected_keys.names],
output_keys.values,
)
output_keys_values = output_keys.values[:, col_idx]

# First check that all of the selected keys exist in the output keys
for selected in selected_keys.values:
if not any(
[
bool(all(selected == output_keys_values[i]))
for i in range(len(output_keys_values))
]
):
raise ValueError(
f"selected key {selected_keys.names} = {selected} not found "
"in the output keys"
)

# Build a mask of the selected keys
mask = _dispatch.bool_array_like(
[
any([bool(all(i == j)) for j in selected_keys.values])
for i in output_keys_values
],
like=selected_keys.values,
)

mask_indices = _dispatch.int_array_like(
list(range(len(combinations))), like=selected_keys.values
)[mask]

# Apply the mask to combinations and keys
combinations = [combinations[i] for i in mask_indices]
output_keys = Labels(names=output_keys.names, values=output_keys.values[mask])

return output_keys, combinations


def _group_combinations_of_same_blocks(
output_keys: Labels,
combinations: List[Tuple[int, int]],
Expand Down

0 comments on commit 83f4ae3

Please sign in to comment.