Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patched docs for torch_compile_tutorial #2936

Merged
merged 9 commits into from
Aug 30, 2024
102 changes: 100 additions & 2 deletions intermediate_source/torch_compile_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,21 @@ def foo(x, y):

######################################################################
# Alternatively, we can decorate the function.
t1 = torch.randn(10, 10)
williamwen42 marked this conversation as resolved.
Show resolved Hide resolved
t2 = torch.randn(10, 10)

@torch.compile
def opt_foo2(x, y):
a = torch.sin(x)
b = torch.cos(y)
return a + b
print(opt_foo2(torch.randn(10, 10), torch.randn(10, 10)))
print(opt_foo2(t1, t2))

######################################################################
# We can also optimize ``torch.nn.Module`` instances.

t = torch.randn(10, 100)

class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -94,7 +98,101 @@ def forward(self, x):

mod = MyModule()
opt_mod = torch.compile(mod)
print(opt_mod(torch.randn(10, 100)))
print(opt_mod(t))

######################################################################
# torch.compile and Nested Calls
# ------------------------------
# Nested function calls within the decorated function will also be compiled.

def nested_function(x):
return torch.sin(x)

@torch.compile
def outer_function(x, y):
a = nested_function(x)
b = torch.cos(y)
return a + b

print(outer_function(t1, t2))

######################################################################
# In the same fashion, when compiling a module all sub-modules and methods
# within it, that are not in a skiplist, are also compiled.
svekars marked this conversation as resolved.
Show resolved Hide resolved

class OuterModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.inner_module = MyModule()
self.outer_lin = torch.nn.Linear(10, 2)

def forward(self, x):
x = self.inner_module(x)
return torch.nn.functional.relu(self.outer_lin(x))

outer_mod = OuterModule()
opt_outer_mod = torch.compile(outer_mod)
print(opt_outer_mod(t))

######################################################################
# We can also disable some functions from being compiled by using
williamwen42 marked this conversation as resolved.
Show resolved Hide resolved
# `torch.compiler.disable`. Suppose you want to disable the tracing on just
svekars marked this conversation as resolved.
Show resolved Hide resolved
# the `complex_function` function, but want to continue the tracing back in
svekars marked this conversation as resolved.
Show resolved Hide resolved
# `complex_conjugate`. In this case, you can use
svekars marked this conversation as resolved.
Show resolved Hide resolved
# `torch.compiler.disable(recursive=False)` option. Otherwise, the default is
svekars marked this conversation as resolved.
Show resolved Hide resolved
# `recursive=True`.
svekars marked this conversation as resolved.
Show resolved Hide resolved

def complex_conjugate(z):
return torch.conj(z)

@torch.compiler.disable(recursive=False)
def complex_function(real, imag):
# Assuming this function cause problems in the compilation
z = torch.complex(real, imag)
return complex_conjugate(z)

def outer_function():
real = torch.tensor([2, 3], dtype=torch.float32)
imag = torch.tensor([4, 5], dtype=torch.float32)
z = complex_function(real, imag)
return torch.abs(z)

# Try to compile the outer_function
try:
opt_outer_function = torch.compile(outer_function)
print(opt_outer_function())
except Exception as e:
print("Compilation of outer_function failed:", e)

######################################################################
# Best Practices and Recommendations
# ----------------------------------
#
# Behavior of ``torch.compile`` with Nested Modules and Function Calls
#
# When you use ``torch.compile``, the compiler will try to recursively compile
# every function call inside the target function or module inside the target
# function or module that is not in a skiplist (e.g. builtins, some functions in
svekars marked this conversation as resolved.
Show resolved Hide resolved
# the torch.* namespace).
#
# **Best Practices:**
williamwen42 marked this conversation as resolved.
Show resolved Hide resolved
#
# 1. **Top-Level Compilation:** One approach is to compile at the highest level
# possible (i.e., when the top-level module is initialized/called) and
# selectively disable compilation when encountering excessive graph breaks or
# errors. If there are still many compile issues, compile individual
# subcomponents instead.
#
# 2. **Modular Testing:** Test individual functions and modules with ``torch.compile``
# before integrating them into larger models to isolate potential issues.
#
# 3. **Disable Compilation Selectively:** If certain functions or sub-modules
# cannot be handled by `torch.compile`, use the `torch.compiler.disable` context
# managers to recursively exclude them from compilation.
#
# 4. **Compile Leaf Functions First:** In complex models with multiple nested
# functions and modules, start by compiling the leaf functions or modules first.
# For more information see `TorchDynamo APIs for fine-grained tracing <https://pytorch.org/docs/stable/torch.compiler_fine_grain_apis.html>`__.

######################################################################
# Demonstrating Speedups
Expand Down
Loading