Skip to content

Commit

Permalink
RF test_rope_causal_self_att, update for new HF version
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Jan 13, 2025
1 parent 1a1f5c6 commit 2332837
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions tests/test_rf_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def test_rope_causal_self_att():
LlamaRotaryEmbedding,
LlamaConfig,
apply_rotary_pos_emb,
eager_attention_forward,
)

config = LlamaConfig(
Expand Down Expand Up @@ -322,7 +323,8 @@ def test_rope_causal_self_att():

position_ids = rf.expand_dim(rf.range_over_dim(seq_dim), batch_dim) # LlamaRotaryEmbedding wants this
with PyTracer(
[LlamaAttention.forward, LlamaRotaryEmbedding.forward, apply_rotary_pos_emb], torch.Tensor
[LlamaAttention.forward, LlamaRotaryEmbedding.forward, apply_rotary_pos_emb, eager_attention_forward],
torch.Tensor,
) as trace_hf:
# causal_mask code copied from LlamaAttention
sequence_length = target_length = in_.raw_tensor.shape[1]
Expand All @@ -340,29 +342,35 @@ def test_rope_causal_self_att():
position_embeddings = rotary_emb(
torch.zeros(()), position_ids=position_ids.copy_compatible_to_dims_raw((batch_dim, seq_dim))
)
out_hf, _, _ = model_hf(in_.raw_tensor, attention_mask=causal_mask, position_embeddings=position_embeddings)
out_hf, *_ = model_hf(in_.raw_tensor, attention_mask=causal_mask, position_embeddings=position_embeddings)
pprint(trace_hf.captured_locals)

print("First HF att weight tensor:")
print(trace_hf.captured_locals[LlamaAttention.forward][0]["attn_weights"][2][0, 0, 0].detach().numpy())
print(trace_hf.captured_locals[LlamaAttention.forward][0]["attn_weights"][-1][0, 0, 0].detach().numpy())

check_py_traces_rf_to_pt_equal(
trace_rf.captured_locals,
trace_hf.captured_locals,
[
(
(rf.RotaryPosCausalSelfAttention.__call__, 0, "q", 0),
(LlamaAttention.forward, 0, "query_states", 1),
# input: batch_dim, seq_dim, model_dim
# input_shape: batch_dim, seq_dim
# HF query_states': (batch_dim, seq_dim, num_heads, self.head_dim),
# then transposed to (batch_dim, num_heads, seq_dim, self.head_dim)
(LlamaAttention.forward, 0, "query_states", 0),
lambda x, *, name, **_: rf.convert_to_tensor(
# reorder complex numbers
x.reshape(*x.shape[:-1], 2, -1).transpose(-1, -2).flatten(-2),
dims=(batch_dim, model_rf.num_heads, seq_dim, model_rf.key_dim_per_head),
name=name,
),
),
(
(rf.RotaryPosCausalSelfAttention.__call__, 0, "k", 0),
(LlamaAttention.forward, 0, "key_states", 1),
(LlamaAttention.forward, 0, "key_states", 0),
lambda x, *, name, **_: rf.convert_to_tensor(
# reorder complex numbers
x.reshape(*x.shape[:-1], 2, -1).transpose(-1, -2).flatten(-2),
dims=(batch_dim, model_rf.num_heads, seq_dim, model_rf.key_dim_per_head),
name=name,
Expand Down Expand Up @@ -418,25 +426,25 @@ def test_rope_causal_self_att():
),
(
(rf.dot_attention, 0, "energy", 0),
(LlamaAttention.forward, 0, "attn_weights", 0),
(eager_attention_forward, 0, "attn_weights", 0),
(batch_dim, model_rf.num_heads, seq_dim, "axis"),
),
(
(rf.dot_attention, 0, "att_weights", 0),
(LlamaAttention.forward, 0, "attn_weights", 2),
(LlamaAttention.forward, 0, "attn_weights", -1),
(batch_dim, model_rf.num_heads, seq_dim, "axis"),
),
(
(rf.dot_attention, 0, "att", 0),
(LlamaAttention.forward, 0, "attn_output", 0),
(batch_dim, model_rf.num_heads, seq_dim, model_rf.value_dim_per_head),
(batch_dim, seq_dim, model_rf.num_heads, model_rf.value_dim_per_head),
),
],
)

print("Final check...")
assert out_rf.raw_tensor.shape == out_hf.shape
torch.testing.assert_allclose(out_rf.raw_tensor, out_hf)
torch.testing.assert_close(out_rf.raw_tensor, out_hf)
print(" all matched!")


Expand Down

0 comments on commit 2332837

Please sign in to comment.