Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patched docs for torch_compile_tutorial #2936

Merged
merged 9 commits into from
Aug 30, 2024
89 changes: 87 additions & 2 deletions intermediate_source/torch_compile_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,35 @@ def foo(x, y):

######################################################################
# Alternatively, we can decorate the function.
t1 = torch.randn(10, 10)
williamwen42 marked this conversation as resolved.
Show resolved Hide resolved
t2 = torch.randn(10, 10)

@torch.compile
def opt_foo2(x, y):
a = torch.sin(x)
b = torch.cos(y)
return a + b
print(opt_foo2(torch.randn(10, 10), torch.randn(10, 10)))
print(opt_foo2(t1, t2))

# When using the decorator approach, nested function calls within the decorated
# function will also be compiled.

def nested_function(x):
return torch.sin(x)

@torch.compile
def outer_function(x, y):
a = nested_function(x)
b = torch.cos(y)
return a + b

print(outer_function(t1, t2))

######################################################################
# We can also optimize ``torch.nn.Module`` instances.

t = torch.randn(10, 100)

class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -94,7 +112,74 @@ def forward(self, x):

mod = MyModule()
opt_mod = torch.compile(mod)
print(opt_mod(torch.randn(10, 100)))
print(opt_mod(t))

# In the same fashion, when compiling a module all sub-modules and methods
# within it are also compiled.

class OuterModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.inner_module = MyModule()
self.outer_lin = torch.nn.Linear(10, 2)

def forward(self, x):
x = self.inner_module(x)
return torch.nn.functional.relu(self.outer_lin(x))

outer_mod = OuterModule()
opt_outer_mod = torch.compile(outer_mod)
print(opt_outer_mod(t))

######################################################################
# We can also disable some functions from being compiled by using
williamwen42 marked this conversation as resolved.
Show resolved Hide resolved
# `torch.compiler.disable`

@torch.compiler.disable
def complex_function(real, imag):
# Assuming this function cause problems in the compilation
return torch.complex(real, imag)

def outer_function():
real = torch.tensor([2, 3], dtype=torch.float32)
imag = torch.tensor([4, 5], dtype=torch.float32)
z = complex_function(real, imag)
return torch.abs(z)

# Try to compile the outer_function
try:
opt_outer_function = torch.compile(outer_function)
print(opt_outer_function())
except Exception as e:
print("Compilation of outer_function failed:", e)

######################################################################
# Best Practices and Recommendations
# ----------------------------------
#
# Behavior of ``torch.compile`` with Nested Modules and Function Calls
#
# When you use ``torch.compile``, the compiler will try to recursively inline
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't have to mention inlining - going over how dynamo inlines is a little more involved.

# and compile every function call inside the target function or module.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"every function call" is not exactly right - there is a skiplist of various functions (e.g. builtins, some functions in the torch.* namespace).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, so I will modify the wording adding your comment to something like "compile every function call inside the target function or module that is not in a skiplist (e.g. builtins, some functions in the torch.* namespace)."

#
# This includes:
#
# - **Nested function calls:** All functions called within the decorated or compiled function will also be compiled.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There isn't really too much of a difference in the nested call behavior between torch.compile'd functions and torch.compile'd modules, so I don't think we need to highlight them as distinct cases.

#
# - **Nested modules:** If a ``torch.nn.Module`` is compiled, all sub-modules and functions within the module are also compiled.
#
# **Best Practices:**
williamwen42 marked this conversation as resolved.
Show resolved Hide resolved
#
# 1. **Modular Testing:** Test individual functions and modules with ``torch.compile``
# before integrating them into larger models to isolate potential issues.
#
# 2. **Disable Compilation Selectively:** If certain functions or sub-modules
# cannot be handled by `torch.compile`, use the `torch.compiler.disable` context
# managers to recursively exclude them from compilation.
#
# 3. **Compile Leaf Functions First:** In complex models with multiple nested
# functions and modules, start by compiling the leaf functions or modules first.
# For more information see `TorchDynamo APIs for fine-grained tracing <https://pytorch.org/docs/stable/torch.compiler_fine_grain_apis.html>`__.

######################################################################
# Demonstrating Speedups
Expand Down
Loading