Skip to content

Commit

Permalink
[compiled autograd] tls access helpers (pytorch#138061)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#138061
Approved by: https://github.com/yf225
ghstack dependencies: pytorch#137953, pytorch#137821
  • Loading branch information
xmfan authored and pytorchmergebot committed Oct 22, 2024
1 parent 49fa437 commit 5a13282
Show file tree
Hide file tree
Showing 12 changed files with 20 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def _check_count(copy_count, resize_count):
f"Unexpected number of `inductor.resize_storage_bytes_` ops (expected {resize_count}, got {actual_resize_count}) in graph: {graph}", # noqa: B950
)

if not torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"):
if not torch._dynamo.compiled_autograd.in_compiled_autograd_region():
_check_count(fwd_copy_count, fwd_resize_count) # fwd graph
else:
_check_count(bwd_copy_count, bwd_resize_count) # bwd graph
Expand Down
4 changes: 1 addition & 3 deletions test/dynamo/test_activation_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,7 @@ def match_rng_op(node, op):


def collect_fwd_graph_outputs(graph: torch.fx.Graph, *, fwd_outputs: Set[str]):
if not torch._dynamo.compiled_autograd.local.get(
"in_compiled_autograd_region"
): # fwd graph
if not torch._dynamo.compiled_autograd.in_compiled_autograd_region(): # fwd graph
return_node = list(graph.nodes)[-1]
assert return_node.target == "output"
for x in return_node.args[0]:
Expand Down
2 changes: 1 addition & 1 deletion test/inductor/test_compiled_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2412,7 +2412,7 @@ def train(errors, model, x):
try:
out = model(x)
with compiled_autograd.enable(compiler_fn):
self.assertEqual(compiled_autograd.local.enabled(), True)
self.assertEqual(compiled_autograd.enabled(), True)
self.assertEqual(compiled_autograd.local.get("next_ctx_id"), 1)
except Exception as e:
print(f"Found error: {e}")
Expand Down
11 changes: 8 additions & 3 deletions torch/_dynamo/compiled_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,6 @@ def revert():

return revert

def enabled(self) -> bool:
return self.get("compiler") is not None

def enter_ctx(self) -> Callable[[], None]:
state = self._get_tls()
state.next_ctx_id += 1
Expand Down Expand Up @@ -127,6 +124,14 @@ def exit():
local = TLSWrapper()


def enabled() -> bool:
return local.get("compiler") is not None


def in_compiled_autograd_region() -> bool:
return local.get("in_compiled_autograd_region")


def maybe_clone(x):
if x is not None:
return clone_preserve_strides(x)
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3051,7 +3051,7 @@ def flatten_graph_inputs(gm: torch.fx.GraphModule, inputs, compile_gm):
if node.op == "placeholder" and node.meta.get("steal_arg", False)
]

if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"):
if torch._dynamo.compiled_autograd.in_compiled_autograd_region():
# fast path, avoid pytree overhead
# compiled autograd inputs are always a list of tensors, maybe followed by symints
assert inputs_idx_to_clear == [0]
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/variables/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def create(
user_hooks: VariableTracker,
user_pre_hooks: VariableTracker,
):
if not compiled_autograd.local.enabled():
if not compiled_autograd.enabled():
unimplemented("module-level backwards hooks require compiled autograd")

def _in_graph_bw_hooks(bw_state: BackwardState):
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/variables/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,7 +929,7 @@ def call_method(
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "queue_callback":
if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"):
if torch._dynamo.compiled_autograd.in_compiled_autograd_region():
assert (
tx.one_graph
), "queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True"
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/variables/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,7 +1007,7 @@ def _method_register_hook(self, name: str, hook: VariableTracker):
tx = InstructionTranslator.current_tx()

if not self.source:
if not compiled_autograd.local.enabled():
if not compiled_autograd.enabled():
# TODO(voz):
# We can relax this by speculating the callable and ensuring that it doesn't modify arbitrary
# python state.
Expand Down
2 changes: 1 addition & 1 deletion torch/_functorch/_aot_autograd/autograd_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def check_cacheable(gm: torch.fx.GraphModule):
Checks that the graph module only uses supported operators
"""
nodes = gm.graph.nodes
if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"):
if torch._dynamo.compiled_autograd.in_compiled_autograd_region():
raise BypassAOTAutogradCache(
"Cannot cache a graph with compiled autograd enabled"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,7 @@ def view_avoid_dupes_with_primals(t):
traced_tangent_memory_formats = [t[1] for t in tangents_and_memory_formats]
nonlocal static_input_indices
static_input_indices = static_input_indices or []
if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"):
if torch._dynamo.compiled_autograd.in_compiled_autograd_region():
passed_indices = set(static_input_indices)
static_input_indices = [
i
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,7 @@ def aot_dispatch_autograd(
# becomes the lazy version again. One example is when dynamic shape is enabled
# upfront, the bw_compiler will be called above which can cause extra
# graph module recompilation on bw_module.
if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"):
if torch._dynamo.compiled_autograd.in_compiled_autograd_region():
from torch.fx._lazy_graph_module import _LazyGraphModule

_LazyGraphModule.force_recompile(bw_module)
Expand Down
4 changes: 2 additions & 2 deletions torch/distributed/_composable/fsdp/_fsdp_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def detect_compiled_autograd():
import torch._dynamo.compiled_autograd as ca

_compiled_autograd_enabled = (
ca.local.enabled()
ca.enabled()
or ca.compiled_autograd_enabled_force_eager
or ca.local.get("in_compiled_autograd_region")
or ca.in_compiled_autograd_region()
)

def compiled_autograd_enabled():
Expand Down

0 comments on commit 5a13282

Please sign in to comment.