diff --git a/README.md b/README.md index 356a290..5e4ff35 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,8 @@ A list of the available models can be found in the `mlx_transformers.models` mod config = BertConfig.from_pretrained("bert-base-uncased") model = MLXBertModel(config) + model.load_weights("bert-base-uncased.npz", strict=True) + sample_input = "Hello, world!" inputs = tokenizer(sample_input, return_tensors="np") outputs = model(**inputs) diff --git a/setup.py b/setup.py index 1c5ecb3..c0a36f8 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ setuptools.setup( name="mlx-transformers", - version="0.0.1", + version="0.1.0", author="Ogundepo Odunayo", author_email="ogundepoodunayo@gmail.com", description=description, diff --git a/src/mlx_transformers/models/__init__.py b/src/mlx_transformers/models/__init__.py index 77c259d..9f8cc1d 100644 --- a/src/mlx_transformers/models/__init__.py +++ b/src/mlx_transformers/models/__init__.py @@ -1,4 +1,4 @@ -from .bert import BertModel +from .bert import BertForSequenceClassification, BertModel from .roberta import ( RobertaModel, RobertaForSequenceClassification, diff --git a/src/mlx_transformers/models/bert.py b/src/mlx_transformers/models/bert.py index 347d535..c4ffad5 100644 --- a/src/mlx_transformers/models/bert.py +++ b/src/mlx_transformers/models/bert.py @@ -425,4 +425,86 @@ def __call__( class BertForSequenceClassification(nn.Module): def __init__(self, config): - pass + super().__init__() + self.num_labels = config.num_labels + self.config = config + + self.bert = BertModel(config) + classifier_dropout = ( + config.classifier_dropout + if config.classifier_dropout is not None + else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + def __call__( + self, + input_ids: Optional[mx.array] = None, + attention_mask: Optional[mx.array] = None, + token_type_ids: Optional[mx.array] = None, + position_ids: Optional[mx.array] = None, + labels: Optional[mx.array] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[mx.array], SequenceClassifierOutput]: + r""" + labels (`array` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs.pooler_output + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and ( + labels.dtype == mx.long or labels.dtype == mx.int + ): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + if self.num_labels == 1: + loss = nn.losses.mse_loss(logits.squeeze(), labels.squeeze()) + else: + loss = nn.losses.mse_loss(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss = nn.losses.cross_entropy( + logits.view(-1, self.num_labels), labels.view(-1) + ) + elif self.config.problem_type == "multi_label_classification": + loss = nn.losses.binary_cross_entropy(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/mlx_transformers/models/modelling_outputs.py b/src/mlx_transformers/models/modelling_outputs.py index 6e23ca1..f61839a 100644 --- a/src/mlx_transformers/models/modelling_outputs.py +++ b/src/mlx_transformers/models/modelling_outputs.py @@ -29,7 +29,27 @@ class SequenceClassifierOutput: logits: mx.array = None hidden_states: Optional[Tuple[mx.array, ...]] = None attentions: Optional[Tuple[mx.array, ...]] = None - + +@dataclass +class SequenceClassifierOutputWithPast: + loss: Optional[mx.array] = None + logits: mx.array = None + past_key_values: Optional[Tuple[Tuple[mx.array]]] = None + hidden_states: Optional[Tuple[mx.array, ...]] = None + attentions: Optional[Tuple[mx.array, ...]] = None + + +@dataclass +class Seq2SeqSequenceClassifierOutput: + loss: Optional[mx.array] = None + logits: mx.array = None + past_key_values: Optional[Tuple[Tuple[mx.array]]] = None + decoder_hidden_states: Optional[Tuple[mx.array, ...]] = None + decoder_attentions: Optional[Tuple[mx.array, ...]] = None + cross_attentions: Optional[Tuple[mx.array, ...]] = None + encoder_last_hidden_state: Optional[mx.array] = None + encoder_hidden_states: Optional[Tuple[mx.array, ...]] = None + encoder_attentions: Optional[Tuple[mx.array, ...]] = None @dataclass class TokenClassifierOutput: @@ -46,4 +66,3 @@ class QuestionAnsweringModelOutput: end_logits: mx.array = None hidden_states: Optional[Tuple[mx.array, ...]] = None attentions: Optional[Tuple[mx.array, ...]] = None - \ No newline at end of file diff --git a/tests/test_bert.py b/tests/test_bert.py index 9b4c882..e66cc8a 100644 --- a/tests/test_bert.py +++ b/tests/test_bert.py @@ -4,26 +4,34 @@ import mlx.core as mx import numpy as np -from transformers import BertConfig, BertModel, BertTokenizer - +from transformers import ( + BertConfig, + BertForSequenceClassification, + BertModel, + BertTokenizer, +) + +from src.mlx_transformers.models import ( + BertForSequenceClassification as MlxBertForSequenceClassification, +) from src.mlx_transformers.models import BertModel as MlxBertModel from src.mlx_transformers.models.utils import convert -def load_model(model_name: str) -> MlxBertModel: +def load_model( + model_name: str, config, hgf_model_class, mlx_model_class +) -> MlxBertModel: current_directory = os.path.dirname(os.path.realpath(__file__)) weights_path = os.path.join( current_directory, "model_checkpoints", model_name.replace("/", "-") + ".npz" ) if not os.path.exists(weights_path): - convert(model_name, weights_path, BertModel) + convert(model_name, weights_path, hgf_model_class) - config = BertConfig.from_pretrained(model_name) - model = MlxBertModel(config) + model = mlx_model_class(config) model.load_weights(weights_path, strict=True) - return model @@ -37,7 +45,9 @@ class TestMlxRoberta(unittest.TestCase): @classmethod def setUpClass(cls) -> None: cls.model_name = "bert-base-uncased" - cls.model = load_model(cls.model_name) + config = BertConfig.from_pretrained(cls.model_name) + + cls.model = load_model(cls.model_name, config, BertModel, MlxBertModel) cls.tokenizer = BertTokenizer.from_pretrained(cls.model_name) cls.input_text = "Hello, my dog is cute" @@ -68,6 +78,24 @@ def test_model_output_hgf(self): self.assertTrue(np.allclose(outputs_mlx, outputs_hgf, atol=1e-4)) + def test_sequence_classification(self): + inputs = self.tokenizer( + self.input_text, return_tensors="np", padding=True, truncation=True + ) + + inputs = {key: mx.array(v) for key, v in inputs.items()} + model_name = "textattack/bert-base-uncased-imdb" + config = BertConfig.from_pretrained(model_name) + model = load_model( + model_name, + config, + BertForSequenceClassification, + MlxBertForSequenceClassification, + ) + outputs = model(**inputs) + + self.assertIsInstance(outputs.logits, mx.array) + if __name__ == "__main__": unittest.main()