Skip to content

Commit

Permalink
Added sub-tasks for phi, phi3 and persimmon
Browse files Browse the repository at this point in the history
  • Loading branch information
Seun-Ajayi committed May 13, 2024
1 parent 780a060 commit ec6a6f8
Show file tree
Hide file tree
Showing 8 changed files with 868 additions and 19 deletions.
16 changes: 13 additions & 3 deletions src/mlx_transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,19 @@
from .llama import LlamaForCausalLM, LlamaModel
from .m2m_100 import M2M100ForConditionalGeneration
from .openelm import OpenELMForCausalLM, OpenELMModel
from .phi import PhiForCausalLM, PhiModel
from .phi3 import Phi3ForCausalLM, Phi3Model
from .persimmon import PersimmonForCausalLM
from .phi import (
PhiForCausalLM,
PhiModel,
PhiForSequenceClassification,
PhiForTokenClassification,
)
from .phi3 import (
Phi3ForCausalLM,
Phi3Model,
Phi3ForSequenceClassification,
Phi3ForTokenClassification,
)
from .persimmon import PersimmonForCausalLM, PersimmonForSequenceClassification
from .fuyu import FuyuForCausalLM
from .roberta import (
RobertaForQuestionAnswering,
Expand Down
125 changes: 123 additions & 2 deletions src/mlx_transformers/models/persimmon.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
import logging
from typing import Optional, Tuple, Dict
from typing import Optional, Tuple, Dict, List, Union

import mlx.core as mx
import mlx.nn as nn
Expand All @@ -9,7 +9,11 @@

from .base import MlxPretrainedMixin
from .cache import Cache, DynamicCache
from .modelling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from .modelling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
from .utils import ACT2FN

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -764,3 +768,120 @@ def sample(logits):
next_token = sample(next_token_logits)

yield next_token


class PersimmonForSequenceClassification(nn.Module, MlxPretrainedMixin):
def __init__(self, config):
super().__init__()
self.config = config
self.num_labels = config.num_labels
self.model = PersimmonModel(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)

def get_input_embeddings(self):
return self.model.embed_tokens

def set_input_embeddings(self, value):
self.model.embed_tokens = value

def __call__(
self,
input_ids: mx.array = None,
attention_mask: Optional[mx.array] = None,
position_ids: Optional[mx.array] = None,
past_key_values: Optional[Union[Cache, List[mx.array]]] = None,
inputs_embeds: Optional[mx.array] = None,
labels: Optional[mx.array] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`mx.array` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)

model_outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = model_outputs.last_hidden_state
logits = self.score(hidden_states)

if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]

if self.config.pad_token_id is None and batch_size != 1:
raise ValueError(
"Cannot handle batch sizes > 1 if no padding token is defined."
)
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = (
mx.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
)
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1

pooled_logits = logits[
mx.arange(batch_size, device=logits.device), sequence_lengths
]

loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (
labels.dtype == mx.long or labels.dtype == mx.int
):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"

if self.config.problem_type == "regression":
loss_fct = nn.losses.mse_loss
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = nn.losses.cross_entropy
loss = loss_fct(
pooled_logits.view(-1, self.num_labels), labels.view(-1)
)
elif self.config.problem_type == "multi_label_classification":
loss_fct = nn.losses.binary_cross_entropy
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + model_outputs[1:]
return ((loss,) + output) if loss is not None else output

return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=model_outputs.past_key_values,
hidden_states=model_outputs.hidden_states,
attentions=model_outputs.attentions,
)
206 changes: 204 additions & 2 deletions src/mlx_transformers/models/phi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
import logging
from typing import Optional, Tuple, Dict
from typing import Optional, Tuple, Dict, Union, List

import mlx.core as mx
import mlx.nn as nn
Expand All @@ -9,7 +9,12 @@

from .base import MlxPretrainedMixin
from .cache import Cache, DynamicCache
from .modelling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from .modelling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from .utils import ACT2FN

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -896,3 +901,200 @@ def sample(logits):
next_token = sample(next_token_logits)

yield next_token


class PhiForSequenceClassification(nn.Module, MlxPretrainedMixin):
def __init__(self, config):
super().__init__()
self.config = config
self.num_labels = config.num_labels
self.model = PhiModel(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)

def get_input_embeddings(self):
return self.model.embed_tokens

def set_input_embeddings(self, value):
self.model.embed_tokens = value

def __call__(
self,
input_ids: mx.array = None,
attention_mask: Optional[mx.array] = None,
position_ids: Optional[mx.array] = None,
past_key_values: Optional[Union[Cache, List[mx.array]]] = None,
inputs_embeds: Optional[mx.array] = None,
labels: Optional[mx.array] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`mx.array` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)

model_outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = model_outputs.last_hidden_state
logits = self.score(hidden_states)

if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]

if self.config.pad_token_id is None and batch_size != 1:
raise ValueError(
"Cannot handle batch sizes > 1 if no padding token is defined."
)
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = (
mx.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
)
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1

pooled_logits = logits[
mx.arange(batch_size, device=logits.device), sequence_lengths
]

loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (
labels.dtype == mx.long or labels.dtype == mx.int
):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"

if self.config.problem_type == "regression":
loss_fct = nn.losses.mse_loss
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = nn.losses.cross_entropy
loss = loss_fct(
pooled_logits.view(-1, self.num_labels), labels.view(-1)
)
elif self.config.problem_type == "multi_label_classification":
loss_fct = nn.losses.binary_cross_entropy
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + model_outputs[1:]
return ((loss,) + output) if loss is not None else output

return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=model_outputs.past_key_values,
hidden_states=model_outputs.hidden_states,
attentions=model_outputs.attentions,
)


class PhiForTokenClassification(nn.Module, MlxPretrainedMixin):
def __init__(self, config: PhiConfig):
super().__init__()
self.num_labels = config.num_labels
self.config = config

self.model = PhiModel(config)
if (
hasattr(config, "classifier_dropout")
and config.classifier_dropout is not None
):
classifier_dropout = config.classifier_dropout
elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)

def __call__(
self,
input_ids: Optional[mx.array] = None,
past_key_values: Optional[Tuple[Tuple[mx.array, mx.array], ...]] = None,
attention_mask: Optional[mx.array] = None,
inputs_embeds: Optional[mx.array] = None,
labels: Optional[mx.array] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**deprecated_arguments,
) -> Union[Tuple[mx.array], TokenClassifierOutput]:
r"""
labels (`mx.array` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)

model_outputs = self.model(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

hidden_states = model_outputs.last_hidden_state
hidden_states = self.dropout(hidden_states)
logits = self.classifier(hidden_states)

loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
batch_size, seq_length = labels.shape
loss_fct = nn.losses.cross_entropy
loss = loss_fct(
logits.view(batch_size * seq_length, self.num_labels),
labels.view(batch_size * seq_length),
)

if not return_dict:
output = (logits,) + model_outputs[2:]
return ((loss,) + output) if loss is not None else output

return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=model_outputs.hidden_states,
attentions=model_outputs.attentions,
)
Loading

0 comments on commit ec6a6f8

Please sign in to comment.