Skip to content

Commit

Permalink
Revert "[test] AOTAutograd: support mutations on buffers that happen …
Browse files Browse the repository at this point in the history
…during th bw (pytorch#112906)"

This reverts commit c8974d6.

Reverted pytorch#112906 on behalf of https://github.com/huydhn due to There are lots of failure after this change https://hud.pytorch.org/pytorch/pytorch/commit/c8974d649d684a33a5c02a0b112a6e0743201d97, this is probably a landrace ([comment](pytorch#112906 (comment)))
  • Loading branch information
pytorchmergebot committed Nov 29, 2023
1 parent 4bfb198 commit 48820c9
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 123 deletions.
85 changes: 0 additions & 85 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1438,91 +1438,6 @@ def inp_callable(req_grad):
with self.assertRaisesRegex(AssertionError, "attempted to compile the backward with incorrect subclass metadata"):
self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True, make_inputs_subclasses=True)

# Mutations in the backward are allowed as long as the mutated object does not require grad
def test_backward_mutation_data(self):
class BwMutation(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x.clone()

@staticmethod
def backward(ctx, grad_output):
x, = ctx.saved_tensors
# bw mutation
x.mul_(2)
return grad_output.clone()

def f(a, b):
out = BwMutation.apply(b)
return a * out

inp_no_grad = [
torch.ones(3, 3, requires_grad=True),
torch.ones(3, 3, requires_grad=False),
]

# Mutation on buffer that does not require grad during the backward is allowed
self.verify_aot_autograd(f, inp_no_grad, test_mutation=True)

inp_grad = [
torch.ones(3, 3, requires_grad=True),
torch.ones(3, 3, requires_grad=True),
]
with self.assertRaisesRegex(AssertionError, "input that requires_grad and was mutated in the backward"):
self.verify_aot_autograd(f, inp_grad, test_mutation=True)

def test_backward_mutation_metadata(self):
class BwMutation(torch.autograd.Function):
@staticmethod
def forward(ctx, a, b):
ctx.save_for_backward(b)
return a.clone(), b.clone()

@staticmethod
def backward(ctx, grad_a, grad_b):
b, = ctx.saved_tensors
# bw metadata mutation
b.transpose_(1, 0)
return grad_a.clone(), grad_b.clone()

def f(a, b):
a_, b_ = BwMutation.apply(a, b)
out = a_ * b_
return out

inp_no_grad = [
torch.ones(3, 3, requires_grad=True),
torch.ones(3, 3, requires_grad=False),
]

with self.assertRaisesRegex(AssertionError, "input that had its metadata mutated in the backward"):
self.verify_aot_autograd(f, inp_no_grad, test_mutation=True)

def test_backward_mutation_on_grad_out(self):
class BwMutation(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x.clone()

@staticmethod
def backward(ctx, grad_output):
grad_output.mul_(2)
return grad_output.clone()

def f(a, b):
tmp = a * b
out = BwMutation.apply(tmp)
return out

inp_grad = [
torch.ones(3, 3, requires_grad=True),
torch.ones(3, 3, requires_grad=True),
]
f_compiled = aot_function(f, nop)
with self.assertRaisesRegex(AssertionError, "input to the backward that was mutated during the backward"):
out = f_compiled(*inp_grad)

# Partially addresses https://github.com/pytorch/pytorch/issues/106457
def test_input_mutation_false_aliasing(self):
def f(a, b):
Expand Down
39 changes: 1 addition & 38 deletions torch/_functorch/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1910,43 +1910,6 @@ def functionalized_f_helper(*args):
# Run the joint
f_outs = fn(*f_args)

if trace_joint:
# We support a limited amount of mutation of graph inputs during the backward pass.
# (This is used e.g. by Float8, which needs to update buffers during the backward pass)
# Here, we perform extra checks for primals that were mutated in the **backward**
# We're doing the checks here instead of doing them with the rest of the input mutation handling because:
# - We need to detect inputs that were mutated in the backward **separately** from mutations that happened
# during the forward, because the handling is different: some input mutations from the the forward
# can be only handled in a fw-only runtime epilogue, and in theory if we wanted to handle those same
# types of mutations in the backward we would need a bw-only runtime epilogue.
# - We could in theory have our analysis pass differentiate mutations in the fw from mutations in
# the bw by running our analysis first on the fw-only graph, and then on the joint graph. This would
# require an extra round of tracing though, so it's more efficient to do in-line here.
assert isinstance(args, tuple) and len(args) == 2 and isinstance(args[0], (list, tuple))
# Only look at mutations that happened to forward inputs (e.g. fw buffers that were saved for bw)
primals_before = args[0]
primals_after = pytree.tree_map(from_fun, f_args[0])
for before, after, inpt_info in zip(primals_before, primals_after, meta.input_info):
# Ban metadata mutations on fw inputs during the bw
if not inpt_info.mutates_metadata:
assert not was_metadata_updated(before, after), \
"Found a graph input that had its metadata mutated in the backward. This is not supported"
# Allow data mutations on fw inputs during the bw, but only if they do not require grad
# So we can guarantee that we can keep the mutations in the graph
if was_updated(before, after) and not inpt_info.mutates_data:
assert not inpt_info.requires_grad, \
"Found a graph input that requires_grad and was mutated in the backward. This is not supported"
# Otherwise, put the mutation in the graph
before.copy_(after)
# Now that we covered mutations to *forward* inputs during the backward,
# we also need to cover mutations to *backward-only* inputs during the backward (e.g. mutation to a grad_out).
# Today, we will just error in all cases of this happening unless someone needs us to support it.
tangents_before = args[1]
tangents_after = pytree.tree_map(from_fun, f_args[1])
for before, after in zip(tangents_before, tangents_after):
assert not was_metadata_updated(before, after) and not was_updated(before, after), \
"Found an input to the backward that was mutated during the backward pass. This is not supported"

if aot_config.keep_inference_input_mutations:
# Note: This is a bit annoying. There's a layering issue here, where:
# (1) functionalization needs to operate on **synthetic base** inputs, before unpacking them into the "real" inputs.
Expand Down Expand Up @@ -2230,7 +2193,7 @@ def assert_functional_graph(fx_g: torch.fx.Graph, *, allow_input_mutations: bool
if n.op == "placeholder":
placeholders.add(n)
if isinstance(n.target, torch._ops.OpOverload):
if n.target is aten.copy_.default:
if n.target is aten.copy_.default and allow_input_mutations:
suffix = True
# Can only copy_ into an input, and can only do so once
assert n.args[0] in placeholders
Expand Down

0 comments on commit 48820c9

Please sign in to comment.