Skip to content

Commit

Permalink
Adds RoBERTa model (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonahgeorge authored Dec 5, 2024
1 parent 0858607 commit 6694a3e
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion lib/informers/models.rb
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,18 @@ class RobertaPreTrainedModel < PreTrainedModel
class RobertaModel < RobertaPreTrainedModel
end

class RobertaForTokenClassification < RobertaPreTrainedModel
def call(model_inputs)
TokenClassifierOutput.new(*super(model_inputs))
end
end

class RobertaForSequenceClassification < RobertaPreTrainedModel
def call(model_inputs)
SequenceClassifierOutput.new(*super(model_inputs))
end
end

class RobertaForMaskedLM < RobertaPreTrainedModel
def call(model_inputs)
MaskedLMOutput.new(*super(model_inputs))
Expand Down Expand Up @@ -1224,12 +1236,14 @@ class ClapModel < ClapPreTrainedModel
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {
"bert" => ["BertForSequenceClassification", BertForSequenceClassification],
"distilbert" => ["DistilBertForSequenceClassification", DistilBertForSequenceClassification],
"roberta" => ["RobertaForSequenceClassification", RobertaForSequenceClassification],
"xlm-roberta" => ["XLMRobertaForSequenceClassification", XLMRobertaForSequenceClassification],
"bart" => ["BartForSequenceClassification", BartForSequenceClassification]
}

MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = {
"bert" => ["BertForTokenClassification", BertForTokenClassification]
"bert" => ["BertForTokenClassification", BertForTokenClassification],
"roberta" => ["RobertaForTokenClassification", RobertaForTokenClassification]
}

MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = {
Expand Down

0 comments on commit 6694a3e

Please sign in to comment.