Skip to content

Commit

Permalink
Merge branch 'master' into feature/int8_try2
Browse files Browse the repository at this point in the history
  • Loading branch information
karpathy authored Oct 9, 2023
2 parents 5186b50 + 46ded11 commit 1f8af82
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 31 deletions.
17 changes: 15 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
<img src="assets/llama_cute.jpg" width="300" height="300" alt="Cute Llama">
</p>

Have you ever wanted to inference a baby [Llama 2](https://ai.meta.com/llama/) model in pure C? No? Well, now you can!

Train the Llama 2 LLM architecture in PyTorch then inference it with one simple 700-line C file ([run.c](run.c)). You might think that you need many billion parameter LLMs to do anything useful, but in fact very small LLMs can have surprisingly strong performance if you make the domain narrow enough (ref: [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories) paper). This repo is a "fullstack" train + inference solution for Llama 2 LLM, with focus on minimalism and simplicity.

As the architecture is identical, you can also load and inference Meta's Llama 2 models. However, the current code only inferences models in fp32, so you will most likely not be able to productively load models larger than 7B. Work on model quantization is currently ongoing.
Expand All @@ -14,7 +16,7 @@ Please note that this repo started recently as a fun weekend project: I took my

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/karpathy/llama2.c/blob/master/run.ipynb)

First, navigate to the folder when you keep your projects and clone this repository to this folder:
First, navigate to the folder where you keep your projects and clone this repository to this folder:

```bash
git clone https://github.com/karpathy/llama2.c.git
Expand Down Expand Up @@ -109,8 +111,9 @@ Chat with Code Llama Instruct:
python export.py codellama2_7b_instruct.bin --meta-llama /path/to/CodeLlama-7b-Instruct
python tokenizer.py --tokenizer-model=/path/to/CodeLlama-7b-Instruct/tokenizer.model
./run codellama2_7b_instruct.bin -m chat -z /path/to/CodeLlama-7b-Instruct/tokenizer.bin
```

## hugginface models
## huggingface models

We can load any huggingface models that use the Llama 2 architecture. See the script [export.py](export.py) and the `--hf` flag to export the model .bin file.

Expand Down Expand Up @@ -311,6 +314,7 @@ If your candidate PRs have elements of these it doesn't mean they won't get merg
- [llama2-rs](https://github.com/danielgrittner/llama2-rs) by @[danielgrittner](https://github.com/danielgrittner): a Rust port of this project
- [llama2.rs](https://github.com/lintian06/llama2.rs) by @[lintian06](https://github.com/lintian06): A Rust port of this project
- [pecca.rs](https://github.com/rahoua/pecca-rs) by @[rahoua](https://github.com/rahoua): A Rust port leveraging [ndarray](https://github.com/rust-ndarray/ndarray), supports BLAS.
- [llama2.rs](https://github.com/flaneur2020/llama2.rs) by @[flaneur2020](https://github.com/flaneur2020): A Rust port of this project.
- Go
- [go-llama2](https://github.com/tmc/go-llama2) by @[tmc](https://github.com/tmc): a Go port of this project
- [llama2.go](https://github.com/nikolaydubina/llama2.go) by @[nikolaydubina](https://github.com/nikolaydubina): a Go port of this project
Expand All @@ -323,6 +327,7 @@ If your candidate PRs have elements of these it doesn't mean they won't get merg
- [llama2.cpp](https://github.com/leloykun/llama2.cpp) by @[leloykun](https://github.com/leloykun): a C++ port of this project
- JavaScript
- [llama2.js](https://github.com/epicure/llama2.js) by @[epicure](https://github.com/epicure): a JavaScript port of this project
- [llamajs](https://github.com/agershun/llamajs) by @[agershun](https://github.com/agershun): a JavaScript port of this project
- [llama2.ts](https://github.com/wizzard0/llama2.ts) by @[oleksandr_now](https://twitter.com/oleksandr_now): a TypeScript port of this project. Full Llama2-7B capable.
- [llama2.c-emscripten](https://github.com/gohai/llama2.c-emscripten) by @[gohai](https://github.com/gohai): Emscripten (JavaScript) port, based on @ggerganov's initial prototype
- Zig
Expand All @@ -343,8 +348,16 @@ If your candidate PRs have elements of these it doesn't mean they won't get merg
- [llama2.cs](https://github.com/trrahul/llama2.cs) by @[trrahul](https://github.com/trrahul): a C# port of this project
- Dart
- [llama2.dart](https://github.com/yiminghan/llama2.dart) by @[yiminghan](https://github.com/yiminghan/llama2.dart): one-file dart port of this project, works with Flutter!
- Web
- [llama2c-web](https://github.com/dmarcos/llama2.c-web) by @[dmarcos](https://github.com/dmarcos): Super simple way to build unmodified llama2.c to WASM and run it in the browser. [Demo](https://diegomarcos.com/llama2.c-web/)
- WebAssembly
- [icpp-llm](https://github.com/icppWorld/icpp-llm): LLMs for the Internet Computer
- Fortran
- [llama2.f90](https://github.com/rbitr/llama2.f90): a Fortran port of this project
- Mojo
- [llama2.🔥](https://github.com/tairov/llama2.mojo) by @[tairov](https://github.com/tairov): pure Mojo port of this project
- OCaml
- [llama2.ml](https://github.com/jackpeck/llama2.ml) by @[jackpeck](https://github.com/jackpeck): an OCaml port of this project
- [llama2.c - Llama 2 Everywhere](https://github.com/trholding/llama2.c) by @[trholding](https://github.com/trholding): Standalone, Bootable & Portable Binary Llama 2
- [llama2.c-zh - Bilingual Chinese and English](https://github.com/chenyangMl/llama2.c-zh) by @[chenyangMl](https://github.com/chenyangMl): Expand tokenizer to support training and inference in both Chinese and English

Expand Down
100 changes: 98 additions & 2 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,96 @@ def version2_export(model, filepath, group_size=64):
out_file.close()
print(f"wrote {filepath}")

def hf_export(llama_model, filepath, group_size=64, dtype=torch.float32):
""" Generate the pytorch_model.bin state_dict and config.json for HuggingFace """

try:
from transformers.models.llama.configuration_llama import LlamaConfig
except ImportError:
print("Error: transformers package is required to load huggingface models")
print("Please run `pip install transformers` to install it")
return None

# Generate LlamaModel state_dict
hf_state_dict = {}

# Sometimes we have repeated key values for the heads
dim = llama_model.params.dim
num_key_value_heads = llama_model.params.n_kv_heads
n_rep = llama_model.params.n_heads // num_key_value_heads
key_value_dim = dim // n_rep

# HuggingFace needs the weights permuted.
# See: https://github.com/huggingface/transformers/blob/b132c1703eb1c8bd9dfa4ad6a9be2bfd6ef819e9/src/transformers/models/llama/convert_llama_weights_to_hf.py#L122
def permute_original(w, n_heads=llama_model.params.n_heads, dim1=dim, dim2=dim):
return w.view(dim1, dim2).reshape(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)

# Transfer weights from llama model to the HF state dictionary format
hf_state_dict['model.embed_tokens.weight'] = llama_model.tok_embeddings.weight.clone().to(dtype)
hf_state_dict['model.norm.weight'] = llama_model.norm.weight.clone().to(dtype)

# Add each layer's weights to the HF state dictionary
for i, layer in enumerate(llama_model.layers):
layer_id = layer.layer_id
hf_state_dict[f'model.layers.{i}.input_layernorm.weight'] = llama_model.layers[layer_id].attention_norm.weight.clone().to(dtype)
hf_state_dict[f'model.layers.{i}.self_attn.q_proj.weight'] = permute_original(llama_model.layers[layer_id].attention.wq.weight.clone()).to(dtype)
hf_state_dict[f'model.layers.{i}.self_attn.k_proj.weight'] = permute_original(llama_model.layers[layer_id].attention.wk.weight.clone(), num_key_value_heads, key_value_dim, dim).to(dtype)
hf_state_dict[f'model.layers.{i}.self_attn.v_proj.weight'] = llama_model.layers[layer_id].attention.wv.weight.clone().to(dtype)
hf_state_dict[f'model.layers.{i}.self_attn.o_proj.weight'] = llama_model.layers[layer_id].attention.wo.weight.clone().to(dtype)
hf_state_dict[f'model.layers.{i}.post_attention_layernorm.weight'] = llama_model.layers[layer_id].ffn_norm.weight.clone().to(dtype)
hf_state_dict[f'model.layers.{i}.mlp.gate_proj.weight'] = llama_model.layers[layer_id].feed_forward.w1.weight.clone().to(dtype)
hf_state_dict[f'model.layers.{i}.mlp.down_proj.weight'] = llama_model.layers[layer_id].feed_forward.w2.weight.clone().to(dtype)
hf_state_dict[f'model.layers.{i}.mlp.up_proj.weight'] = llama_model.layers[layer_id].feed_forward.w3.weight.clone().to(dtype)

# llama2.c usually uses tied weights -> reference the embed_tokens.weights instead
hf_state_dict['lm_head.weight'] = hf_state_dict['model.embed_tokens.weight']

# We check that the embeddings are tied, else use manual output weights
_embeddings_are_tied: bool = torch.equal(llama_model.tok_embeddings.weight, llama_model.output.weight)
if not _embeddings_are_tied:
hf_state_dict['lm_head.weight'] = llama_model.output.weight.clone().to(dtype)


# Generate LlamaConfig (seen in transformers.models.llama.configuration_llama)

# Extract necessary attributes from llama.c model
vocab_size = llama_model.params.vocab_size
hidden_size = llama_model.params.dim
intermediate_size = llama_model.layers[0].feed_forward.w1.weight.shape[0]
num_hidden_layers = llama_model.params.n_layers
num_attention_heads = llama_model.params.n_heads
num_key_value_heads = llama_model.params.n_kv_heads
max_position_embeddings = llama_model.params.max_seq_len
rms_norm_eps = llama_model.params.norm_eps

# TODO check values for:
# pretraining_tp, initializer_range, use_cache,
# rope_theta, and rope_scaling.

config = LlamaConfig(
vocab_size=vocab_size,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
max_position_embeddings=max_position_embeddings,
rms_norm_eps=rms_norm_eps,
tie_word_embeddings=_embeddings_are_tied,
# Manual
architectures=["LlamaForCausalLM"],
hidden_act="silu",
)


# Save files in directory filepath
# First make the directory if it doesn't exist
os.makedirs(filepath, exist_ok=True)

# Save the state dictionary in .bin format, and config as .json
torch.save(hf_state_dict, os.path.join(filepath, "pytorch_model.bin"))
config.save_pretrained(filepath)


# -----------------------------------------------------------------------------
# Load / import functions
Expand Down Expand Up @@ -399,19 +489,23 @@ def permute_reverse(w, n_heads=config.n_heads, dim1=config.dim, dim2=config.dim)
# -----------------------------------------------------------------------------
# API entrypoint

def model_export(model, filepath, version):
def model_export(model, filepath, version, dtype=torch.float32):
"""
Versions docs:
v-1:huggingface export, i.e. intended for use outside of this repo, in HF
v0: legacy llama2.c float format, DEPRECATED
v1: float32 export
v2: int8 quantized Q8_0 export, similar to llama.cpp, in groups
# TODO: add dtype export support for other versions (?)
"""
if version == 0:
legacy_export(model, filepath)
elif version == 1:
version1_export(model, filepath)
elif version == 2:
version2_export(model, filepath)
elif version == -1:
hf_export(model, filepath, dtype)
else:
raise ValueError(f"unknown version {version}")

Expand Down Expand Up @@ -451,11 +545,13 @@ def torchscript_export(model, filepath, zero_params=False, gzip_output=False):
parser = argparse.ArgumentParser()
parser.add_argument("filepath", type=str, help="the output filepath")
parser.add_argument("--version", default=0, type=int, help="the version to export with")
parser.add_argument("--dtype", type=str, help="dtype of the model (fp16, fp32)", default="fp32")
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--checkpoint", type=str, help="model checkpoint, .pt file")
group.add_argument("--meta-llama", type=str, help="meta llama model path")
group.add_argument("--hf", type=str, help="huggingface model path")
args = parser.parse_args()
dtype = {"fp16": torch.float16, "fp32": torch.float32}[args.dtype]

if args.checkpoint:
model = load_checkpoint(args.checkpoint)
Expand All @@ -468,4 +564,4 @@ def torchscript_export(model, filepath, zero_params=False, gzip_output=False):
parser.error("Can't load input model!")

# export
model_export(model, args.filepath, args.version)
model_export(model, args.filepath, args.version, args.dtype)
47 changes: 21 additions & 26 deletions run.c
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,13 @@ void malloc_run_state(RunState* s, Config* p) {
s->hb = calloc(p->hidden_dim, sizeof(float));
s->hb2 = calloc(p->hidden_dim, sizeof(float));
s->q = calloc(p->dim, sizeof(float));
s->k = calloc(kv_dim, sizeof(float));
s->v = calloc(kv_dim, sizeof(float));
s->att = calloc(p->n_heads * p->seq_len, sizeof(float));
s->logits = calloc(p->vocab_size, sizeof(float));
s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
s->att = calloc(p->n_heads * p->seq_len, sizeof(float));
s->logits = calloc(p->vocab_size, sizeof(float));
// ensure all mallocs went fine
if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
|| !s->k || !s->v || !s->att || !s->logits || !s->key_cache
|| !s->value_cache) {
|| !s->key_cache || !s->value_cache || !s->att || !s->logits) {
fprintf(stderr, "malloc failed!\n");
exit(EXIT_FAILURE);
}
Expand All @@ -105,8 +102,6 @@ void free_run_state(RunState* s) {
free(s->hb);
free(s->hb2);
free(s->q);
free(s->k);
free(s->v);
free(s->att);
free(s->logits);
free(s->key_cache);
Expand All @@ -115,26 +110,28 @@ void free_run_state(RunState* s) {

void memory_map_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) {
int head_size = p->dim / p->n_heads;
// make sure the multiplications below are done in 64bit to fit the parameter counts of 13B+ models
unsigned long long n_layers = p->n_layers;
w->token_embedding_table = ptr;
ptr += p->vocab_size * p->dim;
w->rms_att_weight = ptr;
ptr += p->n_layers * p->dim;
ptr += n_layers * p->dim;
w->wq = ptr;
ptr += p->n_layers * p->dim * (p->n_heads * head_size);
ptr += n_layers * p->dim * (p->n_heads * head_size);
w->wk = ptr;
ptr += p->n_layers * p->dim * (p->n_kv_heads * head_size);
ptr += n_layers * p->dim * (p->n_kv_heads * head_size);
w->wv = ptr;
ptr += p->n_layers * p->dim * (p->n_kv_heads * head_size);
ptr += n_layers * p->dim * (p->n_kv_heads * head_size);
w->wo = ptr;
ptr += p->n_layers * (p->n_heads * head_size) * p->dim;
ptr += n_layers * (p->n_heads * head_size) * p->dim;
w->rms_ffn_weight = ptr;
ptr += p->n_layers * p->dim;
ptr += n_layers * p->dim;
w->w1 = ptr;
ptr += p->n_layers * p->dim * p->hidden_dim;
ptr += n_layers * p->dim * p->hidden_dim;
w->w2 = ptr;
ptr += p->n_layers * p->hidden_dim * p->dim;
ptr += n_layers * p->hidden_dim * p->dim;
w->w3 = ptr;
ptr += p->n_layers * p->dim * p->hidden_dim;
ptr += n_layers * p->dim * p->hidden_dim;
w->rms_final_weight = ptr;
ptr += p->dim;
ptr += p->seq_len * head_size / 2; // skip what used to be freq_cis_real (for RoPE)
Expand Down Expand Up @@ -249,11 +246,16 @@ float* forward(Transformer* transformer, int token, int pos) {
memcpy(x, content_row, dim*sizeof(*x));

// forward all the layers
for(int l = 0; l < p->n_layers; l++) {
for(unsigned long long l = 0; l < p->n_layers; l++) {

// attention rmsnorm
rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim);

// key and value point to the kv cache
int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience
s->k = s->key_cache + loff + pos * kv_dim;
s->v = s->value_cache + loff + pos * kv_dim;

// qkv matmuls for this position
matmul(s->q, s->xb, w->wq + l*dim*dim, dim, dim);
matmul(s->k, s->xb, w->wk + l*dim*kv_dim, dim, kv_dim);
Expand All @@ -276,13 +278,6 @@ float* forward(Transformer* transformer, int token, int pos) {
}
}

// save key,value at this time step (pos) to our kv cache
int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience
float* key_cache_row = s->key_cache + loff + pos * kv_dim;
float* value_cache_row = s->value_cache + loff + pos * kv_dim;
memcpy(key_cache_row, s->k, kv_dim * sizeof(*key_cache_row));
memcpy(value_cache_row, s->v, kv_dim * sizeof(*value_cache_row));

// multihead attention. iterate over all heads
int h;
#pragma omp parallel for private(h)
Expand Down Expand Up @@ -754,7 +749,7 @@ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
// forward the transformer to get logits for the next token
float* logits = forward(transformer, token, pos);

// advance the state state machine
// advance the state machine
if (pos < num_prompt_tokens - 1) {
// if we are still processing the input prompt, force the next prompt token
next = prompt_tokens[pos + 1];
Expand Down
2 changes: 1 addition & 1 deletion tinystories.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def train_vocab(vocab_size):
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))

print(f"Writing temporary file {tiny_file} with {num_shards} shards...")
with open(tiny_file, "w") as of:
with open(tiny_file, "w", encoding="utf-8") as of:
for shard in tqdm(shard_filenames[:num_shards]):
with open(shard, "r") as f:
data = json.load(f)
Expand Down

0 comments on commit 1f8af82

Please sign in to comment.