Skip to content

Commit

Permalink
fix: some more fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Aug 2, 2024
1 parent 3a21c70 commit 37041b3
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
6 changes: 3 additions & 3 deletions datasets/gcsfuse.sh
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ fi

if [[ -d ${MOUNT_PATH} ]]; then
echo "$MOUNT_PATH exists, removing..."
sudo fusermount -u $MOUNT_PATH || rm -rf $MOUNT_PATH
fusermount -u $MOUNT_PATH || rm -rf $MOUNT_PATH
fi

sudo mkdir -p $MOUNT_PATH
mkdir -p $MOUNT_PATH

# see https://cloud.google.com/storage/docs/gcsfuse-cli for all configurable options of gcsfuse CLI
# Grain uses _PROCESS_MANAGEMENT_MAX_THREADS = 64 (https://github.com/google/grain/blob/main/grain/_src/python/grain_pool.py)
# Please make sure max-conns-per-host > grain_worker_count * _PROCESS_MANAGEMENT_MAX_THREADS

sudo gcsfuse -o ro --implicit-dirs --http-client-timeout=5s --max-conns-per-host=0 --max-idle-conns-per-host=10000 \
gcsfuse -o ro --implicit-dirs --http-client-timeout=5s --max-conns-per-host=0 --max-idle-conns-per-host=10000 \
--experimental-enable-json-read --kernel-list-cache-ttl-secs=-1 -o ro --config-file=$HOME/gcsfuse.yml \
--log-file=$HOME/gcsfuse.json "$DATASET_GCS_BUCKET" "$MOUNT_PATH"
11 changes: 7 additions & 4 deletions training.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ def map(self, element) -> Dict[str, jnp.array]:
# CC12m and other GCS data sources -------------------------------------------------------------#
# -----------------------------------------------------------------------------------------------#

def data_source_gcs(source="/mnt/gcs_mount/arrayrecord/cc12m/"):
def data_source():
def data_source_gcs():
def data_source(source="/mnt/gcs_mount/arrayrecord/cc12m/"):
records_path = source
records = [os.path.join(records_path, i) for i in os.listdir(
records_path) if 'array_record' in i]
Expand Down Expand Up @@ -235,7 +235,7 @@ def map(self, element) -> Dict[str, jnp.array]:
"augmenter": tfds_augmenters,
},
"cc12m": {
"source": data_source_gcs("/mnt/gcs_mount/arrayrecord/cc12m/"),
"source": data_source_gcs(),
"augmenter": gcs_augmenters,
},
}
Expand All @@ -254,9 +254,10 @@ def get_dataset_grain(
grain_read_buffer_size=50,
grain_worker_buffer_size=20,
seed=0,
dataset_source="/mnt/gcs_mount/arrayrecord/cc12m/",
):
dataset = datasetMap[data_name]
data_source = dataset["source"]()
data_source = dataset["source"](dataset_source)
augmenter = dataset["augmenter"](image_scale, method)

local_batch_size = batch_size // jax.process_count()
Expand Down Expand Up @@ -806,6 +807,8 @@ def fit(self, data, steps_per_epoch, epochs):
default=None, help='Steps per epoch')
parser.add_argument('--dataset', type=str,
default='cc12m', help='Dataset to use')
parser.add_argument('--dataset_path', type=str,
default='/home/mrwhite0racle/gcs_mount', help="Dataset location path")

parser.add_argument('--learning_rate', type=float,
default=2e-4, help='Learning rate')
Expand Down

0 comments on commit 37041b3

Please sign in to comment.