From d14ff1cf446de503c3dd7d23949c79de31c6e51c Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Wed, 4 Sep 2024 11:46:22 -0700 Subject: [PATCH] update --- .../compiled_autograd_tutorial.py | 146 ++++++++++-------- 1 file changed, 79 insertions(+), 67 deletions(-) diff --git a/intermediate_source/compiled_autograd_tutorial.py b/intermediate_source/compiled_autograd_tutorial.py index 4b5e2bbebf8..4fd58e9743f 100644 --- a/intermediate_source/compiled_autograd_tutorial.py +++ b/intermediate_source/compiled_autograd_tutorial.py @@ -4,58 +4,57 @@ Compiled Autograd: Capturing a larger backward graph for ``torch.compile`` ========================================================================== +**Author:** `Simon Fan `_ + +.. grid:: 2 + + .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn + :class-card: card-prerequisites + + * How compiled autograd interacts with torch.compile + * How to use the compiled autograd API + * How to inspect logs using TORCH_LOGS + + .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites + :class-card: card-prerequisites + + * PyTorch 2.4 + * `torch.compile `_ familiarity + """ ###################################################################### +# Overview +# ------------ # Compiled Autograd is a torch.compile extension introduced in PyTorch 2.4 -# that allows the capture of a larger backward graph. It is highly recommended -# to familiarize yourself with `torch.compile `_. +# that allows the capture of a larger backward graph. # - -###################################################################### # Doesn't torch.compile already capture the backward graph? # ------------ -# Partially. AOTAutograd captures the backward graph ahead-of-time, but with certain limitations: -# - Graph breaks in the forward lead to graph breaks in the backward -# - `Backward hooks `_ are not captured +# And it does, **partially**. AOTAutograd captures the backward graph ahead-of-time, but with certain limitations: +# 1. Graph breaks in the forward lead to graph breaks in the backward +# 2. `Backward hooks `_ are not captured # # Compiled Autograd addresses these limitations by directly integrating with the autograd engine, allowing # it to capture the full backward graph at runtime. Models with these two characteristics should try # Compiled Autograd, and potentially observe better performance. # # However, Compiled Autograd has its own limitations: -# - Dynamic autograd structure leads to recompiles +# 1. Additional runtime overhead at the start of the backward +# 2. Dynamic autograd structure leads to recompiles +# +# .. note:: Compiled Autograd is under active development and is not yet compatible with all existing PyTorch features. For the latest status on a particular feature, refer to `Compiled Autograd Landing Page `_. # -###################################################################### -# Tutorial output cells setup -# ------------ -# - -import os - -class ScopedLogging: - def __init__(self): - assert "TORCH_LOGS" not in os.environ - assert "TORCH_LOGS_FORMAT" not in os.environ - os.environ["TORCH_LOGS"] = "compiled_autograd_verbose" - os.environ["TORCH_LOGS_FORMAT"] = "short" - - def __del__(self): - del os.environ["TORCH_LOGS"] - del os.environ["TORCH_LOGS_FORMAT"] - ###################################################################### -# Basic Usage +# Setup # ------------ -# +# In this tutorial, we'll base our examples on this toy model. +# import torch -# NOTE: Must be enabled before using the decorator -torch._dynamo.config.compiled_autograd = True - class Model(torch.nn.Module): def __init__(self): super().__init__() @@ -64,24 +63,30 @@ def __init__(self): def forward(self, x): return self.linear(x) + +###################################################################### +# Basic usage +# ------------ +# .. note:: The ``torch._dynamo.config.compiled_autograd = True`` config must be enabled before calling the torch.compile API. +# + +model = Model() +x = torch.randn(10) + +torch._dynamo.config.compiled_autograd = True @torch.compile def train(model, x): loss = model(x).sum() loss.backward() -model = Model() -x = torch.randn(10) train(model, x) ###################################################################### # Inspecting the compiled autograd logs # ------------ -# Run the script with either TORCH_LOGS environment variables -# -# - To only print the compiled autograd graph, use `TORCH_LOGS="compiled_autograd" python example.py` -# - To sacrifice some performance, in order to print the graph with more tensor medata and recompile reasons, use `TORCH_LOGS="compiled_autograd_verbose" python example.py` -# -# Logs can also be enabled through the private API torch._logging._internal.set_logs. +# Run the script with the TORCH_LOGS environment variables: +# - To only print the compiled autograd graph, use ``TORCH_LOGS="compiled_autograd" python example.py`` +# - To print the graph with more tensor medata and recompile reasons, at the cost of performance, use ``TORCH_LOGS="compiled_autograd_verbose" python example.py`` # @torch.compile @@ -92,13 +97,11 @@ def train(model, x): train(model, x) ###################################################################### -# The compiled autograd graph should now be logged to stdout. Certain graph nodes will have names that are prefixed by aot0_, -# these correspond to the nodes previously compiled ahead of time in AOTAutograd backward graph 0. -# -# NOTE: This is the graph that we will call torch.compile on, NOT the optimized graph. Compiled Autograd basically -# generated some python code to represent the entire C++ autograd execution. +# The compiled autograd graph should now be logged to stderr. Certain graph nodes will have names that are prefixed by ``aot0_``, +# these correspond to the nodes previously compiled ahead of time in AOTAutograd backward graph 0 e.g. ``aot0_view_2`` corresponds to ``view_2`` of the AOT backward graph with id=0. # -""" + +stderr_output = """ DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[] DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:TRACED GRAPH ===== Compiled autograd graph ===== @@ -152,6 +155,10 @@ def forward(self, inputs, sizes, scalars, hooks): return [] """ +###################################################################### +# .. note:: This is the graph that we will call torch.compile on, NOT the optimized graph. Compiled Autograd generates some python code to represent the entire C++ autograd execution. +# + ###################################################################### # Compiling the forward and backward pass using different flags # ------------ @@ -163,7 +170,7 @@ def train(model, x): torch.compile(lambda: loss.backward(), fullgraph=True)() ###################################################################### -# Or you can use the context manager, which will apply to all autograd calls within it +# Or you can use the context manager, which will apply to all autograd calls within its scope. # def train(model, x): @@ -174,7 +181,7 @@ def train(model, x): ###################################################################### -# Demonstrating the limitations of AOTAutograd addressed by Compiled Autograd +# Compiled Autograd addresses certain limitations of AOTAutograd # ------------ # 1. Graph breaks in the forward lead to graph breaks in the backward # @@ -208,7 +215,12 @@ def fn(x): ###################################################################### -# 2. `Backward hooks are not captured +# In the ``1. base torch.compile`` case, we see that 3 backward graphs were produced due to the 2 graph breaks in the compiled function ``fn``. +# Whereas in ``2. torch.compile with compiled autograd``, we see that a full backward graph was traced despite the graph breaks. +# + +###################################################################### +# 2. Backward hooks are not captured # @torch.compile(backend="aot_eager") @@ -223,19 +235,19 @@ def fn(x): loss.backward() ###################################################################### -# There is a `call_hook` node in the graph, which dynamo will inline +# There should be a ``call_hook`` node in the graph, which dynamo will later inline into # -""" +stderr_output = """ DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[] DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:TRACED GRAPH - ===== Compiled autograd graph ===== - .2 class CompiledAutograd(torch.nn.Module): - def forward(self, inputs, sizes, scalars, hooks): - ... - getitem_2 = hooks[0]; hooks = None - call_hook: "f32[10, 10][0, 0]cpu" = torch__dynamo_external_utils_call_hook(getitem_2, aot0_expand, hook_type = 'tensor_pre_hook'); getitem_2 = aot0_expand = None - ... +===== Compiled autograd graph ===== +.2 class CompiledAutograd(torch.nn.Module): + def forward(self, inputs, sizes, scalars, hooks): + ... + getitem_2 = hooks[0]; hooks = None + call_hook: "f32[10, 10][0, 0]cpu" = torch__dynamo_external_utils_call_hook(getitem_2, aot0_expand, hook_type = 'tensor_pre_hook'); getitem_2 = aot0_expand = None + ... """ ###################################################################### @@ -250,10 +262,10 @@ def forward(self, inputs, sizes, scalars, hooks): torch.compile(lambda: loss.backward(), backend="eager")() ###################################################################### -# You should see some cache miss logs (recompiles): +# You should see some recompile messages: **Cache miss due to new autograd node**. # -""" +stderr_output = """ Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[] ... Cache miss due to new autograd node: SubBackward0 (NodeCall 2) with key size 56, previous key sizes=[] @@ -268,7 +280,6 @@ def forward(self, inputs, sizes, scalars, hooks): # 2. Due to dynamic shapes # -torch._logging._internal.set_logs(compiled_autograd_verbose=True) torch._dynamo.config.compiled_autograd = True for i in [10, 100, 10]: x = torch.randn(i, i, requires_grad=True) @@ -276,10 +287,10 @@ def forward(self, inputs, sizes, scalars, hooks): torch.compile(lambda: loss.backward(), backend="eager")() ###################################################################### -# You should see some cache miss logs (recompiles): +# You should see some recompiles messages: **Cache miss due to changed shapes**. # -""" +stderr_output = """ ... Cache miss due to changed shapes: marking size idx 0 of torch::autograd::GraphRoot (NodeCall 0) as dynamic Cache miss due to changed shapes: marking size idx 1 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic @@ -289,8 +300,9 @@ def forward(self, inputs, sizes, scalars, hooks): """ ###################################################################### -# Compatibility and rough edges -# ------------ -# -# Compiled Autograd is under active development and is not yet compatible with all existing PyTorch features. -# For the latest status on a particular feature, refer to: https://docs.google.com/document/d/11VucFBEewzqgkABIjebZIzMvrXr3BtcY1aGKpX61pJY. +# Conclusion +# ---------- +# In this tutorial, we went over the high-level ecosystem of torch.compile with compiled autograd, the basics of compiled autograd and a few common recompilation reasons. +# +# For feedback on this tutorial, please file an issue on https://github.com/pytorch/tutorials. +# \ No newline at end of file