Skip to content

Commit

Permalink
Merge pull request #72 from RWKV/rwkv-x-eagle-notebooks
Browse files Browse the repository at this point in the history
Rwkv x eagle notebooks
  • Loading branch information
PicoCreator authored Feb 2, 2024
2 parents d4fa285 + 6d52048 commit a7b090d
Show file tree
Hide file tree
Showing 14 changed files with 4,516 additions and 1,638 deletions.
6 changes: 3 additions & 3 deletions RWKV-v5/config-example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -431,15 +431,15 @@ data:
# or throw an error if the default fallback is not found
#
# IMPORTANT NOTE: as newlines are commonly used for multi_column_suffix, etc.
# you should use single quotes to ensure such values dun get escaped.
# eg. multi_column_suffix: ['\n\n']
# you should use double quotes to ensure such values dun get escaped.
# eg. multi_column_suffix: ["\n\n"]
#
# See: https://github.com/RWKV/RWKV-infctx-trainer/issues/34
# Need to use " or the new lines won't be tokenized properly
# ---
# multi_column_keys: ["instruction", "input", "output"]
# multi_column_prefix: ["Instruction:\n", "Input:\n", "Output:\n"]
# multi_column_suffix: ["\n\n", "\n\n", "\n\n"]
# multi_column_suffix: ['', '', '']
# multi_column_train_mask: [true, false, true]
# multi_column_separator: "\n\n"

Expand Down
7 changes: 4 additions & 3 deletions RWKV-v5/src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(self, layer_id, n_layer, n_embd, n_head, head_size, dropout, dim_at
self.drop0 = nn.Dropout(p = dropout)
self.drop1 = nn.Dropout(p = dropout)

@TCompileBaseline
def forward(self, x, last_state: BlockState):
if self.layer_id == 0:
x = self.ln0(x)
Expand Down Expand Up @@ -599,7 +600,7 @@ def deepspeed_stage(self) -> int:
return "stage" in cfg
return -1

@TCompileBaseline
# @TCompileBaseline
def forward(self, idx: torch.Tensor, last_shift_states: torch.Tensor = None,
last_wkv_states: torch.Tensor = None):
B, T = idx.size()
Expand Down Expand Up @@ -797,7 +798,7 @@ def manual_backward(self, loss: torch.Tensor, *args, **kwargs):
#
# Main compute_loss function, this is called by the trainer loop
#
@TCompileBaseline
# @TCompileBaseline
def compute_loss(self, batch, batch_idx, is_training_run: bool = False, is_validation_run: bool = False):

# Used for token/second performance tracking
Expand Down Expand Up @@ -1334,7 +1335,7 @@ def training_step(self, batch, batch_idx):

return training_loss

@TCompileBaseline
# @TCompileBaseline
def validation_step(self, batch, batch_idx):
sampling_loss, training_loss = self.compute_loss(batch, batch_idx, False, True)
self.log('validation/loss', sampling_loss, prog_bar=True, sync_dist=True)
Expand Down
6 changes: 3 additions & 3 deletions RWKV-v5/src/module/CoreDependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,16 @@ def is_torch_version_above(required_version):
if 'RWKV_JIT_ON' not in globals():
RWKV_JIT_ON = os.getenv("RWKV_JIT_ON", "1").lower() in ("1", "true", "yes")
if 'RWKV_TORCH_COMPILE' not in globals():
RWKV_TORCH_COMPILE = os.getenv("RWKV_TORCH_COMPILE", f"1").lower() in ("1", "true", "yes")
RWKV_TORCH_COMPILE = os.getenv("RWKV_TORCH_COMPILE", f"0").lower() in ("1", "true", "yes")

# The RWKV_NO_CUDA global
global RWKV_NO_CUDA
if 'RWKV_NO_CUDA' not in globals():
RWKV_NO_CUDA = os.getenv("RWKV_NO_CUDA", f"0").lower() in ("1", "true", "yes")
RWKV_NO_CUDA = os.getenv("RWKV_NO_CUDA", f"1").lower() in ("1", "true", "yes")

# Enforce no cuda, if there is no cuda
if torch.cuda is None or torch.cuda.is_available() == False or torch.cuda.device_count() <= 0:
print(f"[RWKV.model] No CUDA device found, setting RWKV_NO_CUDA=True")
print(f"[RWKV.model] No CUDA device found, enforcing RWKV_NO_CUDA=True")
RWKV_NO_CUDA = True

# Disable torch compile if its not atleast v2.1.0
Expand Down
4 changes: 2 additions & 2 deletions RWKV-v5/src/module/TimeMix.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def _forward_cuda(self, x, last_state: tuple[torch.Tensor,torch.Tensor]):
# Return the logits and the state
return (x_logits, (x[:,-1],state))

@TCompileMax
# @TCompileMax
@JITModMethod
def _forward_nocuda_optimized(self, x, last_state: tuple[torch.Tensor,torch.Tensor]):
shift_state_out = x[:,-1]
Expand Down Expand Up @@ -366,7 +366,7 @@ def _forward_nocuda_optimized(self, x, last_state: tuple[torch.Tensor,torch.Tens
# Return the logits and the state
return (x_logits, (shift_state_out,wkv_state))

@TCompileMax
# @TCompileMax
@JITModMethod
def _x_logits_gate(self, x_logits, gate):
B, T, C = x_logits.size()
Expand Down
685 changes: 685 additions & 0 deletions notebook/finetune-example/Eagle-x-ALMA-prompt-completion.ipynb

Large diffs are not rendered by default.

195 changes: 195 additions & 0 deletions notebook/finetune-example/Eagle-x-ALMA-prompt-completion.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
###############################################
##
## Trainer settings are kept minimal
##
## See the full `config-example.yaml` for more
## details on the trainer/model configs
##
###############################################

trainer:
# Limit to 1 epoch
max_epochs: 1

# Resonable batch size, for a more realistic it/s rate
# this is currently overwritten in the notebook
target_batch_size: 64

# Logger setting for wandb, if you want to enable wandb, uncomment the whole logger section
# ---
logger:
class_path: lightning.pytorch.loggers.WandbLogger
init_args:
name: 'Eagle-x-finetune'
project: 'RWKV-V5-Eagle-Finetune'
tags: ['Eagle', 'RWKV-V5']

# Checkpoint settings for the training process
callbacks:
class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
# Configure this to the path you want to save your checkpoints to
# note that a subdir will be created with the name `epoch=x-step=y.ckpt`
#
# to convert a checkpoint to a model, you can use the
# `python3 export_checkpoint.py <checkpoint path>` script,
# which will create a `rwkv_model.pth` in the checkpoint directory.
#
# Do not use the `zero_to_fp32.py` script as that will have export format issues
dirpath: ../checkpoint/finetune-example/Eagle-x-ALMA-prompt-completion/
filename: null

# Save the top/last K checkpoints
save_top_k: 2
# Choose the most recent checkpoints by steps
monitor: 'step'
mode: max

# If enabled (true), save a copy of the latest checkpoint to 'last.ckpt'
# useful to simply checkpoint resume scripts, at a price of disk performance
save_last: true

# DO NOT set this as true, as the model weight exported will have format issues
# expert as checkpoint, and use the `export_checkpoint.py` script to convert to model instead
save_weights_only: false

# How frequent you want to save a checkpoint for every step.
# This will happen for every X data sample, where X = every_n_train_steps * accumulate_grad_batches
#
# In general you will want to avoid putting a low number (expecially if accumulate_grad_batches <= 100)
# as the checkpoint process, will pause all the gpu training for some time, slowing down the overall process
# However you do not want to configure too high of a number, where you will lose too much progress if the training crashes
every_n_train_steps: null
every_n_epochs: 1
save_on_train_epoch_end: true
train_time_interval: null

# Other pytorch lightning settings, which in most cases you can remove/ignore
# ---
# verbose: false
# auto_insert_metric_name: true

model:
# The model to load
load_model: ../model/L6-D512-neox-init.pth

# Starting and ending learning rate
lr_init: 5e-5
lr_final: 5e-5

# Training context length, note that the dataset can be
# larger then the context size, in which the trainer
# will process the dataset in chunks
ctx_len: 2048

# BPTT learning, this allows you to run the trainer against dataset
# larger then its training context length
bptt_learning: true
bptt_learning_range: -1

########################################
## Training model settings
########################################
data:
# dataset_path for the prebuilt dataset, using HF `load_from_disk()`
#
# Use this if you have built your own dataset and saved it with `save_to_disk()`
# with source left as null. Other wise configure this to a directory which the
# dataset will be built and tokenized by the huggingface dataset process.
data_path: ../datapath/world/alma-repacked-16k/

# Other wise provide the source path, which is used as huggingface dataset path
# this will be used to populate the dataset_path
#
# Use either the following
# - hugging face dataset
# - Directory path to a directory containing dataset files
# - Path to a single dataset file
# - hugging face dataset mode (ie: text,csv,etc - use data_dir, to configure the path then)
# - null
#
# If source is disabled, all other params, except data_path, is ignored
source: "kristaller486/ALMA-prompt-completion"
# source: text
# source: /home/ubuntu/RWKV-LM-LoRA/dataset-text/enwik8.txt

# # Additional source dataset params, used to grab subsets of the dataset
# source_dataset_params:
# language: en

# # Use data_dir, if you are using source=text/json/etc
# # this should be relative to the trainer script path
# source_data_dir: null

# Tokenizer to use, use either the inbuilt 'neox', or 'neox' tokenizer
# If using a custom tokenizer, provide the tokenizer file path
# ---
tokenizer: world

# Minimum / Maximum token size of the dataset to use
# useful for filtering out small noisy data samples from large datasets
# (eg. removal of small articles of less then 512 tokens from wikipedia)
#
# This is ignored, if set to -1
min_token_size: -1
max_token_size: -1

# Multi Column merging process, default setting is used to support and merge
# "instruction", "input", "output", datasets. To disable set multi_column_keys to []
#
# A minimum of 2 columns is required, with non empty data, for the merge to occur
# If no match is found, this will fallback to the default prompt/completion or text column,
# or throw an error if the default fallback is not found
#
# IMPORTANT NOTE: as newlines are commonly used for multi_column_suffix, etc.
# you should use single quotes to ensure such values dun get escaped.
# eg. multi_column_suffix: ['\n\n']
#
# See: https://github.com/RWKV/RWKV-infctx-trainer/issues/34
# Need to use " or the new lines won't be tokenized properly
# ---
# multi_column_keys: ["instruction", "input", "output"]
# multi_column_prefix: ["Instruction:\n", "Input:\n", "Output:\n"]
# multi_column_suffix: ["\n\n", "\n\n", "\n\n"]
# multi_column_train_mask: [true, false, true]
# multi_column_separator: "\n\n"

# If processing prompt/completion jsonl pairs, the prompt is masked by default
# use this flag to disable this default behaviour
# ---
# disable_prompt_completion_mask: false

# After loading the dataset, split out test data used for validation,
# This process is skipped if the dataset includes a test split
# This process is skipped if set to zero
test_split: 0.01
test_split_shuffle: false

# ----------------------------
# Dataset packing support
# Recommended to be used with mixed documents sized finetuning
# For foundation model "from scratch", rechunking is typically used instead
# ----------------------------

# Boolean flag to enable / disable dataset packing
packing_enable: True

# Used to ensure all training samples wihin this batch size is the same length
# Ideally this should align exactly with your real "batch size"
#
# Uses, `8 * (3 * 4 * 5 * 6 * 7) = 20160` for default, as it should align across
# a large number of batch size combinations. This helps reduce the amount of
# misaligned batches, and thus reduce the amount of wasted training time.
packing_batchsize: 64

# Chunking size to align within each batch, this ideally should be equal to
# the training context length used.
packing_chunksize: 8192

# Minimum size to pack up to, this should be a multiple of packing_chunksize
# defautls to -1, which equals to packing_chunksize
packing_min_ctx_len: -1

# Pack the data sequentially if possible, in accordance to the dataset sequence
# this can be used together with sort_by_length, otherwise a shuffle will be done
packing_in_sequence: False
Loading

0 comments on commit a7b090d

Please sign in to comment.