Skip to content

Commit

Permalink
Update causal_language_model_training_example.py
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar authored May 27, 2024
1 parent 7e1bddb commit 8ce8119
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions examples/training/causal_language_model_training_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def main():

dtype = jnp.bfloat16
sharding_axis_dims = eval(FLAGS.sharding_axis_dims)
FLAGS.input_shape = eval(FLAGS.input_shape)
input_shape = eval(FLAGS.input_shape)
qps = PartitionSpec(("dp", "fsdp"), "sp", None, "tp")
kps = PartitionSpec(("dp", "fsdp"), "sp", None, "tp")
vps = PartitionSpec(("dp", "fsdp"), "sp", None, "tp")
Expand All @@ -83,7 +83,7 @@ def main():
model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
FLAGS.pretrained_model_name_or_path,
device=jax.devices('cpu')[0],
input_shape=FLAGS.input_shape,
input_shape=input_shape,
device_map="auto",
sharding_axis_dims=sharding_axis_dims,
config_kwargs=dict(
Expand Down Expand Up @@ -115,7 +115,7 @@ def main():
"config": config,
"dtype": dtype,
"param_dtype": dtype,
"input_shape": FLAGS.input_shape
"input_shape": input_shape
}

if tokenizer.pad_token == None:
Expand Down Expand Up @@ -144,7 +144,7 @@ def main():
scheduler=FLAGS.scheduler,
weight_decay=FLAGS.weight_decay,
total_batch_size=FLAGS.total_batch_size,
init_input_shape=FLAGS.input_shape,
init_input_shape=input_shape,
max_sequence_length=FLAGS.max_length,
model_name=FLAGS.model_name,
training_time=FLAGS.training_time,
Expand Down Expand Up @@ -183,7 +183,7 @@ def main():
with jax.default_device(jax.devices("cpu")[0]):
state = EasyDeLState.load_state(
output.checkpoint_path,
input_shape=FLAGS.input_shape,
input_shape=input_shape,
)

if model_use_tie_word_embedding:
Expand Down

0 comments on commit 8ce8119

Please sign in to comment.