Skip to content

Commit

Permalink
Add a mixing for loading models directly from Huggingface (#6)
Browse files Browse the repository at this point in the history
* add a mixin for loading directly from HGF

* refactor
  • Loading branch information
ToluClassics authored Apr 22, 2024
1 parent 3751ecc commit bcbb2e4
Show file tree
Hide file tree
Showing 11 changed files with 299 additions and 196 deletions.
7 changes: 4 additions & 3 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ jobs:
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest and unittest
run: |
python -m unittest
# Commenting out until I figure out how to install mlx in the CI
# - name: Test with pytest and unittest
# run: |
# python -m unittest
25 changes: 5 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,32 +20,17 @@ pip install mlx-transformers

A list of the available models can be found in the `mlx_transformers.models` module and are also listed in the [section below](#available-model-architectures). The following example demonstrates how to load a model and use it for inference:

- First you need to download and convert the model checkpoint to MLX format
To do this from huggingface

```python

from transformers import BertModel
from mlx_transformers.models.utils import convert

model_name_or_path = "bert-base-uncased"
mlx_checkpoint = "bert-base-uncased.npz"

convert("bert-base-uncased", "bert-base-uncased.npz", BertModel)
```
This will download the model checkpoint from huggingface and convert it to MLX format.

- Now you can load the model using MLX transformers in few lines of code
- You can load the model using MLX transformers in few lines of code

```python
from transformers import BertConfig, BertTokenizer
from mlx_transformers.models import BertModel as MLXBertModel
from mlx_transformers.models import BertForMaskedLM as MLXBertForMaskedLM

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
config = BertConfig.from_pretrained("bert-base-uncased")
model = MLXBertModel(config)

model.load_weights("bert-base-uncased.npz", strict=True)
model = BertForMaskedLM(config)
model.from_pretrained("bert-base-uncased")

sample_input = "Hello, world!"
inputs = tokenizer(sample_input, return_tensors="np")
Expand Down Expand Up @@ -76,7 +61,7 @@ tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v
config = AutoConfig.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')

model = MLXBertModel(config)
model.load_weights("all-MiniLM-L6-v2.npz", strict=True)
model.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

inputs = tokenizer(sentences, return_tensors="np", padding=True, truncation=True)
outputs = model(**inputs)
Expand Down
11 changes: 6 additions & 5 deletions src/mlx_transformers/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from .bert import (
BertModel,
BertForMaskedLM,
BertForQuestionAnswering,
BertForSequenceClassification,
BertForTokenClassification,
BertForQuestionAnswering
BertModel,
)
from .llama import LlamaForCausalLM, LlamaModel
from .roberta import (
RobertaModel,
RobertaForQuestionAnswering,
RobertaForSequenceClassification,
RobertaForTokenClassification,
RobertaForQuestionAnswering
)
RobertaModel,
)
from .xlm_roberta import XLMRobertaModel
44 changes: 44 additions & 0 deletions src/mlx_transformers/models/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import importlib
import os
from typing import Optional

import mlx.core as mx
from huggingface_hub import HfFileSystem, hf_hub_download
from mlx.utils import tree_unflatten
from safetensors.numpy import load_file
from transformers import AutoConfig
from transformers.utils.import_utils import is_safetensors_available

CONFIG_FILE = "config.json"
WEIGHTS_FILE_NAME = "model.safetensors"


def _sanitize_keys(key):
keys = key.split(".")
return ".".join(keys[1:])


class MlxPretrainedMixin:

def from_pretrained(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
revision: Optional[str] = "main",
float16: bool = False,
):

architecture = self.config.architectures[0]
transformers_module = importlib.import_module("transformers")

_class = getattr(transformers_module, architecture, None)

model = _class.from_pretrained(model_name_or_path)
# # save the tensors
tensors = {
key: mx.array(tensor.numpy()) for key, tensor in model.state_dict().items()
}

tensors = [(key, tensor) for key, tensor in tensors.items()]

self.update(tree_unflatten(tensors))
123 changes: 107 additions & 16 deletions src/mlx_transformers/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import mlx.nn as nn
from transformers import BertConfig

from .base import MlxPretrainedMixin
from .modelling_outputs import *
from .utils import ACT2FN, get_extended_attention_mask

Expand Down Expand Up @@ -319,14 +320,14 @@ def __init__(self, config):
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

def __call__(self, hidden_states: mx.tanh) -> mx.tanh:
def __call__(self, hidden_states: mx.array) -> mx.array:
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states


class BertModel(nn.Module):
class BertModel(nn.Module, MlxPretrainedMixin):

def __init__(self, config, add_pooling_layer=True):
super().__init__()
Expand Down Expand Up @@ -423,7 +424,94 @@ def __call__(
)


class BertForSequenceClassification(nn.Module):
class BertLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.transform = BertPredictionHeadTransform(config)

# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

self.bias = mx.zeros(config.vocab_size)

# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def __call__(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states


class BertOnlyMLMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = BertLMPredictionHead(config)

def __call__(self, sequence_output: mx.array) -> mx.array:
prediction_scores = self.predictions(sequence_output)
return prediction_scores


class BertForMaskedLM(nn.Module, MlxPretrainedMixin):
def __init__(self, config):
super().__init__()
self.config = config
self.bert = BertModel(config, add_pooling_layer=False)
self.cls = BertOnlyMLMHead(config)

def __call__(
self,
input_ids: Optional[mx.array] = None,
attention_mask: Optional[mx.array] = None,
token_type_ids: Optional[mx.array] = None,
position_ids: Optional[mx.array] = None,
use_cache: Optional[bool] = None,
labels: Optional[mx.array] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):

return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

sequence_output = outputs.last_hidden_state
prediction_scores = self.cls(sequence_output)

masked_lm_loss = None
if labels is not None:
loss_fct = nn.losses.cross_entropy # -100 index = padding token
masked_lm_loss = loss_fct(
prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
)

if not return_dict:
output = (prediction_scores,) + outputs[2:]
return (
((masked_lm_loss,) + output) if masked_lm_loss is not None else output
)

return MaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


class BertForSequenceClassification(nn.Module, MlxPretrainedMixin):
def __init__(self, config):
super().__init__()
self.num_labels = config.num_labels
Expand All @@ -437,6 +525,7 @@ def __init__(self, config):
)
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.train = config.train if hasattr(config, "train") else False

def __call__(
self,
Expand Down Expand Up @@ -471,7 +560,7 @@ def __call__(

pooled_output = outputs.pooler_output

if self.config.train:
if self.train:
pooled_output = self.dropout(pooled_output)

logits = self.classifier(pooled_output)
Expand Down Expand Up @@ -512,20 +601,21 @@ def __call__(
)


class BertForTokenClassification(nn.Module):
class BertForTokenClassification(nn.Module, MlxPretrainedMixin):
def __init__(self, config):
super().__init__()
self.num_labels = config.num_labels

self.config = config
self.bert = BertModel(config, add_pooling_layer=False)
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
config.classifier_dropout
if config.classifier_dropout is not None
else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)


def __call__(
self,
input_ids: Optional[mx.array] = None,
Expand All @@ -541,8 +631,9 @@ def __call__(
labels (`mx.array` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)

outputs = self.bert(
input_ids,
Expand All @@ -561,7 +652,7 @@ def __call__(

loss = None
if labels is not None:
loss_fct = nn.losses.cross_entropy()
loss_fct = nn.losses.cross_entropy
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

if not return_dict:
Expand All @@ -573,10 +664,10 @@ def __call__(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
)


class BertForQuestionAnswering(nn.Module):
class BertForQuestionAnswering(nn.Module, MlxPretrainedMixin):
def __init__(self, config):
super().__init__()
self.num_labels = config.num_labels
Expand All @@ -585,7 +676,6 @@ def __init__(self, config):
self.bert = BertModel(config, add_pooling_layer=False)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)


def __call__(
self,
input_ids: Optional[mx.array] = None,
Expand All @@ -608,7 +698,9 @@ def __call__(
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)

outputs = self.bert(
input_ids,
Expand Down Expand Up @@ -640,7 +732,7 @@ def __call__(
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)

loss_fct = nn.losses.cross_entropy(ignore_index=ignored_index)
loss_fct = nn.losses.cross_entropy
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
Expand All @@ -656,4 +748,3 @@ def __call__(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

Loading

0 comments on commit bcbb2e4

Please sign in to comment.