diff --git a/test/distributed/_composable/fsdp/test_fully_shard_compile.py b/test/distributed/_composable/fsdp/test_fully_shard_compile.py index b1552909d457d..fb75cb735531c 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_compile.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_compile.py @@ -256,7 +256,7 @@ def _check_count(copy_count, resize_count): f"Unexpected number of `inductor.resize_storage_bytes_` ops (expected {resize_count}, got {actual_resize_count}) in graph: {graph}", # noqa: B950 ) - if not torch._dynamo.compiled_autograd.in_compiled_autograd_region: + if not torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"): _check_count(fwd_copy_count, fwd_resize_count) # fwd graph else: _check_count(bwd_copy_count, bwd_resize_count) # bwd graph diff --git a/test/dynamo/test_activation_checkpointing.py b/test/dynamo/test_activation_checkpointing.py index e82ad38d1fbd5..3a4c44cdd3c86 100644 --- a/test/dynamo/test_activation_checkpointing.py +++ b/test/dynamo/test_activation_checkpointing.py @@ -86,7 +86,9 @@ def match_rng_op(node, op): def collect_fwd_graph_outputs(graph: torch.fx.Graph, *, fwd_outputs: Set[str]): - if not torch._dynamo.compiled_autograd.in_compiled_autograd_region: # fwd graph + if not torch._dynamo.compiled_autograd.local.get( + "in_compiled_autograd_region" + ): # fwd graph return_node = list(graph.nodes)[-1] assert return_node.target == "output" for x in return_node.args[0]: diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index f65e700749977..5a4256ce469f0 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -6,9 +6,11 @@ import itertools import logging import os +import queue import re import subprocess import sys +import threading import unittest from importlib.machinery import SourceFileLoader from pathlib import Path @@ -2405,6 +2407,39 @@ def test_logs(self): not in logs.getvalue() ) + def test_multithreading_tls(self): + def train(errors, model, x): + try: + out = model(x) + with compiled_autograd.enable(compiler_fn): + self.assertEqual(compiled_autograd.local.enabled(), True) + self.assertEqual(compiled_autograd.local.get("next_ctx_id"), 1) + except Exception as e: + print(f"Found error: {e}") + errors.put(1) + raise + + model = torch.nn.Sequential( + torch.nn.Linear(4, 4), + torch.nn.ReLU(), + torch.nn.Linear(4, 4), + torch.nn.ReLU(), + ) + x = torch.randn([2, 4]) + + threads = [] + errors = queue.Queue() + with compiled_autograd.enable(compiler_fn): + for i in range(4): + thread = threading.Thread(target=train, args=(errors, model, x)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + assert errors.empty() + def test_verbose_logs_graph(self): def fn(): model = torch.nn.Sequential( diff --git a/torch/_C/_dynamo/compiled_autograd.pyi b/torch/_C/_dynamo/compiled_autograd.pyi index 80144e3a77907..b308f63844ed6 100644 --- a/torch/_C/_dynamo/compiled_autograd.pyi +++ b/torch/_C/_dynamo/compiled_autograd.pyi @@ -1,10 +1,6 @@ from typing import Callable -from torch._dynamo.compiled_autograd import AutogradCompilerInstance - -def set_autograd_compiler( - autograd_compiler: Callable[[], AutogradCompilerInstance] | None, -) -> Callable[[], AutogradCompilerInstance] | None: ... +def notify_autograd_engine() -> None: ... def clear_cache() -> None: ... def is_cache_empty() -> bool: ... def set_verbose_logger(fn: Callable[[str], None] | None) -> bool: ... diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index 950737d4bcba6..8cb7c15c3e8d5 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -1,7 +1,10 @@ # mypy: allow-untyped-defs import contextlib import functools -from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union +import threading +from dataclasses import dataclass +from logging import Logger +from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union import torch from torch._dynamo.external_utils import ( @@ -38,14 +41,90 @@ verbose_log = getArtifactLogger(__name__, "compiled_autograd_verbose") -def snapshot_verbose_logging_enabled(): - return torch._logging._internal.log_state.is_artifact_enabled( - "compiled_autograd_verbose" - ) +@dataclass +class CompiledAutogradTLS: + next_ctx_id: int = 0 + in_compiled_autograd_region: bool = False + compiler: Optional["AutogradCompilerInstance"] = None + vlogger: Optional[Logger] = None + + +class TLSWrapper: + tls_key = "compiled_autograd_state" + + def __init__(self): + self._local = threading.local() + + def _get_tls(self) -> CompiledAutogradTLS: + if hasattr(self._local, self.tls_key): + # first look in python + state = getattr(self._local, self.tls_key) + if torch._C._is_key_in_tls(self.tls_key): + # then look in cpp + state = torch._C._get_obj_in_tls(self.tls_key) + else: + # init new thread created outside of autograd + # TODO: what if context manager wrapped outside of thread? + setattr(self._local, self.tls_key, CompiledAutogradTLS()) + state = getattr(self._local, self.tls_key) + torch._C._stash_obj_in_tls(self.tls_key, state) + return state + + # queries on the object stored in TLS + def get(self, name): + return getattr(self._get_tls(), name) + + def set_tls(self, **kwargs) -> Callable[[], None]: + priors: Dict[str, Any] = {} + for k, v in kwargs.items(): + state = self._get_tls() + priors[k] = getattr(state, k) + setattr(state, k, v) + + torch._C._dynamo.compiled_autograd.notify_autograd_engine() + + def revert(): + self.set_tls(**priors) + + return revert + + def enabled(self) -> bool: + return self.get("compiler") is not None + + def enter_ctx(self) -> Callable[[], None]: + state = self._get_tls() + state.next_ctx_id += 1 + id = state.next_ctx_id + + def exit(): + assert ( + state is self._get_tls() + ), "Runtime must begin and end on the same thread" + assert state.next_ctx_id == id, ( + "Error nesting compiled autograd context managers: " + "inner context managers must have shorter lifetime than the outer context manager" + ) + state.next_ctx_id -= 1 + + return exit + + def enter_compiled_region(self) -> Callable[[], None]: + state = self._get_tls() + prior = state.in_compiled_autograd_region + state.in_compiled_autograd_region = True + assert prior is False, "Nested compiled autograd regions are not supported" + + def exit(): + assert ( + state is self._get_tls() + ), "Runtime must begin and end on the same thread" + assert state.in_compiled_autograd_region is True + state.in_compiled_autograd_region = prior + + return exit -def snapshot_cudagraph_enabled(): - return torch._inductor.config.triton.cudagraphs +local = TLSWrapper() def maybe_clone(x): @@ -307,7 +386,7 @@ def end_capture(self, outputs): self.rename_aot_dispatcher_nodes() self.reorder_accumulate_grad_nodes() runtime_inputs_to_move: List[int] = [] - if snapshot_cudagraph_enabled(): + if torch._inductor.config.triton.cudagraphs: runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph) graph = GraphModule( @@ -329,16 +408,15 @@ def end_capture(self, outputs): ) def runtime_wrapper(compiled_fn, inputs, sizes, scalars, hooks): - global in_compiled_autograd_region try: - in_compiled_autograd_region = True + exit_compiled_region = local.enter_compiled_region() for i in runtime_inputs_to_move: inputs[i] = inputs[i].pin_memory().cuda(non_blocking=True) with disable(): return compiled_fn(inputs, sizes, scalars, hooks) finally: - in_compiled_autograd_region = False + exit_compiled_region() return runtime_wrapper, self.compiler_fn(graph) @@ -510,15 +588,9 @@ def set_node_origin( set_stack_trace(new_stack_trace) -# state of the autograd engine dispatch, kept in sync by enable/disable context managers -compiled_autograd_enabled = False - # global flag to check if compiled autograd is enabled but Dynamo stance is "force_eager" compiled_autograd_enabled_force_eager = False -# global flag to check if we are processing graphs produced from a compiled autograd graph -in_compiled_autograd_region = False - @contextlib.contextmanager def enable(compiler_fn): @@ -538,39 +610,42 @@ def enable(compiler_fn): # we need to lazily import it, because of circular dependencies import torch._inductor.cudagraph_trees - prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler( - functools.partial(AutogradCompilerInstance, compiler_fn) + exit_ctx = local.enter_ctx() + revert_tls = local.set_tls( + compiler=functools.partial(AutogradCompilerInstance, compiler_fn), + vlogger=verbose_log + if torch._logging._internal.log_state.is_artifact_enabled( + "compiled_autograd_verbose" + ) + else None, ) - if snapshot_verbose_logging_enabled(): - torch._C._dynamo.compiled_autograd.set_verbose_logger(verbose_log) - global compiled_autograd_enabled - compiled_autograd_enabled = True try: with torch.autograd.set_multithreading_enabled(False): yield finally: - if not prior: - compiled_autograd_enabled = False - torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior) + revert_tls() + exit_ctx() @contextlib.contextmanager def disable(): - prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler(None) - global compiled_autograd_enabled - compiled_autograd_enabled = False + exit_ctx = local.enter_ctx() + revert_tls = local.set_tls( + compiler=None, + vlogger=None, + ) try: yield finally: - if prior: - compiled_autograd_enabled = True - torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior) + revert_tls() + exit_ctx() # return to starting state of a new process def reset() -> None: - global compiled_autograd_enabled - compiled_autograd_enabled = False - assert not in_compiled_autograd_region - torch._C._dynamo.compiled_autograd.set_autograd_compiler(None) - torch._C._dynamo.compiled_autograd.set_verbose_logger(None) + assert local.get("next_ctx_id") == 0 + assert local.get("in_compiled_autograd_region") is False + local.set_tls( + compiler=None, + vlogger=None, + ) diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index c3db01dfa454b..47c3b5470612d 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -472,6 +472,11 @@ def default_debug_dir_root(): # Overrides torch.compile() kwargs for Compiled Autograd: compiled_autograd_kwargs_override: Dict[str, Any] = {} +# Compiled Autograd will attempt to automatically wrap C++ autograd functions found in the autograd graph, +# and make them opaque to the compiler. This does not work when the C++ backward implementation involves +# other dispatcher subsystems e.g. custom subclasses, autocast, vmap. +compiled_autograd_opaque_cpp_node = False + # Enables use of collectives *during* compilation to synchronize behavior # across ranks. Today, this is used solely to modify automatic_dynamic_shapes # behavior, making it so that we infer that if an input is dynamic by diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 3d4f00480be0f..40e3e76e12f70 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -3051,7 +3051,7 @@ def flatten_graph_inputs(gm: torch.fx.GraphModule, inputs, compile_gm): if node.op == "placeholder" and node.meta.get("steal_arg", False) ] - if torch._dynamo.compiled_autograd.in_compiled_autograd_region: + if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"): # fast path, avoid pytree overhead # compiled autograd inputs are always a list of tensors, maybe followed by symints assert inputs_idx_to_clear == [0] diff --git a/torch/_dynamo/variables/distributed.py b/torch/_dynamo/variables/distributed.py index c14b8794cba5f..42aec8b40ba60 100644 --- a/torch/_dynamo/variables/distributed.py +++ b/torch/_dynamo/variables/distributed.py @@ -313,7 +313,7 @@ def create( user_hooks: VariableTracker, user_pre_hooks: VariableTracker, ): - if not compiled_autograd.compiled_autograd_enabled: + if not compiled_autograd.local.enabled(): unimplemented("module-level backwards hooks require compiled autograd") def _in_graph_bw_hooks(bw_state: BackwardState): diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 7f4ad96601a73..8526c9b965a3d 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -929,7 +929,7 @@ def call_method( kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": if name == "queue_callback": - if torch._dynamo.compiled_autograd.in_compiled_autograd_region: + if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"): assert ( tx.one_graph ), "queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True" diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index c6c9918a71449..fb3fa959ceafd 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -1007,7 +1007,7 @@ def _method_register_hook(self, name: str, hook: VariableTracker): tx = InstructionTranslator.current_tx() if not self.source: - if not compiled_autograd.compiled_autograd_enabled: + if not compiled_autograd.local.enabled(): # TODO(voz): # We can relax this by speculating the callable and ensuring that it doesn't modify arbitrary # python state. diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index 833958c78cb9e..e31e572e624a1 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -176,7 +176,7 @@ def check_cacheable(gm: torch.fx.GraphModule): Checks that the graph module only uses supported operators """ nodes = gm.graph.nodes - if torch._dynamo.compiled_autograd.in_compiled_autograd_region: + if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"): raise BypassAOTAutogradCache( "Cannot cache a graph with compiled autograd enabled" ) diff --git a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py index 9be7779789196..59d01817bbdbf 100644 --- a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py +++ b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py @@ -704,7 +704,7 @@ def view_avoid_dupes_with_primals(t): traced_tangent_memory_formats = [t[1] for t in tangents_and_memory_formats] nonlocal static_input_indices static_input_indices = static_input_indices or [] - if torch._dynamo.compiled_autograd.in_compiled_autograd_region: + if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"): passed_indices = set(static_input_indices) static_input_indices = [ i diff --git a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py index 496d93d0cd035..bfe90378b3b52 100644 --- a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py @@ -760,7 +760,7 @@ def aot_dispatch_autograd( # becomes the lazy version again. One example is when dynamic shape is enabled # upfront, the bw_compiler will be called above which can cause extra # graph module recompilation on bw_module. - if torch._dynamo.compiled_autograd.in_compiled_autograd_region: + if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"): from torch.fx._lazy_graph_module import _LazyGraphModule _LazyGraphModule.force_recompile(bw_module) diff --git a/torch/csrc/dynamo/compiled_autograd.h b/torch/csrc/dynamo/compiled_autograd.h index db69b17ecee39..2f7f364300105 100644 --- a/torch/csrc/dynamo/compiled_autograd.h +++ b/torch/csrc/dynamo/compiled_autograd.h @@ -224,11 +224,29 @@ struct LiftedIValueArgs { const std::optional& active_node_call_idx; }; +// Hold GIL while using +struct PyTLSWrapper { + PyTLSWrapper(PyObject* state) : state(state) {} + PyTLSWrapper(const PyTLSWrapper&) = delete; + PyTLSWrapper& operator=(const PyTLSWrapper&) = delete; + PyTLSWrapper(PyTLSWrapper&&) = default; + PyTLSWrapper& operator=(PyTLSWrapper&&) = default; + + static PyTLSWrapper create(); + + PyObject* get(std::string_view key) const; + + private: + PyObject* state; +}; + struct AutogradCompilerCall { - AutogradCompilerCall() + AutogradCompilerCall() = delete; + AutogradCompilerCall(PyTLSWrapper&& state) : active_node_call_idx(std::nullopt), tensor_args(active_node_call_idx), - lifted_ivalue_args(active_node_call_idx) {} + lifted_ivalue_args(active_node_call_idx), + state(std::move(state)) {} void add_size_input(const c10::SymInt& s) { all_size_inputs.emplace_back( default_dyn_type, s.guard_int(__FILE__, __LINE__)); @@ -254,8 +272,11 @@ struct AutogradCompilerCall { std::vector hooks; NodeCalls node_calls; SizeInput::DynType default_dyn_type = SizeInput::STATIC; + // NodeCall id of each size, only when verbose logging is enabled std::vector size_input_origins; + + const PyTLSWrapper state; }; class CompiledNodeArgs { diff --git a/torch/csrc/dynamo/python_compiled_autograd.cpp b/torch/csrc/dynamo/python_compiled_autograd.cpp index 732b43d7d5c7c..7a5969fffba16 100644 --- a/torch/csrc/dynamo/python_compiled_autograd.cpp +++ b/torch/csrc/dynamo/python_compiled_autograd.cpp @@ -88,10 +88,6 @@ static void check(bool result) { if (C10_UNLIKELY(!result)) check(nullptr); } - -// snapshot of python verbose logging toggle -static PyObject* python_verbose_logger = nullptr; - struct PythonLogger { PythonLogger() = delete; explicit PythonLogger(PyObject* logger) : logger_(logger) { @@ -135,15 +131,15 @@ struct PythonLogger { }; struct VerboseLogger : public PythonLogger { - static std::optional maybe_create() { - if (python_verbose_logger == nullptr) { + VerboseLogger(PyObject* vlogger) : PythonLogger(vlogger) {} + + static std::optional maybe_create(PyObject* vlogger) { + if (vlogger == Py_None) { return std::nullopt; } - return VerboseLogger(python_verbose_logger); + return VerboseLogger(vlogger); } - VerboseLogger(PyObject* vlogger) : PythonLogger(vlogger) {} - void log_node_check( const Node& fn, size_t size_inputs_num, @@ -368,8 +364,22 @@ struct InputBuffers : public std::unordered_map { } }; -static PyObject* the_autograd_compiler = nullptr; -static PyObject* set_autograd_compiler(PyObject* dummy, PyObject* args); +/* static */ PyTLSWrapper PyTLSWrapper::create() { + TORCH_INTERNAL_ASSERT( + at::impl::ThreadLocalPythonObjects::contains("compiled_autograd_state")); + PyObject* compiled_autograd_state = + check(at::impl::ThreadLocalPythonObjects::get("compiled_autograd_state") + ->ptr(getPyInterpreter())); + return PyTLSWrapper(compiled_autograd_state); +} + +// Refer to fields in python class CompiledAutogradTLS +// May return Py_None +PyObject* PyTLSWrapper::get(std::string_view key) const { + return check(PyObject_GetAttrString(state, key.data())); +} + +static PyObject* notify_autograd_engine(PyObject* dummy, PyObject* args); static PyObject* clear_cache(PyObject* dummy, PyObject* args) { HANDLE_TH_ERRORS; @@ -387,28 +397,11 @@ static PyObject* is_cache_empty(PyObject* dummy, PyObject* args) { END_HANDLE_TH_ERRORS; } -static PyObject* set_verbose_logger(PyObject* dummy, PyObject* args) { - HANDLE_TH_ERRORS; - PyObject* logger = nullptr; - if (!PyArg_ParseTuple(args, "O", &logger)) { - throw_python_error(); - } - - if (logger == Py_None) { - python_verbose_logger = nullptr; - } else { - python_verbose_logger = logger; - } - Py_RETURN_TRUE; - END_HANDLE_TH_ERRORS; -} - // NOLINTNEXTLINE(*array*) static PyMethodDef _methods[] = { - {"set_autograd_compiler", set_autograd_compiler, METH_VARARGS, nullptr}, + {"notify_autograd_engine", notify_autograd_engine, METH_NOARGS, nullptr}, {"clear_cache", clear_cache, METH_NOARGS, nullptr}, {"is_cache_empty", is_cache_empty, METH_NOARGS, nullptr}, - {"set_verbose_logger", set_verbose_logger, METH_VARARGS, nullptr}, {nullptr, nullptr, 0, nullptr}}; static struct PyModuleDef _module = { @@ -568,7 +561,7 @@ CacheNode* _compiled_autograd_impl( THPObjectPtr* graph_arg_hooks) { std::unordered_map& dependencies = graph_task.dependencies_; std::vector> worklist{graph_root}; - AutogradCompilerCall compiler_call; + AutogradCompilerCall compiler_call(PyTLSWrapper::create()); for (const auto i : c10::irange(output_edges.size())) { compiler_call.node_calls @@ -583,7 +576,8 @@ CacheNode* _compiled_autograd_impl( check_exec_info ? graph_task.exec_info_.size() : dependencies.size() + 1); int i = 0; - std::optional vlogger = VerboseLogger::maybe_create(); + std::optional vlogger = + VerboseLogger::maybe_create(compiler_call.state.get("vlogger")); while (!worklist.empty()) { std::shared_ptr fn = std::move(worklist.back()); worklist.pop_back(); @@ -642,6 +636,8 @@ CacheNode* _compiled_autograd_impl( // TODO(jansel): some dynamic sizes seem to be ints not symints if (!cache->check_dynamic_sizes(compiler_call, vlogger)) { // cache miss, need to capture FX graph + PyObject* the_autograd_compiler = compiler_call.state.get("compiler"); + TORCH_INTERNAL_ASSERT(the_autograd_compiler != Py_None); ClosingTHPObjectPtr py_compiler( check(PyObject_CallNoArgs((the_autograd_compiler)))); @@ -839,28 +835,16 @@ variable_list compiled_autograd( return outputs; } -static PyObject* set_autograd_compiler(PyObject* dummy, PyObject* args) { +static PyObject* notify_autograd_engine(PyObject* dummy, PyObject* args) { HANDLE_TH_ERRORS; - PyObject* obj = nullptr; - if (!PyArg_ParseTuple(args, "O", &obj)) { - return nullptr; - } - - PyObject* prior = the_autograd_compiler; - if (obj == Py_None) { // disable - the_autograd_compiler = nullptr; // decref not needed due to `prior` + PyTLSWrapper state = PyTLSWrapper::create(); + PyObject* compiler = state.get("compiler"); + if (compiler == Py_None) { // disable Engine::set_compiled_autograd(nullptr); } else { // enable - Py_INCREF(obj); - the_autograd_compiler = obj; Engine::set_compiled_autograd(&compiled_autograd); } - - if (prior == nullptr) { - Py_RETURN_NONE; - } else { - return prior; - } + Py_RETURN_NONE; END_HANDLE_TH_ERRORS; } diff --git a/torch/distributed/_composable/fsdp/_fsdp_common.py b/torch/distributed/_composable/fsdp/_fsdp_common.py index 74c6f4fdfea7b..c4be02d10ae7f 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_common.py +++ b/torch/distributed/_composable/fsdp/_fsdp_common.py @@ -33,9 +33,9 @@ def detect_compiled_autograd(): import torch._dynamo.compiled_autograd as ca _compiled_autograd_enabled = ( - ca.compiled_autograd_enabled + ca.local.enabled() or ca.compiled_autograd_enabled_force_eager - or ca.in_compiled_autograd_region + or ca.local.get("in_compiled_autograd_region") ) def compiled_autograd_enabled():