Skip to content

Commit

Permalink
Changes
Browse files Browse the repository at this point in the history
  • Loading branch information
jwa7 committed Sep 4, 2024
1 parent 837ef79 commit 39484d0
Show file tree
Hide file tree
Showing 8 changed files with 476 additions and 169 deletions.
1 change: 0 additions & 1 deletion python/rascaline-torch/tests/utils/density_correlations.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def spherical_expansion():
return calculator.compute(system())


# copy of def test_correlate_density_angular_selection(
@pytest.mark.parametrize("selected_keys", [None, SELECTED_KEYS])
@pytest.mark.parametrize("skip_redundant", [True, False])
def test_torch_script_correlate_density_angular_selection(
Expand Down
84 changes: 84 additions & 0 deletions python/rascaline-torch/tests/utils/tensor_correlator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# -*- coding: utf-8 -*-
import os

import metatensor.torch
import pytest
import torch
from metatensor.torch import Labels
from metatensor.torch.atomistic import System

import rascaline.torch
from rascaline.torch.utils.clebsch_gordan import TensorCorrelator, _utils


DATA_ROOT = os.path.join(os.path.dirname(__file__), "data")


SPHERICAL_EXPANSION_HYPERS = {
"cutoff": 2.5,
"max_radial": 3,
"max_angular": 3,
"atomic_gaussian_width": 0.2,
"radial_basis": {"Gto": {}},
"cutoff_function": {"ShiftedCosine": {"width": 0.5}},
"center_atom_weight": 1.0,
}

SELECTED_KEYS = Labels(names=["o3_lambda"], values=torch.tensor([1, 3]).reshape(-1, 1))


def system():
return System(
types=torch.tensor([8, 1, 1]),
positions=torch.tensor(
[
[2.56633400, 2.50000000, 2.50370100],
[1.97361700, 1.73067300, 2.47063400],
[1.97361700, 3.26932700, 2.47063400],
]
),
cell=torch.zeros((3, 3)),
)


def spherical_expansion():
"""Returns a rascaline SphericalExpansion"""
calculator = rascaline.torch.SphericalExpansion(**SPHERICAL_EXPANSION_HYPERS)
return calculator.compute(system())


# copy of def test_correlate_density_angular_selection(
@pytest.mark.parametrize("selected_keys", [None, SELECTED_KEYS])
@pytest.mark.parametrize("skip_redundant", [True, False])
def test_torch_script_correlate_density_angular_selection(
selected_keys: Labels,
skip_redundant: bool,
):
"""
Tests that the correct angular channels are output based on the specified
``selected_keys``.
"""
nu_1 = spherical_expansion()
corr_calculator = TensorCorrelator(
max_angular=SPHERICAL_EXPANSION_HYPERS["max_angular"] * 2,
)

ref_nu_2 = corr_calculator.compute(
_utils._increment_property_name_suffices(nu_1, 1),
_utils._increment_property_name_suffices(nu_1, 2),
selected_keys=selected_keys,
)
scripted_corr_calculator = torch.jit.script(corr_calculator)

# Test compute
scripted_nu_2 = scripted_corr_calculator.compute(
_utils._increment_property_name_suffices(nu_1, 1),
_utils._increment_property_name_suffices(nu_1, 2),
selected_keys=selected_keys,
)
metatensor.torch.equal_metadata_raise(scripted_nu_2, ref_nu_2)
assert metatensor.torch.allclose(scripted_nu_2, ref_nu_2)

# Test compute_metadata
scripted_nu_2 = scripted_corr_calculator.compute_metadata(nu_1)
assert metatensor.torch.equal_metadata(scripted_nu_2, ref_nu_2)
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
"""
This module provides convenience calculators for preforming density correlations, i.e.
the (iterative) CG tensor products of density (body order 2) tensors.
All of these calculators wrap the :py:class:`TensorCorrelator` class, handling the
higher-level metadata manipulation to produce the desired output tensors.
"""
from typing import List, Optional, Tuple

from .. import _dispatch
from .._backend import (
Labels,
TensorMap,
TorchModule,
operations,
torch_jit_export,
)
from . import _coefficients, _utils
from ._tensor_correlator import TensorCorrelator


try:
import torch

HAS_TORCH = True
except ImportError:
HAS_TORCH = False


class DensityCorrelations(TorchModule):
"""
Iterative products of a density to form higher arbitrary body order tensors.
:param max_angular: :py:class:`int`, the maximum angular momentum to compute CG
coefficients for.
:param arrays_backend: :py:class:`str`, the backend to use for array operations. If
``None``, the backend is automatically selected based on the environment.
Possible values are "numpy" and "torch".
:param cg_backend: :py:class:`str`, the backend to use for the CG tensor product. If
``None``, the backend is automatically selected based on the arrays backend.
"""

def __init__(
self,
n_correlations: int,
max_angular: int,
arrays_backend: Optional[str] = None,
cg_backend: Optional[str] = None,
) -> None:

super().__init__()

self._n_correlations = n_correlations
self._tensor_correlator = TensorCorrelator(
max_angular=max_angular,
arrays_backend=arrays_backend,
)

def forward(
self,
tensor: TensorMap,
density: TensorMap,
selected_keys: Optional[Labels] = None,
angular_cutoff: Optional[int] = None,
skip_redundant: bool = False,
) -> TensorMap:
"""
Calls the :py:meth:`DensityCorrelations.compute` function.
This is intended for :py:class:`torch.nn.Module` compatibility, and should be
ignored in pure Python mode.
See :py:meth:`compute` for a full description of the parameters.
"""
return self.compute(
tensor,
density,
tensor_2,
selected_keys=selected_keys,
angular_cutoff=angular_cutoff,
skip_redundant=skip_redundant,
)

@torch_jit_export
def compute_metadata(
self,
tensor: TensorMap,
density: TensorMap,
n_correlations: int,
selected_keys: Optional[Labels] = None,
angular_cutoff: Optional[int] = None,
skip_redundant: bool = False,
) -> TensorMap:
"""
Returns the metadata-only :py:class:`TensorMap` that would be output by the
function :py:meth:`compute` for the same calculator under the same settings,
without performing the actual Clebsch-Gordan tensor products.
See :py:meth:`compute` for a full description of the parameters.
"""
return self._density_correlations(
tensor,
density,
selected_keys,
angular_cutoff,
skip_redundant,
compute_metadata=True,
)

def compute(
self,
tensor: TensorMap,
density: TensorMap,
selected_keys: Optional[Labels] = None,
angular_cutoff: Optional[int] = None,
skip_redundant: bool = False,
) -> TensorMap:
"""
Takes ``n_correlations`` of iterative CG tensor products of a tensor with a
density.
.. math::
\\T^{\\nu=\\nu'+n_{corr}} = T^{\\nu=\\nu'}
\\otimes \\rho^{\\nu=1} \\ldots \\otimes
\\rho^{\\nu=1}
where T is the input ``tensor`` of arbitrary correlation order \\nu' and \\rho
is the input ``density`` tensor of correlation order 1 (body order 2).
As the density is by definition a correlation order 1 tensor, the correlation
order of ``tensor`` will be increased by ``n_correlations`` from its original
correlation order.
``tensor`` and ``density`` must have metadata that is compatible for a CG tensor
product by the :py:class:`TensorCorrelator` class. For every iteration after the
first, the property dimension names of ``density`` are incremented numerically
by 1 so that the metadata is compatible for the next tensor product.
``selected_keys`` can be passed to select the keys to compute on the final
iteration. If ``None``, all keys are computed. To limit the maximum angular
momenta to compute on **intermediate** iterations, pass ``angular_cutoff``.
If ``angular_cutoff`` and ``selected_keys`` are both passed, ``angular_cutoff``
is ignored on the final iteration.
``skip_redundant`` can be passed to skip redundant computations on intermediate
iterations.
:param tensor: :py:class:`TensorMap`, the input tensor of arbitrary correlation
order.
:param density: :py:class:`TensorMap`, the input density tensor of correlation
order 1.
:param n_correlations: :py:class:`int`, the number of CG tensor products to
perform.
:param selected_keys: :py:class:`Labels`, the keys to compute on the final
iteration. If ``None``, all keys are computed.
"""
return self._density_correlations(
tensor,
density,
selected_keys,
angular_cutoff,
skip_redundant,
compute_metadata=False,
)


def _density_correlations(
self,
tensor: TensorMap,
density: TensorMap,
selected_keys: Optional[Labels],
angular_cutoff: Optional[int],
skip_redundant: bool,
compute_metadata: bool,
) -> TensorMap:
"""
Performs the iterative CG tensor products.
"""
# Parse selection filters
selected_keys, angular_cutoff = _parse_selection_filters(
self._n_correlations, selected_keys, angular_cutoff
)

# Perform iterative CG tensor products
new_lambda_names = []
density_correlations = tensor
for i_correlation in range(self._n_correlations):

# Rename density property dimensions
if i_correlation > 0: # metadata assumed ok on first iteration
for name in density.property_names:
density = operations.rename_dimension(
density,
"properties",
name,
_utils._increment_numeric_suffix(name),
)

# Define new key dimension names for tracking intermediate correlations
if i_correlation == 0:
o3_lambda_1_name = f"l_{i_correlation + 1}"
else:
o3_lambda_1_name = f"k_{i_correlation + 1}"
o3_lambda_2_name = f"l_{i_correlation + 2}"
new_lambda_names.extend([o3_lambda_1_name, o3_lambda_2_name])

# Compute CG tensor product
density_correlations = self._tensor_correlator._cg_tensor_product(
density_correlations,
density,
o3_lambda_1_name,
o3_lambda_2_name,
selected_keys=selected_keys[i_correlation],
angular_cutoff=angular_cutoff[i_correlation],
skip_redundant=skip_redundant,
compute_metadata=compute_metadata,
)

return density_correlations


def _parse_selection_filters(
n_correlations: int,
selected_keys: Optional[Labels],
angular_cutoff: Optional[int],
) -> Tuple[List]:

# Parse selected_keys
selected_keys_ = [None] * (n_correlations - 1)
selected_keys_ += [selected_keys]

# Parse angular_cutoff and selected_keys
angular_cutoff_ = [angular_cutoff] * (n_correlations - 1)
if selected_keys is None:
angular_cutoff_ += [angular_cutoff]
else:
angular_cutoff_ += [None]

return selected_keys_, angular_cutoff_
Loading

0 comments on commit 39484d0

Please sign in to comment.