Skip to content

Commit

Permalink
[RWKV]: update rwkv support
Browse files Browse the repository at this point in the history
  • Loading branch information
uniartisan committed Feb 8, 2025
1 parent 58436a0 commit 1408c4a
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 11 deletions.
2 changes: 1 addition & 1 deletion fla/layers/rwkv6.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def forward(
cu_seqlens = kwargs.get('cu_seqlens', None)
if mode == 'fused_recurrent':
o, recurrent_state = fused_recurrent_rwkv6(
r=r,
q=r,
k=k,
v=v,
w=w,
Expand Down
2 changes: 1 addition & 1 deletion fla/models/rwkv6/configuration_rwkv6.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class RWKV6Config(PretrainedConfig):

model_type = 'rwkv6'
model_type = 'rwkv6_fla'
keys_to_ignore_at_inference = ['past_key_values']

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion fla/models/rwkv7/configuration_rwkv7.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class RWKV7Config(PretrainedConfig):

model_type = 'rwkv7'
model_type = 'rwkv7_fla'
keys_to_ignore_at_inference = ['past_key_values']

def __init__(
Expand Down
4 changes: 3 additions & 1 deletion fla/ops/rwkv6/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

from .chunk import chunk_rwkv6
from .fused_recurrent import fused_recurrent_rwkv6
from .recurrent_naive import native_recurrent_rwkv6

__all__ = [
'chunk_rwkv6',
'fused_recurrent_rwkv6'
'fused_recurrent_rwkv6',
'native_recurrent_rwkv6'
]
12 changes: 6 additions & 6 deletions fla/ops/rwkv6/fused_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ def backward(ctx, do, dht):


def fused_recurrent_rwkv6(
r: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
w: torch.Tensor,
Expand All @@ -624,7 +624,7 @@ def fused_recurrent_rwkv6(
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Args:
r (torch.Tensor):
q (torch.Tensor):
reception of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
Alias: q, query in linear attention.
k (torch.Tensor):
Expand Down Expand Up @@ -688,10 +688,10 @@ def fused_recurrent_rwkv6(
>>> assert o.allclose(o_var.view(o.shape))
>>> assert ht.allclose(ht_var)
"""
set_torch_device(r)
set_torch_device(q)
if cu_seqlens is not None:
if r.shape[0] != 1:
raise ValueError(f"The batch size is expected to be 1 rather than {r.shape[0]} when using `cu_seqlens`."
if q.shape[0] != 1:
raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing.")
if head_first:
raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
Expand All @@ -701,7 +701,7 @@ def fused_recurrent_rwkv6(
if scale is None:
scale = k.shape[-1] ** -0.5
o, final_state = FusedRecurrentRWKV6Function.apply(
r,
q,
k,
v,
w,
Expand Down
67 changes: 66 additions & 1 deletion fla/ops/rwkv6/recurrent_naive.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# -*- coding: utf-8 -*-

from typing import Optional
from typing import Tuple, Optional

import torch
from fla.utils import autocast_custom_fwd, autocast_custom_bwd


def naive_recurrent_rwkv6(
Expand Down Expand Up @@ -101,3 +102,67 @@ def naive_recurrent_rwkv6_bwd(
dw[:, :, i] = dw[:, :, i+1] + dq_aux[:, :, i+1] * q[:, :, i+1] - dk_aux[:, :, i] * k[:, :, i]

return dq, dk, dv, dw, du, dh


class NativeRecurrentRWKV6Function(torch.autograd.Function):
@staticmethod
@autocast_custom_fwd
def forward(ctx, q, k, v, w, u, scale, initial_state, output_final_state: bool = False):
o, ht = naive_recurrent_rwkv6(q, k, v, w, u, scale, initial_state, output_final_state)
if initial_state is not None:
initial_state = initial_state.clone()

ctx.save_for_backward(q, k, v, w, u, o, initial_state)
ctx.scale = scale
return o, ht

@staticmethod
@autocast_custom_bwd
def backward(ctx, do, dht):
q, k, v, w, u, o, initial_state = ctx.saved_tensors
dq, dk, dv, dw, du, dh = naive_recurrent_rwkv6_bwd(q, k, v, w, u, o, do, dht, initial_state, ctx.scale)
dh = dh if initial_state is not None else None
return dq, dk, dv, dw, du, None, dh, None


def native_recurrent_rwkv6(
r: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
w: torch.Tensor,
u: torch.Tensor,
scale: float = 1.0,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Args:
r (torch.Tensor):
reception of shape `(B, H, T, K)`. Alias: q, query in linear attention.
k (torch.Tensor):
keys of shape `(B, H, T, K)`
v (torch.Tensor):
values of shape `(B, H, T, V)`
w (torch.Tensor):
data-dependent decays of shape `(B, H, T, K)` in log space! Alias: g.
u (torch.Tensor):
bonus of shape `(H, K)` or `(B, H, K)` for each head.
scale (Optional[int]):
Scale factor for the RWKV6 attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `(B, H, K, V)`. Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `(B, H, K, V)`. Default: `False`.
"""
if scale == -1.0:
scale = r.shape[-1] ** -0.5

assert cu_seqlens is None, "cu_seqlens is not supported in the native implementation."
assert head_first, "head_first=False is not supported in the native implementation."

o, final_state = NativeRecurrentRWKV6Function.apply(r, k, v, w, u, scale, initial_state, output_final_state)

return o, final_state

0 comments on commit 1408c4a

Please sign in to comment.