-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a mixing for loading models directly from Huggingface (#6)
* add a mixin for loading directly from HGF * refactor
- Loading branch information
1 parent
3751ecc
commit bcbb2e4
Showing
11 changed files
with
299 additions
and
196 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,15 @@ | ||
from .bert import ( | ||
BertModel, | ||
BertForMaskedLM, | ||
BertForQuestionAnswering, | ||
BertForSequenceClassification, | ||
BertForTokenClassification, | ||
BertForQuestionAnswering | ||
BertModel, | ||
) | ||
from .llama import LlamaForCausalLM, LlamaModel | ||
from .roberta import ( | ||
RobertaModel, | ||
RobertaForQuestionAnswering, | ||
RobertaForSequenceClassification, | ||
RobertaForTokenClassification, | ||
RobertaForQuestionAnswering | ||
) | ||
RobertaModel, | ||
) | ||
from .xlm_roberta import XLMRobertaModel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import importlib | ||
import os | ||
from typing import Optional | ||
|
||
import mlx.core as mx | ||
from huggingface_hub import HfFileSystem, hf_hub_download | ||
from mlx.utils import tree_unflatten | ||
from safetensors.numpy import load_file | ||
from transformers import AutoConfig | ||
from transformers.utils.import_utils import is_safetensors_available | ||
|
||
CONFIG_FILE = "config.json" | ||
WEIGHTS_FILE_NAME = "model.safetensors" | ||
|
||
|
||
def _sanitize_keys(key): | ||
keys = key.split(".") | ||
return ".".join(keys[1:]) | ||
|
||
|
||
class MlxPretrainedMixin: | ||
|
||
def from_pretrained( | ||
self, | ||
model_name_or_path: str, | ||
cache_dir: Optional[str] = None, | ||
revision: Optional[str] = "main", | ||
float16: bool = False, | ||
): | ||
|
||
architecture = self.config.architectures[0] | ||
transformers_module = importlib.import_module("transformers") | ||
|
||
_class = getattr(transformers_module, architecture, None) | ||
|
||
model = _class.from_pretrained(model_name_or_path) | ||
# # save the tensors | ||
tensors = { | ||
key: mx.array(tensor.numpy()) for key, tensor in model.state_dict().items() | ||
} | ||
|
||
tensors = [(key, tensor) for key, tensor in tensors.items()] | ||
|
||
self.update(tree_unflatten(tensors)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.