-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
476 additions
and
169 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# -*- coding: utf-8 -*- | ||
import os | ||
|
||
import metatensor.torch | ||
import pytest | ||
import torch | ||
from metatensor.torch import Labels | ||
from metatensor.torch.atomistic import System | ||
|
||
import rascaline.torch | ||
from rascaline.torch.utils.clebsch_gordan import TensorCorrelator, _utils | ||
|
||
|
||
DATA_ROOT = os.path.join(os.path.dirname(__file__), "data") | ||
|
||
|
||
SPHERICAL_EXPANSION_HYPERS = { | ||
"cutoff": 2.5, | ||
"max_radial": 3, | ||
"max_angular": 3, | ||
"atomic_gaussian_width": 0.2, | ||
"radial_basis": {"Gto": {}}, | ||
"cutoff_function": {"ShiftedCosine": {"width": 0.5}}, | ||
"center_atom_weight": 1.0, | ||
} | ||
|
||
SELECTED_KEYS = Labels(names=["o3_lambda"], values=torch.tensor([1, 3]).reshape(-1, 1)) | ||
|
||
|
||
def system(): | ||
return System( | ||
types=torch.tensor([8, 1, 1]), | ||
positions=torch.tensor( | ||
[ | ||
[2.56633400, 2.50000000, 2.50370100], | ||
[1.97361700, 1.73067300, 2.47063400], | ||
[1.97361700, 3.26932700, 2.47063400], | ||
] | ||
), | ||
cell=torch.zeros((3, 3)), | ||
) | ||
|
||
|
||
def spherical_expansion(): | ||
"""Returns a rascaline SphericalExpansion""" | ||
calculator = rascaline.torch.SphericalExpansion(**SPHERICAL_EXPANSION_HYPERS) | ||
return calculator.compute(system()) | ||
|
||
|
||
# copy of def test_correlate_density_angular_selection( | ||
@pytest.mark.parametrize("selected_keys", [None, SELECTED_KEYS]) | ||
@pytest.mark.parametrize("skip_redundant", [True, False]) | ||
def test_torch_script_correlate_density_angular_selection( | ||
selected_keys: Labels, | ||
skip_redundant: bool, | ||
): | ||
""" | ||
Tests that the correct angular channels are output based on the specified | ||
``selected_keys``. | ||
""" | ||
nu_1 = spherical_expansion() | ||
corr_calculator = TensorCorrelator( | ||
max_angular=SPHERICAL_EXPANSION_HYPERS["max_angular"] * 2, | ||
) | ||
|
||
ref_nu_2 = corr_calculator.compute( | ||
_utils._increment_property_name_suffices(nu_1, 1), | ||
_utils._increment_property_name_suffices(nu_1, 2), | ||
selected_keys=selected_keys, | ||
) | ||
scripted_corr_calculator = torch.jit.script(corr_calculator) | ||
|
||
# Test compute | ||
scripted_nu_2 = scripted_corr_calculator.compute( | ||
_utils._increment_property_name_suffices(nu_1, 1), | ||
_utils._increment_property_name_suffices(nu_1, 2), | ||
selected_keys=selected_keys, | ||
) | ||
metatensor.torch.equal_metadata_raise(scripted_nu_2, ref_nu_2) | ||
assert metatensor.torch.allclose(scripted_nu_2, ref_nu_2) | ||
|
||
# Test compute_metadata | ||
scripted_nu_2 = scripted_corr_calculator.compute_metadata(nu_1) | ||
assert metatensor.torch.equal_metadata(scripted_nu_2, ref_nu_2) |
241 changes: 241 additions & 0 deletions
241
python/rascaline/rascaline/utils/clebsch_gordan/_density_correlations 2.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,241 @@ | ||
""" | ||
This module provides convenience calculators for preforming density correlations, i.e. | ||
the (iterative) CG tensor products of density (body order 2) tensors. | ||
All of these calculators wrap the :py:class:`TensorCorrelator` class, handling the | ||
higher-level metadata manipulation to produce the desired output tensors. | ||
""" | ||
from typing import List, Optional, Tuple | ||
|
||
from .. import _dispatch | ||
from .._backend import ( | ||
Labels, | ||
TensorMap, | ||
TorchModule, | ||
operations, | ||
torch_jit_export, | ||
) | ||
from . import _coefficients, _utils | ||
from ._tensor_correlator import TensorCorrelator | ||
|
||
|
||
try: | ||
import torch | ||
|
||
HAS_TORCH = True | ||
except ImportError: | ||
HAS_TORCH = False | ||
|
||
|
||
class DensityCorrelations(TorchModule): | ||
""" | ||
Iterative products of a density to form higher arbitrary body order tensors. | ||
:param max_angular: :py:class:`int`, the maximum angular momentum to compute CG | ||
coefficients for. | ||
:param arrays_backend: :py:class:`str`, the backend to use for array operations. If | ||
``None``, the backend is automatically selected based on the environment. | ||
Possible values are "numpy" and "torch". | ||
:param cg_backend: :py:class:`str`, the backend to use for the CG tensor product. If | ||
``None``, the backend is automatically selected based on the arrays backend. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
n_correlations: int, | ||
max_angular: int, | ||
arrays_backend: Optional[str] = None, | ||
cg_backend: Optional[str] = None, | ||
) -> None: | ||
|
||
super().__init__() | ||
|
||
self._n_correlations = n_correlations | ||
self._tensor_correlator = TensorCorrelator( | ||
max_angular=max_angular, | ||
arrays_backend=arrays_backend, | ||
) | ||
|
||
def forward( | ||
self, | ||
tensor: TensorMap, | ||
density: TensorMap, | ||
selected_keys: Optional[Labels] = None, | ||
angular_cutoff: Optional[int] = None, | ||
skip_redundant: bool = False, | ||
) -> TensorMap: | ||
""" | ||
Calls the :py:meth:`DensityCorrelations.compute` function. | ||
This is intended for :py:class:`torch.nn.Module` compatibility, and should be | ||
ignored in pure Python mode. | ||
See :py:meth:`compute` for a full description of the parameters. | ||
""" | ||
return self.compute( | ||
tensor, | ||
density, | ||
tensor_2, | ||
selected_keys=selected_keys, | ||
angular_cutoff=angular_cutoff, | ||
skip_redundant=skip_redundant, | ||
) | ||
|
||
@torch_jit_export | ||
def compute_metadata( | ||
self, | ||
tensor: TensorMap, | ||
density: TensorMap, | ||
n_correlations: int, | ||
selected_keys: Optional[Labels] = None, | ||
angular_cutoff: Optional[int] = None, | ||
skip_redundant: bool = False, | ||
) -> TensorMap: | ||
""" | ||
Returns the metadata-only :py:class:`TensorMap` that would be output by the | ||
function :py:meth:`compute` for the same calculator under the same settings, | ||
without performing the actual Clebsch-Gordan tensor products. | ||
See :py:meth:`compute` for a full description of the parameters. | ||
""" | ||
return self._density_correlations( | ||
tensor, | ||
density, | ||
selected_keys, | ||
angular_cutoff, | ||
skip_redundant, | ||
compute_metadata=True, | ||
) | ||
|
||
def compute( | ||
self, | ||
tensor: TensorMap, | ||
density: TensorMap, | ||
selected_keys: Optional[Labels] = None, | ||
angular_cutoff: Optional[int] = None, | ||
skip_redundant: bool = False, | ||
) -> TensorMap: | ||
""" | ||
Takes ``n_correlations`` of iterative CG tensor products of a tensor with a | ||
density. | ||
.. math:: | ||
\\T^{\\nu=\\nu'+n_{corr}} = T^{\\nu=\\nu'} | ||
\\otimes \\rho^{\\nu=1} \\ldots \\otimes | ||
\\rho^{\\nu=1} | ||
where T is the input ``tensor`` of arbitrary correlation order \\nu' and \\rho | ||
is the input ``density`` tensor of correlation order 1 (body order 2). | ||
As the density is by definition a correlation order 1 tensor, the correlation | ||
order of ``tensor`` will be increased by ``n_correlations`` from its original | ||
correlation order. | ||
``tensor`` and ``density`` must have metadata that is compatible for a CG tensor | ||
product by the :py:class:`TensorCorrelator` class. For every iteration after the | ||
first, the property dimension names of ``density`` are incremented numerically | ||
by 1 so that the metadata is compatible for the next tensor product. | ||
``selected_keys`` can be passed to select the keys to compute on the final | ||
iteration. If ``None``, all keys are computed. To limit the maximum angular | ||
momenta to compute on **intermediate** iterations, pass ``angular_cutoff``. | ||
If ``angular_cutoff`` and ``selected_keys`` are both passed, ``angular_cutoff`` | ||
is ignored on the final iteration. | ||
``skip_redundant`` can be passed to skip redundant computations on intermediate | ||
iterations. | ||
:param tensor: :py:class:`TensorMap`, the input tensor of arbitrary correlation | ||
order. | ||
:param density: :py:class:`TensorMap`, the input density tensor of correlation | ||
order 1. | ||
:param n_correlations: :py:class:`int`, the number of CG tensor products to | ||
perform. | ||
:param selected_keys: :py:class:`Labels`, the keys to compute on the final | ||
iteration. If ``None``, all keys are computed. | ||
""" | ||
return self._density_correlations( | ||
tensor, | ||
density, | ||
selected_keys, | ||
angular_cutoff, | ||
skip_redundant, | ||
compute_metadata=False, | ||
) | ||
|
||
|
||
def _density_correlations( | ||
self, | ||
tensor: TensorMap, | ||
density: TensorMap, | ||
selected_keys: Optional[Labels], | ||
angular_cutoff: Optional[int], | ||
skip_redundant: bool, | ||
compute_metadata: bool, | ||
) -> TensorMap: | ||
""" | ||
Performs the iterative CG tensor products. | ||
""" | ||
# Parse selection filters | ||
selected_keys, angular_cutoff = _parse_selection_filters( | ||
self._n_correlations, selected_keys, angular_cutoff | ||
) | ||
|
||
# Perform iterative CG tensor products | ||
new_lambda_names = [] | ||
density_correlations = tensor | ||
for i_correlation in range(self._n_correlations): | ||
|
||
# Rename density property dimensions | ||
if i_correlation > 0: # metadata assumed ok on first iteration | ||
for name in density.property_names: | ||
density = operations.rename_dimension( | ||
density, | ||
"properties", | ||
name, | ||
_utils._increment_numeric_suffix(name), | ||
) | ||
|
||
# Define new key dimension names for tracking intermediate correlations | ||
if i_correlation == 0: | ||
o3_lambda_1_name = f"l_{i_correlation + 1}" | ||
else: | ||
o3_lambda_1_name = f"k_{i_correlation + 1}" | ||
o3_lambda_2_name = f"l_{i_correlation + 2}" | ||
new_lambda_names.extend([o3_lambda_1_name, o3_lambda_2_name]) | ||
|
||
# Compute CG tensor product | ||
density_correlations = self._tensor_correlator._cg_tensor_product( | ||
density_correlations, | ||
density, | ||
o3_lambda_1_name, | ||
o3_lambda_2_name, | ||
selected_keys=selected_keys[i_correlation], | ||
angular_cutoff=angular_cutoff[i_correlation], | ||
skip_redundant=skip_redundant, | ||
compute_metadata=compute_metadata, | ||
) | ||
|
||
return density_correlations | ||
|
||
|
||
def _parse_selection_filters( | ||
n_correlations: int, | ||
selected_keys: Optional[Labels], | ||
angular_cutoff: Optional[int], | ||
) -> Tuple[List]: | ||
|
||
# Parse selected_keys | ||
selected_keys_ = [None] * (n_correlations - 1) | ||
selected_keys_ += [selected_keys] | ||
|
||
# Parse angular_cutoff and selected_keys | ||
angular_cutoff_ = [angular_cutoff] * (n_correlations - 1) | ||
if selected_keys is None: | ||
angular_cutoff_ += [angular_cutoff] | ||
else: | ||
angular_cutoff_ += [None] | ||
|
||
return selected_keys_, angular_cutoff_ |
Oops, something went wrong.