Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Qwen2VL mrope for transformers 4.47.0 #464

Merged
merged 1 commit into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions src/liger_kernel/ops/qwen2vl_mrope.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def _triton_qwen2vl_mrope(
cos,
sin,
sl,
bs: tl.constexpr,
n_qh: tl.constexpr,
n_kh: tl.constexpr,
hd: tl.constexpr,
Expand Down Expand Up @@ -41,13 +42,12 @@ def _triton_qwen2vl_mrope(
t_end = mrope_section_t
h_end = t_end + mrope_section_h

cos_row_idx = pid % sl
t_cos = cos + cos_row_idx * hd
h_cos = t_cos + sl * hd
w_cos = h_cos + sl * hd
t_sin = sin + cos_row_idx * hd
h_sin = t_sin + sl * hd
w_sin = h_sin + sl * hd
t_cos = cos + pid * hd
h_cos = t_cos + bs * sl * hd
w_cos = h_cos + bs * sl * hd
t_sin = sin + pid * hd
h_sin = t_sin + bs * sl * hd
w_sin = h_sin + bs * sl * hd

cos_offsets = tl.arange(0, pad_hd // 2)
t_mask = cos_offsets < t_end
Expand Down Expand Up @@ -151,6 +151,7 @@ def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
cos,
sin,
seq_len,
batch_size,
n_q_head,
n_kv_head,
head_dim,
Expand Down Expand Up @@ -189,6 +190,7 @@ def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section):
cos,
sin,
seq_len,
batch_size,
n_q_head,
n_kv_head,
head_dim,
Expand Down Expand Up @@ -216,8 +218,8 @@ def forward(ctx, q, k, cos, sin, mrope_section, unsqueeze_dim=1):
"""
q size: (bsz, n_q_head, seq_len, head_dim)
k size: (bsz, n_kv_head, seq_len, head_dim)
cos size: (3, 1, seq_len, head_dim)
sin size: (3, 1, seq_len, head_dim)
cos size: (3, bsz, seq_len, head_dim)
sin size: (3, bsz, seq_len, head_dim)
"""
q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section)
ctx.save_for_backward(cos, sin)
Expand All @@ -228,10 +230,9 @@ def backward(ctx, dq, dk):
"""
dq size: (bsz, n_q_head, seq_len, head_dim)
dk size: (bsz, n_kv_head, seq_len, head_dim)
cos size: (3, 1, seq_len, head_dim)
sin size: (3, 1, seq_len, head_dim)
cos size: (3, bsz, seq_len, head_dim)
sin size: (3, bsz, seq_len, head_dim)
"""

cos, sin = ctx.saved_tensors
mrope_section = ctx.mrope_section
dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section)
Expand Down
4 changes: 2 additions & 2 deletions src/liger_kernel/transformers/qwen2vl_mrope.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ def liger_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim
Args:
q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim).
cos (torch.Tensor): The cosine tensor of shape (3, 1, seq_len, head_dim).
sin (torch.Tensor): The sine tensor of shape (3, 1, seq_len, head_dim).
cos (torch.Tensor): The cosine tensor of shape (3, bsz, seq_len, head_dim).
sin (torch.Tensor): The sine tensor of shape (3, bsz, seq_len, head_dim).
mrope_section (List[int]): The multimodal rope section for channel dimension of temporal, height and width in rope calculation.
unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.

Expand Down
8 changes: 6 additions & 2 deletions test/transformers/test_qwen2vl_mrope.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def test_correctness(
k2 = _tensor_k.clone().requires_grad_(True)

# NOTE: this position ids distribution is different from the real one, just to test op correctness
pos_ids = torch.arange(seq_len * 3, device=device, dtype=torch.long).view(3, 1, -1)
pos_ids = torch.arange(seq_len * 3 * bsz, device=device, dtype=torch.long).view(
3, bsz, seq_len
)
cos, sin = rotary_emb(k1, pos_ids)

# validate forward pass
Expand Down Expand Up @@ -130,7 +132,9 @@ def test_functional_correctness(

rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device=device)

pos_ids = torch.arange(seq_len * 3, device=device, dtype=torch.long).view(3, 1, -1)
pos_ids = torch.arange(seq_len * 3 * bsz, device=device, dtype=torch.long).view(
3, bsz, seq_len
)
cos, sin = rotary_emb(k1, pos_ids)

functional_q, functional_k = liger_qwen2vl_mrope(q1, k1, cos, sin, mrope_section)
Expand Down
Loading