diff --git a/init2winit/dataset_lib/imagenet_dataset.py b/init2winit/dataset_lib/imagenet_dataset.py index 7ea7a650..4b826653 100644 --- a/init2winit/dataset_lib/imagenet_dataset.py +++ b/init2winit/dataset_lib/imagenet_dataset.py @@ -442,7 +442,7 @@ def get_imagenet(shuffle_rng, batch_size, eval_batch_size, hps, global_step=0): 'test', hps=hps, image_size=image_size, - tfds_dataset_name='imagenet_v2/matched-frequency', + tfds_dataset_name='imagenet_v2/matched-frequency:3.0.0', global_step=global_step, ) test_ds = tfds.as_numpy(test_ds)