-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* improved loading * run precommit and ruff formatting
- Loading branch information
1 parent
e2fdbce
commit cfc53f9
Showing
12 changed files
with
182 additions
and
195 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import mlx.core as mx | ||
import numpy as np | ||
|
||
from transformers import AutoConfig, AutoTokenizer | ||
from mlx_transformers.models import BertModel as MLXBertModel | ||
|
||
|
||
def _mean_pooling(last_hidden_state: mx.array, attention_mask: mx.array): | ||
token_embeddings = last_hidden_state | ||
input_mask_expanded = mx.expand_dims(attention_mask, -1) | ||
input_mask_expanded = mx.broadcast_to( | ||
input_mask_expanded, token_embeddings.shape | ||
).astype(mx.float32) | ||
sum_embeddings = mx.sum(token_embeddings * input_mask_expanded, 1) | ||
sum_mask = mx.clip(input_mask_expanded.sum(axis=1), 1e-9, None) | ||
return sum_embeddings / sum_mask | ||
|
||
|
||
sentences = ["This is an example sentence", "Each sentence is converted"] | ||
|
||
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") | ||
config = AutoConfig.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") | ||
|
||
model = MLXBertModel(config) | ||
model.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") | ||
|
||
inputs = tokenizer(sentences, return_tensors="np", padding=True, truncation=True) | ||
inputs = {key: mx.array(v) for key, v in inputs.items()} | ||
|
||
outputs = model(**inputs) | ||
|
||
sentence_embeddings = _mean_pooling(outputs.last_hidden_state, inputs["attention_mask"]) | ||
|
||
print(sentence_embeddings) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,65 +1,98 @@ | ||
import importlib | ||
import os | ||
import logging | ||
from typing import Callable, Optional | ||
from dataclasses import dataclass | ||
from pathlib import Path | ||
from typing import Optional | ||
from huggingface_hub import snapshot_download | ||
from huggingface_hub import HfFileSystem | ||
|
||
import mlx.core as mx | ||
from mlx.utils import tree_unflatten | ||
|
||
CONFIG_FILE = "config.json" | ||
WEIGHTS_FILE_NAME = "model.safetensors" | ||
|
||
logger = logging.getLogger(__name__) | ||
fs = HfFileSystem() | ||
|
||
HF_TOKEN = os.getenv("HF_TOKEN", None) | ||
|
||
|
||
@dataclass | ||
class ModelLoadingConfig: | ||
"""Configuration for model loading parameters.""" | ||
|
||
model_name_or_path: str | ||
cache_dir: Optional[str] = None | ||
revision: str = "main" | ||
float16: bool = False | ||
trust_remote_code: bool = False | ||
max_workers: int = 4 | ||
|
||
|
||
class MlxPretrainedMixin: | ||
"""Mixin class for loading pretrained models in MLX format.""" | ||
|
||
def from_pretrained( | ||
self, | ||
model_name_or_path: str, | ||
cache_dir: Optional[str] = None, | ||
revision: Optional[str] = "main", | ||
revision: str = "main", | ||
float16: bool = False, | ||
huggingface_model_architecture: Optional[Callable] = None, | ||
trust_remote_code: bool = False, | ||
max_workers: int = 4, | ||
): | ||
if huggingface_model_architecture: | ||
architecture = huggingface_model_architecture | ||
elif hasattr(self.config, "architectures"): | ||
architecture = self.config.architectures[0] | ||
else: | ||
raise ValueError("No architecture found for loading this model") | ||
) -> "MlxPretrainedMixin": | ||
""" | ||
Load a pretrained model from HuggingFace Hub or local path. | ||
transformers_module = importlib.import_module("transformers") | ||
_class = getattr(transformers_module, architecture, None) | ||
Args: | ||
model_name_or_path: HuggingFace model name or path to local model | ||
cache_dir: Directory to store downloaded models | ||
revision: Git revision to use when downloading | ||
float16: Whether to convert model to float16 | ||
huggingface_model_architecture: Custom model architecture class | ||
trust_remote_code: Whether to trust remote code when loading | ||
max_workers: Number of worker threads for tensor conversion | ||
if not _class: | ||
raise ValueError(f"Could not find the class for {architecture}") | ||
Returns: | ||
Self with loaded model weights | ||
""" | ||
config = ModelLoadingConfig( | ||
model_name_or_path=model_name_or_path, | ||
cache_dir=cache_dir, | ||
revision=revision, | ||
float16=float16, | ||
trust_remote_code=trust_remote_code, | ||
max_workers=max_workers, | ||
) | ||
|
||
dtype = mx.float16 if float16 else mx.float32 | ||
logger.info(f"Loading model using the following configuration {config}") | ||
|
||
logger.info(f"Loading model from {model_name_or_path}") | ||
model = _class.from_pretrained( | ||
model_name_or_path, trust_remote_code=trust_remote_code | ||
safe_tensor_files = fs.glob( | ||
f"{config.model_name_or_path}/*.safetensors", | ||
**{"revision": config.revision}, | ||
) | ||
safe_tensor_files = [f.split("/")[-1] for f in safe_tensor_files] | ||
|
||
# # save the tensors | ||
logger.info("Converting model tensors to Mx arrays") | ||
import concurrent.futures | ||
if not safe_tensor_files: | ||
raise ValueError("No safe tensor files found for this model") | ||
|
||
def convert_tensor(key, tensor, dtype): | ||
return key, mx.array(tensor.numpy()).astype(dtype) | ||
download_path = snapshot_download( | ||
repo_id=config.model_name_or_path, | ||
allow_patterns="*.safetensors", | ||
max_workers=config.max_workers, | ||
revision=config.revision, | ||
token=HF_TOKEN, | ||
) | ||
dtype = mx.float16 if config.float16 else mx.float32 | ||
|
||
tensors = {} | ||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: | ||
futures = [] | ||
for key, tensor in model.state_dict().items(): | ||
future = executor.submit(convert_tensor, key, tensor, dtype) | ||
futures.append(future) | ||
for file in safe_tensor_files: | ||
file_path = Path(download_path) / file | ||
with file_path.open("rb") as f: | ||
state_dict = mx.load(f) | ||
|
||
for future in concurrent.futures.as_completed(futures): | ||
key, converted_tensor = future.result() | ||
tensors[key] = converted_tensor | ||
tensors.update(state_dict) | ||
|
||
tensors = [(key, tensor) for key, tensor in tensors.items()] | ||
tensors = {k: v.astype(dtype) for k, v in tensors.items()} | ||
|
||
self.update(tree_unflatten(tensors)) | ||
# Update model weights | ||
self.update(tree_unflatten(list(tensors.items()))) | ||
return self |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.