From 1ceaa4e2adc8ef5a0864f99e126e4ab18cd7db8f Mon Sep 17 00:00:00 2001 From: Will Constable Date: Thu, 23 May 2024 17:02:11 -0700 Subject: [PATCH] Add test for PP tracer frontend - switch to using public PipelineStage API - clean up some asserts in tracer codepath ghstack-source-id: 2d069b7d45c4f3c788dec8fc85d8a7e83e463fcd Pull Request resolved: https://github.com/pytorch/torchtitan/pull/357 --- test_runner.py | 13 ++++++ torchtitan/parallelisms/parallelize_llama.py | 46 ++++++++++---------- 2 files changed, 37 insertions(+), 22 deletions(-) diff --git a/test_runner.py b/test_runner.py index 834fc080..59bc49a4 100755 --- a/test_runner.py +++ b/test_runner.py @@ -113,6 +113,19 @@ def build_test_list(args): "PP+TP 2D test", requires_seed_checkpoint=True, ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + f"--job.dump_folder {args.output_dir}/pp_tracer/", + "--experimental.pipeline_parallel_degree 2", + "--experimental.pipeline_parallel_split_points layers.1", + "--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with tracer + ], + ], + "PP tracer frontend test", + requires_seed_checkpoint=True, + ), OverrideDefinitions( [ [ diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 425d3abe..3617eb23 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -18,10 +18,11 @@ checkpoint_wrapper as ptd_checkpoint_wrapper, CheckpointImpl, ) -from torch.distributed.pipelining import pipeline, SplitPoint -from torch.distributed.pipelining.PipelineStage import ( - _PipelineStage, +from torch.distributed.pipelining import ( ManualPipelineStage, + pipeline, + PipelineStage, + SplitPoint, ) from torch.distributed.tensor.parallel import ( ColwiseParallel, @@ -159,6 +160,14 @@ def _llama_trace_input(job_config, model_config, device="meta"): return (tokens,) +def _mixed_precision_dtype( + job_config: JobConfig, parallel_dims, default: torch.dtype = torch.float32 +) -> torch.dtype: + """Get the mixed precision dtype if fsdp is enabled, otherwise return the default""" + mp_arg = job_config.training.mixed_precision_param + return TORCH_DTYPE_MAP[mp_arg] if parallel_dims.dp_enabled else default + + def pipeline_llama_manual( model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict ): @@ -204,8 +213,7 @@ def pipeline_llama_manual( # TODO(whc) once ManualPipelineStage supports lazy shape inference, we can leave model on meta device longer and # get rid of the input shape hardcoded here. For now, it should not be a big deal since we only materialize the # layers of the model that map to this stage, not the whole model. - mp_arg = job_config.training.mixed_precision_param - mp_dtype = TORCH_DTYPE_MAP[mp_arg] if parallel_dims.dp_enabled else torch.float32 + mp_dtype = _mixed_precision_dtype(job_config, parallel_dims) batch_size = job_config.training.batch_size local_seq_len = int(job_config.training.seq_len // parallel_dims.tp) layers_io_shape = (batch_size, local_seq_len, model_config.dim) @@ -216,12 +224,7 @@ def pipeline_llama_manual( ) if pp_rank == 0: # first layer - input = torch.randint( - model_config.vocab_size, - size=(batch_size, job_config.training.seq_len), - dtype=torch.int64, - device=device, - ) + (input,) = _llama_trace_input(job_config, model_config, device=device) else: # later layers (assume all start w/ a transformer layer) input = torch.rand(layers_io_shape, dtype=mp_dtype, device=device) @@ -257,21 +260,21 @@ def pipeline_llama_tracer( "fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm." ) - # TODO(whc) maybe we can just fix this by feeding bf16 into the tracer for its input shapes? - raise NotImplementedError( - "pipeline tracer doesn't work with fsdp mixed precision currently. " - "To work around, edit fsdp mixed precision config to use fp32." - ) + if _mixed_precision_dtype(job_config, parallel_dims) == torch.bfloat16: + raise NotImplementedError( + "pipeline tracer doesn't work with fsdp mixed precision currently. " + "To work around, edit fsdp mixed precision config to use fp32." + ) + pp_mesh = world_mesh["pp"] pp_rank = pp_mesh.get_local_rank() - stage_idx = pp_mesh.get_local_rank() + stage_idx = pp_rank layers_per_rank = len(model.layers) // parallel_dims.pp split_spec = { f"layers.{i * layers_per_rank}": SplitPoint.BEGINNING for i in range(1, parallel_dims.pp) } - # Create a pipeline representation from the model pipe = pipeline( model, job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp, @@ -279,10 +282,9 @@ def pipeline_llama_tracer( split_spec=split_spec, ) model = pipe.get_stage_module(stage_idx) - stage = _PipelineStage( - stage_module=model, - stage_index=pp_rank, - pipe_info=pipe.pipe_info, + stage = PipelineStage( + pipe, + stage_index=stage_idx, device=device, group=pp_mesh.get_group(), )