Skip to content

Commit

Permalink
implementing llama (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
ToluClassics authored Apr 22, 2024
1 parent 0f4e6f4 commit 3751ecc
Show file tree
Hide file tree
Showing 14 changed files with 1,189 additions and 17 deletions.
45 changes: 43 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
[![PyPI](https://img.shields.io/pypi/v/mlx-transformers?color=red)](https://pypi.org/project/mlx-transformers/)


MLX Transformers is a library that provides model implementation in MLX. It uses a similar model interface as HuggingFace Transformers and provides a way to load and use models in Apple Silicon devices. Implemented models have the same modules
MLX Transformers is a library that provides model implementation in [MLX](https://github.com/ml-explore/mlx). It uses a similar model interface as HuggingFace Transformers and provides a way to load and use models in Apple Silicon devices. Implemented models have the same modules

MLX transformers is currently only available for infernce on Apple Silicon devices. Training support will be added in the future.

Expand Down Expand Up @@ -52,6 +52,38 @@ A list of the available models can be found in the `mlx_transformers.models` mod
outputs = model(**inputs)
```

### Sentence Transformer Example

```python
import mlx.core as mx
import numpy as np

from transformers import AutoConfig, AutoTokenizer
from mlx_transformers.models import BertModel as MLXBertModel


def _mean_pooling(last_hidden_state: mx.array, attention_mask: mx.array):
token_embeddings = last_hidden_state
input_mask_expanded = mx.expand_dims(attention_mask, -1)
input_mask_expanded = mx.broadcast_to(input_mask_expanded, token_embeddings.shape).astype(mx.float32)
sum_embeddings = mx.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = mx.clip(input_mask_expanded.sum(axis=1), 1e-9, None)
return sum_embeddings / sum_mask

sentences = ['This is an example sentence', 'Each sentence is converted']

tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
config = AutoConfig.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')

model = MLXBertModel(config)
model.load_weights("all-MiniLM-L6-v2.npz", strict=True)

inputs = tokenizer(sentences, return_tensors="np", padding=True, truncation=True)
outputs = model(**inputs)

sentence_embeddings = _mean_pooling(outputs.last_hidden_state, inputs.attention_mask)
```


## Available Models

Expand All @@ -61,10 +93,19 @@ The following models have been ported to MLX Transformers from Huggingface for i
2. Roberta
3. XLMRoberta
4. M2M100
5. Sentence Transformers
6. CLIP -> Coming soon...
7. Llama
8. T5 -> Coming soon...

## Examples

Coming soon...
The `examples` directory contains a few examples that demonstrate how to use the models in MLX Transformers.

1. [LLama Example](examples/llama_generation.py)
```bash
python3 examples/llama_generation.py --model-name "meta-llama/Llama-2-7b-hf" --model-path meta-llama-Llama-2-7b-hf.npz
```

## Benchmarks

Expand Down
123 changes: 123 additions & 0 deletions examples/llama_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import argparse
import os
import time
from pathlib import Path
from typing import Tuple

import mlx.core as mx
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM

from mlx_transformers.models import LlamaForCausalLM as MlxLlamaForCausalLM
from mlx_transformers.models.utils import convert


def tic():
"Return generation time in seconds"
return time.time()


def toc(msg, start):
"Return generation time in seconds and a message"
end = time.time()
return f"[INFO] {msg}: {end - start:.3f} s"


def load_model(
model_name: str, model_weights: str, hgf_model_class, mlx_model_class
) -> Tuple[MlxLlamaForCausalLM, AutoTokenizer]:
"""
Load a llama model and tokenizer from the given model name and weights.
Args:
model_name (str): Name of the llama model to load
model_weights (str): Path to the model weights
hgf_model_class: Huggingface model class
mlx_model_class: Mlx model class
Returns:
_type_: _description_
"""
config = LlamaConfig.from_pretrained(model_name)
current_directory = os.path.dirname(os.path.realpath(__file__))

model = mlx_model_class(config)
model.load_weights(model_weights, strict=False)

tokenizer = AutoTokenizer.from_pretrained(model_name)

return model, tokenizer


def generate(model: MlxLlamaForCausalLM, tokenizer: AutoTokenizer, args):
print(args.prompt)
inputs = tokenizer(args.prompt, return_tensors="np", truncation=True)

inputs = {key: mx.array(v) for key, v in inputs.items()}
skip = 0
prompt_processing = None
tokens = []
start = tic()
for token in model.generate(inputs, args.temp):
tokens.append(token)

if len(tokens) == 1:
# Actually perform the computation to measure the prompt processing time
mx.eval(token)
prompt_processing = toc("Prompt processing", start)

if len(tokens) >= args.max_tokens:
break

elif (len(tokens) % args.write_every) == 0:
# It is perfectly ok to eval things we have already eval-ed.
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])
print(s[skip:], end="", flush=True)
skip = len(s)

mx.eval(tokens)
full_gen = toc("Full generation", start)
s = tokenizer.decode([t.item() for t in tokens])
print(s[skip:], flush=True)
print("------")
print(prompt_processing)
print(full_gen)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Llama inference script")
parser.add_argument(
"--model-name",
help="The model name to load",
default="meta-llama/Llama-2-7b-hf",
)
parser.add_argument(
"--model-path",
help="Path to the model weights",
default="mlx_model",
)
parser.add_argument(
"--prompt",
help="The message to be processed by the model. Ignored when --few-shot is provided.",
default="In the beginning the Universe was created.",
)
parser.add_argument(
"--max-tokens", "-m", type=int, default=100, help="How many tokens to generate"
)
parser.add_argument(
"--write-every", type=int, default=1, help="After how many tokens to detokenize"
)
parser.add_argument(
"--temp", type=float, default=0.0, help="The sampling temperature"
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")

args = parser.parse_args()

mx.random.seed(args.seed)

model, tokenizer = load_model(
args.model_name, args.model_path, LlamaForCausalLM, MlxLlamaForCausalLM
)

generate(model, tokenizer, args)
Empty file.
18 changes: 9 additions & 9 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
huggingface-hub==0.22.2
mlx==0.8.1
numpy==1.26.4
safetensors==0.4.3
sentencepiece==0.2.0
tokenizers==0.19.1
torch==2.2.2
tqdm==4.66.2
transformers==4.40.0
huggingface-hub>=0.22.2
mlx>=0.8.1
numpy>=1.26.4
safetensors>=0.4.3
sentencepiece>=0.2.0
tokenizers>=0.19.1
torch>=2.2.2
tqdm>=4.66.2
transformers>=4.40.0
1 change: 1 addition & 0 deletions src/mlx_transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
BertForTokenClassification,
BertForQuestionAnswering
)
from .llama import LlamaForCausalLM, LlamaModel
from .roberta import (
RobertaModel,
RobertaForSequenceClassification,
Expand Down
4 changes: 3 additions & 1 deletion src/mlx_transformers/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,9 @@ def __call__(

pooled_output = outputs.pooler_output

pooled_output = self.dropout(pooled_output)
if self.config.train:
pooled_output = self.dropout(pooled_output)

logits = self.classifier(pooled_output)

loss = None
Expand Down
168 changes: 168 additions & 0 deletions src/mlx_transformers/models/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

import mlx.core as mx

logger = logging.getLogger(__name__)


@dataclass
class Cache:
"""
Base, abstract class for all caches. The actual data structure is specific to each subclass.
"""

def update(
self,
key_states: mx.array,
value_states: mx.array,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[mx.array, mx.array]:
raise NotImplementedError("Make sure to implement `update` in a subclass.")

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
raise NotImplementedError(
"Make sure to implement `get_seq_length` in a subclass."
)

def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states, if there is any."""
raise NotImplementedError(
"Make sure to implement `get_max_length` in a subclass."
)

def get_usable_length(
self, new_seq_length: int, layer_idx: Optional[int] = 0
) -> int:
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
# Cache without size limit -> all cache is usable
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
# length, we will need to evict part of the cache (and thus not all cache is usable)
max_length = self.get_max_length()
previous_seq_length = self.get_seq_length(layer_idx)
if max_length is not None and previous_seq_length + new_seq_length > max_length:
return max_length - new_seq_length
return previous_seq_length

@property
def seen_tokens(self):
logger.warning_once(
"The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
"model input instead."
)
if hasattr(self, "_seen_tokens"):
return self._seen_tokens
else:
return None


class DynamicCache(Cache):
"""
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
`[batch_size, num_heads, seq_len, head_dim]`.
"""

def __init__(self) -> None:
self.key_cache: List[mx.array] = []
self.value_cache: List[mx.array] = []
self._seen_tokens = (
0 # Used in `generate` to keep tally of how many tokens the cache has seen
)

def __getitem__(self, layer_idx: int) -> List[Tuple[mx.array]]:
"""
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
sequence length.
"""
if layer_idx < len(self):
return (self.key_cache[layer_idx], self.value_cache[layer_idx])
else:
raise KeyError(
f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}"
)

def __iter__(self):
"""
Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
keys and values
"""
for layer_idx in range(len(self)):
yield (self.key_cache[layer_idx], self.value_cache[layer_idx])

def __len__(self):
"""
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
to the number of layers in the model.
"""
return len(self.key_cache)

def update(
self,
key_states: mx.array,
value_states: mx.array,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[mx.array, mx.array]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`mx.array`):
The new key states to cache.
value_states (`mx.array`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
Return:
A tuple containing the updated key and value states.
"""
# Update the number of seen tokens
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]

# Update the cache
if len(self.key_cache) <= layer_idx:
self.key_cache.append(key_states)
self.value_cache.append(value_states)
else:
self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states

return self.key_cache[layer_idx], self.value_cache[layer_idx]

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
if len(self.key_cache) <= layer_idx:
return 0
return self.key_cache[layer_idx].shape[-2]

def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
return None

def to_legacy_cache(self) -> Tuple[Tuple[mx.array], Tuple[mx.array]]:
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
legacy_cache = ()
for layer_idx in range(len(self)):
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
return legacy_cache

@classmethod
def from_legacy_cache(
cls, past_key_values: Optional[Tuple[Tuple[mx.array]]] = None
) -> "DynamicCache":
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
cache = cls()
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx]
cache.update(key_states, value_states, layer_idx)
return cache
Loading

0 comments on commit 3751ecc

Please sign in to comment.