From 96806b177702dec48ea63837ff890cb2479674de Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 6 Jun 2024 21:06:05 -0700 Subject: [PATCH] [pipelining][doc] Add frontend description and change tracer example (#128070) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128070 Approved by: https://github.com/wconstab, https://github.com/H-Huang --- docs/source/distributed.pipelining.rst | 362 ++++++++++++++++--------- torch/distributed/pipelining/_IR.py | 15 +- 2 files changed, 245 insertions(+), 132 deletions(-) diff --git a/docs/source/distributed.pipelining.rst b/docs/source/distributed.pipelining.rst index 2f4218a0d9808..48f66b5d3276c 100644 --- a/docs/source/distributed.pipelining.rst +++ b/docs/source/distributed.pipelining.rst @@ -4,184 +4,314 @@ Pipeline Parallelism #################### -.. note:: ``torch.distributed.pipelining`` is a package migrated from the `PiPPy `_ project. It is currently in alpha state and under extensive development. For examples that work with our APIs, please refer to PiPPy's `examples `_ directory. +.. note:: + ``torch.distributed.pipelining`` is currently in alpha state and under + development. API changes may be possible. It was migrated from the `PiPPy + `_ project. + Why Pipeline Parallel? ********************** -One of the most important techniques for advancing the state of the art in deep learning is scaling. Common techniques for scaling neural networks include *data parallelism*, *tensor/operation parallelism*, and *pipeline parallelism* (or *pipelining*). Pipelining is a technique in which the *code* of the model is partitioned and multiple *micro-batches* execute different parts of the model code concurrently. In many cases, pipeline parallelism can be an effective technique for scaling, in particular for large-scale jobs or bandwidth-limited interconnects. To learn more about pipeline parallelism in deep learning, see `this article `_. - -What is ``torch.distributed.pipelining``? -***************************************** +Pipeline Parallelism is one of the **primitive** parallelism for deep learning. +It allows the **execution** of a model to be partitioned such that multiple +**micro-batches** can execute different parts of the model code concurrently. +Pipeline parallelism can be an effective technique for: -.. automodule:: torch.distributed.pipelining +* large-scale training +* bandwidth-limited clusters +* large model inference. -.. currentmodule:: torch.distributed.pipelining +The above scenarios share a commonality that the computation per device cannot +hide the communication of conventional parallelism, for example, the weight +all-gather of FSDP. -While promising for scaling, pipelining is often difficult to implement, requiring intrusive code changes to model code and difficult-to-implement runtime orchestration code. ``torch.distributed.pipelining`` aims to provide **a toolkit that does said things automatically to allow high-productivity scaling of models.** It consists of a **compiler** and a **runtime** stack for easy pipelining of PyTorch models. In particular, it provides the following features: -* Splitting of model code based on your specification. The goal is for the user to provide model code as-is to the system for parallelization, without having to make heavyweight modifications to make parallelism work. The specification is also simple. -* Support for rich pipeline scheduling paradigms, including GPipe, 1F1B, Interleaved 1F1B and Looped BFS. It will be also easy to customize your own schedule under this framework. -* First-class support for cross-host pipeline parallelism, as this is where PP is typically used (over slower interconnects). -* Composability with other PyTorch parallel schemes such as data parallelism (DDP, FSDP) or tensor parallelism (overall, known as "3d parallelism"). +What is ``torch.distributed.pipelining``? +***************************************** -Examples -******** +While promising for scaling, pipelining is often difficult to implement because +it needs to **partition the execution** of a model in addition to model weights. +The partitioning of execution often requires intrusive code changes to your +model. Another aspect of complexity comes from **scheduling micro-batches in a +distributed environment**, with **data flow dependency** considered. -In the `PiPPy `_ repo where this package is migrated from, we provide rich examples based on realistic models. In particular, we show how to apply pipelining without any model code change. You can refer to the `HuggingFace examples directory `_. Popular examples include: `GPT2 `_, and `LLaMA `_. +The ``pipelining`` package provides a toolkit that does said things +**automatically** which allows easy implementation of pipeline parallelism +on **general** models. -Techniques Explained -******************** +It consists of two parts: a +**splitting frontend** and a **distributed runtime**. +The splitting frontend takes your model code as-is, splits it up into "model +partitions", and capture the data-flow relationship. The distributed runtime +executes the pipeline stages on different devices in parallel, handling things +like micro-batch splitting, scheduling, communication, and gradient propagation, +etc. -``torch.distributed.pipelining`` consists of two parts: a *compiler* and a *runtime*. The compiler takes your model code, splits it up, and transforms it into a ``Pipe``, which is a wrapper that describes the model at each pipeline stage and their data-flow relationship. The runtime executes the ``PipelineStage`` in parallel, handling things like micro-batch splitting, scheduling, communication, and gradient propagation, etc. We will cover the APIs for these concepts in this section. +Overall, the ``pipelining`` package provides the following features: -Splitting a Model with ``pipeline`` -=================================== +* Splitting of model code based on simple specification. The goal is to make + parallelism work for your model with **zero model code change**. +* Rich support for pipeline schedules, including GPipe, 1F1B, + Interleaved 1F1B and Looped BFS, and provide the infrastruture for writing + customized schedules. +* First-class support for cross-host pipeline parallelism, as this is where PP + is typically used (over slower interconnects). +* Composability with other PyTorch parallel techniques such as data parallel + (DDP, FSDP) or tensor parallel. The `TorchTitan + `_ project demonstrates a "3D parallel" + application on the Llama model. -To see how we can split a model into a pipeline, let's first take an example trivial neural network: -.. code-block:: python +Step 1: choose the frontend that fits your need +*********************************************** - import torch +The ``pipelining`` package provides two frontends for two different use cases. +You can make your choice based on whether you have: - class MyNetworkBlock(torch.nn.Module): - def __init__(self, in_dim, out_dim): - super().__init__() - self.lin = torch.nn.Linear(in_dim, out_dim) +* a full model, or +* module constructor for each stage. - def forward(self, x): - x = self.lin(x) - x = torch.relu(x) - return x +Frontend 1: the ``pipeline`` API -- if you have a full model +============================================================ - class MyNetwork(torch.nn.Module): - def __init__(self, in_dim, layer_dims): - super().__init__() +If you have a full model and do not want to spend time on modifying it into a +sequence of "model partitions", the ``pipeline`` API is here to help. +Here is a brief example: - prev_dim = in_dim - for i, dim in enumerate(layer_dims): - setattr(self, f'layer{i}', MyNetworkBlock(prev_dim, dim)) - prev_dim = dim - - self.num_layers = len(layer_dims) - # 10 output classes - self.output_proj = torch.nn.Linear(layer_dims[-1], 10) - - def forward(self, x): - for i in range(self.num_layers): - x = getattr(self, f'layer{i}')(x) +.. code-block:: python - return self.output_proj(x) + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.emb = torch.nn.Embedding(10, 3) + self.layers = torch.nn.ModuleList( + Layer() for _ in range(2) + ) + self.lm = LMHead() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.emb(x) + for layer in self.layers: + x = layer(x) + x = self.lm(x) + return x - in_dim = 512 - layer_dims = [512, 1024, 256] - mn = MyNetwork(in_dim, layer_dims).to(device) +If we print the model, we can see multiple hierarchies, which makes it hard to split by hand:: -This network is written as free-form Python code; it has not been modified for any specific parallelism technique. + Model( + (emb): Embedding(10, 3) + (layers): ModuleList( + (0-1): 2 x Layer( + (lin): Linear(in_features=3, out_features=3, bias=True) + ) + ) + (lm): LMHead( + (proj): Linear(in_features=3, out_features=3, bias=True) + ) + ) -Let us see our usage of the ``pipeline`` interface: +Let us see how the ``pipeline`` API works: .. code-block:: python - from torch.distributed.pipelining import annotate_split_points, pipeline, Pipe, SplitPoint + from torch.distributed.pipelining import pipeline, SplitPoint - annotate_split_points(mn, {'layer0': SplitPoint.END, - 'layer1': SplitPoint.END}) - - batch_size = 32 - example_input = torch.randn(batch_size, in_dim, device=device) - chunks = 4 + x = torch.LongTensor([1, 2, 4, 5]) + pipe = pipeline( + module=mod, + num_chunks=1, + example_args=(x,), + split_spec={ + "layers.1": SplitPoint.BEGINNING, + } + ) - pipe = pipeline(mn, chunks, example_args=(example_input,)) - print(pipe) +The ``pipeline`` API splits your model given a ``split_spec``, where +``SplitPoint.BEGINNING`` stands for adding a split point +*before* execution of certain submodule in the ``forward`` function, and +similarly, ``SplitPoint.END`` for split point *after* such. -:: +If we ``print(pipe)``, we can see:: - ************************************* pipe ************************************* GraphModule( (submod_0): GraphModule( - (layer0): InterpreterModule( - (lin): InterpreterModule() + (emb): InterpreterModule() + (layers): Module( + (0): InterpreterModule( + (lin): InterpreterModule() + ) ) ) (submod_1): GraphModule( - (layer1): InterpreterModule( - (lin): InterpreterModule() + (layers): Module( + (1): InterpreterModule( + (lin): InterpreterModule() + ) ) - ) - (submod_2): GraphModule( - (layer2): InterpreterModule( - (lin): InterpreterModule() + (lm): InterpreterModule( + (proj): InterpreterModule() ) - (output_proj): InterpreterModule() ) ) - def forward(self, arg8_1): - submod_0 = self.submod_0(arg8_1); arg8_1 = None + def forward(self, x): + submod_0 = self.submod_0(x); x = None submod_1 = self.submod_1(submod_0); submod_0 = None - submod_2 = self.submod_2(submod_1); submod_1 = None - return (submod_2,) - -So what's going on here? First, ``pipeline`` turns our model into a directed acyclic graph (DAG) by tracing the model. Then, it groups together the operations and parameters into *pipeline stages*. Stages are represented as ``submod_N`` submodules, where ``N`` is a natural number. + return (submod_1,) -We used ``annotate_split_points`` to specify that the code should be split and the end of ``layer0`` and ``layer1``. Our code has thus been split into *three* pipeline stages. Our library also provides ``SplitPoint.BEGINNING`` if a user wants to split before certain annotation point. -While the ``annotate_split_points`` API gives users a way to specify the split points without modifying the model, our library also provides an API for in-model annotation: ``pipe_split()``. For details, you can read `this example `_. +The "model partitions" are represented by submodules (``submod_0``, +``submod_1``), each of which is reconstructed with original model operations +and hierarchies. In addition, a "root-level" ``forward`` function is +reconstructed to capture the data flow between those partitions. Such data flow +will be replayed by the pipeline runtime later, in a distributed fashion. -This covers the basic usage of the ``Pipe`` API. For more information, please see the documentation. +The ``Pipe`` object provides a method for retrieving the "model partitions": -Using ``PipelineSchedule`` for Execution -======================================== +.. code-block:: python -After transforming the model into a ``Pipe`` representation, we can run its stages in a distributed *runtime*. This can be done in two steps: -* instantiate a ``PipelineStage`` from a stage module of ``Pipe``; -* run the ``PipelineStage`` according to a ``PipelineSchedule``. + stage_mod : nn.Module = pipe.get_stage_module(stage_idx) -First off, let us instantiate a ``PipelineStage`` instance: +You can also create a distributed stage runtime on a device using ``Pipe``: .. code-block:: python - # We are using `torchrun` to run this example with multiple processes. - # `torchrun` defines two environment variables: `RANK` and `WORLD_SIZE`. - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) + from torch.distributed.pipelining import PipelineStage - # Initialize distributed environment - import torch.distributed as dist - dist.init_process_group(rank=rank, world_size=world_size) + stage = PipelineStage(pipe, stage_idx, device) - # Pipeline stage is our main pipeline runtime. It takes in the pipe object, - # the rank of this process, and the device. - from torch.distributed.pipelining import PipelineStage - stage = PipelineStage(pipe, rank, device) +.. note:: + The ``pipeline`` frontend uses a tracer (``torch.export``) to capture your + model into a single graph. If your model is not full-graph'able, you can use + our manual frontend below. + +Frontend 2: ``ManualPipelineStage`` -- if you already have module for each stage +================================================================================ -We can now attach the ``PipelineStage`` to a pipeline schedule, GPipe for example, and run with data: +If you already have the module for each stage, you can skip the pipeline split +step above and directly connect to our runtime offering: ``ManualPipelineStage``. +The ``ManualPipelineStage`` wraps your stage module given a distributed context, +i.e. a ``ProcessGroup`` along the pipeline dimension. + +TODO: manual example here + + +Step 2: use ``PipelineSchedule`` for execution +********************************************** + +We can now attach the ``PipelineStage`` to a pipeline schedule, and run the +schedule with input data. Here is a GPipe example: .. code-block:: python from torch.distributed.pipelining import ScheduleGPipe - schedule = ScheduleGPipe(stage, chunks) - # Input data + # Create a schedule + schedule = ScheduleGPipe(stage, n_microbatches) + + # Input data (whole batch) x = torch.randn(batch_size, in_dim, device=device) - # Run the pipeline with input `x`. Divide the batch into 4 micro-batches - # and run them in parallel on the pipeline + # Run the pipeline with input `x` + # `x` will be divided into microbatches automatically if rank == 0: schedule.step(x) else: output = schedule.step() -Note that since we split our model into three stages, we must run this script with three workers. For this example, we will use ``torchrun`` to run multiple processes within a single machine for demonstration purposes. We can collect up all of the code blocks above into a file named `example.py `_ and then run it with ``torchrun`` like so: +Note that the above code needs to be launched for each worker, thus we use a +launcher service to launch multiple processes: .. code-block:: bash - torchrun --nproc_per_node=3 example.py + torchrun --nproc_per_node=2 example.py + + +Hugging Face Examples +********************* + +In the `PiPPy `_ repo where this package was +original created, we kept examples based on unmodified Hugging Face models. +See the `examples/huggingface +`_ directory. + +Examples include: + +* `GPT2 `_ +* `Llama `_ + + +Technical Deep Dive +******************* + +How does the ``pipeline`` API split a model? +============================================ + +First, the ``pipeline`` API turns our model into a directed acyclic graph (DAG) +by tracing the model. It traces the model using ``torch.export`` -- a PyTorch 2 +full-graph capturing tool. + +Then, it groups together the **operations and parameters** needed by a stage +into a reconstructed submodule: ``submod_0``, ``submod_1``, ... + +Different from conventional submodule access methods like ``Module.children()``, +the ``pipeline`` API does not only cut the module structure of your model, but +also the **forward** function of your model. + +This is necessary because model structure like ``Module.children()`` merely +captures information during ``Module.__init__()``, and does not capture any +information about ``Module.forward()``. Said differently, ``Module.children()`` +lacks information about the following aspects key to pipelininig: -Pipeline Transformation APIs +* Exectuion order of child modules in ``forward`` +* Activation flows between child modules +* Whether there are any functional operators between child modules (for example, + ``relu`` or ``add`` operations will not be captured by ``Module.children()``). + +The ``pipeline`` API, on the contrary, makes sure that the ``forward`` behavior +is truly preserved. It also captures the activation flow between the partitions, +helping the distributed runtime to make correct send/receive calls without human +intervention. + +Another flexibility of the ``pipeline`` API is that split points can be at +arbitrary hierarchy of your model. In the split partitions, the original model +hierarchy related to that partition will be reconstructed at no cost of yours. +At a result, fully-qualified names (FQNs) pointing to a submodule or parameter +would be still valid, and services that relies on FQNs (such as FSDP, TP or +checkpointing) can still run with your partitioned modules with almost zero code +change. + + +Implementing Your Own Schedule +****************************** + +You can implement your own pipeline schedule by extending one of the following two class: + +* ``PipelineScheduleSingle`` +* ``PipelineScheduleMulti`` + +``PipelineScheduleSingle`` is for schedules that assigns *only one* stage per rank. +``PipelineScheduleMulti`` is for schedules that assigns multiple stages per rank. + +For example, ``ScheduleGPipe`` and ``Schedule1F1B`` are subclasses of ``PipelineScheduleSingle``. +Whereas, ``ScheduleInterleaved1F1B`` and ``ScheduleLoopedBFS`` are subclasses of ``PipelineScheduleMulti``. + +.. currentmodule:: torch.distributed.pipelining.PipelineSchedule + +.. autoclass:: PipelineScheduleSingle + +.. autoclass:: PipelineScheduleMulti + + +API Reference +************* + +.. automodule:: torch.distributed.pipelining + +Model Split APIs ============================ The following set of APIs transform your model into a pipeline representation. @@ -240,23 +370,3 @@ Pipeline Schedules .. autoclass:: ScheduleInterleaved1F1B .. autoclass:: ScheduleLoopedBFS - -Implementing Your Own Schedule -============================== - -You can implement your own pipeline schedule by extending one of the following two class: - -* ``PipelineScheduleSingle`` -* ``PipelineScheduleMulti`` - -``PipelineScheduleSingle`` is for schedules that assigns *only one* stage per rank. -``PipelineScheduleMulti`` is for schedules that assigns multiple stages per rank. - -For example, ``ScheduleGPipe`` and ``Schedule1F1B`` are subclasses of ``PipelineScheduleSingle``. -Whereas, ``ScheduleInterleaved1F1B`` and ``ScheduleLoopedBFS`` are subclasses of ``PipelineScheduleMulti``. - -.. currentmodule:: torch.distributed.pipelining.PipelineSchedule - -.. autoclass:: PipelineScheduleSingle - -.. autoclass:: PipelineScheduleMulti diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index c7ea787f98b57..0a45c4459f305 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -487,17 +487,17 @@ def _direct_serialization_reduce(self): class Pipe(torch.nn.Module): # Class variables - """ - args_chunk_spec: - Chunking specification for positional inputs. (default: `None`) - kwargs_chunk_spec: - Chunking specification for keyword inputs. (default: `None`) - """ # args_chunk_spec and kwargs_chunk_spec are used to specify how to chunk # inputs. They are used to create microbatched examples before tracing. # See context managers `ArgsChunkSpec` and `KwargsChunkSpec`. # TODO: Do we need to support `_Replicate`? It's unclear, dropping for now. + + # args_chunk_spec: + # Chunking specification for positional inputs. (default: `None`) args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None + + # kwargs_chunk_spec: + # Chunking specification for keyword inputs. (default: `None`) kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None @dataclass @@ -622,6 +622,9 @@ def forward(self, *args, **kwargs): return res def get_stage_module(self, stage_idx: int) -> torch.nn.Module: + """ + Return a stage module corresponding to `stage_idx` of the `pipe`. + """ if stage_idx < 0 or stage_idx >= self.num_stages: raise ValueError(f"Invalid stage index {stage_idx}!") return getattr(self.split_gm, f"submod_{stage_idx}")