diff --git a/src/autoplex/fitting/common/flows.py b/src/autoplex/fitting/common/flows.py index 7914d8248..50172c395 100644 --- a/src/autoplex/fitting/common/flows.py +++ b/src/autoplex/fitting/common/flows.py @@ -450,81 +450,40 @@ def make( adapter = AseAtomsAdaptor() + # must always exist + required_paths = ["train.extxyz", "test.extxyz"] + + optional_paths = [ + "phonon/train.extxyz", + "phonon/test.extxyz", + "rattled/train.extxyz", + "rattled/test.extxyz", + "without_regularization/train.extxyz", + "without_regularization/test.extxyz", + ] + database_dict = { - "train.extxyz": [ - adapter.get_structure(atoms) - for atoms in ase.io.read(Path.cwd() / "train.extxyz", ":") - ], - "test.extxyz": [ + path: [ adapter.get_structure(atoms) - for atoms in ase.io.read(Path.cwd() / "test.extxyz", ":") - ], - "phonon/train.extxyz": ( - None - if not Path(Path.cwd() / "phonon" / "train.extxyz").exists() - else [ - adapter.get_structure(atoms) - for atoms in ase.io.read( - Path.cwd() / "phonon" / "train.extxyz", ":" - ) - ] - ), - "phonon/test.extxyz": ( - None - if not Path(Path.cwd() / "phonon" / "test.extxyz").exists() - else [ - adapter.get_structure(atoms) - for atoms in ase.io.read( - Path.cwd() / "phonon" / "test.extxyz", ":" - ) - ] - ), - "rattled/train.extxyz": ( - None - if not Path(Path.cwd() / "rattled" / "train.extxyz").exists() - else [ - adapter.get_structure(atoms) - for atoms in ase.io.read( - Path.cwd() / "rattled" / "train.extxyz", ":" - ) - ] - ), - "rattled/test.extxyz": ( - None - if not Path(Path.cwd() / "rattled" / "test.extxyz").exists() - else [ - adapter.get_structure(atoms) - for atoms in ase.io.read( - Path.cwd() / "rattled" / "test.extxyz", ":" - ) - ] - ), - "without_regularization/train.extxyz": ( - None - if not Path( - Path.cwd() / "without_regularization" / "train.extxyz" - ).exists() - else [ - adapter.get_structure(atoms) - for atoms in ase.io.read( - Path.cwd() / "without_regularization" / "train.extxyz", ":" - ) - ] - ), - "without_regularization/test.extxyz": ( - None - if not Path( - Path.cwd() / "without_regularization" / "test.extxyz" - ).exists() - else [ - adapter.get_structure(atoms) - for atoms in ase.io.read( - Path.cwd() / "without_regularization" / "test.extxyz", ":" - ) - ] - ), + for atoms in ase.io.read(Path.cwd() / path, ":") + ] + for path in required_paths } + database_dict.update( + { + path: ( + [ + adapter.get_structure(atoms) + for atoms in ase.io.read(Path.cwd() / path, ":") + ] + if (Path.cwd() / path).exists() + else None + ) + for path in optional_paths + } + ) + return {"database_dir": Path.cwd(), "database_dict": database_dict} return {"database_dir": Path.cwd(), "database_dict": None} diff --git a/tests/auto/phonons/test_flows.py b/tests/auto/phonons/test_flows.py index 0087dd054..64db0bd4f 100644 --- a/tests/auto/phonons/test_flows.py +++ b/tests/auto/phonons/test_flows.py @@ -812,6 +812,7 @@ def test_iterative_complete_dft_vs_ml_benchmark_workflow_gap(vasp_test_dir, mock assert len(vasp_xyz) == 10 assert isinstance(complete_workflow.output.resolve(memory_jobstore)["dft_references"], list) + def test_iterative_complete_dft_vs_ml_benchmark_workflow_gap_add_phonon_false(vasp_test_dir, mock_vasp, test_dir, memory_jobstore, ref_paths4_mpid_new2, fake_run_vasp_kwargs4_mpid_new2, clean_dir): # first test with just one iteration (more tests need to be added) from ase.io import read @@ -957,7 +958,6 @@ def test_complete_dft_vs_gap_benchmark_workflow_database( assert expected_soap_dict in results_file, f"Expected soap_dict not found in {file_path}" - def test_complete_dft_vs_ml_benchmark_workflow_m3gnet( vasp_test_dir, mock_vasp, test_dir, memory_jobstore, ref_paths4_mpid, fake_run_vasp_kwargs4_mpid, clean_dir ):