Skip to content

Commit

Permalink
Improved Model Loading (#24)
Browse files Browse the repository at this point in the history
* improved loading

* run precommit and ruff formatting
  • Loading branch information
ToluClassics authored Nov 19, 2024
1 parent e2fdbce commit cfc53f9
Show file tree
Hide file tree
Showing 12 changed files with 182 additions and 195 deletions.
17 changes: 10 additions & 7 deletions chat/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,16 @@ def load_model_and_cache(ref):
)
tokenizer.chat_template = chat_template
else:
chat_template = tokenizer.chat_template or (
"{% for message in messages %}"
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
"{% endfor %}"
"{% if add_generation_prompt %}"
"{{ '<|im_start|>assistant\n' }}"
"{% endif %}"
chat_template = (
tokenizer.chat_template
or (
"{% for message in messages %}"
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
"{% endfor %}"
"{% if add_generation_prompt %}"
"{{ '<|im_start|>assistant\n' }}"
"{% endif %}"
)
)

supports_system_role = "system role not supported" not in chat_template.lower()
Expand Down
34 changes: 34 additions & 0 deletions examples/bert/sentence_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import mlx.core as mx
import numpy as np

from transformers import AutoConfig, AutoTokenizer
from mlx_transformers.models import BertModel as MLXBertModel


def _mean_pooling(last_hidden_state: mx.array, attention_mask: mx.array):
token_embeddings = last_hidden_state
input_mask_expanded = mx.expand_dims(attention_mask, -1)
input_mask_expanded = mx.broadcast_to(
input_mask_expanded, token_embeddings.shape
).astype(mx.float32)
sum_embeddings = mx.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = mx.clip(input_mask_expanded.sum(axis=1), 1e-9, None)
return sum_embeddings / sum_mask


sentences = ["This is an example sentence", "Each sentence is converted"]

tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
config = AutoConfig.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

model = MLXBertModel(config)
model.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

inputs = tokenizer(sentences, return_tensors="np", padding=True, truncation=True)
inputs = {key: mx.array(v) for key, v in inputs.items()}

outputs = model(**inputs)

sentence_embeddings = _mean_pooling(outputs.last_hidden_state, inputs["attention_mask"])

print(sentence_embeddings)
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def load_model(
model_name,
huggingface_model_architecture="AutoModelForCausalLM",
trust_remote_code=True,
fp16=fp16,
float16=fp16,
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
Expand Down
2 changes: 1 addition & 1 deletion examples/text_generation/phi_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def load_model(
os.path.dirname(os.path.realpath(__file__))

model = mlx_model_class(config)
model.from_pretrained(model_name, fp16=fp16)
model.from_pretrained(model_name, float16=fp16)

tokenizer = AutoTokenizer.from_pretrained(model_name)

Expand Down
107 changes: 70 additions & 37 deletions src/mlx_transformers/models/base.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,98 @@
import importlib
import os
import logging
from typing import Callable, Optional
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
from huggingface_hub import snapshot_download
from huggingface_hub import HfFileSystem

import mlx.core as mx
from mlx.utils import tree_unflatten

CONFIG_FILE = "config.json"
WEIGHTS_FILE_NAME = "model.safetensors"

logger = logging.getLogger(__name__)
fs = HfFileSystem()

HF_TOKEN = os.getenv("HF_TOKEN", None)


@dataclass
class ModelLoadingConfig:
"""Configuration for model loading parameters."""

model_name_or_path: str
cache_dir: Optional[str] = None
revision: str = "main"
float16: bool = False
trust_remote_code: bool = False
max_workers: int = 4


class MlxPretrainedMixin:
"""Mixin class for loading pretrained models in MLX format."""

def from_pretrained(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
revision: Optional[str] = "main",
revision: str = "main",
float16: bool = False,
huggingface_model_architecture: Optional[Callable] = None,
trust_remote_code: bool = False,
max_workers: int = 4,
):
if huggingface_model_architecture:
architecture = huggingface_model_architecture
elif hasattr(self.config, "architectures"):
architecture = self.config.architectures[0]
else:
raise ValueError("No architecture found for loading this model")
) -> "MlxPretrainedMixin":
"""
Load a pretrained model from HuggingFace Hub or local path.
transformers_module = importlib.import_module("transformers")
_class = getattr(transformers_module, architecture, None)
Args:
model_name_or_path: HuggingFace model name or path to local model
cache_dir: Directory to store downloaded models
revision: Git revision to use when downloading
float16: Whether to convert model to float16
huggingface_model_architecture: Custom model architecture class
trust_remote_code: Whether to trust remote code when loading
max_workers: Number of worker threads for tensor conversion
if not _class:
raise ValueError(f"Could not find the class for {architecture}")
Returns:
Self with loaded model weights
"""
config = ModelLoadingConfig(
model_name_or_path=model_name_or_path,
cache_dir=cache_dir,
revision=revision,
float16=float16,
trust_remote_code=trust_remote_code,
max_workers=max_workers,
)

dtype = mx.float16 if float16 else mx.float32
logger.info(f"Loading model using the following configuration {config}")

logger.info(f"Loading model from {model_name_or_path}")
model = _class.from_pretrained(
model_name_or_path, trust_remote_code=trust_remote_code
safe_tensor_files = fs.glob(
f"{config.model_name_or_path}/*.safetensors",
**{"revision": config.revision},
)
safe_tensor_files = [f.split("/")[-1] for f in safe_tensor_files]

# # save the tensors
logger.info("Converting model tensors to Mx arrays")
import concurrent.futures
if not safe_tensor_files:
raise ValueError("No safe tensor files found for this model")

def convert_tensor(key, tensor, dtype):
return key, mx.array(tensor.numpy()).astype(dtype)
download_path = snapshot_download(
repo_id=config.model_name_or_path,
allow_patterns="*.safetensors",
max_workers=config.max_workers,
revision=config.revision,
token=HF_TOKEN,
)
dtype = mx.float16 if config.float16 else mx.float32

tensors = {}
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = []
for key, tensor in model.state_dict().items():
future = executor.submit(convert_tensor, key, tensor, dtype)
futures.append(future)
for file in safe_tensor_files:
file_path = Path(download_path) / file
with file_path.open("rb") as f:
state_dict = mx.load(f)

for future in concurrent.futures.as_completed(futures):
key, converted_tensor = future.result()
tensors[key] = converted_tensor
tensors.update(state_dict)

tensors = [(key, tensor) for key, tensor in tensors.items()]
tensors = {k: v.astype(dtype) for k, v in tensors.items()}

self.update(tree_unflatten(tensors))
# Update model weights
self.update(tree_unflatten(list(tensors.items())))
return self
6 changes: 3 additions & 3 deletions src/mlx_transformers/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,9 @@ def __call__(
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[
1:
] # add attentions if we output them
outputs = (
attention_output,
) + self_outputs[1:] # add attentions if we output them
return outputs


Expand Down
4 changes: 1 addition & 3 deletions src/mlx_transformers/models/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,7 @@ class DynamicCache(Cache):
def __init__(self) -> None:
self.key_cache: List[mx.array] = []
self.value_cache: List[mx.array] = []
self._seen_tokens = (
0 # Used in `generate` to keep tally of how many tokens the cache has seen
)
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen

def __getitem__(self, layer_idx: int) -> List[Tuple[mx.array]]:
"""
Expand Down
6 changes: 3 additions & 3 deletions src/mlx_transformers/models/roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,9 @@ def __call__(
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[
1:
] # add attentions if we output them
outputs = (
attention_output,
) + self_outputs[1:] # add attentions if we output them
return outputs


Expand Down
8 changes: 5 additions & 3 deletions src/mlx_transformers/models/xlm_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,9 @@ def __call__(
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[
1:
] # add attentions if we output them
outputs = (
attention_output,
) + self_outputs[1:] # add attentions if we output them
return outputs


Expand Down Expand Up @@ -425,6 +425,7 @@ def __call__(
position_ids=position_ids,
token_type_ids=token_type_ids,
)

encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
Expand Down Expand Up @@ -507,6 +508,7 @@ def __call__(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

sequence_output = outputs.last_hidden_state
logits = self.classifier(sequence_output)

Expand Down
38 changes: 28 additions & 10 deletions tests/test_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
AutoConfig,
AutoTokenizer,
BertConfig,
BertForMaskedLM,
BertModel,
BertForQuestionAnswering,
BertForSequenceClassification,
BertForTokenClassification,
BertTokenizer,
)

from src.mlx_transformers.models import BertForMaskedLM as MlxBertForMaskedLM
from src.mlx_transformers.models import BertModel as MlxBertModel
from src.mlx_transformers.models import (
BertForQuestionAnswering as MlxBertForQuestionAnswering,
)
Expand All @@ -33,14 +33,13 @@ def load_hgf_model(model_name: str, hgf_model_class):
class TestMlxBert(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
cls.model_name = "bert-base-uncased"
cls.model_name = "sentence-transformers/all-MiniLM-L6-v2"
cls.config = BertConfig.from_pretrained(cls.model_name)
cls.tokenizer = BertTokenizer.from_pretrained(cls.model_name)
cls.hgf_model_class = BertForMaskedLM
cls.hgf_model_class = BertModel

# cls.model_class = MlxBertForMaskedLM
cls.model = MlxBertForMaskedLM(cls.config)
cls.model.from_pretrained(cls.model_name)
cls.model = MlxBertModel(cls.config)
cls.model.from_pretrained(cls.model_name, revision="main")

cls.input_text = "Hello, my dog is cute"

Expand All @@ -51,14 +50,14 @@ def test_model_output_hgf(self):

inputs_mlx = {key: mx.array(v) for key, v in inputs_mlx.items()}
outputs_mlx = self.model(**inputs_mlx)
outputs_mlx = np.array(outputs_mlx.logits)
outputs_mlx = np.array(outputs_mlx.last_hidden_state)

inputs_hgf = self.tokenizer(
self.input_text, return_tensors="pt", padding=True, truncation=True
)
hgf_model = load_hgf_model(self.model_name, self.hgf_model_class)
outputs_hgf = hgf_model(**inputs_hgf)
outputs_hgf = outputs_hgf.logits.detach().numpy()
outputs_hgf = outputs_hgf.last_hidden_state.detach().numpy()

self.assertTrue(np.allclose(outputs_mlx, outputs_hgf, atol=1e-4))

Expand All @@ -72,7 +71,7 @@ def setUpClass(cls) -> None:

cls.hgf_model_class = BertForSequenceClassification
cls.model = MlxBertForSequenceClassification(cls.config)
cls.model.from_pretrained(cls.model_name)
cls.model.from_pretrained(cls.model_name, revision="refs/pr/1")

cls.input_text = "Hello, my dog is cute"

Expand All @@ -91,6 +90,7 @@ def test_model_output_hgf(self):
)

inputs_mlx = {key: mx.array(v) for key, v in inputs_mlx.items()}

outputs_mlx = self.model(**inputs_mlx)
outputs_mlx = np.array(outputs_mlx.logits)
predicted_class_id = outputs_mlx.argmax().item()
Expand All @@ -100,6 +100,7 @@ def test_model_output_hgf(self):
self.input_text, return_tensors="pt", padding=True, truncation=True
)
hgf_model = load_hgf_model(self.model_name, self.hgf_model_class)

outputs_hgf = hgf_model(**inputs_hgf)
outputs_hgf = outputs_hgf.logits

Expand Down Expand Up @@ -171,6 +172,9 @@ def test_model_output_hgf(self):
]

self.assertEqual(mlx_predicted_tokens_classes, hgf_predicted_tokens_classes)
self.assertTrue(
np.allclose(np.array(outputs_mlx), outputs_hgf.detach().numpy(), atol=1e-4)
)


class TestMlxBertForQuestionAnswering(unittest.TestCase):
Expand Down Expand Up @@ -232,6 +236,20 @@ def test_model_output_hgf(self):
)

self.assertEqual(mlx_answer, hgf_answer)
self.assertTrue(
np.allclose(
np.array(outputs_mlx.start_logits),
outputs_hgf.start_logits.detach().numpy(),
atol=1e-4,
)
)
self.assertTrue(
np.allclose(
np.array(outputs_mlx.end_logits),
outputs_hgf.end_logits.detach().numpy(),
atol=1e-4,
)
)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit cfc53f9

Please sign in to comment.