Skip to content

Commit

Permalink
feat: Implement GTE model to support the non-flash-attn version (#446)
Browse files Browse the repository at this point in the history
Co-authored-by: Hyeongchan Kim <[email protected]>
  • Loading branch information
OlivierDehaene and kozistr authored Dec 12, 2024
1 parent e27a4fb commit 0462171
Show file tree
Hide file tree
Showing 13 changed files with 3,964 additions and 203 deletions.
2 changes: 2 additions & 0 deletions backends/candle/src/layers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ mod layer_norm;
mod linear;
#[allow(dead_code, unused)]
mod rms_norm;
mod rotary;

pub use cublaslt::get_cublas_lt_wrapper;
pub use layer_norm::LayerNorm;
pub use linear::{HiddenAct, Linear};
#[allow(unused_imports)]
pub use rms_norm::RMSNorm;
pub use rotary::{apply_rotary, get_cos_sin, get_inv_freqs, RopeScaling};
73 changes: 73 additions & 0 deletions backends/candle/src/layers/rotary.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use candle::{DType, Device, Result, Tensor, D};
use serde::Deserialize;

#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct NTKScaling {
pub factor: f32,
}

#[derive(Debug, Clone, PartialEq, Deserialize)]
#[serde(tag = "type", rename_all = "kebab-case")]
pub enum RopeScaling {
Ntk(NTKScaling),
}

pub fn get_inv_freqs(
dim: usize,
base: f32,
device: &Device,
rope_scaling: Option<&RopeScaling>,
) -> Result<Tensor> {
let get_inv_freqs_inner = |dim: usize, base: f32, device: &Device| {
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / base.powf(i as f32 / dim as f32))
.collect();
let inv_freq_len = inv_freq.len();
Tensor::from_vec(inv_freq, (1, inv_freq_len), device)
};

if let Some(rope_scaling) = rope_scaling {
match rope_scaling {
RopeScaling::Ntk(ntk_scaling) => {
let inv_freqs = get_inv_freqs_inner(dim, base * ntk_scaling.factor, device)?;
let s = ntk_scaling.factor.powf(2.0 / dim as f32) as f64;
return inv_freqs / s;
}
}
}
get_inv_freqs_inner(dim, base, device)
}

pub fn get_cos_sin(
length: usize,
inv_freqs: &Tensor,
dtype: DType,
repeat_freqs: bool,
) -> Result<(Tensor, Tensor)> {
let t = Tensor::arange(0u32, length as u32, inv_freqs.device())?
.to_dtype(DType::F32)?
.reshape((length, 1))?;
let mut freqs = t.matmul(inv_freqs)?;
if repeat_freqs {
freqs = Tensor::cat(&[&freqs, &freqs], 1)?;
}

let cos = freqs.cos()?.to_dtype(dtype)?;
let sin = freqs.sin()?.to_dtype(dtype)?;
Ok((cos, sin))
}

pub fn apply_rotary(
x: &Tensor,
cos: &Tensor,
sin: &Tensor,
attention_head_size: usize,
) -> Result<Tensor> {
let dim = attention_head_size / 2;
let x1 = x.narrow(D::Minus1, 0, dim)?;
let x2 = x.narrow(D::Minus1, dim, dim)?;
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
let rope = (x.broadcast_mul(cos)? + rotate_x.broadcast_mul(sin)?)?;
Ok(rope)
}
18 changes: 10 additions & 8 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::compute_cap::{
compatible_compute_cap, get_compile_compute_cap, get_runtime_compute_cap,
};
use crate::models::{
BertConfig, BertModel, DistilBertConfig, DistilBertModel, GTEConfig, JinaBertModel,
BertConfig, BertModel, DistilBertConfig, DistilBertModel, GTEConfig, GTEModel, JinaBertModel,
JinaCodeBertModel, MistralConfig, Model, NomicBertModel, NomicConfig, Qwen2Config,
};
#[cfg(feature = "cuda")]
Expand Down Expand Up @@ -218,10 +218,10 @@ impl CandleBackend {
"Mistral is only supported on Cuda devices in fp16 with flash attention enabled"
.to_string(),
)),
(Config::Gte(_), Device::Cpu | Device::Metal(_)) => Err(BackendError::Start(
"GTE is only supported on Cuda devices in fp16 with flash attention enabled"
.to_string(),
)),
(Config::Gte(config), Device::Cpu | Device::Metal(_)) => {
tracing::info!("Starting GTE model on {:?}", device);
Ok(Box::new(GTEModel::load(vb, &config, model_type).s()?))
}
(Config::Qwen2(_), Device::Cpu | Device::Metal(_)) => Err(BackendError::Start(
"Qwen2 is only supported on Cuda devices in fp16 with flash attention enabled"
.to_string(),
Expand Down Expand Up @@ -349,10 +349,12 @@ impl CandleBackend {
if dtype != DType::F16
|| !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
{
return Err(BackendError::Start("GTE is only supported on Cuda devices in fp16 with flash attention enabled".to_string()));
tracing::info!("Starting GTE model on {:?}", device);
Ok(Box::new(GTEModel::load(vb, &config, model_type).s()?))
} else {
tracing::info!("Starting FlashGTE model on {:?}", device);
Ok(Box::new(FlashGTEModel::load(vb, &config, model_type).s()?))
}
tracing::info!("Starting FlashGTE model on {:?}", device);
Ok(Box::new(FlashGTEModel::load(vb, &config, model_type).s()?))
}
#[cfg(feature = "cuda")]
(Config::Qwen2(config), Device::Cuda(_)) => {
Expand Down
145 changes: 18 additions & 127 deletions backends/candle/src/models/flash_gte.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use crate::flash_attn::flash_attn_varlen;
use crate::layers::{HiddenAct, LayerNorm, Linear};
use crate::models::{GTEConfig, Model, NTKScaling, PositionEmbeddingType, RopeScaling};
use crate::layers::{get_cos_sin, get_inv_freqs, LayerNorm, Linear};
use crate::models::{GTEClassificationHead, GTEConfig, Model, PositionEmbeddingType, GTEMLP};
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::{Embedding, Module, VarBuilder};
use candle_rotary::apply_rotary_inplace;
use text_embeddings_backend_core::{Batch, ModelType, Pool};

struct GTEAttention {
Expand Down Expand Up @@ -72,7 +73,7 @@ impl GTEAttention {
let k = qkv.narrow(1, self.num_attention_heads, self.num_attention_heads)?;
let v = qkv.narrow(1, self.num_attention_heads * 2, self.num_attention_heads)?;

candle_rotary::apply_rotary_inplace(&q, &k, &cos, &sin, true)?;
apply_rotary_inplace(&q, &k, &cos, &sin, true)?;

let attention = flash_attn_varlen(
&q,
Expand All @@ -93,60 +94,7 @@ impl GTEAttention {
}
}

struct GTEMLP {
up_gate_proj: Linear,
down_proj: Linear,

act: HiddenAct,
intermediate_size: usize,

span: tracing::Span,
}

impl GTEMLP {
pub fn load(vb: VarBuilder, config: &GTEConfig) -> Result<Self> {
let intermediate_size = config.intermediate_size;

let up_gate_proj_weight = vb
.pp("up_gate_proj")
.get((intermediate_size * 2, config.hidden_size), "weight")?;

let up_gate_proj = Linear::new(up_gate_proj_weight, None, None);

let down_proj_weight = vb
.pp("down_proj")
.get((config.hidden_size, intermediate_size), "weight")?;
let down_proj_bias = vb.pp("down_proj").get(config.hidden_size, "bias")?;
let down_proj = Linear::new(down_proj_weight, Some(down_proj_bias), None);

Ok(Self {
up_gate_proj,
down_proj,
intermediate_size,
act: config.hidden_act.clone(),
span: tracing::span!(tracing::Level::TRACE, "mlp"),
})
}

pub fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();

let up_gate_states = self.up_gate_proj.forward(hidden_states)?;
let up_states = up_gate_states.narrow(1, 0, self.intermediate_size)?;
let gate_states =
up_gate_states.narrow(1, self.intermediate_size, self.intermediate_size)?;

let gate_states = match self.act {
HiddenAct::Gelu => gate_states.gelu(),
HiddenAct::Relu => gate_states.relu(),
HiddenAct::Swiglu => gate_states.silu(),
}?;
let r = self.down_proj.forward(&(gate_states * up_states)?);
r
}
}

struct GTELayer {
pub struct GTELayer {
attention: GTEAttention,
mlp: GTEMLP,
attention_layer_norm: LayerNorm,
Expand Down Expand Up @@ -198,58 +146,6 @@ impl GTELayer {
}
}

pub struct GTEClassificationHead {
pooler: Option<Linear>,
classifier: Linear,
span: tracing::Span,
}

impl GTEClassificationHead {
#[allow(dead_code)]
pub(crate) fn load(vb: VarBuilder, config: &GTEConfig) -> Result<Self> {
let n_classes = match &config.id2label {
None => candle::bail!("`id2label` must be set for classifier models"),
Some(id2label) => id2label.len(),
};

let pooler = if let Ok(pooler_weight) = vb
.pp("pooler.dense")
.get((config.hidden_size, config.hidden_size), "weight")
{
let pooler_bias = vb.pp("pooler.dense").get(config.hidden_size, "bias")?;
Some(Linear::new(pooler_weight, Some(pooler_bias), None))
} else {
None
};

let classifier_weight = vb
.pp("classifier")
.get((n_classes, config.hidden_size), "weight")?;
let classifier_bias = vb.pp("classifier").get(n_classes, "bias")?;
let classifier = Linear::new(classifier_weight, Some(classifier_bias), None);

Ok(Self {
classifier,
pooler,
span: tracing::span!(tracing::Level::TRACE, "classifier"),
})
}

pub(crate) fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();

let mut hidden_states = hidden_states.unsqueeze(1)?;
if let Some(pooler) = self.pooler.as_ref() {
hidden_states = pooler.forward(&hidden_states)?;
hidden_states = hidden_states.tanh()?;
}

let hidden_states = self.classifier.forward(&hidden_states)?;
let hidden_states = hidden_states.squeeze(1)?;
Ok(hidden_states)
}
}

pub struct FlashGTEModel {
word_embeddings: Embedding,
token_type_embeddings: Option<Embedding>,
Expand Down Expand Up @@ -322,24 +218,19 @@ impl FlashGTEModel {
config.layer_norm_eps,
)?;

let inv_freqs = if let Some(RopeScaling::Ntk(NTKScaling { factor })) = config.rope_scaling {
let inv_freqs = candle_rotary::inv_freqs(
layers[0].attention.attention_head_size,
config.rope_theta * factor,
vb.device(),
)?;
let s = factor.powf(2.0 / layers[0].attention.attention_head_size as f32) as f64;
inv_freqs / s
} else {
candle_rotary::inv_freqs(
layers[0].attention.attention_head_size,
config.rope_theta,
vb.device(),
)
}?;

let (cos_cache, sin_cache) =
candle_rotary::cos_sin(config.max_position_embeddings, &inv_freqs, vb.dtype())?;
let inv_freqs = get_inv_freqs(
layers[0].attention.attention_head_size,
config.rope_theta,
vb.device(),
config.rope_scaling.as_ref(),
)?;

let (cos_cache, sin_cache) = get_cos_sin(
config.max_position_embeddings,
&inv_freqs,
vb.dtype(),
false,
)?;

Ok(Self {
word_embeddings,
Expand Down
16 changes: 11 additions & 5 deletions backends/candle/src/models/flash_mistral.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use crate::flash_attn::flash_attn_varlen;
use crate::layers::{HiddenAct, Linear, RMSNorm};
use crate::layers::{get_cos_sin, get_inv_freqs, HiddenAct, Linear, RMSNorm};
use crate::models::{MistralConfig, Model};
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::{Embedding, Module, VarBuilder};
use candle_rotary::apply_rotary_inplace;
use text_embeddings_backend_core::{Batch, ModelType, Pool};

struct MistralAttention {
Expand Down Expand Up @@ -90,7 +91,7 @@ impl MistralAttention {
self.num_key_value_heads,
)?;

candle_rotary::apply_rotary_inplace(&q, &k, &cos, &sin, true)?;
apply_rotary_inplace(&q, &k, &cos, &sin, true)?;

let attention = flash_attn_varlen(
&q,
Expand Down Expand Up @@ -267,13 +268,18 @@ impl FlashMistralModel {

let norm = RMSNorm::load(vb.pp("norm"), config.hidden_size, config.rms_norm_eps)?;

let inv_freqs = candle_rotary::inv_freqs(
let inv_freqs = get_inv_freqs(
layers[0].attention.attention_head_size,
config.rope_theta,
vb.device(),
None,
)?;
let (cos_cache, sin_cache) = get_cos_sin(
config.max_position_embeddings,
&inv_freqs,
vb.dtype(),
false,
)?;
let (cos_cache, sin_cache) =
candle_rotary::cos_sin(config.max_position_embeddings, &inv_freqs, vb.dtype())?;

Ok(Self {
embeddings,
Expand Down
14 changes: 8 additions & 6 deletions backends/candle/src/models/flash_nomic.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use crate::flash_attn::flash_attn_varlen;
use crate::layers::{LayerNorm, Linear};
use crate::layers::{get_cos_sin, get_inv_freqs, LayerNorm, Linear};
use crate::models::nomic::{NomicBertEmbeddings, NomicBertGatedMLP};
use crate::models::{Model, NomicConfig};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::VarBuilder;
use candle_rotary::apply_rotary_inplace;
use text_embeddings_backend_core::{Batch, ModelType, Pool};

struct NomicAttention {
Expand Down Expand Up @@ -68,7 +69,7 @@ impl NomicAttention {
let qkv = qkv.reshape(new_qkv_shape.as_slice())?;
let qkv = qkv.chunk(3, 1)?;

candle_rotary::apply_rotary_inplace(&qkv[0], &qkv[1], &cos, &sin, true)?;
apply_rotary_inplace(&qkv[0], &qkv[1], &cos, &sin, true)?;

let attention = flash_attn_varlen(
&qkv[0],
Expand Down Expand Up @@ -221,20 +222,21 @@ impl FlashNomicBertModel {
let encoder = NomicBertEncoder::load(vb.pp("encoder"), config)?;

let rotary_dim = encoder.layers[0].attention.attention_head_size;
let inv_freqs = candle_rotary::inv_freqs(rotary_dim, config.rotary_emb_base, vb.device())?;
let rotary_cache = candle_rotary::cos_sin(config.n_positions, &inv_freqs, vb.dtype())?;
let inv_freqs = get_inv_freqs(rotary_dim, config.rotary_emb_base, vb.device(), None)?;
let rotary_cache = get_cos_sin(config.n_positions, &inv_freqs, vb.dtype(), false)?;

let scaled_rotary_cache = if let Some(scaling_factor) = config.rotary_scaling_factor {
let new_base = (config.rotary_emb_base
* ((scaling_factor * config.n_positions as f32
/ config.max_trained_positions as f32)
- (scaling_factor - 1.0)))
.powi((rotary_dim as f32 / (rotary_dim as f32 - 2.0)) as i32);
let inv_freqs = candle_rotary::inv_freqs(rotary_dim, new_base, vb.device())?;
Some(candle_rotary::cos_sin(
let inv_freqs = get_inv_freqs(rotary_dim, new_base, vb.device(), None)?;
Some(get_cos_sin(
config.n_positions,
&inv_freqs,
vb.dtype(),
false,
)?)
} else {
None
Expand Down
Loading

0 comments on commit 0462171

Please sign in to comment.