From 20def6c517213d55886fde6f759b66a6d9173bf5 Mon Sep 17 00:00:00 2001 From: Guillaume Fraux Date: Mon, 14 Oct 2024 18:16:58 +0200 Subject: [PATCH] Implement an Explicit basis, containing a map from angular channels to radial basis --- docs/extensions/rascaline_json_schema.py | 31 ++++- docs/rascaline-json-schema/Cargo.toml | 2 +- docs/rascaline-json-schema/main.rs | 42 ++++--- docs/src/references/api/python/basis.rst | 3 + python/rascaline/rascaline/basis.py | 41 ++++++- python/rascaline/tests/calculators/hypers.py | 17 +-- rascaline/Cargo.toml | 2 +- .../calculators/lode/radial_integral/mod.rs | 94 ++++++++++----- .../lode/radial_integral/spline.rs | 5 +- .../calculators/lode/spherical_expansion.rs | 75 ++++++++++-- rascaline/src/calculators/shared/basis/mod.rs | 72 +++++++++++- rascaline/src/calculators/shared/mod.rs | 2 +- .../src/calculators/soap/power_spectrum.rs | 12 ++ .../calculators/soap/radial_integral/mod.rs | 109 ++++++++++++------ .../src/calculators/soap/radial_spectrum.rs | 13 ++- .../calculators/soap/spherical_expansion.rs | 60 +++++++++- .../soap/spherical_expansion_pair.rs | 72 ++++++++++-- 17 files changed, 521 insertions(+), 131 deletions(-) diff --git a/docs/extensions/rascaline_json_schema.py b/docs/extensions/rascaline_json_schema.py index 64a365c1b..19a4b4bd8 100644 --- a/docs/extensions/rascaline_json_schema.py +++ b/docs/extensions/rascaline_json_schema.py @@ -42,11 +42,11 @@ def run(self, *args, **kwargs): section.insert(1, schema_node) # add missing entries to self._definitions - for name in schema.get("definitions", {}).keys(): + for name in schema.get("$defs", {}).keys(): self._definition_used(name) for name in self._definitions.keys(): - definition = schema["definitions"][name] + definition = schema["$defs"][name] target, subsection = self._transform(definition, name) section += target @@ -112,7 +112,7 @@ def _json_schema_to_nodes( assert "allOf" not in schema ref = schema["$ref"] - assert ref.startswith("#/definitions/") + assert ref.startswith("#/$defs/") type_name = ref.split("/")[-1] self._definition_used(type_name) @@ -228,6 +228,10 @@ def _json_schema_to_nodes( field_list += body + additional = schema.get("additionalProperties") + if additional is not None: + pass + return field_list else: object_node = nodes.inline() @@ -261,6 +265,25 @@ def _json_schema_to_nodes( object_node += field + additional = schema.get("additionalProperties") + if isinstance(additional, dict): + # JSON Schema does not have a concept of key type being anything + # else than string. In rascaline, we annotate `HashMap` with a + # custom `x-key-type` to carry this information all the way to + # here + key_type = schema.get("x-key-type") + if key_type is None: + key_type = "string" + + field = nodes.inline() + field += nodes.Text("[key: ") + field += nodes.literal(text=key_type) + field += nodes.Text("]: ") + + field += self._json_schema_to_nodes(additional) + + object_node += field + object_node += nodes.Text("}") return object_node @@ -285,6 +308,8 @@ def _json_schema_to_nodes( if "enum" in schema: values = [f'"{v}"' for v in schema["enum"]] return nodes.literal(text=" | ".join(values)) + elif "const" in schema: + return nodes.Text('"' + schema["const"] + '"') else: return nodes.literal(text="string") diff --git a/docs/rascaline-json-schema/Cargo.toml b/docs/rascaline-json-schema/Cargo.toml index bff70a568..fbccd0ae7 100644 --- a/docs/rascaline-json-schema/Cargo.toml +++ b/docs/rascaline-json-schema/Cargo.toml @@ -14,5 +14,5 @@ test = false [dependencies] rascaline = {path = "../../rascaline"} -schemars = "0.8.6" +schemars = "=1.0.0-alpha.15" serde_json = "1" diff --git a/docs/rascaline-json-schema/main.rs b/docs/rascaline-json-schema/main.rs index b82a02984..03d617243 100644 --- a/docs/rascaline-json-schema/main.rs +++ b/docs/rascaline-json-schema/main.rs @@ -1,6 +1,6 @@ use std::path::PathBuf; -use schemars::schema::RootSchema; +use schemars::Schema; use rascaline::calculators::AtomicComposition; use rascaline::calculators::SortedDistances; @@ -45,30 +45,40 @@ struct RenameRefInSchema { in_docs: &'static str, } -impl schemars::visit::Visitor for RenameRefInSchema { - fn visit_schema_object(&mut self, schema: &mut schemars::schema::SchemaObject) { - schemars::visit::visit_schema_object(self, schema); - - let in_code_reference = format!("#/definitions/{}", self.in_code); - - if let Some(reference) = &schema.reference { - if reference == &in_code_reference { - schema.reference = Some(format!("#/definitions/{}", self.in_docs)); +impl schemars::transform::Transform for RenameRefInSchema { + fn transform(&mut self, schema: &mut Schema) { + let in_code_reference = format!("#/$defs/{}", self.in_code); + if let Some(schema_object) = schema.as_object_mut() { + if let Some(reference) = schema_object.get_mut("$ref") { + if reference == &in_code_reference { + *reference = format!("#/$defs/{}", self.in_docs).into(); + } } } + schemars::transform::transform_subschemas(self, schema); } } -fn save_schema(name: &str, mut schema: RootSchema) { - for transform in REFS_TO_RENAME { - if let Some(value) = schema.definitions.remove(transform.in_code) { - assert!(!schema.definitions.contains_key(transform.in_docs)); - schema.definitions.insert(transform.in_docs.into(), value); +fn save_schema(name: &str, mut schema: Schema) { + let schema_object = schema.as_object_mut().expect("schema should be an object"); - schemars::visit::visit_root_schema(&mut transform.clone(), &mut schema); + // rename some of the autogenerate names. + // Step 1: rename the definitions + for transform in REFS_TO_RENAME { + if let Some(definitions) = schema_object.get_mut("$defs") { + let definitions = definitions.as_object_mut().expect("$defs should be an object"); + if let Some(value) = definitions.remove(transform.in_code) { + assert!(!definitions.contains_key(transform.in_docs)); + definitions.insert(transform.in_docs.into(), value); + } } } + // Step 2: rename the references to these definitions + for transform in REFS_TO_RENAME { + schemars::transform::transform_subschemas(&mut transform.clone(), &mut schema); + } + let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); path.pop(); path.push("build"); diff --git a/docs/src/references/api/python/basis.rst b/docs/src/references/api/python/basis.rst index 7c670bcbf..2a1fef155 100644 --- a/docs/src/references/api/python/basis.rst +++ b/docs/src/references/api/python/basis.rst @@ -4,6 +4,9 @@ Basis functions .. autoclass:: rascaline.basis.TensorProduct :show-inheritance: +.. autoclass:: rascaline.basis.Explicit + :show-inheritance: + .. autoclass:: rascaline.basis.ExpansionBasis :members: diff --git a/python/rascaline/rascaline/basis.py b/python/rascaline/rascaline/basis.py index 766fc35bd..7f05d9836 100644 --- a/python/rascaline/rascaline/basis.py +++ b/python/rascaline/rascaline/basis.py @@ -1,5 +1,5 @@ import abc -from typing import Optional +from typing import Optional, Dict import numpy as np @@ -355,7 +355,7 @@ def Jn_zeros(angular_size: int, radial_size: int) -> np.ndarray: ######################################################################################## -class ExpansionBasis(metaclass=abc.ABCMeta): +class ExpansionBasis: """ Base class representing a set of basis functions used by spherical expansions. @@ -420,3 +420,40 @@ def get_hypers(self): "radial": self.radial, "spline_accuracy": self.spline_accuracy, } + + +class Explicit(ExpansionBasis): + r""" + An expansion basis where combinations of radial and angular functions is picked + explicitly. + + The angular basis functions are still spherical harmonics, but only the degrees + included as keys in ``by_angular`` will be part of the output. Each of these angular + basis function can then be associated with a set of different radial basis function, + potentially of different sizes. + """ + + def __init__( + self, + *, + by_angular: Dict[int, RadialBasis], + spline_accuracy: Optional[float] = 1e-8, + ): + self.by_angular = by_angular + + if spline_accuracy is None: + self.spline_accuracy = None + else: + self.spline_accuracy = float(spline_accuracy) + assert self.spline_accuracy > 0 + + for angular, radial in self.by_angular.items(): + assert angular >= 0 + assert isinstance(radial, RadialBasis) + + def get_hypers(self): + return { + "type": "Explicit", + "by_angular": self.by_angular, + "spline_accuracy": self.spline_accuracy, + } diff --git a/python/rascaline/tests/calculators/hypers.py b/python/rascaline/tests/calculators/hypers.py index 68a05386c..0dd4cf204 100644 --- a/python/rascaline/tests/calculators/hypers.py +++ b/python/rascaline/tests/calculators/hypers.py @@ -97,9 +97,11 @@ def test_hypers_classes(): }, }, "basis": { - "type": "TensorProduct", - "max_angular": 5, - "radial": {"type": "Gto", "max_radial": 5}, + "type": "Explicit", + "by_angular": { + 3: {"type": "Gto", "max_radial": 5}, + 5: {"type": "Gto", "max_radial": 2}, + }, }, } @@ -115,9 +117,11 @@ def test_hypers_classes(): scaling=rascaline.density.Willatt2018(exponent=3, rate=2.2, scale=1.1), center_atom_weight=0.3, ), - basis=rascaline.basis.TensorProduct( - max_angular=5, - radial=rascaline.basis.Gto(max_radial=5), + basis=rascaline.basis.Explicit( + by_angular={ + 3: rascaline.basis.Gto(max_radial=5), + 5: rascaline.basis.Gto(max_radial=2), + } ), ) @@ -126,7 +130,6 @@ def test_hypers_classes(): def test_hypers_custom_classes_errors(): - class MyCustomSmoothing(rascaline.cutoff.SmoothingFunction): def compute(self, cutoff, positions, derivative): pass diff --git a/rascaline/Cargo.toml b/rascaline/Cargo.toml index d31a68504..45596d82f 100644 --- a/rascaline/Cargo.toml +++ b/rascaline/Cargo.toml @@ -46,7 +46,7 @@ time-graph = "0.3.0" serde = { version = "1", features = ["derive"] } serde_json = "1" -schemars = "0.8" +schemars = "=1.0.0-alpha.15" chemfiles = {version = "0.10", optional = true} diff --git a/rascaline/src/calculators/lode/radial_integral/mod.rs b/rascaline/src/calculators/lode/radial_integral/mod.rs index c4761242a..dc4d1ea5c 100644 --- a/rascaline/src/calculators/lode/radial_integral/mod.rs +++ b/rascaline/src/calculators/lode/radial_integral/mod.rs @@ -52,6 +52,45 @@ pub struct LodeRadialIntegralCache { } impl LodeRadialIntegralCache { + fn new( + o3_lambda: usize, + radial: &LodeRadialBasis, + density: DensityKind, + k_cutoff: f64, + spline_accuracy: Option, + ) -> Result { + let implementation = match radial { + LodeRadialBasis::Gto { .. } => { + let gto = LodeRadialIntegralGto::new(radial, o3_lambda)?; + + if let Some(accuracy) = spline_accuracy { + let do_center_contribution = o3_lambda == 0; + Box::new(LodeRadialIntegralSpline::with_accuracy( + gto, density, k_cutoff, accuracy, do_center_contribution + )?) + } else { + Box::new(gto) as Box + } + }, + LodeRadialBasis::Tabulated(ref tabulated) => { + Box::new(LodeRadialIntegralSpline::from_tabulated( + tabulated.clone(), + density, + )) as Box + } + }; + + let size = implementation.size(); + let values = Array1::from_elem(size, 0.0); + let gradients = Array1::from_elem(size, 0.0); + + return Ok(LodeRadialIntegralCache { + implementation, + values, + gradients, + }); + } + /// Run the calculation, the results are stored inside `self.values` and /// `self.gradients` pub fn compute(&mut self, k_norm: f64, do_gradients: bool) { @@ -91,43 +130,36 @@ impl LodeRadialIntegralCacheByAngular { SphericalExpansionBasis::TensorProduct(basis) => { let mut by_angular = BTreeMap::new(); for o3_lambda in 0..=basis.max_angular { - // We only support some specific radial basis - let implementation = match basis.radial { - LodeRadialBasis::Gto { .. } => { - let gto = LodeRadialIntegralGto::new(&basis.radial, o3_lambda)?; - - if let Some(accuracy) = basis.spline_accuracy { - let do_center_contribution = o3_lambda == 0; - Box::new(LodeRadialIntegralSpline::with_accuracy( - gto, density, k_cutoff, accuracy, do_center_contribution - )?) - } else { - Box::new(gto) as Box - } - }, - LodeRadialBasis::Tabulated(ref tabulated) => { - Box::new(LodeRadialIntegralSpline::from_tabulated( - tabulated.clone(), - density, - )) as Box - } - }; - - let size = implementation.size(); - let values = Array1::from_elem(size, 0.0); - let gradients = Array1::from_elem(size, 0.0); - - by_angular.insert(o3_lambda, LodeRadialIntegralCache { - implementation, - values, - gradients, - }); + let cache = LodeRadialIntegralCache::new( + o3_lambda, + &basis.radial, + density, + k_cutoff, + basis.spline_accuracy + )?; + by_angular.insert(o3_lambda, cache); } return Ok(LodeRadialIntegralCacheByAngular { by_angular }); } + SphericalExpansionBasis::Explicit(basis) => { + let mut by_angular = BTreeMap::new(); + for (&o3_lambda, radial) in &*basis.by_angular { + let cache = LodeRadialIntegralCache::new( + o3_lambda, + radial, + density, + k_cutoff, + basis.spline_accuracy + )?; + by_angular.insert(o3_lambda, cache); + } + return Ok(LodeRadialIntegralCacheByAngular { + by_angular + }); + } } } diff --git a/rascaline/src/calculators/lode/radial_integral/spline.rs b/rascaline/src/calculators/lode/radial_integral/spline.rs index 915a1b651..753e1c5f0 100644 --- a/rascaline/src/calculators/lode/radial_integral/spline.rs +++ b/rascaline/src/calculators/lode/radial_integral/spline.rs @@ -94,7 +94,10 @@ impl LodeRadialIntegral for LodeRadialIntegralSpline { } if self.center_contribution.is_none() { - return Err(Error::InvalidParameter("TODO".into())); + return Err(Error::InvalidParameter( + "`center_contribution` must be defined for the Tabulated radial \ + basis used with the L=0 angular channel".into() + )); } return Ok(self.center_contribution.clone().expect("just checked")); diff --git a/rascaline/src/calculators/lode/spherical_expansion.rs b/rascaline/src/calculators/lode/spherical_expansion.rs index c44e0bbc9..c5932ca4f 100644 --- a/rascaline/src/calculators/lode/spherical_expansion.rs +++ b/rascaline/src/calculators/lode/spherical_expansion.rs @@ -70,9 +70,9 @@ pub struct LodeSphericalExpansion { /// implementation + cached allocation to compute the radial integral radial_integral: ThreadLocal>, /// Cached allocations for the k-vector to nlm projection coefficients. - /// The vector contains different l values, and the Array is indexed by + /// The map contains different l values, and the Array is indexed by /// `m, n, k`. - k_vector_to_m_n: ThreadLocal>>>, + k_vector_to_m_n: ThreadLocal>>>, /// Cached allocation for everything that only depends on the k vector k_dependent_values: ThreadLocal>>, } @@ -241,9 +241,9 @@ impl LodeSphericalExpansion { }).borrow_mut(); let mut k_vector_to_m_n = self.k_vector_to_m_n.get_or(|| { - let mut k_vector_to_m_n = Vec::new(); - for _ in self.parameters.basis.angular_channels() { - k_vector_to_m_n.push(Array3::from_elem((0, 0, 0), 0.0)); + let mut k_vector_to_m_n = BTreeMap::new(); + for o3_lambda in self.parameters.basis.angular_channels() { + k_vector_to_m_n.insert(o3_lambda, Array3::from_elem((0, 0, 0), 0.0)); } return RefCell::new(k_vector_to_m_n); @@ -252,7 +252,7 @@ impl LodeSphericalExpansion { for o3_lambda in self.parameters.basis.angular_channels() { let radial_size = radial_integral.get(o3_lambda).expect("missing o3_lambda").size(); let shape = (2 * o3_lambda + 1, radial_size, k_vectors.len()); - resize_array3(&mut k_vector_to_m_n[o3_lambda], shape); + resize_array3(k_vector_to_m_n.get_mut(&o3_lambda).expect("missing o3_lambda"), shape); } for (ik, k_vector) in k_vectors.iter().enumerate() { @@ -265,10 +265,11 @@ impl LodeSphericalExpansion { let spherical_harmonics = spherical_harmonics.values.angular_slice(o3_lambda); let radial_integral = radial_integral.get(o3_lambda).expect("missing o3_lambda"); let radial_integral = &radial_integral.values; + let array = k_vector_to_m_n.get_mut(&o3_lambda).expect("missing o3_lambda"); for (m, sph_value) in spherical_harmonics.iter().enumerate() { for (n, ri_value) in radial_integral.iter().enumerate() { - k_vector_to_m_n[o3_lambda][[m, n, ik]] = ri_value * sph_value; + array[[m, n, ik]] = ri_value * sph_value; } } } @@ -569,6 +570,8 @@ impl CalculatorBase for LodeSphericalExpansion { } fn properties(&self, keys: &Labels) -> Vec { + assert_eq!(keys.names(), ["o3_lambda", "o3_sigma", "center_type", "neighbor_type"]); + match self.parameters.basis { SphericalExpansionBasis::TensorProduct(ref basis) => { let mut properties = LabelsBuilder::new(self.property_names()); @@ -578,6 +581,20 @@ impl CalculatorBase for LodeSphericalExpansion { return vec![properties.finish(); keys.count()]; } + SphericalExpansionBasis::Explicit(ref basis) => { + let mut result = Vec::new(); + for [o3_lambda, _, _, _] in keys.iter_fixed_size() { + let mut properties = LabelsBuilder::new(self.property_names()); + + let radial = basis.by_angular.get(&o3_lambda.usize()).expect("missing o3_lambda"); + for n in 0..radial.size() { + properties.add(&[n]); + } + + result.push(properties.finish()); + } + return result; + } } } @@ -700,7 +717,7 @@ impl CalculatorBase for LodeSphericalExpansion { sf_per_center_imag }; - let k_vector_to_m_n = &k_vector_to_m_n[o3_lambda]; + let k_vector_to_m_n = k_vector_to_m_n.get(&o3_lambda).expect("missing o3_lambda"); let data = block.data_mut(); let samples = &*data.samples; @@ -866,9 +883,10 @@ impl CalculatorBase for LodeSphericalExpansion { #[cfg(test)] mod tests { + use crate::calculators::shared::ExplicitBasis; use crate::Calculator; use crate::calculators::{CalculatorBase, DensityKind, LodeRadialBasis, TensorProductBasis}; - use crate::systems::test_utils::test_system; + use crate::systems::test_utils::{test_system, test_systems}; use Vector3D; use approx::assert_relative_eq; @@ -1062,4 +1080,43 @@ mod tests { max_relative=1e-4 ); } + + #[test] + fn explicit_basis() { + let mut by_angular = BTreeMap::new(); + by_angular.insert(1, LodeRadialBasis::Gto { max_radial: 5, radius: 5.5 }); + by_angular.insert(12, LodeRadialBasis::Gto { max_radial: 3, radius: 3.4 }); + + let mut calculator = Calculator::from(Box::new(LodeSphericalExpansion::new( + LodeSphericalExpansionParameters { + k_cutoff: None, + density: Density { + kind: DensityKind::SmearedPowerLaw { + smearing: 0.8, + exponent: 1, + }, + scaling: None, + center_atom_weight: 1.0, + }, + basis: SphericalExpansionBasis::Explicit(ExplicitBasis { + by_angular: by_angular.into(), + spline_accuracy: Some(1e-8), + }), + } + ).unwrap()) as Box); + + let mut systems = test_systems(&["water"]); + + let descriptor = calculator.compute(&mut systems, Default::default()).unwrap(); + + for (key, block) in &descriptor { + if key[0] == 1 { + assert_eq!(block.properties().count(), 6); + } else if key[0] == 12 { + assert_eq!(block.properties().count(), 4); + } else { + panic!("unexpected o3_lambda value"); + } + } + } } diff --git a/rascaline/src/calculators/shared/basis/mod.rs b/rascaline/src/calculators/shared/basis/mod.rs index 2a72d1d91..dab75e876 100644 --- a/rascaline/src/calculators/shared/basis/mod.rs +++ b/rascaline/src/calculators/shared/basis/mod.rs @@ -1,5 +1,7 @@ pub(crate) mod radial; +use std::collections::BTreeMap; + pub use self::radial::{SoapRadialBasis, LodeRadialBasis}; /// Possible Basis functions to use for the SOAP or LODE spherical expansion. @@ -11,9 +13,13 @@ pub use self::radial::{SoapRadialBasis, LodeRadialBasis}; #[serde(deny_unknown_fields)] #[serde(tag = "type")] pub enum SphericalExpansionBasis { - /// A Tensor product basis, combining all possible radial basis functions - /// with all possible angular basis functions. + /// This defines a tensor product basis, combining all possible radial basis + /// functions with all possible angular basis functions. TensorProduct(TensorProductBasis), + /// This defines an explicit basis, where only a specific subset of angular + /// basis can be used, and every angular basis can use a different radial + /// basis. + Explicit(ExplicitBasis), } impl SphericalExpansionBasis { @@ -22,10 +28,15 @@ impl SphericalExpansionBasis { SphericalExpansionBasis::TensorProduct(basis) => { return (0..=basis.max_angular).collect(); } + SphericalExpansionBasis::Explicit(basis) => { + return basis.by_angular.keys().copied().collect(); + } } } } +#[allow(clippy::unnecessary_wraps)] +fn serde_default_spline_accuracy() -> Option { Some(1e-8) } /// Information about "tensor product" spherical expansion basis functions #[derive(Debug, Clone)] @@ -47,5 +58,58 @@ pub struct TensorProductBasis { pub spline_accuracy: Option, } -#[allow(clippy::unnecessary_wraps)] -fn serde_default_spline_accuracy() -> Option { Some(1e-8) } +#[derive(Debug, Clone)] +#[derive(serde::Deserialize, serde::Serialize, schemars::JsonSchema)] +#[serde(deny_unknown_fields)] +// work around https://github.com/serde-rs/serde/issues/1183 +#[serde(try_from = "BTreeMap")] +pub struct ByAngular(BTreeMap); + +impl std::ops::Deref for ByAngular { + type Target = BTreeMap; + + fn deref(&self) -> &Self::Target { + & self.0 + } +} + +impl TryFrom> for ByAngular { + type Error = ::Err; + + fn try_from(value: BTreeMap) -> Result { + let mut result = BTreeMap::new(); + for (angular, radial) in value { + let angular: usize = angular.parse()?; + result.insert(angular, radial); + } + Ok(ByAngular(result)) + } +} + +impl From> for ByAngular { + fn from(value: BTreeMap) -> Self { + ByAngular(value) + } +} + +/// Information about "explicit" spherical expansion basis functions +#[derive(Debug, Clone)] +#[derive(serde::Deserialize, serde::Serialize, schemars::JsonSchema)] +#[serde(deny_unknown_fields)] +pub struct ExplicitBasis { + /// A map of radial basis to use for the specified angular channels. + /// + /// Only angular channels included in this map will be included in the + /// output. Different angular channels are allowed to use completely + /// different radial basis functions. + #[schemars(extend("x-key-type" = "integer"))] + pub by_angular: ByAngular, + /// Accuracy for splining the radial integral. Using splines is typically + /// faster than analytical implementations. If this is None, no splining is + /// done. + /// + /// The number of control points in the spline is automatically determined + /// to ensure the average absolute error is close to the requested accuracy. + #[serde(default = "serde_default_spline_accuracy")] + pub spline_accuracy: Option, +} diff --git a/rascaline/src/calculators/shared/mod.rs b/rascaline/src/calculators/shared/mod.rs index 8454252c8..5bfa59a91 100644 --- a/rascaline/src/calculators/shared/mod.rs +++ b/rascaline/src/calculators/shared/mod.rs @@ -6,7 +6,7 @@ mod density; pub use self::density::{Density, DensityKind, DensityScaling}; pub(crate) mod basis; -pub use self::basis::{SphericalExpansionBasis, TensorProductBasis}; +pub use self::basis::{SphericalExpansionBasis, TensorProductBasis, ExplicitBasis}; pub use self::basis::{SoapRadialBasis, LodeRadialBasis}; pub mod descriptors_by_systems; diff --git a/rascaline/src/calculators/soap/power_spectrum.rs b/rascaline/src/calculators/soap/power_spectrum.rs index 1371d0749..9e65e37c9 100644 --- a/rascaline/src/calculators/soap/power_spectrum.rs +++ b/rascaline/src/calculators/soap/power_spectrum.rs @@ -509,6 +509,18 @@ impl CalculatorBase for SoapPowerSpectrum { return vec![properties.finish(); keys.count()]; } + SphericalExpansionBasis::Explicit(ref basis) => { + let mut properties = LabelsBuilder::new(self.property_names()); + for (&l, radial) in &*basis.by_angular { + for n1 in 0..radial.size() { + for n2 in 0..radial.size() { + properties.add(&[l, n1, n2]); + } + } + } + + return vec![properties.finish(); keys.count()]; + }, } } diff --git a/rascaline/src/calculators/soap/radial_integral/mod.rs b/rascaline/src/calculators/soap/radial_integral/mod.rs index 5894c7b47..2c1492509 100644 --- a/rascaline/src/calculators/soap/radial_integral/mod.rs +++ b/rascaline/src/calculators/soap/radial_integral/mod.rs @@ -75,6 +75,53 @@ pub struct SoapRadialIntegralCache { } impl SoapRadialIntegralCache { + fn new( + o3_lambda: usize, + radial: &SoapRadialBasis, + density: DensityKind, + cutoff: f64, + spline_accuracy: Option, + ) -> Result { + // We only support some specific combinations of density and basis + let implementation = match (density, radial) { + // Gaussian density + GTO basis + (DensityKind::Gaussian {..}, &SoapRadialBasis::Gto { .. }) => { + let gto = SoapRadialIntegralGto::new(cutoff, density, radial, o3_lambda)?; + + if let Some(accuracy) = spline_accuracy { + Box::new(SoapRadialIntegralSpline::with_accuracy( + gto, cutoff, accuracy + )?) + } else { + Box::new(gto) as Box + } + }, + // Dirac density + tabulated basis (also used for + // tabulated radial integral with a different density) + (DensityKind::DiracDelta, SoapRadialBasis::Tabulated(tabulated)) => { + Box::new(SoapRadialIntegralSpline::from_tabulated( + tabulated.clone() + )) as Box + } + // Everything else is an error + _ => { + return Err(Error::InvalidParameter( + "this combination of basis and density is not supported in SOAP".into() + )) + } + }; + + let size = implementation.size(); + let values = Array1::from_elem(size, 0.0); + let gradients = Array1::from_elem(size, 0.0); + + return Ok(SoapRadialIntegralCache { + implementation, + values, + gradients, + }); + } + /// Run the calculation, the results are stored inside `self.values` and /// `self.gradients` pub fn compute(&mut self, distance: f64, do_gradients: bool) { @@ -104,48 +151,34 @@ impl SoapRadialIntegralCacheByAngular { SphericalExpansionBasis::TensorProduct(basis) => { let mut by_angular = BTreeMap::new(); for o3_lambda in 0..=basis.max_angular { - // We only support some specific combinations of density and basis - let implementation = match (density, &basis.radial) { - // Gaussian density + GTO basis - (DensityKind::Gaussian {..}, &SoapRadialBasis::Gto { .. }) => { - let gto = SoapRadialIntegralGto::new(cutoff, density, &basis.radial, o3_lambda)?; - - if let Some(accuracy) = basis.spline_accuracy { - Box::new(SoapRadialIntegralSpline::with_accuracy( - gto, cutoff, accuracy - )?) - } else { - Box::new(gto) as Box - } - }, - // Dirac density + tabulated basis (also used for - // tabulated radial integral with a different density) - (DensityKind::DiracDelta, SoapRadialBasis::Tabulated(tabulated)) => { - Box::new(SoapRadialIntegralSpline::from_tabulated( - tabulated.clone() - )) as Box - } - // Everything else is an error - _ => { - return Err(Error::InvalidParameter( - "this combination of basis and density is not supported in SOAP".into() - )) - } - }; - - let size = implementation.size(); - let values = Array1::from_elem(size, 0.0); - let gradients = Array1::from_elem(size, 0.0); - - by_angular.insert(o3_lambda, SoapRadialIntegralCache { - implementation, - values, - gradients, - }); + let cache = SoapRadialIntegralCache::new( + o3_lambda, + &basis.radial, + density, + cutoff, + basis.spline_accuracy + )?; + by_angular.insert(o3_lambda, cache); } return Ok(SoapRadialIntegralCacheByAngular { by_angular }); } + SphericalExpansionBasis::Explicit(basis) => { + let mut by_angular = BTreeMap::new(); + for (&o3_lambda, radial) in &*basis.by_angular { + let cache = SoapRadialIntegralCache::new( + o3_lambda, + radial, + density, + cutoff, + basis.spline_accuracy + )?; + by_angular.insert(o3_lambda, cache); + } + return Ok(SoapRadialIntegralCacheByAngular { + by_angular + }); + } } } diff --git a/rascaline/src/calculators/soap/radial_spectrum.rs b/rascaline/src/calculators/soap/radial_spectrum.rs index 09f83337c..c9b1ca510 100644 --- a/rascaline/src/calculators/soap/radial_spectrum.rs +++ b/rascaline/src/calculators/soap/radial_spectrum.rs @@ -1,3 +1,5 @@ +use std::collections::BTreeMap; + use metatensor::{EmptyArray, TensorBlock, TensorMap}; use metatensor::{LabelValue, Labels, LabelsBuilder}; @@ -10,7 +12,7 @@ use crate::calculators::shared::{ Density, SoapRadialBasis, SphericalExpansionBasis, - TensorProductBasis, + ExplicitBasis }; @@ -74,12 +76,15 @@ impl std::fmt::Debug for SoapRadialSpectrum { impl SoapRadialSpectrum { pub fn new(parameters: RadialSpectrumParameters) -> Result { + // radial spectrum only needs a single angular basis function + let mut by_angular = BTreeMap::new(); + by_angular.insert(0, parameters.basis.radial.clone()); + let expansion_parameters = SphericalExpansionParameters { cutoff: parameters.cutoff, density: parameters.density, - basis: SphericalExpansionBasis::TensorProduct(TensorProductBasis { - max_angular: 0, - radial: parameters.basis.radial.clone(), + basis: SphericalExpansionBasis::Explicit(ExplicitBasis { + by_angular: by_angular.into(), spline_accuracy: parameters.basis.spline_accuracy, }) }; diff --git a/rascaline/src/calculators/soap/spherical_expansion.rs b/rascaline/src/calculators/soap/spherical_expansion.rs index 8def42f58..71986c89a 100644 --- a/rascaline/src/calculators/soap/spherical_expansion.rs +++ b/rascaline/src/calculators/soap/spherical_expansion.rs @@ -32,9 +32,8 @@ pub struct SphericalExpansion { impl SphericalExpansion { /// Create a new `SphericalExpansion` calculator with the given parameters pub fn new(parameters: SphericalExpansionParameters) -> Result { - let m_1_pow_l = parameters.basis.angular_channels() - .into_iter() - .map(|l| f64::powi(-1.0, l as i32)) + let max_angular = parameters.basis.angular_channels().into_iter().max().expect("there should be at least one angular channel"); + let m_1_pow_l = (0..=max_angular).map(|l| f64::powi(-1.0, l as i32)) .collect::>(); return Ok(SphericalExpansion { @@ -132,6 +131,9 @@ impl SphericalExpansion { SphericalExpansionBasis::TensorProduct(ref basis) => { vec![basis.radial.size(); basis.max_angular + 1] }, + SphericalExpansionBasis::Explicit(ref basis) => { + basis.by_angular.values().map(|radial| radial.size()).collect() + }, }; let angular_channels = self.by_pair.parameters.basis.angular_channels(); @@ -772,6 +774,8 @@ impl CalculatorBase for SphericalExpansion { } fn properties(&self, keys: &Labels) -> Vec { + assert_eq!(keys.names(), ["o3_lambda", "o3_sigma", "center_type", "neighbor_type"]); + match self.by_pair.parameters.basis { SphericalExpansionBasis::TensorProduct(ref basis) => { let mut properties = LabelsBuilder::new(self.property_names()); @@ -781,6 +785,20 @@ impl CalculatorBase for SphericalExpansion { return vec![properties.finish(); keys.count()]; } + SphericalExpansionBasis::Explicit(ref basis) => { + let mut result = Vec::new(); + for [o3_lambda, _, _, _] in keys.iter_fixed_size() { + let mut properties = LabelsBuilder::new(self.property_names()); + + let radial = basis.by_angular.get(&o3_lambda.usize()).expect("missing o3_lambda"); + for n in 0..radial.size() { + properties.add(&[n]); + } + + result.push(properties.finish()); + } + return result; + } } } @@ -834,6 +852,8 @@ impl CalculatorBase for SphericalExpansion { #[cfg(test)] mod tests { + use std::collections::BTreeMap; + use ndarray::ArrayD; use metatensor::{Labels, TensorBlock, EmptyArray, LabelsBuilder, TensorMap}; @@ -843,7 +863,7 @@ mod tests { use super::{SphericalExpansion, SphericalExpansionParameters}; use crate::calculators::soap::{Cutoff, Smoothing}; - use crate::calculators::shared::{Density, DensityKind, DensityScaling}; + use crate::calculators::shared::{Density, DensityKind, DensityScaling, ExplicitBasis}; use crate::calculators::shared::{SoapRadialBasis, SphericalExpansionBasis, TensorProductBasis}; @@ -1053,4 +1073,36 @@ mod tests { let array = block.values.as_array(); assert_eq!(array.index_axis(ndarray::Axis(0), 0), ArrayD::from_elem(vec![1, 6], 0.0)); } + + #[test] + fn explicit_basis() { + let mut by_angular = BTreeMap::new(); + by_angular.insert(1, SoapRadialBasis::Gto { max_radial: 5 }); + by_angular.insert(12, SoapRadialBasis::Gto { max_radial: 3 }); + + let mut calculator = Calculator::from(Box::new(SphericalExpansion::new( + SphericalExpansionParameters { + basis: SphericalExpansionBasis::Explicit(ExplicitBasis { + by_angular: by_angular.into(), + spline_accuracy: None, + + }), + ..parameters() + } + ).unwrap()) as Box); + + let mut systems = test_systems(&["water"]); + + let descriptor = calculator.compute(&mut systems, Default::default()).unwrap(); + + for (key, block) in &descriptor { + if key[0] == 1 { + assert_eq!(block.properties().count(), 6); + } else if key[0] == 12 { + assert_eq!(block.properties().count(), 4); + } else { + panic!("unexpected o3_lambda value"); + } + } + } } diff --git a/rascaline/src/calculators/soap/spherical_expansion_pair.rs b/rascaline/src/calculators/soap/spherical_expansion_pair.rs index 1a81851e1..2139327db 100644 --- a/rascaline/src/calculators/soap/spherical_expansion_pair.rs +++ b/rascaline/src/calculators/soap/spherical_expansion_pair.rs @@ -168,9 +168,8 @@ impl SphericalExpansionByPair { pub fn new(mut parameters: SphericalExpansionParameters) -> Result { parameters.validate()?; - let m_1_pow_l = parameters.basis.angular_channels() - .into_iter() - .map(|l| f64::powi(-1.0, l as i32)) + let max_angular = parameters.basis.angular_channels().into_iter().max().expect("there should be at least one angular channel"); + let m_1_pow_l = (0..=max_angular).map(|l| f64::powi(-1.0, l as i32)) .collect::>(); Ok(SphericalExpansionByPair { @@ -353,10 +352,6 @@ impl SphericalExpansionByPair { RefCell::new(SphericalHarmonicsCache::new(max_angular)) }).borrow_mut(); - let radial_basis = match self.parameters.basis { - SphericalExpansionBasis::TensorProduct(ref basis) => &basis.radial, - }; - radial_integral.compute(distance, do_gradients.any()); spherical_harmonics.compute(direction, do_gradients.any()); @@ -364,6 +359,12 @@ impl SphericalExpansionByPair { let f_scaling_grad = self.scaling_functions_gradient(distance); for o3_lambda in self.parameters.basis.angular_channels() { + let radial_basis_size = match self.parameters.basis { + SphericalExpansionBasis::TensorProduct(ref basis) => basis.radial.size(), + SphericalExpansionBasis::Explicit(ref basis) => { + basis.by_angular.get(&o3_lambda).expect("missing o3_lambda").size() + }, + }; let spherical_harmonics_grad = [ spherical_harmonics.gradients[0].angular_slice(o3_lambda), @@ -396,7 +397,7 @@ impl SphericalExpansionByPair { let sph_grad_y = spherical_harmonics_grad[1][m]; let sph_grad_z = spherical_harmonics_grad[2][m]; - for n in 0..radial_basis.size() { + for n in 0..radial_basis_size { let ri_value = radial_integral[n]; let ri_grad = radial_integral_grad[n]; @@ -699,6 +700,8 @@ impl CalculatorBase for SphericalExpansionByPair { } fn properties(&self, keys: &Labels) -> Vec { + assert_eq!(keys.names(), ["o3_lambda", "o3_sigma", "first_atom_type", "second_atom_type"]); + match self.parameters.basis { SphericalExpansionBasis::TensorProduct(ref basis) => { let mut properties = LabelsBuilder::new(self.property_names()); @@ -708,6 +711,20 @@ impl CalculatorBase for SphericalExpansionByPair { return vec![properties.finish(); keys.count()]; } + SphericalExpansionBasis::Explicit(ref basis) => { + let mut result = Vec::new(); + for [o3_lambda, _, _, _] in keys.iter_fixed_size() { + let mut properties = LabelsBuilder::new(self.property_names()); + + let radial = basis.by_angular.get(&o3_lambda.usize()).expect("missing o3_lambda"); + for n in 0..radial.size() { + properties.add(&[n]); + } + + result.push(properties.finish()); + } + return result; + } } } @@ -730,6 +747,9 @@ impl CalculatorBase for SphericalExpansionByPair { SphericalExpansionBasis::TensorProduct(ref basis) => { vec![basis.radial.size(); basis.max_angular + 1] }, + SphericalExpansionBasis::Explicit(ref basis) => { + basis.by_angular.values().map(|radial| radial.size()).collect() + }, }; let mut contribution = PairContribution::new( @@ -823,6 +843,8 @@ impl CalculatorBase for SphericalExpansionByPair { #[cfg(test)] mod tests { + use std::collections::BTreeMap; + use metatensor::Labels; use ndarray::{s, Axis}; use approx::assert_ulps_eq; @@ -834,7 +856,7 @@ mod tests { use super::{SphericalExpansionByPair, SphericalExpansionParameters}; use crate::calculators::soap::{Cutoff, Smoothing}; - use crate::calculators::shared::{Density, DensityKind, DensityScaling}; + use crate::calculators::shared::{Density, DensityKind, DensityScaling, ExplicitBasis}; use crate::calculators::shared::{SoapRadialBasis, SphericalExpansionBasis, TensorProductBasis}; fn basis() -> TensorProductBasis { @@ -997,4 +1019,36 @@ mod tests { } } } + + #[test] + fn explicit_basis() { + let mut by_angular = BTreeMap::new(); + by_angular.insert(1, SoapRadialBasis::Gto { max_radial: 5 }); + by_angular.insert(12, SoapRadialBasis::Gto { max_radial: 3 }); + + let mut calculator = Calculator::from(Box::new(SphericalExpansionByPair::new( + SphericalExpansionParameters { + basis: SphericalExpansionBasis::Explicit(ExplicitBasis { + by_angular: by_angular.into(), + spline_accuracy: None, + + }), + ..parameters() + } + ).unwrap()) as Box); + + let mut systems = test_systems(&["water"]); + + let descriptor = calculator.compute(&mut systems, Default::default()).unwrap(); + + for (key, block) in &descriptor { + if key[0] == 1 { + assert_eq!(block.properties().count(), 6); + } else if key[0] == 12 { + assert_eq!(block.properties().count(), 4); + } else { + panic!("unexpected o3_lambda value"); + } + } + } }