From 328b99d4ca7323ea3c85f7f33eaaca31d1dd38ff Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 19 Oct 2023 18:33:44 -0700 Subject: [PATCH] FastMRI dataset split fix: - Fix end indexing for validation set. - Correct number of test examples. PiperOrigin-RevId: 575063140 --- init2winit/dataset_lib/fastmri_dataset.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/init2winit/dataset_lib/fastmri_dataset.py b/init2winit/dataset_lib/fastmri_dataset.py index 574a81df..1b0ddde8 100644 --- a/init2winit/dataset_lib/fastmri_dataset.py +++ b/init2winit/dataset_lib/fastmri_dataset.py @@ -45,7 +45,7 @@ # two, 100 files for validation, 99 for test. This amounts to 3,554 slices # for validation, 3,581 for test. test_dir='knee_singlecoil_val', - test_size=3581, + test_size=3548, num_test_h5_files=99, eval_seed=0, )) @@ -212,7 +212,10 @@ def load_split(per_host_batch_size, split, hps, shuffle_rng=None): # entirely to the end of it on the last host, because otherwise we will drop # the last `{train,valid}_size % split_size` elements. if jax.process_index() == jax.process_count() - 1: - end = -1 + if split == 'val': + end = hps.num_valid_h5_files + else: + end = -1 data_dir = hps.data_dir @@ -229,7 +232,6 @@ def load_split(per_host_batch_size, split, hps, shuffle_rng=None): h5_paths = [ os.path.join(data_dir, path) for path in listdir(data_dir) ][start:end] - ds = tf.data.Dataset.from_tensor_slices(h5_paths) ds = ds.interleave( _create_generator,