Skip to content

Commit

Permalink
FastMRI dataset split fix:
Browse files Browse the repository at this point in the history
- Fix end indexing for validation set.
- Correct number of test examples.

PiperOrigin-RevId: 573916348
  • Loading branch information
priyakasimbeg authored and copybara-github committed Oct 20, 2023
1 parent ae878a7 commit 960c2e1
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions init2winit/dataset_lib/fastmri_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
# two, 100 files for validation, 99 for test. This amounts to 3,554 slices
# for validation, 3,581 for test.
test_dir='knee_singlecoil_val',
test_size=3581,
test_size=3548,
num_test_h5_files=99,
eval_seed=0,
))
Expand Down Expand Up @@ -212,7 +212,10 @@ def load_split(per_host_batch_size, split, hps, shuffle_rng=None):
# entirely to the end of it on the last host, because otherwise we will drop
# the last `{train,valid}_size % split_size` elements.
if jax.process_index() == jax.process_count() - 1:
end = -1
if split == 'val':
end = hps.num_valid_h5_files
else:
end = -1

data_dir = hps.data_dir

Expand All @@ -229,7 +232,6 @@ def load_split(per_host_batch_size, split, hps, shuffle_rng=None):
h5_paths = [
os.path.join(data_dir, path) for path in listdir(data_dir)
][start:end]

ds = tf.data.Dataset.from_tensor_slices(h5_paths)
ds = ds.interleave(
_create_generator,
Expand Down

0 comments on commit 960c2e1

Please sign in to comment.