diff --git a/tests/test_rf_attention.py b/tests/test_rf_attention.py index 763b41dbc..da5b34662 100644 --- a/tests/test_rf_attention.py +++ b/tests/test_rf_attention.py @@ -280,6 +280,7 @@ def test_rope_causal_self_att(): LlamaRotaryEmbedding, LlamaConfig, apply_rotary_pos_emb, + eager_attention_forward, ) config = LlamaConfig( @@ -322,7 +323,8 @@ def test_rope_causal_self_att(): position_ids = rf.expand_dim(rf.range_over_dim(seq_dim), batch_dim) # LlamaRotaryEmbedding wants this with PyTracer( - [LlamaAttention.forward, LlamaRotaryEmbedding.forward, apply_rotary_pos_emb], torch.Tensor + [LlamaAttention.forward, LlamaRotaryEmbedding.forward, apply_rotary_pos_emb, eager_attention_forward], + torch.Tensor, ) as trace_hf: # causal_mask code copied from LlamaAttention sequence_length = target_length = in_.raw_tensor.shape[1] @@ -340,11 +342,11 @@ def test_rope_causal_self_att(): position_embeddings = rotary_emb( torch.zeros(()), position_ids=position_ids.copy_compatible_to_dims_raw((batch_dim, seq_dim)) ) - out_hf, _, _ = model_hf(in_.raw_tensor, attention_mask=causal_mask, position_embeddings=position_embeddings) + out_hf, *_ = model_hf(in_.raw_tensor, attention_mask=causal_mask, position_embeddings=position_embeddings) pprint(trace_hf.captured_locals) print("First HF att weight tensor:") - print(trace_hf.captured_locals[LlamaAttention.forward][0]["attn_weights"][2][0, 0, 0].detach().numpy()) + print(trace_hf.captured_locals[LlamaAttention.forward][0]["attn_weights"][-1][0, 0, 0].detach().numpy()) check_py_traces_rf_to_pt_equal( trace_rf.captured_locals, @@ -352,8 +354,13 @@ def test_rope_causal_self_att(): [ ( (rf.RotaryPosCausalSelfAttention.__call__, 0, "q", 0), - (LlamaAttention.forward, 0, "query_states", 1), + # input: batch_dim, seq_dim, model_dim + # input_shape: batch_dim, seq_dim + # HF query_states': (batch_dim, seq_dim, num_heads, self.head_dim), + # then transposed to (batch_dim, num_heads, seq_dim, self.head_dim) + (LlamaAttention.forward, 0, "query_states", 0), lambda x, *, name, **_: rf.convert_to_tensor( + # reorder complex numbers x.reshape(*x.shape[:-1], 2, -1).transpose(-1, -2).flatten(-2), dims=(batch_dim, model_rf.num_heads, seq_dim, model_rf.key_dim_per_head), name=name, @@ -361,8 +368,9 @@ def test_rope_causal_self_att(): ), ( (rf.RotaryPosCausalSelfAttention.__call__, 0, "k", 0), - (LlamaAttention.forward, 0, "key_states", 1), + (LlamaAttention.forward, 0, "key_states", 0), lambda x, *, name, **_: rf.convert_to_tensor( + # reorder complex numbers x.reshape(*x.shape[:-1], 2, -1).transpose(-1, -2).flatten(-2), dims=(batch_dim, model_rf.num_heads, seq_dim, model_rf.key_dim_per_head), name=name, @@ -418,25 +426,25 @@ def test_rope_causal_self_att(): ), ( (rf.dot_attention, 0, "energy", 0), - (LlamaAttention.forward, 0, "attn_weights", 0), + (eager_attention_forward, 0, "attn_weights", 0), (batch_dim, model_rf.num_heads, seq_dim, "axis"), ), ( (rf.dot_attention, 0, "att_weights", 0), - (LlamaAttention.forward, 0, "attn_weights", 2), + (LlamaAttention.forward, 0, "attn_weights", -1), (batch_dim, model_rf.num_heads, seq_dim, "axis"), ), ( (rf.dot_attention, 0, "att", 0), (LlamaAttention.forward, 0, "attn_output", 0), - (batch_dim, model_rf.num_heads, seq_dim, model_rf.value_dim_per_head), + (batch_dim, seq_dim, model_rf.num_heads, model_rf.value_dim_per_head), ), ], ) print("Final check...") assert out_rf.raw_tensor.shape == out_hf.shape - torch.testing.assert_allclose(out_rf.raw_tensor, out_hf) + torch.testing.assert_close(out_rf.raw_tensor, out_hf) print(" all matched!")