diff --git a/src/mlx_transformers/models/__init__.py b/src/mlx_transformers/models/__init__.py index b794cc8..5edc111 100644 --- a/src/mlx_transformers/models/__init__.py +++ b/src/mlx_transformers/models/__init__.py @@ -10,9 +10,19 @@ from .llama import LlamaForCausalLM, LlamaModel from .m2m_100 import M2M100ForConditionalGeneration from .openelm import OpenELMForCausalLM, OpenELMModel -from .phi import PhiForCausalLM, PhiModel -from .phi3 import Phi3ForCausalLM, Phi3Model -from .persimmon import PersimmonForCausalLM +from .phi import ( + PhiForCausalLM, + PhiModel, + PhiForSequenceClassification, + PhiForTokenClassification, +) +from .phi3 import ( + Phi3ForCausalLM, + Phi3Model, + Phi3ForSequenceClassification, + Phi3ForTokenClassification, +) +from .persimmon import PersimmonForCausalLM, PersimmonForSequenceClassification from .fuyu import FuyuForCausalLM from .roberta import ( RobertaForQuestionAnswering, diff --git a/src/mlx_transformers/models/persimmon.py b/src/mlx_transformers/models/persimmon.py index a13f588..4c5a309 100644 --- a/src/mlx_transformers/models/persimmon.py +++ b/src/mlx_transformers/models/persimmon.py @@ -1,6 +1,6 @@ import math import logging -from typing import Optional, Tuple, Dict +from typing import Optional, Tuple, Dict, List, Union import mlx.core as mx import mlx.nn as nn @@ -9,7 +9,11 @@ from .base import MlxPretrainedMixin from .cache import Cache, DynamicCache -from .modelling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from .modelling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) from .utils import ACT2FN logger = logging.getLogger(__name__) @@ -764,3 +768,120 @@ def sample(logits): next_token = sample(next_token_logits) yield next_token + + +class PersimmonForSequenceClassification(nn.Module, MlxPretrainedMixin): + def __init__(self, config): + super().__init__() + self.config = config + self.num_labels = config.num_labels + self.model = PersimmonModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def __call__( + self, + input_ids: mx.array = None, + attention_mask: Optional[mx.array] = None, + position_ids: Optional[mx.array] = None, + past_key_values: Optional[Union[Cache, List[mx.array]]] = None, + inputs_embeds: Optional[mx.array] = None, + labels: Optional[mx.array] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`mx.array` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + model_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = model_outputs.last_hidden_state + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError( + "Cannot handle batch sizes > 1 if no padding token is defined." + ) + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = ( + mx.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + ) + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[ + mx.arange(batch_size, device=logits.device), sequence_lengths + ] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and ( + labels.dtype == mx.long or labels.dtype == mx.int + ): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = nn.losses.mse_loss + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = nn.losses.cross_entropy + loss = loss_fct( + pooled_logits.view(-1, self.num_labels), labels.view(-1) + ) + elif self.config.problem_type == "multi_label_classification": + loss_fct = nn.losses.binary_cross_entropy + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + model_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=model_outputs.past_key_values, + hidden_states=model_outputs.hidden_states, + attentions=model_outputs.attentions, + ) diff --git a/src/mlx_transformers/models/phi.py b/src/mlx_transformers/models/phi.py index 46e85ed..4f3c2c1 100644 --- a/src/mlx_transformers/models/phi.py +++ b/src/mlx_transformers/models/phi.py @@ -1,6 +1,6 @@ import math import logging -from typing import Optional, Tuple, Dict +from typing import Optional, Tuple, Dict, Union, List import mlx.core as mx import mlx.nn as nn @@ -9,7 +9,12 @@ from .base import MlxPretrainedMixin from .cache import Cache, DynamicCache -from .modelling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from .modelling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) from .utils import ACT2FN logger = logging.getLogger(__name__) @@ -896,3 +901,200 @@ def sample(logits): next_token = sample(next_token_logits) yield next_token + + +class PhiForSequenceClassification(nn.Module, MlxPretrainedMixin): + def __init__(self, config): + super().__init__() + self.config = config + self.num_labels = config.num_labels + self.model = PhiModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def __call__( + self, + input_ids: mx.array = None, + attention_mask: Optional[mx.array] = None, + position_ids: Optional[mx.array] = None, + past_key_values: Optional[Union[Cache, List[mx.array]]] = None, + inputs_embeds: Optional[mx.array] = None, + labels: Optional[mx.array] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`mx.array` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + model_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = model_outputs.last_hidden_state + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError( + "Cannot handle batch sizes > 1 if no padding token is defined." + ) + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = ( + mx.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + ) + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[ + mx.arange(batch_size, device=logits.device), sequence_lengths + ] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and ( + labels.dtype == mx.long or labels.dtype == mx.int + ): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = nn.losses.mse_loss + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = nn.losses.cross_entropy + loss = loss_fct( + pooled_logits.view(-1, self.num_labels), labels.view(-1) + ) + elif self.config.problem_type == "multi_label_classification": + loss_fct = nn.losses.binary_cross_entropy + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + model_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=model_outputs.past_key_values, + hidden_states=model_outputs.hidden_states, + attentions=model_outputs.attentions, + ) + + +class PhiForTokenClassification(nn.Module, MlxPretrainedMixin): + def __init__(self, config: PhiConfig): + super().__init__() + self.num_labels = config.num_labels + self.config = config + + self.model = PhiModel(config) + if ( + hasattr(config, "classifier_dropout") + and config.classifier_dropout is not None + ): + classifier_dropout = config.classifier_dropout + elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + def __call__( + self, + input_ids: Optional[mx.array] = None, + past_key_values: Optional[Tuple[Tuple[mx.array, mx.array], ...]] = None, + attention_mask: Optional[mx.array] = None, + inputs_embeds: Optional[mx.array] = None, + labels: Optional[mx.array] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **deprecated_arguments, + ) -> Union[Tuple[mx.array], TokenClassifierOutput]: + r""" + labels (`mx.array` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + model_outputs = self.model( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = model_outputs.last_hidden_state + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + batch_size, seq_length = labels.shape + loss_fct = nn.losses.cross_entropy + loss = loss_fct( + logits.view(batch_size * seq_length, self.num_labels), + labels.view(batch_size * seq_length), + ) + + if not return_dict: + output = (logits,) + model_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=model_outputs.hidden_states, + attentions=model_outputs.attentions, + ) diff --git a/src/mlx_transformers/models/phi3.py b/src/mlx_transformers/models/phi3.py index 56f13d1..2134f47 100644 --- a/src/mlx_transformers/models/phi3.py +++ b/src/mlx_transformers/models/phi3.py @@ -1,6 +1,6 @@ import math import logging -from typing import Optional, Dict, Tuple +from typing import Optional, Dict, Tuple, Union, List import mlx.core as mx import mlx.nn as nn @@ -12,6 +12,8 @@ from .modelling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, ) from .utils import ACT2FN @@ -929,3 +931,200 @@ def sample(logits): next_token = sample(next_token_logits) yield next_token + + +class Phi3ForSequenceClassification(nn.Module, MlxPretrainedMixin): + def __init__(self, config): + super().__init__() + self.config = config + self.num_labels = config.num_labels + self.model = Phi3Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def __call__( + self, + input_ids: mx.array = None, + attention_mask: Optional[mx.array] = None, + position_ids: Optional[mx.array] = None, + past_key_values: Optional[Union[Cache, List[mx.array]]] = None, + inputs_embeds: Optional[mx.array] = None, + labels: Optional[mx.array] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`mx.array` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + model_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = model_outputs.last_hidden_state + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError( + "Cannot handle batch sizes > 1 if no padding token is defined." + ) + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = ( + mx.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + ) + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[ + mx.arange(batch_size, device=logits.device), sequence_lengths + ] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and ( + labels.dtype == mx.long or labels.dtype == mx.int + ): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = nn.losses.mse_loss + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = nn.losses.cross_entropy + loss = loss_fct( + pooled_logits.view(-1, self.num_labels), labels.view(-1) + ) + elif self.config.problem_type == "multi_label_classification": + loss_fct = nn.losses.binary_cross_entropy + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + model_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=model_outputs.past_key_values, + hidden_states=model_outputs.hidden_states, + attentions=model_outputs.attentions, + ) + + +class Phi3ForTokenClassification(nn.Module, MlxPretrainedMixin): + def __init__(self, config): + super().__init__() + self.config = config + self.num_labels = config.num_labels + + self.model = Phi3Model(config) + if ( + hasattr(config, "classifier_dropout") + and config.classifier_dropout is not None + ): + classifier_dropout = config.classifier_dropout + elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + def __call__( + self, + input_ids: Optional[mx.array] = None, + past_key_values: Optional[Tuple[Tuple[mx.array, mx.array], ...]] = None, + attention_mask: Optional[mx.array] = None, + inputs_embeds: Optional[mx.array] = None, + labels: Optional[mx.array] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **deprecated_arguments, + ) -> Union[Tuple[mx.array], TokenClassifierOutput]: + r""" + labels (`mx.array` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + model_outputs = self.model( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = model_outputs.last_hidden_state + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + batch_size, seq_length = labels.shape + loss_fct = nn.losses.cross_entropy + loss = loss_fct( + logits.view(batch_size * seq_length, self.num_labels), + labels.view(batch_size * seq_length), + ) + + if not return_dict: + output = (logits,) + model_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=model_outputs.hidden_states, + attentions=model_outputs.attentions, + ) diff --git a/tests/test_persimmon.py b/tests/test_persimmon.py new file mode 100644 index 0000000..9cddb17 --- /dev/null +++ b/tests/test_persimmon.py @@ -0,0 +1,82 @@ +import unittest +import numpy as np +import mlx.core as mx +from transformers import ( + AutoTokenizer, + PersimmonConfig, + PersimmonForCausalLM, + PersimmonForSequenceClassification, +) + +from src.mlx_transformers.models import PersimmonForCausalLM as MlxPersimmonForCausalLM +from src.mlx_transformers.models import PersimmonForSequenceClassification as MlxPersimmonForSequenceClassification + + +def load_hgf_model(model_name: str, hgf_model_class: str): + model = hgf_model_class.from_pretrained(model_name) + return model + + +class MlxPersimmon(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + cls.model_name = "adept/persimmon-8b-base" + config = PersimmonConfig.from_pretrained(cls.model_name) + cls.hgf_model_class = PersimmonForCausalLM + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) + cls.model = MlxPersimmonForCausalLM(config) + cls.model.from_pretrained(cls.model_name) + + cls.input_text = "human: Hey, what should I eat for dinner?" + + def test_forward(self) -> None: + inputs = self.tokenizer(self.input_text, return_tensors="np", truncation=True) + + inputs = {key: mx.array(v) for key, v in inputs.items()} + outputs = self.model(**inputs, use_cache=True) + + assert type(outputs.logits) == mx.array + + +class TestMlxPersimmonForSequenceClassification(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + cls.model_name = "adept/persimmon-8b-base" + config = PersimmonConfig.from_pretrained(cls.model_name) + cls.hgf_model_class = PersimmonForSequenceClassification + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) + cls.model = MlxPersimmonForSequenceClassification(config) + cls.model.from_pretrained(cls.model_name) + + cls.input_text = "human: Hey, what should I eat for dinner?" + + def test_forward(self) -> None: + inputs = self.tokenizer(self.input_text, return_tensors="np", truncation=True) + + inputs = {key: mx.array(v) for key, v in inputs.items()} + outputs = self.model(**inputs, use_cache=True) + + assert type(outputs.logits) == mx.array + + def test_model_output_hgf(self): + inputs_mlx = self.tokenizer( + self.input_text, return_tensors="np", padding=True, truncation=True + ) + + inputs_mlx = {key: mx.array(v) for key, v in inputs_mlx.items()} + outputs_mlx = self.model(**inputs_mlx) + outputs_mlx = np.array(outputs_mlx.logits) + predicted_class_id = outputs_mlx.argmax().item() + mlx_label = self.model.config.id2label[predicted_class_id] + + inputs_hgf = self.tokenizer( + self.input_text, return_tensors="pt", padding=True, truncation=True + ) + hgf_model = load_hgf_model(self.model_name, self.hgf_model_class) + outputs_hgf = hgf_model(**inputs_hgf) + outputs_hgf = outputs_hgf.logits + + predicted_class_id = outputs_hgf.argmax().item() + hgf_label = hgf_model.config.id2label[predicted_class_id] + + self.assertEqual(mlx_label, hgf_label) diff --git a/tests/test_phi.py b/tests/test_phi.py index 6d4f39e..99be7a4 100644 --- a/tests/test_phi.py +++ b/tests/test_phi.py @@ -1,13 +1,25 @@ import unittest - +import numpy as np import mlx.core as mx -from transformers import AutoTokenizer, PhiConfig, PhiForCausalLM +from transformers import ( + AutoTokenizer, + PhiConfig, + PhiForCausalLM, + PhiForSequenceClassification, + PhiForTokenClassification, +) from src.mlx_transformers.models import PhiForCausalLM as MlxPhiForCausalLM +from src.mlx_transformers.models import ( + PhiForSequenceClassification as MlxPhiForSequenceClassification, +) +from src.mlx_transformers.models import ( + PhiForTokenClassification as MlxPhiForTokenClassification, +) -def load_hgf_model(model_name: str) -> PhiForCausalLM: - model = PhiForCausalLM.from_pretrained(model_name) +def load_hgf_model(model_name: str, hgf_model_class: str): + model = hgf_model_class.from_pretrained(model_name) return model @@ -16,6 +28,7 @@ class TestMlxPhi(unittest.TestCase): def setUpClass(cls) -> None: cls.model_name = "microsoft/phi-2" config = PhiConfig.from_pretrained(cls.model_name) + cls.hgf_model_class = PhiForCausalLM cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) cls.model = MlxPhiForCausalLM(config) cls.model.from_pretrained(cls.model_name) @@ -29,3 +42,105 @@ def test_forward(self) -> None: outputs = self.model(**inputs, use_cache=True) assert type(outputs.logits) == mx.array + + +class TestMlxPhiForTokenClassification(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + cls.model_name = "microsoft/phi-2" + config = PhiConfig.from_pretrained(cls.model_name) + cls.hgf_model_class = PhiForTokenClassification + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) + cls.model = MlxPhiForTokenClassification(config) + cls.model.from_pretrained(cls.model_name) + + cls.input_text = "Hey, are you conscious? Can you talk to me?" + + def test_forward(self) -> None: + inputs = self.tokenizer(self.input_text, return_tensors="np", truncation=True) + + inputs = {key: mx.array(v) for key, v in inputs.items()} + outputs = self.model(**inputs, use_cache=True) + + assert type(outputs.logits) == mx.array + + def test_model_output_hgf(self): + inputs_mlx = self.tokenizer( + self.input_text, + return_tensors="np", + padding=True, + truncation=True, + add_special_tokens=False, + ) + + inputs_mlx = {key: mx.array(v) for key, v in inputs_mlx.items()} + outputs_mlx = self.model(**inputs_mlx) + outputs_mlx = np.array(outputs_mlx.logits) + mlx_predicted_token_class_ids = outputs_mlx.argmax(-1) + mlx_predicted_tokens_classes = [ + self.model.config.id2label[t.item()] + for t in mlx_predicted_token_class_ids[0] + ] + + inputs_hgf = self.tokenizer( + self.input_text, + return_tensors="pt", + padding=True, + truncation=True, + add_special_tokens=False, + ) + hgf_model = load_hgf_model(self.model_name, self.hgf_model_class) + outputs_hgf = hgf_model(**inputs_hgf) + outputs_hgf = outputs_hgf.logits + + hgf_predicted_token_class_ids = outputs_hgf.argmax(-1) + hgf_predicted_tokens_classes = [ + hgf_model.config.id2label[t.item()] + for t in hgf_predicted_token_class_ids[0] + ] + + self.assertEqual(mlx_predicted_tokens_classes, hgf_predicted_tokens_classes) + + +class TestMlxPhiForSequenceClassification(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + cls.model_name = "microsoft/phi-2" + config = PhiConfig.from_pretrained(cls.model_name) + cls.hgf_model_class = PhiForSequenceClassification + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) + cls.model = MlxPhiForSequenceClassification(config) + cls.model.from_pretrained(cls.model_name) + + cls.input_text = "Hey, are you conscious? Can you talk to me?" + + def test_forward(self) -> None: + inputs = self.tokenizer(self.input_text, return_tensors="np", truncation=True) + + inputs = {key: mx.array(v) for key, v in inputs.items()} + outputs = self.model(**inputs, use_cache=True) + + assert type(outputs.logits) == mx.array + + def test_model_output_hgf(self): + inputs_mlx = self.tokenizer( + self.input_text, return_tensors="np", padding=True, truncation=True + ) + + inputs_mlx = {key: mx.array(v) for key, v in inputs_mlx.items()} + outputs_mlx = self.model(**inputs_mlx) + outputs_mlx = np.array(outputs_mlx.logits) + predicted_class_id = outputs_mlx.argmax().item() + mlx_label = self.model.config.id2label[predicted_class_id] + + inputs_hgf = self.tokenizer( + self.input_text, return_tensors="pt", padding=True, truncation=True + ) + hgf_model = load_hgf_model(self.model_name, self.hgf_model_class) + outputs_hgf = hgf_model(**inputs_hgf) + outputs_hgf = outputs_hgf.logits + + predicted_class_id = outputs_hgf.argmax().item() + hgf_label = hgf_model.config.id2label[predicted_class_id] + + self.assertEqual(mlx_label, hgf_label) diff --git a/tests/test_phi3.py b/tests/test_phi3.py index 0a0c7b6..a62c08e 100644 --- a/tests/test_phi3.py +++ b/tests/test_phi3.py @@ -1,13 +1,25 @@ import unittest - +import numpy as np import mlx.core as mx -from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM +from transformers import ( + AutoTokenizer, + AutoConfig, + AutoModelForCausalLM, + AutoModelForSequenceClassification, + AutoModelForTokenClassification, +) from src.mlx_transformers.models import Phi3ForCausalLM as MlxPhi3ForCausalLM +from src.mlx_transformers.models import ( + Phi3ForSequenceClassification as MlxPhi3ForSequenceClassification, +) +from src.mlx_transformers.models import ( + Phi3ForTokenClassification as MlxPhi3ForTokenClassification, +) -def load_hgf_model(model_name: str) -> AutoModelForCausalLM: - model = AutoModelForCausalLM.from_pretrained(model_name) +def load_hgf_model(model_name: str, hgf_model_class: str): + model = hgf_model_class.from_pretrained(model_name) return model @@ -16,6 +28,7 @@ class TestMlxPhi3(unittest.TestCase): def setUpClass(cls) -> None: cls.model_name = "microsoft/Phi-3-mini-4k-instruct" config = AutoConfig.from_pretrained(cls.model_name) + cls.hgf_model_class = AutoModelForCausalLM cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) cls.model = MlxPhi3ForCausalLM(config) cls.model.from_pretrained( @@ -33,3 +46,113 @@ def test_forward(self) -> None: outputs = self.model(**inputs, use_cache=True) assert type(outputs.logits) == mx.array + + +class TestMlxPhiForTokenClassification(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + cls.model_name = "microsoft/Phi-3-mini-4k-instruct" + config = AutoConfig.from_pretrained(cls.model_name) + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) + cls.hgf_model_class = AutoModelForTokenClassification + cls.model = MlxPhi3ForTokenClassification(config) + cls.model.from_pretrained( + cls.model_name, + huggingface_model_architecture="AutoModelForTokenClassification", + trust_remote_code=True, + ) + + cls.input_text = "Hey, are you conscious? Can you talk to me?" + + def test_forward(self) -> None: + inputs = self.tokenizer(self.input_text, return_tensors="np", truncation=True) + + inputs = {key: mx.array(v) for key, v in inputs.items()} + outputs = self.model(**inputs, use_cache=True) + + assert type(outputs.logits) == mx.array + + def test_model_output_hgf(self): + inputs_mlx = self.tokenizer( + self.input_text, + return_tensors="np", + padding=True, + truncation=True, + add_special_tokens=False, + ) + + inputs_mlx = {key: mx.array(v) for key, v in inputs_mlx.items()} + outputs_mlx = self.model(**inputs_mlx) + outputs_mlx = np.array(outputs_mlx.logits) + mlx_predicted_token_class_ids = outputs_mlx.argmax(-1) + mlx_predicted_tokens_classes = [ + self.model.config.id2label[t.item()] + for t in mlx_predicted_token_class_ids[0] + ] + + inputs_hgf = self.tokenizer( + self.input_text, + return_tensors="pt", + padding=True, + truncation=True, + add_special_tokens=False, + ) + hgf_model = load_hgf_model(self.model_name, self.hgf_model_class) + outputs_hgf = hgf_model(**inputs_hgf) + outputs_hgf = outputs_hgf.logits + + hgf_predicted_token_class_ids = outputs_hgf.argmax(-1) + hgf_predicted_tokens_classes = [ + hgf_model.config.id2label[t.item()] + for t in hgf_predicted_token_class_ids[0] + ] + + self.assertEqual(mlx_predicted_tokens_classes, hgf_predicted_tokens_classes) + + +class TestMlxPhi3ForSequenceClassification(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + cls.model_name = "microsoft/Phi-3-mini-4k-instruct" + config = AutoConfig.from_pretrained(cls.model_name) + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) + cls.hgf_model_class = AutoModelForSequenceClassification + cls.model = MlxPhi3ForSequenceClassification(config) + cls.model.from_pretrained( + cls.model_name, + huggingface_model_architecture="AutoModelForSequenceClassification", + trust_remote_code=True, + ) + + cls.input_text = "Hey, are you conscious? Can you talk to me?" + + def test_forward(self) -> None: + inputs = self.tokenizer(self.input_text, return_tensors="np", truncation=True) + + inputs = {key: mx.array(v) for key, v in inputs.items()} + outputs = self.model(**inputs, use_cache=True) + + assert type(outputs.logits) == mx.array + + def test_model_output_hgf(self): + inputs_mlx = self.tokenizer( + self.input_text, return_tensors="np", padding=True, truncation=True + ) + + inputs_mlx = {key: mx.array(v) for key, v in inputs_mlx.items()} + outputs_mlx = self.model(**inputs_mlx) + outputs_mlx = np.array(outputs_mlx.logits) + predicted_class_id = outputs_mlx.argmax().item() + mlx_label = self.model.config.id2label[predicted_class_id] + + inputs_hgf = self.tokenizer( + self.input_text, return_tensors="pt", padding=True, truncation=True + ) + hgf_model = load_hgf_model(self.model_name, self.hgf_model_class) + outputs_hgf = hgf_model(**inputs_hgf) + outputs_hgf = outputs_hgf.logits + + predicted_class_id = outputs_hgf.argmax().item() + hgf_label = hgf_model.config.id2label[predicted_class_id] + + self.assertEqual(mlx_label, hgf_label) diff --git a/tests/test_xlm_roberta.py b/tests/test_xlm_roberta.py index b4fc88e..03539ee 100644 --- a/tests/test_xlm_roberta.py +++ b/tests/test_xlm_roberta.py @@ -87,7 +87,6 @@ def test_model_output_hgf(self): class TestMlxXLMRobertaForSequenceClassification(unittest.TestCase): - @classmethod def setUpClass(cls) -> None: cls.model_name = "cardiffnlp/twitter-roberta-base-emotion" @@ -131,7 +130,6 @@ def test_model_output_hgf(self): class TestMlxXLMRobertaForTokenClassification(unittest.TestCase): - @classmethod def setUpClass(cls) -> None: cls.model_name = "Jean-Baptiste/roberta-large-ner-english" @@ -193,7 +191,6 @@ def test_model_output_hgf(self): class TestMlxXLMRobertaForQuestionAnswering(unittest.TestCase): - @classmethod def setUpClass(cls) -> None: cls.model_name = "deepset/roberta-base-squad2"