From 321e0a6cd247cf93ee976c40a99759503c7e4aa4 Mon Sep 17 00:00:00 2001 From: Hubert Beck Date: Fri, 20 Dec 2024 18:20:09 +0000 Subject: [PATCH] fix preprocessed test sets --- mace/cli/preprocess_data.py | 4 +++- mace/cli/run_train.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/mace/cli/preprocess_data.py b/mace/cli/preprocess_data.py index ef9f1343..a6ac7001 100644 --- a/mace/cli/preprocess_data.py +++ b/mace/cli/preprocess_data.py @@ -123,7 +123,7 @@ def multi_valid_hdf5(process, args, split_valid, drop_last): def multi_test_hdf5(process, name, args, split_test, drop_last): with h5py.File( - args.h5_prefix + "test/" + name + "_" + str(process) + ".h5", "w" + args.h5_prefix + "test/" + name + "/" + "test_" + str(process) + ".h5", "w" ) as f: f.attrs["drop_last"] = drop_last save_configurations_as_HDF5(split_test[process], process, f) @@ -268,6 +268,8 @@ def run(args: argparse.Namespace): if args.test_file is not None: logging.info("Preparing test sets") for name, subset in collections.tests: + if not os.path.exists(args.h5_prefix + "test/" + name): + os.makedirs(args.h5_prefix + "test/" + name) drop_last = False if len(subset) % 2 == 1: drop_last = True diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 1c0898b7..6f82ebc5 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -706,7 +706,7 @@ def run(args: argparse.Namespace) -> None: else: test_folders = glob(head_config.test_dir + "/*") for folder in test_folders: - name = os.path.splitext(os.path.basename(test_file))[0] + name = os.path.splitext(os.path.basename(folder))[0] test_sets[name] = data.dataset_from_sharded_hdf5( folder, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name )