Skip to content

Commit

Permalink
Convert from old style to new style hypers
Browse files Browse the repository at this point in the history
  • Loading branch information
Luthaf committed Sep 26, 2024
1 parent 88228be commit 604c172
Show file tree
Hide file tree
Showing 8 changed files with 345 additions and 24 deletions.
18 changes: 9 additions & 9 deletions docs/src/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,15 @@ def setup(app):
}

intersphinx_mapping = {
"ase": ("https://wiki.fysik.dtu.dk/ase/", None),
"chemfiles": ("https://chemfiles.org/chemfiles.py/latest/", None),
"metatensor": ("https://docs.metatensor.org/latest/", None),
"matplotlib": ("https://matplotlib.org/stable/", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"scipy": ("https://docs.scipy.org/doc/scipy/", None),
"skmatter": ("https://scikit-matter.readthedocs.io/en/latest/", None),
"torch": ("https://pytorch.org/docs/stable/", None),
"python": ("https://docs.python.org/3", None),
# "ase": ("https://wiki.fysik.dtu.dk/ase/", None),
# "chemfiles": ("https://chemfiles.org/chemfiles.py/latest/", None),
# "metatensor": ("https://docs.metatensor.org/latest/", None),
# "matplotlib": ("https://matplotlib.org/stable/", None),
# "numpy": ("https://numpy.org/doc/stable/", None),
# "scipy": ("https://docs.scipy.org/doc/scipy/", None),
# "skmatter": ("https://scikit-matter.readthedocs.io/en/latest/", None),
# "torch": ("https://pytorch.org/docs/stable/", None),
# "python": ("https://docs.python.org/3", None),
}

# -- Options for HTML output -------------------------------------------------
Expand Down
2 changes: 2 additions & 0 deletions docs/src/references/api/python/misc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@ Miscellaneous
.. autoclass:: rascaline.Profiler
:members:
:undoc-members:

.. autofunction:: rascaline.convert_hypers
1 change: 1 addition & 0 deletions python/rascaline/rascaline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .profiling import Profiler # noqa
from .status import RascalError # noqa
from .systems import IntoSystem, SystemBase # noqa
from .utils import convert_hypers # noqa
from .version import __version__ # noqa


Expand Down
57 changes: 42 additions & 15 deletions python/rascaline/rascaline/calculators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json

from .utils import BadHyperParameters, convert_hypers


try:
# see rascaline-torch/calculators.py for the explanation of what's going on here
Expand Down Expand Up @@ -96,6 +98,27 @@ def __init__(self, *, cutoff, max_neighbors, separate_neighbor_types):
super().__init__("sorted_distances", json.dumps(parameters))


def _check_for_old_hypers(calculator, hypers):
try:
new_hypers = convert_hypers(
origin="rascaline",
representation=calculator,
hypers=hypers,
)
except BadHyperParameters as e:
print(e)
raise ValueError(
f"invalid hyper parameters to {calculator}, "
"expected `density` and `basis` to be present"
)

raise ValueError(
f"{calculator} hyper parameter changed recently, "
"please update your code. Here are the new equivalent parameters:\n"
+ new_hypers
)


class SphericalExpansion(CalculatorBase):
"""Spherical expansion of Smooth Overlap of Atomic Positions (SOAP).
Expand All @@ -115,9 +138,9 @@ class SphericalExpansion(CalculatorBase):
:ref:`documentation <spherical-expansion>`.
"""

def __init__(self, *, cutoff, density, basis, **kwargs):
if len(kwargs) != 0:
raise ValueError("TODO: old style parameters")
def __init__(self, *, cutoff=None, density=None, basis=None, **kwargs):
if len(kwargs) != 0 or density is None or basis is None:
_check_for_old_hypers("SphericalExpansion", {"cutoff": cutoff, **kwargs})

parameters = {
"cutoff": cutoff,
Expand Down Expand Up @@ -161,9 +184,11 @@ class SphericalExpansionByPair(CalculatorBase):
:ref:`documentation <spherical-expansion-by-pair>`.
"""

def __init__(self, *, cutoff, density, basis, **kwargs):
if len(kwargs) != 0:
raise ValueError("TODO: old style parameters")
def __init__(self, *, cutoff=None, density=None, basis=None, **kwargs):
if len(kwargs) != 0 or density is None or basis is None:
_check_for_old_hypers(
"SphericalExpansionByPair", {"cutoff": cutoff, **kwargs}
)

parameters = {
"cutoff": cutoff,
Expand Down Expand Up @@ -191,9 +216,9 @@ class SoapRadialSpectrum(CalculatorBase):
:ref:`documentation <soap-radial-spectrum>`.
"""

def __init__(self, *, cutoff, density, basis, **kwargs):
if len(kwargs) != 0:
raise ValueError("TODO: old style parameters")
def __init__(self, *, cutoff=None, density=None, basis=None, **kwargs):
if len(kwargs) != 0 or density is None or basis is None:
_check_for_old_hypers("SoapRadialSpectrum", {"cutoff": cutoff, **kwargs})

parameters = {
"cutoff": cutoff,
Expand Down Expand Up @@ -225,9 +250,9 @@ class SoapPowerSpectrum(CalculatorBase):
allows to compute the power spectrum from different spherical expansions.
"""

def __init__(self, *, cutoff, density, basis, **kwargs):
if len(kwargs) != 0:
raise ValueError("TODO: old style parameters")
def __init__(self, *, cutoff=None, density=None, basis=None, **kwargs):
if len(kwargs) != 0 or density is None or basis is None:
_check_for_old_hypers("SoapPowerSpectrum", {"cutoff": cutoff, **kwargs})

parameters = {
"cutoff": cutoff,
Expand Down Expand Up @@ -255,9 +280,11 @@ class LodeSphericalExpansion(CalculatorBase):
:ref:`documentation <lode-spherical-expansion>`.
"""

def __init__(self, *, density, basis, k_cutoff=None, **kwargs):
if len(kwargs) != 0:
raise ValueError("TODO: old style parameters")
def __init__(self, *, density=None, basis=None, k_cutoff=None, **kwargs):
if len(kwargs) != 0 or density is None or basis is None:
_check_for_old_hypers(
"LodeSphericalExpansion", {"k_cutoff": k_cutoff, **kwargs}
)

parameters = {
"k_cutoff": k_cutoff,
Expand Down
1 change: 1 addition & 0 deletions python/rascaline/rascaline/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
calculate_cg_coefficients,
cartesian_to_spherical,
)
from .hypers import BadHyperParameters, convert_hypers # noqa
from .power_spectrum import PowerSpectrum # noqa


Expand Down
49 changes: 49 additions & 0 deletions python/rascaline/rascaline/utils/hypers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import json
import re

from . import _rascaline


class BadHyperParameters(Exception):
pass


def convert_hypers(origin, representation=None, hypers=None):
"""Convert hyper-parameters from other software into the format used by rascaline.
:param origin: which software do the hyper-parameters come from? Valid values are:
- ``"rascaline"`` for old rascaline format;
:param representation: which representation are these hyper for? The meaning depend
on the ``origin``:
- for ``origin="rascaline"``, this is the name of the calculator class;
:param hypers: the hyper parameter to convert. The type depend on the ``origin``:
- for ``origin="rascaline"``, this should be a dictionary;
:return: A string containing the code corresponding to the requested representation
and hypers
"""
if origin == "rascaline":
if representation in [
"SphericalExpansion",
"SphericalExpansionByPair",
"SoapPowerSpectrum",
]:
hypers = _rascaline.convert_soap(hypers)
elif representation == "SoapRadialSpectrum":
hypers = _rascaline.convert_radial_spectrum(hypers)
elif representation == "LodeSphericalExpansion":
hypers = _rascaline.convert_lode(hypers)
else:
raise ValueError(
"no hyper conversion exists for rascaline representation "
f"'{representation}'"
)

hypers_dict = json.dumps(hypers, indent=4)
hypers_dict = re.sub(r"\bnull\b", "None", hypers_dict)
return f"{representation}(**{hypers_dict})"
else:
raise ValueError(f"no hyper conversion exists for {origin} software")
166 changes: 166 additions & 0 deletions python/rascaline/rascaline/utils/hypers/_rascaline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
def convert_soap(hypers):
"""convert from old style rascaline hypers for SOAP-related representations"""
cleaned = {
"cutoff": _process_cutoff(hypers),
"density": _process_density(hypers),
}

max_angular = _get_or_error(hypers, "max_angular", "<root>")
radial, spline_accuracy = _process_radial_basis(hypers)
cleaned["basis"] = {
"type": "TensorProduct",
"max_angular": max_angular,
"radial": radial,
}
if spline_accuracy is not None:
if isinstance(spline_accuracy, float):
cleaned["basis"]["spline_accuracy"] = spline_accuracy
else:
cleaned["basis"]["spline_accuracy"] = None

return cleaned


def convert_radial_spectrum(hypers):
"""convert from old style rascaline hypers for SOAP radial spectrum"""
cleaned = {
"cutoff": _process_cutoff(hypers),
"density": _process_density(hypers),
}

radial, spline_accuracy = _process_radial_basis(hypers)
cleaned["basis"] = {"radial": radial}
if spline_accuracy is not None:
if isinstance(spline_accuracy, float):
cleaned["basis"]["spline_accuracy"] = spline_accuracy
else:
cleaned["basis"]["spline_accuracy"] = None

return cleaned


def convert_lode(hypers):
"""convert from old style rascaline hypers for LODE spherical expansion"""

cleaned = {
"density": _process_density(hypers),
}

max_angular = _get_or_error(hypers, "max_angular", "<root>")
radial, spline_accuracy = _process_radial_basis(hypers, lode=True)
cleaned["basis"] = {
"type": "TensorProduct",
"max_angular": max_angular,
"radial": radial,
}
if spline_accuracy is not None:
if isinstance(spline_accuracy, float):
cleaned["basis"]["spline_accuracy"] = spline_accuracy
else:
cleaned["basis"]["spline_accuracy"] = None

k_cutoff = hypers.get("k_cutoff")
if k_cutoff is not None:
cleaned["k_cutoff"] = k_cutoff

return cleaned


def _process_cutoff(hypers):
cutoff = {
"radius": _get_or_error(hypers, "cutoff", "<root>"),
}

cutoff_fn = _get_or_error(hypers, "cutoff_function", "<root>")
if "Step" in cutoff_fn:
cutoff["smoothing"] = {"type": "Step"}
if "ShiftedCosine" in cutoff_fn:
width = _get_or_error(
cutoff_fn["ShiftedCosine"], "width", "cutoff_function.ShiftedCosine"
)
cutoff["smoothing"] = {"type": "ShiftedCosine", "width": width}

return cutoff


def _process_density(hypers):
gaussian_width = _get_or_error(hypers, "atomic_gaussian_width", "<root>")
center_weight = _get_or_error(hypers, "center_atom_weight", "<root>")
exponent = hypers.get("potential_exponent")

if exponent is None:
density = {
"type": "Gaussian",
"width": gaussian_width,
}
else:
density = {
"type": "LongRangeGaussian",
"width": gaussian_width,
"exponent": exponent,
}

if center_weight != 1.0:
density["center_atom_weight"] = center_weight

if "radial_scaling" in hypers:
radial_scaling = hypers["radial_scaling"]
if radial_scaling is None:
pass

if "None" in radial_scaling:
pass

if "Willatt2018" in radial_scaling:
exponent = _get_or_error(
radial_scaling["Willatt2018"], "exponent", "radial_scaling.Willatt2018"
)
rate = _get_or_error(
radial_scaling["Willatt2018"], "rate", "radial_scaling.Willatt2018"
)
scale = _get_or_error(
radial_scaling["Willatt2018"], "scale", "radial_scaling.Willatt2018"
)

density["scaling"] = {
"type": "Willatt2018",
"exponent": exponent,
"rate": rate,
"scale": scale,
}

return density


def _process_radial_basis(hypers, lode=False):
spline_accuracy = None
max_radial = _get_or_error(hypers, "max_radial", "<root>") - 1
radial_basis = _get_or_error(hypers, "radial_basis", "<root>")

if "Gto" in radial_basis:
radial = {"type": "Gto", "max_radial": max_radial}

if lode:
cutoff = _get_or_error(hypers, "cutoff", "<root>") - 1
radial["radius"] = cutoff

gto_basis = radial_basis["Gto"]
do_splines = gto_basis.get("splined_radial_integral", True)
if do_splines:
spline_accuracy = gto_basis.get("spline_accuracy")
else:
spline_accuracy = False

elif "TabulatedRadialIntegral" in radial_basis:
raise NotImplementedError("TabulatedRadialIntegral radial basis")

return radial, spline_accuracy


def _get_or_error(hypers, name, path):
from . import BadHyperParameters

if name not in hypers:
raise BadHyperParameters(f"missing {name} at {path} in hypers")

return hypers.pop(name)
Loading

0 comments on commit 604c172

Please sign in to comment.