Skip to content

Commit

Permalink
yet more fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Aug 2, 2024
1 parent 25448cf commit 644f4bc
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions training.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,7 @@ def model_loss(params):
return state, loss

if distributed_training:
train_step = jax.pmap(axis_name="device")(train_step)
train_step = jax.pmap(train_step, axis_name="device")
else:
train_step = jax.jit(train_step)

Expand Down Expand Up @@ -808,7 +808,7 @@ def fit(self, data, steps_per_epoch, epochs):
parser.add_argument('--dataset', type=str,
default='cc12m', help='Dataset to use')
parser.add_argument('--dataset_path', type=str,
default='/home/mrwhite0racle/gcs_mount', help="Dataset location path")
default='/home/mrwhite0racle/gcs_mount/arrayrecord/cc12m', help="Dataset location path")

parser.add_argument('--learning_rate', type=float,
default=2e-4, help='Learning rate')
Expand Down

0 comments on commit 644f4bc

Please sign in to comment.