Skip to content

Commit

Permalink
Implement an Explicit basis, containing a map from angular channels t…
Browse files Browse the repository at this point in the history
…o radial basis
  • Loading branch information
Luthaf committed Nov 1, 2024
1 parent 80ebe56 commit 20def6c
Show file tree
Hide file tree
Showing 17 changed files with 521 additions and 131 deletions.
31 changes: 28 additions & 3 deletions docs/extensions/rascaline_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ def run(self, *args, **kwargs):
section.insert(1, schema_node)

# add missing entries to self._definitions
for name in schema.get("definitions", {}).keys():
for name in schema.get("$defs", {}).keys():
self._definition_used(name)

for name in self._definitions.keys():
definition = schema["definitions"][name]
definition = schema["$defs"][name]
target, subsection = self._transform(definition, name)

section += target
Expand Down Expand Up @@ -112,7 +112,7 @@ def _json_schema_to_nodes(
assert "allOf" not in schema

ref = schema["$ref"]
assert ref.startswith("#/definitions/")
assert ref.startswith("#/$defs/")
type_name = ref.split("/")[-1]

self._definition_used(type_name)
Expand Down Expand Up @@ -228,6 +228,10 @@ def _json_schema_to_nodes(

field_list += body

additional = schema.get("additionalProperties")
if additional is not None:
pass

return field_list
else:
object_node = nodes.inline()
Expand Down Expand Up @@ -261,6 +265,25 @@ def _json_schema_to_nodes(

object_node += field

additional = schema.get("additionalProperties")
if isinstance(additional, dict):
# JSON Schema does not have a concept of key type being anything
# else than string. In rascaline, we annotate `HashMap` with a
# custom `x-key-type` to carry this information all the way to
# here
key_type = schema.get("x-key-type")
if key_type is None:
key_type = "string"

field = nodes.inline()
field += nodes.Text("[key: ")
field += nodes.literal(text=key_type)
field += nodes.Text("]: ")

field += self._json_schema_to_nodes(additional)

object_node += field

object_node += nodes.Text("}")

return object_node
Expand All @@ -285,6 +308,8 @@ def _json_schema_to_nodes(
if "enum" in schema:
values = [f'"{v}"' for v in schema["enum"]]
return nodes.literal(text=" | ".join(values))
elif "const" in schema:
return nodes.Text('"' + schema["const"] + '"')
else:
return nodes.literal(text="string")

Expand Down
2 changes: 1 addition & 1 deletion docs/rascaline-json-schema/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ test = false

[dependencies]
rascaline = {path = "../../rascaline"}
schemars = "0.8.6"
schemars = "=1.0.0-alpha.15"
serde_json = "1"
42 changes: 26 additions & 16 deletions docs/rascaline-json-schema/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::path::PathBuf;

use schemars::schema::RootSchema;
use schemars::Schema;

use rascaline::calculators::AtomicComposition;
use rascaline::calculators::SortedDistances;
Expand Down Expand Up @@ -45,30 +45,40 @@ struct RenameRefInSchema {
in_docs: &'static str,
}

impl schemars::visit::Visitor for RenameRefInSchema {
fn visit_schema_object(&mut self, schema: &mut schemars::schema::SchemaObject) {
schemars::visit::visit_schema_object(self, schema);

let in_code_reference = format!("#/definitions/{}", self.in_code);

if let Some(reference) = &schema.reference {
if reference == &in_code_reference {
schema.reference = Some(format!("#/definitions/{}", self.in_docs));
impl schemars::transform::Transform for RenameRefInSchema {
fn transform(&mut self, schema: &mut Schema) {
let in_code_reference = format!("#/$defs/{}", self.in_code);
if let Some(schema_object) = schema.as_object_mut() {
if let Some(reference) = schema_object.get_mut("$ref") {
if reference == &in_code_reference {
*reference = format!("#/$defs/{}", self.in_docs).into();
}
}
}
schemars::transform::transform_subschemas(self, schema);
}
}

fn save_schema(name: &str, mut schema: RootSchema) {
for transform in REFS_TO_RENAME {
if let Some(value) = schema.definitions.remove(transform.in_code) {
assert!(!schema.definitions.contains_key(transform.in_docs));
schema.definitions.insert(transform.in_docs.into(), value);
fn save_schema(name: &str, mut schema: Schema) {
let schema_object = schema.as_object_mut().expect("schema should be an object");

schemars::visit::visit_root_schema(&mut transform.clone(), &mut schema);
// rename some of the autogenerate names.
// Step 1: rename the definitions
for transform in REFS_TO_RENAME {
if let Some(definitions) = schema_object.get_mut("$defs") {
let definitions = definitions.as_object_mut().expect("$defs should be an object");
if let Some(value) = definitions.remove(transform.in_code) {
assert!(!definitions.contains_key(transform.in_docs));
definitions.insert(transform.in_docs.into(), value);
}
}
}

// Step 2: rename the references to these definitions
for transform in REFS_TO_RENAME {
schemars::transform::transform_subschemas(&mut transform.clone(), &mut schema);
}

let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.pop();
path.push("build");
Expand Down
3 changes: 3 additions & 0 deletions docs/src/references/api/python/basis.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ Basis functions
.. autoclass:: rascaline.basis.TensorProduct
:show-inheritance:

.. autoclass:: rascaline.basis.Explicit
:show-inheritance:

.. autoclass:: rascaline.basis.ExpansionBasis
:members:

Expand Down
41 changes: 39 additions & 2 deletions python/rascaline/rascaline/basis.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import abc
from typing import Optional
from typing import Optional, Dict

import numpy as np

Expand Down Expand Up @@ -355,7 +355,7 @@ def Jn_zeros(angular_size: int, radial_size: int) -> np.ndarray:
########################################################################################


class ExpansionBasis(metaclass=abc.ABCMeta):
class ExpansionBasis:
"""
Base class representing a set of basis functions used by spherical expansions.
Expand Down Expand Up @@ -420,3 +420,40 @@ def get_hypers(self):
"radial": self.radial,
"spline_accuracy": self.spline_accuracy,
}


class Explicit(ExpansionBasis):
r"""
An expansion basis where combinations of radial and angular functions is picked
explicitly.
The angular basis functions are still spherical harmonics, but only the degrees
included as keys in ``by_angular`` will be part of the output. Each of these angular
basis function can then be associated with a set of different radial basis function,
potentially of different sizes.
"""

def __init__(
self,
*,
by_angular: Dict[int, RadialBasis],
spline_accuracy: Optional[float] = 1e-8,
):
self.by_angular = by_angular

if spline_accuracy is None:
self.spline_accuracy = None
else:
self.spline_accuracy = float(spline_accuracy)
assert self.spline_accuracy > 0

for angular, radial in self.by_angular.items():
assert angular >= 0
assert isinstance(radial, RadialBasis)

def get_hypers(self):
return {
"type": "Explicit",
"by_angular": self.by_angular,
"spline_accuracy": self.spline_accuracy,
}
17 changes: 10 additions & 7 deletions python/rascaline/tests/calculators/hypers.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,11 @@ def test_hypers_classes():
},
},
"basis": {
"type": "TensorProduct",
"max_angular": 5,
"radial": {"type": "Gto", "max_radial": 5},
"type": "Explicit",
"by_angular": {
3: {"type": "Gto", "max_radial": 5},
5: {"type": "Gto", "max_radial": 2},
},
},
}

Expand All @@ -115,9 +117,11 @@ def test_hypers_classes():
scaling=rascaline.density.Willatt2018(exponent=3, rate=2.2, scale=1.1),
center_atom_weight=0.3,
),
basis=rascaline.basis.TensorProduct(
max_angular=5,
radial=rascaline.basis.Gto(max_radial=5),
basis=rascaline.basis.Explicit(
by_angular={
3: rascaline.basis.Gto(max_radial=5),
5: rascaline.basis.Gto(max_radial=2),
}
),
)

Expand All @@ -126,7 +130,6 @@ def test_hypers_classes():


def test_hypers_custom_classes_errors():

class MyCustomSmoothing(rascaline.cutoff.SmoothingFunction):
def compute(self, cutoff, positions, derivative):
pass
Expand Down
2 changes: 1 addition & 1 deletion rascaline/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ time-graph = "0.3.0"

serde = { version = "1", features = ["derive"] }
serde_json = "1"
schemars = "0.8"
schemars = "=1.0.0-alpha.15"

chemfiles = {version = "0.10", optional = true}

Expand Down
94 changes: 63 additions & 31 deletions rascaline/src/calculators/lode/radial_integral/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,45 @@ pub struct LodeRadialIntegralCache {
}

impl LodeRadialIntegralCache {
fn new(
o3_lambda: usize,
radial: &LodeRadialBasis,
density: DensityKind,
k_cutoff: f64,
spline_accuracy: Option<f64>,
) -> Result<LodeRadialIntegralCache, Error> {
let implementation = match radial {
LodeRadialBasis::Gto { .. } => {
let gto = LodeRadialIntegralGto::new(radial, o3_lambda)?;

if let Some(accuracy) = spline_accuracy {
let do_center_contribution = o3_lambda == 0;
Box::new(LodeRadialIntegralSpline::with_accuracy(
gto, density, k_cutoff, accuracy, do_center_contribution
)?)
} else {
Box::new(gto) as Box<dyn LodeRadialIntegral>
}
},
LodeRadialBasis::Tabulated(ref tabulated) => {
Box::new(LodeRadialIntegralSpline::from_tabulated(
tabulated.clone(),
density,
)) as Box<dyn LodeRadialIntegral>
}
};

let size = implementation.size();
let values = Array1::from_elem(size, 0.0);
let gradients = Array1::from_elem(size, 0.0);

return Ok(LodeRadialIntegralCache {
implementation,
values,
gradients,
});
}

/// Run the calculation, the results are stored inside `self.values` and
/// `self.gradients`
pub fn compute(&mut self, k_norm: f64, do_gradients: bool) {
Expand Down Expand Up @@ -91,43 +130,36 @@ impl LodeRadialIntegralCacheByAngular {
SphericalExpansionBasis::TensorProduct(basis) => {
let mut by_angular = BTreeMap::new();
for o3_lambda in 0..=basis.max_angular {
// We only support some specific radial basis
let implementation = match basis.radial {
LodeRadialBasis::Gto { .. } => {
let gto = LodeRadialIntegralGto::new(&basis.radial, o3_lambda)?;

if let Some(accuracy) = basis.spline_accuracy {
let do_center_contribution = o3_lambda == 0;
Box::new(LodeRadialIntegralSpline::with_accuracy(
gto, density, k_cutoff, accuracy, do_center_contribution
)?)
} else {
Box::new(gto) as Box<dyn LodeRadialIntegral>
}
},
LodeRadialBasis::Tabulated(ref tabulated) => {
Box::new(LodeRadialIntegralSpline::from_tabulated(
tabulated.clone(),
density,
)) as Box<dyn LodeRadialIntegral>
}
};

let size = implementation.size();
let values = Array1::from_elem(size, 0.0);
let gradients = Array1::from_elem(size, 0.0);

by_angular.insert(o3_lambda, LodeRadialIntegralCache {
implementation,
values,
gradients,
});
let cache = LodeRadialIntegralCache::new(
o3_lambda,
&basis.radial,
density,
k_cutoff,
basis.spline_accuracy
)?;
by_angular.insert(o3_lambda, cache);
}

return Ok(LodeRadialIntegralCacheByAngular {
by_angular
});
}
SphericalExpansionBasis::Explicit(basis) => {
let mut by_angular = BTreeMap::new();
for (&o3_lambda, radial) in &*basis.by_angular {
let cache = LodeRadialIntegralCache::new(
o3_lambda,
radial,
density,
k_cutoff,
basis.spline_accuracy
)?;
by_angular.insert(o3_lambda, cache);
}
return Ok(LodeRadialIntegralCacheByAngular {
by_angular
});
}
}
}

Expand Down
5 changes: 4 additions & 1 deletion rascaline/src/calculators/lode/radial_integral/spline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ impl LodeRadialIntegral for LodeRadialIntegralSpline {
}

if self.center_contribution.is_none() {
return Err(Error::InvalidParameter("TODO".into()));
return Err(Error::InvalidParameter(
"`center_contribution` must be defined for the Tabulated radial \
basis used with the L=0 angular channel".into()
));
}

return Ok(self.center_contribution.clone().expect("just checked"));
Expand Down
Loading

0 comments on commit 20def6c

Please sign in to comment.