Skip to content

Commit

Permalink
[compiled autograd] Compiled autograd configs in TLS (pytorch#137821)
Browse files Browse the repository at this point in the history
Multithreaded doesn't work yet, this adds python side TLS only for the python side state

Pull Request resolved: pytorch#137821
Approved by: https://github.com/jansel, https://github.com/yf225
ghstack dependencies: pytorch#137953
  • Loading branch information
xmfan authored and pytorchmergebot committed Oct 22, 2024
1 parent 7525914 commit 49fa437
Show file tree
Hide file tree
Showing 16 changed files with 221 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def _check_count(copy_count, resize_count):
f"Unexpected number of `inductor.resize_storage_bytes_` ops (expected {resize_count}, got {actual_resize_count}) in graph: {graph}", # noqa: B950
)

if not torch._dynamo.compiled_autograd.in_compiled_autograd_region:
if not torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"):
_check_count(fwd_copy_count, fwd_resize_count) # fwd graph
else:
_check_count(bwd_copy_count, bwd_resize_count) # bwd graph
Expand Down
4 changes: 3 additions & 1 deletion test/dynamo/test_activation_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ def match_rng_op(node, op):


def collect_fwd_graph_outputs(graph: torch.fx.Graph, *, fwd_outputs: Set[str]):
if not torch._dynamo.compiled_autograd.in_compiled_autograd_region: # fwd graph
if not torch._dynamo.compiled_autograd.local.get(
"in_compiled_autograd_region"
): # fwd graph
return_node = list(graph.nodes)[-1]
assert return_node.target == "output"
for x in return_node.args[0]:
Expand Down
35 changes: 35 additions & 0 deletions test/inductor/test_compiled_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
import itertools
import logging
import os
import queue
import re
import subprocess
import sys
import threading
import unittest
from importlib.machinery import SourceFileLoader
from pathlib import Path
Expand Down Expand Up @@ -2405,6 +2407,39 @@ def test_logs(self):
not in logs.getvalue()
)

def test_multithreading_tls(self):
def train(errors, model, x):
try:
out = model(x)
with compiled_autograd.enable(compiler_fn):
self.assertEqual(compiled_autograd.local.enabled(), True)
self.assertEqual(compiled_autograd.local.get("next_ctx_id"), 1)
except Exception as e:
print(f"Found error: {e}")
errors.put(1)
raise

model = torch.nn.Sequential(
torch.nn.Linear(4, 4),
torch.nn.ReLU(),
torch.nn.Linear(4, 4),
torch.nn.ReLU(),
)
x = torch.randn([2, 4])

threads = []
errors = queue.Queue()
with compiled_autograd.enable(compiler_fn):
for i in range(4):
thread = threading.Thread(target=train, args=(errors, model, x))
threads.append(thread)
thread.start()

for thread in threads:
thread.join()

assert errors.empty()

def test_verbose_logs_graph(self):
def fn():
model = torch.nn.Sequential(
Expand Down
6 changes: 1 addition & 5 deletions torch/_C/_dynamo/compiled_autograd.pyi
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
from typing import Callable

from torch._dynamo.compiled_autograd import AutogradCompilerInstance

def set_autograd_compiler(
autograd_compiler: Callable[[], AutogradCompilerInstance] | None,
) -> Callable[[], AutogradCompilerInstance] | None: ...
def notify_autograd_engine() -> None: ...
def clear_cache() -> None: ...
def is_cache_empty() -> bool: ...
def set_verbose_logger(fn: Callable[[str], None] | None) -> bool: ...
149 changes: 112 additions & 37 deletions torch/_dynamo/compiled_autograd.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# mypy: allow-untyped-defs
import contextlib
import functools
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
import threading
from dataclasses import dataclass
from logging import Logger
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union

import torch
from torch._dynamo.external_utils import (
Expand Down Expand Up @@ -38,14 +41,90 @@
verbose_log = getArtifactLogger(__name__, "compiled_autograd_verbose")


def snapshot_verbose_logging_enabled():
return torch._logging._internal.log_state.is_artifact_enabled(
"compiled_autograd_verbose"
)
@dataclass
class CompiledAutogradTLS:
next_ctx_id: int = 0
in_compiled_autograd_region: bool = False
compiler: Optional["AutogradCompilerInstance"] = None
vlogger: Optional[Logger] = None


class TLSWrapper:
tls_key = "compiled_autograd_state"

def __init__(self):
self._local = threading.local()

def _get_tls(self) -> CompiledAutogradTLS:
if hasattr(self._local, self.tls_key):
# first look in python
state = getattr(self._local, self.tls_key)
if torch._C._is_key_in_tls(self.tls_key):
# then look in cpp
state = torch._C._get_obj_in_tls(self.tls_key)
else:
# init new thread created outside of autograd
# TODO: what if context manager wrapped outside of thread?
setattr(self._local, self.tls_key, CompiledAutogradTLS())
state = getattr(self._local, self.tls_key)
torch._C._stash_obj_in_tls(self.tls_key, state)
return state

# queries on the object stored in TLS
def get(self, name):
return getattr(self._get_tls(), name)

def set_tls(self, **kwargs) -> Callable[[], None]:
priors: Dict[str, Any] = {}
for k, v in kwargs.items():
state = self._get_tls()
priors[k] = getattr(state, k)
setattr(state, k, v)

torch._C._dynamo.compiled_autograd.notify_autograd_engine()

def revert():
self.set_tls(**priors)

return revert

def enabled(self) -> bool:
return self.get("compiler") is not None

def enter_ctx(self) -> Callable[[], None]:
state = self._get_tls()
state.next_ctx_id += 1
id = state.next_ctx_id

def exit():
assert (
state is self._get_tls()
), "Runtime must begin and end on the same thread"
assert state.next_ctx_id == id, (
"Error nesting compiled autograd context managers: "
"inner context managers must have shorter lifetime than the outer context manager"
)
state.next_ctx_id -= 1

return exit

def enter_compiled_region(self) -> Callable[[], None]:
state = self._get_tls()
prior = state.in_compiled_autograd_region
state.in_compiled_autograd_region = True
assert prior is False, "Nested compiled autograd regions are not supported"

def exit():
assert (
state is self._get_tls()
), "Runtime must begin and end on the same thread"
assert state.in_compiled_autograd_region is True
state.in_compiled_autograd_region = prior

return exit


def snapshot_cudagraph_enabled():
return torch._inductor.config.triton.cudagraphs
local = TLSWrapper()


def maybe_clone(x):
Expand Down Expand Up @@ -307,7 +386,7 @@ def end_capture(self, outputs):
self.rename_aot_dispatcher_nodes()
self.reorder_accumulate_grad_nodes()
runtime_inputs_to_move: List[int] = []
if snapshot_cudagraph_enabled():
if torch._inductor.config.triton.cudagraphs:
runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph)

graph = GraphModule(
Expand All @@ -329,16 +408,15 @@ def end_capture(self, outputs):
)

def runtime_wrapper(compiled_fn, inputs, sizes, scalars, hooks):
global in_compiled_autograd_region
try:
in_compiled_autograd_region = True
exit_compiled_region = local.enter_compiled_region()
for i in runtime_inputs_to_move:
inputs[i] = inputs[i].pin_memory().cuda(non_blocking=True)

with disable():
return compiled_fn(inputs, sizes, scalars, hooks)
finally:
in_compiled_autograd_region = False
exit_compiled_region()

return runtime_wrapper, self.compiler_fn(graph)

Expand Down Expand Up @@ -510,15 +588,9 @@ def set_node_origin(
set_stack_trace(new_stack_trace)


# state of the autograd engine dispatch, kept in sync by enable/disable context managers
compiled_autograd_enabled = False

# global flag to check if compiled autograd is enabled but Dynamo stance is "force_eager"
compiled_autograd_enabled_force_eager = False

# global flag to check if we are processing graphs produced from a compiled autograd graph
in_compiled_autograd_region = False


@contextlib.contextmanager
def enable(compiler_fn):
Expand All @@ -538,39 +610,42 @@ def enable(compiler_fn):
# we need to lazily import it, because of circular dependencies
import torch._inductor.cudagraph_trees

prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler(
functools.partial(AutogradCompilerInstance, compiler_fn)
exit_ctx = local.enter_ctx()
revert_tls = local.set_tls(
compiler=functools.partial(AutogradCompilerInstance, compiler_fn),
vlogger=verbose_log
if torch._logging._internal.log_state.is_artifact_enabled(
"compiled_autograd_verbose"
)
else None,
)
if snapshot_verbose_logging_enabled():
torch._C._dynamo.compiled_autograd.set_verbose_logger(verbose_log)
global compiled_autograd_enabled
compiled_autograd_enabled = True
try:
with torch.autograd.set_multithreading_enabled(False):
yield
finally:
if not prior:
compiled_autograd_enabled = False
torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior)
revert_tls()
exit_ctx()


@contextlib.contextmanager
def disable():
prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler(None)
global compiled_autograd_enabled
compiled_autograd_enabled = False
exit_ctx = local.enter_ctx()
revert_tls = local.set_tls(
compiler=None,
vlogger=None,
)
try:
yield
finally:
if prior:
compiled_autograd_enabled = True
torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior)
revert_tls()
exit_ctx()


# return to starting state of a new process
def reset() -> None:
global compiled_autograd_enabled
compiled_autograd_enabled = False
assert not in_compiled_autograd_region
torch._C._dynamo.compiled_autograd.set_autograd_compiler(None)
torch._C._dynamo.compiled_autograd.set_verbose_logger(None)
assert local.get("next_ctx_id") == 0
assert local.get("in_compiled_autograd_region") is False
local.set_tls(
compiler=None,
vlogger=None,
)
5 changes: 5 additions & 0 deletions torch/_dynamo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,11 @@ def default_debug_dir_root():
# Overrides torch.compile() kwargs for Compiled Autograd:
compiled_autograd_kwargs_override: Dict[str, Any] = {}

# Compiled Autograd will attempt to automatically wrap C++ autograd functions found in the autograd graph,
# and make them opaque to the compiler. This does not work when the C++ backward implementation involves
# other dispatcher subsystems e.g. custom subclasses, autocast, vmap.
compiled_autograd_opaque_cpp_node = False

# Enables use of collectives *during* compilation to synchronize behavior
# across ranks. Today, this is used solely to modify automatic_dynamic_shapes
# behavior, making it so that we infer that if an input is dynamic by
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3051,7 +3051,7 @@ def flatten_graph_inputs(gm: torch.fx.GraphModule, inputs, compile_gm):
if node.op == "placeholder" and node.meta.get("steal_arg", False)
]

if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"):
# fast path, avoid pytree overhead
# compiled autograd inputs are always a list of tensors, maybe followed by symints
assert inputs_idx_to_clear == [0]
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/variables/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def create(
user_hooks: VariableTracker,
user_pre_hooks: VariableTracker,
):
if not compiled_autograd.compiled_autograd_enabled:
if not compiled_autograd.local.enabled():
unimplemented("module-level backwards hooks require compiled autograd")

def _in_graph_bw_hooks(bw_state: BackwardState):
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/variables/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,7 +929,7 @@ def call_method(
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "queue_callback":
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"):
assert (
tx.one_graph
), "queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True"
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/variables/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,7 +1007,7 @@ def _method_register_hook(self, name: str, hook: VariableTracker):
tx = InstructionTranslator.current_tx()

if not self.source:
if not compiled_autograd.compiled_autograd_enabled:
if not compiled_autograd.local.enabled():
# TODO(voz):
# We can relax this by speculating the callable and ensuring that it doesn't modify arbitrary
# python state.
Expand Down
2 changes: 1 addition & 1 deletion torch/_functorch/_aot_autograd/autograd_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def check_cacheable(gm: torch.fx.GraphModule):
Checks that the graph module only uses supported operators
"""
nodes = gm.graph.nodes
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"):
raise BypassAOTAutogradCache(
"Cannot cache a graph with compiled autograd enabled"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,7 @@ def view_avoid_dupes_with_primals(t):
traced_tangent_memory_formats = [t[1] for t in tangents_and_memory_formats]
nonlocal static_input_indices
static_input_indices = static_input_indices or []
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"):
passed_indices = set(static_input_indices)
static_input_indices = [
i
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,7 @@ def aot_dispatch_autograd(
# becomes the lazy version again. One example is when dynamic shape is enabled
# upfront, the bw_compiler will be called above which can cause extra
# graph module recompilation on bw_module.
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"):
from torch.fx._lazy_graph_module import _LazyGraphModule

_LazyGraphModule.force_recompile(bw_module)
Expand Down
Loading

0 comments on commit 49fa437

Please sign in to comment.