Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
xmfan committed Sep 4, 2024
1 parent 45a4087 commit d14ff1c
Showing 1 changed file with 79 additions and 67 deletions.
146 changes: 79 additions & 67 deletions intermediate_source/compiled_autograd_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,58 +4,57 @@
Compiled Autograd: Capturing a larger backward graph for ``torch.compile``
==========================================================================
**Author:** `Simon Fan <https://github.com/xmfan>`_
.. grid:: 2
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
:class-card: card-prerequisites
* How compiled autograd interacts with torch.compile
* How to use the compiled autograd API
* How to inspect logs using TORCH_LOGS
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
:class-card: card-prerequisites
* PyTorch 2.4
* `torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_ familiarity
"""

######################################################################
# Overview
# ------------
# Compiled Autograd is a torch.compile extension introduced in PyTorch 2.4
# that allows the capture of a larger backward graph. It is highly recommended
# to familiarize yourself with `torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_.
# that allows the capture of a larger backward graph.
#

######################################################################
# Doesn't torch.compile already capture the backward graph?
# ------------
# Partially. AOTAutograd captures the backward graph ahead-of-time, but with certain limitations:
# - Graph breaks in the forward lead to graph breaks in the backward
# - `Backward hooks <https://pytorch.org/docs/stable/notes/autograd.html#backward-hooks-execution>`_ are not captured
# And it does, **partially**. AOTAutograd captures the backward graph ahead-of-time, but with certain limitations:
# 1. Graph breaks in the forward lead to graph breaks in the backward
# 2. `Backward hooks <https://pytorch.org/docs/stable/notes/autograd.html#backward-hooks-execution>`_ are not captured
#
# Compiled Autograd addresses these limitations by directly integrating with the autograd engine, allowing
# it to capture the full backward graph at runtime. Models with these two characteristics should try
# Compiled Autograd, and potentially observe better performance.
#
# However, Compiled Autograd has its own limitations:
# - Dynamic autograd structure leads to recompiles
# 1. Additional runtime overhead at the start of the backward
# 2. Dynamic autograd structure leads to recompiles
#
# .. note:: Compiled Autograd is under active development and is not yet compatible with all existing PyTorch features. For the latest status on a particular feature, refer to `Compiled Autograd Landing Page <https://docs.google.com/document/d/11VucFBEewzqgkABIjebZIzMvrXr3BtcY1aGKpX61pJY>`_.
#

######################################################################
# Tutorial output cells setup
# ------------
#

import os

class ScopedLogging:
def __init__(self):
assert "TORCH_LOGS" not in os.environ
assert "TORCH_LOGS_FORMAT" not in os.environ
os.environ["TORCH_LOGS"] = "compiled_autograd_verbose"
os.environ["TORCH_LOGS_FORMAT"] = "short"

def __del__(self):
del os.environ["TORCH_LOGS"]
del os.environ["TORCH_LOGS_FORMAT"]


######################################################################
# Basic Usage
# Setup
# ------------
#
# In this tutorial, we'll base our examples on this toy model.
#

import torch

# NOTE: Must be enabled before using the decorator
torch._dynamo.config.compiled_autograd = True

class Model(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -64,24 +63,30 @@ def __init__(self):
def forward(self, x):
return self.linear(x)


######################################################################
# Basic usage
# ------------
# .. note:: The ``torch._dynamo.config.compiled_autograd = True`` config must be enabled before calling the torch.compile API.
#

model = Model()
x = torch.randn(10)

torch._dynamo.config.compiled_autograd = True
@torch.compile
def train(model, x):
loss = model(x).sum()
loss.backward()

model = Model()
x = torch.randn(10)
train(model, x)

######################################################################
# Inspecting the compiled autograd logs
# ------------
# Run the script with either TORCH_LOGS environment variables
#
# - To only print the compiled autograd graph, use `TORCH_LOGS="compiled_autograd" python example.py`
# - To sacrifice some performance, in order to print the graph with more tensor medata and recompile reasons, use `TORCH_LOGS="compiled_autograd_verbose" python example.py`
#
# Logs can also be enabled through the private API torch._logging._internal.set_logs.
# Run the script with the TORCH_LOGS environment variables:
# - To only print the compiled autograd graph, use ``TORCH_LOGS="compiled_autograd" python example.py``
# - To print the graph with more tensor medata and recompile reasons, at the cost of performance, use ``TORCH_LOGS="compiled_autograd_verbose" python example.py``
#

@torch.compile
Expand All @@ -92,13 +97,11 @@ def train(model, x):
train(model, x)

######################################################################
# The compiled autograd graph should now be logged to stdout. Certain graph nodes will have names that are prefixed by aot0_,
# these correspond to the nodes previously compiled ahead of time in AOTAutograd backward graph 0.
#
# NOTE: This is the graph that we will call torch.compile on, NOT the optimized graph. Compiled Autograd basically
# generated some python code to represent the entire C++ autograd execution.
# The compiled autograd graph should now be logged to stderr. Certain graph nodes will have names that are prefixed by ``aot0_``,
# these correspond to the nodes previously compiled ahead of time in AOTAutograd backward graph 0 e.g. ``aot0_view_2`` corresponds to ``view_2`` of the AOT backward graph with id=0.
#
"""

stderr_output = """
DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]
DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:TRACED GRAPH
===== Compiled autograd graph =====
Expand Down Expand Up @@ -152,6 +155,10 @@ def forward(self, inputs, sizes, scalars, hooks):
return []
"""

######################################################################
# .. note:: This is the graph that we will call torch.compile on, NOT the optimized graph. Compiled Autograd generates some python code to represent the entire C++ autograd execution.
#

######################################################################
# Compiling the forward and backward pass using different flags
# ------------
Expand All @@ -163,7 +170,7 @@ def train(model, x):
torch.compile(lambda: loss.backward(), fullgraph=True)()

######################################################################
# Or you can use the context manager, which will apply to all autograd calls within it
# Or you can use the context manager, which will apply to all autograd calls within its scope.
#

def train(model, x):
Expand All @@ -174,7 +181,7 @@ def train(model, x):


######################################################################
# Demonstrating the limitations of AOTAutograd addressed by Compiled Autograd
# Compiled Autograd addresses certain limitations of AOTAutograd
# ------------
# 1. Graph breaks in the forward lead to graph breaks in the backward
#
Expand Down Expand Up @@ -208,7 +215,12 @@ def fn(x):


######################################################################
# 2. `Backward hooks are not captured
# In the ``1. base torch.compile`` case, we see that 3 backward graphs were produced due to the 2 graph breaks in the compiled function ``fn``.
# Whereas in ``2. torch.compile with compiled autograd``, we see that a full backward graph was traced despite the graph breaks.
#

######################################################################
# 2. Backward hooks are not captured
#

@torch.compile(backend="aot_eager")
Expand All @@ -223,19 +235,19 @@ def fn(x):
loss.backward()

######################################################################
# There is a `call_hook` node in the graph, which dynamo will inline
# There should be a ``call_hook`` node in the graph, which dynamo will later inline into
#

"""
stderr_output = """
DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]
DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:TRACED GRAPH
===== Compiled autograd graph =====
<eval_with_key>.2 class CompiledAutograd(torch.nn.Module):
def forward(self, inputs, sizes, scalars, hooks):
...
getitem_2 = hooks[0]; hooks = None
call_hook: "f32[10, 10][0, 0]cpu" = torch__dynamo_external_utils_call_hook(getitem_2, aot0_expand, hook_type = 'tensor_pre_hook'); getitem_2 = aot0_expand = None
...
===== Compiled autograd graph =====
<eval_with_key>.2 class CompiledAutograd(torch.nn.Module):
def forward(self, inputs, sizes, scalars, hooks):
...
getitem_2 = hooks[0]; hooks = None
call_hook: "f32[10, 10][0, 0]cpu" = torch__dynamo_external_utils_call_hook(getitem_2, aot0_expand, hook_type = 'tensor_pre_hook'); getitem_2 = aot0_expand = None
...
"""

######################################################################
Expand All @@ -250,10 +262,10 @@ def forward(self, inputs, sizes, scalars, hooks):
torch.compile(lambda: loss.backward(), backend="eager")()

######################################################################
# You should see some cache miss logs (recompiles):
# You should see some recompile messages: **Cache miss due to new autograd node**.
#

"""
stderr_output = """
Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]
...
Cache miss due to new autograd node: SubBackward0 (NodeCall 2) with key size 56, previous key sizes=[]
Expand All @@ -268,18 +280,17 @@ def forward(self, inputs, sizes, scalars, hooks):
# 2. Due to dynamic shapes
#

torch._logging._internal.set_logs(compiled_autograd_verbose=True)
torch._dynamo.config.compiled_autograd = True
for i in [10, 100, 10]:
x = torch.randn(i, i, requires_grad=True)
loss = x.sum()
torch.compile(lambda: loss.backward(), backend="eager")()

######################################################################
# You should see some cache miss logs (recompiles):
# You should see some recompiles messages: **Cache miss due to changed shapes**.
#

"""
stderr_output = """
...
Cache miss due to changed shapes: marking size idx 0 of torch::autograd::GraphRoot (NodeCall 0) as dynamic
Cache miss due to changed shapes: marking size idx 1 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic
Expand All @@ -289,8 +300,9 @@ def forward(self, inputs, sizes, scalars, hooks):
"""

######################################################################
# Compatibility and rough edges
# ------------
#
# Compiled Autograd is under active development and is not yet compatible with all existing PyTorch features.
# For the latest status on a particular feature, refer to: https://docs.google.com/document/d/11VucFBEewzqgkABIjebZIzMvrXr3BtcY1aGKpX61pJY.
# Conclusion
# ----------
# In this tutorial, we went over the high-level ecosystem of torch.compile with compiled autograd, the basics of compiled autograd and a few common recompilation reasons.
#
# For feedback on this tutorial, please file an issue on https://github.com/pytorch/tutorials.
#

0 comments on commit d14ff1c

Please sign in to comment.