From 8ce811941fc8beab6c1e18b92cbe834c1b2455a9 Mon Sep 17 00:00:00 2001 From: Erfan Zare Chavoshi <59269023+erfanzar@users.noreply.github.com> Date: Tue, 28 May 2024 01:48:24 +0330 Subject: [PATCH] Update causal_language_model_training_example.py --- .../training/causal_language_model_training_example.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/training/causal_language_model_training_example.py b/examples/training/causal_language_model_training_example.py index 73bb80ba..eb329913 100644 --- a/examples/training/causal_language_model_training_example.py +++ b/examples/training/causal_language_model_training_example.py @@ -65,7 +65,7 @@ def main(): dtype = jnp.bfloat16 sharding_axis_dims = eval(FLAGS.sharding_axis_dims) - FLAGS.input_shape = eval(FLAGS.input_shape) + input_shape = eval(FLAGS.input_shape) qps = PartitionSpec(("dp", "fsdp"), "sp", None, "tp") kps = PartitionSpec(("dp", "fsdp"), "sp", None, "tp") vps = PartitionSpec(("dp", "fsdp"), "sp", None, "tp") @@ -83,7 +83,7 @@ def main(): model, params = AutoEasyDeLModelForCausalLM.from_pretrained( FLAGS.pretrained_model_name_or_path, device=jax.devices('cpu')[0], - input_shape=FLAGS.input_shape, + input_shape=input_shape, device_map="auto", sharding_axis_dims=sharding_axis_dims, config_kwargs=dict( @@ -115,7 +115,7 @@ def main(): "config": config, "dtype": dtype, "param_dtype": dtype, - "input_shape": FLAGS.input_shape + "input_shape": input_shape } if tokenizer.pad_token == None: @@ -144,7 +144,7 @@ def main(): scheduler=FLAGS.scheduler, weight_decay=FLAGS.weight_decay, total_batch_size=FLAGS.total_batch_size, - init_input_shape=FLAGS.input_shape, + init_input_shape=input_shape, max_sequence_length=FLAGS.max_length, model_name=FLAGS.model_name, training_time=FLAGS.training_time, @@ -183,7 +183,7 @@ def main(): with jax.default_device(jax.devices("cpu")[0]): state = EasyDeLState.load_state( output.checkpoint_path, - input_shape=FLAGS.input_shape, + input_shape=input_shape, ) if model_use_tie_word_embedding: