Skip to content

Commit

Permalink
Update TfGrainCheckpointHandler.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 686405738
  • Loading branch information
Connectomics Team authored and copybara-github committed Oct 17, 2024
1 parent b9d4ca5 commit 0286a94
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion ffn/jax/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def train_and_evaluate(
train_state_path, args=ocp.args.StandardRestore(state)
)
checkpointed_state['train_iter'] = iter_handler.restore(
train_iter_path, args
train_iter_path, args=args
)
logging.info('Initializing training from %r', config.init_from_cpoint)
elif latest_step is not None:
Expand Down

0 comments on commit 0286a94

Please sign in to comment.