Skip to content

Commit

Permalink
Add tests for cross attention (#7609)
Browse files Browse the repository at this point in the history
  • Loading branch information
dvorjackz authored Jan 11, 2025
1 parent b83b3af commit b8cea10
Showing 1 changed file with 129 additions and 24 deletions.
153 changes: 129 additions & 24 deletions extension/llm/modules/test/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def setUp(self):
self.num_kv_heads = 8
self.head_dim = 64
self.max_seq_len = 128
self.encoder_max_seq_len = 128
self.rope_base = 500_000
self.scale_factor = 32

Expand Down Expand Up @@ -86,16 +87,26 @@ def setUp(self):
max_seq_len=self.max_seq_len,
)
self.et_mha.load_state_dict(self.tt_mha.state_dict())

# Common inputs.
seq_len = 10
self.x = torch.randn(1, seq_len, self.embed_dim)
self.y = torch.randn(1, seq_len, self.embed_dim)
self.input_pos = torch.arange(seq_len).unsqueeze(0) # shape [1, seq_len]
seq_len_dim = torch.export.Dim("seq_len", min=1, max=100)
self.dynamic_shapes = (
{0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC},
{0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC},
{0: torch.export.Dim.STATIC, 1: seq_len_dim},
)
self.seq_len_dim = torch.export.Dim("seq_len", min=1, max=self.max_seq_len)
self.dynamic_shapes = {
"x": {
0: torch.export.Dim.STATIC,
1: self.seq_len_dim,
2: torch.export.Dim.STATIC,
},
"y": {
0: torch.export.Dim.STATIC,
1: self.seq_len_dim,
2: torch.export.Dim.STATIC,
},
"input_pos": {0: torch.export.Dim.STATIC, 1: self.seq_len_dim},
}
self.causal_mask = torch.tril(
torch.ones(
size=(self.max_seq_len, self.max_seq_len),
Expand All @@ -110,8 +121,8 @@ def test_attention_eager(self):
assert_close(et_res, tt_res)

# test with kv cache
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=20)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=20)
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)

et_res = self.et_mha(self.x, self.x) # Self attention.
tt_res = self.tt_mha(self.x, self.x) # Self attention.
Expand Down Expand Up @@ -144,12 +155,12 @@ def test_attention_export(self):
# Self attention.

# test with kv cache
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
with torch.no_grad():
et_mha_ep = torch.export.export(
self.et_mha,
(self.x, self.x),
(self.x, self.y),
kwargs={"input_pos": self.input_pos},
dynamic_shapes=self.dynamic_shapes,
strict=True,
Expand All @@ -166,8 +177,8 @@ def test_attention_aoti(self):
# Self attention.

# test with kv cache
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
with torch.no_grad():
so = torch._export.aot_compile(
self.et_mha,
Expand All @@ -189,13 +200,13 @@ def test_attention_aoti(self):

def test_attention_executorch(self):
# Self attention.
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)

with torch.no_grad():
et_mha_ep = torch.export.export(
self.et_mha,
(self.x, self.x),
(self.x, self.y),
kwargs={"input_pos": self.input_pos},
dynamic_shapes=self.dynamic_shapes,
strict=True,
Expand All @@ -222,22 +233,18 @@ def test_attention_executorch(self):

def test_attention_torch_cond_eager(self):
# Different from vanilla torchtune MHA, we rewrite the if condition with torch.cond. We need to make sure they are giving the same results regarding the if condition.
# For the first run of MHA we provide `y` (self.x) but for the second run it will be a tensor full of nan.
# For the first run of MHA we provide `y` but for the second run it will be a tensor full of nan.
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)

mask = self.causal_mask[self.input_pos, :]
# First run.
et_res = self.et_mha(
self.x, self.x, mask=mask, input_pos=self.input_pos
) # Self attention with input pos.
tt_res = self.tt_mha(
self.x, self.x, mask=mask, input_pos=self.input_pos
) # Self attention with input pos.
et_res = self.et_mha(self.x, self.y, mask=mask, input_pos=self.input_pos)
tt_res = self.tt_mha(self.x, self.y, mask=mask, input_pos=self.input_pos)

assert_close(et_res, tt_res)

# Second run test kv cache read. Input pos is [10, 11, ..., 19]
# Second run tests kv cache read. Input pos is [10, 11, ..., 19]
next_input_pos = torch.arange(10, 20).unsqueeze(0)

empty_y = torch.full_like(self.x, torch.nan)
Expand All @@ -246,3 +253,101 @@ def test_attention_torch_cond_eager(self):
tt_res = self.tt_mha(self.x, None, mask=mask, input_pos=next_input_pos)

assert_close(et_res, tt_res)

def test_attention_torch_cond_export(self):
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
mask = self.causal_mask[self.input_pos, :]
dynamic_shapes = {
**self.dynamic_shapes,
**{
"mask": {
0: torch.export.Dim.STATIC,
1: self.seq_len_dim,
2: torch.export.Dim.STATIC,
}
},
}
with torch.no_grad():
et_mha_ep = torch.export.export(
self.et_mha,
(self.x, self.y),
kwargs={
"mask": mask,
"input_pos": self.input_pos,
},
dynamic_shapes=dynamic_shapes,
strict=True,
)

# First run.
et_res = et_mha_ep.module()(self.x, self.y, mask=mask, input_pos=self.input_pos)
tt_res = self.tt_mha(self.x, self.y, mask=mask, input_pos=self.input_pos)

assert_close(et_res, tt_res)

# Second run tests kv cache read. Input pos is [10, 11, ..., 19]
next_input_pos = torch.arange(10, 20).unsqueeze(0)
empty_y = torch.full_like(self.y, torch.nan)
mask = self.causal_mask[next_input_pos, :]
et_res = et_mha_ep.module()(
self.x, empty_y, mask=mask, input_pos=next_input_pos
)
tt_res = self.tt_mha(self.x, None, mask=mask, input_pos=next_input_pos)

assert_close(et_res, tt_res)

def test_attention_torch_cond_executorch(self):
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
mask = self.causal_mask[self.input_pos, :]
dynamic_shapes = {
**self.dynamic_shapes,
**{
"mask": {
0: torch.export.Dim.STATIC,
1: self.seq_len_dim,
2: torch.export.Dim.STATIC,
}
},
}
with torch.no_grad():
et_mha_ep = torch.export.export(
self.et_mha,
(self.x, self.y),
kwargs={
"mask": mask,
"input_pos": self.input_pos,
},
dynamic_shapes=dynamic_shapes,
strict=True,
)
et_program = to_edge(
et_mha_ep,
compile_config=EdgeCompileConfig(
_core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg],
_check_ir_validity=False,
),
).to_executorch(
config=ExecutorchBackendConfig(
passes=[InitializedMutableBufferPass(["cache_pos"])],
)
)

# First run.
runtime = Runtime.get()
program = runtime.load_program(et_program.buffer)
method = program.load_method("forward")
et_res = method.execute((self.x, self.y, mask, self.input_pos))
tt_res = self.tt_mha(self.x, self.y, mask=mask, input_pos=self.input_pos)

assert_close(et_res[0], tt_res)

# Second run tests kv cache read. Input pos is [10, 11, ..., 19]
next_input_pos = torch.arange(10, 20).unsqueeze(0)
empty_y = torch.full_like(self.y, torch.nan)
mask = self.causal_mask[next_input_pos, :]
et_res = method.execute((self.x, empty_y, mask, next_input_pos))
tt_res = self.tt_mha(self.x, None, mask=mask, input_pos=next_input_pos)

assert_close(et_res[0], tt_res)

0 comments on commit b8cea10

Please sign in to comment.