diff --git a/intermediate_source/torch_compile_tutorial.py b/intermediate_source/torch_compile_tutorial.py index 5e7112f5b9..67b055d9ff 100644 --- a/intermediate_source/torch_compile_tutorial.py +++ b/intermediate_source/torch_compile_tutorial.py @@ -73,17 +73,21 @@ def foo(x, y): ###################################################################### # Alternatively, we can decorate the function. +t1 = torch.randn(10, 10) +t2 = torch.randn(10, 10) @torch.compile def opt_foo2(x, y): a = torch.sin(x) b = torch.cos(y) return a + b -print(opt_foo2(torch.randn(10, 10), torch.randn(10, 10))) +print(opt_foo2(t1, t2)) ###################################################################### # We can also optimize ``torch.nn.Module`` instances. +t = torch.randn(10, 100) + class MyModule(torch.nn.Module): def __init__(self): super().__init__() @@ -94,7 +98,101 @@ def forward(self, x): mod = MyModule() opt_mod = torch.compile(mod) -print(opt_mod(torch.randn(10, 100))) +print(opt_mod(t)) + +###################################################################### +# torch.compile and Nested Calls +# ------------------------------ +# Nested function calls within the decorated function will also be compiled. + +def nested_function(x): + return torch.sin(x) + +@torch.compile +def outer_function(x, y): + a = nested_function(x) + b = torch.cos(y) + return a + b + +print(outer_function(t1, t2)) + +###################################################################### +# In the same fashion, when compiling a module all sub-modules and methods +# within it, that are not in a skip list, are also compiled. + +class OuterModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.inner_module = MyModule() + self.outer_lin = torch.nn.Linear(10, 2) + + def forward(self, x): + x = self.inner_module(x) + return torch.nn.functional.relu(self.outer_lin(x)) + +outer_mod = OuterModule() +opt_outer_mod = torch.compile(outer_mod) +print(opt_outer_mod(t)) + +###################################################################### +# We can also disable some functions from being compiled by using +# ``torch.compiler.disable``. Suppose you want to disable the tracing on just +# the ``complex_function`` function, but want to continue the tracing back in +# ``complex_conjugate``. In this case, you can use +# ``torch.compiler.disable(recursive=False)`` option. Otherwise, the default is +# ``recursive=True``. + +def complex_conjugate(z): + return torch.conj(z) + +@torch.compiler.disable(recursive=False) +def complex_function(real, imag): + # Assuming this function cause problems in the compilation + z = torch.complex(real, imag) + return complex_conjugate(z) + +def outer_function(): + real = torch.tensor([2, 3], dtype=torch.float32) + imag = torch.tensor([4, 5], dtype=torch.float32) + z = complex_function(real, imag) + return torch.abs(z) + +# Try to compile the outer_function +try: + opt_outer_function = torch.compile(outer_function) + print(opt_outer_function()) +except Exception as e: + print("Compilation of outer_function failed:", e) + +###################################################################### +# Best Practices and Recommendations +# ---------------------------------- +# +# Behavior of ``torch.compile`` with Nested Modules and Function Calls +# +# When you use ``torch.compile``, the compiler will try to recursively compile +# every function call inside the target function or module inside the target +# function or module that is not in a skip list (such as built-ins, some functions in +# the torch.* namespace). +# +# **Best Practices:** +# +# 1. **Top-Level Compilation:** One approach is to compile at the highest level +# possible (i.e., when the top-level module is initialized/called) and +# selectively disable compilation when encountering excessive graph breaks or +# errors. If there are still many compile issues, compile individual +# subcomponents instead. +# +# 2. **Modular Testing:** Test individual functions and modules with ``torch.compile`` +# before integrating them into larger models to isolate potential issues. +# +# 3. **Disable Compilation Selectively:** If certain functions or sub-modules +# cannot be handled by `torch.compile`, use the `torch.compiler.disable` context +# managers to recursively exclude them from compilation. +# +# 4. **Compile Leaf Functions First:** In complex models with multiple nested +# functions and modules, start by compiling the leaf functions or modules first. +# For more information see `TorchDynamo APIs for fine-grained tracing `__. ###################################################################### # Demonstrating Speedups