Skip to content

Commit

Permalink
Update causal_language_model_training_example.py
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar authored May 27, 2024
1 parent 6ee12b8 commit 7e1bddb
Showing 1 changed file with 1 addition and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def main():

dtype = jnp.bfloat16
sharding_axis_dims = eval(FLAGS.sharding_axis_dims)
FLAGS.input_shape = eval(FLAGS.input_shape)
qps = PartitionSpec(("dp", "fsdp"), "sp", None, "tp")
kps = PartitionSpec(("dp", "fsdp"), "sp", None, "tp")
vps = PartitionSpec(("dp", "fsdp"), "sp", None, "tp")
Expand Down

0 comments on commit 7e1bddb

Please sign in to comment.