Skip to content

Commit

Permalink
Fix lint errors (#170)
Browse files Browse the repository at this point in the history
  • Loading branch information
jannalulu authored Feb 6, 2025
1 parent bfacf89 commit 19ad2e2
Show file tree
Hide file tree
Showing 10 changed files with 57 additions and 51 deletions.
18 changes: 11 additions & 7 deletions fla/ops/generalized_delta_rule/dplr/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import torch
from einops import rearrange


def get_abs_err(x, y):
return (x-y).flatten().abs().max().item()


def get_err_ratio(x, y):
err = (x-y).flatten().square().mean().sqrt().item()
base = (x).flatten().square().mean().sqrt().item()
Expand All @@ -20,6 +22,8 @@ def assert_close(prefix, ref, tri, ratio):
# S_t = S_t @ (I + alpha_t beta_t^T) + v_t k_t^T
# q, k, alpha, beta [B, H, L, D_K]
# v [B, H, L, D_V]


def dplr_recurrence(q, k, v, alpha, beta, gk, initial_state=None, output_final_state=True):
orig_dtype = q.dtype
b, h, l, d_k = q.shape
Expand Down Expand Up @@ -58,7 +62,8 @@ def dplr_chunkwise(q, k, v, alpha, beta, gk, initial_state=None, output_final_st

# note that diagonal is masked.
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0)
q, k, v, alpha, beta, gk = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size).float(), [q, k, v, alpha, beta, gk])
q, k, v, alpha, beta, gk = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d',
c=chunk_size).float(), [q, k, v, alpha, beta, gk])

gk_cumsum = gk.cumsum(-2)

Expand All @@ -78,7 +83,7 @@ def dplr_chunkwise(q, k, v, alpha, beta, gk, initial_state=None, output_final_st
A_qb[:, :, :, i, :] = (q_i * beta * attn_i).sum(-1).clone()
mask = (torch.arange(chunk_size) < i).to(q.device)
# shift by one.
attn_i = (gk_i - gk[:,:,:,i,None] - gk_cumsum).masked_fill(~mask.unsqueeze(-1), float('-inf')).exp()
attn_i = (gk_i - gk[:, :, :, i, None] - gk_cumsum).masked_fill(~mask.unsqueeze(-1), float('-inf')).exp()
A_ab[:, :, :, i, :] = (alpha_i * beta * attn_i).sum(-1).clone()
A_ak[:, :, :, i, :] = (alpha_i * k * attn_i).sum(-1).clone()

Expand All @@ -95,13 +100,14 @@ def dplr_chunkwise(q, k, v, alpha, beta, gk, initial_state=None, output_final_st
for i in range(0, l // chunk_size):
q_i, k_i, v_i, u_i, w_i, beta_i = q[:, :, i], k[:, :, i], v[:, :, i], u[:, :, i], w[:, :, i], beta[:, :, i]
v2_i = u_i + w_i @ S

o_1 = A_qk[:, :, i] @ v_i
o_2 = A_qb[:, :, i] @ v2_i
o_3 = (q_i * gk_cumsum[:, :, i].exp()) @ S
o[:, :, i] = o_1 + o_2 + o_3
decay = (gk_cumsum[:, :, i, -1, None] - gk_cumsum[:, :, i]).exp()
S = S*gk_cumsum[:, :, i, -1, :, None].exp() + (k_i * decay).transpose(-1, -2) @ v_i + (beta_i * decay).transpose(-1, -2) @ v2_i
S = S*gk_cumsum[:, :, i, -1, :, None].exp() + (k_i * decay).transpose(-1, -2) @ v_i + \
(beta_i * decay).transpose(-1, -2) @ v2_i
S = None if output_final_state is False else S
return rearrange(o, 'b h n c d -> b h (n c) d'), S

Expand All @@ -125,7 +131,7 @@ def dplr_chunkwise(q, k, v, alpha, beta, gk, initial_state=None, output_final_st
beta = beta.clone().detach().requires_grad_(True)
gate_logit_normalizer = 16
w = torch.nn.functional.logsigmoid(torch.randn(B, H, L, DK)) / gate_logit_normalizer

w = w.cuda().requires_grad_(True)
o, s = dplr_recurrence(q.clone(), k.clone(), v.clone(), -alpha.clone(), beta.clone(), w.clone())
do = torch.randn_like(o).cuda()
Expand All @@ -136,7 +142,6 @@ def dplr_chunkwise(q, k, v, alpha, beta, gk, initial_state=None, output_final_st
alpha_grad, alpha.grad = alpha.grad, None
beta_grad, beta.grad = beta.grad, None


o2, s2 = dplr_chunkwise(q.clone(), k.clone(), v.clone(), -alpha.clone(), beta.clone(), w.clone(), chunk_size=16)
o2.backward(do)
assert_close("o", o, o2, 0.002)
Expand All @@ -147,4 +152,3 @@ def dplr_chunkwise(q, k, v, alpha, beta, gk, initial_state=None, output_final_st
assert_close("alpha.grad", alpha.grad, alpha_grad, 0.002)
assert_close("beta.grad", beta.grad, beta_grad, 0.002)
print("All passed!")

3 changes: 2 additions & 1 deletion fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import triton
import triton.language as tl


@triton.heuristics({
'USE_OFFSETS': lambda args: args['offsets'] is not None
})
Expand Down Expand Up @@ -173,4 +174,4 @@ def chunk_dplr_bwd_wy(
BV=BV,
HEAD_FIRST=head_first
)
return dA_ab, dA_ak, dv, dag
return dA_ab, dA_ak, dv, dag
10 changes: 4 additions & 6 deletions fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import triton
import triton.language as tl


@triton.heuristics({
'USE_OFFSETS': lambda args: args['offsets'] is not None
})
Expand All @@ -21,13 +22,13 @@
@triton.jit
def fwd_prepare_wy_repr_kernel_chunk32(
A_ab,
A_ab_inv,
A_ab_inv,
offsets,
indices,
T: tl.constexpr,
H: tl.constexpr,
BT: tl.constexpr,
BC: tl.constexpr, #placeholder, do not delete
BC: tl.constexpr, # placeholder, do not delete
USE_OFFSETS: tl.constexpr,
HEAD_FIRST: tl.constexpr,
):
Expand All @@ -54,7 +55,7 @@ def fwd_prepare_wy_repr_kernel_chunk32(
b_A_ab = tl.where(mask[:, None], b_a, b_A_ab)
b_A_ab += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]
tl.store(p_Aab_inv, b_A_ab.to(p_Aab_inv.dtype.element_ty), boundary_check=(0, 1))


@triton.heuristics({
'USE_OFFSETS': lambda args: args['offsets'] is not None
Expand Down Expand Up @@ -201,7 +202,6 @@ def fwd_wu_kernel(
b_ag = tl.load(p_ag, boundary_check=(0, 1))
b_w = tl.dot(b_Aab_inv.to(b_ag.dtype), b_ag, allow_tf32=False)
tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))


for i_v in range(tl.cdiv(V, BV)):
if HEAD_FIRST:
Expand Down Expand Up @@ -312,5 +312,3 @@ def fwd_wu(
HEAD_FIRST=head_first
)
return w, u


21 changes: 12 additions & 9 deletions fla/ops/generalized_delta_rule/iplr/fused_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from fla.utils import contiguous


@triton.heuristics({
'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
Expand All @@ -34,7 +35,7 @@ def fused_recurrent_fwd_kernel(
ha, # tmp variable [B, H, L, V] for storing intermediate results of (h * a[None, :]).sum(0)
h0, # initial hidden state [B, H, K, V]
ht, # final hidden state [B, H, K, V]
offsets, # varlen offsets
offsets, # varlen offsets
scale, # K ** -0.5
H, # n_heads
T, # seq_len
Expand All @@ -56,7 +57,7 @@ def fused_recurrent_fwd_kernel(
T = eos - bos
else:
bos, eos = i_n * T, i_n * T + T

if HEAD_FIRST:
p_q = q + i_nh * T*K + tl.arange(0, BK)
p_k = k + i_nh * T*K + tl.arange(0, BK)
Expand Down Expand Up @@ -145,7 +146,7 @@ def fused_recurrent_bwd_kernel(
dha, # gradient of ha [NK, B, H, L, V]
h0, # initial state [B, H, K, V]
scale, # K ** -0.5
offsets, # offsets
offsets, # offsets
B, # batch_size
H, # n_heads
T, # seq_len
Expand Down Expand Up @@ -251,7 +252,7 @@ def fused_recurrent_bwd_kernel(
mask_kv = mask_k[:, None] & mask_v[None, :]
p_h0 = h0 + i_nh * K * V + (tl.arange(0, BK)[:, None]) * V + ((i_v * BV + tl.arange(0, BV))[None, :])
b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)

p_k = k + tl.arange(0, BK)
p_v = v + tl.arange(0, BV)
p_ha = ha + tl.arange(0, BV)
Expand Down Expand Up @@ -303,7 +304,8 @@ def forward(ctx, q, k, v, a, b, scale=None, initial_state=None, output_final_sta
final_state = None

ha = torch.empty_like(v, dtype=torch.float32)
grid = lambda meta: (

def grid(meta): return (
triton.cdiv(V, meta['BV']),
N * H
)
Expand Down Expand Up @@ -341,7 +343,7 @@ def backward(ctx, do, dht):
B, H, T, K, V = *q.shape, v.shape[-1]
else:
B, T, H, K, V = *q.shape, v.shape[-1]

N = B if ctx.offsets is None else len(ctx.offsets) - 1
scale = ctx.scale
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 64)
Expand Down Expand Up @@ -429,7 +431,7 @@ def fused_recurrent_iplr_delta_rule(
output_final_state (Optional[bool]):
Whether to output the final state of shape `[B, H, K, V]`. Default: `False`.
offsets (Optional[torch.Tensor]):
"""
if offsets is not None:
if q.shape[0] != 1:
Expand All @@ -444,5 +446,6 @@ def fused_recurrent_iplr_delta_rule(
scale = q.shape[-1] ** -0.5
else:
assert scale > 0, "scale must be positive"
o, final_state = FusedRecurrentIPLRDeltaRuleFunction.apply(q, k, v, a, b, scale, initial_state, output_final_state, offsets, head_first)
return o, final_state
o, final_state = FusedRecurrentIPLRDeltaRuleFunction.apply(
q, k, v, a, b, scale, initial_state, output_final_state, offsets, head_first)
return o, final_state
6 changes: 3 additions & 3 deletions fla/ops/generalized_delta_rule/iplr/wy_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import triton
import triton.language as tl


@triton.heuristics({
'USE_OFFSETS': lambda args: args['offsets'] is not None
})
Expand All @@ -30,7 +31,7 @@ def fwd_prepare_wy_repr_kernel_chunk32(
K: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BC: tl.constexpr, # dummy placeholder
BC: tl.constexpr, # dummy placeholder
USE_OFFSETS: tl.constexpr,
HEAD_FIRST: tl.constexpr,
):
Expand Down Expand Up @@ -206,7 +207,7 @@ def fwd_wu_kernel(
p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
else:
p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))

b_A = tl.load(p_A, boundary_check=(0, 1))
b_Aak = tl.zeros([BT, BT], dtype=tl.float32)

Expand Down Expand Up @@ -342,4 +343,3 @@ def fwd_wu(
HEAD_FIRST=head_first
)
return w, u

2 changes: 1 addition & 1 deletion fla/ops/rwkv7/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
__all__ = [
'chunk_rwkv7',
'fused_recurrent_rwkv7'
]
]
2 changes: 1 addition & 1 deletion fla/ops/titans/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
__all__ = [
'fused_chunk_titans_linear',
'chunk_titans_linear'
]
]
14 changes: 7 additions & 7 deletions fla/ops/titans/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@


def combine_params(theta, alpha, eta, seq_len):
beta = torch.cumprod(1 - alpha, dim = 2) # β_t = ∏(1 - α_t) in titans paper
beta = torch.cumprod(1 - alpha, dim=2) # β_t = ∏(1 - α_t) in titans paper

m = torch.cumprod(eta, dim = 2) # [batch_size, head_dim, sequence_length, 1]
m = torch.cumprod(eta, dim=2) # [batch_size, head_dim, sequence_length, 1]
m[:, :, 0:1, :] = 1

n = m * theta # n_i=m_i*theta_i
beta_T = beta[:, :, -1:, :].clone() # [batch_size, head_dim, 1, 1]
# calculate beta_T/beta_j
beta_ratio = beta_T / beta # [batch_size, head_dim, sequence_length, 1]
mask = torch.triu(torch.ones(seq_len, seq_len, dtype = beta.dtype)).to(
mask = torch.triu(torch.ones(seq_len, seq_len, dtype=beta.dtype)).to(
beta.device) # [sequence_length, sequence_length]
mask = mask.view(1, 1, seq_len, seq_len) # [1, 1, sequence_length, sequence_length]
beta_ratio = beta_ratio.view(*beta_ratio.shape[:-1], 1) # [batch_size, head_dim, sequence_length, 1]
Expand All @@ -41,7 +41,7 @@ def titans_linear(q, k, v, w, b, theta, alpha, eta, eps, BT, initial_state, outp
# [num_heads, 1, head_dim]
h = initial_state
if initial_state is None:
h = torch.zeros((B, H, D, D), device = v.device, dtype = v.dtype).to(torch.float32)
h = torch.zeros((B, H, D, D), device=v.device, dtype=v.dtype).to(torch.float32)
# [num_batch, B, num_heads, mini_batch_size, head_dim]
o = torch.empty_like(_v)

Expand All @@ -54,7 +54,7 @@ def titans_linear(q, k, v, w, b, theta, alpha, eta, eps, BT, initial_state, outp
reconstruction_target = v_i - k_i

mean = kh.mean(-1, True)
var = kh.var(-1, unbiased = False, keepdim = True).to(torch.float32)
var = kh.var(-1, unbiased=False, keepdim=True).to(torch.float32)
rstd = torch.sqrt(var + eps).to(torch.float32)
kh_hat = (kh - mean) / rstd

Expand All @@ -67,8 +67,8 @@ def titans_linear(q, k, v, w, b, theta, alpha, eta, eps, BT, initial_state, outp
h = beta_T * h - 2 * (f_i @ k_i).transpose(-1, -2) @ v_new
# layer norm with residuals

mean = o_i.mean(dim = -1, keepdim = True)
var = o_i.var(dim = -1, unbiased = False, keepdim = True).to(torch.float32)
mean = o_i.mean(dim=-1, keepdim=True)
var = o_i.var(dim=-1, unbiased=False, keepdim=True).to(torch.float32)
rstd = torch.sqrt(var + eps).to(torch.float32)
o[i] = o_i + (o_i - mean) / rstd * w + b

Expand Down
Loading

0 comments on commit 19ad2e2

Please sign in to comment.