From 5a13282c75f8eedc73c5dd27a88619a766ec8797 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Mon, 21 Oct 2024 17:32:41 -0700 Subject: [PATCH] [compiled autograd] tls access helpers (#138061) Pull Request resolved: https://github.com/pytorch/pytorch/pull/138061 Approved by: https://github.com/yf225 ghstack dependencies: #137953, #137821 --- .../_composable/fsdp/test_fully_shard_compile.py | 2 +- test/dynamo/test_activation_checkpointing.py | 4 +--- test/inductor/test_compiled_autograd.py | 2 +- torch/_dynamo/compiled_autograd.py | 11 ++++++++--- torch/_dynamo/utils.py | 2 +- torch/_dynamo/variables/distributed.py | 2 +- torch/_dynamo/variables/misc.py | 2 +- torch/_dynamo/variables/tensor.py | 2 +- torch/_functorch/_aot_autograd/autograd_cache.py | 2 +- .../_aot_autograd/collect_metadata_analysis.py | 2 +- .../_aot_autograd/jit_compile_runtime_wrappers.py | 2 +- torch/distributed/_composable/fsdp/_fsdp_common.py | 4 ++-- 12 files changed, 20 insertions(+), 17 deletions(-) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_compile.py b/test/distributed/_composable/fsdp/test_fully_shard_compile.py index fb75cb735531c1..4d02a06af69004 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_compile.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_compile.py @@ -256,7 +256,7 @@ def _check_count(copy_count, resize_count): f"Unexpected number of `inductor.resize_storage_bytes_` ops (expected {resize_count}, got {actual_resize_count}) in graph: {graph}", # noqa: B950 ) - if not torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"): + if not torch._dynamo.compiled_autograd.in_compiled_autograd_region(): _check_count(fwd_copy_count, fwd_resize_count) # fwd graph else: _check_count(bwd_copy_count, bwd_resize_count) # bwd graph diff --git a/test/dynamo/test_activation_checkpointing.py b/test/dynamo/test_activation_checkpointing.py index 3a4c44cdd3c869..afdc7bcadf600a 100644 --- a/test/dynamo/test_activation_checkpointing.py +++ b/test/dynamo/test_activation_checkpointing.py @@ -86,9 +86,7 @@ def match_rng_op(node, op): def collect_fwd_graph_outputs(graph: torch.fx.Graph, *, fwd_outputs: Set[str]): - if not torch._dynamo.compiled_autograd.local.get( - "in_compiled_autograd_region" - ): # fwd graph + if not torch._dynamo.compiled_autograd.in_compiled_autograd_region(): # fwd graph return_node = list(graph.nodes)[-1] assert return_node.target == "output" for x in return_node.args[0]: diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index 5a4256ce469f09..7c09ea9d49e041 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -2412,7 +2412,7 @@ def train(errors, model, x): try: out = model(x) with compiled_autograd.enable(compiler_fn): - self.assertEqual(compiled_autograd.local.enabled(), True) + self.assertEqual(compiled_autograd.enabled(), True) self.assertEqual(compiled_autograd.local.get("next_ctx_id"), 1) except Exception as e: print(f"Found error: {e}") diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index 8cb7c15c3e8d55..ada3aa4eee0b45 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -88,9 +88,6 @@ def revert(): return revert - def enabled(self) -> bool: - return self.get("compiler") is not None - def enter_ctx(self) -> Callable[[], None]: state = self._get_tls() state.next_ctx_id += 1 @@ -127,6 +124,14 @@ def exit(): local = TLSWrapper() +def enabled() -> bool: + return local.get("compiler") is not None + + +def in_compiled_autograd_region() -> bool: + return local.get("in_compiled_autograd_region") + + def maybe_clone(x): if x is not None: return clone_preserve_strides(x) diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 40e3e76e12f701..2325dda0edabeb 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -3051,7 +3051,7 @@ def flatten_graph_inputs(gm: torch.fx.GraphModule, inputs, compile_gm): if node.op == "placeholder" and node.meta.get("steal_arg", False) ] - if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"): + if torch._dynamo.compiled_autograd.in_compiled_autograd_region(): # fast path, avoid pytree overhead # compiled autograd inputs are always a list of tensors, maybe followed by symints assert inputs_idx_to_clear == [0] diff --git a/torch/_dynamo/variables/distributed.py b/torch/_dynamo/variables/distributed.py index 42aec8b40ba607..6afffc15ad169d 100644 --- a/torch/_dynamo/variables/distributed.py +++ b/torch/_dynamo/variables/distributed.py @@ -313,7 +313,7 @@ def create( user_hooks: VariableTracker, user_pre_hooks: VariableTracker, ): - if not compiled_autograd.local.enabled(): + if not compiled_autograd.enabled(): unimplemented("module-level backwards hooks require compiled autograd") def _in_graph_bw_hooks(bw_state: BackwardState): diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 8526c9b965a3db..32a09df5961c33 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -929,7 +929,7 @@ def call_method( kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": if name == "queue_callback": - if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"): + if torch._dynamo.compiled_autograd.in_compiled_autograd_region(): assert ( tx.one_graph ), "queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True" diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index fb3fa959ceafd0..c782a19d7f1e4d 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -1007,7 +1007,7 @@ def _method_register_hook(self, name: str, hook: VariableTracker): tx = InstructionTranslator.current_tx() if not self.source: - if not compiled_autograd.local.enabled(): + if not compiled_autograd.enabled(): # TODO(voz): # We can relax this by speculating the callable and ensuring that it doesn't modify arbitrary # python state. diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index e31e572e624a16..9512e6561a438a 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -176,7 +176,7 @@ def check_cacheable(gm: torch.fx.GraphModule): Checks that the graph module only uses supported operators """ nodes = gm.graph.nodes - if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"): + if torch._dynamo.compiled_autograd.in_compiled_autograd_region(): raise BypassAOTAutogradCache( "Cannot cache a graph with compiled autograd enabled" ) diff --git a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py index 59d01817bbdbf1..d84321ddf4b07e 100644 --- a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py +++ b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py @@ -704,7 +704,7 @@ def view_avoid_dupes_with_primals(t): traced_tangent_memory_formats = [t[1] for t in tangents_and_memory_formats] nonlocal static_input_indices static_input_indices = static_input_indices or [] - if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"): + if torch._dynamo.compiled_autograd.in_compiled_autograd_region(): passed_indices = set(static_input_indices) static_input_indices = [ i diff --git a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py index bfe90378b3b527..801a0dfb8bd97d 100644 --- a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py @@ -760,7 +760,7 @@ def aot_dispatch_autograd( # becomes the lazy version again. One example is when dynamic shape is enabled # upfront, the bw_compiler will be called above which can cause extra # graph module recompilation on bw_module. - if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"): + if torch._dynamo.compiled_autograd.in_compiled_autograd_region(): from torch.fx._lazy_graph_module import _LazyGraphModule _LazyGraphModule.force_recompile(bw_module) diff --git a/torch/distributed/_composable/fsdp/_fsdp_common.py b/torch/distributed/_composable/fsdp/_fsdp_common.py index c4be02d10ae7fe..d967a55d254541 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_common.py +++ b/torch/distributed/_composable/fsdp/_fsdp_common.py @@ -33,9 +33,9 @@ def detect_compiled_autograd(): import torch._dynamo.compiled_autograd as ca _compiled_autograd_enabled = ( - ca.local.enabled() + ca.enabled() or ca.compiled_autograd_enabled_force_eager - or ca.local.get("in_compiled_autograd_region") + or ca.in_compiled_autograd_region() ) def compiled_autograd_enabled():