Skip to content

Commit

Permalink
Fill neighbor species when construct Python PS
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Aug 7, 2023
1 parent b97bcec commit 6df6e2a
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 6 deletions.
39 changes: 33 additions & 6 deletions python/rascaline/rascaline/utils/power_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,19 @@ def compute(
*,
gradients: Optional[List[str]] = None,
use_native_system: bool = True,
fill_species_neighbor: bool = False,
) -> TensorMap:
"""Runs a calculation with this calculator on the given ``systems``.
See :py:func:`rascaline.calculators.CalculatorBase.compute()` for details on the
parameters.
:param fill_species_neighbor: Blocks with the same ``species_neighbor`` keys are
merged along the ``properties`` dimensions. By default the power spectrum
will **only** contain neighbor species from existing blocks. This behaviour
might prevent later merging blocks along the ``sample`` direction. One can
this behevaiour and consider all possible species by setting
``fill_species_neighbor=True``.
:raises NotImplementedError: If a spherical expansions contains a gradient with
respect to an unknwon parameter.
"""
Expand All @@ -183,9 +190,18 @@ def compute(
]
assert spherical_expansion_1.keys.names == expected_key_names
assert spherical_expansion_1.property_names == ["n"]
spherical_expansion_1 = spherical_expansion_1.keys_to_properties(
"species_neighbor"
)

if fill_species_neighbor:
keys_to_move = Labels(
names="species_neighbor",
values=np.unique(
spherical_expansion_1.keys["species_neighbor"]
).reshape(-1, 1),
)
else:
keys_to_move = Labels.empty(names="species_neighbor")

spherical_expansion_1 = spherical_expansion_1.keys_to_properties(keys_to_move)

if self.calculator_2 is None:
spherical_expansion_2 = spherical_expansion_1
Expand All @@ -197,8 +213,19 @@ def compute(
)
assert spherical_expansion_2.keys.names == expected_key_names
assert spherical_expansion_2.property_names == ["n"]

if fill_species_neighbor:
keys_to_move = Labels(
names="species_neighbor",
values=np.unique(
spherical_expansion_2.keys["species_neighbor"]
).reshape(-1, 1),
)
else:
keys_to_move = Labels.empty(names="species_neighbor")

spherical_expansion_2 = spherical_expansion_2.keys_to_properties(
"species_neighbor"
keys_to_move
)

blocks = []
Expand All @@ -212,7 +239,7 @@ def compute(
spherical_harmonics_l=ell, species_center=species_center
)
for block_2 in blocks_2:
# Makre sure that samples are the same. This should not happen.
# Make sure that samples are the same. This should not happen.
assert block_1.samples == block_2.samples

properties = Labels(
Expand Down Expand Up @@ -258,7 +285,7 @@ def _positions_gradients(new_block, block_1, block_2, factor):
gradient_2 = block_2.gradient("positions")

if len(gradient_1.samples) == 0 or len(gradient_2.samples) == 0:
gradients_samples = Labels.empty()
gradients_samples = Labels.empty(names=["sample", "structure", "atom"])
gradient_values = np.array([]).reshape(0, 1, len(new_block.properties))
else:
# The "sample" dimension in the power spectrum gradient samples do
Expand Down
18 changes: 18 additions & 0 deletions python/rascaline/tests/utils/power_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,21 @@ def test_power_spectrum_unknown_gradient() -> None:
msg = "PowerSpectrum currently only supports gradients w.r.t. to positions"
with pytest.raises(NotImplementedError, match=msg):
PowerSpectrum(calculator).compute(SystemForTests(), gradients=["cell"])


def test_fill_species_neighbor() -> None:
"""Test that ``species_center`` keys can be merged."""

frames = [
ase.Atoms("H", positions=np.zeros([1, 3])),
ase.Atoms("O", positions=np.zeros([1, 3])),
]

calculator = PowerSpectrum(
calculator_1=rascaline.SphericalExpansion(**HYPERS),
calculator_2=rascaline.SphericalExpansion(**HYPERS),
)

descriptor = calculator.compute(frames, fill_species_neighbor=True)

descriptor.keys_to_samples("species_center")

0 comments on commit 6df6e2a

Please sign in to comment.