Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dtensor][debug] Creating recipe allowing users to learn how CommDebugMode works and use visual browser #2993

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 202 additions & 0 deletions recipes_source/distributed_comm_debug_mode.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
Using CommDebugMode
=====================================================

**Author**: `Anshul Sinha <https://github.com/sinhaanshul>`__

Prerequisites

- Python 3.8 - 3.11
- PyTorch 2.2 or later

svekars marked this conversation as resolved.
Show resolved Hide resolved

What is CommDebugMode and why is it useful
------------------------------------------
sinhaanshul marked this conversation as resolved.
Show resolved Hide resolved
As the size of models continues to increase, users are seeking to leverage various combinations
of parallel strategies to scale up distributed training. However, the lack of interoperability
between existing solutions poses a significant challenge, primarily due to the absence of a
unified abstraction that can bridge these different parallelism strategies. To address this
issue, PyTorch has proposed `DistributedTensor(DTensor)
<https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/examples/comm_mode_features_example.py>`_
which abstracts away the complexities of tensor communication in distributed training,
providing a seamless user experience. However, this abstraction creates a lack of transparency
sinhaanshul marked this conversation as resolved.
Show resolved Hide resolved
that can make it challenging for users to identify and resolve issues. To address this challenge,
``CommDebugMode``, a Python context manager will serve as one of the primary debugging tools for
DTensors, enabling users to view when and why collective operations are happening when using DTensors,
effectively addressing this issue.


How to use CommDebugMode
------------------------
sinhaanshul marked this conversation as resolved.
Show resolved Hide resolved

Here is how you can use ``CommDebugMode``:

.. code-block:: python

comm_mode = CommDebugMode()
with comm_mode:
output = model(inp)
sinhaanshul marked this conversation as resolved.
Show resolved Hide resolved

# print the operation level collective tracing information
print(comm_mode.generate_comm_debug_tracing_table(noise_level=0))

# log the operation level collective tracing information to a file
comm_mode.log_comm_debug_tracing_table_to_file(
noise_level=1, file_name="transformer_operation_log.txt"
)

# dump the operation level collective tracing information to json file,
# used in the visual browser below
comm_mode.generate_json_dump(noise_level=2)

.. code-block:: python

"""
This is what the output looks like for a MLPModule at noise level 0
Expected Output:
Global
FORWARD PASS
*c10d_functional.all_reduce: 1
MLPModule
FORWARD PASS
*c10d_functional.all_reduce: 1
MLPModule.net1
MLPModule.relu
MLPModule.net2
FORWARD PASS
*c10d_functional.all_reduce: 1
"""

To use ``CommDebugMode``, you must wrap the code running the model in ``CommDebugMode`` and call the API that
you want to use to display the data. You can also use a ``noise_level`` argument to control the verbosity
level of displayed information. Here is what each noise level displays:

| 0. Prints module-level collective counts
| 1. Prints dTensor operations not included in trivial operations, module information
sinhaanshul marked this conversation as resolved.
Show resolved Hide resolved
| 2. Prints operations not included in trivial operations
sinhaanshul marked this conversation as resolved.
Show resolved Hide resolved
| 3. Prints all operations

In the example above, you can see that the collective operation, all_reduce, occurs once in the forward pass
of the ``MLPModule``. Furthermore, you can use ``CommDebugMode`` to pinpoint that the all-reduce operation happens
in the second linear layer of the ``MLPModule``.


Below is the interactive module tree visualization that you can use to upload your own JSON dump:

.. raw:: html

<!DOCTYPE html>
<html lang ="en">
<head>
<meta charset="UTF-8">
<meta name = "viewport" content="width=device-width, initial-scale=1.0">
<title>CommDebugMode Module Tree</title>
<style>
ul, #tree-container {
list-style-type: none;
margin: 0;
padding: 0;
}
.caret {
cursor: pointer;
user-select: none;
}
.caret::before {
content: "\25B6";
color:black;
display: inline-block;
margin-right: 6px;
}
.caret-down::before {
transform: rotate(90deg);
}
.tree {
padding-left: 20px;
}
.tree ul {
padding-left: 20px;
}
.nested {
display: none;
}
.active {
display: block;
}
.forward-pass,
.backward-pass {
margin-left: 40px;
}
.forward-pass table {
margin-left: 40px;
width: auto;
}
.forward-pass table td, .forward-pass table th {
padding: 8px;
}
.forward-pass ul {
display: none;
}
table {
font-family: arial, sans-serif;
border-collapse: collapse;
width: 100%;
}
td, th {
border: 1px solid #dddddd;
text-align: left;
padding: 8px;
}
tr:nth-child(even) {
background-color: #dddddd;
}
#drop-area {
position: relative;
width: 25%;
height: 100px;
border: 2px dashed #ccc;
border-radius: 5px;
padding: 0px;
text-align: center;
}
.drag-drop-block {
display: inline-block;
width: 200px;
height: 50px;
background-color: #f7f7f7;
border: 1px solid #ccc;
border-radius: 5px;
padding: 10px;
font-size: 14px;
color: #666;
cursor: pointer;
}
#file-input {
position: absolute;
top: 0;
left: 0;
width: 100%;
height: 100%;
opacity: 0;
}
</style>
</head>
<body>
<div id="drop-area">
<div class="drag-drop-block">
<span>Drag file here</span>
</div>
<input type="file" id="file-input" accept=".json">
</div>
<div id="tree-container"></div>
<script src="https://cdn.jsdelivr.net/gh/pytorch/pytorch@main/torch/distributed/_tensor/debug/comm_mode_broswer_visual.js"></script>
</body>
</html>

Conclusion
------------------------------------------

In this recipe, we have learned how to use ``CommDebugMode`` to debug Distributed Tensors. You can use your
sinhaanshul marked this conversation as resolved.
Show resolved Hide resolved
own JSON outputs in the embedded visual browser.

For more detailed information about ``CommDebugMode``, see
`comm_mode_features_example.py
<https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/examples/comm_mode_features_example.py>`_
8 changes: 8 additions & 0 deletions recipes_source/recipes_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,13 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
:link: ../recipes/distributed_async_checkpoint_recipe.html
:tags: Distributed-Training

.. customcarditem::
:header: Getting Started with CommDebugMode
:card_description: Learn how to use CommDebugMode for DTensors
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
:link: ../recipes/distributed_comm_debug_mode.html
:tags: Distributed-Training

.. TorchServe

.. customcarditem::
Expand Down Expand Up @@ -449,3 +456,4 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
/recipes/cuda_rpc
/recipes/distributed_optim_torchscript
/recipes/mobile_interpreter
/recipes/distributed_comm_debug_mode
Loading