Skip to content

Commit

Permalink
feat: Implement MPNet model (#363) (#447)
Browse files Browse the repository at this point in the history
Co-authored-by: Hyeongchan Kim <[email protected]>
  • Loading branch information
OlivierDehaene and kozistr authored Dec 12, 2024
1 parent 0462171 commit 01d0fbd
Show file tree
Hide file tree
Showing 11 changed files with 29,281 additions and 3 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ Ember, GTE and E5. TEI implements many features such as:
#### Text Embeddings

Text Embeddings Inference currently supports Nomic, BERT, CamemBERT, XLM-RoBERTa models with absolute positions, JinaBERT
model with Alibi positions and Mistral, Alibaba GTE and Qwen2 models with Rope positions.
model with Alibi positions and Mistral, Alibaba GTE, Qwen2 models with Rope positions, and MPNet.

Below are some examples of the currently supported models:

Expand All @@ -81,7 +81,7 @@ Below are some examples of the currently supported models:
| N/A | 0.1B | NomicBert | [nomic-ai/nomic-embed-text-v1.5](https://hf.co/nomic-ai/nomic-embed-text-v1.5) |
| N/A | 0.1B | JinaBERT | [jinaai/jina-embeddings-v2-base-en](https://hf.co/jinaai/jina-embeddings-v2-base-en) |
| N/A | 0.1B | JinaBERT | [jinaai/jina-embeddings-v2-base-code](https://hf.co/jinaai/jina-embeddings-v2-base-code) |

| N/A | 0.1B | MPNet | [sentence-transformers/all-mpnet-base-v2](https://hf.co/sentence-transformers/all-mpnet-base-v2) |

To explore the list of best performing text embeddings models, visit the
[Massive Text Embedding Benchmark (MTEB) Leaderboard](https://huggingface.co/spaces/mteb/leaderboard).
Expand Down
9 changes: 8 additions & 1 deletion backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ use crate::compute_cap::{
};
use crate::models::{
BertConfig, BertModel, DistilBertConfig, DistilBertModel, GTEConfig, GTEModel, JinaBertModel,
JinaCodeBertModel, MistralConfig, Model, NomicBertModel, NomicConfig, Qwen2Config,
JinaCodeBertModel, MPNetConfig, MPNetModel, MistralConfig, Model, NomicBertModel, NomicConfig,
Qwen2Config,
};
#[cfg(feature = "cuda")]
use crate::models::{
Expand Down Expand Up @@ -60,6 +61,8 @@ enum Config {
#[serde(rename = "new")]
Gte(GTEConfig),
Qwen2(Qwen2Config),
#[serde(rename = "mpnet")]
MPNet(MPNetConfig),
}

pub struct CandleBackend {
Expand Down Expand Up @@ -226,6 +229,10 @@ impl CandleBackend {
"Qwen2 is only supported on Cuda devices in fp16 with flash attention enabled"
.to_string(),
)),
(Config::MPNet(config), _) => {
tracing::info!("Starting MPNet model on {:?}", device);
Ok(Box::new(MPNetModel::load(vb, &config, model_type).s()?))
}
#[cfg(feature = "cuda")]
(Config::Bert(config), Device::Cuda(_)) => {
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
Expand Down
2 changes: 2 additions & 0 deletions backends/candle/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ mod flash_mistral;
#[cfg(feature = "cuda")]
mod flash_qwen2;
mod gte;
mod mpnet;
mod qwen2;

pub use bert::{BertConfig, BertModel, PositionEmbeddingType};
Expand All @@ -44,6 +45,7 @@ pub use gte::{GTEClassificationHead, GTEConfig, GTEModel, GTEMLP};
pub use jina::JinaBertModel;
pub use jina_code::JinaCodeBertModel;
pub use mistral::MistralConfig;
pub use mpnet::{MPNetConfig, MPNetModel};
pub use nomic::{NomicBertModel, NomicConfig};
pub use qwen2::Qwen2Config;
use text_embeddings_backend_core::Batch;
Expand Down
Loading

0 comments on commit 01d0fbd

Please sign in to comment.