diff --git a/training.py b/training.py index b043f22..a3f063d 100644 --- a/training.py +++ b/training.py @@ -132,10 +132,10 @@ def __repr__(self): def data_source_tfds(name, use_tf=True, split="all"): import tensorflow_datasets as tfds if use_tf: - def data_source(): + def data_source(path_override): return tfds.load(name, split=split, shuffle_files=True) else: - def data_source(): + def data_source(path_override): return tfds.data_source(name, split=split, try_gcs=False) return data_source @@ -874,7 +874,7 @@ def main(args): IMAGE_SIZE = args.image_size dataset_name = args.dataset - datalen = len(datasetMap[dataset_name]['source']()) + datalen = len(datasetMap[dataset_name]['source'](args.dataset_path)) batches = datalen // BATCH_SIZE # Define the configuration using the command-line arguments attention_configs = [