diff --git a/python/rascaline/rascaline/utils/clebsch_gordan/_cg_product.py b/python/rascaline/rascaline/utils/clebsch_gordan/_cg_product.py index 6bb47c6de..6e5268752 100644 --- a/python/rascaline/rascaline/utils/clebsch_gordan/_cg_product.py +++ b/python/rascaline/rascaline/utils/clebsch_gordan/_cg_product.py @@ -300,10 +300,10 @@ def _cg_tensor_product( # 3. a) Apply key selections if selected_keys is not None: - output_keys, combinations = _utils._apply_key_selection( - output_keys, - combinations, - selected_keys, + selected_idx = output_keys.select(selected_keys) + combinations = [combinations[i] for i in selected_idx] + output_keys = Labels( + names=output_keys.names, values=output_keys.values[selected_idx] ) # 3. b) Apply key filter diff --git a/python/rascaline/rascaline/utils/clebsch_gordan/_utils.py b/python/rascaline/rascaline/utils/clebsch_gordan/_utils.py index 01505650e..b02292d3d 100644 --- a/python/rascaline/rascaline/utils/clebsch_gordan/_utils.py +++ b/python/rascaline/rascaline/utils/clebsch_gordan/_utils.py @@ -160,62 +160,6 @@ def _compute_output_keys( return output_keys, combinations -def _apply_key_selection( - output_keys: Labels, - combinations: List[Tuple[int, int]], - selected_keys: Labels, -) -> Tuple[Labels, List[Tuple[int, int]]]: - """ - Applies a selection according to ``selected_keys`` to the keys of an output - TensorMap ``output_keys`` produced by the provided ``combinations`` of blocks. - - After application of the selections, returned is a reduced set of keys and set of - corresponding parents key entries. - - If a selection in ``selected_keys`` is not valid based on the keys in - ``output_keys``, we raise an error. - """ - # Extract the relevant columns from `selected_keys` that the selection will be - # performed on - col_idx = _dispatch.int_array_like( - [output_keys.names.index(name) for name in selected_keys.names], - output_keys.values, - ) - output_keys_values = output_keys.values[:, col_idx] - - # First check that all of the selected keys exist in the output keys - for selected in selected_keys.values: - if not any( - [ - bool(all(selected == output_keys_values[i])) - for i in range(len(output_keys_values)) - ] - ): - raise ValueError( - f"selected key {selected_keys.names} = {selected} not found " - "in the output keys" - ) - - # Build a mask of the selected keys - mask = _dispatch.bool_array_like( - [ - any([bool(all(i == j)) for j in selected_keys.values]) - for i in output_keys_values - ], - like=selected_keys.values, - ) - - mask_indices = _dispatch.int_array_like( - list(range(len(combinations))), like=selected_keys.values - )[mask] - - # Apply the mask to combinations and keys - combinations = [combinations[i] for i in mask_indices] - output_keys = Labels(names=output_keys.names, values=output_keys.values[mask]) - - return output_keys, combinations - - def _group_combinations_of_same_blocks( output_keys: Labels, combinations: List[Tuple[int, int]],