Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add cls pooling as default for BERT variants #426

Merged
merged 1 commit into from
Oct 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,20 @@ pub async fn run(
// Optionally download the pooling config.
if pooling.is_none() {
// If a pooling config exist, download it
let _ = download_pool_config(&api_repo).await;
let _ = download_pool_config(&api_repo).await.map_err(|err| {
tracing::warn!("Download failed: {err}");
err
});
}

// Download sentence transformers config
// Download legacy sentence transformers config
// We don't warn on failure as it is a legacy file
let _ = download_st_config(&api_repo).await;
// Download new sentence transformers config
let _ = download_new_st_config(&api_repo).await;
let _ = download_new_st_config(&api_repo).await.map_err(|err| {
tracing::warn!("Download failed: {err}");
err
});

// Download model from the Hub
download_artifacts(&api_repo)
Expand Down Expand Up @@ -387,10 +394,21 @@ fn get_backend_model_type(
None => {
// Load pooling config
let config_path = model_root.join("1_Pooling/config.json");
let config = fs::read_to_string(config_path).context("The `--pooling` arg is not set and we could not find a pooling configuration (`1_Pooling/config.json`) for this model.")?;
let config: PoolConfig =
serde_json::from_str(&config).context("Failed to parse `1_Pooling/config.json`")?;
Pool::try_from(config)?

match fs::read_to_string(config_path) {
Ok(config) => {
let config: PoolConfig = serde_json::from_str(&config)
.context("Failed to parse `1_Pooling/config.json`")?;
Pool::try_from(config)?
}
Err(err) => {
if !config.model_type.to_lowercase().contains("bert") {
return Err(err).context("The `--pooling` arg is not set and we could not find a pooling configuration (`1_Pooling/config.json`) for this model.");
}
tracing::warn!("The `--pooling` arg is not set and we could not find a pooling configuration (`1_Pooling/config.json`) for this model but the model is a BERT variant. Defaulting to `CLS` pooling.");
text_embeddings_backend::Pool::Cls
}
}
}
};
Ok(text_embeddings_backend::ModelType::Embedding(pool))
Expand Down
Loading