Skip to content

Commit

Permalink
Update to metatensor-core v0.1.10
Browse files Browse the repository at this point in the history
And use Labels::select where relevant
  • Loading branch information
Luthaf committed Sep 17, 2024
1 parent 71cb929 commit 064daac
Show file tree
Hide file tree
Showing 11 changed files with 107 additions and 75 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
matrix:
include:
- os: ubuntu-22.04
python-version: "3.8"
python-version: "3.9"
- os: ubuntu-22.04
python-version: "3.12"
- os: macos-14
Expand Down Expand Up @@ -68,7 +68,7 @@ jobs:
name: Python ${{ matrix.python-version }} / check build
strategy:
matrix:
python-version: ['3.8', '3.12']
python-version: ['3.9', '3.12']
os: [ubuntu-22.04]
steps:
- uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/torch-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
include:
- os: ubuntu-22.04
torch-version: 1.12.*
python-version: "3.8"
python-version: "3.9"
cargo-test-flags: --release

- os: ubuntu-22.04
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "rascaline"
dynamic = ["version", "authors", "optional-dependencies"]
requires-python = ">=3.8"
requires-python = ">=3.9"

readme = "README.rst"
license = {text = "BSD-3-Clause"}
Expand Down
2 changes: 1 addition & 1 deletion python/rascaline-torch/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "rascaline-torch"
dynamic = ["version", "authors", "dependencies"]
requires-python = ">=3.8"
requires-python = ">=3.9"

readme = "README.rst"
license = {text = "BSD-3-Clause"}
Expand Down
61 changes: 47 additions & 14 deletions python/rascaline/tests/calculators/keys_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from metatensor import Labels, TensorBlock, TensorMap

from rascaline import RascalError
from rascaline.calculators import DummyCalculator
from rascaline.calculators import DummyCalculator, SphericalExpansion

from ..test_systems import SystemForTests

Expand Down Expand Up @@ -38,17 +38,54 @@ def test_selection_existing():
system = SystemForTests()
calculator = DummyCalculator(cutoff=3.2, delta=2, name="")

# no selection
descriptor = calculator.compute(system, selected_keys=None)
assert len(descriptor.keys) == 2
assert descriptor.keys.values.tolist() == [[1], [8]]

# Manually select the keys
selected_keys = Labels(
names=["center_type"],
values=np.array([[1]], dtype=np.int32),
)
descriptor = calculator.compute(
system, use_native_system=False, selected_keys=selected_keys
)
descriptor = calculator.compute(system, selected_keys=selected_keys)

assert len(descriptor.keys) == 1
assert tuple(descriptor.keys[0]) == (1,)
assert descriptor.keys[0].values.tolist() == [1]


def test_selection_partial():
system = SystemForTests()
calculator = SphericalExpansion(
cutoff=2.5,
max_radial=1,
max_angular=1,
atomic_gaussian_width=0.2,
radial_basis={"Gto": {}},
cutoff_function={"ShiftedCosine": {"width": 0.5}},
center_atom_weight=1.0,
)

# Manually select the keys
selected_keys = Labels(
names=["center_type"],
values=np.array([[1]], dtype=np.int32),
)
descriptor = calculator.compute(system, selected_keys=selected_keys)

assert len(descriptor.keys) == 6
assert descriptor.keys.names == [
"o3_lambda",
"o3_sigma",
"center_type",
"neighbor_type",
]
assert descriptor.keys.values.tolist() == [
[0, 1, 1, 1],
[1, 1, 1, 1],
[0, 1, 1, 8],
[1, 1, 1, 8],
]


def test_select_key_not_in_systems():
Expand All @@ -60,9 +97,7 @@ def test_select_key_not_in_systems():
names=["center_type"],
values=np.array([[4]], dtype=np.int32),
)
descriptor = calculator.compute(
system, use_native_system=False, selected_keys=selected_keys
)
descriptor = calculator.compute(system, selected_keys=selected_keys)

C_block = descriptor.block(center_type=4)
assert C_block.values.shape == (0, 2)
Expand Down Expand Up @@ -97,7 +132,6 @@ def test_predefined_selection():

descriptor = calculator.compute(
system,
use_native_system=False,
selected_properties=selected_properties,
selected_keys=selected_keys,
)
Expand All @@ -119,11 +153,11 @@ def test_name_errors():
)

message = (
"invalid parameter: names for the keys of the calculator "
"\\[center_type\\] and selected keys \\[bad_name\\] do not match"
"invalid parameter: 'bad_name' in keys selection is not "
"part of the keys of this calculator"
)
with pytest.raises(RascalError, match=message):
calculator.compute(system, use_native_system=False, selected_keys=selected_keys)
calculator.compute(system, selected_keys=selected_keys)


def test_key_errors():
Expand All @@ -137,7 +171,7 @@ def test_key_errors():

message = "invalid parameter: selected keys can not be empty"
with pytest.raises(RascalError, match=message):
calculator.compute(system, use_native_system=False, selected_keys=selected_keys)
calculator.compute(system, selected_keys=selected_keys)

# in the case where both selected_properties/selected_samples and
# selected_keys are given, the selected keys must be in the keys of the
Expand Down Expand Up @@ -171,7 +205,6 @@ def test_key_errors():
with pytest.raises(RascalError, match=message):
calculator.compute(
system,
use_native_system=False,
selected_properties=selected_properties,
selected_keys=selected_keys,
)
2 changes: 1 addition & 1 deletion python/rascaline/tests/calculators/properties_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def test_errors():

message = (
"invalid parameter: 'bad_name' in properties selection is not "
"one of the properties of this calculator"
"part of the properties of this calculator"
)
with pytest.raises(RascalError, match=message):
calculator.compute(
Expand Down
2 changes: 1 addition & 1 deletion python/rascaline/tests/calculators/sample_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_errors():
)

message = (
"invalid parameter: 'bad_name' in samples selection is not one "
"invalid parameter: 'bad_name' in samples selection is not part "
"of the samples of this calculator"
)
with pytest.raises(RascalError, match=message):
Expand Down
4 changes: 2 additions & 2 deletions rascaline-c-api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ endif()
# ============================================================================ #
# Setup metatensor

set(METATENSOR_FETCH_VERSION "0.1.7")
set(METATENSOR_FETCH_VERSION "0.1.10")
set(METATENSOR_REQUIRED_VERSION "0.1")
if (RASCALINE_FETCH_METATENSOR)
message(STATUS "Fetching metatensor @ ${METATENSOR_FETCH_VERSION} from github")
Expand All @@ -232,7 +232,7 @@ if (RASCALINE_FETCH_METATENSOR)
FetchContent_Declare(
metatensor
URL ${URL_ROOT}/metatensor-core-v${METATENSOR_FETCH_VERSION}/metatensor-core-cxx-${METATENSOR_FETCH_VERSION}.tar.gz
URL_HASH SHA256=005c39aefdd5aaf8a7596b78ac01688976070795c4db21e20e6b8db4f2421e97
URL_HASH SHA256=3ec0775da67bb0eb3246b81770426e612f83b6591442a39eb17aad6969b5f9d9
)

if (CMAKE_VERSION VERSION_GREATER 3.18)
Expand Down
4 changes: 2 additions & 2 deletions rascaline-torch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ find_package(Torch 1.12 REQUIRED)
# ============================================================================ #
# Setup metatensor_torch

set(METATENSOR_FETCH_VERSION "0.5.0")
set(METATENSOR_FETCH_VERSION "0.5.5")
set(REQUIRED_METATENSOR_TORCH_VERSION "0.5")
if (RASCALINE_TORCH_FETCH_METATENSOR_TORCH)
message(STATUS "Fetching metatensor-torch @ ${METATENSOR_FETCH_VERSION} from github")
Expand All @@ -68,7 +68,7 @@ if (RASCALINE_TORCH_FETCH_METATENSOR_TORCH)
FetchContent_Declare(
metatensor_torch
URL ${URL_ROOT}/metatensor-torch-v${METATENSOR_FETCH_VERSION}/metatensor-torch-cxx-${METATENSOR_FETCH_VERSION}.tar.gz
URL_HASH SHA256=904cf858d8f98b67b948e8a453d8a6da56111e022050d6c8c3d32a9a2cc83464
URL_HASH SHA256=dac306ab59ac8b59167827405f468397dbf0d4a69988fce7b9f4285f2816a57c
)

if (CMAKE_VERSION VERSION_GREATER 3.18)
Expand Down
2 changes: 1 addition & 1 deletion rascaline/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ name = "soap-power-spectrum"
harness = false

[dependencies]
metatensor = {version = "0.1", features = ["rayon"]}
metatensor = {version = "0.1.6", features = ["rayon"]}

ndarray = {version = "0.15", features = ["approx-0_5", "rayon", "serde"]}
num-traits = "0.2"
Expand Down
97 changes: 48 additions & 49 deletions rascaline/src/calculator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ use std::collections::BTreeMap;
use std::convert::TryFrom;

use log::warn;
use metatensor::c_api::MTS_INVALID_PARAMETER_ERROR;
use once_cell::sync::Lazy;

use metatensor::{LabelValue, Labels, LabelsBuilder};
use metatensor::{Labels, LabelsBuilder};
use metatensor::{TensorBlockRef, TensorBlock, TensorMap};
use ndarray::ArrayD;

Expand Down Expand Up @@ -43,6 +44,30 @@ pub enum LabelsSelection<'a> {
Predefined(&'a TensorMap),
}

fn map_selection_error<'a>(
default_names: &'a [&str],
selection_names: &'a [&str],
label_kind: &'a str
) -> impl FnOnce(metatensor::Error) -> Error + 'a{
return move |err| {
match err.code {
Some(MTS_INVALID_PARAMETER_ERROR) => {
for name in selection_names {
if !default_names.contains(name) {
return Error::InvalidParameter(format!(
"'{}' in {} selection is not part of the {} of this calculator",
name, label_kind, label_kind
));
}
}
// it was some other error, bubble it up
Error::from(err)
}
_ => Error::from(err)
}
};
}

impl<'a> LabelsSelection<'a> {
fn select<'call, F, G, H>(
&self,
Expand All @@ -69,45 +94,17 @@ impl<'a> LabelsSelection<'a> {
let default_names = get_default_names();

let mut results = Vec::new();
if selection.names() == default_names {
for labels in default_labels {
let mut builder = LabelsBuilder::new(default_names.clone());
for entry in selection {
if labels.contains(entry) {
builder.add(entry);
}
}
results.push(builder.finish());
}
} else {
let mut dimensions_to_match = Vec::new();
for variable in selection.names() {
let i = match default_names.iter().position(|&v| v == variable) {
Some(index) => index,
None => {
return Err(Error::InvalidParameter(format!(
"'{}' in {} selection is not one of the {} of this calculator",
variable, label_kind, label_kind
)))
}
};
dimensions_to_match.push(i);
}
for labels in default_labels {
let mut builder = LabelsBuilder::new(default_names.clone());

let mut candidate = vec![LabelValue::new(0); dimensions_to_match.len()];
for labels in default_labels {
let mut builder = LabelsBuilder::new(default_names.clone());
for entry in &labels {
for (i, &v) in dimensions_to_match.iter().enumerate() {
candidate[i] = entry[v];
}

if selection.contains(&candidate) {
builder.add(entry);
}
}
results.push(builder.finish());
// better error message in case of un-matched names
let matches = labels.select(selection)
.map_err(map_selection_error(&default_names, &selection.names(), label_kind))?;

for entry in matches {
builder.add(&labels[entry as usize]);
}
results.push(builder.finish());
}

return Ok(results);
Expand Down Expand Up @@ -292,18 +289,20 @@ impl Calculator {
let default_keys = self.implementation.keys(systems)?;

let keys = match options.selected_keys {
Some(keys) if keys.is_empty() => {
return Err(Error::InvalidParameter("selected keys can not be empty".into()));
}
Some(keys) => {
if default_keys.names() == keys.names() {
keys.clone()

Some(selection) => {
if selection.is_empty() {
return Err(Error::InvalidParameter("selected keys can not be empty".into()));
} else if default_keys.names() == selection.names() {
selection.clone()
} else {
return Err(Error::InvalidParameter(format!(
"names for the keys of the calculator [{}] and selected keys [{}] do not match",
default_keys.names().join(", "),
keys.names().join(", "))
));
let mut builder = LabelsBuilder::new(default_keys.names());
let matches = default_keys.select(selection)
.map_err(map_selection_error(&default_keys.names(), &selection.names(), "keys"))?;
for entry in matches {
builder.add(&default_keys[entry as usize]);
}
builder.finish()
}
}
None => default_keys,
Expand Down

0 comments on commit 064daac

Please sign in to comment.