Skip to content

Commit

Permalink
Merge pull request #2 from ToluClassics/sequence_classification
Browse files Browse the repository at this point in the history
Add Bert Classification Model
  • Loading branch information
ToluClassics authored Apr 19, 2024
2 parents 68e0523 + 1aaef6a commit 5cb5edb
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/mlx_transformers/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .bert import BertModel
from .bert import BertForSequenceClassification, BertModel
from .roberta import RobertaModel
from .xlm_roberta import XLMRobertaModel
84 changes: 83 additions & 1 deletion src/mlx_transformers/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,4 +425,86 @@ def __call__(

class BertForSequenceClassification(nn.Module):
def __init__(self, config):
pass
super().__init__()
self.num_labels = config.num_labels
self.config = config

self.bert = BertModel(config)
classifier_dropout = (
config.classifier_dropout
if config.classifier_dropout is not None
else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)

def __call__(
self,
input_ids: Optional[mx.array] = None,
attention_mask: Optional[mx.array] = None,
token_type_ids: Optional[mx.array] = None,
position_ids: Optional[mx.array] = None,
labels: Optional[mx.array] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[mx.array], SequenceClassifierOutput]:
r"""
labels (`array` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)

outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

pooled_output = outputs.pooler_output

pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)

loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (
labels.dtype == mx.long or labels.dtype == mx.int
):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"

if self.config.problem_type == "regression":
if self.num_labels == 1:
loss = nn.losses.mse_loss(logits.squeeze(), labels.squeeze())
else:
loss = nn.losses.mse_loss(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss = nn.losses.cross_entropy(
logits.view(-1, self.num_labels), labels.view(-1)
)
elif self.config.problem_type == "multi_label_classification":
loss = nn.losses.binary_cross_entropy(logits, labels)

if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
38 changes: 38 additions & 0 deletions src/mlx_transformers/models/modelling_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,41 @@ class BaseModelOutputWithPoolingAndCrossAttentions:
past_key_values: Optional[Tuple[Tuple[mx.array]]] = None
attentions: Optional[Tuple[mx.array, ...]] = None
cross_attentions: Optional[Tuple[mx.array, ...]] = None


@dataclass
class SequenceClassifierOutputWithPast:
loss: Optional[mx.array] = None
logits: mx.array = None
past_key_values: Optional[Tuple[Tuple[mx.array]]] = None
hidden_states: Optional[Tuple[mx.array, ...]] = None
attentions: Optional[Tuple[mx.array, ...]] = None


@dataclass
class Seq2SeqSequenceClassifierOutput:
loss: Optional[mx.array] = None
logits: mx.array = None
past_key_values: Optional[Tuple[Tuple[mx.array]]] = None
decoder_hidden_states: Optional[Tuple[mx.array, ...]] = None
decoder_attentions: Optional[Tuple[mx.array, ...]] = None
cross_attentions: Optional[Tuple[mx.array, ...]] = None
encoder_last_hidden_state: Optional[mx.array] = None
encoder_hidden_states: Optional[Tuple[mx.array, ...]] = None
encoder_attentions: Optional[Tuple[mx.array, ...]] = None


@dataclass
class TokenClassifierOutput:
loss: Optional[mx.array] = None
logits: mx.array = None
hidden_states: Optional[Tuple[mx.array, ...]] = None
attentions: Optional[Tuple[mx.array, ...]] = None


@dataclass
class SequenceClassifierOutput:
loss: Optional[mx.array] = None
logits: mx.array = None
hidden_states: Optional[Tuple[mx.array, ...]] = None
attentions: Optional[Tuple[mx.array, ...]] = None
44 changes: 36 additions & 8 deletions tests/test_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,34 @@

import mlx.core as mx
import numpy as np
from transformers import BertConfig, BertModel, BertTokenizer

from transformers import (
BertConfig,
BertForSequenceClassification,
BertModel,
BertTokenizer,
)

from src.mlx_transformers.models import (
BertForSequenceClassification as MlxBertForSequenceClassification,
)
from src.mlx_transformers.models import BertModel as MlxBertModel
from src.mlx_transformers.models.utils import convert


def load_model(model_name: str) -> MlxBertModel:
def load_model(
model_name: str, config, hgf_model_class, mlx_model_class
) -> MlxBertModel:
current_directory = os.path.dirname(os.path.realpath(__file__))
weights_path = os.path.join(
current_directory, "model_checkpoints", model_name.replace("/", "-") + ".npz"
)

if not os.path.exists(weights_path):
convert(model_name, weights_path, BertModel)
convert(model_name, weights_path, hgf_model_class)

config = BertConfig.from_pretrained(model_name)
model = MlxBertModel(config)
model = mlx_model_class(config)

model.load_weights(weights_path, strict=True)

return model


Expand All @@ -37,7 +45,9 @@ class TestMlxRoberta(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
cls.model_name = "bert-base-uncased"
cls.model = load_model(cls.model_name)
config = BertConfig.from_pretrained(cls.model_name)

cls.model = load_model(cls.model_name, config, BertModel, MlxBertModel)
cls.tokenizer = BertTokenizer.from_pretrained(cls.model_name)
cls.input_text = "Hello, my dog is cute"

Expand Down Expand Up @@ -68,6 +78,24 @@ def test_model_output_hgf(self):

self.assertTrue(np.allclose(outputs_mlx, outputs_hgf, atol=1e-4))

def test_sequence_classification(self):
inputs = self.tokenizer(
self.input_text, return_tensors="np", padding=True, truncation=True
)

inputs = {key: mx.array(v) for key, v in inputs.items()}
model_name = "textattack/bert-base-uncased-imdb"
config = BertConfig.from_pretrained(model_name)
model = load_model(
model_name,
config,
BertForSequenceClassification,
MlxBertForSequenceClassification,
)
outputs = model(**inputs)

self.assertIsInstance(outputs.logits, mx.array)


if __name__ == "__main__":
unittest.main()

0 comments on commit 5cb5edb

Please sign in to comment.