Skip to content

Commit

Permalink
Update symmetrisation function
Browse files Browse the repository at this point in the history
  • Loading branch information
ppegolo committed Jul 11, 2024
1 parent cda9c77 commit e1c7946
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 20 deletions.
20 changes: 20 additions & 0 deletions python/rascaline/rascaline/utils/_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,26 @@ def concatenate(arrays: List[TorchTensor], axis: int):
return np.concatenate(arrays, axis)
else:
raise TypeError(UNKNOWN_ARRAY_TYPE)

def stack(arrays: List[TorchTensor], axis: int):
"""
Stack a group of arrays along a new axis.
This function has the same behavior as ``numpy.stack(arrays, axis)``
and ``torch.stack(arrays, axis)``.
Passing `axis` as ``0`` is equivalent to stacking arrays along the first
dimension, ``1`` along the second dimension, and so on.
"""
if isinstance(arrays[0], TorchTensor):
_check_all_torch_tensor(arrays)
return torch.stack(arrays, axis)
elif isinstance(arrays[0], np.ndarray):
_check_all_np_ndarray(arrays)
return np.stack(arrays, axis)
else:
raise TypeError(UNKNOWN_ARRAY_TYPE)



def empty_like(array, shape: Optional[List[int]] = None, requires_grad: bool = False):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def _correlate_tensor_with_density(
pair_density, "properties", "n", "n_2"
)

# Symmetrise permutations. TODO: finish function
# Symmetrise permutations. TODO: to remove
pair_density = _utils._symmetrise_permutations(pair_density)

# Initialize the CorrelateTensorWithDensity calculator. Re-use the CG
Expand Down
36 changes: 17 additions & 19 deletions python/rascaline/rascaline/utils/clebsch_gordan/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,53 +736,51 @@ def _symmetrise_permutations(pairs: TensorMap) -> TensorMap:
"""
Symmetrise the permutations of the samples in the pairwise SphericalExpansion.
"""
raise NotImplementedError("This function is not yet implemented.")
# raise NotImplementedError("This function is not yet implemented.")
new_pairs = operations.insert_dimension(
pairs, axis="samples", index=len(pairs.sample_names), name="sign", values=1
)

new_blocks = []

for key, block in pairs.items():
for key, block in new_pairs.items():
same_types = key["first_atom_type"] == key["second_atom_type"]

if not same_types:
new_blocks.append(block)
continue

new_values, new_samples = [], []
samples = block.samples
for i_sample, sample in enumerate(samples):

A, i, j, x, y, z, sign = sample

# Always include the positive permutation
new_samples.append(sample.values)
new_values.append(block.values[i_sample])

if i == j and x == 0 and y == 0 and z == 0: # on-site
continue

# Create the negative label and append the negative permutation
negative_label = torch.tensor([A, j, i, x, y, z, -sign], dtype=torch.int32)
# Always include the positive permutation
new_samples.append(sample.values)
new_values.append(block.values[i_sample])

# TODO: finish this function
# Create the negative sample label inverting i,j and changing sign
new_samples.append(_dispatch.int_array_like([A, j, i, x, y, z, -sign], like = sample.values))

# new_samples.append(negative_label)
# new_values.append(block.values[block.samples.])
# Find the index of the negative sample label, defined inverting i,j and changing sign to x,y,z
neg_i_sample = samples.position(_dispatch.int_array_like([A, j, i, -x, -y, -z, sign], like = sample.values))
assert isinstance(neg_i_sample, int), f"Negative sample label not found for key={key}, neg_i_sample={neg_i_sample}"
new_values.append(block.values[neg_i_sample])

new_block = mts.TensorBlock(
values=new_values,
samples=new_samples,
components=block.components,
properties=block.properties,
new_block = TensorBlock(
values = _dispatch.stack(new_values, axis = 0),
samples = Labels(samples.names, _dispatch.stack(new_samples, axis = 0)),
components = block.components,
properties = block.properties,
)
new_blocks.append(new_block)

new_pairs = mts.TensorMap(pairs.keys, new_blocks)
new_pairs = TensorMap(pairs.keys, new_blocks)

return mts.sort(new_pairs)
return operations.sort(new_pairs)


# ======================================================================= #
Expand Down

0 comments on commit e1c7946

Please sign in to comment.