Skip to content

Commit

Permalink
Handle the case where selections create an empty TensorMap
Browse files Browse the repository at this point in the history
  • Loading branch information
Luthaf committed Oct 10, 2024
1 parent fb7f363 commit fdbe962
Show file tree
Hide file tree
Showing 7 changed files with 12 additions and 2 deletions.
4 changes: 3 additions & 1 deletion rascaline/src/calculator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,9 @@ impl Calculator {

let mut tensor = self.prepare(systems, options)?;

self.implementation.compute(systems, &mut tensor)?;
if tensor.keys().count() > 0 {
self.implementation.compute(systems, &mut tensor)?;
}

return Ok(tensor);
}
Expand Down
2 changes: 2 additions & 0 deletions rascaline/src/calculators/soap/power_spectrum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,8 @@ impl CalculatorBase for SoapPowerSpectrum {
#[time_graph::instrument(name = "SoapPowerSpectrum::compute")]
#[allow(clippy::too_many_lines)]
fn compute(&mut self, systems: &mut [Box<dyn System>], descriptor: &mut TensorMap) -> Result<(), Error> {
assert!(descriptor.keys().count() > 0);

let mut gradients = Vec::new();
if descriptor.block_by_id(0).gradient("positions").is_some() {
gradients.push("positions");
Expand Down
2 changes: 2 additions & 0 deletions rascaline/src/calculators/soap/radial_spectrum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ impl CalculatorBase for SoapRadialSpectrum {
#[time_graph::instrument(name = "SoapRadialSpectrum::compute")]
fn compute(&mut self, systems: &mut [Box<dyn System>], descriptor: &mut TensorMap) -> Result<(), Error> {
assert_eq!(descriptor.keys().names(), ["center_type", "neighbor_type"]);
assert!(descriptor.keys().count() > 0);

let mut gradients = Vec::new();
if descriptor.block_by_id(0).gradient("positions").is_some() {
gradients.push("positions");
Expand Down
1 change: 1 addition & 0 deletions rascaline/src/calculators/soap/spherical_expansion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,7 @@ impl CalculatorBase for SphericalExpansion {
#[time_graph::instrument(name = "SphericalExpansion::compute")]
fn compute(&mut self, systems: &mut [Box<dyn System>], descriptor: &mut TensorMap) -> Result<(), Error> {
assert_eq!(descriptor.keys().names(), ["o3_lambda", "o3_sigma", "center_type", "neighbor_type"]);
assert!(descriptor.keys().count() > 0);

let do_gradients = GradientsOptions {
positions: descriptor.block_by_id(0).gradient("positions").is_some(),
Expand Down
1 change: 1 addition & 0 deletions rascaline/src/calculators/soap/spherical_expansion_pair.rs
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,7 @@ impl CalculatorBase for SphericalExpansionByPair {
#[time_graph::instrument(name = "SphericalExpansionByPair::compute")]
fn compute(&mut self, systems: &mut [Box<dyn System>], descriptor: &mut TensorMap) -> Result<(), Error> {
assert_eq!(descriptor.keys().names(), ["o3_lambda", "o3_sigma", "first_atom_type", "second_atom_type"]);
assert!(descriptor.keys().count() > 0);

let do_gradients = GradientsOptions {
positions: descriptor.block_by_id(0).gradient("positions").is_some(),
Expand Down
1 change: 1 addition & 0 deletions rascaline/src/tutorials/moments/moments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ impl CalculatorBase for GeometricMoments {
// [compute]
fn compute(&mut self, systems: &mut [Box<dyn System>], descriptor: &mut TensorMap) -> Result<(), Error> {
assert_eq!(descriptor.keys().names(), ["center_type", "neighbor_type"]);
assert!(descriptor.keys().count() > 0);

let do_positions_gradients = descriptor.block_by_id(0).gradient("positions").is_some();

Expand Down
3 changes: 2 additions & 1 deletion rascaline/src/tutorials/moments/s3_compute_5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ impl CalculatorBase for GeometricMoments {
fn compute(&mut self, systems: &mut [Box<dyn System>], descriptor: &mut TensorMap) -> Result<(), Error> {
// ...

// add this line
// add these lines
assert!(descriptor.keys().count() > 0);
let do_positions_gradients = descriptor.block_by_id(0).gradient("positions").is_some();

for (system_i, system) in systems.iter_mut().enumerate() {
Expand Down

0 comments on commit fdbe962

Please sign in to comment.