Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expose model version api #170

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,9 @@ jobs:
- macOS-latest-cmake
- windows-latest-cmake

permissions:
contents: write

steps:
- name: Download artifacts
id: download-artifact
Expand Down
6 changes: 6 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ endif()

if (RWKV_CUBLAS)
cmake_minimum_required(VERSION 3.17)
set(CMAKE_CUDA_COMPILER_FORCED TRUE)

find_package(CUDAToolkit)

Expand Down Expand Up @@ -417,6 +418,11 @@ target_compile_features(ggml PUBLIC c_std_11) # Don't bump

if (MSVC)
target_link_libraries(ggml PUBLIC ${RWKV_EXTRA_LIBS} Threads::Threads)
if (RWKV_CUBLAS)
target_compile_options(ggml PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:
-allow-unsupported-compiler
>)
endif()
else()
if (WIN32 AND RWKV_HIPBLAS)
target_link_libraries(ggml PUBLIC ${RWKV_EXTRA_LIBS} Threads::Threads)
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ This project provides [a C library rwkv.h](rwkv.h) and [a convinient Python wrap

[RWKV v5](https://huggingface.co/BlinkDL/rwkv-5-world) is a major upgrade to RWKV architecture, making it competitive with Transformers in quality. RWKV v5 models are supported.

[RWKV v6](https://huggingface.co/BlinkDL/rwkv-6-world) is a further improvement to RWKV architecture, with better quality. RWKV v6 models are supported.

Loading LoRA checkpoints in [Blealtan's format](https://github.com/Blealtan/RWKV-LM-LoRA) is supported through [merge_lora_into_ggml.py script](rwkv%2Fmerge_lora_into_ggml.py).

## Quality and performance
Expand Down
21 changes: 18 additions & 3 deletions python/convert_pytorch_to_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t

is_v5_1_or_2: bool = 'blocks.0.att.ln_x.weight' in state_dict
is_v5_2: bool = 'blocks.0.att.gate.weight' in state_dict
is_v6_0: bool = 'blocks.0.att.time_maa_x' in state_dict

if is_v5_2:
if is_v6_0:
print('Detected RWKV v6.0')
elif is_v5_2:
print('Detected RWKV v5.2')
elif is_v5_1_or_2:
print('Detected RWKV v5.1')
Expand All @@ -57,13 +60,25 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t
1 if is_FP16 else 0
))

if is_v6_0:
n_head: int = state_dict['blocks.0.att.time_faaaa'].shape[0]
for k in state_dict.keys():
tensor: torch.Tensor = state_dict[k].float()

if '.time_' in k:
tensor = tensor.squeeze()

if is_v5_1_or_2:
if is_v6_0:
if '.time_faaaa' in k:
tensor = tensor.unsqueeze(-1)
if '.time_maa_w1' in k or '.time_decay_w' in k:
tensor = tensor.transpose(0, 1)
if '.time_maa_w2' in k:
tensor = tensor.transpose(1, 2)
if '.time_decay' in k and '_w' not in k:
tensor = tensor.reshape(n_head, -1, 1)

elif is_v5_1_or_2:
if '.time_decay' in k:
if is_v5_2:
tensor = torch.exp(-torch.exp(tensor)).unsqueeze(-1)
Expand Down Expand Up @@ -105,7 +120,7 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t

out_file.write(k_encoded)

tensor.numpy().tofile(out_file)
tensor.detach().numpy().tofile(out_file)

def main() -> None:
args = parse_args()
Expand Down
16 changes: 13 additions & 3 deletions python/merge_lora_into_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
def parse_args():
parser = argparse.ArgumentParser(description='Merge a PyTorch LoRA checkpoint (.pth) into an rwkv.cpp model file')
parser.add_argument('src_path', help='Path to source rwkv.cpp model')
parser.add_argument('rwkv_arch_version', help='Version of RWKV architecture: v4, v5.1, v5.2', type=str, choices=['v4', 'v5.1', 'v5.2'])
parser.add_argument('rwkv_arch_version', help='Version of RWKV architecture: v4, v5.1, v5.2, v6.0', type=str, choices=['v4', 'v5.1', 'v5.2', 'v6.0'])
parser.add_argument('lora_path', help='Path to LoRA checkpoint in PyTorch format')
parser.add_argument('lora_alpha', help='Value of lora_alpha parameter used when training this LoRA checkpoint', type=int)
parser.add_argument('dest_path', help='Path to destination rwkv.cpp model, will be overwitten with the merged model')
Expand Down Expand Up @@ -47,7 +47,7 @@ def main() -> None:

arch_version: str = args.rwkv_arch_version

if not (arch_version == 'v4' or arch_version == 'v5.1' or arch_version == 'v5.2'):
if not (arch_version == 'v4' or arch_version == 'v5.1' or arch_version == 'v5.2' or arch_version == 'v6.0'):
raise ValueError(f'Invalid RWKV architecture version {arch_version}')

print(f'Reading {args.lora_path}')
Expand Down Expand Up @@ -108,7 +108,17 @@ def main() -> None:
if '.time_' in key:
replacement = replacement.squeeze()

if arch_version == 'v5.1' or arch_version == 'v5.2':
if arch_version == 'v6.0':
if '.time_faaaa' in k:
replacement = replacement.unsqueeze(-1)
if '.time_maa_w1' in k or '.time_decay_w' in k:
replacement = replacement.transpose(0, 1)
if '.time_maa_w2' in k:
n_head: int = replacement.shape[1]
replacement = replacement.transpose(1, 2)
if '.time_decay' in k and '_w' not in k:
replacement = replacement.reshape(n_head, -1, 1)
elif arch_version == 'v5.1' or arch_version == 'v5.2':
if '.time_decay' in key:
if arch_version == 'v5.2':
replacement = torch.exp(-torch.exp(replacement)).unsqueeze(-1)
Expand Down
8 changes: 8 additions & 0 deletions python/rwkv_cpp/rwkv_cpp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,14 @@ def gpu_offload_layers(self, layer_count: int) -> bool:

return self._library.rwkv_gpu_offload_layers(self._ctx, layer_count)

@property
def arch_version_major(self) -> int:
return self._library.rwkv_get_arch_version_major(self._ctx)

@property
def arch_version_minor(self) -> int:
return self._library.rwkv_get_arch_version_minor(self._ctx)

@property
def n_vocab(self) -> int:
return self._library.rwkv_get_n_vocab(self._ctx)
Expand Down
30 changes: 30 additions & 0 deletions python/rwkv_cpp/rwkv_cpp_shared_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ def __init__(self, shared_library_path: str) -> None:
]
self.library.rwkv_eval_sequence_in_chunks.restype = ctypes.c_bool

self.library.rwkv_get_arch_version_major.argtypes = [ctypes.c_void_p]
self.library.rwkv_get_arch_version_major.restype = ctypes.c_uint32

self.library.rwkv_get_arch_version_minor.argtypes = [ctypes.c_void_p]
self.library.rwkv_get_arch_version_minor.restype = ctypes.c_uint32

self.library.rwkv_get_n_vocab.argtypes = [ctypes.c_void_p]
self.library.rwkv_get_n_vocab.restype = ctypes.c_size_t

Expand Down Expand Up @@ -284,6 +290,30 @@ def rwkv_eval_sequence_in_chunks(
):
raise ValueError('rwkv_eval_sequence_in_chunks failed, check stderr')

def rwkv_get_arch_version_major(self, ctx: RWKVContext) -> int:
"""
Returns the major version used by the given model.

Parameters
----------
ctx : RWKVContext
RWKV context obtained from rwkv_init_from_file.
"""

return self.library.rwkv_get_arch_version_major(ctx.ptr)

def rwkv_get_arch_version_minor(self, ctx: RWKVContext) -> int:
"""
Returns the minor version used by the given model.

Parameters
----------
ctx : RWKVContext
RWKV context obtained from rwkv_init_from_file.
"""

return self.library.rwkv_get_arch_version_minor(ctx.ptr)

def rwkv_get_n_vocab(self, ctx: RWKVContext) -> int:
"""
Returns the number of tokens in the given model's vocabulary.
Expand Down
12 changes: 12 additions & 0 deletions rwkv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ static_assert(sizeof(decltype(ftell(NULL))) >= 8, "File offsets should be 64-bit

#include "rwkv_operators_wkv_v5.inc"

#include "rwkv_operators_wkv_v6.inc"

#include "rwkv_graph.inc"

// API function.
Expand Down Expand Up @@ -104,6 +106,16 @@ extern "C" RWKV_API uint32_t rwkv_get_logits_buffer_element_count(const struct r
return rwkv_get_logits_len(ctx);
}

// API function.
size_t rwkv_get_arch_version_major(const struct rwkv_context * ctx) {
return (size_t) ctx->model->arch_version_major;
}

// API function.
size_t rwkv_get_arch_version_minor(const struct rwkv_context * ctx) {
return (size_t) ctx->model->arch_version_minor;
}

// API function.
size_t rwkv_get_n_vocab(const struct rwkv_context * ctx) {
return (size_t) ctx->model->header.n_vocab;
Expand Down
6 changes: 6 additions & 0 deletions rwkv.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,12 @@ extern "C" {
float * logits_out
);

// Returns the major version used by the given model.
RWKV_API size_t rwkv_get_arch_version_major(const struct rwkv_context * ctx);

// Returns the minor version used by the given model.
RWKV_API size_t rwkv_get_arch_version_minor(const struct rwkv_context * ctx);

// Returns the number of tokens in the given model's vocabulary.
// Useful for telling 20B_tokenizer models (n_vocab = 50277) apart from World models (n_vocab = 65536).
RWKV_API size_t rwkv_get_n_vocab(const struct rwkv_context * ctx);
Expand Down
Loading