Skip to content

Commit

Permalink
Add a 3-stage PP config
Browse files Browse the repository at this point in the history
Pipelining is unique in that there is no need to stick to power-of-2
numbers of stages, and there maybe reasons an odd number is optimal
depending on how you divide up your cluster.

Anyway, I use this for validation of the 1f1b schedule in a slightly-more-complicated
than 2-stage but simpler than 4-stage setup.

seems to run fine, if run with an even batch size
(`--training.batch_size 12`)

ghstack-source-id: 289eeb8473afa84e3b767986f9fb285f1d91fbf2
Pull Request resolved: #345
  • Loading branch information
wconstab committed May 22, 2024
1 parent fadb3ab commit 82faac9
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 2 deletions.
15 changes: 15 additions & 0 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,21 @@ def build_test_list(args):
"""
integration_tests_flavors = defaultdict(list)
integration_tests_flavors["debug_model.toml"] = [
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/pp_1f1b_3stage/",
"--experimental.pipeline_parallel_degree 3",
"--experimental.pipeline_parallel_split_points layers.1, layers.2",
"--experimental.pipeline_parallel_schedule 1f1b",
"--training.data_parallel_degree 1",
],
],
"PP 1D test 1f1b with 3 PP stages",
requires_seed_checkpoint=True,
ngpu=3,
),
OverrideDefinitions(
[
[
Expand Down
4 changes: 2 additions & 2 deletions torchtitan/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
__all__ = ["Transformer"]

llama2_configs = {
"debugmodel": ModelArgs(dim=256, n_layers=2, n_heads=16),
"debugmodel": ModelArgs(dim=256, n_layers=3, n_heads=16),
"271M": ModelArgs(dim=1024, n_layers=16, n_heads=8),
"1B": ModelArgs(dim=2048, n_layers=18, n_heads=16),
"7B": ModelArgs(dim=4096, n_layers=32, n_heads=32),
Expand All @@ -29,7 +29,7 @@
}

llama3_configs = {
"debugmodel": ModelArgs(dim=256, n_layers=2, n_heads=16, rope_theta=500000),
"debugmodel": ModelArgs(dim=256, n_layers=3, n_heads=16, rope_theta=500000),
"8B": ModelArgs(
dim=4096,
n_layers=32,
Expand Down
6 changes: 6 additions & 0 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,12 @@ def pipeline_llama_manual(

logger.info(f"PP rank {pp_rank} is using this model chunk\n{model}")

# TODO, support this? or just guard against it inside the lib
if job_config.training.batch_size % parallel_dims.pp != 0:
raise ValueError(
f"batch_size {job_config.training.batch_size} not divisible by pp dim, currently unsupported"
)

# TODO(whc) once ManualPipelineStage supports lazy shape inference, we can leave model on meta device longer and
# get rid of the input shape hardcoded here. For now, it should not be a big deal since we only materialize the
# layers of the model that map to this stage, not the whole model.
Expand Down

0 comments on commit 82faac9

Please sign in to comment.