diff --git a/datasets/gcsfuse.sh b/datasets/gcsfuse.sh index 5e5e9f2..de6235f 100755 --- a/datasets/gcsfuse.sh +++ b/datasets/gcsfuse.sh @@ -38,15 +38,15 @@ fi if [[ -d ${MOUNT_PATH} ]]; then echo "$MOUNT_PATH exists, removing..." - sudo fusermount -u $MOUNT_PATH || rm -rf $MOUNT_PATH + fusermount -u $MOUNT_PATH || rm -rf $MOUNT_PATH fi -sudo mkdir -p $MOUNT_PATH +mkdir -p $MOUNT_PATH # see https://cloud.google.com/storage/docs/gcsfuse-cli for all configurable options of gcsfuse CLI # Grain uses _PROCESS_MANAGEMENT_MAX_THREADS = 64 (https://github.com/google/grain/blob/main/grain/_src/python/grain_pool.py) # Please make sure max-conns-per-host > grain_worker_count * _PROCESS_MANAGEMENT_MAX_THREADS -sudo gcsfuse -o ro --implicit-dirs --http-client-timeout=5s --max-conns-per-host=0 --max-idle-conns-per-host=10000 \ +gcsfuse -o ro --implicit-dirs --http-client-timeout=5s --max-conns-per-host=0 --max-idle-conns-per-host=10000 \ --experimental-enable-json-read --kernel-list-cache-ttl-secs=-1 -o ro --config-file=$HOME/gcsfuse.yml \ --log-file=$HOME/gcsfuse.json "$DATASET_GCS_BUCKET" "$MOUNT_PATH" \ No newline at end of file diff --git a/training.py b/training.py index a7ab86c..c73dcbe 100644 --- a/training.py +++ b/training.py @@ -172,8 +172,8 @@ def map(self, element) -> Dict[str, jnp.array]: # CC12m and other GCS data sources -------------------------------------------------------------# # -----------------------------------------------------------------------------------------------# -def data_source_gcs(source="/mnt/gcs_mount/arrayrecord/cc12m/"): - def data_source(): +def data_source_gcs(): + def data_source(source="/mnt/gcs_mount/arrayrecord/cc12m/"): records_path = source records = [os.path.join(records_path, i) for i in os.listdir( records_path) if 'array_record' in i] @@ -235,7 +235,7 @@ def map(self, element) -> Dict[str, jnp.array]: "augmenter": tfds_augmenters, }, "cc12m": { - "source": data_source_gcs("/mnt/gcs_mount/arrayrecord/cc12m/"), + "source": data_source_gcs(), "augmenter": gcs_augmenters, }, } @@ -254,9 +254,10 @@ def get_dataset_grain( grain_read_buffer_size=50, grain_worker_buffer_size=20, seed=0, + dataset_source="/mnt/gcs_mount/arrayrecord/cc12m/", ): dataset = datasetMap[data_name] - data_source = dataset["source"]() + data_source = dataset["source"](dataset_source) augmenter = dataset["augmenter"](image_scale, method) local_batch_size = batch_size // jax.process_count() @@ -806,6 +807,8 @@ def fit(self, data, steps_per_epoch, epochs): default=None, help='Steps per epoch') parser.add_argument('--dataset', type=str, default='cc12m', help='Dataset to use') +parser.add_argument('--dataset_path', type=str, + default='/home/mrwhite0racle/gcs_mount', "Dataset location path") parser.add_argument('--learning_rate', type=float, default=2e-4, help='Learning rate')