diff --git a/src/mlx_transformers/models/__init__.py b/src/mlx_transformers/models/__init__.py index 3f8b5a3..6b5d80b 100644 --- a/src/mlx_transformers/models/__init__.py +++ b/src/mlx_transformers/models/__init__.py @@ -13,4 +13,9 @@ RobertaForTokenClassification, RobertaModel, ) -from .xlm_roberta import XLMRobertaModel +from .xlm_roberta import ( + XLMRobertaForQuestionAnswering, + XLMRobertaForSequenceClassification, + XLMRobertaForTokenClassification, + XLMRobertaModel, +) diff --git a/src/mlx_transformers/models/base.py b/src/mlx_transformers/models/base.py index 2979bc6..4b30322 100644 --- a/src/mlx_transformers/models/base.py +++ b/src/mlx_transformers/models/base.py @@ -2,7 +2,6 @@ import os from typing import Callable, Optional - import mlx.core as mx from huggingface_hub import HfFileSystem, hf_hub_download from mlx.utils import tree_unflatten diff --git a/src/mlx_transformers/models/roberta.py b/src/mlx_transformers/models/roberta.py index 272b2c1..54a7b3a 100644 --- a/src/mlx_transformers/models/roberta.py +++ b/src/mlx_transformers/models/roberta.py @@ -522,7 +522,9 @@ def __call__( if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == mx.array): + elif self.num_labels > 1 and ( + labels.dtype == mx.long or labels.dtype == mx.int + ): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" diff --git a/src/mlx_transformers/models/xlm_roberta.py b/src/mlx_transformers/models/xlm_roberta.py index 313b92d..52f2dbf 100644 --- a/src/mlx_transformers/models/xlm_roberta.py +++ b/src/mlx_transformers/models/xlm_roberta.py @@ -8,6 +8,7 @@ import mlx.nn as nn from transformers import XLMRobertaConfig +from .base import MlxPretrainedMixin from .modelling_outputs import * from .utils import ACT2FN, get_extended_attention_mask @@ -354,7 +355,7 @@ def __call__(self, hidden_states: mx.array) -> mx.array: return pooled_output -class XLMRobertaModel(nn.Module): +class XLMRobertaModel(nn.Module, MlxPretrainedMixin): def __init__(self, config, add_pooling_layer=True): super().__init__() @@ -451,6 +452,241 @@ def __call__( ) -class XLMRobertaForSequenceClassification(nn.Module): +class XLMRobertaClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + classifier_dropout = ( + config.classifier_dropout + if config.classifier_dropout is not None + else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + def __call__(self, features, **kwargs): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x) + x = self.dense(x) + x = mx.tanh(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +class XLMRobertaForSequenceClassification(nn.Module, MlxPretrainedMixin): + def __init__(self, config): + super().__init__() + self.num_labels = config.num_labels + self.config = config + + self.roberta = XLMRobertaModel(config, add_pooling_layer=False) + self.classifier = XLMRobertaClassificationHead(config) + + def __call__( + self, + input_ids: Optional[mx.array] = None, + attention_mask: Optional[mx.array] = None, + token_type_ids: Optional[mx.array] = None, + position_ids: Optional[mx.array] = None, + labels: Optional[mx.array] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[mx.array], SequenceClassifierOutput]: + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs.last_hidden_state + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and ( + labels.dtype == mx.long or labels.dtype == mx.int + ): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = nn.losses.mse_loss + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = nn.losses.cross_entropy + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = nn.losses.binary_cross_entropy + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class XLMRobertaForTokenClassification(nn.Module, MlxPretrainedMixin): + def __init__(self, config): + super().__init__() + self.num_labels = config.num_labels + self.config = config + + self.roberta = XLMRobertaModel(config, add_pooling_layer=False) + classifier_dropout = ( + config.classifier_dropout + if config.classifier_dropout is not None + else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + def __call__( + self, + input_ids: Optional[mx.array] = None, + attention_mask: Optional[mx.array] = None, + token_type_ids: Optional[mx.array] = None, + position_ids: Optional[mx.array] = None, + labels: Optional[mx.array] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[mx.array], TokenClassifierOutput]: + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs.last_hidden_state + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + loss_fct = nn.losses.cross_entropy + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class XLMRobertaForQuestionAnswering(nn.Module, MlxPretrainedMixin): def __init__(self, config): - pass + super().__init__() + self.num_labels = config.num_labels + self.config = config + + self.roberta = XLMRobertaModel(config, add_pooling_layer=False) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + def __call__( + self, + input_ids: Optional[mx.array] = None, + attention_mask: Optional[mx.array] = None, + token_type_ids: Optional[mx.array] = None, + position_ids: Optional[mx.array] = None, + start_positions: Optional[mx.array] = None, + end_positions: Optional[mx.array] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[mx.array], QuestionAnsweringModelOutput]: + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs.last_hidden_state + + logits = self.qa_outputs(sequence_output) + splits = logits.split(2, axis=-1) + start_logits, end_logits = splits[0], splits[1] + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, + # we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = nn.losses.cross_entropy + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/tests/test_roberta.py b/tests/test_roberta.py index 8ae8128..efa250a 100644 --- a/tests/test_roberta.py +++ b/tests/test_roberta.py @@ -5,30 +5,33 @@ import mlx.core as mx import numpy as np from transformers import ( + AutoTokenizer, RobertaConfig, - RobertaModel, - RobertaTokenizer, + RobertaForQuestionAnswering, RobertaForSequenceClassification, RobertaForTokenClassification, - RobertaForQuestionAnswering, - AutoTokenizer + RobertaModel, + RobertaTokenizer, ) +from src.mlx_transformers.models import ( + RobertaForQuestionAnswering as MlxRobertaForQuestionAnswering, +) +from src.mlx_transformers.models import ( + RobertaForSequenceClassification as MlxRobertaForSequenceClassification, +) +from src.mlx_transformers.models import ( + RobertaForTokenClassification as MlxRobertaForTokenClassification, +) from src.mlx_transformers.models import RobertaModel as MlxRobertaModel -from src.mlx_transformers.models import RobertaForSequenceClassification as MlxRobertaForSequenceClassification -from src.mlx_transformers.models import RobertaForTokenClassification as MlxRobertaForTokenClassification -from src.mlx_transformers.models import RobertaForQuestionAnswering as MlxRobertaForQuestionAnswering from src.mlx_transformers.models.utils import convert def load_model(model_name: str, mlx_model_class, hgf_model_class): current_directory = os.path.dirname(os.path.realpath(__file__)) weights_path = os.path.join( - current_directory, - "model_checkpoints", - model_name.replace( - "/", - "-") + ".npz") + current_directory, "model_checkpoints", model_name.replace("/", "-") + ".npz" + ) if not os.path.exists(weights_path): convert(model_name, weights_path, hgf_model_class) @@ -54,10 +57,7 @@ def setUpClass(cls) -> None: cls.model_name = "FacebookAI/roberta-base" cls.model_class = MlxRobertaModel cls.hgf_model_class = RobertaModel - cls.model = load_model( - cls.model_name, - cls.model_class, - cls.hgf_model_class) + cls.model = load_model(cls.model_name, cls.model_class, cls.hgf_model_class) cls.tokenizer = RobertaTokenizer.from_pretrained(cls.model_name) cls.input_text = "Hello, my dog is cute" @@ -96,10 +96,7 @@ def setUpClass(cls) -> None: cls.model_name = "cardiffnlp/twitter-roberta-base-emotion" cls.model_class = MlxRobertaForSequenceClassification cls.hgf_model_class = RobertaForSequenceClassification - cls.model = load_model( - cls.model_name, - cls.model_class, - cls.hgf_model_class) + cls.model = load_model(cls.model_name, cls.model_class, cls.hgf_model_class) cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) cls.input_text = "Hello, my dog is cute" @@ -143,10 +140,7 @@ def setUpClass(cls) -> None: cls.model_name = "Jean-Baptiste/roberta-large-ner-english" cls.model_class = MlxRobertaForTokenClassification cls.hgf_model_class = RobertaForTokenClassification - cls.model = load_model( - cls.model_name, - cls.model_class, - cls.hgf_model_class) + cls.model = load_model(cls.model_name, cls.model_class, cls.hgf_model_class) cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) cls.input_text = "HuggingFace is a company based in Paris and New York" @@ -156,7 +150,8 @@ def test_forward(self) -> None: return_tensors="np", padding=True, truncation=True, - add_special_tokens=False) + add_special_tokens=False, + ) inputs = {key: mx.array(v) for key, v in inputs.items()} outputs = self.model(**inputs) @@ -168,32 +163,36 @@ def test_model_output_hgf(self): return_tensors="np", padding=True, truncation=True, - add_special_tokens=False) + add_special_tokens=False, + ) inputs_mlx = {key: mx.array(v) for key, v in inputs_mlx.items()} outputs_mlx = self.model(**inputs_mlx) outputs_mlx = np.array(outputs_mlx.logits) mlx_predicted_token_class_ids = outputs_mlx.argmax(-1) mlx_predicted_tokens_classes = [ - self.model.config.id2label[t.item()] for t in mlx_predicted_token_class_ids[0]] + self.model.config.id2label[t.item()] + for t in mlx_predicted_token_class_ids[0] + ] inputs_hgf = self.tokenizer( self.input_text, return_tensors="pt", padding=True, truncation=True, - add_special_tokens=False) + add_special_tokens=False, + ) hgf_model = load_hgf_model(self.model_name, self.hgf_model_class) outputs_hgf = hgf_model(**inputs_hgf) outputs_hgf = outputs_hgf.logits hgf_predicted_token_class_ids = outputs_hgf.argmax(-1) hgf_predicted_tokens_classes = [ - hgf_model.config.id2label[t.item()] for t in hgf_predicted_token_class_ids[0]] + hgf_model.config.id2label[t.item()] + for t in hgf_predicted_token_class_ids[0] + ] - self.assertEqual( - mlx_predicted_tokens_classes, - hgf_predicted_tokens_classes) + self.assertEqual(mlx_predicted_tokens_classes, hgf_predicted_tokens_classes) class TestMlxRobertaForQuestionAnswering(unittest.TestCase): @@ -203,10 +202,7 @@ def setUpClass(cls) -> None: cls.model_name = "deepset/roberta-base-squad2" cls.model_class = MlxRobertaForQuestionAnswering cls.hgf_model_class = RobertaForQuestionAnswering - cls.model = load_model( - cls.model_name, - cls.model_class, - cls.hgf_model_class) + cls.model = load_model(cls.model_name, cls.model_class, cls.hgf_model_class) cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) cls.input_question = "Who was Jim Henson?" cls.input_text = "Jim Henson was a nice puppet" @@ -231,11 +227,13 @@ def test_model_output_hgf(self): mlx_answer_start_index = outputs_mlx.start_logits.argmax().item() mlx_answer_end_index = outputs_mlx.end_logits.argmax().item() - mlx_predict_answer_tokens = inputs_mlx['input_ids'].tolist() - mlx_predict_answer_tokens = mlx_predict_answer_tokens[ - 0][mlx_answer_start_index: mlx_answer_end_index + 1] + mlx_predict_answer_tokens = inputs_mlx["input_ids"].tolist() + mlx_predict_answer_tokens = mlx_predict_answer_tokens[0][ + mlx_answer_start_index : mlx_answer_end_index + 1 + ] mlx_answer = self.tokenizer.decode( - mlx_predict_answer_tokens, skip_special_tokens=True) + mlx_predict_answer_tokens, skip_special_tokens=True + ) inputs_hgf = self.tokenizer( self.input_question, self.input_text, return_tensors="pt" @@ -246,10 +244,12 @@ def test_model_output_hgf(self): hgf_answer_start_index = outputs_hgf.start_logits.argmax() hgf_answer_end_index = outputs_hgf.end_logits.argmax() - hgf_predict_answer_tokens = inputs_hgf.input_ids[0, - hgf_answer_start_index: hgf_answer_end_index + 1] + hgf_predict_answer_tokens = inputs_hgf.input_ids[ + 0, hgf_answer_start_index : hgf_answer_end_index + 1 + ] hgf_answer = self.tokenizer.decode( - hgf_predict_answer_tokens, skip_special_tokens=True) + hgf_predict_answer_tokens, skip_special_tokens=True + ) self.assertEqual(mlx_answer, hgf_answer) diff --git a/tests/test_xlm_roberta.py b/tests/test_xlm_roberta.py index 5b6f1ed..0a94a12 100644 --- a/tests/test_xlm_roberta.py +++ b/tests/test_xlm_roberta.py @@ -4,31 +4,48 @@ import mlx.core as mx import numpy as np -from transformers import XLMRobertaConfig, XLMRobertaModel, XLMRobertaTokenizer +from transformers import ( + AutoTokenizer, + XLMRobertaConfig, + XLMRobertaForQuestionAnswering, + XLMRobertaForSequenceClassification, + XLMRobertaForTokenClassification, + XLMRobertaModel, + XLMRobertaTokenizer, +) +from src.mlx_transformers.models import ( + XLMRobertaForQuestionAnswering as MlxXLMRobertaForQuestionAnswering, +) +from src.mlx_transformers.models import ( + XLMRobertaForSequenceClassification as MlxXLMRobertaForSequenceClassification, +) +from src.mlx_transformers.models import ( + XLMRobertaForTokenClassification as MlxXLMRobertaForTokenClassification, +) from src.mlx_transformers.models import XLMRobertaModel as MlxXLMRobertaModel from src.mlx_transformers.models.utils import convert -def load_model(model_name: str) -> MlxXLMRobertaModel: +def load_model(model_name: str, mlx_model_class, hgf_model_class): current_directory = os.path.dirname(os.path.realpath(__file__)) weights_path = os.path.join( current_directory, "model_checkpoints", model_name.replace("/", "-") + ".npz" ) if not os.path.exists(weights_path): - convert(model_name, weights_path, XLMRobertaModel) + convert(model_name, weights_path, hgf_model_class) config = XLMRobertaConfig.from_pretrained(model_name) - model = MlxXLMRobertaModel(config) + model = mlx_model_class(config) model.load_weights(weights_path, strict=True) return model -def load_hgf_model(model_name: str) -> XLMRobertaModel: - model = XLMRobertaModel.from_pretrained(model_name) +def load_hgf_model(model_name: str, hgf_model_class): + model = hgf_model_class.from_pretrained(model_name) return model @@ -37,7 +54,9 @@ class TestMlxXLMRoberta(unittest.TestCase): @classmethod def setUpClass(cls) -> None: cls.model_name = "FacebookAI/xlm-roberta-base" - cls.model = load_model(cls.model_name) + cls.model_class = MlxXLMRobertaModel + cls.hgf_model_class = XLMRobertaModel + cls.model = load_model(cls.model_name, cls.model_class, cls.hgf_model_class) cls.tokenizer = XLMRobertaTokenizer.from_pretrained(cls.model_name) cls.input_text = "Hello, my dog is cute" @@ -62,12 +81,177 @@ def test_model_output_hgf(self): inputs_hgf = self.tokenizer( self.input_text, return_tensors="pt", padding=True, truncation=True ) - hgf_model = load_hgf_model(self.model_name) + hgf_model = load_hgf_model(self.model_name, self.hgf_model_class) outputs_hgf = hgf_model(**inputs_hgf) outputs_hgf = outputs_hgf.last_hidden_state.detach().numpy() self.assertTrue(np.allclose(outputs_mlx, outputs_hgf, atol=1e-4)) +class TestMlxXLMRobertaForSequenceClassification(unittest.TestCase): + + @classmethod + def setUpClass(cls) -> None: + cls.model_name = "cardiffnlp/twitter-roberta-base-emotion" + cls.model_class = MlxXLMRobertaForSequenceClassification + cls.hgf_model_class = XLMRobertaForSequenceClassification + cls.model = load_model(cls.model_name, cls.model_class, cls.hgf_model_class) + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) + cls.input_text = "Hello, my dog is cute" + + def test_forward(self) -> None: + inputs = self.tokenizer( + self.input_text, return_tensors="np", padding=True, truncation=True + ) + + inputs = {key: mx.array(v) for key, v in inputs.items()} + outputs = self.model(**inputs) + self.assertIsInstance(outputs.logits, mx.array) + + def test_model_output_hgf(self): + inputs_mlx = self.tokenizer( + self.input_text, return_tensors="np", padding=True, truncation=True + ) + + inputs_mlx = {key: mx.array(v) for key, v in inputs_mlx.items()} + outputs_mlx = self.model(**inputs_mlx) + outputs_mlx = np.array(outputs_mlx.logits) + predicted_class_id = outputs_mlx.argmax().item() + mlx_label = self.model.config.id2label[predicted_class_id] + + inputs_hgf = self.tokenizer( + self.input_text, return_tensors="pt", padding=True, truncation=True + ) + hgf_model = load_hgf_model(self.model_name, self.hgf_model_class) + outputs_hgf = hgf_model(**inputs_hgf) + outputs_hgf = outputs_hgf.logits + + predicted_class_id = outputs_hgf.argmax().item() + hgf_label = hgf_model.config.id2label[predicted_class_id] + + self.assertEqual(mlx_label, hgf_label) + + +class TestMlxXLMRobertaForTokenClassification(unittest.TestCase): + + @classmethod + def setUpClass(cls) -> None: + cls.model_name = "Jean-Baptiste/roberta-large-ner-english" + cls.model_class = MlxXLMRobertaForTokenClassification + cls.hgf_model_class = XLMRobertaForTokenClassification + cls.model = load_model(cls.model_name, cls.model_class, cls.hgf_model_class) + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) + cls.input_text = "HuggingFace is a company based in Paris and New York" + + def test_forward(self) -> None: + inputs = self.tokenizer( + self.input_text, + return_tensors="np", + padding=True, + truncation=True, + add_special_tokens=False, + ) + + inputs = {key: mx.array(v) for key, v in inputs.items()} + outputs = self.model(**inputs) + self.assertIsInstance(outputs.logits, mx.array) + + def test_model_output_hgf(self): + inputs_mlx = self.tokenizer( + self.input_text, + return_tensors="np", + padding=True, + truncation=True, + add_special_tokens=False, + ) + + inputs_mlx = {key: mx.array(v) for key, v in inputs_mlx.items()} + outputs_mlx = self.model(**inputs_mlx) + outputs_mlx = np.array(outputs_mlx.logits) + mlx_predicted_token_class_ids = outputs_mlx.argmax(-1) + mlx_predicted_tokens_classes = [ + self.model.config.id2label[t.item()] + for t in mlx_predicted_token_class_ids[0] + ] + + inputs_hgf = self.tokenizer( + self.input_text, + return_tensors="pt", + padding=True, + truncation=True, + add_special_tokens=False, + ) + hgf_model = load_hgf_model(self.model_name, self.hgf_model_class) + outputs_hgf = hgf_model(**inputs_hgf) + outputs_hgf = outputs_hgf.logits + + hgf_predicted_token_class_ids = outputs_hgf.argmax(-1) + hgf_predicted_tokens_classes = [ + hgf_model.config.id2label[t.item()] + for t in hgf_predicted_token_class_ids[0] + ] + + self.assertEqual(mlx_predicted_tokens_classes, hgf_predicted_tokens_classes) + + +class TestMlxXLMRobertaForQuestionAnswering(unittest.TestCase): + + @classmethod + def setUpClass(cls) -> None: + cls.model_name = "deepset/roberta-base-squad2" + cls.model_class = MlxXLMRobertaForQuestionAnswering + cls.hgf_model_class = XLMRobertaForQuestionAnswering + cls.model = load_model(cls.model_name, cls.model_class, cls.hgf_model_class) + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) + cls.input_question = "Who was Jim Henson?" + cls.input_text = "Jim Henson was a nice puppet" + + def test_forward(self) -> None: + inputs = self.tokenizer( + self.input_question, self.input_text, return_tensors="np" + ) + + inputs = {key: mx.array(v) for key, v in inputs.items()} + outputs = self.model(**inputs) + self.assertIsInstance(outputs.start_logits, mx.array) + self.assertIsInstance(outputs.end_logits, mx.array) + + def test_model_output_hgf(self): + inputs_mlx = self.tokenizer( + self.input_question, self.input_text, return_tensors="np" + ) + + inputs_mlx = {key: mx.array(v) for key, v in inputs_mlx.items()} + outputs_mlx = self.model(**inputs_mlx) + + mlx_answer_start_index = outputs_mlx.start_logits.argmax().item() + mlx_answer_end_index = outputs_mlx.end_logits.argmax().item() + mlx_predict_answer_tokens = inputs_mlx["input_ids"].tolist() + mlx_predict_answer_tokens = mlx_predict_answer_tokens[0][ + mlx_answer_start_index : mlx_answer_end_index + 1 + ] + mlx_answer = self.tokenizer.decode( + mlx_predict_answer_tokens, skip_special_tokens=True + ) + + inputs_hgf = self.tokenizer( + self.input_question, self.input_text, return_tensors="pt" + ) + + hgf_model = load_hgf_model(self.model_name, self.hgf_model_class) + outputs_hgf = hgf_model(**inputs_hgf) + + hgf_answer_start_index = outputs_hgf.start_logits.argmax() + hgf_answer_end_index = outputs_hgf.end_logits.argmax() + hgf_predict_answer_tokens = inputs_hgf.input_ids[ + 0, hgf_answer_start_index : hgf_answer_end_index + 1 + ] + hgf_answer = self.tokenizer.decode( + hgf_predict_answer_tokens, skip_special_tokens=True + ) + + self.assertEqual(mlx_answer, hgf_answer) + + if __name__ == "__main__": unittest.main()