From 638ec48941292914075222265fa0d8aadd95df99 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 22 May 2024 09:26:54 -0700 Subject: [PATCH] Fix bug in PP output layer shape mostly harmless bug, since output shape of last layer is not used for send/recv purpose, the runtime value overrides it no matter what value you configured it with. However, since adding in/out shape validation to pipeline lib in torch, this raises an error and has to be fixed. ghstack-source-id: 950e41529b7b506085ab280d8a492e345eaefd24 Pull Request resolved: https://github.com/pytorch/torchtitan/pull/354 --- torchtitan/parallelisms/parallelize_llama.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 5c69ac4e..894d97f0 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -209,7 +209,11 @@ def pipeline_llama_manual( batch_size = job_config.training.batch_size local_seq_len = int(job_config.training.seq_len // parallel_dims.tp) layers_io_shape = (batch_size, local_seq_len, model_config.dim) - output_layer_shape = (batch_size, local_seq_len, model_config.vocab_size) + output_layer_shape = ( + batch_size, + job_config.training.seq_len, + model_config.vocab_size, + ) if pp_rank == 0: # first layer input = torch.randint(