Skip to content

Commit

Permalink
fix logits tests
Browse files Browse the repository at this point in the history
Signed-off-by: Austin Liu <[email protected]>
  • Loading branch information
austin362667 committed Jan 28, 2025
1 parent 8791f16 commit 585c765
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 168 deletions.
23 changes: 9 additions & 14 deletions src/liger_kernel/transformers/model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch.nn.functional as F

from torch.nn import CrossEntropyLoss
from torch.nn.attention.flex_attention import create_block_mask
from torch.nn.attention.flex_attention import flex_attention
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import _CONFIG_FOR_DOC
Expand Down Expand Up @@ -256,8 +255,6 @@ def lce_forward(


# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/flex_attention.py#L12


def flex_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
Expand All @@ -279,26 +276,24 @@ def causal_mod(score, b, h, q_idx, kv_idx):
score = score + causal_mask[b][0][q_idx][kv_idx]
return score

# We only got `attention_mask` tensors, so we recreate `causal_mask` function as specific llama causal attention
# TODO: Consider other customized `attention_mask` in the future, e.g., shared prefix
def causal_mask_fn(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
# def causal_mask_fn(b, h, q_idx, kv_idx):
# return q_idx >= kv_idx

# To construct block attention mask that leverages sparsity.
sparse_causal_mask = create_block_mask(causal_mask_fn, None, None, query.shape[-2], query.shape[-2], device="cuda")
# TODO: Construct block attention mask that leverages sparsity
# sparse_causal_mask = create_block_mask(
# causal_mask_fn, B=None, H=None, Q_LEN=query.shape[-2], KV_LEN=key.shape[-2], device=query.device, BLOCK_SIZE=1
# )

attn_output, attention_weights = flex_attention(
query,
key,
value,
score_mod=causal_mod,
block_mask=sparse_causal_mask,
# block_mask=sparse_causal_mask,
enable_gqa=True,
scale=scaling,
# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
# For simplification, we thus always return it as no additional computations are introduced.
return_lse=True,
kernel_options={ # different harware might need different configs
kernel_options={
"BLOCK_M": 32,
"BLOCK_N": 32,
"BLOCK_M1": 16,
Expand All @@ -307,7 +302,7 @@ def causal_mask_fn(b, h, q_idx, kv_idx):
"BLOCK_N2": 16,
},
)
# lse is returned in float32

attention_weights = attention_weights.to(value.dtype)
attn_output = attn_output.transpose(1, 2).contiguous()

Expand Down
Loading

0 comments on commit 585c765

Please sign in to comment.