From f4bf91c80b4f8151560e08a6bb8f4b4d329eb319 Mon Sep 17 00:00:00 2001 From: Yuan Chiang Date: Sun, 12 Jan 2025 14:54:20 -0800 Subject: [PATCH] promote tasks up to module level --- mlip_arena/tasks/__init__.py | 12 +++++++++++- mlip_arena/tasks/neb.py | 4 ++-- tests/test_neb.py | 5 ++--- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/mlip_arena/tasks/__init__.py b/mlip_arena/tasks/__init__.py index 7a13751..bdaf2e9 100644 --- a/mlip_arena/tasks/__init__.py +++ b/mlip_arena/tasks/__init__.py @@ -6,7 +6,17 @@ from mlip_arena.models import MLIP from mlip_arena.models import REGISTRY as MODEL_REGISTRY -with open(Path(__file__).parent / "registry.yaml") as f: + +from .elasticity import run as ELASTICITY +from .eos import run as EOS +from .md import run as MD +from .neb import run as NEB +from .neb import run_from_endpoints as NEB_FROM_ENDPOINTS +from .optimize import run as OPT + +__all__ = ["OPT", "EOS", "MD", "NEB", "NEB_FROM_ENDPOINTS", "ELASTICITY"] + +with open(Path(__file__).parent / "registry.yaml", encoding="utf-8") as f: REGISTRY = yaml.safe_load(f) diff --git a/mlip_arena/tasks/neb.py b/mlip_arena/tasks/neb.py index e304e40..44aa527 100644 --- a/mlip_arena/tasks/neb.py +++ b/mlip_arena/tasks/neb.py @@ -167,11 +167,11 @@ def run( @task( - name="NEB from end points", + name="NEB from endpoints", task_run_name=_generate_task_run_name, cache_policy=TASK_SOURCE + INPUTS, ) -def run_from_end_points( +def run_from_endpoints( start: Atoms, end: Atoms, n_images: int, diff --git a/tests/test_neb.py b/tests/test_neb.py index 78bcef4..8db3599 100644 --- a/tests/test_neb.py +++ b/tests/test_neb.py @@ -2,8 +2,7 @@ import pytest from mlip_arena.models import MLIPEnum -from mlip_arena.tasks.neb import run as NEB -from mlip_arena.tasks.neb import run_from_end_points as NEB +from mlip_arena.tasks import NEB_FROM_ENDPOINTS from prefect.testing.utilities import prefect_test_harness from ase.spacegroup import crystal @@ -32,7 +31,7 @@ def test_neb(model: MLIPEnum): """ with prefect_test_harness(): - result = NEB( + result = NEB_FROM_ENDPOINTS( start=start.copy(), end=end.copy(), n_images=5,