Skip to content

Commit

Permalink
Update to metatensor-torch v0.6.0
Browse files Browse the repository at this point in the history
This adds a `pbc` argument to System (for now only all True or all False
is supported) and a `strict` option to the neighbor list. We request non-strict
NL since we will make a copy & re-filter the neighbor lists anyway.
  • Loading branch information
Luthaf committed Oct 30, 2024
1 parent b70b19e commit dd4e110
Show file tree
Hide file tree
Showing 17 changed files with 69 additions and 52 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ classifiers = [

dependencies = [
"metatensor-core >=0.1.0,<0.2.0",
"metatensor-operations >=0.2.0,<0.3.0",
"metatensor-operations >=0.3.0,<0.4.0",
"wigners",
]

Expand Down
2 changes: 1 addition & 1 deletion python/rascaline-torch/build-backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_requires_for_build_wheel(config_settings=None):
defaults = build_meta.get_requires_for_build_wheel(config_settings)
return defaults + [
"torch >= 1.12",
"metatensor-torch >=0.5.0,<0.6.0",
"metatensor-torch >=0.6.0,<0.7.0",
RASCALINE_DEP,
]

Expand Down
3 changes: 3 additions & 0 deletions python/rascaline-torch/rascaline/torch/calculator_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ def requested_neighbor_lists(self) -> List[NeighborListOptions]:
NeighborListOptions(
cutoff=cutoff,
full_list=False,
# we will re-filter the NL when converting to rascaline internal
# type, so we don't need the engine to pre-filter it for us
strict=False,
requestor="rascaline",
)
)
Expand Down
4 changes: 4 additions & 0 deletions python/rascaline-torch/rascaline/torch/system.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Optional, Sequence, overload

import torch
import numpy as np
from metatensor.torch.atomistic import System
from packaging import version

Expand Down Expand Up @@ -65,6 +66,9 @@ def _system_to_torch(system, positions_requires_grad, cell_requires_grad):
types=torch.tensor(system.types()),
positions=torch.tensor(system.positions()),
cell=torch.tensor(system.cell()),
pbc=torch.tensor([False, False, False])
if np.all(system.cell() == 0.0)
else torch.tensor([True, True, True]),
)

if positions_requires_grad is not None:
Expand Down
2 changes: 1 addition & 1 deletion python/rascaline-torch/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def git_extra_version():

install_requires = [
"torch >= 1.12",
"metatensor-torch >=0.5.0,<0.6.0",
"metatensor-torch >=0.6.0,<0.7.0",
]
if os.path.exists(RASCALINE_C_API):
# we are building from a git checkout
Expand Down
51 changes: 27 additions & 24 deletions python/rascaline-torch/tests/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@ def _create_random_system(n_atoms, cell_size):
cell = torch.tensor(cell[:], dtype=torch.float64)

positions = torch.rand((n_atoms, 3), dtype=torch.float64) @ cell
pbc = torch.tensor([True, True, True])

return types, positions, cell
return types, positions, cell, pbc


def _compute_spherical_expansion(types, positions, cell):
system = System(types=types, positions=positions, cell=cell)
def _compute_spherical_expansion(types, positions, cell, pbc):
system = System(types=types, positions=positions, cell=cell, pbc=pbc)

calculator = SphericalExpansion(**HYPERS)
descriptor = calculator(system)
Expand All @@ -47,8 +48,8 @@ def _compute_spherical_expansion(types, positions, cell):
return descriptor.block(0).values


def _compute_power_spectrum(types, positions, cell):
system = System(types=types, positions=positions, cell=cell)
def _compute_power_spectrum(types, positions, cell, pbc):
system = System(types=types, positions=positions, cell=cell, pbc=pbc)

calculator = SoapPowerSpectrum(**HYPERS)
descriptor = calculator(system)
Expand All @@ -59,56 +60,57 @@ def _compute_power_spectrum(types, positions, cell):


def test_spherical_expansion_positions_grad():
types, positions, cell = _create_random_system(n_atoms=75, cell_size=5.0)
types, positions, cell, pbc = _create_random_system(n_atoms=75, cell_size=5.0)
positions.requires_grad = True

assert torch.autograd.gradcheck(
_compute_spherical_expansion,
(types, positions, cell),
(types, positions, cell, pbc),
fast_mode=True,
)


def test_spherical_expansion_cell_grad():
types, positions, cell = _create_random_system(n_atoms=75, cell_size=5.0)
types, positions, cell, pbc = _create_random_system(n_atoms=75, cell_size=5.0)
cell.requires_grad = True

assert torch.autograd.gradcheck(
_compute_spherical_expansion,
(types, positions, cell),
(types, positions, cell, pbc),
fast_mode=True,
)


def test_power_spectrum_positions_grad():
types, positions, cell = _create_random_system(n_atoms=75, cell_size=5.0)
types, positions, cell, pbc = _create_random_system(n_atoms=75, cell_size=5.0)
positions.requires_grad = True

assert torch.autograd.gradcheck(
_compute_power_spectrum,
(types, positions, cell),
(types, positions, cell, pbc),
fast_mode=True,
)


def test_power_spectrum_cell_grad():
types, positions, cell = _create_random_system(n_atoms=75, cell_size=5.0)
types, positions, cell, pbc = _create_random_system(n_atoms=75, cell_size=5.0)
cell.requires_grad = True

assert torch.autograd.gradcheck(
_compute_power_spectrum,
(types, positions, cell),
(types, positions, cell, pbc),
fast_mode=True,
)


def test_power_spectrum_register_autograd():
# check autograd when registering the graph after pre-computing a representation
types, positions, cell = _create_random_system(n_atoms=75, cell_size=5.0)
types, positions, cell, pbc = _create_random_system(n_atoms=75, cell_size=5.0)

calculator = SoapPowerSpectrum(**HYPERS)
precomputed = calculator(
System(types, positions, cell), gradients=["positions", "cell"]
System(types, positions, cell, pbc),
gradients=["positions", "cell"],
)

# no grad_fn for now
Expand All @@ -118,7 +120,7 @@ def compute(new_positions, new_cell):
same_positions = (new_positions - positions).norm() < 1e-30
same_cell = (new_cell - cell).norm() < 1e-30

system = System(types=types, positions=positions, cell=cell)
system = System(types, positions, cell, pbc)

if same_positions and same_cell:
# we can only re-use the calculation when working with the same input
Expand Down Expand Up @@ -146,14 +148,14 @@ def compute(new_positions, new_cell):


def test_power_spectrum_positions_grad_grad():
types, positions, cell = _create_random_system(n_atoms=75, cell_size=5.0)
types, positions, cell, pbc = _create_random_system(n_atoms=75, cell_size=5.0)
positions.requires_grad = True

X = _compute_power_spectrum(types, positions, cell)
X = _compute_power_spectrum(types, positions, cell, pbc)
weights = torch.rand((X.shape[-1], 1), requires_grad=True, dtype=torch.float64)

def compute(weights):
X = _compute_power_spectrum(types, positions, cell)
X = _compute_power_spectrum(types, positions, cell, pbc)
A = X @ weights

return torch.autograd.grad(
Expand Down Expand Up @@ -184,14 +186,14 @@ def compute(weights):


def test_power_spectrum_cell_grad_grad():
types, positions, cell = _create_random_system(n_atoms=75, cell_size=5.0)
types, positions, cell, pbc = _create_random_system(n_atoms=75, cell_size=5.0)
cell.requires_grad = True

X = _compute_power_spectrum(types, positions, cell)
X = _compute_power_spectrum(types, positions, cell, pbc)
weights = torch.rand((X.shape[-1], 1), requires_grad=True, dtype=torch.float64)

def compute(weights):
X = _compute_power_spectrum(types, positions, cell)
X = _compute_power_spectrum(types, positions, cell, pbc)
A = X @ weights

return torch.autograd.grad(
Expand Down Expand Up @@ -244,7 +246,7 @@ def test_different_device_dtype():
options.append((torch.device("cuda:0"), torch.float64))

for device, dtype in options:
types, positions, cell = _create_random_system(n_atoms=10, cell_size=3.0)
types, positions, cell, pbc = _create_random_system(n_atoms=10, cell_size=3.0)
positions = positions.to(dtype=dtype, device=device, copy=True)
positions.requires_grad = True
assert positions.grad is None
Expand All @@ -254,11 +256,12 @@ def test_different_device_dtype():
assert cell.grad is None

types = types.to(device=device, copy=True)
pbc = pbc.to(device=device, copy=True)

with warnings.catch_warnings():
warnings.filterwarnings("ignore")

X = _compute_power_spectrum(types, positions, cell)
X = _compute_power_spectrum(types, positions, cell, pbc)

assert X.dtype == dtype
assert X.device == device
Expand Down
2 changes: 1 addition & 1 deletion python/rascaline-torch/tests/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def system():
types=torch.tensor([1, 1, 8, 8]),
positions=torch.tensor([[0.0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 0, 3]]),
cell=torch.tensor([[10.0, 0, 0], [0, 10, 0], [0, 0, 10]]),
pbc=torch.tensor([True, True, True]),
)


Expand Down Expand Up @@ -188,7 +189,6 @@ def test_different_device_dtype_errors(system):
custom_device = torch.device("cuda:0")

if custom_device is not None:

device_system = system.to(device=custom_device)

torch.set_warn_always(True)
Expand Down
3 changes: 2 additions & 1 deletion python/rascaline-torch/tests/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,10 @@ def test_export_as_metatensor_model(tmpdir):
model = Model(types=[1, 6, 8])
model.eval()

energy_output = ModelOutput()
energy_output = ModelOutput(quantity="energy", unit="eV")
capabilities = ModelCapabilities(
supported_devices=["cpu"],
length_unit="A",
interaction_range=HYPERS["cutoff"],
atomic_types=[1, 6, 8],
dtype="float64",
Expand Down
7 changes: 3 additions & 4 deletions python/rascaline-torch/tests/utils/cg_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,10 @@ def system():
[2.56633400, 2.50000000, 2.50370100],
[1.97361700, 1.73067300, 2.47063400],
[1.97361700, 3.26932700, 2.47063400],
],
),
cell=torch.zeros(
(3, 3),
]
),
cell=torch.zeros((3, 3)),
pbc=torch.tensor([False, False, False]),
)


Expand Down
1 change: 1 addition & 0 deletions python/rascaline-torch/tests/utils/density_correlations.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def system():
]
),
cell=torch.zeros((3, 3)),
pbc=torch.tensor([False, False, False]),
)


Expand Down
1 change: 1 addition & 0 deletions python/rascaline-torch/tests/utils/power_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def system():
types=torch.tensor([1, 1, 8, 8]),
positions=torch.tensor([[0.0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 0, 3]]),
cell=torch.tensor([[10.0, 0, 0], [0, 10, 0], [0, 0, 10]]),
pbc=torch.tensor([True, True, True]),
)


Expand Down
4 changes: 2 additions & 2 deletions rascaline-c-api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ endif()
# ============================================================================ #
# Setup metatensor

set(METATENSOR_FETCH_VERSION "0.1.10")
set(METATENSOR_FETCH_VERSION "0.1.11")
set(METATENSOR_REQUIRED_VERSION "0.1")
if (RASCALINE_FETCH_METATENSOR)
message(STATUS "Fetching metatensor @ ${METATENSOR_FETCH_VERSION} from github")
Expand All @@ -232,7 +232,7 @@ if (RASCALINE_FETCH_METATENSOR)
FetchContent_Declare(
metatensor
URL ${URL_ROOT}/metatensor-core-v${METATENSOR_FETCH_VERSION}/metatensor-core-cxx-${METATENSOR_FETCH_VERSION}.tar.gz
URL_HASH SHA256=3ec0775da67bb0eb3246b81770426e612f83b6591442a39eb17aad6969b5f9d9
URL_HASH SHA256=b79435dbeb59a95361b4a5881574dfde4a4ed13f12b05a225df161dcaa7ea61e
)

if (CMAKE_VERSION VERSION_GREATER 3.18)
Expand Down
6 changes: 3 additions & 3 deletions rascaline-torch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ find_package(Torch 1.12 REQUIRED)
# ============================================================================ #
# Setup metatensor_torch

set(METATENSOR_FETCH_VERSION "0.5.5")
set(REQUIRED_METATENSOR_TORCH_VERSION "0.5")
set(METATENSOR_FETCH_VERSION "0.6.0")
set(REQUIRED_METATENSOR_TORCH_VERSION "0.6")
if (RASCALINE_TORCH_FETCH_METATENSOR_TORCH)
message(STATUS "Fetching metatensor-torch @ ${METATENSOR_FETCH_VERSION} from github")

Expand All @@ -68,7 +68,7 @@ if (RASCALINE_TORCH_FETCH_METATENSOR_TORCH)
FetchContent_Declare(
metatensor_torch
URL ${URL_ROOT}/metatensor-torch-v${METATENSOR_FETCH_VERSION}/metatensor-torch-cxx-${METATENSOR_FETCH_VERSION}.tar.gz
URL_HASH SHA256=dac306ab59ac8b59167827405f468397dbf0d4a69988fce7b9f4285f2816a57c
URL_HASH SHA256=f050743662ece38948b2087dd025d60110645716840dbfc5370c059e1275d0cf
)

if (CMAKE_VERSION VERSION_GREATER 3.18)
Expand Down
21 changes: 12 additions & 9 deletions rascaline-torch/src/system.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ SystemAdapter::SystemAdapter(metatensor_torch::System system): system_(std::move
for (const auto& options: system_->known_neighbor_lists()) {
for (const auto& requestor: options->requestors()) {
if (requestor == "rascaline") {
auto cutoff = options->cutoff();
auto neighbors = system_->get_neighbor_list(options);
auto samples_values = neighbors->samples()->values().to(torch::kCPU).contiguous();
auto samples = samples_values.accessor<int32_t, 2>();
Expand All @@ -24,7 +25,6 @@ SystemAdapter::SystemAdapter(metatensor_torch::System system): system_(std::move
auto n_pairs = samples.size(0);

auto pairs = std::vector<rascal_pair_t>();
pairs.reserve(static_cast<size_t>(n_pairs));
for (int64_t i=0; i<n_pairs; i++) {
auto x = distances[i][0];
auto y = distances[i][1];
Expand All @@ -38,18 +38,21 @@ SystemAdapter::SystemAdapter(metatensor_torch::System system): system_(std::move
assert(pair.second < this->size());

pair.distance = std::sqrt(x*x + y*y + z*z);
pair.vector[0] = x;
pair.vector[1] = y;
pair.vector[2] = z;

pair.cell_shift_indices[0] = samples[i][2];
pair.cell_shift_indices[1] = samples[i][3];
pair.cell_shift_indices[2] = samples[i][4];
if (pair.distance < cutoff) {
pair.vector[0] = x;
pair.vector[1] = y;
pair.vector[2] = z;

pairs.emplace_back(pair);
pair.cell_shift_indices[0] = samples[i][2];
pair.cell_shift_indices[1] = samples[i][3];
pair.cell_shift_indices[2] = samples[i][4];

pairs.emplace_back(pair);
}
}

this->set_precomputed_pairs(options->cutoff(), std::move(pairs));
this->set_precomputed_pairs(cutoff, std::move(pairs));
continue;
}
}
Expand Down
4 changes: 3 additions & 1 deletion rascaline-torch/tests/calculator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,5 +299,7 @@ metatensor_torch::System test_system(bool positions_grad, bool cell_grad) {
auto cell = 10 * torch::eye(3);
cell.requires_grad_(cell_grad);

return torch::make_intrusive<metatensor_torch::SystemHolder>(types, positions, cell);
auto pbc = torch::ones(3, torch::TensorOptions().dtype(torch::kBool));

return torch::make_intrusive<metatensor_torch::SystemHolder>(types, positions, cell, pbc);
}
3 changes: 2 additions & 1 deletion rascaline-torch/tests/cmake-project/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ int main() {
auto system = torch::make_intrusive<metatensor_torch::SystemHolder>(
torch::zeros({5}, torch::kI32),
torch::rand({5, 3}, torch::kF64),
torch::zeros({3, 3}, torch::kF64)
torch::zeros({3, 3}, torch::kF64),
torch::zeros({3}, torch::kBool)
);

const auto* HYPERS_JSON = R"({
Expand Down
5 changes: 2 additions & 3 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ metatensor-core-requirement =
metatensor-core >=0.1.0,<0.2.0

metatensor-torch-requirement =
metatensor-torch >=0.5.0,<0.6.0
metatensor-torch >=0.6.0,<0.7.0

build-single-wheel = --no-deps --no-build-isolation --check-build-dependencies
warning_options =
Expand Down Expand Up @@ -77,8 +77,7 @@ deps =
torch
pyscf;platform_system!="Windows"
wigners
# TODO: add mops once it becomes stable enough (and potentially supports windows)
#mops@git+https://github.com/lab-cosmo/mops ; platform_system!="Windows"

commands =
pytest {[testenv]test_options} {posargs}

Expand Down

0 comments on commit dd4e110

Please sign in to comment.