diff --git a/.github/workflows/integration_test_4gpu.yaml b/.github/workflows/integration_test_4gpu.yaml index a82bd96c..993ae797 100644 --- a/.github/workflows/integration_test_4gpu.yaml +++ b/.github/workflows/integration_test_4gpu.yaml @@ -38,6 +38,5 @@ jobs: pip config --user set global.progress_bar off python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 - USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git mkdir artifacts-to-be-uploaded python ./test_runner.py artifacts-to-be-uploaded --ngpu 4 diff --git a/torchtitan/models/norms.py b/torchtitan/models/norms.py index 26645330..31527452 100644 --- a/torchtitan/models/norms.py +++ b/torchtitan/models/norms.py @@ -284,7 +284,6 @@ def backward(ctx, dy): M, N = dy.shape dx = torch.empty_like(x) - dw = torch.empty_like(weight) sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) diff --git a/torchtitan/parallelisms/pipelining_utils.py b/torchtitan/parallelisms/pipelining_utils.py index a5c61e62..28640bc1 100644 --- a/torchtitan/parallelisms/pipelining_utils.py +++ b/torchtitan/parallelisms/pipelining_utils.py @@ -41,6 +41,18 @@ def build_pipeline_schedule(job_config, stages, loss_fn): if n_microbatches is None: n_microbatches = job_config.experimental.pipeline_parallel_degree + # Validation that the stages are compatible with the schedule + if isinstance(schedule_class, PipelineScheduleSingle): + if len(stages) != 1: + raise ValueError( + f"PipelineScheduleSingle requires exactly one stage, got {len(stages)}" + ) + elif isinstance(schedule_class, PipelineScheduleMulti): + if len(stages) < 2: + raise ValueError( + f"PipelineScheduleMulti requires at least two stages, got {len(stages)}" + ) + return schedule_class( stages if looped_schedule else stages[0], n_microbatches=n_microbatches,