Skip to content

Commit

Permalink
Merge pull request #71 from RWKV/rwkv-x-interweave-datapack
Browse files Browse the repository at this point in the history
Rwkv x interweave datapack
  • Loading branch information
PicoCreator authored Feb 1, 2024
2 parents 84fb00a + fd0342f commit d4fa285
Show file tree
Hide file tree
Showing 12 changed files with 3,111 additions and 110 deletions.
18 changes: 10 additions & 8 deletions RWKV-v5/src/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,11 +480,13 @@ def map_tokenizer(x):
for i in range(len(conversation)):
# lets loop through each key in the io pair
for key, value in conversation[i].items():
# Get the sender key
sender = key
# lets get the prefix for this key
prefix = conversation_prefix_encoding_map[key] if sender in conversation_prefix_encoding_map[key] else None
prefix = conversation_prefix_encoding_map[key] if sender in conversation_prefix_encoding_map else None

# Add the prefix
if prefix is not None:
if prefix is not None:
input_ids += prefix['input_ids']
token_type_ids += prefix['token_type_ids']
attention_mask += prefix['attention_mask']
Expand Down Expand Up @@ -1038,22 +1040,22 @@ def dataloader_collator_fn(records):

# Compute the total length of the records
input_ids_len = 0
token_type_ids_len = 0
attention_mask_len = 0
# token_type_ids_len = 0
# attention_mask_len = 0

# Loop through the records and compute the max length
for i in range(records_len):
input_ids_len = max(input_ids_len, len(records[i]["input_ids"]))
token_type_ids_len = max(token_type_ids_len, len(records[i]["token_type_ids"]))
attention_mask_len = max(attention_mask_len, len(records[i]["attention_mask"]))
# token_type_ids_len = max(token_type_ids_len, len(records[i]["token_type_ids"]))
# attention_mask_len = max(attention_mask_len, len(records[i]["attention_mask"]))

# First row of the records
first_row = records[0]

# Create the output arrays, with the default 0 values (no learning mask)
out_input_ids = torch.zeros((records_len, input_ids_len), dtype=first_row["input_ids"].dtype)
out_token_type_ids = torch.zeros((records_len, token_type_ids_len), dtype=first_row["token_type_ids"].dtype)
out_attention_mask = torch.zeros((records_len, attention_mask_len), dtype=first_row["attention_mask"].dtype)
out_token_type_ids = torch.zeros((records_len, input_ids_len), dtype=first_row["token_type_ids"].dtype)
out_attention_mask = torch.zeros((records_len, input_ids_len), dtype=first_row["attention_mask"].dtype)
out_data_ctx_len = torch.zeros((records_len), dtype=torch.int32)

out_index = 0
Expand Down
25 changes: 17 additions & 8 deletions RWKV-v5/src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,16 @@ def compute_loss(self, batch, batch_idx, is_training_run: bool = False, is_valid
# Checkpoint steps
def checkpointed_step(idx, targets, mask, last_shift_states,
last_wkv_states):
# Skip if there is no tokens of value to learn from
if idx.shape[1] == 0:
# Prepare dummy loss
train_loss = torch.tensor(0, dtype=self.emb.weight.dtype).requires_grad_()
sample_loss = train_loss.clone().detach().requires_grad_(False)

# Return the checkpoint values
return sample_loss, train_loss, last_shift_states, last_wkv_states, 0

# Get the logits, and the new states
logits, new_shift_states, new_wkv_states = self(
idx, last_shift_states, last_wkv_states)

Expand Down Expand Up @@ -1041,10 +1051,10 @@ def checkpointed_step(idx, targets, mask, last_shift_states,
# it also helps ensure the segment cutoff points are more varied, across mixed dataset sizes
# and avoid potentially undesired training behaviour at fixed cutoff points
# (this only applies for segmented learning)
segment_size = min(math.ceil(T / segment_count)+1, self.ctx_len)
segment_size = min(math.ceil(T / segment_count)+2, self.ctx_len)

# Dummy 2D tensor of shape [1,1], are used to do "dummy checkpoint/forward/backprop" to keep everything in sync
dummy_2d_zero = torch.tensor([[0]], dtype=torch.long, device=cur_device)
# Dummy 2D tensor of shape [B,0], are used to do "dummy checkpoint/forward/backprop" to keep everything in sync
dummy_empty_zero = torch.zeros(B,0, dtype=torch.long, device=cur_device)

# Get the max segment count across all GPUs, in the current substep, which is used to keep all devices are in sync
# Once a thread has completed all its segments, it will do dummy checkpoint/forward/backprop with one token,
Expand Down Expand Up @@ -1134,9 +1144,9 @@ def checkpointed_step(idx, targets, mask, last_shift_states,
cur_tar = targets[:, i * segment_size:(i + 1) * segment_size]
cur_msk = seq_mask[:, i * segment_size:(i + 1) * segment_size]
else:
cur_idx = dummy_2d_zero
cur_tar = dummy_2d_zero
cur_msk = dummy_2d_zero
cur_idx = dummy_empty_zero
cur_tar = dummy_empty_zero
cur_msk = dummy_empty_zero

# Segmented learning, applies the forward/pass over each chunk seperately
segment_sample_loss, segment_train_loss, new_shift_states, new_wkv_states, segment_train_tokens = checkpointed_step(
Expand Down Expand Up @@ -1286,8 +1296,7 @@ def checkpointed_step(idx, targets, mask, last_shift_states,
# Dataset based tracking
f'dataset/validation/{dataset_index}.loss': training_loss,
f'dataset/validation/{dataset_index}.data_loss': sampling_loss,
f'dataset/validation/{dataset_index}.tokens': tokens,
f'dataset/validation/{dataset_index}.ctx_len': ctx_len,
f'dataset/validation/{dataset_index}.ctx_len': T,
f'dataset/validation/{dataset_index}.name': dataset_name,

# Step and trainer tracking
Expand Down
13 changes: 9 additions & 4 deletions RWKV-v5/src/module/CoreDependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def is_torch_version_above(required_version):

# Torch versioning flags
IS_TORCH_2_1_COMPATIBLE = is_torch_version_above("2.1.0")
IS_TORCH_2_1_2_COMPATIBLE = is_torch_version_above("2.1.2")
# IS_TORCH_2_1_2_COMPATIBLE = is_torch_version_above("2.1.2")

# Get the JIT / torch compile option flags from the environment
# This default is FOR inference mode, the trainer mode default is configured in the lightning_trainer.py
Expand All @@ -54,15 +54,20 @@ def is_torch_version_above(required_version):
if 'RWKV_JIT_ON' not in globals():
RWKV_JIT_ON = os.getenv("RWKV_JIT_ON", "1").lower() in ("1", "true", "yes")
if 'RWKV_TORCH_COMPILE' not in globals():
RWKV_TORCH_COMPILE = os.getenv("RWKV_TORCH_COMPILE", f"0").lower() in ("1", "true", "yes")
RWKV_TORCH_COMPILE = os.getenv("RWKV_TORCH_COMPILE", f"1").lower() in ("1", "true", "yes")

# The RWKV_NO_CUDA global
global RWKV_NO_CUDA
if 'RWKV_NO_CUDA' not in globals():
RWKV_NO_CUDA = os.getenv("RWKV_NO_CUDA", f"0").lower() in ("1", "true", "yes")

# Disable torch compile if its not atleast v2.1.2
if not IS_TORCH_2_1_2_COMPATIBLE:
# Enforce no cuda, if there is no cuda
if torch.cuda is None or torch.cuda.is_available() == False or torch.cuda.device_count() <= 0:
print(f"[RWKV.model] No CUDA device found, setting RWKV_NO_CUDA=True")
RWKV_NO_CUDA = True

# Disable torch compile if its not atleast v2.1.0
if not IS_TORCH_2_1_COMPATIBLE:
RWKV_TORCH_COMPILE = False

# We enable JITMod*/Function when supporting torch.jit
Expand Down
100 changes: 32 additions & 68 deletions RWKV-v5/src/module/TimeMix.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Dependencies
from .CoreDependencies import *
from .OptimizedOps import modified_lerp
from .rwkv_inner import rwkv_inner
import os

# Current code file path
Expand Down Expand Up @@ -274,7 +275,7 @@ def forward(self, x, last_state: tuple[torch.Tensor,torch.Tensor]):
return self._forward_cuda(x, last_state)

# Run without cuda (cpu mode, etc)
return self._forward_nocuda(x, last_state)
return self._forward_nocuda_optimized(x, last_state)

def _forward_cuda(self, x, last_state: tuple[torch.Tensor,torch.Tensor]):
# Get the x sizing
Expand Down Expand Up @@ -313,94 +314,57 @@ def _forward_cuda(self, x, last_state: tuple[torch.Tensor,torch.Tensor]):
# Return the logits and the state
return (x_logits, (x[:,-1],state))

# Doing the forward pass withotu CUDA, this is currently rather slow
# and is not recommended - but it works (awaiting future optimization)
#
# We intentionally split the forward pass into smaller chunks of 32
# to ensure it processes at a resonable rate / vram usage
@TCompileMax
@JITModMethod
def _forward_nocuda(self, x, last_state: tuple[torch.Tensor,torch.Tensor]):
# Get the x sizing
B, TT, C = x.size()
def _forward_nocuda_optimized(self, x, last_state: tuple[torch.Tensor,torch.Tensor]):
shift_state_out = x[:,-1]

# Logits to return
x_logits = torch.zeros(B, TT, C, device=x.device, dtype=x.dtype)
# 24 is optimal chunk length (longer will use too much memory and cause precision problems or even numerical instability, shorter is inefficient)
chunk_len = 24

# Process in chunks
chunk_len = 32
for i in range(0, TT, chunk_len):
# Get the chunk
chunk = x[:, i:i+chunk_len]
# padding to support fast path for non-exact chunk size multiple sequence lengths
n_padding = (chunk_len - x.size(-2) % chunk_len) % chunk_len
if n_padding != 0:
x = F.pad(x, [0, 0, 0, n_padding, 0, 0])

# Process the chunk
chunk_logits, last_state = self._forward_nocuda_noChunking(chunk, last_state)

# Store the chunk logits
x_logits[:, i:i+chunk_len] = chunk_logits

# Return the logits and the state
return (x_logits, last_state)

# The no chunking varient of forwarding without cuda
@JITModMethod
def _forward_nocuda_noChunking(self, x, last_state: tuple[torch.Tensor,torch.Tensor]):
# Get the x sizing
B, TT, C = x.size()
B, T, C = x.size()
H = self.n_head
K = self.head_size
V = K

# Perform the tokenshift, and get the respective state
xx = torch.concat((last_state[0].unsqueeze(1), x[:, :-1]), dim=1)
# Get the xk, xv, xr, xg

# Get the xk, xv, xr, xg, and rkvg
xk = modified_lerp(x, self.time_mix_k, xx)
xv = modified_lerp(x, self.time_mix_v, xx)
xr = modified_lerp(x, self.time_mix_r, xx)
xg = modified_lerp(x, self.time_mix_g, xx)

r = self.receptance(xr).view(B, TT, self.n_head, 1, -1)
k = self.key(xk).view(B, TT, self.n_head, -1, 1)
v = self.value(xv).view(B, TT, self.n_head, 1, -1)
r = self.receptance(xr).view(B, T, H, K).transpose(1, 2) # BHTK
k = self.key(xk).view(B, T, H, K).transpose(1, 2) # BHTK
v = self.value(xv).view(B, T, H, V).transpose(1, 2) # BHTV
g = F.silu(self.gate(xg))

# The WKV state to update
if last_state[1] is None:
wkv_state = torch.zeros((B, self.n_head, self.head_size, self.head_size)).to(r.dtype)
else:
wkv_state = last_state[1].clone().to(r.dtype)

# # Compute attent and the initial output tensor
# at = k @ v
# u = self.time_faaaa.view(1,1,self.n_head, 1, -1)
w = torch.exp(-torch.exp(self.time_decay.float())).view(1,H,1,K).expand(1,H,T,K)

# # Slightly inefficent, but it works, lets compute all the tokens
# w = self.time_decay.exp().neg().exp().reshape(1, self.n_head,-1,1)
u = self.time_faaaa.view(1,H,1,K).to(r.dtype)

# out = (u * r) @ at
# for t in range(TT):
# out[:,t] += r[:,t] @ wkv_state

# # We make a clone copy, so the previous object backprop state is tracked seperately
# wkv_state = wkv_state.clone()
# wkv_state *= w
# wkv_state += at[:,t]

wkv_state, out = compute_wkv_state(
k, v, r,
self.time_faaaa,
self.time_decay,
wkv_state,
self.n_head, self.head_size,
B, TT
)
# Logits and state
wkv_state = last_state[1].to(r.dtype)

x_logits, wkv_state = rwkv_inner(r, k, v, w, u, wkv_state, chunk_len=chunk_len)
x_logits = x_logits.transpose(1,2).reshape(B,T,C)

# Reshape and normalize the logits
x_logits = self._x_logits_gate(x_logits, g)

# Compute the final x output
x_logits = out.view(-1, C)
x_logits = self.ln_x(x_logits / self.head_size_divisor).view(B, TT, C)
x_logits = self.output(x_logits * g)
if n_padding != 0:
x_logits = x_logits[..., :-n_padding, :] # BHTV

# Return the logits and the state
return (x_logits, (x[:,-1],wkv_state))
# return x_logits, (x[:,-1],ms[-1])
return (x_logits, (shift_state_out,wkv_state))

@TCompileMax
@JITModMethod
Expand Down
112 changes: 112 additions & 0 deletions RWKV-v5/src/module/rwkv_inner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import Tensor

# 24 is optimal chunk length (longer will use too much memory and cause precision problems or even numerical instability, shorter is inefficient)
def rwkv_inner(r,k,v,w,u,kv_state,chunk_len:int=24,precision_dtype:torch.dtype=torch.float32):
"""
expects
r : (B,H,L,K)
k : (B,H,L,K)
v : (B,H,L,V)
w : (B,H,L,K) or (1,H,L,K)
u : (1,H,1,K)
kv_state : (B,H,K,V)
"""
B,H,L,K = k.size()
V = v.size(-1)
T = chunk_len

if L == 1:
kv = k @ v
out = r @ (kv_state + u * kv)
kv_state = w * kv_state + kv
return out, kv_state
else:
# FIXME - support fast path for non-exact multiples
# ensure it's an exact multiple
if L % T != 0:
T = 1

N = L // T

# this has to be done to avoid numerical instability (inf/NaN) when w is used as a divisor up to chunk_length//2 places away (so precision_min_val^(T//2) has to be in fp range)
# NOTE - this does not account for the impact of the size of R, K so we currently use the chunk_len=32 numbers for chunk_len=24
assert(precision_dtype == torch.float32 or precision_dtype == torch.float64)
if precision_dtype == torch.float32:
precision_min_val = 0.005 # good for fp32 (1.175e-38 ^ (1/16.0) < 0.00426)
else: #elif precision_dtype == torch.float64:
precision_min_val = 1e-10 # good for fp64 (1.7e-308 ^ (1/16.0) < 5.8e-20)
w = w.clamp(precision_min_val)

# calculate cumulative decay in log space where it won't overflow
w_log = w.float().log() # (1,H,L,K) or (B,H,L,K)

# chunked view of w_log
wc_log = w_log.view(w.size(0),H,N,T,K)
wc_log_cum = wc_log.cumsum(dim=-2)

# chunked view of shifted_w_log
shifted_wc_log_cum = F.pad(wc_log_cum, (0, 0, 1, -1))


# NOTE - we have to apply the decay weight from TWO ahead.. ONE ahead gets no decay (log==0)
# pre-applied weights
# left side is prior chunk (w_inter), right side is current chunk (w_intra)
# without u...
# w0 w1 w2 w3 | w4 w5 w6 w7
# w1:4 w2:4 w3:4 w4:4 | w4:5 w4:6 w4:7 w4:8
# with u...
# w0 w1 w2 w3 | w4 w5 w6 w7
# w1:4 w2:4 w3:4 w4:4 | w4:4 w4:5 w4:6 w4:7

# ws decays the entire current state (representing t-1) to the prior block (t-2)
ws = wc_log.sum(dim=-2, keepdim=True) # 1HN1K or BHN1K
# w_inter is the decay to the end of the current block, since it will be applied at the next iteration when current (t) becomes prior (t-1)
# this formula because e.g. w1:4 = w0:4 - w0:1
w_inter = ws - wc_log_cum # 1HNTK or BHNTK (w^(T-1) ... w^0)
# w_intra is the decay from the beginning of the current block (t), since it will be applied to current queries (t) against prior state (representing keys+values up to but not including block t)
# this formula because e.g. w1:3 = w0:3 - w0
w_intra = wc_log_cum - wc_log # 1HNTK or BHNTK (w^0 ... w^(T-2))

ws = list(ws.mT.exp().to(r.dtype).unbind(dim=-3)) # N x 1HK1 or BHK1 !!NOTE THE .mT HERE!!
w_inter = w_inter.exp().to(r.dtype) # 1HNTK or BHNTK
w_intra = w_intra.exp().to(r.dtype) # 1HNTK or BHNTK

# chunked view of r, k, v
r = r.view(B,H,N,T,K)
k = k.view(B,H,N,T,K)
v = v.view(B,H,N,T,V)
u = u.unsqueeze(2).to(r.dtype) # (1,H,1,1,K)

# parallel calculation of all intra-chunk attention contributions
wc_log_offset = shifted_wc_log_cum[...,T//2:T//2+1,:] # B,H,N,1,K
r_decay = (shifted_wc_log_cum - wc_log_offset).to(precision_dtype).exp() # B,H,N,T,K
k_inv_decay = (wc_log_offset - wc_log_cum).to(precision_dtype).exp() # B,H,N,T,K
a = ((r*r_decay) @ (k*k_inv_decay).mT).to(r.dtype).tril(-1) # B,H,N,T,T
# add u term to attention (NOTE - the tril(-1) above zeroed the diagonal)
a = a + torch.einsum('bhntk,bhntk->bhnt', r, u * k).diag_embed()
out = a @ v # BHNTV
# alternate way of adding in u
# out = out + torch.einsum('bhntk,bhntk,bhntv->bhntv', r, u * k, v)

# parallel precalculation of chunked (k*wk).mT@v for use in recurrent state calc below
wkv = (k * w_inter).mT @ v # BHNKV
wkv = list(wkv.unbind(dim=-3)) # N x BHKV

# recurrent calculation of all states
states = []
for i in range(N):
states.append(kv_state)
kv_state = kv_state * ws[i] + wkv[i] # BHKV
# equivalent non-precalced version
#wkv = (k[...,i,:,:] * wk[...,i,:,:]).mT @ v[...,i,:,:]
#kv_state = kv_state * ws[i] + wkv
states = torch.stack(states, dim=2) # BHNKV

# parallel application of all r to states
out = out + (r * w_intra) @ states # BHNTV
out = out.view(B,H,L,V)
return out, kv_state
Loading

0 comments on commit d4fa285

Please sign in to comment.