Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene committed Dec 12, 2024
1 parent 9a3f29d commit 2d06fdb
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 22 deletions.
23 changes: 3 additions & 20 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ use crate::compute_cap::{
};
use crate::models::{
BertConfig, BertModel, DistilBertConfig, DistilBertModel, GTEConfig, GTEModel, JinaBertModel,
JinaCodeBertModel, MPNetConfig, MPNetModel, MistralConfig, Model, NomicBertModel, NomicConfig, Qwen2Config,
JinaCodeBertModel, MPNetConfig, MPNetModel, MistralConfig, Model, NomicBertModel, NomicConfig,
Qwen2Config,
};
#[cfg(feature = "cuda")]
use crate::models::{
Expand Down Expand Up @@ -228,7 +229,7 @@ impl CandleBackend {
"Qwen2 is only supported on Cuda devices in fp16 with flash attention enabled"
.to_string(),
)),
(Config::MPNet(config), Device::Cpu | Device::Metal(_)) => {
(Config::MPNet(config), _) => {
tracing::info!("Starting MPNet model on {:?}", device);
Ok(Box::new(MPNetModel::load(vb, &config, model_type).s()?))
}
Expand Down Expand Up @@ -374,24 +375,6 @@ impl CandleBackend {
FlashQwen2Model::load(vb, &config, model_type).s()?,
))
}
#[cfg(feature = "cuda")]
(Config::MPNet(config), Device::Cuda(_)) => {
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
&& dtype == DType::F16
// Allow disabling because of flash attention v1 precision problems
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"
{
// TODO: FLASH ATTENTION does not support (additive) `attention bias` for now.
// See: https://github.com/Dao-AILab/flash-attention/issues/342
return Err(BackendError::Start(
"MPNet is only supported on Cuda devices in fp32.".to_string(),
));
} else {
tracing::info!("Starting MPNet model on {:?}", device);
Ok(Box::new(MPNetModel::load(vb, &config, model_type).s()?))
}
}
};

Ok(Self {
Expand Down
4 changes: 2 additions & 2 deletions backends/candle/tests/test_mpnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ fn test_mini() -> Result<()> {
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(
model_root,
&model_root,
"float32".to_string(),
ModelType::Embedding(Pool::Mean),
)?;
Expand Down Expand Up @@ -73,7 +73,7 @@ fn test_mini_pooled_raw() -> Result<()> {
let tokenizer = load_tokenizer(&model_root)?;

let backend = CandleBackend::new(
model_root,
&model_root,
"float32".to_string(),
ModelType::Embedding(Pool::Cls),
)?;
Expand Down

0 comments on commit 2d06fdb

Please sign in to comment.