Skip to content

Commit

Permalink
import operations
Browse files Browse the repository at this point in the history
  • Loading branch information
jwa7 committed Jul 8, 2024
1 parent 5e9230f commit cda9c77
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions python/rascaline/rascaline/utils/clebsch_gordan/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import List, Optional, Tuple, Union

from .. import _dispatch
from .._backend import Array, Labels, TensorBlock, TensorMap, is_labels
from .._backend import Array, Labels, TensorBlock, TensorMap, is_labels, operations
from . import _coefficients


Expand Down Expand Up @@ -705,10 +705,10 @@ def _broadcast_first_block_samples(
"""
Broadcasts the values tensor of ``block_1`` along the samples dimension to match
those of ``block_2`` and returns the modified :py:class:`TensorBlock`.
Assumes that ``block_1`` has samples that are a subset of those of ``block_2``, and
are matching in the dimensions named in ``match_samples``.
are matching in the dimensions named in ``match_samples``.
Returns ``block_1`` with the same samples metadata as ``block_2`` and broadcasted
values along this axis.
"""
Expand Down Expand Up @@ -737,7 +737,9 @@ def _symmetrise_permutations(pairs: TensorMap) -> TensorMap:
Symmetrise the permutations of the samples in the pairwise SphericalExpansion.
"""
raise NotImplementedError("This function is not yet implemented.")
new_pairs = operations.insert_dimension(pairs, axis="samples", index=len(pairs.sample_names), name="sign", values=1)
new_pairs = operations.insert_dimension(
pairs, axis="samples", index=len(pairs.sample_names), name="sign", values=1
)

new_blocks = []

Expand All @@ -759,7 +761,7 @@ def _symmetrise_permutations(pairs: TensorMap) -> TensorMap:

if i == j and x == 0 and y == 0 and z == 0: # on-site
continue

# Create the negative label and append the negative permutation
negative_label = torch.tensor([A, j, i, x, y, z, -sign], dtype=torch.int32)
new_samples.append(sample.values)
Expand All @@ -770,7 +772,6 @@ def _symmetrise_permutations(pairs: TensorMap) -> TensorMap:
# new_samples.append(negative_label)
# new_values.append(block.values[block.samples.])


new_block = mts.TensorBlock(
values=new_values,
samples=new_samples,
Expand All @@ -779,7 +780,6 @@ def _symmetrise_permutations(pairs: TensorMap) -> TensorMap:
)
new_blocks.append(new_block)


new_pairs = mts.TensorMap(pairs.keys, new_blocks)

return mts.sort(new_pairs)
Expand Down

0 comments on commit cda9c77

Please sign in to comment.