Skip to content

Commit

Permalink
[Inductor] Add FlexAttention backward kernel dynamic shape tests (pyt…
Browse files Browse the repository at this point in the history
…orch#127728)

Fixes #ISSUE_NUMBER

Pull Request resolved: pytorch#127728
Approved by: https://github.com/Chillee
  • Loading branch information
yanboliang authored and pytorchmergebot committed Jun 4, 2024
1 parent e793ae2 commit 8d153e0
Showing 1 changed file with 110 additions and 48 deletions.
158 changes: 110 additions & 48 deletions test/inductor/test_flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,47 @@ def _check_equal(
msg = f"{name} Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X."
self.assertTrue(False, msg)

def _check_out_and_grad(
self,
golden_out: torch.Tensor,
ref_out: torch.Tensor,
compiled_out: torch.Tensor,
q_gold: torch.Tensor,
q_ref: torch.Tensor,
q: torch.Tensor,
k_gold: torch.Tensor,
k_ref: torch.Tensor,
k: torch.Tensor,
v_gold: torch.Tensor,
v_ref: torch.Tensor,
v: torch.Tensor,
):
dtype = ref_out.dtype
with torch.no_grad():
# Note, it seems like we really are less accurate than the float32
# computation, likely due to the online softmax
if dtype == torch.float32:
fudge_factor = 10.0
else:
fudge_factor = 1.1

# Checkout output
self._check_equal(golden_out, ref_out, compiled_out, fudge_factor, "Out")

# Check gradients
q_fudge_factor = 2.5 * fudge_factor
self._check_equal(
q_gold.grad, q_ref.grad, q.grad, q_fudge_factor, "Grad_Query"
)
k_fudge_factor = 4 * fudge_factor
self._check_equal(
k_gold.grad, k_ref.grad, k.grad, k_fudge_factor, "Grad_Key"
)
v_fudge_factor = 4 * fudge_factor
self._check_equal(
v_gold.grad, v_ref.grad, v.grad, v_fudge_factor, "Grad_Value"
)

def run_test(
self,
score_mod: Callable,
Expand Down Expand Up @@ -190,30 +231,20 @@ def run_test(
ref_out.backward(backward_grad)
compiled_out.backward(backward_grad)

with torch.no_grad():
# Note, it seems like we really are less accurate than the float32
# computation, likely due to the online softmax
if dtype == torch.float32:
fudge_factor = 10.0
else:
fudge_factor = 1.1

# Checkout output
self._check_equal(golden_out, ref_out, compiled_out, fudge_factor, "Out")

# Check gradients
q_fudge_factor = 2.5 * fudge_factor
self._check_equal(
q_gold.grad, q_ref.grad, q.grad, q_fudge_factor, "Grad_Query"
)
k_fudge_factor = 4 * fudge_factor
self._check_equal(
k_gold.grad, k_ref.grad, k.grad, k_fudge_factor, "Grad_Key"
)
v_fudge_factor = 4 * fudge_factor
self._check_equal(
v_gold.grad, v_ref.grad, v.grad, v_fudge_factor, "Grad_Value"
)
self._check_out_and_grad(
golden_out,
ref_out,
compiled_out,
q_gold,
q_ref,
q,
k_gold,
k_ref,
k,
v_gold,
v_ref,
v,
)

def run_dynamic_test(
self,
Expand All @@ -226,45 +257,76 @@ def run_dynamic_test(
):
sdpa_partial = create_attention(score_mod)
# The first eager batch, shape (B, H, S, D)
q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
golden_out1 = sdpa_partial(
q1.to(torch.float64), k1.to(torch.float64), v1.to(torch.float64)
)
ref_out1 = sdpa_partial(q1, k1, v1)
q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
q1_ref, k1_ref, v1_ref = query_key_value_clones(q1, k1, v1)
q1_gold, k1_gold, v1_gold = query_key_value_clones(q1, k1, v1, torch.float64)
ref_out1 = sdpa_partial(q1_ref, k1_ref, v1_ref)
golden_out1 = sdpa_partial(q1_gold, k1_gold, v1_gold)

backward_grad1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")

golden_out1.backward(backward_grad1.to(torch.float64))
ref_out1.backward(backward_grad1)

# The second eager batch, shape (B * 2, H, S / 2, D)
B = int(B * 2)
S = int(S / 2)
q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
golden_out2 = sdpa_partial(
q2.to(torch.float64), k2.to(torch.float64), v2.to(torch.float64)
)
ref_out2 = sdpa_partial(q2, k2, v2)
q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
q2_ref, k2_ref, v2_ref = query_key_value_clones(q2, k2, v2)
q2_gold, k2_gold, v2_gold = query_key_value_clones(q2, k2, v2, torch.float64)
ref_out2 = sdpa_partial(q2_ref, k2_ref, v2_ref)
golden_out2 = sdpa_partial(q2_gold, k2_gold, v2_gold)

backward_grad2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")

golden_out2.backward(backward_grad2.to(torch.float64))
ref_out2.backward(backward_grad2)

# Need to clear dynamo counters, since flex attention eager mode also uses dynamo tracing.
# We check dynamo counters["frames"]["ok"] to ensure there is no re-compilation.
torch._dynamo.reset()
# Compiling with dynamic shape in the first batch.
compiled_sdpa = torch.compile(sdpa_partial, dynamic=True)
compiled_out1 = compiled_sdpa(q1, k1, v1)

# Note, it seems like we really are less accurate than the float32
# computation, likely due to the online softmax
if dtype == torch.float32:
fudge_factor = 10.0
else:
fudge_factor = 1.1

self._check_equal(golden_out1, ref_out1, compiled_out1, fudge_factor)
compiled_out1.backward(backward_grad1)

self._check_out_and_grad(
golden_out1,
ref_out1,
compiled_out1,
q1_gold,
q1_ref,
q1,
k1_gold,
k1_ref,
k1,
v1_gold,
v1_ref,
v1,
)
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)

# No re-compilation, use the compiled dynamic shape version.
compiled_out2 = compiled_sdpa(q2, k2, v2)
self._check_equal(golden_out2, ref_out2, compiled_out2, fudge_factor)
compiled_out2.backward(backward_grad2)
self._check_out_and_grad(
golden_out2,
ref_out2,
compiled_out2,
q2_gold,
q2_ref,
q2,
k2_gold,
k2_ref,
k2,
v2_gold,
v2_ref,
v2,
)
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)

def run_automatic_dynamic_test(
Expand Down

0 comments on commit 8d153e0

Please sign in to comment.