diff --git a/ffn/jax/train.py b/ffn/jax/train.py index 698e9af..4149747 100644 --- a/ffn/jax/train.py +++ b/ffn/jax/train.py @@ -432,7 +432,7 @@ def train_and_evaluate( train_state_path, args=ocp.args.StandardRestore(state) ) checkpointed_state['train_iter'] = iter_handler.restore( - train_iter_path, args + train_iter_path, args=args ) logging.info('Initializing training from %r', config.init_from_cpoint) elif latest_step is not None: