diff --git a/test_runner.py b/test_runner.py index 6d706a64..16f2ed5c 100755 --- a/test_runner.py +++ b/test_runner.py @@ -142,7 +142,7 @@ def build_test_list(): "--checkpoint.enable_checkpoint", "--experimental.pipeline_parallel_degree 4", "--experimental.pipeline_parallel_split_points layers.1,layers.2,layers.3,layers.4,layers.5,layers.6,layers.7", - "--experimental.pipeline_parallel_schedule flexible_interleaved_1f1b", + "--experimental.pipeline_parallel_schedule FlexibleInterleaved1F1B", ], ], "PP looped flexible 1f1b test", @@ -265,7 +265,7 @@ def build_test_list(): "--checkpoint.enable_checkpoint", "--experimental.pipeline_parallel_degree 4", "--experimental.pipeline_parallel_split_points layers.1,layers.2,layers.3,layers.4,layers.5,layers.6,layers.7", - "--experimental.pipeline_parallel_schedule interleaved_1f1b", + "--experimental.pipeline_parallel_schedule Interleaved1F1B", ], ], "PP looped 1f1b test", diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 67c82d53..8b8a3026 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -299,14 +299,14 @@ def __init__(self): self.parser.add_argument( "--experimental.pipeline_parallel_schedule", type=str, - choices=["1f1b", "gpipe", "interleaved_1f1b", "flexible_interleaved_1f1b"], + choices=["1f1b", "gpipe", "Interleaved1F1B", "FlexibleInterleaved1F1B"], default="1f1b", help=""" Specify the Pipeline Parallel schedule to use. The schedule must be compatible with the split points and stages_per_rank. - Looped schedules (e.g. interleaved_1f1b) require specifying pipeline_paralle_degree = number of ranks, + Looped schedules (e.g. Interleaved1F1B) require specifying pipeline_parallel_degree = number of ranks, and split_points = number of stages - 1""", ) self.parser.add_argument( diff --git a/torchtitan/parallelisms/pipelining_utils.py b/torchtitan/parallelisms/pipelining_utils.py index a5c61e62..b3b97e62 100644 --- a/torchtitan/parallelisms/pipelining_utils.py +++ b/torchtitan/parallelisms/pipelining_utils.py @@ -6,34 +6,21 @@ from typing import Tuple from torch.distributed.pipelining import ( - Schedule1F1B, ScheduleFlexibleInterleaved1F1B, - ScheduleGPipe, ScheduleInterleaved1F1B, ) +from torch.distributed.pipelining.schedules import get_schedule_class from torchtitan.logging import logger def build_pipeline_schedule(job_config, stages, loss_fn): looped_schedule = False - if job_config.experimental.pipeline_parallel_schedule == "1f1b": - schedule_class = Schedule1F1B - elif job_config.experimental.pipeline_parallel_schedule == "gpipe": - schedule_class = ScheduleGPipe - elif job_config.experimental.pipeline_parallel_schedule == "interleaved_1f1b": - schedule_class = ScheduleInterleaved1F1B - looped_schedule = True - elif ( + schedule_class = get_schedule_class( job_config.experimental.pipeline_parallel_schedule - == "flexible_interleaved_1f1b" - ): - schedule_class = ScheduleFlexibleInterleaved1F1B + ) + if schedule_class in [ScheduleInterleaved1F1B, ScheduleFlexibleInterleaved1F1B]: looped_schedule = True - else: - raise NotImplementedError( - f"{job_config.experimental.pipeline_parallel_schedule} is not implemented" - ) logger.info( f"Using pipeline schedule {job_config.experimental.pipeline_parallel_schedule}" )