From fd0986009b65b529871ec5715bd7551f982aa29a Mon Sep 17 00:00:00 2001 From: Matteo Giantomassi Date: Wed, 4 Sep 2024 18:04:09 +0200 Subject: [PATCH] abiml: add support for orb and sevenn --- abipy/ml/aseml.py | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/abipy/ml/aseml.py b/abipy/ml/aseml.py index f50f1dd2f..3e75a784b 100644 --- a/abipy/ml/aseml.py +++ b/abipy/ml/aseml.py @@ -1466,6 +1466,8 @@ class CalcBuilder: "nequip", "metatensor", "deepmd", + "orb", + "sevenn", ] def __init__(self, name: str, dftd3_args=None, **kwargs): @@ -1740,6 +1742,45 @@ class MyDpCalculator(_MyCalculator, DP): cls = MyDpCalculator if with_delta else DP calc = cls(self.model_path, **self.calc_kwargs) + elif self.nn_type == "orb": + try: + from orb_models.forcefield import pretrained + from orb_models.forcefield.calculator import ORBCalculator + except ImportError as exc: + raise ImportError("orb not installed. See https://github.com/orbital-materials/orb-models") from exc + + + class MyOrbCalculator(_MyCalculator, ORBCalculator): + """Add abi_forces and abi_stress""" + + model_name = "orb-v1" if self.model_name is None else self.model_name + # Mapping model_name --> function returning the model e.g. {"orb-v1": orb_v1} + f = pretrained.ORB_PRETRAINED_MODELS[model_name] + model = f() + + cls = MyOrbCalculator if with_delta else OrbCalculator + calc = cls(model, **self.calc_kwargs) + + elif self.nn_type == "sevenn": + try: + from sevenn.sevennet_calculator import SevenNetCalculator + except ImportError as exc: + raise ImportError("sevenn not installed. See https://github.com/MDIL-SNU/SevenNet") from exc + + class MySevenNetCalculator(_MyCalculator, SevenNetCalculator): + """Add abi_forces and abi_stress""" + + # 7net-0, SevenNet-0, 7net-0_22May2024, 7net-0_11July2024 ... + # model_name = "7net-0" if self.model_name is None else self.model_name + # SevenNet-0 (11July2024) + # This model was trained on MPtrj. We suggest starting with this model as we found that it performs better + # than the previous SevenNet-0 (22May2024). + # Check Matbench Discovery leaderborad for this model's performance on materials discovery. For more information, click here. + + model_name = "SevenNet-0" if self.model_name is None else self.model_name + cls = MySevenNetCalculator if with_delta else SevenNetCalculator + calc = MySevenNetCalculator(model=model_name, **self.calc_kwargs) + else: raise ValueError(f"Invalid {self.nn_type=}")