Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dtensor][debug] Creating recipe allowing users to learn how CommDebugMode works and use visual browser #2993

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 197 additions & 0 deletions recipes_source/distributed_comm_debug_mode.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
Using CommDebugMode
=====================================================

**Author**: `Anshul Sinha <https://github.com/sinhaanshul>`__

Prerequisites:
sinhaanshul marked this conversation as resolved.
Show resolved Hide resolved

- `Distributed Communication Package - torch.distributed <https://pytorch.org/docs/stable/distributed.html>`__
sinhaanshul marked this conversation as resolved.
Show resolved Hide resolved
- Python 3.8 - 3.11
- PyTorch 2.2
sinhaanshul marked this conversation as resolved.
Show resolved Hide resolved


What is CommDebugMode and why is it useful
------------------------------------------
sinhaanshul marked this conversation as resolved.
Show resolved Hide resolved
As the size of models continues to increase, users are seeking to leverage various combinations
of parallel strategies to scale up distributed training. However, the lack of interoperability
between existing solutions poses a significant challenge, primarily due to the absence of a
unified abstraction that can bridge these different parallelism strategies. To address this
issue, PyTorch has proposed DistributedTensor(DTensor) which abstracts away the complexities of
sinhaanshul marked this conversation as resolved.
Show resolved Hide resolved
tensor communication in distributed training, providing a seamless user experience. However,
this abstraction creates a lack of transparency that can make it challenging for users to
identify and resolve issues. To address this challenge, CommDebugMode, a Python context manager
will serve as one of the primary debugging tools for DTensors, enabling users to view when and
why collective operations are happening when using DTensors, addressing this problem.
sinhaanshul marked this conversation as resolved.
Show resolved Hide resolved


How to use CommDebugMode
------------------------
sinhaanshul marked this conversation as resolved.
Show resolved Hide resolved
Using CommDebugMode and getting its output is very simple.
sinhaanshul marked this conversation as resolved.
Show resolved Hide resolved

.. code-block:: python

comm_mode = CommDebugMode()
with comm_mode:
output = model(inp)
sinhaanshul marked this conversation as resolved.
Show resolved Hide resolved

# print the operation level collective tracing information
print(comm_mode.generate_comm_debug_tracing_table(noise_level=0))

# log the operation level collective tracing information to a file
comm_mode.log_comm_debug_tracing_table_to_file(
noise_level=1, file_name="transformer_operation_log.txt"
)

# dump the operation level collective tracing information to json file,
# used in the visual browser below
comm_mode.generate_json_dump(noise_level=2)

"""
This is what the output looks like for a MLPModule at noise level 0
Expected Output:
Global
FORWARD PASS
*c10d_functional.all_reduce: 1
MLPModule
FORWARD PASS
*c10d_functional.all_reduce: 1
MLPModule.net1
MLPModule.relu
MLPModule.net2
FORWARD PASS
*c10d_functional.all_reduce: 1
"""

All users have to do is wrap the code running the model in CommDebugMode and call the API that
sinhaanshul marked this conversation as resolved.
Show resolved Hide resolved
they want to use to display the data. One important thing to note is that the users can use a noise_level
sinhaanshul marked this conversation as resolved.
Show resolved Hide resolved
arguement to control how much information is displayed to the user. The information below shows what each
sinhaanshul marked this conversation as resolved.
Show resolved Hide resolved
noise level displays
sinhaanshul marked this conversation as resolved.
Show resolved Hide resolved

| 0. prints module-level collective counts
sinhaanshul marked this conversation as resolved.
Show resolved Hide resolved
| 1. prints dTensor operations not included in trivial operations, module information
sinhaanshul marked this conversation as resolved.
Show resolved Hide resolved
| 2. prints operations not included in trivial operations
sinhaanshul marked this conversation as resolved.
Show resolved Hide resolved
| 3. prints all operations
sinhaanshul marked this conversation as resolved.
Show resolved Hide resolved

In the example above, users can see in the first picture that the collective operation, all_reduce, occurs
sinhaanshul marked this conversation as resolved.
Show resolved Hide resolved
once in the forward pass of the MLPModule. The second picture provides a greater level of detail, allowing
sinhaanshul marked this conversation as resolved.
Show resolved Hide resolved
users to pinpoint that the all-reduce operation happens in the second linear layer of the MLPModule.
sinhaanshul marked this conversation as resolved.
Show resolved Hide resolved


Below is the interactive module tree visualization that users can upload their JSON dump to:
svekars marked this conversation as resolved.
Show resolved Hide resolved

.. raw:: html

<!DOCTYPE html>
<html lang ="en">
<head>
<meta charset="UTF-8">
<meta name = "viewport" content="width=device-width, initial-scale=1.0">
<title>CommDebugMode Module Tree</title>
<style>
ul, #tree-container {
list-style-type: none;
margin: 0;
padding: 0;
}
.caret {
cursor: pointer;
user-select: none;
}
.caret::before {
content: "\25B6";
color:black;
display: inline-block;
margin-right: 6px;
}
.caret-down::before {
transform: rotate(90deg);
}
.tree {
padding-left: 20px;
}
.tree ul {
padding-left: 20px;
}
.nested {
display: none;
}
.active {
display: block;
}
.forward-pass,
.backward-pass {
margin-left: 40px;
}
.forward-pass table {
margin-left: 40px;
width: auto;
}
.forward-pass table td, .forward-pass table th {
padding: 8px;
}
.forward-pass ul {
display: none;
}
table {
font-family: arial, sans-serif;
border-collapse: collapse;
width: 100%;
}
td, th {
border: 1px solid #dddddd;
text-align: left;
padding: 8px;
}
tr:nth-child(even) {
background-color: #dddddd;
}
#drop-area {
position: relative;
width: 25%;
height: 100px;
border: 2px dashed #ccc;
border-radius: 5px;
padding: 0px;
text-align: center;
}
.drag-drop-block {
display: inline-block;
width: 200px;
height: 50px;
background-color: #f7f7f7;
border: 1px solid #ccc;
border-radius: 5px;
padding: 10px;
font-size: 14px;
color: #666;
cursor: pointer;
}
#file-input {
position: absolute;
top: 0;
left: 0;
width: 100%;
height: 100%;
opacity: 0;
}
</style>
</head>
<body>
<div id="drop-area">
<div class="drag-drop-block">
<span>Drag file here</span>
</div>
<input type="file" id="file-input" accept=".json">
</div>
<div id="tree-container"></div>
<script src="https://cdn.jsdelivr.net/gh/pytorch/pytorch@main/torch/distributed/_tensor/debug/comm_mode_broswer_visual.js"></script>
</body>
</html>

Conclusion
------------------------------------------
In conclusion, we have learned how to use CommDebugMode in order to debug Distributed Tensors
sinhaanshul marked this conversation as resolved.
Show resolved Hide resolved
and can use future json dumps in the embedded visual browser.
sinhaanshul marked this conversation as resolved.
Show resolved Hide resolved

For more detailed information about CommDebugMode, please see
sinhaanshul marked this conversation as resolved.
Show resolved Hide resolved
https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/examples/comm_mode_features_example.py
sinhaanshul marked this conversation as resolved.
Show resolved Hide resolved
8 changes: 8 additions & 0 deletions recipes_source/recipes_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,13 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
:link: ../recipes/distributed_async_checkpoint_recipe.html
:tags: Distributed-Training

.. customcarditem::
:header: Getting Started with CommDebugMode
:card_description: Learn how to use CommDebugMode for DTensors
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
:link: ../recipes/distributed_comm_debug_mode.html
:tags: Distributed-Training

.. TorchServe

.. customcarditem::
Expand Down Expand Up @@ -449,3 +456,4 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
/recipes/cuda_rpc
/recipes/distributed_optim_torchscript
/recipes/mobile_interpreter
/recipes/distributed_comm_debug_mode
Loading