Skip to content

Commit

Permalink
feat: add support for ModernBERT
Browse files Browse the repository at this point in the history
  • Loading branch information
default-anton committed Jan 29, 2025
1 parent 332d5f8 commit ce465a9
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
28 changes: 28 additions & 0 deletions lib/informers/models.rb
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,30 @@ def call(model_inputs)
end
end

class ModernBertPreTrainedModel < PreTrainedModel
end

class ModernBertModel < ModernBertPreTrainedModel
end

class ModernBertForMaskedLM < ModernBertPreTrainedModel
def call(model_inputs)
MaskedLMOutput.new(*super(model_inputs))
end
end

class ModernBertForSequenceClassification < ModernBertPreTrainedModel
def call(model_inputs)
SequenceClassifierOutput.new(*super(model_inputs))
end
end

class ModernBertForTokenClassification < ModernBertPreTrainedModel
def call(model_inputs)
TokenClassifierOutput.new(*super(model_inputs))
end
end

class NomicBertPreTrainedModel < PreTrainedModel
end

Expand Down Expand Up @@ -1198,6 +1222,7 @@ class ClapModel < ClapPreTrainedModel

MODEL_MAPPING_NAMES_ENCODER_ONLY = {
"bert" => ["BertModel", BertModel],
"modernbert" => ["ModernBertModel", ModernBertModel],
"nomic_bert" => ["NomicBertModel", NomicBertModel],
"electra" => ["ElectraModel", ElectraModel],
"convbert" => ["ConvBertModel", ConvBertModel],
Expand Down Expand Up @@ -1235,6 +1260,7 @@ class ClapModel < ClapPreTrainedModel

MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {
"bert" => ["BertForSequenceClassification", BertForSequenceClassification],
"modernbert" => ["ModernBertForSequenceClassification", ModernBertForSequenceClassification],
"distilbert" => ["DistilBertForSequenceClassification", DistilBertForSequenceClassification],
"roberta" => ["RobertaForSequenceClassification", RobertaForSequenceClassification],
"xlm-roberta" => ["XLMRobertaForSequenceClassification", XLMRobertaForSequenceClassification],
Expand All @@ -1243,6 +1269,7 @@ class ClapModel < ClapPreTrainedModel

MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = {
"bert" => ["BertForTokenClassification", BertForTokenClassification],
"modernbert" => ["ModernBertForTokenClassification", ModernBertForTokenClassification],
"roberta" => ["RobertaForTokenClassification", RobertaForTokenClassification]
}

Expand All @@ -1259,6 +1286,7 @@ class ClapModel < ClapPreTrainedModel

MODEL_FOR_MASKED_LM_MAPPING_NAMES = {
"bert" => ["BertForMaskedLM", BertForMaskedLM],
"modernbert" => ["ModernBertForMaskedLM", ModernBertForMaskedLM],
"roberta" => ["RobertaForMaskedLM", RobertaForMaskedLM]
}

Expand Down
11 changes: 11 additions & 0 deletions test/model_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,17 @@ def test_gte_small
assert_elements_in_delta [-0.05246907, 0.03752426, 0.07344585], embeddings[-1][..2]
end

# https://huggingface.co/Alibaba-NLP/gte-modernbert-base
def test_gte_modernbert_base
sentences = ["How is the weather today?", "What is the weather like today?"]

model = Informers.pipeline("embedding", "Alibaba-NLP/gte-modernbert-base")
embeddings = model.(sentences)

assert_elements_in_delta [0.027, -0.0228, -0.0105], embeddings[0][..2]
assert_elements_in_delta [0.0553, -0.0261, -0.04309], embeddings[-1][..2]
end

# https://huggingface.co/intfloat/e5-base-v2
def test_e5_base
doc_prefix = "passage: "
Expand Down

0 comments on commit ce465a9

Please sign in to comment.