diff --git a/docs/src/conf.py b/docs/src/conf.py index 1aa61e596..d21e8f371 100644 --- a/docs/src/conf.py +++ b/docs/src/conf.py @@ -139,15 +139,15 @@ def setup(app): } intersphinx_mapping = { - "ase": ("https://wiki.fysik.dtu.dk/ase/", None), - "chemfiles": ("https://chemfiles.org/chemfiles.py/latest/", None), - "metatensor": ("https://docs.metatensor.org/latest/", None), - "matplotlib": ("https://matplotlib.org/stable/", None), - "numpy": ("https://numpy.org/doc/stable/", None), - "scipy": ("https://docs.scipy.org/doc/scipy/", None), - "skmatter": ("https://scikit-matter.readthedocs.io/en/latest/", None), - "torch": ("https://pytorch.org/docs/stable/", None), - "python": ("https://docs.python.org/3", None), + # "ase": ("https://wiki.fysik.dtu.dk/ase/", None), + # "chemfiles": ("https://chemfiles.org/chemfiles.py/latest/", None), + # "metatensor": ("https://docs.metatensor.org/latest/", None), + # "matplotlib": ("https://matplotlib.org/stable/", None), + # "numpy": ("https://numpy.org/doc/stable/", None), + # "scipy": ("https://docs.scipy.org/doc/scipy/", None), + # "skmatter": ("https://scikit-matter.readthedocs.io/en/latest/", None), + # "torch": ("https://pytorch.org/docs/stable/", None), + # "python": ("https://docs.python.org/3", None), } # -- Options for HTML output ------------------------------------------------- diff --git a/docs/src/references/api/python/misc.rst b/docs/src/references/api/python/misc.rst index 5475d1cc8..6706b5509 100644 --- a/docs/src/references/api/python/misc.rst +++ b/docs/src/references/api/python/misc.rst @@ -13,3 +13,5 @@ Miscellaneous .. autoclass:: rascaline.Profiler :members: :undoc-members: + +.. autofunction:: rascaline.convert_hypers diff --git a/python/rascaline/rascaline/__init__.py b/python/rascaline/rascaline/__init__.py index c11deb7a8..ce7d98bd5 100644 --- a/python/rascaline/rascaline/__init__.py +++ b/python/rascaline/rascaline/__init__.py @@ -17,6 +17,7 @@ from .profiling import Profiler # noqa from .status import RascalError # noqa from .systems import IntoSystem, SystemBase # noqa +from .utils import convert_hypers # noqa from .version import __version__ # noqa diff --git a/python/rascaline/rascaline/calculators.py b/python/rascaline/rascaline/calculators.py index 05dabec37..43ec0607e 100644 --- a/python/rascaline/rascaline/calculators.py +++ b/python/rascaline/rascaline/calculators.py @@ -1,5 +1,7 @@ import json +from .utils import BadHyperParameters, convert_hypers + try: # see rascaline-torch/calculators.py for the explanation of what's going on here @@ -96,6 +98,27 @@ def __init__(self, *, cutoff, max_neighbors, separate_neighbor_types): super().__init__("sorted_distances", json.dumps(parameters)) +def _check_for_old_hypers(calculator, hypers): + try: + new_hypers = convert_hypers( + origin="rascaline", + representation=calculator, + hypers=hypers, + ) + except BadHyperParameters as e: + print(e) + raise ValueError( + f"invalid hyper parameters to {calculator}, " + "expected `density` and `basis` to be present" + ) + + raise ValueError( + f"{calculator} hyper parameter changed recently, " + "please update your code. Here are the new equivalent parameters:\n" + + new_hypers + ) + + class SphericalExpansion(CalculatorBase): """Spherical expansion of Smooth Overlap of Atomic Positions (SOAP). @@ -115,9 +138,9 @@ class SphericalExpansion(CalculatorBase): :ref:`documentation `. """ - def __init__(self, *, cutoff, density, basis, **kwargs): - if len(kwargs) != 0: - raise ValueError("TODO: old style parameters") + def __init__(self, *, cutoff=None, density=None, basis=None, **kwargs): + if len(kwargs) != 0 or density is None or basis is None: + _check_for_old_hypers("SphericalExpansion", {"cutoff": cutoff, **kwargs}) parameters = { "cutoff": cutoff, @@ -161,9 +184,11 @@ class SphericalExpansionByPair(CalculatorBase): :ref:`documentation `. """ - def __init__(self, *, cutoff, density, basis, **kwargs): - if len(kwargs) != 0: - raise ValueError("TODO: old style parameters") + def __init__(self, *, cutoff=None, density=None, basis=None, **kwargs): + if len(kwargs) != 0 or density is None or basis is None: + _check_for_old_hypers( + "SphericalExpansionByPair", {"cutoff": cutoff, **kwargs} + ) parameters = { "cutoff": cutoff, @@ -191,9 +216,9 @@ class SoapRadialSpectrum(CalculatorBase): :ref:`documentation `. """ - def __init__(self, *, cutoff, density, basis, **kwargs): - if len(kwargs) != 0: - raise ValueError("TODO: old style parameters") + def __init__(self, *, cutoff=None, density=None, basis=None, **kwargs): + if len(kwargs) != 0 or density is None or basis is None: + _check_for_old_hypers("SoapRadialSpectrum", {"cutoff": cutoff, **kwargs}) parameters = { "cutoff": cutoff, @@ -225,9 +250,9 @@ class SoapPowerSpectrum(CalculatorBase): allows to compute the power spectrum from different spherical expansions. """ - def __init__(self, *, cutoff, density, basis, **kwargs): - if len(kwargs) != 0: - raise ValueError("TODO: old style parameters") + def __init__(self, *, cutoff=None, density=None, basis=None, **kwargs): + if len(kwargs) != 0 or density is None or basis is None: + _check_for_old_hypers("SoapPowerSpectrum", {"cutoff": cutoff, **kwargs}) parameters = { "cutoff": cutoff, @@ -255,9 +280,11 @@ class LodeSphericalExpansion(CalculatorBase): :ref:`documentation `. """ - def __init__(self, *, density, basis, k_cutoff=None, **kwargs): - if len(kwargs) != 0: - raise ValueError("TODO: old style parameters") + def __init__(self, *, density=None, basis=None, k_cutoff=None, **kwargs): + if len(kwargs) != 0 or density is None or basis is None: + _check_for_old_hypers( + "LodeSphericalExpansion", {"k_cutoff": k_cutoff, **kwargs} + ) parameters = { "k_cutoff": k_cutoff, diff --git a/python/rascaline/rascaline/utils/__init__.py b/python/rascaline/rascaline/utils/__init__.py index a63c95567..e39e05520 100644 --- a/python/rascaline/rascaline/utils/__init__.py +++ b/python/rascaline/rascaline/utils/__init__.py @@ -6,6 +6,7 @@ calculate_cg_coefficients, cartesian_to_spherical, ) +from .hypers import BadHyperParameters, convert_hypers # noqa from .power_spectrum import PowerSpectrum # noqa diff --git a/python/rascaline/rascaline/utils/hypers/__init__.py b/python/rascaline/rascaline/utils/hypers/__init__.py new file mode 100644 index 000000000..fa42977a4 --- /dev/null +++ b/python/rascaline/rascaline/utils/hypers/__init__.py @@ -0,0 +1,49 @@ +import json +import re + +from . import _rascaline + + +class BadHyperParameters(Exception): + pass + + +def convert_hypers(origin, representation=None, hypers=None): + """Convert hyper-parameters from other software into the format used by rascaline. + + :param origin: which software do the hyper-parameters come from? Valid values are: + + - ``"rascaline"`` for old rascaline format; + :param representation: which representation are these hyper for? The meaning depend + on the ``origin``: + + - for ``origin="rascaline"``, this is the name of the calculator class; + :param hypers: the hyper parameter to convert. The type depend on the ``origin``: + + - for ``origin="rascaline"``, this should be a dictionary; + + :return: A string containing the code corresponding to the requested representation + and hypers + """ + if origin == "rascaline": + if representation in [ + "SphericalExpansion", + "SphericalExpansionByPair", + "SoapPowerSpectrum", + ]: + hypers = _rascaline.convert_soap(hypers) + elif representation == "SoapRadialSpectrum": + hypers = _rascaline.convert_radial_spectrum(hypers) + elif representation == "LodeSphericalExpansion": + hypers = _rascaline.convert_lode(hypers) + else: + raise ValueError( + "no hyper conversion exists for rascaline representation " + f"'{representation}'" + ) + + hypers_dict = json.dumps(hypers, indent=4) + hypers_dict = re.sub(r"\bnull\b", "None", hypers_dict) + return f"{representation}(**{hypers_dict})" + else: + raise ValueError(f"no hyper conversion exists for {origin} software") diff --git a/python/rascaline/rascaline/utils/hypers/_rascaline.py b/python/rascaline/rascaline/utils/hypers/_rascaline.py new file mode 100644 index 000000000..df73bb31a --- /dev/null +++ b/python/rascaline/rascaline/utils/hypers/_rascaline.py @@ -0,0 +1,166 @@ +def convert_soap(hypers): + """convert from old style rascaline hypers for SOAP-related representations""" + cleaned = { + "cutoff": _process_cutoff(hypers), + "density": _process_density(hypers), + } + + max_angular = _get_or_error(hypers, "max_angular", "") + radial, spline_accuracy = _process_radial_basis(hypers) + cleaned["basis"] = { + "type": "TensorProduct", + "max_angular": max_angular, + "radial": radial, + } + if spline_accuracy is not None: + if isinstance(spline_accuracy, float): + cleaned["basis"]["spline_accuracy"] = spline_accuracy + else: + cleaned["basis"]["spline_accuracy"] = None + + return cleaned + + +def convert_radial_spectrum(hypers): + """convert from old style rascaline hypers for SOAP radial spectrum""" + cleaned = { + "cutoff": _process_cutoff(hypers), + "density": _process_density(hypers), + } + + radial, spline_accuracy = _process_radial_basis(hypers) + cleaned["basis"] = {"radial": radial} + if spline_accuracy is not None: + if isinstance(spline_accuracy, float): + cleaned["basis"]["spline_accuracy"] = spline_accuracy + else: + cleaned["basis"]["spline_accuracy"] = None + + return cleaned + + +def convert_lode(hypers): + """convert from old style rascaline hypers for LODE spherical expansion""" + + cleaned = { + "density": _process_density(hypers), + } + + max_angular = _get_or_error(hypers, "max_angular", "") + radial, spline_accuracy = _process_radial_basis(hypers, lode=True) + cleaned["basis"] = { + "type": "TensorProduct", + "max_angular": max_angular, + "radial": radial, + } + if spline_accuracy is not None: + if isinstance(spline_accuracy, float): + cleaned["basis"]["spline_accuracy"] = spline_accuracy + else: + cleaned["basis"]["spline_accuracy"] = None + + k_cutoff = hypers.get("k_cutoff") + if k_cutoff is not None: + cleaned["k_cutoff"] = k_cutoff + + return cleaned + + +def _process_cutoff(hypers): + cutoff = { + "radius": _get_or_error(hypers, "cutoff", ""), + } + + cutoff_fn = _get_or_error(hypers, "cutoff_function", "") + if "Step" in cutoff_fn: + cutoff["smoothing"] = {"type": "Step"} + if "ShiftedCosine" in cutoff_fn: + width = _get_or_error( + cutoff_fn["ShiftedCosine"], "width", "cutoff_function.ShiftedCosine" + ) + cutoff["smoothing"] = {"type": "ShiftedCosine", "width": width} + + return cutoff + + +def _process_density(hypers): + gaussian_width = _get_or_error(hypers, "atomic_gaussian_width", "") + center_weight = _get_or_error(hypers, "center_atom_weight", "") + exponent = hypers.get("potential_exponent") + + if exponent is None: + density = { + "type": "Gaussian", + "width": gaussian_width, + } + else: + density = { + "type": "LongRangeGaussian", + "width": gaussian_width, + "exponent": exponent, + } + + if center_weight != 1.0: + density["center_atom_weight"] = center_weight + + if "radial_scaling" in hypers: + radial_scaling = hypers["radial_scaling"] + if radial_scaling is None: + pass + + if "None" in radial_scaling: + pass + + if "Willatt2018" in radial_scaling: + exponent = _get_or_error( + radial_scaling["Willatt2018"], "exponent", "radial_scaling.Willatt2018" + ) + rate = _get_or_error( + radial_scaling["Willatt2018"], "rate", "radial_scaling.Willatt2018" + ) + scale = _get_or_error( + radial_scaling["Willatt2018"], "scale", "radial_scaling.Willatt2018" + ) + + density["scaling"] = { + "type": "Willatt2018", + "exponent": exponent, + "rate": rate, + "scale": scale, + } + + return density + + +def _process_radial_basis(hypers, lode=False): + spline_accuracy = None + max_radial = _get_or_error(hypers, "max_radial", "") - 1 + radial_basis = _get_or_error(hypers, "radial_basis", "") + + if "Gto" in radial_basis: + radial = {"type": "Gto", "max_radial": max_radial} + + if lode: + cutoff = _get_or_error(hypers, "cutoff", "") - 1 + radial["radius"] = cutoff + + gto_basis = radial_basis["Gto"] + do_splines = gto_basis.get("splined_radial_integral", True) + if do_splines: + spline_accuracy = gto_basis.get("spline_accuracy") + else: + spline_accuracy = False + + elif "TabulatedRadialIntegral" in radial_basis: + raise NotImplementedError("TabulatedRadialIntegral radial basis") + + return radial, spline_accuracy + + +def _get_or_error(hypers, name, path): + from . import BadHyperParameters + + if name not in hypers: + raise BadHyperParameters(f"missing {name} at {path} in hypers") + + return hypers.pop(name) diff --git a/python/rascaline/tests/calculators/hyper_conversion.py b/python/rascaline/tests/calculators/hyper_conversion.py new file mode 100644 index 000000000..13f3aeb16 --- /dev/null +++ b/python/rascaline/tests/calculators/hyper_conversion.py @@ -0,0 +1,75 @@ +import pytest + +from rascaline.calculators import ( + LodeSphericalExpansion, + SoapPowerSpectrum, + SoapRadialSpectrum, + SphericalExpansion, + SphericalExpansionByPair, +) + + +@pytest.mark.parametrize( + "CalculatorClass", + [ + SphericalExpansion, + SphericalExpansionByPair, + SoapPowerSpectrum, + SoapRadialSpectrum, + ], +) +def test_soap_hypers(CalculatorClass): + message = ( + "hyper parameter changed recently, please update your code. " + "Here are the new equivalent parameters" + ) + with pytest.raises(ValueError, match=message) as err: + CalculatorClass( + atomic_gaussian_width=0.3, + center_atom_weight=1.0, + cutoff=3.4, + cutoff_function={"ShiftedCosine": {"width": 0.5}}, + max_angular=5, + max_radial=3, + radial_basis={"Gto": {"spline_accuracy": 1e-3}}, + radial_scaling={"Willatt2018": {"exponent": 3, "rate": 2.2, "scale": 1.1}}, + ) + + error_message = str(err.value.args[0]) + first_line = error_message.find("\n") + code = error_message[first_line:] + + # max radial meaning changed + assert '"max_radial": 2' in code + + # check that the error message contains valid code that can be copy/pasted + eval(code) + + +def test_lode_hypers(): + message = ( + "hyper parameter changed recently, please update your code. " + "Here are the new equivalent parameters" + ) + with pytest.raises(ValueError, match=message) as err: + LodeSphericalExpansion( + atomic_gaussian_width=0.3, + center_atom_weight=0.5, + cutoff=3.4, + cutoff_function={"Step": {}}, + max_angular=5, + max_radial=3, + radial_basis={"Gto": {"splined_radial_integral": False}}, + potential_exponent=3, + k_cutoff=26.2, + ) + + error_message = str(err.value.args[0]) + first_line = error_message.find("\n") + code = error_message[first_line:] + + # max radial meaning changed + assert '"max_radial": 2' in code + + # check that the error message contains valid code that can be copy/pasted + eval(code)