Skip to content

Commit

Permalink
Update to metatensor-torch v0.6.0
Browse files Browse the repository at this point in the history
This adds a `pbc` argument to System (for now only all True or all False
is supported) and a `strict` option to the neighbor list. We request non-strict
NL since we will make a copy & re-filter the neighbor lists anyway.
  • Loading branch information
Luthaf committed Nov 1, 2024
1 parent 8a28eef commit 5326b6e
Show file tree
Hide file tree
Showing 19 changed files with 78 additions and 60 deletions.
4 changes: 2 additions & 2 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ myst-parser # markdown => rst translation, used in extensions/rascaline_json

# dependencies for the tutorials
--extra-index-url https://download.pytorch.org/whl/cpu
metatensor
metatensor-torch >= 0.5.0,<0.6.0
metatensor-operations >=0.3.0,<0.4.0
metatensor-torch >= 0.6.0,<0.7.0
torch
chemfiles
matplotlib
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ classifiers = [

dependencies = [
"metatensor-core >=0.1.0,<0.2.0",
"metatensor-operations >=0.2.0,<0.3.0",
"metatensor-operations >=0.3.0,<0.4.0",
"wigners",
]

Expand Down
2 changes: 1 addition & 1 deletion python/rascaline-torch/build-backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_requires_for_build_wheel(config_settings=None):
defaults = build_meta.get_requires_for_build_wheel(config_settings)
return defaults + [
"torch >= 1.12",
"metatensor-torch >=0.5.0,<0.6.0",
"metatensor-torch >=0.6.0,<0.7.0",
RASCALINE_DEP,
]

Expand Down
3 changes: 3 additions & 0 deletions python/rascaline-torch/rascaline/torch/calculator_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ def requested_neighbor_lists(self) -> List[NeighborListOptions]:
NeighborListOptions(
cutoff=cutoff,
full_list=False,
# we will re-filter the NL when converting to rascaline internal
# type, so we don't need the engine to pre-filter it for us
strict=False,
requestor="rascaline",
)
)
Expand Down
6 changes: 6 additions & 0 deletions python/rascaline-torch/rascaline/torch/system.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Optional, Sequence, overload

import numpy as np
import torch
from metatensor.torch.atomistic import System
from packaging import version
Expand Down Expand Up @@ -65,6 +66,11 @@ def _system_to_torch(system, positions_requires_grad, cell_requires_grad):
types=torch.tensor(system.types()),
positions=torch.tensor(system.positions()),
cell=torch.tensor(system.cell()),
pbc=(
torch.tensor([False, False, False])
if np.all(system.cell() == 0.0)
else torch.tensor([True, True, True])
),
)

if positions_requires_grad is not None:
Expand Down
2 changes: 1 addition & 1 deletion python/rascaline-torch/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def git_extra_version():

install_requires = [
"torch >= 1.12",
"metatensor-torch >=0.5.0,<0.6.0",
"metatensor-torch >=0.6.0,<0.7.0",
]
if os.path.exists(RASCALINE_C_API):
# we are building from a git checkout
Expand Down
51 changes: 27 additions & 24 deletions python/rascaline-torch/tests/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@ def _create_random_system(n_atoms, cell_size):
cell = torch.tensor(cell[:], dtype=torch.float64)

positions = torch.rand((n_atoms, 3), dtype=torch.float64) @ cell
pbc = torch.tensor([True, True, True])

return types, positions, cell
return types, positions, cell, pbc


def _compute_spherical_expansion(types, positions, cell):
system = System(types=types, positions=positions, cell=cell)
def _compute_spherical_expansion(types, positions, cell, pbc):
system = System(types=types, positions=positions, cell=cell, pbc=pbc)

calculator = SphericalExpansion(**HYPERS)
descriptor = calculator(system)
Expand All @@ -47,8 +48,8 @@ def _compute_spherical_expansion(types, positions, cell):
return descriptor.block(0).values


def _compute_power_spectrum(types, positions, cell):
system = System(types=types, positions=positions, cell=cell)
def _compute_power_spectrum(types, positions, cell, pbc):
system = System(types=types, positions=positions, cell=cell, pbc=pbc)

calculator = SoapPowerSpectrum(**HYPERS)
descriptor = calculator(system)
Expand All @@ -59,56 +60,57 @@ def _compute_power_spectrum(types, positions, cell):


def test_spherical_expansion_positions_grad():
types, positions, cell = _create_random_system(n_atoms=75, cell_size=5.0)
types, positions, cell, pbc = _create_random_system(n_atoms=75, cell_size=5.0)
positions.requires_grad = True

assert torch.autograd.gradcheck(
_compute_spherical_expansion,
(types, positions, cell),
(types, positions, cell, pbc),
fast_mode=True,
)


def test_spherical_expansion_cell_grad():
types, positions, cell = _create_random_system(n_atoms=75, cell_size=5.0)
types, positions, cell, pbc = _create_random_system(n_atoms=75, cell_size=5.0)
cell.requires_grad = True

assert torch.autograd.gradcheck(
_compute_spherical_expansion,
(types, positions, cell),
(types, positions, cell, pbc),
fast_mode=True,
)


def test_power_spectrum_positions_grad():
types, positions, cell = _create_random_system(n_atoms=75, cell_size=5.0)
types, positions, cell, pbc = _create_random_system(n_atoms=75, cell_size=5.0)
positions.requires_grad = True

assert torch.autograd.gradcheck(
_compute_power_spectrum,
(types, positions, cell),
(types, positions, cell, pbc),
fast_mode=True,
)


def test_power_spectrum_cell_grad():
types, positions, cell = _create_random_system(n_atoms=75, cell_size=5.0)
types, positions, cell, pbc = _create_random_system(n_atoms=75, cell_size=5.0)
cell.requires_grad = True

assert torch.autograd.gradcheck(
_compute_power_spectrum,
(types, positions, cell),
(types, positions, cell, pbc),
fast_mode=True,
)


def test_power_spectrum_register_autograd():
# check autograd when registering the graph after pre-computing a representation
types, positions, cell = _create_random_system(n_atoms=75, cell_size=5.0)
types, positions, cell, pbc = _create_random_system(n_atoms=75, cell_size=5.0)

calculator = SoapPowerSpectrum(**HYPERS)
precomputed = calculator(
System(types, positions, cell), gradients=["positions", "cell"]
System(types, positions, cell, pbc),
gradients=["positions", "cell"],
)

# no grad_fn for now
Expand All @@ -118,7 +120,7 @@ def compute(new_positions, new_cell):
same_positions = (new_positions - positions).norm() < 1e-30
same_cell = (new_cell - cell).norm() < 1e-30

system = System(types=types, positions=positions, cell=cell)
system = System(types, positions, cell, pbc)

if same_positions and same_cell:
# we can only re-use the calculation when working with the same input
Expand Down Expand Up @@ -146,14 +148,14 @@ def compute(new_positions, new_cell):


def test_power_spectrum_positions_grad_grad():
types, positions, cell = _create_random_system(n_atoms=75, cell_size=5.0)
types, positions, cell, pbc = _create_random_system(n_atoms=75, cell_size=5.0)
positions.requires_grad = True

X = _compute_power_spectrum(types, positions, cell)
X = _compute_power_spectrum(types, positions, cell, pbc)
weights = torch.rand((X.shape[-1], 1), requires_grad=True, dtype=torch.float64)

def compute(weights):
X = _compute_power_spectrum(types, positions, cell)
X = _compute_power_spectrum(types, positions, cell, pbc)
A = X @ weights

return torch.autograd.grad(
Expand Down Expand Up @@ -184,14 +186,14 @@ def compute(weights):


def test_power_spectrum_cell_grad_grad():
types, positions, cell = _create_random_system(n_atoms=75, cell_size=5.0)
types, positions, cell, pbc = _create_random_system(n_atoms=75, cell_size=5.0)
cell.requires_grad = True

X = _compute_power_spectrum(types, positions, cell)
X = _compute_power_spectrum(types, positions, cell, pbc)
weights = torch.rand((X.shape[-1], 1), requires_grad=True, dtype=torch.float64)

def compute(weights):
X = _compute_power_spectrum(types, positions, cell)
X = _compute_power_spectrum(types, positions, cell, pbc)
A = X @ weights

return torch.autograd.grad(
Expand Down Expand Up @@ -244,7 +246,7 @@ def test_different_device_dtype():
options.append((torch.device("cuda:0"), torch.float64))

for device, dtype in options:
types, positions, cell = _create_random_system(n_atoms=10, cell_size=3.0)
types, positions, cell, pbc = _create_random_system(n_atoms=10, cell_size=3.0)
positions = positions.to(dtype=dtype, device=device, copy=True)
positions.requires_grad = True
assert positions.grad is None
Expand All @@ -254,11 +256,12 @@ def test_different_device_dtype():
assert cell.grad is None

types = types.to(device=device, copy=True)
pbc = pbc.to(device=device, copy=True)

with warnings.catch_warnings():
warnings.filterwarnings("ignore")

X = _compute_power_spectrum(types, positions, cell)
X = _compute_power_spectrum(types, positions, cell, pbc)

assert X.dtype == dtype
assert X.device == device
Expand Down
2 changes: 1 addition & 1 deletion python/rascaline-torch/tests/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def system():
types=torch.tensor([1, 1, 8, 8]),
positions=torch.tensor([[0.0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 0, 3]]),
cell=torch.tensor([[10.0, 0, 0], [0, 10, 0], [0, 0, 10]]),
pbc=torch.tensor([True, True, True]),
)


Expand Down Expand Up @@ -188,7 +189,6 @@ def test_different_device_dtype_errors(system):
custom_device = torch.device("cuda:0")

if custom_device is not None:

device_system = system.to(device=custom_device)

torch.set_warn_always(True)
Expand Down
3 changes: 2 additions & 1 deletion python/rascaline-torch/tests/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,10 @@ def test_export_as_metatensor_model(tmpdir):
model = Model(types=[1, 6, 8])
model.eval()

energy_output = ModelOutput()
energy_output = ModelOutput(quantity="energy", unit="eV")
capabilities = ModelCapabilities(
supported_devices=["cpu"],
length_unit="A",
interaction_range=HYPERS["cutoff"],
atomic_types=[1, 6, 8],
dtype="float64",
Expand Down
7 changes: 3 additions & 4 deletions python/rascaline-torch/tests/utils/cg_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,10 @@ def system():
[2.56633400, 2.50000000, 2.50370100],
[1.97361700, 1.73067300, 2.47063400],
[1.97361700, 3.26932700, 2.47063400],
],
),
cell=torch.zeros(
(3, 3),
]
),
cell=torch.zeros((3, 3)),
pbc=torch.tensor([False, False, False]),
)


Expand Down
1 change: 1 addition & 0 deletions python/rascaline-torch/tests/utils/density_correlations.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def system():
]
),
cell=torch.zeros((3, 3)),
pbc=torch.tensor([False, False, False]),
)


Expand Down
1 change: 1 addition & 0 deletions python/rascaline-torch/tests/utils/power_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def system():
types=torch.tensor([1, 1, 8, 8]),
positions=torch.tensor([[0.0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 0, 3]]),
cell=torch.tensor([[10.0, 0, 0], [0, 10, 0], [0, 0, 10]]),
pbc=torch.tensor([True, True, True]),
)


Expand Down
11 changes: 5 additions & 6 deletions python/rascaline/tests/utils/density_correlations.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def h2o_isolated():


def h2o_periodic():

return [
ase.Atoms(
symbols=["O", "H", "H"],
Expand Down Expand Up @@ -353,7 +352,7 @@ def test_correlate_density_norm():
unique_samples = metatensor.unique_metadata(
ps, "samples", names=["system", "atom", "center_type"]
)
grouped_labels = [
selections = [
Labels(names=ps.sample_names, values=unique_samples.values[i].reshape(1, 3))
for i in range(len(unique_samples))
]
Expand All @@ -362,11 +361,11 @@ def test_correlate_density_norm():
norm_nu1 = 0.0
norm_ps = 0.0
norm_ps_sorted = 0.0
for sample in grouped_labels:
for selection in selections:
# Slice the TensorMaps
nu1_sliced = metatensor.slice(density, "samples", labels=sample)
ps_sliced = metatensor.slice(ps, "samples", labels=sample)
ps_sorted_sliced = metatensor.slice(ps_sorted, "samples", labels=sample)
nu1_sliced = metatensor.slice(density, "samples", selection=selection)
ps_sliced = metatensor.slice(ps, "samples", selection=selection)
ps_sorted_sliced = metatensor.slice(ps_sorted, "samples", selection=selection)

# Calculate norms
norm_nu1 += get_norm(nu1_sliced) ** (n_correlations + 1)
Expand Down
4 changes: 2 additions & 2 deletions rascaline-c-api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ endif()
# ============================================================================ #
# Setup metatensor

set(METATENSOR_FETCH_VERSION "0.1.10")
set(METATENSOR_FETCH_VERSION "0.1.11")
set(METATENSOR_REQUIRED_VERSION "0.1")
if (RASCALINE_FETCH_METATENSOR)
message(STATUS "Fetching metatensor @ ${METATENSOR_FETCH_VERSION} from github")
Expand All @@ -232,7 +232,7 @@ if (RASCALINE_FETCH_METATENSOR)
FetchContent_Declare(
metatensor
URL ${URL_ROOT}/metatensor-core-v${METATENSOR_FETCH_VERSION}/metatensor-core-cxx-${METATENSOR_FETCH_VERSION}.tar.gz
URL_HASH SHA256=3ec0775da67bb0eb3246b81770426e612f83b6591442a39eb17aad6969b5f9d9
URL_HASH SHA256=b79435dbeb59a95361b4a5881574dfde4a4ed13f12b05a225df161dcaa7ea61e
)

if (CMAKE_VERSION VERSION_GREATER 3.18)
Expand Down
6 changes: 3 additions & 3 deletions rascaline-torch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ find_package(Torch 1.12 REQUIRED)
# ============================================================================ #
# Setup metatensor_torch

set(METATENSOR_FETCH_VERSION "0.5.5")
set(REQUIRED_METATENSOR_TORCH_VERSION "0.5")
set(METATENSOR_FETCH_VERSION "0.6.0")
set(REQUIRED_METATENSOR_TORCH_VERSION "0.6")
if (RASCALINE_TORCH_FETCH_METATENSOR_TORCH)
message(STATUS "Fetching metatensor-torch @ ${METATENSOR_FETCH_VERSION} from github")

Expand All @@ -68,7 +68,7 @@ if (RASCALINE_TORCH_FETCH_METATENSOR_TORCH)
FetchContent_Declare(
metatensor_torch
URL ${URL_ROOT}/metatensor-torch-v${METATENSOR_FETCH_VERSION}/metatensor-torch-cxx-${METATENSOR_FETCH_VERSION}.tar.gz
URL_HASH SHA256=dac306ab59ac8b59167827405f468397dbf0d4a69988fce7b9f4285f2816a57c
URL_HASH SHA256=f050743662ece38948b2087dd025d60110645716840dbfc5370c059e1275d0cf
)

if (CMAKE_VERSION VERSION_GREATER 3.18)
Expand Down
Loading

0 comments on commit 5326b6e

Please sign in to comment.