From bb91e49b7128564f197bbc2ee64879a61a0279f4 Mon Sep 17 00:00:00 2001 From: wildkid1024 Date: Fri, 8 Mar 2024 18:39:18 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E6=B7=BB=E5=8A=A0python=20Tensor=E7=BA=A7a?= =?UTF-8?q?pi?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyfastllm/README.md | 20 +- pyfastllm/examples/chatglm2.py | 293 ++++++++++++++++++ pyfastllm/examples/test_chatglm2_cpp.py | 263 ++++++++++++++++ pyfastllm/examples/test_ops.py | 155 +++++---- pyfastllm/fastllm/functions/__init__.py | 2 +- pyfastllm/fastllm/functions/custom_ops.py | 1 + pyfastllm/fastllm/functions/fastllm_ops.py | 79 +++-- pyfastllm/fastllm/functions/numpy_ops.py | 41 +++ pyfastllm/fastllm/nn/__init__.py | 3 +- .../nn/{BaseModule.py => base_module.py} | 7 +- pyfastllm/fastllm/nn/modules.py | 63 ++++ src/fastllm.cpp | 9 +- src/pybinding.cpp | 221 ++++++++++--- 13 files changed, 996 insertions(+), 161 deletions(-) create mode 100644 pyfastllm/examples/chatglm2.py create mode 100644 pyfastllm/examples/test_chatglm2_cpp.py create mode 100644 pyfastllm/fastllm/functions/numpy_ops.py rename pyfastllm/fastllm/nn/{BaseModule.py => base_module.py} (80%) create mode 100644 pyfastllm/fastllm/nn/modules.py diff --git a/pyfastllm/README.md b/pyfastllm/README.md index b8d67838..283c3d8f 100644 --- a/pyfastllm/README.md +++ b/pyfastllm/README.md @@ -6,12 +6,19 @@ pyfastllm是基于fastllm的python api接口实现,通过pyfastllm可以更加 - 对接fastapi、flask等web框架,向外提供数据接口 - 利用python yield生成器语言特性,流式问答响应 +- 类似于torch的python低级接口,目前支持到cpu版本 - 对接Lora、Ptuning等微调方法,下游任务可微调(开发中...) - 无缝对接加速HugingFace模型库,无痛加速迁移原有业务代码(开发中...) - 其他更多... ## 版本更新 +### v0.2.1 2024-03-08 +- 增加了低级python接口 +- 测试低级接口,实现了纯python版本的chatglm2 +- 增加了一些新的op + + ### v0.2.0 2023-10-23 - 代码结构调整优化 @@ -103,6 +110,7 @@ python3 cli_simple.py -p chatglm-6b-int8.flm -t 8 - `examples/convert_model.py`: 模型转换示例 - `examples/web_api.py`, `examples/web_api_client.py`: fastapi webapi调用 - `examples/test_ops.py`: 部分op的使用样例及测试 +- `examples/chatglm2.py`: 低级python接口下的chatglm2模型(目前仅支持cpu) ### 命令行工具 @@ -211,10 +219,10 @@ python web_api.py -m 0 -p path_for_chatglm --max_batch_size 32 - [x] 修改response_batch的output_str函数,以返回值的形式返回答案 - [x] 编解码部分优化,合并不同的返回类型 -- [ ] 对接numpy等矩阵库 -- [ ] Tensor的深复制和浅复制,以及基础运算符重载 -- [ ] fix low_api下pastKV复制的bug +- [x] 对接numpy等矩阵库 +- [ ] Tensor的深复制和浅复制,以及基础运算符重载,在python端编写 +- [x] fix low_api下pastKV复制的bug - [x] 模型运行参数对象类,封装模型运行时参数,包含模型路径、运行线程数、是否为低内存模型、惩罚因子、温度等 -- [ ] 增加更多的op -- [ ] 增加module - +- [x] 增加更多的op以及module,后续可增加更多 +- [ ] 增加其他后端 +- [ ] 更新文档接口说明 \ No newline at end of file diff --git a/pyfastllm/examples/chatglm2.py b/pyfastllm/examples/chatglm2.py new file mode 100644 index 00000000..0638c9e8 --- /dev/null +++ b/pyfastllm/examples/chatglm2.py @@ -0,0 +1,293 @@ +import sys +import pytest +import numpy as np +import fastllm + +import pyfastllm +import gc + +from pathlib import Path + +np.random.seed(42) + +from abc import ABC +from abc import abstractmethod +from typing import Any + +from fastllm import ops +# import ops + +def diff(dataA, dataB): + mae = np.max(np.abs(dataA - dataB)) + print('max abs err is ', mae) + return mae + +def to_tensor(data): + if not isinstance(data, np.ndarray): + return None + + return pyfastllm.from_numpy(data) + +def to_numpy(data): + if not isinstance(data, fastllm.Tensor): + return None + + return np.array(data, copy=False) + +def load_weights(): + file = "/home/pan/Public/Models/models-flm/chatglm2-6b.flm" + state_dict = fastllm.load(file) + # print(state_dict.keys()) + return state_dict + +state_dict = load_weights() + +def get_sin_cos(): + base = 1e4 + dim = 128 + inv_freq = 1.0 / (base ** (np.arange(0, dim, 2) / dim)) + + t = np.arange(0, 32768) + freqs = np.einsum('i,j->ij', t, inv_freq) + + emb = np.concatenate((freqs, freqs), axis=-1) + return np.sin(emb), np.cos(emb) + +def get_postion_id(seq_len): + pos_id = np.zeros(shape=[2, seq_len]) + for i in range(seq_len): + pos_id[0, i] = i + pos_id[1, -1] = 1 + return pos_id + +def get_mask(seq_len): + attn_mask = np.zeros(shape=[seq_len, seq_len]) + for i in range(seq_len): + attn_mask[i, -1] = 1 + + for i in range(seq_len): + for j in range(i+1, seq_len): + attn_mask[i, j] = 1 + + return attn_mask + +sin_data, cos_data = get_sin_cos() +sin_data = to_tensor(sin_data) +cos_data = to_tensor(cos_data) + + +def core_attention(q, k, v, attn_mask, pastkv): + seq_len, batch, num_attention_heads, attn_dim = q.shape + embed_dim = num_attention_heads * attn_dim + + k.reshape([k.shape[0], k.shape[1] * k.shape[2], k.shape[3]]) + v.reshape([v.shape[0], v.shape[1] * v.shape[2], v.shape[3]]) + + k = ops.permute(k, [1, 0, 2]) + v = ops.permute(v, [1, 0, 2]) + + pastKey = pastkv[0] + pastValue = pastkv[1] + + unitLen = 64 + while ( + (len(pastKey.shape) == 0 and (len(pastKey.expansionDims) == 0 or k.shape[1] > pastKey.expansionDims[1])) + or (len(pastKey.shape) > 0 and (len(pastKey.expansionDims) == 0 or pastKey.shape[1] + k.shape[1] > pastKey.expansionDims[1])) + ): + if pastKey.count(0) == 0 or len(pastKey.shape) == 0: + newDims =[k.shape[0], int(((k.shape[1] - 1) / unitLen + 1) * unitLen), k.shape[2]] + else: + newDims = pastKey.shape + newDims[1] += int(((k.shape[1] - 1) / unitLen + 1) * unitLen) + + # print(newDims) + pastKey.expansion(newDims) + + while ( + (len(pastValue.shape) == 0 and (len(pastValue.expansionDims) == 0 or v.shape[1] > pastValue.expansionDims[1])) + or (len(pastValue.shape) > 0 and (len(pastValue.expansionDims) == 0 or pastValue.shape[1] + v.shape[1] > pastValue.expansionDims[1])) + ): + if pastValue.count(0) == 0 or len(pastValue.shape) == 0: + newDims =[v.shape[0], int(((v.shape[1] - 1) / unitLen + 1) * unitLen), v.shape[2]] + else: + newDims = pastValue.shape + newDims[1] += int(((v.shape[1] - 1) / unitLen + 1) * unitLen) + + pastValue.expansion(newDims) + + pyfastllm.cat_direct(pastKey, k, 1) + pyfastllm.cat_direct(pastValue, v, 1) + + q.reshape([q.shape[0], q.shape[1] * q.shape[2], q.shape[3]]) + q = ops.permute(q, [1, 0, 2]) + + context = ops.attention(q, pastKey, pastValue, attn_mask, q.shape[0]//pastKey.shape[0], 1.0/math.sqrt(attn_dim)) + context.reshape([batch, num_attention_heads, seq_len, -1]) + context = ops.permute(context, [2, 0, 1, 3]) + context.reshape([context.dims[0], context.dims[1], embed_dim]) + return context + + +def transformer(hidden_states, i, attn_mask, num_attention_heads, rotary_dim, pos_id, pastkvs): + seq_len, batch, embed_dim = hidden_states.shape + inputRMSWeightName = f"transformer.encoder.layers.{i}.input_layernorm.weight" + atten_input = ops.rms_norm(hidden_states, state_dict[inputRMSWeightName], eps=1e-5) + # print("rms norm ok") + qkv_weight_name = f"transformer.encoder.layers.{i}.self_attention.query_key_value.weight" + qkv_bias_name = f"transformer.encoder.layers.{i}.self_attention.query_key_value.bias" + qkv = ops.linear(atten_input, state_dict[qkv_weight_name], state_dict[qkv_bias_name]) + # print("transformer qkv ok") + + qLen = embed_dim + kvLen = (qkv.shape[-1] - embed_dim) // 2 + q = ops.split(qkv, -1, 0, qLen) + k = ops.split(qkv, -1, qLen, qLen + kvLen) + v = ops.split(qkv, -1, qLen + kvLen, qLen + kvLen + kvLen) + + q.reshape([q.shape[0], q.shape[1], -1, embed_dim // num_attention_heads]) + k.reshape([k.shape[0], k.shape[1], -1, embed_dim // num_attention_heads]) + v.reshape([v.shape[0], v.shape[1], -1, embed_dim // num_attention_heads]) + + q = pyfastllm.nearlyrotateposition2D(q, pos_id, sin_data, cos_data, rotary_dim) + k = pyfastllm.nearlyrotateposition2D(k, pos_id, sin_data, cos_data, rotary_dim) + + + context = core_attention(q, k, v, attn_mask, pastkv=pastkvs[i]) + + # print("transformer attention ok") + + denseWeightName = f"transformer.encoder.layers.{i}.self_attention.dense.weight" + denseBiasName = f"transformer.encoder.layers.{i}.self_attention.dense.bias" + attnOutput = ops.linear(context, state_dict[denseWeightName], state_dict[denseBiasName]) + hidden_states = ops.add(hidden_states, attnOutput, 1.0) + + # print("transformer lr ok") + return hidden_states + + +def mlp(inputs, i): + fcInKeyName = f"transformer.encoder.layers.{i}.mlp.dense_h_to_4h" + fcOutKeyName = f"transformer.encoder.layers.{i}.mlp.dense_4h_to_h" + + middle = ops.linear(inputs, weights=state_dict[fcInKeyName+".weight"], bias=state_dict[fcInKeyName+".bias"]) + middle = ops.activation(middle, activate_type='swiglu') + middle = ops.linear(middle, weights=state_dict[fcOutKeyName+".weight"], bias=state_dict[fcOutKeyName+".bias"]) + + return middle + +def forward( + input_ids, + attn_mask, + pos_id, + pastkvs + ): + batch = input_ids.shape[0] + seq_len = input_ids.shape[1] + + input_ids = ops.permute(input_ids, [1, 0]) + input_embedding = ops.embedding(inputs=input_ids, embedding_weights=state_dict['transformer.embedding.word_embeddings.weight']) + hidden_states = input_embedding + + # print("embedding ok") + + rotary_dim = 64 + layer_num = 28 + num_attention_heads = 32 + embed_dim = 4096 + head_dim = embed_dim // num_attention_heads + scale_attn = math.sqrt(head_dim) + + for i in range(layer_num): + mlp_input = transformer(hidden_states, i, attn_mask, num_attention_heads, rotary_dim, pos_id, pastkvs) + print("transformer ok") + postRMSWeightName = f"transformer.encoder.layers.{i}.post_attention_layernorm.weight" + temp = ops.mul(hidden_states, 1.0) + mlp_input = ops.rms_norm(hidden_states, state_dict[postRMSWeightName], 1e-5) + mlp_output = mlp(mlp_input, i) + hidden_states = ops.add(mlp_output, temp, 1.0) + print("mlp ok") + + if seq_len > 1: + hidden_states = ops.split(hidden_states, 0, seq_len - 1, seq_len) + + hidden_states = ops.rms_norm(hidden_states, state_dict["transformer.encoder.final_layernorm.weight"], 1e-5) + logits = ops.linear(hidden_states, state_dict["transformer.output_layer.weight"]) + + topk = ops.topk(logits, 1) + # print("topk ok") + + topk.to("cpu") + topk_np = np.array(topk, copy=False) + token = int(topk_np[0, 0, 0] + 1e-3) + return token, pastkvs + +from transformers import AutoModel, AutoTokenizer +import math +from typing import List, Tuple + +def build_inputs(tokenizer, query: str, history: List[Tuple[str, str]] = None): + prompt = tokenizer.build_prompt(query, history=history) + inputs = tokenizer([prompt], return_tensors="np") + return inputs + +def warmup(): + bos_token_id = 64792 + input_ids = pyfastllm.Tensor(fastllm.float32, [1, 1], [bos_token_id, ]) + attn_mask = pyfastllm.Tensor(fastllm.float32, [1, 1], [0]) + pos_id = pyfastllm.Tensor(fastllm.float32, [2, 1], [0, 0]) + + pastKeyValues = [] + for i in range(28): + pastKey = pyfastllm.Tensor(fastllm.float32) + pastValue = pyfastllm.Tensor(fastllm.float32) + pastKeyValues.append([pastKey, pastValue]) + + forward(input_ids, attn_mask, pos_id, pastKeyValues) + + +def chatglm2(): + query = "你好" + model_path = "/home/pan/Public/Models/models-hf/chatglm2-6b" + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + input_ids = build_inputs(tokenizer, query=query)['input_ids'] + print(input_ids) + + batch = input_ids.shape[0] + seq_len = input_ids.shape[1] + input_ids = to_tensor(input_ids) + pos_id = to_tensor(get_postion_id(seq_len)) + attn_mask = to_tensor(get_mask(seq_len)) + + pastKeyValues = [] + for i in range(28): + pastKey = pyfastllm.Tensor(fastllm.float32) + pastValue = pyfastllm.Tensor(fastllm.float32) + pastKeyValues.append([pastKey, pastValue]) + + index = 0 + promptLen = seq_len - 2 + + results = [] + while True: + token, pastKeyValues = forward(input_ids, attn_mask, pos_id, pastKeyValues) + + if token == 2: + break + + results.append(token) + ret = tokenizer.decode(results) + print(ret) + + index += 1 + + if index >= 256: + break + + input_ids.copy_from(fastllm.Tensor(fastllm.float32, [1, 1], [token])) + attn_mask = fastllm.Tensor(fastllm.float32) + pos_id.copy_from(fastllm.Tensor(fastllm.float32, [2, 1], [promptLen + index + 1, (index + 1)])) + + +if __name__ == "__main__": + # warmup() + chatglm2() diff --git a/pyfastllm/examples/test_chatglm2_cpp.py b/pyfastllm/examples/test_chatglm2_cpp.py new file mode 100644 index 00000000..c1eb657a --- /dev/null +++ b/pyfastllm/examples/test_chatglm2_cpp.py @@ -0,0 +1,263 @@ +import sys +import pytest +import numpy as np +import fastllm + +import pyfastllm +import gc + +np.random.seed(42) +from fastllm import ops +# import ops + +def to_tensor(data): + return pyfastllm.from_numpy(data) + +def to_numpy(data): + return np.array(data, copy=False) + + +## 模型测试 +def load_weights(): + file = "/home/pan/Public/Models/models-flm/chatglm2-6b.flm" + state_dict = fastllm.load(file) + # print(state_dict.keys()) + return state_dict + +state_dict = load_weights() + +def get_sin_cos(): + base = 1e4 + dim = 128 + inv_freq = 1.0 / (base ** (np.arange(0, dim, 2) / dim)) + + t = np.arange(0, 32768) + freqs = np.einsum('i,j->ij', t, inv_freq) + + emb = np.concatenate((freqs, freqs), axis=-1) + return np.sin(emb), np.cos(emb) + + +def get_postion_id(seq_len): + pos_id = np.zeros(shape=[2, seq_len]) + for i in range(seq_len): + pos_id[0, i] = i + pos_id[1, -1] = 1 + return pos_id + +def get_mask(seq_len): + attn_mask = np.zeros(shape=[seq_len, seq_len]) + for i in range(seq_len): + attn_mask[i, -1] = 1 + + for i in range(seq_len): + for j in range(i+1, seq_len): + attn_mask[i, j] = 1 + + return attn_mask + + + +from transformers import AutoModel, AutoTokenizer +import math +from typing import List, Tuple + +def build_inputs(tokenizer, query: str, history: List[Tuple[str, str]] = None): + prompt = tokenizer.build_prompt(query, history=history) + inputs = tokenizer([prompt], return_tensors="np") + return inputs + +sin_data, cos_data = get_sin_cos() +sin_data = to_tensor(sin_data) +cos_data = to_tensor(cos_data) + +def chatglm(): + query = "你好" + model_path = "/home/pan/Public/Models/models-hf/chatglm2-6b" + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + input_ids = build_inputs(tokenizer, query=query)['input_ids'] + print(input_ids) + + batch = input_ids.shape[0] + seq_len = input_ids.shape[1] + input_ids = to_tensor(input_ids) + pos_id = to_tensor(get_postion_id(seq_len)) + attn_mask = to_tensor(get_mask(seq_len)) + + pastKeyValues = [] + for i in range(28): + pastKey = pyfastllm.Tensor(fastllm.float32) + pastValue = pyfastllm.Tensor(fastllm.float32) + pastKeyValues.append([pastKey, pastValue]) + + index = 0 + promptLen = seq_len - 2 + + results = [] + while True: + token, pastKeyValues = forward(input_ids, attn_mask, pos_id, pastKeyValues) + + if token == 2: + break + + results.append(token) + ret = tokenizer.decode(results) + print(ret) + + index += 1 + + if index >= 256: + break + + input_ids.copy_from(fastllm.Tensor(fastllm.float32, [1, 1], [token])) + attn_mask = fastllm.Tensor(fastllm.float32) + pos_id.copy_from(fastllm.Tensor(fastllm.float32, [2, 1], [promptLen + index + 1, (index + 1)])) + +def warmup(): + bos_token_id = 64792 + input_ids = pyfastllm.Tensor(fastllm.float32, [1, 1], [bos_token_id, ]) + attn_mask = pyfastllm.Tensor(fastllm.float32, [1, 1], [0]) + pos_id = pyfastllm.Tensor(fastllm.float32, [2, 1], [0, 0]) + + pastKeyValues = [] + for i in range(28): + pastKey = pyfastllm.Tensor(fastllm.float32) + pastValue = pyfastllm.Tensor(fastllm.float32) + pastKeyValues.append([pastKey, pastValue]) + + forward(input_ids, attn_mask, pos_id, pastKeyValues) + +def mlp(inputs, i): + fcInKeyName = f"transformer.encoder.layers.{i}.mlp.dense_h_to_4h" + fcOutKeyName = f"transformer.encoder.layers.{i}.mlp.dense_4h_to_h" + + middle = ops.linear(inputs, weights=state_dict[fcInKeyName+".weight"], bias=state_dict[fcInKeyName+".bias"]) + middle = ops.activation(middle, activate_type='swiglu') + middle = ops.linear(middle, weights=state_dict[fcOutKeyName+".weight"], bias=state_dict[fcOutKeyName+".bias"]) + + return middle + +def forward(input_ids, + attn_mask, + pos_id, + pastkv + ): + + batch = input_ids.shape[0] + seq_len = input_ids.shape[1] + + input_ids = ops.permute(input_ids, [1, 0]) + input_embedding = ops.embedding(inputs=input_ids, embedding_weights=state_dict['transformer.embedding.word_embeddings.weight']) + hidden_states = input_embedding + + rotary_dim = 128 + layer_num = 28 + num_attention_heads = 32 + embed_dim = 4096 + head_dim = embed_dim // num_attention_heads + scale_attn = math.sqrt(head_dim) + + for i in range(layer_num): + print(i) + inputRMSWeightName = f"transformer.encoder.layers.{i}.input_layernorm.weight" + atten_input = ops.rms_norm(hidden_states, state_dict[inputRMSWeightName]) + # print(atten_input) + + qkv_weight_name = f"transformer.encoder.layers.{i}.self_attention.query_key_value.weight" + qkv_bias_name = f"transformer.encoder.layers.{i}.self_attention.query_key_value.bias" + qkv = ops.linear(atten_input, state_dict[qkv_weight_name], state_dict[qkv_bias_name]) + + qLen = embed_dim + kvLen = (qkv.shape[-1] - embed_dim) // 2 + q = ops.split(qkv, -1, 0, qLen) + k = ops.split(qkv, -1, qLen, qLen + kvLen) + v = ops.split(qkv, -1, qLen + kvLen, qLen + kvLen + kvLen) + + q.reshape([q.shape[0], q.shape[1], -1, embed_dim // num_attention_heads]) + k.reshape([k.shape[0], k.shape[1], -1, embed_dim // num_attention_heads]) + v.reshape([v.shape[0], v.shape[1], -1, embed_dim // num_attention_heads]) + + q = pyfastllm.nearlyrotateposition2D(q, pos_id, sin_data, cos_data, rotary_dim) + k = pyfastllm.nearlyrotateposition2D(k, pos_id, sin_data, cos_data, rotary_dim) + + # print(k) + + k.reshape([k.shape[0], k.shape[1] * k.shape[2], k.shape[3]]) + v.reshape([v.shape[0], v.shape[1] * v.shape[2], v.shape[3]]) + + k = ops.permute(k, [1, 0, 2]) + v = ops.permute(v, [1, 0, 2]) + + pastKey = pastkv[i][0] + pastValue = pastkv[i][1] + + unitLen = 64 + while ( + (len(pastKey.shape) == 0 and (len(pastKey.expansionDims) == 0 or k.shape[1] > pastKey.expansionDims[1])) + or (len(pastKey.shape) > 0 and (len(pastKey.expansionDims) == 0 or pastKey.shape[1] + k.shape[1] > pastKey.expansionDims[1])) + ): + if pastKey.count(0) == 0 or len(pastKey.shape) == 0: + newDims =[k.shape[0], int(((k.shape[1] - 1) / unitLen + 1) * unitLen), k.shape[2]] + else: + newDims = pastKey.shape + newDims[1] += int(((k.shape[1] - 1) / unitLen + 1) * unitLen) + + # print(newDims) + pastKey.expansion(newDims) + + while ( + (len(pastValue.shape) == 0 and (len(pastValue.expansionDims) == 0 or v.shape[1] > pastValue.expansionDims[1])) + or (len(pastValue.shape) > 0 and (len(pastValue.expansionDims) == 0 or pastValue.shape[1] + v.shape[1] > pastValue.expansionDims[1])) + ): + if pastValue.count(0) == 0 or len(pastValue.shape) == 0: + newDims =[v.shape[0], int(((v.shape[1] - 1) / unitLen + 1) * unitLen), v.shape[2]] + else: + newDims = pastValue.shape + newDims[1] += int(((v.shape[1] - 1) / unitLen + 1) * unitLen) + + pastValue.expansion(newDims) + + pyfastllm.cat_direct(pastKey, k, 1) + pyfastllm.cat_direct(pastValue, v, 1) + + q.reshape([q.shape[0], q.shape[1] * q.shape[2], q.shape[3]]) + q = ops.permute(q, [1, 0, 2]) + + context = ops.attention(q, pastKey, pastValue, attn_mask, q.shape[0]//pastKey.shape[0], 1.0 / scale_attn) + context.reshape([batch, num_attention_heads, seq_len, -1]) + context = ops.permute(context, [2, 0, 1, 3]) + context.reshape([context.dims[0], context.dims[1], embed_dim]) + + denseWeightName = f"transformer.encoder.layers.{i}.self_attention.dense.weight" + denseBiasName = f"transformer.encoder.layers.{i}.self_attention.dense.bias" + attnOutput = ops.linear(context, state_dict[denseWeightName],state_dict[denseBiasName]) + hidden_states = ops.add(hidden_states, attnOutput, 1.0) + + temp = ops.mul(hidden_states, 1.0) + postRMSWeightName = f"transformer.encoder.layers.{i}.post_attention_layernorm.weight" + mlp_input = ops.rms_norm(hidden_states, state_dict[postRMSWeightName], 1e-5) + mlp_output = mlp(mlp_input, i) + hidden_states = ops.add(mlp_output, temp, 1.0) + # hidden_states = mlp_output + gc.collect() + + if seq_len > 1: + hidden_states = ops.split(hidden_states, 0, seq_len - 1, seq_len) + + hidden_states = ops.rms_norm(hidden_states, state_dict["transformer.encoder.final_layernorm.weight"], 1e-5) + logits = ops.linear(hidden_states, state_dict["transformer.output_layer.weight"]) + + topk = ops.topk(logits, 1) + # print(topk.size()) + # print(topk) + print(topk) + topk_np = np.array(topk, copy=False) + token = int(topk_np[0, 0, 0] + 1e-3) + print(token) + return token, pastkv + + +if __name__ == "__main__": + # warmup() + chatglm() + diff --git a/pyfastllm/examples/test_ops.py b/pyfastllm/examples/test_ops.py index 576a4beb..8b862ac5 100644 --- a/pyfastllm/examples/test_ops.py +++ b/pyfastllm/examples/test_ops.py @@ -1,95 +1,90 @@ +import sys import pytest import numpy as np import fastllm -def np_rms_norm(inputs, weights, eps): - channel = inputs.shape[-1] - sqrt_mean = np.sqrt(np.sum(inputs**2)/channel + eps) - return inputs / sqrt_mean *weights +import pyfastllm +import gc +# import np_ops +# import ops as flm_ops -def np_layer_norm(inputs, gamma, beta, axis=-1): - assert axis < len(inputs.shapes), "axis should less than inputs dims" - channel = inputs.shape[axis] - mean = np.mean(inputs, axis=axis) - var = np.var(inputs, axis=axis) +from fastllm import ops as flm_ops +from fastllm import np_ops - output = (inputs - mean) / var * gamma + beta - return output +np.random.seed(42) -def np_linear(inputs, weights, bias): - output = np.matmul(inputs, weights.T) + bias - return output +def diff(dataA, dataB): + mae = np.max(np.abs(dataA - dataB)) + print('max abs err is ', mae) + return mae + +def to_tensor(data): + return pyfastllm.from_numpy(data) + +def to_numpy(data): + return np.array(data, copy=False, order='C') + +def test_rms_norm(inputs=None, weights=None, eps=1e-6): + if not inputs: + inputs = np.random.random(size=[1, 256]) + weights = np.random.random(size=[1, 256]) -def np_softmax(inputs, axis=None): - maxv = inputs.max(axis, keepdims=True) - exp_v = np.exp(inputs - maxv) - exp_sum = np.sum(exp_v, axis=axis) - return exp_v / exp_sum + np_out = np_ops.rms_norm(inputs, weights, eps) + flm_out = flm_ops.rms_norm(to_tensor(inputs), to_tensor(weights), eps) + mae = diff(np_out, to_numpy(flm_out)) + assert mae <= 1e-6 + return flm_out -def np_silu(inputs, ): - return inputs / (1 + np.exp(-inputs)) +def test_swiglu(inputs=None): + if not inputs: + inputs = np.array([1, 5]).reshape([1, 2]) -def np_attention(q, k, v, mask=None, group=None, scale=None): - qk = np_softmax(q @ k.T * scale, axis=-1) - attn = qk @ v - return attn - -def test_linear(): - inputs = np.array([[1, 2]]) - weight = np.array([[3, 4, 5, 5, 6, 7]]).reshape([3, 2]) - bias = np.array([0, 1, 1]) - np_output = np_linear(inputs, weight, bias) - print(np_output) - - input = fastllm.Tensor(fastllm.float32, [1, 2], [1, 2]) - weights = fastllm.Tensor(fastllm.float32, [3, 2], [3, 4, 5, 5, 6, 7]) - bias = fastllm.Tensor(fastllm.float32, [3], [0, 1, 1]) - out = fastllm.ops.linear(input, weights, bias) - print(out) - -def test_rms_norm(): - inputs = np.array([1, 5]).reshape([1, 2]) - weights = np.array([1, 3]).reshape([1, 2]) - eps = 1e-6 - - np_out = np_rms_norm(inputs, weights, eps) - print(np_out) - - input = fastllm.Tensor(fastllm.float32, [1, 2], [1, 5]) - weights = fastllm.Tensor(fastllm.float32, [1, 2], [1, 3]) - out = fastllm.Tensor() - out = fastllm.ops.rms_norm(input, weights, eps=1e-6) - print(out) - -def test_silu(): - inputs = np.array([1, 5]).reshape([1, 2]) - output = np_softmax(inputs) - # output = np_silu(inputs) - print(output) - - inputs = fastllm.Tensor(fastllm.float32, [1, 2], [1, 5]) - out = fastllm.ops.activation(input=inputs, activate_type="softmax") - # out = fastllm.ops.activation(input=inputs, activate_type="silu") - print(out) - -def test_attention(): - q = np.array([1, 2, 3, 4, 5, 6]).reshape([2, 3]) - k = np.array([5, 6, 7, 8, 9, 10]).reshape([2, 3]) - v = np.array([1, 1, 1, 2, 1, 3]).reshape([2, 3]) - scale = 1 / np.sqrt(q.shape[-1]) - output = np_attention(q, k, v, scale=scale) - print(output) - - q = fastllm.Tensor(fastllm.float32, [1, 2, 3], [1, 2, 3, 4, 5, 6]) - k = fastllm.Tensor(fastllm.float32, [1, 2, 3], [5, 6, 7, 8, 9, 10]) - v = fastllm.Tensor(fastllm.float32, [1, 2, 3], [1, 1, 1, 2, 1, 3]) + np_out = np_ops.swiglu(inputs) + out = flm_ops.activation(input=to_tensor(inputs), activate_type="swiglu") + mae = diff(np_out, out) + assert mae <= 1e-6 + return out + +def test_attention(q=None, k=None, v=None, mask=None, group=1, scale=1.0): + if q is None: + q = np.random.random(size=[1, 12, 4096]) + k = np.random.random(size=[1, 12, 4096]) + v = np.random.random(size=[1, 12, 4096]) + scale = 1 / np.sqrt(q.shape[-1]) + + np_out = np_ops.attention(q, k, v, scale=scale) + mask = fastllm.Tensor() - output = fastllm.ops.attention(q, k, v, mask, group=1, scale=scale, attentionType=0) - print(output) + flm_out = flm_ops.attention(to_tensor(q), to_tensor(k), to_tensor(v), mask, group=group, scale=scale, attentionType=0) + + mae = diff(np_out, to_numpy(flm_out)) + assert mae <= 1e-6 + return flm_out + -if __name__ == "__main__": +def test_linear(inputs=None, + weights=None, + bias=None): + + if not inputs: + inputs = np.random.random(size=[1, 2, 4096]) + + np_out = np_ops.linear(inputs=inputs, weights=weights, bias=None) + + if not bias: + bias = fastllm.Tensor() + + output = flm_ops.linear(to_tensor(inputs), weights, bias) + mae = diff(np_out, to_numpy(output)) + + assert mae <= 1e-6 + return output + +if __name__ == "__main__": + test_rms_norm() test_attention() - test_silu() test_linear() - test_rms_norm() + test_swiglu() + + diff --git a/pyfastllm/fastllm/functions/__init__.py b/pyfastllm/fastllm/functions/__init__.py index efb0421c..8b137891 100644 --- a/pyfastllm/fastllm/functions/__init__.py +++ b/pyfastllm/fastllm/functions/__init__.py @@ -1 +1 @@ -from .fastllm_ops import * + diff --git a/pyfastllm/fastllm/functions/custom_ops.py b/pyfastllm/fastllm/functions/custom_ops.py index e69de29b..b0c7fa92 100644 --- a/pyfastllm/fastllm/functions/custom_ops.py +++ b/pyfastllm/fastllm/functions/custom_ops.py @@ -0,0 +1 @@ +import triton as tl \ No newline at end of file diff --git a/pyfastllm/fastllm/functions/fastllm_ops.py b/pyfastllm/fastllm/functions/fastllm_ops.py index 5ca0a618..08ed7bb5 100644 --- a/pyfastllm/fastllm/functions/fastllm_ops.py +++ b/pyfastllm/fastllm/functions/fastllm_ops.py @@ -1,31 +1,39 @@ import pyfastllm +def embedding(inputs: pyfastllm.Tensor, embedding_weights:pyfastllm.Tensor): + output = pyfastllm.Tensor() + pyfastllm.embedding(inputs, embedding_weights, output) + return output -def embedding(data: pyfastllm.Tensor, ): - # some check - return pyfastllm.embedding(data, ) - -def rms_norm(input:pyfastllm.Tensor, weight: pyfastllm.Tensor, eps: float, output: pyfastllm.Tensor=None): - output = pyfastllm.rms_norm(input, weight, eps) +def rms_norm(inputs:pyfastllm.Tensor, weights: pyfastllm.Tensor, eps: float=1e-5): + output = pyfastllm.Tensor() + pyfastllm.rms_norm(inputs, weights, eps, output) return output -def layer_norm(input: pyfastllm.Tensor, +def layer_norm(inputs: pyfastllm.Tensor, gamma: pyfastllm.Tensor, beta: pyfastllm.Tensor, axis:int=-1 ): - output = pyfastllm.layer_norm(input, gamma, beta,axis) + output = pyfastllm.Tensor() + pyfastllm.layer_norm(inputs, gamma, beta,axis, output) return output -def linear(input: pyfastllm.Tensor, - weight: pyfastllm.Tensor, - bias: pyfastllm.Tensor): - output = pyfastllm.linear(input, weight, bias) +def linear(inputs: pyfastllm.Tensor, + weights: pyfastllm.Tensor, + bias: pyfastllm.Tensor=None): + output = pyfastllm.Tensor() + # print(weights) + if not bias: + bias = pyfastllm.Tensor() + + pyfastllm.linear(inputs, weights, bias, output) return output -def matmul(input0: pyfastllm.Tensor, - input1: pyfastllm.Tensor, +def matmul(inputs0: pyfastllm.Tensor, + inputs1: pyfastllm.Tensor, alpha: pyfastllm.Tensor): - output = pyfastllm.matmul(input0, input1, alpha) + output = pyfastllm.Tensor() + pyfastllm.matmul(inputs0, inputs1, alpha, output) return output def attention(q: pyfastllm.Tensor, @@ -34,26 +42,45 @@ def attention(q: pyfastllm.Tensor, mask: pyfastllm.Tensor, group: int, scale: float, - attentionType: int): - output = pyfastllm.attention(q, k, v, mask, group, scale, attentionType) + attentionType:int = 0): + output = pyfastllm.Tensor() + pyfastllm.attention(q, k, v, mask, group, scale, attentionType, output) return output -def activation(input: pyfastllm.Tensor, axis=-1, activate_type="silu"): +def activation(inputs: pyfastllm.Tensor, axis=-1, activate_type="silu"): assert activate_type in ("softmax", "silu", "gelu", "swiglu") func = getattr(pyfastllm, activate_type) + + output = pyfastllm.Tensor() if activate_type == "softmax": - return func(input, axis) - return func(input) + func(inputs, axis, output) + else: + func(inputs, output) + return output -def mul(input: pyfastllm.Tensor, v: int): - output = pyfastllm.mul(input, v) +def mul(inputs: pyfastllm.Tensor, v: int): + output = pyfastllm.Tensor() + pyfastllm.mul(inputs, v, output) return output -def matmul_transB(): - pass +def add(input0: pyfastllm.Tensor, input1: pyfastllm.Tensor, v: int): + output = pyfastllm.Tensor() + output = pyfastllm.add(input0, input1, v) + return output + +def permute(inputs: pyfastllm.Tensor, dims=None): + output = pyfastllm.Tensor() + pyfastllm.permute(inputs, dims, output) + return output + +def split(inputs: pyfastllm.Tensor, axis:int, start:int, end:int): + output = pyfastllm.Tensor() + pyfastllm.split(inputs, axis, start, end, output) + return output -def add(input0: pyfastllm.Tensor, input1: pyfastllm.Tensor): - output = pyfastllm.add(input0, input1) +def topk(logits:pyfastllm.Tensor, axis:int = 1): + output = pyfastllm.Tensor() + pyfastllm.topk(logits, axis, output) return output def AttentionMask(): diff --git a/pyfastllm/fastllm/functions/numpy_ops.py b/pyfastllm/fastllm/functions/numpy_ops.py new file mode 100644 index 00000000..3fb36ba7 --- /dev/null +++ b/pyfastllm/fastllm/functions/numpy_ops.py @@ -0,0 +1,41 @@ +import numpy as np +from numba import cuda +from numba import jit + +@jit(nopython=True) +def rms_norm(inputs, weights, eps): + channel = inputs.shape[-1] + sqrt_mean = np.sqrt(np.sum(inputs**2)/channel + eps) + return inputs / sqrt_mean *weights + +@jit(nopython=True) +def layer_norm(inputs, gamma, beta, axis=-1): + assert axis < len(inputs.shapes), "axis should less than inputs dims" + channel = inputs.shape[axis] + mean = np.mean(inputs, axis=axis) + var = np.var(inputs, axis=axis) + + output = (inputs - mean) / var * gamma + beta + return output + +@jit(nopython=True) +def softmax(inputs, axis=None): + maxv = inputs.max(axis, keepdims=True) + exp_v = np.exp(inputs - maxv) + exp_sum = np.sum(exp_v, axis=axis) + return exp_v / exp_sum + +@jit(nopython=True) +def silu(inputs, ): + return inputs / (1 + np.exp(-inputs)) + +@jit(nopython=True) +def linear(inputs, weights, bias): + output = np.matmul(inputs, weights.T) + bias + return output + +@jit(nopython=True) +def np_self_attention(q, k, v, mask=None, group=None, scale=None): + qk = softmax(q @ k.T * scale, axis=-1) + attn = qk @ v + return attn \ No newline at end of file diff --git a/pyfastllm/fastllm/nn/__init__.py b/pyfastllm/fastllm/nn/__init__.py index f026804c..b886302b 100644 --- a/pyfastllm/fastllm/nn/__init__.py +++ b/pyfastllm/fastllm/nn/__init__.py @@ -1 +1,2 @@ -from BaseModule import Module +from .base_module import Module +from .modules import Linear diff --git a/pyfastllm/fastllm/nn/BaseModule.py b/pyfastllm/fastllm/nn/base_module.py similarity index 80% rename from pyfastllm/fastllm/nn/BaseModule.py rename to pyfastllm/fastllm/nn/base_module.py index 942834e4..f391290a 100644 --- a/pyfastllm/fastllm/nn/BaseModule.py +++ b/pyfastllm/fastllm/nn/base_module.py @@ -1,5 +1,6 @@ +import pyfastllm from typing import Any - +from abc import abstractmethod class Module(): def __init__(self) -> None: @@ -8,11 +9,9 @@ def __init__(self) -> None: def __call__(self, *args: Any, **kwds: Any) -> Any: return self.forward(*args, **args) - @classmethod + @abstractmethod def forward(self, ): pass def _init_weight(self, ): pass - - diff --git a/pyfastllm/fastllm/nn/modules.py b/pyfastllm/fastllm/nn/modules.py new file mode 100644 index 00000000..689e5fc8 --- /dev/null +++ b/pyfastllm/fastllm/nn/modules.py @@ -0,0 +1,63 @@ +from .base_module import Module +from ..functions import fastllm_ops as F + +class Linear(Module): + def __init__(self, input_size, output_size, bias=False) -> None: + self.weight = None + self.bias = None + super().__init__() + + def forward(self, x): + return F.linear(x, self.weight, self.bias) + +class SiLU(Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, axis=-1): + return F.activation(x, axis=axis, activate_type='silu') + +class SwiGLU(Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, inputs): + return F.activation(input=inputs, activate_type="swiglu") + +class Embedding(Module): + def __init__(self) -> None: + super().__init__() + + def forward(self): + return F.embedding() + +class Embedding(Module): + def __init__(self, vocab_size, embed_dim) -> None: + super().__init__() + self.vocab_size = vocab_size + self.embed_dim = embed_dim + self.embedding_weights = None + + # def _init_weight(self): + # self.embedding_weights = to_tensor(np.random.random(size=[self.vocab_size, self.embed_dim])) + + def forward(self, inputs): + return F.embedding(inputs, self.embedding_weights) + +class RMSNorm(Module): + def __init__(self) -> None: + super().__init__() + self.weights = None + + def _init_weight(self): + return super()._init_weight() + + def forward(self, inputs): + return F.rms_norm(inputs, self.weights, eps=self.eps) + +class Attention(Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, q, k, v, mask, group, scale): + return F.attention(q, k, v, mask, group=group, scale=scale, attentionType=0) \ No newline at end of file diff --git a/src/fastllm.cpp b/src/fastllm.cpp index be044e55..acf5c00c 100644 --- a/src/fastllm.cpp +++ b/src/fastllm.cpp @@ -381,6 +381,7 @@ namespace fastllm { this->expansionBytes = (size * this->unitSize - 1) / this->unitSizeDiv + 1; if (this->dataDevice == DataDevice::CPU) { this->cpuData = new uint8_t[this->expansionBytes]; + memset(this->cpuData, 0, this->expansionBytes*sizeof(uint8_t)); } else if (this->dataDevice == DataDevice::CUDA) { #ifdef USE_CUDA if (this->directMemory) { @@ -559,13 +560,13 @@ namespace fastllm { */ int n = Count(0) / dims.back(), m = dims.back(); for (int i = 0; i < n; i++) { - for (int j = 0; j < 10 && j < m; j++) { + for (int j = 0; j < 3 && j < m; j++) { printf("%f ", ((float*)cpuData)[i * m + j]); } - if (m > 10) { + if (m > 3) { printf("... "); - for (int j = 0; j < 10 && j < m; j++) { - printf("%f ", ((float *) cpuData)[i * m + (m - 10 + j)]); + for (int j = 0; j < 3 && j < m; j++) { + printf("%f ", ((float *) cpuData)[i * m + (m - 3 + j)]); } } printf("\n"); diff --git a/src/pybinding.cpp b/src/pybinding.cpp index 41f5ff20..aace6155 100644 --- a/src/pybinding.cpp +++ b/src/pybinding.cpp @@ -1,13 +1,24 @@ #include "model.h" #include "factoryllm.h" +#include +#include + + namespace pyfastllm{ // 对接不断更新的后端接口 // 需优化,减少内存拷贝 - fastllm::Data RMSNorm(const fastllm::Data &input, const fastllm::Data &weight, float eps){ - fastllm::Data output; + + fastllm::Data Embedding(const fastllm::Data &input, fastllm::Data &weight, fastllm::Data &output){ + fastllm::Embedding(input, weight, output); + // output.ToDevice(fastllm::DataDevice::CPU); + return output; + } + fastllm::Data &RMSNorm(const fastllm::Data &input, const fastllm::Data &weight, float eps, fastllm::Data &output){ + // fastllm::Data output; // std::cout<<"run rms norm"< &axis, fastllm::Data &output){ + fastllm::Permute(input, axis, output); + // output.ToDevice(fastllm::DataDevice::CPU); + return output; + } + + fastllm::Data Cat(const fastllm::Data &input0, const fastllm::Data &input1, int axis) { + fastllm::Data output; + fastllm::Cat(input0, input1, axis, output); + // output.ToDevice(fastllm::DataDevice::CPU); + return output; + } + + fastllm::Data CatDirect(fastllm::Data &input0, const fastllm::Data &input1, int axis) { + fastllm::CatDirect(input0, input1, axis); + // input0.ToDevice(fastllm::DataDevice::CPU); return input0; } + fastllm::Data TopK(const fastllm::Data &input, int topk, fastllm::Data &output) { + fastllm::TopK(input, output, topk); + // output.ToDevice(fastllm::DataDevice::CPU); + return output; + } + + fastllm::Data RotatePosition2D(fastllm::Data &input, const fastllm::Data &positionIds, fastllm::Data &sinData, fastllm::Data &cosData, int rotaryDim){ + fastllm::RotatePosition2D(input, positionIds, sinData, cosData, rotaryDim); + return input; + } + + fastllm::Data NearlyRotatePosition2D(fastllm::Data &input, + const fastllm::Data &positionIds, + fastllm::Data &sinData, + fastllm::Data &cosData, + int rotaryDim){ + fastllm::NearlyRotatePosition2D(input, positionIds, sinData, cosData, rotaryDim); + return input; + } + std::string String(const fastllm::Data &data){ std::string ss; - ss += "["; + ss += "tensor(["; int last_dim = data.dims.back(); int n = data.Count(0) / last_dim, m = last_dim; for (int i = 0; i < n; i++) { if (i > 0) ss += "\n"; - for (int j = 0; j < 10 && j < m; j++) { - if (j>0) ss += " "; + for (int j = 0; j < 3 && j < m; j++) { + if (j>0) ss += ", "; ss += std::to_string(reinterpret_cast(data.cpuData)[i*m+j]); } - if (m > 10) { - ss += "... "; - for (int j = 0; j < 10 && j < m; j++) { - if (j>0) ss += " "; - ss += std::to_string(reinterpret_cast(data.cpuData)[i*m + (m-10+j)]); + if (m > 3) { + ss += "..., "; + for (int j = 0; j < 3 && j < m; j++) { + if (j>0) ss += ", "; + ss += std::to_string(reinterpret_cast(data.cpuData)[i*m + (m-3+j)]); } } } - ss += "]"; + ss += "])"; + return ss; + } + + std::string Size(const fastllm::Data &data){ + std::string ss = "Size(["; + for (int i : data.dims) { + ss += std::to_string(i) + ", "; + } + ss += "])"; return ss; } + + fastllm::Data ToDevice(fastllm::Data &data, const std::string &devices){ + size_t pos = devices.find(":"); + int len = devices.length(); + + std::vectordeviceIds; + std::string deviceStr = devices; + int device = fastllm::DataDevice::CPU; + int deviceNum = 0; + + if (pos != -1){ + int deviceNum = atoi(devices.substr(pos, len-pos-1).c_str()); + deviceStr = devices.substr(0, pos); + } + + deviceIds = {deviceNum}; + std::cout< + fastllm::Data fromNumpy(pybind11::array_t NpData){ + pybind11::buffer_info buf = NpData.request(); + // printf("%u \n", buf.ptr); + float *ptr = (float*) buf.ptr; + + std::vector dataSize; + int dataNum = 1; + for (auto sz:buf.shape){ + dataSize.emplace_back((int)sz); + dataNum *= sz; + } + + std::vector Vdata; + for (int i=0;i #include #include @@ -145,15 +272,14 @@ PYBIND11_MODULE(pyfastllm, m) { }); // low level m.def("get_llm_type", &fastllm::GetModelTypeFromFile); - m.def("llm_sampling", &fastllm::LLMSampling) - // .def("embedding", &fastllm::Embedding) + .def("embedding", &pyfastllm::Embedding) .def("rms_norm", &pyfastllm::RMSNorm) .def("layer_norm", &pyfastllm::LayerNorm) .def("linear", &pyfastllm::Linear) - // .def("split", &fastllm::Split) - // .def("cat", &fastllm::Cat) - // .def("cat_direct", &fastllm::CatDirect) + .def("split", &pyfastllm::Split) + .def("cat", &pyfastllm::Cat) + .def("cat_direct", &pyfastllm::CatDirect) .def("matmul", &pyfastllm::MatMul) // .def("matmul_transB", &fastllm::MatMulTransB) .def("softmax", &pyfastllm::Softmax) @@ -161,18 +287,21 @@ PYBIND11_MODULE(pyfastllm, m) { .def("gelu", &pyfastllm::Gelu) .def("swiglu", &pyfastllm::Swiglu) .def("mul", &pyfastllm::Mul) - .def("attention", &pyfastllm::Attention); - // .def("mul_to", &fastllm::MulTo) - // .def("add_to", &fastllm::AddTo) + .def("attention", &pyfastllm::Attention) + .def("mul_to", &fastllm::MulTo) + .def("add_to", &fastllm::AddTo) + .def("add", &pyfastllm::Add) // .def("attention_mask", &fastllm::AttentionMask) // .def("alibi_mask", &fastllm::AlibiMask) - // .def("permute", &fastllm::Permute) + .def("permute", &pyfastllm::Permute) // .def("permute_self", &fastllm::PermuteSelf) - // .def("topk", &fastllm::TopK) - // .def("rotateposition2D", &fastllm::RotatePosition2D) - // .def("nearlyrotateposition2D", &fastllm::NearlyRotatePosition2D) - // .def("llama_rotateposition2D", &fastllm::LlamaRotatePosition2D) + .def("topk", &pyfastllm::TopK) + .def("rotateposition2D", &pyfastllm::RotatePosition2D) + .def("nearlyrotateposition2D", &pyfastllm::NearlyRotatePosition2D) + .def("llama_rotateposition2D", &fastllm::LlamaRotatePosition2D) // .def("repeat_penalty", &fastllm::RepeatPenalty); + .def("load", &pyfastllm::LoadWeights) + .def("from_numpy", &pyfastllm::fromNumpy); py::enum_(m, "Dtype") .value("float32", fastllm::DataType::FLOAT32) @@ -194,19 +323,24 @@ PYBIND11_MODULE(pyfastllm, m) { py::format_descriptor::format(), /* Python struct-style format descriptor */ m.dims.size(), /* Number of dimensions */ m.dims, /* Buffer dimensions */ - { sizeof(float) * m.dims[1], /* Strides (in bytes) for each index */ - sizeof(float) } + m.strides + // { sizeof(float) * m.dims[1], /* Strides (in bytes) for each index */ + // sizeof(float) } ); }) .def_readonly("dims", &fastllm::Data::dims) + .def_readonly("expansionDims", &fastllm::Data::expansionDims) .def(py::init<>()) .def(py::init()) .def(py::init&>()) .def(py::init&, const std::vector&>()) .def(py::init()) - .def_readonly("shape", &fastllm::Data::dims) .def("copy_from", &fastllm::Data::CopyFrom) .def("count", &fastllm::Data::Count) + + .def_readonly("shape", &fastllm::Data::dims) + .def("reshape", &fastllm::Data::Reshape) + .def("expansion", &fastllm::Data::Expansion) .def("to_list", [](fastllm::Data& data){ std::vector vecData; for (int i = 0; i < data.Count(0); i++) { @@ -215,8 +349,8 @@ PYBIND11_MODULE(pyfastllm, m) { return vecData; }) .def("__str__", &pyfastllm::String) - .def("print", &fastllm::Data::Print) - .def("to", static_cast(&fastllm::Data::ToDevice)); + .def("size", &pyfastllm::Size) + .def("to", &pyfastllm::ToDevice); m.def("zeros", [](const std::vector &dims, fastllm::DataType dtype)->fastllm::Data { int nums = 1; @@ -281,11 +415,20 @@ PYBIND11_MODULE(pyfastllm, m) { py::class_(m, "WeightMap") .def_readonly("tokenizer", &fastllm::WeightMap::tokenizer) + .def("load", &fastllm::WeightMap::LoadFromFile) .def("save_lowbit", &fastllm::WeightMap::SaveLowBitModel) .def("set_kv", &fastllm::WeightMap::AddDict) .def("set_weight", &fastllm::WeightMap::AddWeight) .def("__getitem__", [](fastllm::WeightMap &weight, std::string key){ - return weight[key]; }); + return weight[key]; + }) + .def("keys", [](fastllm::WeightMap &weight){ + std::vector keys; + for (auto iter:weight.weight){ + keys.push_back(iter.first); + } + return keys; + }); // model classes From e161065f53c9d60f66a2d21619ba3014ac9852f6 Mon Sep 17 00:00:00 2001 From: wildkid1024 Date: Wed, 13 Mar 2024 18:06:07 +0800 Subject: [PATCH 2/2] =?UTF-8?q?python=E7=AB=AF=E5=AE=8C=E6=95=B4=E6=94=AF?= =?UTF-8?q?=E6=8C=81chatglm2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyfastllm/README.md | 14 + pyfastllm/examples/test_chatglm2.py | 31 ++ pyfastllm/examples/test_chatglm2_cpp.py | 2 +- .../{chatglm2.py => test_chatglm2_func.py} | 44 +-- pyfastllm/examples/test_ops.py | 31 +- pyfastllm/fastllm/functions/__init__.py | 3 +- pyfastllm/fastllm/functions/fastllm_ops.py | 22 +- pyfastllm/fastllm/functions/numpy_ops.py | 39 ++- pyfastllm/fastllm/functions/util.py | 18 + pyfastllm/fastllm/hub/__init__.py | 0 pyfastllm/fastllm/hub/chatglm2.py | 330 ++++++++++++++++++ pyfastllm/fastllm/nn/__init__.py | 2 +- pyfastllm/fastllm/nn/base_module.py | 50 ++- pyfastllm/fastllm/nn/modules.py | 71 ++-- src/pybinding.cpp | 320 ++++++++++++----- test/ops/cppOps.cpp | 4 + 16 files changed, 823 insertions(+), 158 deletions(-) create mode 100644 pyfastllm/examples/test_chatglm2.py rename pyfastllm/examples/{chatglm2.py => test_chatglm2_func.py} (86%) create mode 100644 pyfastllm/fastllm/functions/util.py create mode 100644 pyfastllm/fastllm/hub/__init__.py create mode 100644 pyfastllm/fastllm/hub/chatglm2.py diff --git a/pyfastllm/README.md b/pyfastllm/README.md index 283c3d8f..49e3aa9c 100644 --- a/pyfastllm/README.md +++ b/pyfastllm/README.md @@ -13,6 +13,20 @@ pyfastllm是基于fastllm的python api接口实现,通过pyfastllm可以更加 ## 版本更新 + +### 已知BUG +1. 从cpp到python存在内存拷贝 +2. 由于1的问题,fastllm后端采用的深拷贝策略,cuda data将被忽略 +3. 每个op都将转化为Host端返回,GPU内存释放存在问题 + + +### v0.2.1.1 2024-03-13 +- 解决了numpy转换的一些bug +- 增加了一些Module +- 增加了op测试 +- 完整测试并支持chatglm2 + + ### v0.2.1 2024-03-08 - 增加了低级python接口 - 测试低级接口,实现了纯python版本的chatglm2 diff --git a/pyfastllm/examples/test_chatglm2.py b/pyfastllm/examples/test_chatglm2.py new file mode 100644 index 00000000..d9d41a10 --- /dev/null +++ b/pyfastllm/examples/test_chatglm2.py @@ -0,0 +1,31 @@ +import fastllm +from fastllm.hub.chatglm2 import ChatGLM2, ChatGLMConfig +import fastllm.functions as ops + +from transformers import AutoTokenizer + +def load_weights(): + file = "/home/pan/Public/Models/models-flm/chatglm2-6b.flm" + state_dict = ops.load(file) + return state_dict + +def run(): + # fastllm.set_device_map({"cuda:0": 28}) + state_dict = load_weights() + cfg = ChatGLMConfig() + model = ChatGLM2(cfg) + model.set_weights(state_dict) + print("model loaded!!!") + + model_path = "/home/pan/Public/Models/models-hf/chatglm2-6b" + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + # model.warmup() + res = "" + for output in model.stream_chat(query="飞机为什么会飞", tokenizer=tokenizer): + res = output + + print("最终问答", res) + +if __name__ == "__main__": + run() + diff --git a/pyfastllm/examples/test_chatglm2_cpp.py b/pyfastllm/examples/test_chatglm2_cpp.py index c1eb657a..908b3680 100644 --- a/pyfastllm/examples/test_chatglm2_cpp.py +++ b/pyfastllm/examples/test_chatglm2_cpp.py @@ -20,7 +20,7 @@ def to_numpy(data): ## 模型测试 def load_weights(): file = "/home/pan/Public/Models/models-flm/chatglm2-6b.flm" - state_dict = fastllm.load(file) + state_dict = ops.load(file) # print(state_dict.keys()) return state_dict diff --git a/pyfastllm/examples/chatglm2.py b/pyfastllm/examples/test_chatglm2_func.py similarity index 86% rename from pyfastllm/examples/chatglm2.py rename to pyfastllm/examples/test_chatglm2_func.py index 0638c9e8..b1e67767 100644 --- a/pyfastllm/examples/chatglm2.py +++ b/pyfastllm/examples/test_chatglm2_func.py @@ -36,7 +36,7 @@ def to_numpy(data): def load_weights(): file = "/home/pan/Public/Models/models-flm/chatglm2-6b.flm" - state_dict = fastllm.load(file) + state_dict = ops.load(file) # print(state_dict.keys()) return state_dict @@ -80,8 +80,8 @@ def core_attention(q, k, v, attn_mask, pastkv): seq_len, batch, num_attention_heads, attn_dim = q.shape embed_dim = num_attention_heads * attn_dim - k.reshape([k.shape[0], k.shape[1] * k.shape[2], k.shape[3]]) - v.reshape([v.shape[0], v.shape[1] * v.shape[2], v.shape[3]]) + k.reshape([k.size(0), k.size(1) * k.size(2), k.size(3)]) + v.reshape([v.size(0), v.size(1) * v.size(2), v.size(3)]) k = ops.permute(k, [1, 0, 2]) v = ops.permute(v, [1, 0, 2]) @@ -91,40 +91,40 @@ def core_attention(q, k, v, attn_mask, pastkv): unitLen = 64 while ( - (len(pastKey.shape) == 0 and (len(pastKey.expansionDims) == 0 or k.shape[1] > pastKey.expansionDims[1])) - or (len(pastKey.shape) > 0 and (len(pastKey.expansionDims) == 0 or pastKey.shape[1] + k.shape[1] > pastKey.expansionDims[1])) + (len(pastKey.shape) == 0 and (len(pastKey.expansionDims) == 0 or k.size(1) > pastKey.expansionDims[1])) + or (len(pastKey.shape) > 0 and (len(pastKey.expansionDims) == 0 or pastKey.size(1) + k.size(1) > pastKey.expansionDims[1])) ): if pastKey.count(0) == 0 or len(pastKey.shape) == 0: - newDims =[k.shape[0], int(((k.shape[1] - 1) / unitLen + 1) * unitLen), k.shape[2]] + newDims =[k.size(0), int(((k.size(1) - 1) / unitLen + 1) * unitLen), k.size(2)] else: newDims = pastKey.shape - newDims[1] += int(((k.shape[1] - 1) / unitLen + 1) * unitLen) + newDims[1] += int(((k.size(1) - 1) / unitLen + 1) * unitLen) # print(newDims) pastKey.expansion(newDims) while ( - (len(pastValue.shape) == 0 and (len(pastValue.expansionDims) == 0 or v.shape[1] > pastValue.expansionDims[1])) - or (len(pastValue.shape) > 0 and (len(pastValue.expansionDims) == 0 or pastValue.shape[1] + v.shape[1] > pastValue.expansionDims[1])) + (len(pastValue.shape) == 0 and (len(pastValue.expansionDims) == 0 or v.size(1) > pastValue.expansionDims[1])) + or (len(pastValue.shape) > 0 and (len(pastValue.expansionDims) == 0 or pastValue.size(1) + v.size(1) > pastValue.expansionDims[1])) ): if pastValue.count(0) == 0 or len(pastValue.shape) == 0: - newDims =[v.shape[0], int(((v.shape[1] - 1) / unitLen + 1) * unitLen), v.shape[2]] + newDims =[v.size(0), int(((v.size(1) - 1) / unitLen + 1) * unitLen), v.size(2)] else: newDims = pastValue.shape - newDims[1] += int(((v.shape[1] - 1) / unitLen + 1) * unitLen) + newDims[1] += int(((v.size(1) - 1) / unitLen + 1) * unitLen) pastValue.expansion(newDims) pyfastllm.cat_direct(pastKey, k, 1) pyfastllm.cat_direct(pastValue, v, 1) - q.reshape([q.shape[0], q.shape[1] * q.shape[2], q.shape[3]]) + q.reshape([q.size(0), q.size(1) * q.size(2), q.size(3)]) q = ops.permute(q, [1, 0, 2]) - context = ops.attention(q, pastKey, pastValue, attn_mask, q.shape[0]//pastKey.shape[0], 1.0/math.sqrt(attn_dim)) + context = ops.attention(q, pastKey, pastValue, attn_mask, q.size(0)//pastKey.size(0), 1.0/math.sqrt(attn_dim)) context.reshape([batch, num_attention_heads, seq_len, -1]) context = ops.permute(context, [2, 0, 1, 3]) - context.reshape([context.dims[0], context.dims[1], embed_dim]) + context.reshape([context.size(0), context.size(1), embed_dim]) return context @@ -139,14 +139,14 @@ def transformer(hidden_states, i, attn_mask, num_attention_heads, rotary_dim, po # print("transformer qkv ok") qLen = embed_dim - kvLen = (qkv.shape[-1] - embed_dim) // 2 + kvLen = (qkv.size(-1) - embed_dim) // 2 q = ops.split(qkv, -1, 0, qLen) k = ops.split(qkv, -1, qLen, qLen + kvLen) v = ops.split(qkv, -1, qLen + kvLen, qLen + kvLen + kvLen) - q.reshape([q.shape[0], q.shape[1], -1, embed_dim // num_attention_heads]) - k.reshape([k.shape[0], k.shape[1], -1, embed_dim // num_attention_heads]) - v.reshape([v.shape[0], v.shape[1], -1, embed_dim // num_attention_heads]) + q.reshape([q.size(0), q.size(1), -1, embed_dim // num_attention_heads]) + k.reshape([k.size(0), k.size(1), -1, embed_dim // num_attention_heads]) + v.reshape([v.size(0), v.size(1), -1, embed_dim // num_attention_heads]) q = pyfastllm.nearlyrotateposition2D(q, pos_id, sin_data, cos_data, rotary_dim) k = pyfastllm.nearlyrotateposition2D(k, pos_id, sin_data, cos_data, rotary_dim) @@ -181,14 +181,15 @@ def forward( pos_id, pastkvs ): - batch = input_ids.shape[0] - seq_len = input_ids.shape[1] + batch = input_ids.size(0) + seq_len = input_ids.size(1) input_ids = ops.permute(input_ids, [1, 0]) input_embedding = ops.embedding(inputs=input_ids, embedding_weights=state_dict['transformer.embedding.word_embeddings.weight']) hidden_states = input_embedding - # print("embedding ok") + print("embedding ok") + print(hidden_states) rotary_dim = 64 layer_num = 28 @@ -217,6 +218,7 @@ def forward( # print("topk ok") topk.to("cpu") + print(topk) topk_np = np.array(topk, copy=False) token = int(topk_np[0, 0, 0] + 1e-3) return token, pastkvs diff --git a/pyfastllm/examples/test_ops.py b/pyfastllm/examples/test_ops.py index 8b862ac5..d12e09e0 100644 --- a/pyfastllm/examples/test_ops.py +++ b/pyfastllm/examples/test_ops.py @@ -6,15 +6,17 @@ import pyfastllm import gc -# import np_ops -# import ops as flm_ops +import np_ops +import ops as flm_ops -from fastllm import ops as flm_ops -from fastllm import np_ops +# from fastllm import ops as flm_ops +# from fastllm import np_ops np.random.seed(42) def diff(dataA, dataB): + # print(dataA) + # print(dataB) mae = np.max(np.abs(dataA - dataB)) print('max abs err is ', mae) return mae @@ -23,6 +25,7 @@ def to_tensor(data): return pyfastllm.from_numpy(data) def to_numpy(data): + # return data.numpy() return np.array(data, copy=False, order='C') def test_rms_norm(inputs=None, weights=None, eps=1e-6): @@ -38,19 +41,19 @@ def test_rms_norm(inputs=None, weights=None, eps=1e-6): def test_swiglu(inputs=None): if not inputs: - inputs = np.array([1, 5]).reshape([1, 2]) + inputs = np.random.random(size=[1, 256]) np_out = np_ops.swiglu(inputs) - out = flm_ops.activation(input=to_tensor(inputs), activate_type="swiglu") - mae = diff(np_out, out) + out = flm_ops.activation(inputs=to_tensor(inputs), activate_type="swiglu") + mae = diff(np_out, to_numpy(out)) assert mae <= 1e-6 return out def test_attention(q=None, k=None, v=None, mask=None, group=1, scale=1.0): if q is None: - q = np.random.random(size=[1, 12, 4096]) - k = np.random.random(size=[1, 12, 4096]) - v = np.random.random(size=[1, 12, 4096]) + q = np.random.random(size=[12, 1, 4096]) + k = np.random.random(size=[12, 1, 4096]) + v = np.random.random(size=[12, 1, 4096]) scale = 1 / np.sqrt(q.shape[-1]) np_out = np_ops.attention(q, k, v, scale=scale) @@ -68,19 +71,21 @@ def test_linear(inputs=None, bias=None): if not inputs: - inputs = np.random.random(size=[1, 2, 4096]) + inputs = np.random.random(size=[1, 12, 4096]) + weights = np.random.random(size=[256, 4096]) np_out = np_ops.linear(inputs=inputs, weights=weights, bias=None) if not bias: bias = fastllm.Tensor() - output = flm_ops.linear(to_tensor(inputs), weights, bias) + output = flm_ops.linear(to_tensor(inputs), to_tensor(weights), bias) mae = diff(np_out, to_numpy(output)) - assert mae <= 1e-6 + assert mae <= 1e-3 return output + if __name__ == "__main__": test_rms_norm() test_attention() diff --git a/pyfastllm/fastllm/functions/__init__.py b/pyfastllm/fastllm/functions/__init__.py index 8b137891..827ceb47 100644 --- a/pyfastllm/fastllm/functions/__init__.py +++ b/pyfastllm/fastllm/functions/__init__.py @@ -1 +1,2 @@ - +from .fastllm_ops import * +from . import util \ No newline at end of file diff --git a/pyfastllm/fastllm/functions/fastllm_ops.py b/pyfastllm/fastllm/functions/fastllm_ops.py index 08ed7bb5..022d2205 100644 --- a/pyfastllm/fastllm/functions/fastllm_ops.py +++ b/pyfastllm/fastllm/functions/fastllm_ops.py @@ -58,12 +58,15 @@ def activation(inputs: pyfastllm.Tensor, axis=-1, activate_type="silu"): func(inputs, output) return output +def cat_(inputs, cur_data, axis=1): + pyfastllm.cat_direct(inputs, cur_data, axis) + def mul(inputs: pyfastllm.Tensor, v: int): output = pyfastllm.Tensor() pyfastllm.mul(inputs, v, output) return output -def add(input0: pyfastllm.Tensor, input1: pyfastllm.Tensor, v: int): +def add(input0: pyfastllm.Tensor, input1: pyfastllm.Tensor, v:int=1.0): output = pyfastllm.Tensor() output = pyfastllm.add(input0, input1, v) return output @@ -71,6 +74,7 @@ def add(input0: pyfastllm.Tensor, input1: pyfastllm.Tensor, v: int): def permute(inputs: pyfastllm.Tensor, dims=None): output = pyfastllm.Tensor() pyfastllm.permute(inputs, dims, output) + # pyfastllm.permute_(inputs, dims) return output def split(inputs: pyfastllm.Tensor, axis:int, start:int, end:int): @@ -83,20 +87,22 @@ def topk(logits:pyfastllm.Tensor, axis:int = 1): pyfastllm.topk(logits, axis, output) return output +def load(filepath): + state_dict = pyfastllm.WeightMap() + state_dict.load(filepath) + return state_dict + def AttentionMask(): pass def AlibiMask(): pass -def topk(): - pass - -def RotatePosition2D(): - pass +def RotatePosition2D(data, pos_id, sin_data, cos_data, rotary_dim): + return pyfastllm.rotateposition2D(data, pos_id, sin_data, cos_data, rotary_dim) -def NearlyRotatePosition2D(): - pass +def NearlyRotatePosition2D(data, pos_id, sin_data, cos_data, rotary_dim): + return pyfastllm.nearlyrotateposition2D(data, pos_id, sin_data, cos_data, rotary_dim) def LlamaRotatePosition2D(): pass diff --git a/pyfastllm/fastllm/functions/numpy_ops.py b/pyfastllm/fastllm/functions/numpy_ops.py index 3fb36ba7..f2b5372d 100644 --- a/pyfastllm/fastllm/functions/numpy_ops.py +++ b/pyfastllm/fastllm/functions/numpy_ops.py @@ -18,7 +18,7 @@ def layer_norm(inputs, gamma, beta, axis=-1): output = (inputs - mean) / var * gamma + beta return output -@jit(nopython=True) +# @jit def softmax(inputs, axis=None): maxv = inputs.max(axis, keepdims=True) exp_v = np.exp(inputs - maxv) @@ -29,13 +29,38 @@ def softmax(inputs, axis=None): def silu(inputs, ): return inputs / (1 + np.exp(-inputs)) -@jit(nopython=True) +@jit +def swiglu(inputs, ): + dim = inputs.shape[1] // 2 + for batch in range(inputs.shape[0]): + return inputs[batch, :dim] / (1 + np.exp(-inputs[batch, :dim])) * inputs[batch, dim:] + +# @jit def linear(inputs, weights, bias): - output = np.matmul(inputs, weights.T) + bias + if len(inputs.shape) == 2: + inputs = inputs[None, :] + weights = weights[None, :] + + output = np.zeros(shape=[inputs.shape[0], inputs.shape[1], weights.shape[0]]) + for batch in range(inputs.shape[0]): + output[batch] = np.matmul(inputs[batch], weights.T) + + if bias: + output[batch] += bias[batch] + return output -@jit(nopython=True) -def np_self_attention(q, k, v, mask=None, group=None, scale=None): - qk = softmax(q @ k.T * scale, axis=-1) - attn = qk @ v +# @jit +def attention(q, k, v, mask=None, group=None, scale=None): + print("shape:", q.shape) + if len(q.shape) == 2: + q = q[None, :] + k = k[None, :] + v = v[None, :] + # mask = mask[None, :] + + attn = np.zeros_like(q) + for batch in range(q.shape[0]): + qk = softmax(q[batch] @ k[batch].T * scale, axis=-1) + attn[batch, :, :] = qk @ v[batch] return attn \ No newline at end of file diff --git a/pyfastllm/fastllm/functions/util.py b/pyfastllm/fastllm/functions/util.py new file mode 100644 index 00000000..7c732bf3 --- /dev/null +++ b/pyfastllm/fastllm/functions/util.py @@ -0,0 +1,18 @@ +import numpy as np +import pyfastllm + +def diff(dataA, dataB): + mae = np.max(np.abs(dataA - dataB)) + print('max abs err is ', mae) + return mae + +def to_tensor(data): + if not isinstance(data, np.ndarray): + return None + return pyfastllm.from_numpy(data) + +def to_numpy(data): + if not isinstance(data, pyfastllm.Tensor): + return None + + return np.array(data, copy=False) \ No newline at end of file diff --git a/pyfastllm/fastllm/hub/__init__.py b/pyfastllm/fastllm/hub/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pyfastllm/fastllm/hub/chatglm2.py b/pyfastllm/fastllm/hub/chatglm2.py new file mode 100644 index 00000000..a76b02a4 --- /dev/null +++ b/pyfastllm/fastllm/hub/chatglm2.py @@ -0,0 +1,330 @@ +import sys +import pytest +import numpy as np +import fastllm + +import pyfastllm +np.random.seed(42) + +from fastllm import ops +from fastllm.nn import Module, Linear, SwiGLU, NearlyRoPE, RMSNorm, Embedding +from typing import List, Tuple +import math + +class ChatGLMConfig(): + model_type = "chatglm" + def __init__( + self, + num_layers=28, + padded_vocab_size=65024, + hidden_size=4096, + ffn_hidden_size=13696, + kv_channels=128, + num_attention_heads=32, + seq_length=2048, + hidden_dropout=0.0, + attention_dropout=0.0, + layernorm_epsilon=1e-5, + rmsnorm=True, + apply_residual_connection_post_layernorm=False, + post_layer_norm=True, + add_bias_linear=False, + add_qkv_bias=False, + interleaved_qkv=False, + bias_dropout_fusion=True, + multi_query_attention=False, + multi_query_group_num=1, + apply_query_key_layer_scaling=True, + attention_softmax_in_fp32=True, + fp32_residual_connection=False, + quantization_bit=0, + pre_seq_len=None, + prefix_projection=False, + **kwargs + ): + self.num_layers = num_layers + self.vocab_size = padded_vocab_size + self.padded_vocab_size = padded_vocab_size + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.kv_channels = kv_channels + self.num_attention_heads = num_attention_heads + self.seq_length = seq_length + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.layernorm_epsilon = layernorm_epsilon + self.rmsnorm = rmsnorm + self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm + self.post_layer_norm = post_layer_norm + self.add_bias_linear = add_bias_linear + self.add_qkv_bias = add_qkv_bias + self.bias_dropout_fusion = bias_dropout_fusion + self.multi_query_attention = multi_query_attention + self.multi_query_group_num = multi_query_group_num + self.apply_query_key_layer_scaling = apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.fp32_residual_connection = fp32_residual_connection + self.quantization_bit = quantization_bit + self.pre_seq_len = pre_seq_len + self.prefix_projection = prefix_projection + super().__init__(**kwargs) + + +class MLP(Module): + def __init__(self, config:ChatGLMConfig) -> None: + super().__init__() + self.rms = RMSNorm() + self.dense_h_to_4h = Linear(config.hidden_size, config.ffn_hidden_size * 2) + self.dense_4h_to_h = Linear(config.ffn_hidden_size, config.hidden_size) + self.act = SwiGLU() + + def forward(self, inputs): + outputs = self.rms(inputs) + # print("开始dense") + outputs = self.dense_4h_to_h(self.act(self.dense_h_to_4h(outputs))) + # print("开始add") + outputs = ops.add(outputs, inputs) + return outputs + +class CoreAttn(Module): + def __init__(self,) -> None: + super().__init__() + # self.embed_dim = config.hidden_size + + def _expand_dims(self, past_data, cur_data, unitLen=64): + while ( + (len(past_data.size()) == 0 and (len(past_data.expansionDims) == 0 or cur_data.size(1) > past_data.expansionDims[1])) + or (len(past_data.size()) > 0 and (len(past_data.expansionDims) == 0 or past_data.size(1) + cur_data.size(1) > past_data.expansionDims[1])) + ): + if past_data.count(0) == 0 or len(past_data.size()) == 0: + newDims =[cur_data.size(0), int(((cur_data.size(1) - 1) / unitLen + 1) * unitLen), cur_data.size(2)] + else: + newDims = past_data.size() + newDims[1] += int(((cur_data.size(1) - 1) / unitLen + 1) * unitLen) + + # print(newDims) + past_data.expansion(newDims) + + ops.cat_(past_data, cur_data, 1) + + def forward(self, q, k, v, attn_mask, pastkv): + seq_len, batch, num_attention_heads, attn_dim = q.size() + embed_dim = num_attention_heads * attn_dim + + k.reshape([k.size(0), k.size(1) * k.size(2), k.size(3)]) + v.reshape([v.size(0), v.size(1) * v.size(2), v.size(3)]) + + k = ops.permute(k, [1, 0, 2]) + v = ops.permute(v, [1, 0, 2]) + + pastKey = pastkv[0] + pastValue = pastkv[1] + self._expand_dims(past_data=pastKey, cur_data=k) + self._expand_dims(past_data=pastValue, cur_data=v) + + q.reshape([q.size(0), q.size(1) * q.size(2), q.size(3)]) + q = ops.permute(q, [1, 0, 2]) + context = ops.attention(q, pastKey, pastValue, attn_mask, group=q.size(0)//pastKey.size(0), scale=1.0/math.sqrt(attn_dim)) + context.reshape([batch, num_attention_heads, seq_len, -1]) + context = ops.permute(context, [2, 0, 1, 3]) + context.reshape([context.size(0), context.size(1), embed_dim]) + return context + +class Transformer(Module): + def __init__(self, config:ChatGLMConfig) -> None: + super().__init__() + # print("开始构建transformer") + self.config = config + self.rms = RMSNorm(dim=config.hidden_size) + self.qkv = Linear(in_dim=config.hidden_size, out_dim=4608, bias=True) + self.rope = NearlyRoPE() + self.attn = CoreAttn() + # print("构建core attn结束") + self.post_linear = Linear(in_dim=config.hidden_size, out_dim=config.hidden_size) + self.mlp = MLP(config=config) + + def _split_qkv(self, qkv): + embed_dim = self.config.hidden_size + num_attention_heads = self.config.num_attention_heads + + qLen = embed_dim + kvLen = (qkv.size(-1) - embed_dim) // 2 + q = ops.split(qkv, -1, 0, qLen) + k = ops.split(qkv, -1, qLen, qLen + kvLen) + v = ops.split(qkv, -1, qLen + kvLen, qLen + kvLen + kvLen) + + q.reshape([q.size(0), q.size(1), -1, embed_dim // num_attention_heads]) + k.reshape([k.size(0), k.size(1), -1, embed_dim // num_attention_heads]) + v.reshape([v.size(0), v.size(1), -1, embed_dim // num_attention_heads]) + + return (q, k, v) + + def forward(self, inputs, pos_id, attn_mask, pastkv): + # print("开始计算rms") + atten_input = self.rms(inputs) + # print("开始计算qkv") + qkv = self.qkv(atten_input) + # print("开始split") + q, k, v = self._split_qkv(qkv) + + q = self.rope(q, pos_id) + k = self.rope(k, pos_id) + + context = self.attn(q, k, v, attn_mask, pastkv) + outputs = self.post_linear(context) + outputs = ops.add(outputs, inputs) # TODO: 实现Tensor += + + # print("开始mlp") + outputs = self.mlp(outputs) + + return outputs + +class ChatGLM2(Module): + def __init__(self, config: ChatGLMConfig) -> None: + super().__init__() + # print("开始初始化模型") + self.config = config + self.num_layers = config.num_layers + self.rotary_dim = 64 + self.num_attention_heads = config.num_attention_heads + self.embed_dim = config.hidden_size + self.head_dim = self.embed_dim // self.num_attention_heads + scale_attn = math.sqrt(self.head_dim) + + # print("构建embeding") + self.embedding = Embedding(vocab_size=65024, embed_dim=4096) + self.decoder = [Transformer(config) for i in range(self.num_layers)] + # print("构建decoder结束") + self.rms = RMSNorm(eps=1e-5) + self.head = Linear(config.hidden_size , config.vocab_size) + + def _get_postion_id(self, seq_len): + pos_id = np.zeros(shape=[2, seq_len]) + pos_id[0, :] = np.arange(0, seq_len) + pos_id[1, -1] = 1 + return pos_id + + def _get_mask(self, seq_len): + attn_mask = np.zeros(shape=[seq_len, seq_len]) + attn_mask[:, -1] = 1 + for i in range(seq_len): + for j in range(i+1, seq_len): + attn_mask[i, j] = 1 + + return attn_mask + + def forward( + self, + input_ids, + attn_mask, + pos_id, + pastkvs + ): + batch = input_ids.size(0) + seq_len = input_ids.size(1) + input_ids = ops.permute(input_ids, [1, 0]) + input_embedding = self.embedding(inputs=input_ids) + hidden_states = input_embedding + + # hidden_states.to("cuda") + # print(hidden_states) + + for i in range(self.num_layers): + hidden_states = self.decoder[i].forward(hidden_states, pos_id, attn_mask, pastkv=pastkvs[i]) + + if seq_len > 1: + hidden_states = ops.split(hidden_states, 0, seq_len - 1, seq_len) + + hidden_states = self.rms(hidden_states) + logits = self.head(hidden_states) + + topk = ops.topk(logits, 1) + topk.to("cpu") + print(topk) + topk_np = ops.util.to_numpy(topk) + token = int(topk_np[0, 0, 0] + 1e-3) + return token, pastkvs + + + def set_weights(model, state_dict=None): + # state_dict = load_weights() + # state_dict = load_weights() + # print("加载权重完成") + model.embedding.weights.value = state_dict[f"transformer.embedding.word_embeddings.weight"] + model.head.weights.value = state_dict[f"transformer.output_layer.weight"] + model.rms.weights.value = state_dict[f"transformer.encoder.final_layernorm.weight"] + + for i in range(model.num_layers): + model.decoder[i].rms.weights.value = state_dict[f"transformer.encoder.layers.{i}.input_layernorm.weight"] + model.decoder[i].qkv.weights.value = state_dict[f"transformer.encoder.layers.{i}.self_attention.query_key_value.weight"] + model.decoder[i].qkv.bias.value = state_dict[f"transformer.encoder.layers.{i}.self_attention.query_key_value.bias"] + model.decoder[i].post_linear.weights.value = state_dict[f"transformer.encoder.layers.{i}.self_attention.dense.weight"] + model.decoder[i].mlp.rms.weights.value = state_dict[f"transformer.encoder.layers.{i}.post_attention_layernorm.weight"] + model.decoder[i].mlp.dense_h_to_4h.weights.value = state_dict[f"transformer.encoder.layers.{i}.mlp.dense_h_to_4h.weight"] + model.decoder[i].mlp.dense_4h_to_h.weights.value = state_dict[f"transformer.encoder.layers.{i}.mlp.dense_4h_to_h.weight"] + + + def warmup(model): + bos_token_id = 64792 + input_ids = pyfastllm.Tensor(fastllm.float32, [1, 1], [bos_token_id, ]) + attn_mask = pyfastllm.Tensor(fastllm.float32, [1, 1], [0]) + pos_id = pyfastllm.Tensor(fastllm.float32, [2, 1], [0, 0]) + + pastKeyValues = [] + for i in range(28): + pastKey = pyfastllm.Tensor(fastllm.float32) + pastValue = pyfastllm.Tensor(fastllm.float32) + pastKeyValues.append([pastKey, pastValue]) + + model.forward(input_ids, attn_mask, pos_id, pastKeyValues) + + def build_inputs(model, tokenizer, query: str, history: List[Tuple[str, str]] = None): + prompt = tokenizer.build_prompt(query, history=history) + inputs = tokenizer([prompt], return_tensors="np") + return inputs + + def stream_chat(model, query="", tokenizer=None): + # query = "你好" + # model_path = "/home/pan/Public/Models/models-hf/chatglm2-6b" + # tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + input_ids = model.build_inputs(tokenizer, query=query)['input_ids'] + print(input_ids) + + batch = input_ids.shape[0] + seq_len = input_ids.shape[1] + input_ids = ops.util.to_tensor(input_ids) + pos_id = ops.util.to_tensor(model._get_postion_id(seq_len)) + attn_mask = ops.util.to_tensor(model._get_mask(seq_len)) + + pastKeyValues = [] + for i in range(28): + pastKey = pyfastllm.Tensor(fastllm.float32) + pastValue = pyfastllm.Tensor(fastllm.float32) + pastKeyValues.append([pastKey, pastValue]) + + index = 0 + promptLen = seq_len - 2 + + results = [] + while True: + token, pastKeyValues = model(input_ids, attn_mask, pos_id, pastKeyValues) + + if token == 2: + break + + results.append(token) + ret = tokenizer.decode(results) + print(ret) + yield ret + + index += 1 + + if index >= 2048: + break + + input_ids.copy_from(fastllm.Tensor(fastllm.float32, [1, 1], [token])) + attn_mask = fastllm.Tensor(fastllm.float32) + pos_id.copy_from(fastllm.Tensor(fastllm.float32, [2, 1], [promptLen + index + 1, (index + 1)])) + + + diff --git a/pyfastllm/fastllm/nn/__init__.py b/pyfastllm/fastllm/nn/__init__.py index b886302b..f7cfc9be 100644 --- a/pyfastllm/fastllm/nn/__init__.py +++ b/pyfastllm/fastllm/nn/__init__.py @@ -1,2 +1,2 @@ from .base_module import Module -from .modules import Linear +from .modules import * diff --git a/pyfastllm/fastllm/nn/base_module.py b/pyfastllm/fastllm/nn/base_module.py index f391290a..467731e6 100644 --- a/pyfastllm/fastllm/nn/base_module.py +++ b/pyfastllm/fastllm/nn/base_module.py @@ -7,7 +7,7 @@ def __init__(self) -> None: pass def __call__(self, *args: Any, **kwds: Any) -> Any: - return self.forward(*args, **args) + return self.forward(*args, **kwds) @abstractmethod def forward(self, ): @@ -15,3 +15,51 @@ def forward(self, ): def _init_weight(self, ): pass + +import numpy as np +from typing import Union, Sequence +from pyfastllm import Tensor +from ..functions import util + +class Parameter(object): + _DEFAULT_DTYPE = pyfastllm.float32 + + def __init__(self, + value: Union[np.ndarray, None] = None, + shape: Sequence[int] = None, + dtype: Union[pyfastllm.DataType, None] = None): + dtype = self._DEFAULT_DTYPE if dtype is None else dtype + if value is None: + assert isinstance(shape, (list, tuple)) + self._value = pyfastllm.Tensor() + """ + value = np.zeros(shape=shape, dtype=np.float32) + + if len(shape) == 2: + v_range = np.sqrt(6) / np.sqrt(shape[0] + shape[1]) + else: + v_range = 0.1 + + # value ~ U[-1, 1] + value = np.random.random(size=shape) * 2 - 1 + value = np.array(value, dtype=np.float32) + # value ~ U[-v_range, v_range] + value *= v_range + """ + else: + self._value = util.to_tensor(value) + + @property + def value(self) -> Tensor: + if isinstance(self._value, np.ndarray): + self._value = util.to_tensor(self._value) + + return self._value + + @value.setter + def value(self, v: np.ndarray): + assert isinstance(v, np.ndarray) or isinstance(v, pyfastllm.Tensor) + # assert v.shape == self._value.shape, \ + # ('The value updated is not the same shape as the original. ', \ + # f'Updated: {v.shape}, original: {self._value.shape}') + self._value = v diff --git a/pyfastllm/fastllm/nn/modules.py b/pyfastllm/fastllm/nn/modules.py index 689e5fc8..02bdae11 100644 --- a/pyfastllm/fastllm/nn/modules.py +++ b/pyfastllm/fastllm/nn/modules.py @@ -1,14 +1,25 @@ from .base_module import Module +from .base_module import Parameter from ..functions import fastllm_ops as F +from ..functions import util +import numpy as np class Linear(Module): - def __init__(self, input_size, output_size, bias=False) -> None: - self.weight = None + def __init__(self, in_dim, out_dim, bias=False) -> None: + self.has_bias = bias + self.weights = Parameter(shape=(out_dim, in_dim)) self.bias = None + + if bias: + self.bias = Parameter(shape=(out_dim, )) + super().__init__() def forward(self, x): - return F.linear(x, self.weight, self.bias) + if self.has_bias: + return F.linear(x, self.weights.value, self.bias.value) + + return F.linear(x, self.weights.value) class SiLU(Module): def __init__(self) -> None: @@ -22,42 +33,56 @@ def __init__(self) -> None: super().__init__() def forward(self, inputs): - return F.activation(input=inputs, activate_type="swiglu") - -class Embedding(Module): - def __init__(self) -> None: - super().__init__() - - def forward(self): - return F.embedding() + return F.activation(inputs=inputs, activate_type="swiglu") class Embedding(Module): def __init__(self, vocab_size, embed_dim) -> None: super().__init__() self.vocab_size = vocab_size self.embed_dim = embed_dim - self.embedding_weights = None - - # def _init_weight(self): - # self.embedding_weights = to_tensor(np.random.random(size=[self.vocab_size, self.embed_dim])) + self.weights = Parameter(shape=[vocab_size, embed_dim]) def forward(self, inputs): - return F.embedding(inputs, self.embedding_weights) + return F.embedding(inputs, self.weights.value) class RMSNorm(Module): - def __init__(self) -> None: + def __init__(self, dim=4096, eps=1e-5) -> None: super().__init__() - self.weights = None - - def _init_weight(self): - return super()._init_weight() + self.weights = Parameter(shape=[dim, ]) + self.eps = eps def forward(self, inputs): - return F.rms_norm(inputs, self.weights, eps=self.eps) + return F.rms_norm(inputs, self.weights.value, eps=self.eps) class Attention(Module): def __init__(self) -> None: super().__init__() def forward(self, q, k, v, mask, group, scale): - return F.attention(q, k, v, mask, group=group, scale=scale, attentionType=0) \ No newline at end of file + return F.attention(q, k, v, mask, group=group, scale=scale, attentionType=0) + +class RoPE(Module): + def __init__(self, rotary_dim=128) -> None: + super().__init__() + self.rotary_dim = rotary_dim + self.sin_data, self.cos_data = self._get_sin_cos_data() + self.sin_data = util.to_tensor(self.sin_data) + self.cos_data = util.to_tensor(self.cos_data) + + def _get_sin_cos_data(self, base=1e4, seq_len=32768, dim=128): + inv_freq = 1.0 / (base ** (np.arange(0, dim, 2) / dim)) + t = np.arange(0, seq_len) + freqs = np.einsum('i,j->ij', t, inv_freq) + emb = np.concatenate((freqs, freqs), axis=-1) + return np.sin(emb), np.cos(emb) + + def forward(self, data, pos_id): + return F.RotatePosition2D(data, pos_id, self.sin_data, self.cos_data, self.rotary_dim) + +class NearlyRoPE(RoPE): + def __init__(self, rotary_dim=64) -> None: + super().__init__(rotary_dim) + + def forward(self, data, pos_id): + outputs = F.NearlyRotatePosition2D(data, pos_id, self.sin_data, self.cos_data, self.rotary_dim) + return outputs \ No newline at end of file diff --git a/src/pybinding.cpp b/src/pybinding.cpp index aace6155..4e214e4a 100644 --- a/src/pybinding.cpp +++ b/src/pybinding.cpp @@ -3,9 +3,16 @@ #include #include +#include +#include +#include +#include +namespace py = pybind11; +using namespace pybind11::literals; namespace pyfastllm{ + // TODO GPU内存不释放的bug // 对接不断更新的后端接口 // 需优化,减少内存拷贝 @@ -18,111 +25,112 @@ namespace pyfastllm{ // fastllm::Data output; // std::cout<<"run rms norm"< &axis, fastllm::Data &output){ fastllm::Permute(input, axis, output); - // output.ToDevice(fastllm::DataDevice::CPU); + output.ToDevice(fastllm::DataDevice::CPU); return output; } fastllm::Data Cat(const fastllm::Data &input0, const fastllm::Data &input1, int axis) { fastllm::Data output; fastllm::Cat(input0, input1, axis, output); - // output.ToDevice(fastllm::DataDevice::CPU); + output.ToDevice(fastllm::DataDevice::CPU); return output; } fastllm::Data CatDirect(fastllm::Data &input0, const fastllm::Data &input1, int axis) { fastllm::CatDirect(input0, input1, axis); - // input0.ToDevice(fastllm::DataDevice::CPU); + input0.ToDevice(fastllm::DataDevice::CPU); return input0; } fastllm::Data TopK(const fastllm::Data &input, int topk, fastllm::Data &output) { fastllm::TopK(input, output, topk); - // output.ToDevice(fastllm::DataDevice::CPU); + output.ToDevice(fastllm::DataDevice::CPU); return output; } fastllm::Data RotatePosition2D(fastllm::Data &input, const fastllm::Data &positionIds, fastllm::Data &sinData, fastllm::Data &cosData, int rotaryDim){ fastllm::RotatePosition2D(input, positionIds, sinData, cosData, rotaryDim); + // input.ToDevice(fastllm::DataDevice::CPU); return input; } @@ -132,6 +140,7 @@ namespace pyfastllm{ fastllm::Data &cosData, int rotaryDim){ fastllm::NearlyRotatePosition2D(input, positionIds, sinData, cosData, rotaryDim); + input.ToDevice(fastllm::DataDevice::CPU); return input; } @@ -159,14 +168,16 @@ namespace pyfastllm{ return ss; } - std::string Size(const fastllm::Data &data){ - std::string ss = "Size(["; - for (int i : data.dims) { - ss += std::to_string(i) + ", "; - } - ss += "])"; - return ss; + std::vector GetDims(const fastllm::Data &data){ + return data.dims; + } + + int GetSize(const fastllm::Data &data, int idx){ + int n = data.dims.size(); + idx = (idx + n) % n; + return data.dims[idx]; } + fastllm::Data ToDevice(fastllm::Data &data, const std::string &devices){ size_t pos = devices.find(":"); @@ -193,45 +204,196 @@ namespace pyfastllm{ return data; } - fastllm::WeightMap LoadWeights(const std::string &fileName){ + + fastllm::Data ToCuda(fastllm::Data &data){ + std::vectordeviceIds{0}; + data.ToDevice(fastllm::DataDevice::CUDA, deviceIds); + return data; + } + + fastllm::Data ToCpu(fastllm::Data &data){ + std::vectordeviceIds{0}; + data.ToDevice(fastllm::DataDevice::CPU, deviceIds); + return data; + } + + + // TODO:fix data double free bug + template + py::array_t ToNumpy(fastllm::Data &data){ + py::capsule free_when_done_d(data.cpuData, [](void* f) { + delete[] f; + }); + std::vector newStrides(std::move(data.strides)); + for (auto &stride:newStrides){ + stride *= sizeof(data_t); + } + return py::array_t( + data.dims, // shape + newStrides, // C-style contiguous strides for each index + (data_t*)data.cpuData, // the data pointer + free_when_done_d + ); + } + + + class Tensor { + public: + + Tensor(){} + + Tensor(const Tensor& rhs){ + this->ptr = rhs.ptr; + } + + Tensor(fastllm::DataType type) { + // auto *tensor = new fastllm::Data(type); + ptr = std::make_shared(type); + } + + Tensor(fastllm::DataType type, const std::vector &dims) { + // auto *tensor = new fastllm::Data(type, dims); + this->ptr = std::make_shared(type, dims); + } + + Tensor(fastllm::DataType type, const std::vector &dims, const std::vector &data){ + // auto *tensor = new fastllm::Data(type, dims, data); + this->ptr = std::make_shared(type, dims, data); + } + + Tensor(fastllm::Data &data){ + // auto *tensor = new fastllm::Data(data); + this->ptr = std::make_shared(data); + } + + py::buffer_info MemBuffer(){ + std::vector newStrides(std::move(ptr->strides)); + for(auto &stride:newStrides){ + stride *= sizeof(float); + } + return py::buffer_info( + ptr->cpuData, /* Pointer to buffer */ + sizeof(float), /* Size of one scalar */ + py::format_descriptor::format(), /* Python struct-style format descriptor */ + ptr->dims.size(), /* Number of dimensions */ + ptr->dims, /* Buffer dimensions */ + newStrides /* Strides (in bytes) for each index */ + ); + } + + py::list ToList(){ + auto n = this->Count(0); + py::list data(n); + for (int i = 0; i < n; i++) + data.append(((float*)ptr->cpuData)[i]); + return data; + } + + uint64_t Count(int idx) const { + return this->ptr->Count(idx); + } + + std::vector GetDims() const { + return this->ptr->dims; + } + + int GetSize(int idx) const { + int n = ptr->dims.size(); + idx = (idx + n) % n; + return ptr->dims[idx]; + } + + std::vector GetExpansionDims () const{ + return this->ptr->expansionDims; + } + + void Reshape(const std::vector &dims){ + ptr->Reshape(dims); + } + + void Expansion(const std::vector &dims) { + ptr->Expansion(dims); + } + + std::string HostString(){ + std::string ss; + ss += "tensor(["; + int last_dim = this->ptr->dims.back(); + int n = this->ptr->Count(0) / last_dim, m = last_dim; + for (int i = 0; i < n; i++) { + if (i > 0) ss += "\n"; + for (int j = 0; j < 3 && j < m; j++) { + if (j>0) ss += ", "; + ss += std::to_string(reinterpret_cast(this->ptr->cpuData)[i*m+j]); + } + if (m > 3) { + ss += "..., "; + for (int j = 0; j < 3 && j < m; j++) { + if (j>0) ss += ", "; + ss += std::to_string(reinterpret_cast(this->ptr->cpuData)[i*m + (m-3+j)]); + } + } + } + ss += "])"; + return ss; + } + + std::shared_ptr ptr; + }; + + fastllm::WeightMap &LoadWeights(const std::string &fileName){ fastllm::WeightMap wm; wm.LoadFromFile(fileName); return wm; } // 浅拷贝 - // template - fastllm::Data fromNumpy(pybind11::array_t NpData){ + template + fastllm::Data fromNumpy(pybind11::array_t NpData){ pybind11::buffer_info buf = NpData.request(); // printf("%u \n", buf.ptr); - float *ptr = (float*) buf.ptr; + data_t *ptr = (data_t*) buf.ptr; std::vector dataSize; - int dataNum = 1; + uint64_t dataNum = 1; for (auto sz:buf.shape){ dataSize.emplace_back((int)sz); dataNum *= sz; } - std::vector Vdata; + std::vectorVdata; for (int i=0;i newStrides(n); + std::vector newShape(n); + for (auto i=0;i -#include -#include -#include -#include +// #ifdef PY_API namespace py = pybind11; using namespace pybind11::literals; @@ -270,6 +432,7 @@ PYBIND11_MODULE(pyfastllm, m) { m.def("std_hash", [](std::string input) -> size_t { return std::hash{}(input); }); + // low level m.def("get_llm_type", &fastllm::GetModelTypeFromFile); m.def("llm_sampling", &fastllm::LLMSampling) @@ -294,16 +457,16 @@ PYBIND11_MODULE(pyfastllm, m) { // .def("attention_mask", &fastllm::AttentionMask) // .def("alibi_mask", &fastllm::AlibiMask) .def("permute", &pyfastllm::Permute) - // .def("permute_self", &fastllm::PermuteSelf) + .def("permute_", &fastllm::PermuteSelf) .def("topk", &pyfastllm::TopK) .def("rotateposition2D", &pyfastllm::RotatePosition2D) .def("nearlyrotateposition2D", &pyfastllm::NearlyRotatePosition2D) .def("llama_rotateposition2D", &fastllm::LlamaRotatePosition2D) // .def("repeat_penalty", &fastllm::RepeatPenalty); - .def("load", &pyfastllm::LoadWeights) - .def("from_numpy", &pyfastllm::fromNumpy); + // .def("load", &pyfastllm::LoadWeights) + .def("from_numpy", &pyfastllm::fromNumpy); - py::enum_(m, "Dtype") + py::enum_(m, "DataType") .value("float32", fastllm::DataType::FLOAT32) .value("bfloat16", fastllm::DataType::BFLOAT16) .value("int16", fastllm::DataType::INT16) @@ -317,18 +480,22 @@ PYBIND11_MODULE(pyfastllm, m) { py::class_(m, "Tensor", py::buffer_protocol()) .def_buffer([](fastllm::Data &m) -> py::buffer_info { + std::vector newStrides(std::move(m.strides)); + for(auto &stride:newStrides){ + stride *= sizeof(float); + } return py::buffer_info( m.cpuData, /* Pointer to buffer */ sizeof(float), /* Size of one scalar */ py::format_descriptor::format(), /* Python struct-style format descriptor */ m.dims.size(), /* Number of dimensions */ m.dims, /* Buffer dimensions */ - m.strides + newStrides // { sizeof(float) * m.dims[1], /* Strides (in bytes) for each index */ // sizeof(float) } ); }) - .def_readonly("dims", &fastllm::Data::dims) + .def_readonly("shape", &fastllm::Data::dims) .def_readonly("expansionDims", &fastllm::Data::expansionDims) .def(py::init<>()) .def(py::init()) @@ -337,8 +504,6 @@ PYBIND11_MODULE(pyfastllm, m) { .def(py::init()) .def("copy_from", &fastllm::Data::CopyFrom) .def("count", &fastllm::Data::Count) - - .def_readonly("shape", &fastllm::Data::dims) .def("reshape", &fastllm::Data::Reshape) .def("expansion", &fastllm::Data::Expansion) .def("to_list", [](fastllm::Data& data){ @@ -349,8 +514,11 @@ PYBIND11_MODULE(pyfastllm, m) { return vecData; }) .def("__str__", &pyfastllm::String) - .def("size", &pyfastllm::Size) - .def("to", &pyfastllm::ToDevice); + .def("size", &pyfastllm::GetDims) + .def("size", &pyfastllm::GetSize) + .def("to", &pyfastllm::ToDevice) + .def("cuda", &pyfastllm::ToCuda) + .def("cpu", &pyfastllm::ToCpu); m.def("zeros", [](const std::vector &dims, fastllm::DataType dtype)->fastllm::Data { int nums = 1; @@ -360,38 +528,6 @@ PYBIND11_MODULE(pyfastllm, m) { return data; }, py::arg("dims"), py::arg("dtype")); - m.def("cat", [](std::vector datas, int dim)->fastllm::Data { - // int pos_dim = 0; - // // dim check - // for (int i=0;i vecData; - for (auto data:datas){ - for (int i = 0; i < data.Count(0); i++) { - vecData.push_back(((float*)data.cpuData)[i]); - } - } - int seqLen = vecData.size(); - return fastllm::Data(fastllm::DataType::FLOAT32, {1, seqLen}, vecData); - }); - py::class_(m, "Tokenizer") .def_readonly("add_dummy_prefix", &fastllm::Tokenizer::addDummyPrefix) @@ -414,14 +550,16 @@ PYBIND11_MODULE(pyfastllm, m) { .def("set_special_tokens", &fastllm::Tokenizer::SetSpecialTokens); py::class_(m, "WeightMap") + .def(py::init<>()) .def_readonly("tokenizer", &fastllm::WeightMap::tokenizer) .def("load", &fastllm::WeightMap::LoadFromFile) .def("save_lowbit", &fastllm::WeightMap::SaveLowBitModel) .def("set_kv", &fastllm::WeightMap::AddDict) .def("set_weight", &fastllm::WeightMap::AddWeight) .def("__getitem__", [](fastllm::WeightMap &weight, std::string key){ - return weight[key]; - }) + fastllm::Data &data = weight[key]; + return data; + }, py::return_value_policy::take_ownership) .def("keys", [](fastllm::WeightMap &weight){ std::vector keys; for (auto iter:weight.weight){ @@ -429,7 +567,25 @@ PYBIND11_MODULE(pyfastllm, m) { } return keys; }); - + + py::class_(m, "Tensor_", py::buffer_protocol()) + .def(py::init<>()) + .def(py::init()) + .def(py::init&>()) + .def(py::init&, const std::vector&>()) + // .def(py::init()) + .def_buffer(&pyfastllm::Tensor::MemBuffer) + // .def("size", py::overload_cast<>(&pyfastllm::Tensor::GetDims)) + .def("size", &pyfastllm::Tensor::GetSize) + .def("size", &pyfastllm::Tensor::GetDims) + .def("expand_size", &pyfastllm::Tensor::GetExpansionDims) + .def("count", &pyfastllm::Tensor::Count) + .def("reshape", &pyfastllm::Tensor::Reshape) + .def("expansion", &pyfastllm::Tensor::Expansion) + .def("to_list", &pyfastllm::Tensor::ToList) + .def("__str__", &pyfastllm::Tensor::HostString); + // .def("numpy", &pyfastllm::ToNumpy) + // .def("to", &pyfastllm::ToDevice); // model classes py::class_(m, "basellm"); @@ -580,4 +736,4 @@ PYBIND11_MODULE(pyfastllm, m) { } -#endif +// #endif diff --git a/test/ops/cppOps.cpp b/test/ops/cppOps.cpp index d4dd865f..74dcfe1e 100644 --- a/test/ops/cppOps.cpp +++ b/test/ops/cppOps.cpp @@ -28,6 +28,7 @@ void callBaseOp(int optype=0){ default: break; } + outputs.ToDevice(fastllm::DataDevice::CPU); outputs.Print(); } @@ -49,6 +50,7 @@ void callNormOp(int normType=0){ default: break; } + outputs.ToDevice(fastllm::DataDevice::CPU); outputs.Print(); } @@ -59,6 +61,7 @@ void callLinearOp(){ fastllm::Data bias = fastllm::Data(fastllm::DataType::FLOAT32, {1, 3}, {0, 1, 1}); fastllm::Data outputs; fastllm::Linear(inputs, weights, bias, outputs); + outputs.ToDevice(fastllm::DataDevice::CPU); outputs.Print(); } @@ -82,6 +85,7 @@ void callActivationOp(int activateType=0){ default: break; } + outputs.ToDevice(fastllm::DataDevice::CPU); outputs.Print(); }