Skip to content

Commit

Permalink
Merge pull request #1 from Seun-Ajayi/debruyne
Browse files Browse the repository at this point in the history
  • Loading branch information
ToluClassics authored Apr 20, 2024
2 parents 5cb5edb + a29dc59 commit 9eed187
Show file tree
Hide file tree
Showing 4 changed files with 468 additions and 38 deletions.
7 changes: 6 additions & 1 deletion src/mlx_transformers/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from .bert import BertForSequenceClassification, BertModel
from .roberta import RobertaModel
from .roberta import (
RobertaModel,
RobertaForSequenceClassification,
RobertaForTokenClassification,
RobertaForQuestionAnswering
)
from .xlm_roberta import XLMRobertaModel
13 changes: 10 additions & 3 deletions src/mlx_transformers/models/modelling_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ class BaseModelOutputWithPoolingAndCrossAttentions:
cross_attentions: Optional[Tuple[mx.array, ...]] = None


@dataclass
class SequenceClassifierOutput:
loss: Optional[mx.array] = None
logits: mx.array = None
hidden_states: Optional[Tuple[mx.array, ...]] = None
attentions: Optional[Tuple[mx.array, ...]] = None

@dataclass
class SequenceClassifierOutputWithPast:
loss: Optional[mx.array] = None
Expand All @@ -44,7 +51,6 @@ class Seq2SeqSequenceClassifierOutput:
encoder_hidden_states: Optional[Tuple[mx.array, ...]] = None
encoder_attentions: Optional[Tuple[mx.array, ...]] = None


@dataclass
class TokenClassifierOutput:
loss: Optional[mx.array] = None
Expand All @@ -54,8 +60,9 @@ class TokenClassifierOutput:


@dataclass
class SequenceClassifierOutput:
class QuestionAnsweringModelOutput:
loss: Optional[mx.array] = None
logits: mx.array = None
start_logits: mx.array = None
end_logits: mx.array = None
hidden_states: Optional[Tuple[mx.array, ...]] = None
attentions: Optional[Tuple[mx.array, ...]] = None
Loading

0 comments on commit 9eed187

Please sign in to comment.