diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 5c69ac4e..894d97f0 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -209,7 +209,11 @@ def pipeline_llama_manual( batch_size = job_config.training.batch_size local_seq_len = int(job_config.training.seq_len // parallel_dims.tp) layers_io_shape = (batch_size, local_seq_len, model_config.dim) - output_layer_shape = (batch_size, local_seq_len, model_config.vocab_size) + output_layer_shape = ( + batch_size, + job_config.training.seq_len, + model_config.vocab_size, + ) if pp_rank == 0: # first layer input = torch.randint(