From 0462171b37f5f0b9e7a892044d0bb4fff0ef500d Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Thu, 12 Dec 2024 14:04:42 +0100 Subject: [PATCH] feat: Implement GTE model to support the non-flash-attn version (#446) Co-authored-by: Hyeongchan Kim --- backends/candle/src/layers/mod.rs | 2 + backends/candle/src/layers/rotary.rs | 73 + backends/candle/src/lib.rs | 18 +- backends/candle/src/models/flash_gte.rs | 145 +- backends/candle/src/models/flash_mistral.rs | 16 +- backends/candle/src/models/flash_nomic.rs | 14 +- backends/candle/src/models/flash_qwen2.rs | 16 +- backends/candle/src/models/gte.rs | 694 ++++- backends/candle/src/models/mod.rs | 2 +- backends/candle/src/models/nomic.rs | 55 +- .../tests/snapshots/test_gte__gte_batch.snap | 2309 +++++++++++++++++ .../tests/snapshots/test_gte__gte_single.snap | 773 ++++++ backends/candle/tests/test_gte.rs | 50 + 13 files changed, 3964 insertions(+), 203 deletions(-) create mode 100644 backends/candle/src/layers/rotary.rs create mode 100644 backends/candle/tests/snapshots/test_gte__gte_batch.snap create mode 100644 backends/candle/tests/snapshots/test_gte__gte_single.snap create mode 100644 backends/candle/tests/test_gte.rs diff --git a/backends/candle/src/layers/mod.rs b/backends/candle/src/layers/mod.rs index 81f63310..76386eaa 100644 --- a/backends/candle/src/layers/mod.rs +++ b/backends/candle/src/layers/mod.rs @@ -4,9 +4,11 @@ mod layer_norm; mod linear; #[allow(dead_code, unused)] mod rms_norm; +mod rotary; pub use cublaslt::get_cublas_lt_wrapper; pub use layer_norm::LayerNorm; pub use linear::{HiddenAct, Linear}; #[allow(unused_imports)] pub use rms_norm::RMSNorm; +pub use rotary::{apply_rotary, get_cos_sin, get_inv_freqs, RopeScaling}; diff --git a/backends/candle/src/layers/rotary.rs b/backends/candle/src/layers/rotary.rs new file mode 100644 index 00000000..147f6b33 --- /dev/null +++ b/backends/candle/src/layers/rotary.rs @@ -0,0 +1,73 @@ +use candle::{DType, Device, Result, Tensor, D}; +use serde::Deserialize; + +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct NTKScaling { + pub factor: f32, +} + +#[derive(Debug, Clone, PartialEq, Deserialize)] +#[serde(tag = "type", rename_all = "kebab-case")] +pub enum RopeScaling { + Ntk(NTKScaling), +} + +pub fn get_inv_freqs( + dim: usize, + base: f32, + device: &Device, + rope_scaling: Option<&RopeScaling>, +) -> Result { + let get_inv_freqs_inner = |dim: usize, base: f32, device: &Device| { + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / base.powf(i as f32 / dim as f32)) + .collect(); + let inv_freq_len = inv_freq.len(); + Tensor::from_vec(inv_freq, (1, inv_freq_len), device) + }; + + if let Some(rope_scaling) = rope_scaling { + match rope_scaling { + RopeScaling::Ntk(ntk_scaling) => { + let inv_freqs = get_inv_freqs_inner(dim, base * ntk_scaling.factor, device)?; + let s = ntk_scaling.factor.powf(2.0 / dim as f32) as f64; + return inv_freqs / s; + } + } + } + get_inv_freqs_inner(dim, base, device) +} + +pub fn get_cos_sin( + length: usize, + inv_freqs: &Tensor, + dtype: DType, + repeat_freqs: bool, +) -> Result<(Tensor, Tensor)> { + let t = Tensor::arange(0u32, length as u32, inv_freqs.device())? + .to_dtype(DType::F32)? + .reshape((length, 1))?; + let mut freqs = t.matmul(inv_freqs)?; + if repeat_freqs { + freqs = Tensor::cat(&[&freqs, &freqs], 1)?; + } + + let cos = freqs.cos()?.to_dtype(dtype)?; + let sin = freqs.sin()?.to_dtype(dtype)?; + Ok((cos, sin)) +} + +pub fn apply_rotary( + x: &Tensor, + cos: &Tensor, + sin: &Tensor, + attention_head_size: usize, +) -> Result { + let dim = attention_head_size / 2; + let x1 = x.narrow(D::Minus1, 0, dim)?; + let x2 = x.narrow(D::Minus1, dim, dim)?; + let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?; + let rope = (x.broadcast_mul(cos)? + rotate_x.broadcast_mul(sin)?)?; + Ok(rope) +} diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index 944457e7..aec84375 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -11,7 +11,7 @@ use crate::compute_cap::{ compatible_compute_cap, get_compile_compute_cap, get_runtime_compute_cap, }; use crate::models::{ - BertConfig, BertModel, DistilBertConfig, DistilBertModel, GTEConfig, JinaBertModel, + BertConfig, BertModel, DistilBertConfig, DistilBertModel, GTEConfig, GTEModel, JinaBertModel, JinaCodeBertModel, MistralConfig, Model, NomicBertModel, NomicConfig, Qwen2Config, }; #[cfg(feature = "cuda")] @@ -218,10 +218,10 @@ impl CandleBackend { "Mistral is only supported on Cuda devices in fp16 with flash attention enabled" .to_string(), )), - (Config::Gte(_), Device::Cpu | Device::Metal(_)) => Err(BackendError::Start( - "GTE is only supported on Cuda devices in fp16 with flash attention enabled" - .to_string(), - )), + (Config::Gte(config), Device::Cpu | Device::Metal(_)) => { + tracing::info!("Starting GTE model on {:?}", device); + Ok(Box::new(GTEModel::load(vb, &config, model_type).s()?)) + } (Config::Qwen2(_), Device::Cpu | Device::Metal(_)) => Err(BackendError::Start( "Qwen2 is only supported on Cuda devices in fp16 with flash attention enabled" .to_string(), @@ -349,10 +349,12 @@ impl CandleBackend { if dtype != DType::F16 || !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1")) { - return Err(BackendError::Start("GTE is only supported on Cuda devices in fp16 with flash attention enabled".to_string())); + tracing::info!("Starting GTE model on {:?}", device); + Ok(Box::new(GTEModel::load(vb, &config, model_type).s()?)) + } else { + tracing::info!("Starting FlashGTE model on {:?}", device); + Ok(Box::new(FlashGTEModel::load(vb, &config, model_type).s()?)) } - tracing::info!("Starting FlashGTE model on {:?}", device); - Ok(Box::new(FlashGTEModel::load(vb, &config, model_type).s()?)) } #[cfg(feature = "cuda")] (Config::Qwen2(config), Device::Cuda(_)) => { diff --git a/backends/candle/src/models/flash_gte.rs b/backends/candle/src/models/flash_gte.rs index 53e62f6d..b9bb4cdf 100644 --- a/backends/candle/src/models/flash_gte.rs +++ b/backends/candle/src/models/flash_gte.rs @@ -1,8 +1,9 @@ use crate::flash_attn::flash_attn_varlen; -use crate::layers::{HiddenAct, LayerNorm, Linear}; -use crate::models::{GTEConfig, Model, NTKScaling, PositionEmbeddingType, RopeScaling}; +use crate::layers::{get_cos_sin, get_inv_freqs, LayerNorm, Linear}; +use crate::models::{GTEClassificationHead, GTEConfig, Model, PositionEmbeddingType, GTEMLP}; use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::{Embedding, Module, VarBuilder}; +use candle_rotary::apply_rotary_inplace; use text_embeddings_backend_core::{Batch, ModelType, Pool}; struct GTEAttention { @@ -72,7 +73,7 @@ impl GTEAttention { let k = qkv.narrow(1, self.num_attention_heads, self.num_attention_heads)?; let v = qkv.narrow(1, self.num_attention_heads * 2, self.num_attention_heads)?; - candle_rotary::apply_rotary_inplace(&q, &k, &cos, &sin, true)?; + apply_rotary_inplace(&q, &k, &cos, &sin, true)?; let attention = flash_attn_varlen( &q, @@ -93,60 +94,7 @@ impl GTEAttention { } } -struct GTEMLP { - up_gate_proj: Linear, - down_proj: Linear, - - act: HiddenAct, - intermediate_size: usize, - - span: tracing::Span, -} - -impl GTEMLP { - pub fn load(vb: VarBuilder, config: >EConfig) -> Result { - let intermediate_size = config.intermediate_size; - - let up_gate_proj_weight = vb - .pp("up_gate_proj") - .get((intermediate_size * 2, config.hidden_size), "weight")?; - - let up_gate_proj = Linear::new(up_gate_proj_weight, None, None); - - let down_proj_weight = vb - .pp("down_proj") - .get((config.hidden_size, intermediate_size), "weight")?; - let down_proj_bias = vb.pp("down_proj").get(config.hidden_size, "bias")?; - let down_proj = Linear::new(down_proj_weight, Some(down_proj_bias), None); - - Ok(Self { - up_gate_proj, - down_proj, - intermediate_size, - act: config.hidden_act.clone(), - span: tracing::span!(tracing::Level::TRACE, "mlp"), - }) - } - - pub fn forward(&self, hidden_states: &Tensor) -> Result { - let _enter = self.span.enter(); - - let up_gate_states = self.up_gate_proj.forward(hidden_states)?; - let up_states = up_gate_states.narrow(1, 0, self.intermediate_size)?; - let gate_states = - up_gate_states.narrow(1, self.intermediate_size, self.intermediate_size)?; - - let gate_states = match self.act { - HiddenAct::Gelu => gate_states.gelu(), - HiddenAct::Relu => gate_states.relu(), - HiddenAct::Swiglu => gate_states.silu(), - }?; - let r = self.down_proj.forward(&(gate_states * up_states)?); - r - } -} - -struct GTELayer { +pub struct GTELayer { attention: GTEAttention, mlp: GTEMLP, attention_layer_norm: LayerNorm, @@ -198,58 +146,6 @@ impl GTELayer { } } -pub struct GTEClassificationHead { - pooler: Option, - classifier: Linear, - span: tracing::Span, -} - -impl GTEClassificationHead { - #[allow(dead_code)] - pub(crate) fn load(vb: VarBuilder, config: >EConfig) -> Result { - let n_classes = match &config.id2label { - None => candle::bail!("`id2label` must be set for classifier models"), - Some(id2label) => id2label.len(), - }; - - let pooler = if let Ok(pooler_weight) = vb - .pp("pooler.dense") - .get((config.hidden_size, config.hidden_size), "weight") - { - let pooler_bias = vb.pp("pooler.dense").get(config.hidden_size, "bias")?; - Some(Linear::new(pooler_weight, Some(pooler_bias), None)) - } else { - None - }; - - let classifier_weight = vb - .pp("classifier") - .get((n_classes, config.hidden_size), "weight")?; - let classifier_bias = vb.pp("classifier").get(n_classes, "bias")?; - let classifier = Linear::new(classifier_weight, Some(classifier_bias), None); - - Ok(Self { - classifier, - pooler, - span: tracing::span!(tracing::Level::TRACE, "classifier"), - }) - } - - pub(crate) fn forward(&self, hidden_states: &Tensor) -> Result { - let _enter = self.span.enter(); - - let mut hidden_states = hidden_states.unsqueeze(1)?; - if let Some(pooler) = self.pooler.as_ref() { - hidden_states = pooler.forward(&hidden_states)?; - hidden_states = hidden_states.tanh()?; - } - - let hidden_states = self.classifier.forward(&hidden_states)?; - let hidden_states = hidden_states.squeeze(1)?; - Ok(hidden_states) - } -} - pub struct FlashGTEModel { word_embeddings: Embedding, token_type_embeddings: Option, @@ -322,24 +218,19 @@ impl FlashGTEModel { config.layer_norm_eps, )?; - let inv_freqs = if let Some(RopeScaling::Ntk(NTKScaling { factor })) = config.rope_scaling { - let inv_freqs = candle_rotary::inv_freqs( - layers[0].attention.attention_head_size, - config.rope_theta * factor, - vb.device(), - )?; - let s = factor.powf(2.0 / layers[0].attention.attention_head_size as f32) as f64; - inv_freqs / s - } else { - candle_rotary::inv_freqs( - layers[0].attention.attention_head_size, - config.rope_theta, - vb.device(), - ) - }?; - - let (cos_cache, sin_cache) = - candle_rotary::cos_sin(config.max_position_embeddings, &inv_freqs, vb.dtype())?; + let inv_freqs = get_inv_freqs( + layers[0].attention.attention_head_size, + config.rope_theta, + vb.device(), + config.rope_scaling.as_ref(), + )?; + + let (cos_cache, sin_cache) = get_cos_sin( + config.max_position_embeddings, + &inv_freqs, + vb.dtype(), + false, + )?; Ok(Self { word_embeddings, diff --git a/backends/candle/src/models/flash_mistral.rs b/backends/candle/src/models/flash_mistral.rs index eb94c913..70538269 100644 --- a/backends/candle/src/models/flash_mistral.rs +++ b/backends/candle/src/models/flash_mistral.rs @@ -1,8 +1,9 @@ use crate::flash_attn::flash_attn_varlen; -use crate::layers::{HiddenAct, Linear, RMSNorm}; +use crate::layers::{get_cos_sin, get_inv_freqs, HiddenAct, Linear, RMSNorm}; use crate::models::{MistralConfig, Model}; use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::{Embedding, Module, VarBuilder}; +use candle_rotary::apply_rotary_inplace; use text_embeddings_backend_core::{Batch, ModelType, Pool}; struct MistralAttention { @@ -90,7 +91,7 @@ impl MistralAttention { self.num_key_value_heads, )?; - candle_rotary::apply_rotary_inplace(&q, &k, &cos, &sin, true)?; + apply_rotary_inplace(&q, &k, &cos, &sin, true)?; let attention = flash_attn_varlen( &q, @@ -267,13 +268,18 @@ impl FlashMistralModel { let norm = RMSNorm::load(vb.pp("norm"), config.hidden_size, config.rms_norm_eps)?; - let inv_freqs = candle_rotary::inv_freqs( + let inv_freqs = get_inv_freqs( layers[0].attention.attention_head_size, config.rope_theta, vb.device(), + None, + )?; + let (cos_cache, sin_cache) = get_cos_sin( + config.max_position_embeddings, + &inv_freqs, + vb.dtype(), + false, )?; - let (cos_cache, sin_cache) = - candle_rotary::cos_sin(config.max_position_embeddings, &inv_freqs, vb.dtype())?; Ok(Self { embeddings, diff --git a/backends/candle/src/models/flash_nomic.rs b/backends/candle/src/models/flash_nomic.rs index 7f558b9e..057db768 100644 --- a/backends/candle/src/models/flash_nomic.rs +++ b/backends/candle/src/models/flash_nomic.rs @@ -1,9 +1,10 @@ use crate::flash_attn::flash_attn_varlen; -use crate::layers::{LayerNorm, Linear}; +use crate::layers::{get_cos_sin, get_inv_freqs, LayerNorm, Linear}; use crate::models::nomic::{NomicBertEmbeddings, NomicBertGatedMLP}; use crate::models::{Model, NomicConfig}; use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::VarBuilder; +use candle_rotary::apply_rotary_inplace; use text_embeddings_backend_core::{Batch, ModelType, Pool}; struct NomicAttention { @@ -68,7 +69,7 @@ impl NomicAttention { let qkv = qkv.reshape(new_qkv_shape.as_slice())?; let qkv = qkv.chunk(3, 1)?; - candle_rotary::apply_rotary_inplace(&qkv[0], &qkv[1], &cos, &sin, true)?; + apply_rotary_inplace(&qkv[0], &qkv[1], &cos, &sin, true)?; let attention = flash_attn_varlen( &qkv[0], @@ -221,8 +222,8 @@ impl FlashNomicBertModel { let encoder = NomicBertEncoder::load(vb.pp("encoder"), config)?; let rotary_dim = encoder.layers[0].attention.attention_head_size; - let inv_freqs = candle_rotary::inv_freqs(rotary_dim, config.rotary_emb_base, vb.device())?; - let rotary_cache = candle_rotary::cos_sin(config.n_positions, &inv_freqs, vb.dtype())?; + let inv_freqs = get_inv_freqs(rotary_dim, config.rotary_emb_base, vb.device(), None)?; + let rotary_cache = get_cos_sin(config.n_positions, &inv_freqs, vb.dtype(), false)?; let scaled_rotary_cache = if let Some(scaling_factor) = config.rotary_scaling_factor { let new_base = (config.rotary_emb_base @@ -230,11 +231,12 @@ impl FlashNomicBertModel { / config.max_trained_positions as f32) - (scaling_factor - 1.0))) .powi((rotary_dim as f32 / (rotary_dim as f32 - 2.0)) as i32); - let inv_freqs = candle_rotary::inv_freqs(rotary_dim, new_base, vb.device())?; - Some(candle_rotary::cos_sin( + let inv_freqs = get_inv_freqs(rotary_dim, new_base, vb.device(), None)?; + Some(get_cos_sin( config.n_positions, &inv_freqs, vb.dtype(), + false, )?) } else { None diff --git a/backends/candle/src/models/flash_qwen2.rs b/backends/candle/src/models/flash_qwen2.rs index 4dba7868..c6662047 100644 --- a/backends/candle/src/models/flash_qwen2.rs +++ b/backends/candle/src/models/flash_qwen2.rs @@ -1,8 +1,9 @@ use crate::flash_attn::flash_attn_varlen; -use crate::layers::{HiddenAct, Linear, RMSNorm}; +use crate::layers::{get_cos_sin, get_inv_freqs, HiddenAct, Linear, RMSNorm}; use crate::models::{Model, Qwen2Config}; use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::{Embedding, Module, VarBuilder}; +use candle_rotary::apply_rotary_inplace; use text_embeddings_backend_core::{Batch, ModelType, Pool}; struct Qwen2Attention { @@ -98,7 +99,7 @@ impl Qwen2Attention { self.num_key_value_heads, )?; - candle_rotary::apply_rotary_inplace(&q, &k, &cos, &sin, true)?; + apply_rotary_inplace(&q, &k, &cos, &sin, true)?; let attention = flash_attn_varlen( &q, @@ -277,13 +278,18 @@ impl FlashQwen2Model { let norm = RMSNorm::load(vb.pp("norm"), config.hidden_size, config.rms_norm_eps)?; - let inv_freqs = candle_rotary::inv_freqs( + let inv_freqs = get_inv_freqs( layers[0].attention.attention_head_size, config.rope_theta, vb.device(), + None, + )?; + let (cos_cache, sin_cache) = get_cos_sin( + config.max_position_embeddings, + &inv_freqs, + vb.dtype(), + false, )?; - let (cos_cache, sin_cache) = - candle_rotary::cos_sin(config.max_position_embeddings, &inv_freqs, vb.dtype())?; Ok(Self { embeddings, diff --git a/backends/candle/src/models/gte.rs b/backends/candle/src/models/gte.rs index bc4bfdce..4dca1620 100644 --- a/backends/candle/src/models/gte.rs +++ b/backends/candle/src/models/gte.rs @@ -1,18 +1,13 @@ -use crate::layers::HiddenAct; -use crate::models::PositionEmbeddingType; +use crate::layers::{ + apply_rotary, get_cos_sin, get_cublas_lt_wrapper, get_inv_freqs, HiddenAct, LayerNorm, Linear, + RopeScaling, +}; +use crate::models::{Model, PositionEmbeddingType}; +use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle_nn::{Embedding, Module, VarBuilder}; use serde::Deserialize; use std::collections::HashMap; - -#[derive(Debug, Clone, PartialEq, Deserialize)] -pub struct NTKScaling { - pub factor: f32, -} - -#[derive(Debug, Clone, PartialEq, Deserialize)] -#[serde(tag = "type", rename_all = "kebab-case")] -pub enum RopeScaling { - Ntk(NTKScaling), -} +use text_embeddings_backend_core::{Batch, ModelType, Pool}; #[derive(Debug, Clone, PartialEq, Deserialize)] pub struct GTEConfig { @@ -35,3 +30,676 @@ pub struct GTEConfig { pub logn_attention_clip1: bool, pub id2label: Option>, } + +struct GTEAttention { + qkv_linear: Linear, + o_proj: Linear, + + num_attention_heads: usize, + attention_head_size: usize, + + softmax_scale: f64, + + span: tracing::Span, +} + +impl GTEAttention { + pub fn load(vb: VarBuilder, config: >EConfig) -> Result { + let num_attention_heads = config.num_attention_heads; + let attention_head_size = config.hidden_size / config.num_attention_heads; + let hidden_size = config.hidden_size; + + let qkv_weight = vb + .pp("qkv_proj") + .get((hidden_size * 3, hidden_size), "weight")?; + let qkv_bias = vb.pp("qkv_proj").get(hidden_size * 3, "bias")?; + + let qkv_linear = Linear::new(qkv_weight, Some(qkv_bias), None); + + let o_proj_weight = vb.pp("o_proj").get((hidden_size, hidden_size), "weight")?; + let o_proj_bias = vb.pp("o_proj").get(hidden_size, "bias")?; + + let o_proj = Linear::new(o_proj_weight, Some(o_proj_bias), None); + + let softmax_scale = 1. / (attention_head_size as f64).sqrt(); + + Ok(Self { + qkv_linear, + o_proj, + num_attention_heads, + attention_head_size, + softmax_scale, + span: tracing::span!(tracing::Level::TRACE, "attention"), + }) + } + + pub fn forward( + &self, + hidden_states: &Tensor, + attention_bias: Option<&Tensor>, + cos: &Tensor, + sin: &Tensor, + ) -> Result { + let _enter = self.span.enter(); + let device = hidden_states.device(); + + let qkv = self.qkv_linear.forward(hidden_states)?; + + let mut new_qkv_shape = qkv.dims().to_vec(); + new_qkv_shape.pop(); + new_qkv_shape.push(self.num_attention_heads * 3); + new_qkv_shape.push(self.attention_head_size); + let qkv = qkv.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?; + + let qkv = qkv.chunk(3, 1)?; + let query_layer = &qkv[0].contiguous()?; + let key_layer = &qkv[1].contiguous()?; + let value_layer = &qkv[2]; + + let query_layer = apply_rotary(query_layer, cos, sin, self.attention_head_size)?; + let key_layer = apply_rotary(key_layer, cos, sin, self.attention_head_size)?; + + #[allow(unused_variables)] + let context_layer = if let (Device::Cuda(_), Some(cublaslt)) = + (device, get_cublas_lt_wrapper()) + { + #[cfg(feature = "cuda")] + { + // cuBLASLt batch matmul implementation requires inputs to be dims3 + let (batch_size, _, seq_len, _) = key_layer.shape().dims4()?; + let key_layer = key_layer.flatten(0, 1)?; + let query_layer = query_layer.flatten(0, 1)?; + let value_layer = value_layer.flatten(0, 1)?; + let attention_bias = attention_bias.map(|mask| mask.flatten(0, 1)).transpose()?; + + // If attention_bias is set, we fuse the add by giving it as the output matrix + // and setting beta to 1.0 + let beta = match attention_bias.is_some() { + true => Some(1.0), + false => None, + }; + + // Batch matrix multiplication + // Fuse softmax scale and attention_bias add + let attention_scores = cublaslt.batch_matmul( + &key_layer, + &query_layer, + attention_bias.as_ref(), + Some(self.softmax_scale as f32), + beta, + None, + None, + )?; + let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?; + + let context_layer = cublaslt.batch_matmul( + &value_layer.t()?.contiguous()?, + &attention_probs, + // We save one allocation + Some(&query_layer), + None, + None, + None, + None, + )?; + + // Reshape to dims4 + context_layer.reshape(( + batch_size, + self.num_attention_heads, + seq_len, + self.attention_head_size, + )) + } + #[cfg(not(feature = "cuda"))] + { + candle::bail!("`cuda` feature is not enabled") + } + } else { + let attention_scores = query_layer.matmul(&key_layer.t()?)?; + let mut attention_scores = (attention_scores * self.softmax_scale)?; + + if let Some(attention_bias) = attention_bias { + attention_scores = attention_scores.add(attention_bias)?; + } + + let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?; + attention_probs.matmul(&value_layer.contiguous()?) + }?; + + let context_layer = context_layer.transpose(1, 2)?.flatten_from(D::Minus2)?; + + let hidden_states = self.o_proj.forward(&context_layer)?; + + Ok(hidden_states) + } +} + +#[allow(clippy::upper_case_acronyms)] +pub struct GTEMLP { + up_gate_proj: Linear, + down_proj: Linear, + + act: HiddenAct, + intermediate_size: usize, + + span: tracing::Span, +} + +impl GTEMLP { + pub fn load(vb: VarBuilder, config: >EConfig) -> Result { + let intermediate_size = config.intermediate_size; + + let up_gate_proj_weight = vb + .pp("up_gate_proj") + .get((intermediate_size * 2, config.hidden_size), "weight")?; + + let up_gate_proj = Linear::new(up_gate_proj_weight, None, None); + + let down_proj_weight = vb + .pp("down_proj") + .get((config.hidden_size, intermediate_size), "weight")?; + let down_proj_bias = vb.pp("down_proj").get(config.hidden_size, "bias")?; + let down_proj = Linear::new(down_proj_weight, Some(down_proj_bias), None); + + Ok(Self { + up_gate_proj, + down_proj, + intermediate_size, + act: config.hidden_act.clone(), + span: tracing::span!(tracing::Level::TRACE, "mlp"), + }) + } + + pub fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); + + let up_gate_states = self.up_gate_proj.forward(hidden_states)?; + let up_states = up_gate_states.narrow(D::Minus1, 0, self.intermediate_size)?; + let gate_states = + up_gate_states.narrow(D::Minus1, self.intermediate_size, self.intermediate_size)?; + + let gate_states = match self.act { + HiddenAct::Gelu => gate_states.gelu(), + HiddenAct::Relu => gate_states.relu(), + HiddenAct::Swiglu => gate_states.silu(), + }?; + + self.down_proj.forward(&(gate_states * up_states)?) + } +} + +pub struct GTELayer { + attention: GTEAttention, + mlp: GTEMLP, + attention_layer_norm: LayerNorm, + mlp_layer_norm: LayerNorm, + + span: tracing::Span, +} + +impl GTELayer { + pub fn load(vb: VarBuilder, config: >EConfig) -> Result { + let attention = GTEAttention::load(vb.pp("attention"), config)?; + let mlp = GTEMLP::load(vb.pp("mlp"), config)?; + + let attention_layer_norm = + LayerNorm::load(vb.pp("attn_ln"), config.hidden_size, config.layer_norm_eps)?; + let mlp_layer_norm = + LayerNorm::load(vb.pp("mlp_ln"), config.hidden_size, config.layer_norm_eps)?; + + Ok(Self { + attention, + mlp, + attention_layer_norm, + mlp_layer_norm, + span: tracing::span!(tracing::Level::TRACE, "layer"), + }) + } + + pub fn forward( + &self, + hidden_states: &Tensor, + attention_bias: Option<&Tensor>, + cos: &Tensor, + sin: &Tensor, + ) -> Result { + let _enter = self.span.enter(); + let attn_output = self + .attention + .forward(hidden_states, attention_bias, cos, sin)?; + + let normed_attn_res_output = self + .attention_layer_norm + .forward(&attn_output, Some(hidden_states))?; + + let mlp_output = self.mlp.forward(&normed_attn_res_output)?; + let normed_mlp_res_output = self + .mlp_layer_norm + .forward(&mlp_output, Some(&normed_attn_res_output))?; + Ok(normed_mlp_res_output) + } +} + +pub struct GTEClassificationHead { + pooler: Option, + classifier: Linear, + span: tracing::Span, +} + +impl GTEClassificationHead { + #[allow(dead_code)] + pub(crate) fn load(vb: VarBuilder, config: >EConfig) -> Result { + let n_classes = match &config.id2label { + None => candle::bail!("`id2label` must be set for classifier models"), + Some(id2label) => id2label.len(), + }; + + let pooler = if let Ok(pooler_weight) = vb + .pp("pooler.dense") + .get((config.hidden_size, config.hidden_size), "weight") + { + let pooler_bias = vb.pp("pooler.dense").get(config.hidden_size, "bias")?; + Some(Linear::new(pooler_weight, Some(pooler_bias), None)) + } else { + None + }; + + let classifier_weight = vb + .pp("classifier") + .get((n_classes, config.hidden_size), "weight")?; + let classifier_bias = vb.pp("classifier").get(n_classes, "bias")?; + let classifier = Linear::new(classifier_weight, Some(classifier_bias), None); + + Ok(Self { + classifier, + pooler, + span: tracing::span!(tracing::Level::TRACE, "classifier"), + }) + } + + pub(crate) fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); + + let mut hidden_states = hidden_states.unsqueeze(1)?; + if let Some(pooler) = self.pooler.as_ref() { + hidden_states = pooler.forward(&hidden_states)?; + hidden_states = hidden_states.tanh()?; + } + + let hidden_states = self.classifier.forward(&hidden_states)?; + let hidden_states = hidden_states.squeeze(1)?; + Ok(hidden_states) + } +} + +struct GTEEncoder { + layers: Vec, + span: tracing::Span, +} + +impl GTEEncoder { + pub fn load(vb: VarBuilder, config: >EConfig) -> Result { + let layers = (0..config.num_hidden_layers) + .map(|index| GTELayer::load(vb.pp(format!("layer.{index}")), config)) + .collect::>>()?; + + let span = tracing::span!(tracing::Level::TRACE, "encoder"); + Ok(GTEEncoder { layers, span }) + } + + fn forward( + &self, + hidden_states: &Tensor, + attention_bias: Option<&Tensor>, + cos: &Tensor, + sin: &Tensor, + ) -> Result { + let _enter = self.span.enter(); + + let mut hidden_states = hidden_states.clone(); + + // Use a loop rather than a fold as it's easier to modify when adding debug/... + for layer in self.layers.iter() { + hidden_states = layer.forward(&hidden_states, attention_bias, cos, sin)? + } + + Ok(hidden_states) + } +} + +pub struct GTEModel { + word_embeddings: Embedding, + token_type_embeddings: Option, + embeddings_norm: LayerNorm, + encoder: GTEEncoder, + dtype: DType, + rotary_cache: (Tensor, Tensor), + rotary_dim: usize, + classifier: Option, + pool: Pool, + pub device: Device, + + num_attention_heads: usize, + + span: tracing::Span, +} + +impl GTEModel { + pub fn load(vb: VarBuilder, config: >EConfig, model_type: ModelType) -> Result { + if config.logn_attention_clip1 { + candle::bail!("`logn_attention_clip1` is not supported"); + } + if config.logn_attention_scale { + candle::bail!("`logn_attention_scale` is not supported"); + } + + if config.position_embedding_type != PositionEmbeddingType::Rope { + candle::bail!("Only `PositionEmbeddingType::Rope` is supported"); + } + + let (pool, classifier) = match model_type { + ModelType::Classifier => { + let pool = Pool::Cls; + + let classifier = GTEClassificationHead::load(vb.clone(), config)?; + (pool, Some(classifier)) + } + ModelType::Embedding(pool) => (pool, None), + }; + + let word_embeddings = Embedding::new( + vb.pp("embeddings.word_embeddings") + .get((config.vocab_size, config.hidden_size), "weight")?, + config.hidden_size, + ); + + let token_type_embeddings = if config.type_vocab_size > 0 { + Some(Embedding::new( + vb.pp("embeddings.token_type_embeddings") + .get((config.type_vocab_size, config.hidden_size), "weight")?, + config.hidden_size, + )) + } else { + None + }; + + let encoder = GTEEncoder::load(vb.pp("encoder"), config)?; + + let embeddings_norm = LayerNorm::load( + vb.pp("embeddings.LayerNorm"), + config.hidden_size, + config.layer_norm_eps, + )?; + + let rotary_dim = encoder.layers[0].attention.attention_head_size; + let inv_freqs = get_inv_freqs( + rotary_dim, + config.rope_theta, + vb.device(), + config.rope_scaling.as_ref(), + )?; + + let rotary_cache = + get_cos_sin(config.max_position_embeddings, &inv_freqs, vb.dtype(), true)?; + + Ok(Self { + word_embeddings, + token_type_embeddings, + encoder, + embeddings_norm, + rotary_cache, + classifier, + pool, + num_attention_heads: config.num_attention_heads, + device: vb.device().clone(), + dtype: vb.dtype(), + span: tracing::span!(tracing::Level::TRACE, "model"), + rotary_dim, + }) + } + + pub fn forward(&self, batch: Batch) -> Result<(Option, Option)> { + let _enter = self.span.enter(); + + let batch_size = batch.len(); + let max_length = batch.max_length as usize; + + let shape = (batch_size, max_length); + + let (input_ids, type_ids, position_ids, input_lengths, attention_bias, attention_mask) = + if batch_size > 1 { + // Prepare padded batch + let elems = batch_size * max_length; + + let mut input_ids = Vec::with_capacity(elems); + let mut type_ids = Vec::with_capacity(elems); + let mut position_ids = Vec::with_capacity(elems); + let mut attention_mask = Vec::with_capacity(elems); + let mut attention_bias = Vec::with_capacity(elems); + let mut input_lengths = Vec::with_capacity(batch_size); + // Bool to know if we need to use the attention mask + let mut masking = false; + + for i in 0..batch_size { + let start = batch.cumulative_seq_lengths[i] as usize; + let end = batch.cumulative_seq_lengths[i + 1] as usize; + let seq_length = (end - start) as u32; + input_lengths.push(seq_length as f32); + + // Copy values + for j in start..end { + input_ids.push(batch.input_ids[j]); + type_ids.push(batch.token_type_ids[j]); + position_ids.push(batch.position_ids[j]); + attention_mask.push(1.0_f32); + attention_bias.push(0.0); + } + + // Add padding if needed + let padding = batch.max_length - seq_length; + if padding > 0 { + // Set bool to use attention mask + masking = true; + for _ in 0..padding { + input_ids.push(0); + type_ids.push(0); + position_ids.push(0); + attention_mask.push(0.0_f32); + attention_bias.push(f32::NEG_INFINITY); + } + } + } + + let (attention_bias, attention_mask) = match masking { + true => { + // We only need the mask if we use mean pooling + // For CLS pooling, the bias is enough + let attention_mask = if self.pool == Pool::Mean { + let attention_mask = Tensor::from_vec( + attention_mask, + (batch_size, max_length, 1), + &self.device, + )? + .to_dtype(self.dtype)?; + + Some(attention_mask) + } else { + None + }; + + let attention_bias = Tensor::from_vec( + attention_bias, + (batch_size, 1, 1, max_length), + &self.device, + )? + .to_dtype(self.dtype)?; + // Broadcast once instead of at every layer + let attention_bias = attention_bias + .broadcast_as(( + batch_size, + self.num_attention_heads, + max_length, + max_length, + ))? + .contiguous()?; + (Some(attention_bias), attention_mask) + } + false => (None, None), + }; + + ( + input_ids, + type_ids, + position_ids, + input_lengths, + attention_bias, + attention_mask, + ) + } else { + ( + batch.input_ids, + batch.token_type_ids, + batch.position_ids, + vec![batch.max_length as f32], + None, + None, + ) + }; + + // Create CPU tensors + let input_ids = Tensor::from_vec(input_ids, shape, &self.device)?; + let type_ids = Tensor::from_vec(type_ids, shape, &self.device)?; + let position_ids = Tensor::from_vec(position_ids, batch_size * max_length, &self.device)?; + let input_lengths = + Tensor::from_vec(input_lengths, (batch_size, 1), &self.device)?.to_dtype(self.dtype)?; + + let cos = self.rotary_cache.0.index_select(&position_ids, 0)?; + let sin = self.rotary_cache.1.index_select(&position_ids, 0)?; + + let cos = cos.reshape((batch_size, 1, max_length, self.rotary_dim))?; + let sin = sin.reshape((batch_size, 1, max_length, self.rotary_dim))?; + + let word_embeddings = self.word_embeddings.forward(&input_ids)?; + let token_type_embeddings = self + .token_type_embeddings + .as_ref() + .map(|emb| emb.forward(&type_ids)) + .transpose()?; + + let embedding_output = self + .embeddings_norm + .forward(&word_embeddings, token_type_embeddings.as_ref())?; + + let outputs = + self.encoder + .forward(&embedding_output, attention_bias.as_ref(), &cos, &sin)?; + + let has_pooling_requests = !batch.pooled_indices.is_empty(); + let has_raw_requests = !batch.raw_indices.is_empty(); + + let pooled_embeddings = if has_pooling_requests { + let pooled_indices_length = batch.pooled_indices.len(); + let mut outputs = outputs.clone(); + + // Only use pooled_indices if at least one member of the batch ask for raw embeddings + let pooled_indices = if has_raw_requests { + let pooled_indices = + Tensor::from_vec(batch.pooled_indices, pooled_indices_length, &self.device)?; + + // Select values in the batch + outputs = outputs.index_select(&pooled_indices, 0)?; + Some(pooled_indices) + } else { + None + }; + + let pooled_embeddings = match self.pool { + // CLS pooling + Pool::Cls => outputs.i((.., 0))?, + // Last token pooling is not supported for this model + Pool::LastToken => unreachable!(), + // Mean pooling + Pool::Mean => { + if let Some(ref attention_mask) = attention_mask { + let mut attention_mask = attention_mask.clone(); + + if let Some(pooled_indices) = pooled_indices { + // Select values in the batch + attention_mask = attention_mask.index_select(&pooled_indices, 0)?; + }; + + // Mask padded values + outputs = outputs.broadcast_mul(&attention_mask)?; + } + + (outputs.sum(1)?.broadcast_div(&input_lengths))? + } + Pool::Splade => unreachable!(), + }; + Some(pooled_embeddings) + } else { + None + }; + + let raw_embeddings = if has_raw_requests { + // Reshape outputs + let (b, l, h) = outputs.shape().dims3()?; + let outputs = outputs.reshape((b * l, h))?; + + // We need to remove the padding tokens only if batch_size > 1 and there are some + // member of the batch that require pooling + // or if batch_size > 1 and the members of the batch have different lengths + if (attention_mask.is_some() || has_pooling_requests) && batch_size > 1 { + let mut final_indices: Vec = Vec::with_capacity(batch_size * max_length); + + for i in batch.raw_indices.into_iter() { + let start = i * batch.max_length; + let i = i as usize; + let length = + batch.cumulative_seq_lengths[i + 1] - batch.cumulative_seq_lengths[i]; + + for j in start..start + length { + // Add indices for the tokens of this specific member of the batch + final_indices.push(j); + } + } + + let final_indices_length = final_indices.len(); + let final_indices = + Tensor::from_vec(final_indices, final_indices_length, &self.device)?; + + // Select the tokens with final indices + Some(outputs.index_select(&final_indices, 0)?) + } else { + Some(outputs) + } + } else { + None + }; + + Ok((pooled_embeddings, raw_embeddings)) + } +} + +impl Model for GTEModel { + fn is_padded(&self) -> bool { + true + } + + fn embed(&self, batch: Batch) -> Result<(Option, Option)> { + self.forward(batch) + } + + fn predict(&self, batch: Batch) -> Result { + match &self.classifier { + None => candle::bail!("`predict` is not implemented for this model"), + Some(classifier) => { + let (pooled_embeddings, _raw_embeddings) = self.forward(batch)?; + let pooled_embeddings = + pooled_embeddings.expect("pooled_embeddings is empty. This is a bug."); + classifier.forward(&pooled_embeddings) + } + } + } +} diff --git a/backends/candle/src/models/mod.rs b/backends/candle/src/models/mod.rs index b1e9f937..c1107d66 100644 --- a/backends/candle/src/models/mod.rs +++ b/backends/candle/src/models/mod.rs @@ -40,7 +40,7 @@ pub use bert::{BertConfig, BertModel, PositionEmbeddingType}; use candle::{Result, Tensor}; pub use distilbert::{DistilBertConfig, DistilBertModel}; #[allow(unused_imports)] -pub use gte::{GTEConfig, NTKScaling, RopeScaling}; +pub use gte::{GTEClassificationHead, GTEConfig, GTEModel, GTEMLP}; pub use jina::JinaBertModel; pub use jina_code::JinaCodeBertModel; pub use mistral::MistralConfig; diff --git a/backends/candle/src/models/nomic.rs b/backends/candle/src/models/nomic.rs index cdaaea92..c860cfbc 100644 --- a/backends/candle/src/models/nomic.rs +++ b/backends/candle/src/models/nomic.rs @@ -1,4 +1,6 @@ -use crate::layers::{get_cublas_lt_wrapper, HiddenAct, LayerNorm, Linear}; +use crate::layers::{ + apply_rotary, get_cos_sin, get_cublas_lt_wrapper, get_inv_freqs, HiddenAct, LayerNorm, Linear, +}; use crate::models::Model; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::{Embedding, VarBuilder}; @@ -176,15 +178,6 @@ impl NomicAttention { }) } - fn apply_rotary(&self, x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result { - let dim = self.attention_head_size / 2; - let x1 = x.narrow(D::Minus1, 0, dim)?; - let x2 = x.narrow(D::Minus1, dim, dim)?; - let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?; - let rope = (x.broadcast_mul(cos)? + rotate_x.broadcast_mul(sin)?)?; - Ok(rope) - } - pub fn forward( &self, hidden_states: &Tensor, @@ -208,8 +201,8 @@ impl NomicAttention { let key_layer = &qkv[1].contiguous()?; let value_layer = &qkv[2]; - let query_layer = self.apply_rotary(query_layer, cos, sin)?; - let key_layer = self.apply_rotary(key_layer, cos, sin)?; + let query_layer = apply_rotary(query_layer, cos, sin, self.attention_head_size)?; + let key_layer = apply_rotary(key_layer, cos, sin, self.attention_head_size)?; #[allow(unused_variables)] let context_layer = if let (Device::Cuda(_), Some(cublaslt)) = @@ -416,8 +409,9 @@ impl NomicBertModel { let encoder = NomicBertEncoder::load(vb.pp("encoder"), config)?; let rotary_dim = encoder.layers[0].attention.attention_head_size; - let inv_freqs_tensor = inv_freqs(rotary_dim, config.rotary_emb_base, vb.device())?; - let rotary_cache = cos_sin(config.n_positions, &inv_freqs_tensor, vb.dtype())?; + let inv_freqs_tensor = + get_inv_freqs(rotary_dim, config.rotary_emb_base, vb.device(), None)?; + let rotary_cache = get_cos_sin(config.n_positions, &inv_freqs_tensor, vb.dtype(), true)?; let scaled_rotary_cache = if let Some(scaling_factor) = config.rotary_scaling_factor { let new_base = (config.rotary_emb_base @@ -425,8 +419,13 @@ impl NomicBertModel { / config.max_trained_positions as f32) - (scaling_factor - 1.0))) .powi((rotary_dim as f32 / (rotary_dim as f32 - 2.0)) as i32); - let inv_freqs_tensor = inv_freqs(rotary_dim, new_base, vb.device())?; - Some(cos_sin(config.n_positions, &inv_freqs_tensor, vb.dtype())?) + let inv_freqs_tensor = get_inv_freqs(rotary_dim, new_base, vb.device(), None)?; + Some(get_cos_sin( + config.n_positions, + &inv_freqs_tensor, + vb.dtype(), + true, + )?) } else { None }; @@ -678,31 +677,11 @@ impl NomicBertModel { } } -pub fn inv_freqs(dim: usize, base: f32, device: &Device) -> Result { - let inv_freq: Vec<_> = (0..dim) - .step_by(2) - .map(|i| 1f32 / base.powf(i as f32 / dim as f32)) - .collect(); - let inv_freq_len = inv_freq.len(); - Tensor::from_vec(inv_freq, (1, inv_freq_len), device) -} - -pub fn cos_sin(length: usize, inv_freqs: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> { - let t = Tensor::arange(0u32, length as u32, inv_freqs.device())? - .to_dtype(DType::F32)? - .reshape((length, 1))?; - let freqs = t.matmul(inv_freqs)?; - let freqs = Tensor::cat(&[&freqs, &freqs], 1)?; - - let cos = freqs.cos()?.to_dtype(dtype)?; - let sin = freqs.sin()?.to_dtype(dtype)?; - Ok((cos, sin)) -} - impl Model for NomicBertModel { fn is_padded(&self) -> bool { - false + true } + fn embed(&self, batch: Batch) -> Result<(Option, Option)> { self.forward(batch) } diff --git a/backends/candle/tests/snapshots/test_gte__gte_batch.snap b/backends/candle/tests/snapshots/test_gte__gte_batch.snap new file mode 100644 index 00000000..5cbdcbd1 --- /dev/null +++ b/backends/candle/tests/snapshots/test_gte__gte_batch.snap @@ -0,0 +1,2309 @@ +--- +source: backends/candle/tests/test_gte.rs +assertion_line: 34 +expression: embeddings_batch +--- +- - 0.44259304 + - 0.61034733 + - 0.06906447 + - -0.5930369 + - 0.88121295 + - -0.108871564 + - 0.089603834 + - -0.20720902 + - 1.9351525 + - 0.6233138 + - 0.5944155 + - 0.18624856 + - 0.029520601 + - -0.077266246 + - 0.7963188 + - -0.23349458 + - -0.872522 + - 1.692661 + - 0.12729777 + - -1.925918 + - -0.3450846 + - -1.9348283 + - 0.90229857 + - -0.19892806 + - 0.024524756 + - 0.28779918 + - 0.36588693 + - 0.17927459 + - -0.5102686 + - 0.02657038 + - -0.89667296 + - 0.14085183 + - -0.39250132 + - -0.8478426 + - 0.004012823 + - 0.42249972 + - 0.23077106 + - 1.1616367 + - -1.1474063 + - 0.3378647 + - -2.1478326 + - -1.0142069 + - -0.28349385 + - 0.021243036 + - 0.1642712 + - 0.39428475 + - -0.9525956 + - 0.4098242 + - -0.15686427 + - -0.4859926 + - -0.071725324 + - 0.072733976 + - 0.5592042 + - 0.027051464 + - -0.028886497 + - 0.4373799 + - 1.0145298 + - 0.7788421 + - -0.56286293 + - -1.9803355 + - -0.010913014 + - 0.5070983 + - -0.39306438 + - -0.33362845 + - -1.1662877 + - -0.07186741 + - 0.008138694 + - 0.77821803 + - -0.10029823 + - -0.30778 + - 0.3581222 + - -0.44790483 + - -0.08573072 + - -1.0794868 + - 1.3211484 + - -1.0267495 + - 0.34170142 + - -0.45228508 + - 0.80183166 + - 0.50574994 + - 0.70227575 + - -0.96536624 + - -0.042180635 + - -1.0936062 + - 0.35129213 + - -0.5666409 + - -0.5346388 + - 0.23489185 + - -2.0245805 + - -0.26439658 + - 0.19770235 + - 0.7733333 + - 0.17970699 + - 1.2246317 + - 0.9814428 + - -0.2903474 + - 1.2074271 + - 0.501272 + - 0.24749207 + - -0.21642533 + - 0.2906228 + - -0.49979872 + - 0.26415095 + - 0.501433 + - -0.042981416 + - -0.014059275 + - -0.63753927 + - 0.64429754 + - -1.7964667 + - 0.25490916 + - -0.20513767 + - -0.12913927 + - 0.1964828 + - 0.8092668 + - -0.9696904 + - 0.26530862 + - -0.56467396 + - -0.5872178 + - -0.15478744 + - 0.7404878 + - 0.5340329 + - 0.60305643 + - 1.1866584 + - -0.3774311 + - 0.81875145 + - -0.24098384 + - 0.4060312 + - 0.17040129 + - -0.32154626 + - -0.7550312 + - 0.057789713 + - 0.16519901 + - 0.68956393 + - -0.617511 + - -0.0531684 + - 0.42995512 + - -0.607028 + - 0.040061504 + - -0.86182916 + - -0.7180543 + - 0.7401805 + - -0.42198357 + - 0.26968056 + - -1.5444765 + - 0.16945669 + - 1.217624 + - -0.17086306 + - -0.1421303 + - -0.5469396 + - -0.59001184 + - -0.7784361 + - -0.18679349 + - -0.34887093 + - 0.6221198 + - -0.58943474 + - -0.9040499 + - -0.73156124 + - -0.82022953 + - 0.48355055 + - -1.1594696 + - 1.2104084 + - -1.0509471 + - 0.2527444 + - 1.2852336 + - 0.06373514 + - -0.7195401 + - -0.3677724 + - 0.073669694 + - -0.021352403 + - -0.37145522 + - -0.0073406324 + - 0.3583467 + - -0.8348328 + - -0.59604347 + - -0.40265262 + - -0.02693659 + - -0.31148252 + - -1.1090808 + - -0.030301034 + - 0.13976327 + - -1.2136879 + - 0.035320282 + - 0.3535554 + - -1.5688139 + - 0.17595965 + - -0.6659802 + - -1.9802649 + - -0.35432598 + - -0.29395905 + - -1.2038764 + - -0.41608042 + - 0.36699823 + - -0.05091913 + - -0.8751462 + - -0.9087274 + - 0.4801116 + - -1.4586933 + - 0.78772914 + - -0.27553022 + - 0.08574577 + - 0.22304389 + - 0.001901783 + - -1.0914774 + - -0.0072548836 + - -0.62113667 + - -0.68177116 + - 0.45813638 + - -0.6160619 + - -0.7856389 + - -1.3864756 + - 1.5577792 + - 0.26990858 + - 0.84612846 + - -0.25666517 + - -0.7653059 + - 0.64927864 + - -0.05115719 + - -0.13208972 + - -0.11831948 + - -0.6354545 + - -0.1780964 + - 0.9480155 + - -0.4961671 + - 0.017041028 + - -0.23073995 + - -1.6752698 + - -0.8910312 + - 0.8672527 + - 1.0692105 + - 0.34349495 + - 0.7191764 + - -0.54913974 + - -0.4529444 + - -0.30141753 + - 0.45244497 + - -0.2881953 + - -0.76241916 + - 0.01701805 + - 0.6307543 + - 0.47191522 + - -0.07387722 + - 1.7862662 + - -0.41695195 + - -0.89335895 + - 0.07749042 + - -0.29598492 + - -1.0674714 + - 1.3289616 + - 1.417427 + - 0.20779747 + - -0.10695306 + - 0.15057519 + - 0.37317067 + - 0.05173081 + - -0.68831 + - 0.73900473 + - 0.15694031 + - 0.55931574 + - -0.819237 + - -0.17727043 + - -0.24614926 + - -0.06498201 + - 0.4203459 + - 1.0529735 + - -1.6634095 + - 0.902164 + - 0.22976208 + - -0.5899435 + - -0.7980481 + - -0.0048475564 + - -0.47341445 + - 0.19137529 + - -0.38998026 + - 0.6349899 + - -0.28744707 + - -0.87318814 + - 0.07867032 + - -1.2608685 + - 0.2286636 + - 0.5503384 + - -0.47462064 + - -0.20548862 + - 0.75510746 + - -1.0513641 + - 0.99237406 + - 1.2279786 + - 0.4101724 + - -0.36645734 + - 0.5490203 + - -0.3264212 + - 0.21975031 + - 1.1309171 + - -0.01674299 + - -0.6710228 + - 0.7215167 + - 0.20943573 + - 0.555139 + - 0.71108764 + - -0.16624965 + - 0.5822153 + - 1.149294 + - -0.16024849 + - -0.4955896 + - 0.016667455 + - 1.5621985 + - -0.50284207 + - 1.2017282 + - -1.002291 + - -0.35571888 + - -0.31098118 + - -0.12681499 + - 0.821267 + - 0.7572354 + - 1.1620345 + - 1.6657145 + - -0.80179685 + - -0.3032667 + - 0.3809554 + - -0.7134949 + - -0.74646837 + - 1.1901687 + - -0.9586203 + - 0.30220306 + - 0.1696356 + - 0.8158542 + - -0.8794036 + - 1.1385472 + - -0.45451683 + - 0.6137817 + - -0.47705573 + - 1.1183856 + - 0.014434105 + - -0.49392825 + - -0.039757848 + - 0.28681982 + - -0.71450067 + - 0.8194344 + - -0.0384483 + - 0.6028884 + - 0.75745547 + - -0.33422586 + - 0.86565375 + - -0.71517825 + - 0.7594802 + - -1.3496755 + - 0.18592727 + - -0.39923215 + - 0.3111456 + - 0.16714048 + - 1.2608763 + - -0.42922565 + - 0.24328378 + - 0.83976495 + - -1.3976977 + - 0.3450549 + - -0.6845992 + - -0.28764033 + - 0.46697864 + - 0.681257 + - -0.48346725 + - -0.10998453 + - -0.48293424 + - 0.17039993 + - 0.29646662 + - -0.72920346 + - -0.8712541 + - 0.84642005 + - 0.57345843 + - 0.16414955 + - -0.17650697 + - -0.1654857 + - 0.8680194 + - -0.38802695 + - 0.56318635 + - -0.30581528 + - 1.8246955 + - 0.041201115 + - 0.1730811 + - -1.0117795 + - -0.20816453 + - 1.1007582 + - 0.7987681 + - -1.0011376 + - 0.6699697 + - -0.96139586 + - 0.4520412 + - -0.3227459 + - -0.9787295 + - 1.7064301 + - -0.045380563 + - 0.44003445 + - 0.2881473 + - 0.32060444 + - -0.53330564 + - -0.7458609 + - -1.2799888 + - -0.20681454 + - -0.5788567 + - 0.87884283 + - -0.21842682 + - -0.5110508 + - 0.20337737 + - -0.9529271 + - -0.8920894 + - 0.62817466 + - -0.089243844 + - 1.2506601 + - -0.30859876 + - 0.52464736 + - -1.577458 + - 0.86506957 + - -0.16180237 + - -0.6451876 + - -0.8997455 + - -0.689568 + - -0.5550809 + - 0.54879844 + - -0.20567037 + - 1.0816867 + - 0.53907245 + - 0.24640642 + - -0.00500758 + - 0.0994609 + - -0.082028165 + - 0.6633665 + - -0.06771461 + - -0.8126046 + - -0.08276771 + - 0.122273535 + - -1.4077682 + - 0.65344125 + - 0.49604583 + - 0.48285925 + - -0.61326075 + - 0.059502434 + - 0.4968239 + - -0.3398038 + - -0.85002804 + - -0.8513342 + - 0.10036299 + - 0.64530313 + - 0.62831557 + - 0.78928226 + - -0.85008234 + - -0.766243 + - 0.39012492 + - 0.5910649 + - -0.07126428 + - -0.13246185 + - -0.89592844 + - -0.7658062 + - 1.312253 + - 0.9795044 + - 0.37078738 + - 0.047516555 + - -1.287435 + - 0.17157432 + - -0.16686267 + - -0.10688098 + - -0.11172397 + - 0.20267075 + - -0.6552357 + - 0.2800237 + - -1.8037435 + - 0.1508436 + - 0.24947366 + - 0.67023605 + - 0.67530817 + - -1.033367 + - -0.41353673 + - -0.59349984 + - 0.028830133 + - -1.3523995 + - 0.12497628 + - 0.87255096 + - 0.87087107 + - -0.18449068 + - 0.79595155 + - -0.70120776 + - -0.28769892 + - 0.3902542 + - 0.43118823 + - -0.051839836 + - -0.6050477 + - 0.6396151 + - 0.20942315 + - 0.09991099 + - -0.7391164 + - 0.24729279 + - -0.41480076 + - -0.49494857 + - -0.78128123 + - -0.35835093 + - -0.53758746 + - -0.13555299 + - 0.03073129 + - 0.47043937 + - 0.16701278 + - 1.0457106 + - 0.702593 + - 0.32327354 + - 0.018711254 + - -0.2596096 + - 1.4011946 + - 0.33755624 + - -0.06543507 + - 0.8141793 + - -0.7645023 + - -0.532987 + - -0.1951265 + - 0.19335829 + - -0.61174417 + - 0.077225186 + - -0.6988332 + - -0.3897977 + - 0.37915218 + - 0.59615886 + - -1.9831455 + - -0.3679344 + - -2.023395 + - 0.9688585 + - 0.46854872 + - -0.49805546 + - -0.208342 + - -0.15163574 + - -0.1377097 + - 1.878253 + - -0.18728009 + - 0.38123053 + - -0.13824378 + - -0.2954961 + - 0.42763495 + - -0.24603358 + - 0.1628322 + - 0.842581 + - -0.34907857 + - 0.5294959 + - -0.14574759 + - 0.77288246 + - -0.11421237 + - -0.23344792 + - 1.2086263 + - -1.5588763 + - 0.09942496 + - -1.3481518 + - 0.61870027 + - -1.1478713 + - -0.46976572 + - -0.10636251 + - -1.3085406 + - 0.43214953 + - 0.05524394 + - 0.8000935 + - 0.43331698 + - -0.22318701 + - 0.5597524 + - -8.14653 + - -0.7232454 + - -0.4829777 + - -0.21355215 + - 0.51938415 + - 0.54633397 + - 0.85073006 + - -0.88464475 + - 0.36808544 + - -0.16894472 + - -0.32347974 + - 0.12091856 + - 0.9254223 + - 0.38929504 + - 0.79355645 + - 0.46859992 + - 0.92062116 + - 0.9964036 + - -0.32657716 + - 1.2908907 + - -0.15256053 + - 0.18657881 + - -0.6105832 + - 0.061345264 + - 0.35378 + - -0.792165 + - -0.39395148 + - 1.2144314 + - -0.804891 + - 0.13545582 + - 0.21137342 + - -0.6654251 + - 1.5949675 + - 0.4424338 + - -0.24143839 + - -0.39792645 + - -1.0507987 + - -0.36413527 + - -0.5976953 + - 0.6054765 + - 1.2283214 + - 0.21500719 + - -0.18326366 + - -0.0065856427 + - -0.08184518 + - 0.8007397 + - -0.8180905 + - -1.6634116 + - -1.2854487 + - -0.85243595 + - 0.15832597 + - 0.621912 + - 0.6053898 + - 0.5224174 + - -0.93616796 + - 0.75235915 + - 0.06615754 + - -0.0023372173 + - -0.9213737 + - -0.2548948 + - -0.5085285 + - -0.7929628 + - -0.4497529 + - -0.6456441 + - 0.85625625 + - 0.120143905 + - -0.08442418 + - -0.5129668 + - 0.16538239 + - 1.6077311 + - -0.3928744 + - 0.34487766 + - -0.43184227 + - -0.04122834 + - 0.29670107 + - -0.51762897 + - 0.65680164 + - 0.57277584 + - 0.66948354 + - 0.1398776 + - 0.9790336 + - 0.16417754 + - -0.1196354 + - -0.8737719 + - 0.058854774 + - 0.24216175 + - -0.01265154 + - 0.40640974 + - -0.3477264 + - -0.19956756 + - 0.4118901 + - -0.17779568 + - -0.011484221 + - 0.10376629 + - -0.10691896 + - 0.037767142 + - 1.1812952 + - -0.689515 + - 0.73829395 + - -0.76737964 + - -0.35432225 + - 0.61250883 + - -0.9324362 + - -0.49049675 + - -0.29375598 + - -1.473715 + - -1.0492535 + - -1.2494478 + - -0.5478609 + - -1.4896125 + - 0.4817301 + - -0.79216903 + - -0.24467322 + - -0.55176955 + - -0.083230525 + - -0.47753918 + - 0.18810439 + - 0.47883877 + - -0.2106834 + - -0.83237624 + - -0.10659266 + - 0.043049127 + - -0.098483756 + - 0.5833407 + - -0.8094338 + - -1.1513048 + - 0.19241913 + - 0.50195277 + - -0.8929118 + - -0.03716117 + - -0.44107184 + - 0.5030349 + - 1.4838967 + - -0.876173 + - -0.3227269 + - -0.98616767 + - -0.09206166 + - 0.48493135 + - -0.27781898 + - 0.80406904 + - -0.56890416 + - -0.25707406 + - 0.750425 + - -0.37303916 + - 0.21937534 + - 0.37374908 + - -0.032264948 + - -0.40764412 + - -0.18038614 + - -0.48258096 + - -1.5081406 + - -0.673682 + - 0.8215298 + - 0.25156042 + - -0.82655954 + - 0.3397631 + - -1.2644277 + - 0.8678608 + - -0.27307308 + - -0.34458882 + - -0.4250061 + - -0.4120714 + - 0.8038123 + - 1.1515579 + - -1.2057272 + - 0.3560321 + - 0.0020849928 + - -1.9514849 + - -0.60190904 + - 0.5822501 + - 0.085682176 + - 0.48351315 + - -0.3535313 + - -0.11059061 + - 0.65916437 + - -1.491356 + - -0.14152524 + - -0.21427855 + - 1.0252383 + - -1.4469419 + - 0.56110775 + - 1.7371856 + - 0.095697284 + - -0.5412806 + - -0.13305002 + - 0.7577597 + - -0.27158695 + - 1.0489044 + - 1.0496688 + - 0.23123285 + - -0.15137509 + - 0.17805772 + - 0.09740275 + - -1.6383355 + - -0.60187536 + - -0.9171309 + - -0.93706274 + - 1.0573325 + - 0.702143 + - 0.7050025 + - 0.8193582 + - -2.2506168 + - -0.6361895 + - 0.35078734 + - 0.750742 + - -0.12234919 + - -1.0858344 + - 0.5054128 + - -0.68131065 + - 0.7491392 + - -0.27770603 + - -0.13498452 +- - 0.040542047 + - 0.42446017 + - 0.19612727 + - -0.49574697 + - 0.59192526 + - -0.17090613 + - 1.3935621 + - 0.14561012 + - 1.8730415 + - 0.5819394 + - 0.75708574 + - 0.2834056 + - 0.026276648 + - -0.612317 + - 0.2435061 + - -0.13633023 + - -0.44703823 + - 1.1765193 + - 0.5575247 + - -1.8688344 + - -0.08647424 + - -1.2636399 + - 0.7175725 + - -0.32006907 + - 0.11502649 + - 0.9962714 + - 0.095648795 + - 0.47192067 + - 0.20051816 + - 0.5051677 + - -0.07757187 + - -0.027173162 + - -0.56047416 + - -0.12764451 + - -0.79288316 + - 0.19989747 + - 0.9721345 + - 0.75865805 + - -1.6696091 + - -0.16326287 + - -1.5734959 + - -0.82405263 + - 0.3784384 + - -0.10479588 + - -0.024306625 + - 0.5201199 + - -1.2190377 + - -0.67159575 + - -0.28263557 + - -0.20817214 + - 0.19708586 + - 0.6477342 + - 1.0561608 + - 0.2456415 + - 0.33866727 + - 0.76540446 + - 0.8597696 + - 0.20934343 + - -1.0002661 + - -1.0500708 + - -0.0119538605 + - 0.55093664 + - -0.47561547 + - -0.85928965 + - -0.9383173 + - -0.46921828 + - 0.21268028 + - 1.3986766 + - 0.66930616 + - -0.48944074 + - -0.20923229 + - -0.5341331 + - 0.22264881 + - -1.1056857 + - 0.73507303 + - -1.2791632 + - 0.24249494 + - -0.070151746 + - 1.0916647 + - 0.6843605 + - 0.946941 + - -0.51761365 + - -0.14272289 + - -0.64535344 + - 0.32108337 + - 0.2883522 + - -0.08589348 + - -0.12803522 + - -1.9984871 + - 0.039597828 + - 0.6024853 + - -0.102134004 + - -0.09537558 + - 0.9750749 + - 0.023649722 + - -0.34056672 + - 1.1302195 + - 1.0534015 + - 0.1943338 + - -0.042979673 + - 0.95763 + - -0.3730421 + - 0.20977688 + - 0.3162164 + - 0.4421252 + - 0.08765191 + - -1.1031553 + - 0.9369173 + - -1.3849291 + - 0.39046854 + - -0.14114913 + - -0.58165646 + - -0.4358631 + - 0.7927263 + - -0.6231432 + - 0.6353289 + - -0.6151411 + - -1.3925022 + - 0.3191234 + - 1.153558 + - -0.022324026 + - 0.8212459 + - 0.59790176 + - -0.23713678 + - 0.80506814 + - -0.68106025 + - 0.25533098 + - -0.11740266 + - -0.15243039 + - 0.12769023 + - -0.32505274 + - -0.277374 + - 0.95182633 + - -1.3591958 + - -0.55481017 + - -0.19405249 + - -1.0228506 + - 0.1481619 + - -1.7020165 + - -0.3328117 + - 0.5080841 + - -0.35595414 + - 0.03825751 + - -1.304559 + - 0.016906783 + - 1.2602384 + - -0.46151522 + - 0.3441567 + - -0.8569888 + - -0.9403903 + - -0.55002195 + - 0.28456277 + - -0.122221425 + - 0.11260388 + - -0.40636685 + - -1.0888083 + - -0.5419709 + - 0.74483144 + - 0.06755635 + - -1.5802133 + - 1.1576959 + - -0.8461009 + - -0.10474412 + - 1.1588538 + - 0.08785787 + - -0.50101334 + - -0.071625896 + - 1.0765219 + - 0.62950575 + - 0.09576297 + - -0.37051874 + - 0.18491 + - -0.5340926 + - -0.8450982 + - 0.57809496 + - -0.4731087 + - -0.5351522 + - -0.21453705 + - 0.31386244 + - 0.50930226 + - -0.8011966 + - -0.0983945 + - -0.46540493 + - 0.23230535 + - -0.10790403 + - -0.9376359 + - -0.94408643 + - 0.22340468 + - -1.319359 + - -1.2912819 + - -1.0287464 + - 0.74044275 + - -1.0781026 + - -0.066728905 + - -1.1982046 + - 0.09869842 + - -1.1687448 + - 0.7046962 + - -0.51089966 + - -0.03540148 + - -0.6077792 + - 0.1726652 + - -0.26431453 + - 0.31830972 + - -1.275144 + - 0.012916952 + - 1.4780508 + - -0.107660726 + - -0.35998982 + - -0.28906646 + - 1.0382671 + - 0.37620866 + - 0.979141 + - -0.31594443 + - -0.72480243 + - 1.1751336 + - -0.35315937 + - -0.15831214 + - -0.087500334 + - 0.6445451 + - 0.036330044 + - 1.5773168 + - -0.35828033 + - -0.60775995 + - -0.05082602 + - -0.7292873 + - -0.16795228 + - 0.6533201 + - 1.0814126 + - -0.26549873 + - 0.11587769 + - -0.53632206 + - -0.8153305 + - -1.0973145 + - 1.305935 + - -0.33922243 + - -0.23622292 + - -0.07103914 + - 0.83976716 + - 0.5669034 + - -0.747283 + - 1.4341493 + - -0.80133086 + - -0.65078735 + - 0.42771965 + - 0.036091834 + - -0.031457514 + - 0.79181737 + - 1.0453004 + - 0.6671987 + - 0.57129097 + - 0.33257678 + - -0.31220168 + - 0.028336018 + - 0.036877736 + - 0.63769907 + - 0.06442785 + - 0.1916869 + - -0.6875907 + - 0.11692232 + - -0.4979475 + - 0.12517287 + - 0.03971743 + - 0.3159997 + - -1.3636955 + - 0.6890978 + - 0.8308013 + - -0.37713373 + - -0.921148 + - 0.04009837 + - 0.00234738 + - -0.06769785 + - -0.0109509975 + - 0.7979109 + - -0.16697764 + - -0.45743233 + - -0.7114867 + - -1.1598476 + - -0.06697742 + - 1.1189016 + - -0.4458464 + - 0.3402232 + - 0.39564887 + - -0.28989658 + - 1.0520576 + - 0.26660642 + - 0.83192134 + - -0.8922532 + - 0.15154937 + - -0.84663576 + - 0.78464085 + - 1.3570566 + - 0.04831408 + - 0.35716096 + - -0.876679 + - -0.2748221 + - 0.2898292 + - 0.14646137 + - -0.00055469875 + - 0.043210298 + - 0.6570087 + - 0.08781728 + - -0.004120171 + - -0.18773207 + - 0.7393062 + - -0.5472184 + - 0.46962512 + - -1.1814021 + - -0.23504622 + - -0.36901024 + - -0.16570868 + - 0.5737513 + - 0.38490272 + - 1.0040915 + - 0.897957 + - 0.17298329 + - 0.0817447 + - -0.33141097 + - -0.06994173 + - -0.6300951 + - 0.9652815 + - -0.8159362 + - 0.6373207 + - 0.38391468 + - 1.013145 + - -1.4294459 + - 0.21068448 + - -0.4663118 + - 0.5134508 + - -0.47731888 + - 0.91747534 + - 0.058016457 + - 0.62999165 + - 0.18103294 + - 0.08587863 + - -1.2220843 + - 0.36004192 + - -0.6068393 + - 0.5384738 + - -0.042775482 + - -1.0527043 + - 1.5556973 + - -0.35064143 + - 0.9070766 + - -0.7491555 + - 0.30506378 + - -0.41821223 + - 0.24058324 + - -0.19834441 + - 0.9592104 + - -0.26446548 + - 0.17827874 + - 0.028537348 + - -0.20097929 + - 0.62165433 + - -1.256717 + - 0.7302991 + - 1.1919225 + - 0.38807157 + - -0.49210227 + - -0.33670822 + - 0.08400102 + - 0.21132174 + - -0.0960822 + - 0.30945513 + - -1.489921 + - 0.69584584 + - 0.16356926 + - 0.15066177 + - -0.15670677 + - -0.73829406 + - 0.9786203 + - -0.22559598 + - 0.79444754 + - -0.40620157 + - 1.8994063 + - 0.05238232 + - 0.022787318 + - -0.7207821 + - -0.63105077 + - -0.07373941 + - 0.08858102 + - 0.2254899 + - 0.035557143 + - -0.83375895 + - -0.08202344 + - 0.078769326 + - -1.4408897 + - 1.2387671 + - 0.73968107 + - 0.24625573 + - -0.25635993 + - 0.2950964 + - -0.8886062 + - -0.35395458 + - -0.6570952 + - 0.048298016 + - -0.12896807 + - 0.13518633 + - -0.60169005 + - -0.5090703 + - -0.39904281 + - -1.6808645 + - -0.6444566 + - 0.24464676 + - -0.20071042 + - 0.62357026 + - -0.85902226 + - 0.10668732 + - -1.0630857 + - 0.84102297 + - -0.48261717 + - -0.3038382 + - -0.38693175 + - -0.051314652 + - -1.2077255 + - -0.43158287 + - 0.121252656 + - 0.81028104 + - 0.21713665 + - -0.026761547 + - -0.10355612 + - -0.44182804 + - -0.05141093 + - 0.6770257 + - -0.7959793 + - -0.70283103 + - -0.67285895 + - -0.17379731 + - -1.369148 + - -0.36256412 + - -0.8557293 + - 1.3082734 + - 0.22725204 + - 0.2779716 + - -0.017771516 + - -0.6698593 + - -0.92055225 + - -0.2845761 + - -0.45889273 + - -0.20485726 + - -0.3796129 + - 1.0600024 + - -0.877014 + - -1.0195327 + - 0.13749537 + - 0.15701494 + - 0.57870054 + - -0.23597258 + - -1.0602785 + - -0.54065603 + - 0.9377162 + - 1.0084484 + - -0.86765385 + - 0.090471 + - -0.8862106 + - 0.5238735 + - 0.089417145 + - -0.121034116 + - -0.8703614 + - 1.0850173 + - -0.46411812 + - 0.635017 + - -1.7350965 + - -0.42429128 + - -0.016495615 + - 0.40584397 + - -0.074434154 + - -0.12170325 + - -0.1040256 + - 0.29695147 + - 0.6982202 + - -1.0150762 + - 0.53852797 + - -0.5113349 + - 1.1896764 + - -0.11131793 + - 1.312053 + - -0.8331335 + - 0.39268425 + - -0.09320095 + - 0.20962188 + - 0.24455094 + - 0.15482333 + - 0.15223034 + - -0.65285224 + - -0.39027828 + - -0.8405557 + - 0.15650177 + - -0.35528293 + - -0.1969912 + - 0.074905664 + - -0.2392129 + - -0.366742 + - -0.59970415 + - 0.29807574 + - -0.14679213 + - -0.4383429 + - 0.9272219 + - 0.24143872 + - 0.18309069 + - -0.3880612 + - -0.54499936 + - 0.6727917 + - -0.20499715 + - 0.70841116 + - 0.19495714 + - 0.31201062 + - -0.1799081 + - 0.83613473 + - 0.09103653 + - -0.14636561 + - 0.08192707 + - 0.017090693 + - 0.27483153 + - 0.50218225 + - -0.6352961 + - -2.8496172 + - 0.09823185 + - -1.2781833 + - 0.78997225 + - 0.3299108 + - -0.5872504 + - 0.054005653 + - -0.49273044 + - 0.03493613 + - 1.3219953 + - -0.3794935 + - -0.009906277 + - 0.12828617 + - 0.3948732 + - 0.406583 + - -0.5529195 + - 0.07184595 + - 0.16450629 + - -0.008467749 + - 0.17641145 + - -0.22219446 + - 0.93332946 + - -0.030603673 + - 0.00931567 + - 0.98084724 + - 0.16025215 + - 0.5795517 + - -0.9616306 + - 0.39911082 + - -0.23693886 + - -0.79551834 + - 0.16646478 + - -1.8974074 + - -0.5796682 + - -0.39784008 + - 0.10325241 + - 0.18350427 + - -0.75462794 + - 1.5863496 + - -10.426914 + - -0.06125906 + - 0.010887533 + - -0.55975366 + - 0.6615073 + - 0.43986708 + - 0.3327442 + - -1.2880477 + - 0.35077965 + - -0.02984038 + - -0.38904637 + - -0.19857441 + - 0.3006837 + - -0.23337275 + - 0.77976084 + - 0.24277762 + - 0.855464 + - 0.56046736 + - -0.60108435 + - 0.7455407 + - -0.09187123 + - 0.154196 + - 0.0048865303 + - -0.27587467 + - 0.3022013 + - -0.7745438 + - 0.16001137 + - 0.47864044 + - 0.2006633 + - -0.03047125 + - 0.6744367 + - -0.9260626 + - 1.216224 + - 0.08669993 + - -0.0054739416 + - -0.4691888 + - 0.49358258 + - -0.037783816 + - -0.121854395 + - 0.16236684 + - 0.6680269 + - 0.11098406 + - -0.096719444 + - -0.43595952 + - -0.41786772 + - 0.36503515 + - -0.27678362 + - -1.6934673 + - -0.78652835 + - -0.042230323 + - 0.43430525 + - 0.5092836 + - 0.036725625 + - -0.6272748 + - -0.89257973 + - -0.3442668 + - 0.2833076 + - -0.31758663 + - -0.83551776 + - 0.8183045 + - -0.024753656 + - -1.008642 + - 0.050325975 + - -0.55715376 + - 0.44928348 + - 0.1507823 + - -0.30140626 + - -0.47645533 + - 0.49703652 + - 1.2820202 + - -0.33383223 + - 0.4607349 + - 0.4389985 + - 1.2012699 + - -0.6129646 + - -0.21307299 + - 0.13530591 + - 1.1841232 + - 0.28026375 + - 0.026620835 + - 1.4301221 + - -0.083100595 + - -0.25935137 + - 0.15548328 + - -0.49768478 + - -0.011267386 + - -0.23551744 + - 0.31306407 + - 0.08522159 + - -0.39530423 + - 0.25420204 + - 0.14808966 + - -0.050438046 + - -0.50180197 + - -0.10587026 + - -0.3444148 + - 0.545053 + - -0.94774693 + - 0.29523307 + - -0.59422666 + - 0.13652782 + - 0.0666887 + - -0.03311728 + - -0.55683196 + - -0.41232497 + - -0.9580977 + - -1.0543315 + - 0.14259595 + - -0.16259387 + - -1.5167787 + - 0.19582129 + - -0.5467265 + - -0.074472964 + - -0.2448576 + - -0.48057023 + - -1.0884755 + - -0.02183427 + - -0.5112269 + - -0.49643403 + - -0.77956486 + - -0.36602914 + - 0.467745 + - 0.0116224885 + - 0.094144434 + - -0.821409 + - -0.7807466 + - 1.0912435 + - -0.52230495 + - -0.7589699 + - -0.47725436 + - -0.41993994 + - 0.3886329 + - 1.2815615 + - 0.5036475 + - -0.0996466 + - -0.6336941 + - 0.00055135787 + - 0.79973346 + - -0.8124851 + - 0.95162 + - -0.7080159 + - -0.3792201 + - 0.17573377 + - 0.20454553 + - -0.03715585 + - 0.056939423 + - -0.038372613 + - 0.16225153 + - -0.14071062 + - -1.0473561 + - -1.304622 + - -0.8021556 + - 1.4931024 + - 0.7481458 + - -0.736031 + - -0.38892853 + - -1.3894705 + - 0.75882506 + - -0.4539283 + - 0.104750305 + - 0.9290682 + - -0.51957875 + - 0.6304363 + - 1.8291359 + - 0.1373645 + - -0.58994216 + - 0.1786941 + - -0.7910582 + - -0.2972774 + - 0.5103878 + - 0.33909678 + - 0.35446045 + - -0.50305736 + - 0.11753993 + - -0.15353155 + - -1.3326613 + - -0.28514236 + - 0.5983367 + - 0.86344767 + - -1.6818665 + - 0.15008996 + - 1.9841571 + - 0.3049481 + - 0.040631387 + - 0.29317388 + - 0.30197895 + - -0.5592847 + - 0.5627767 + - 1.0261717 + - -0.6649564 + - -0.032304436 + - -0.15240994 + - -0.053097516 + - -0.48827106 + - -0.19301243 + - -0.49933684 + - -0.7084125 + - 0.4279668 + - 0.15175593 + - -0.013540323 + - 1.0135535 + - -1.2327046 + - -0.13320355 + - 0.6499107 + - 0.8786128 + - 0.14436042 + - -0.7775906 + - 0.28951496 + - 0.3574663 + - -0.07298003 + - -0.14425844 + - -0.041969717 +- - 0.44259304 + - 0.61034733 + - 0.06906447 + - -0.5930369 + - 0.88121295 + - -0.108871564 + - 0.089603834 + - -0.20720902 + - 1.9351525 + - 0.6233138 + - 0.5944155 + - 0.18624856 + - 0.029520601 + - -0.077266246 + - 0.7963188 + - -0.23349458 + - -0.872522 + - 1.692661 + - 0.12729777 + - -1.925918 + - -0.3450846 + - -1.9348283 + - 0.90229857 + - -0.19892806 + - 0.024524756 + - 0.28779918 + - 0.36588693 + - 0.17927459 + - -0.5102686 + - 0.02657038 + - -0.89667296 + - 0.14085183 + - -0.39250132 + - -0.8478426 + - 0.004012823 + - 0.42249972 + - 0.23077106 + - 1.1616367 + - -1.1474063 + - 0.3378647 + - -2.1478326 + - -1.0142069 + - -0.28349385 + - 0.021243036 + - 0.1642712 + - 0.39428475 + - -0.9525956 + - 0.4098242 + - -0.15686427 + - -0.4859926 + - -0.071725324 + - 0.072733976 + - 0.5592042 + - 0.027051464 + - -0.028886497 + - 0.4373799 + - 1.0145298 + - 0.7788421 + - -0.56286293 + - -1.9803355 + - -0.010913014 + - 0.5070983 + - -0.39306438 + - -0.33362845 + - -1.1662877 + - -0.07186741 + - 0.008138694 + - 0.77821803 + - -0.10029823 + - -0.30778 + - 0.3581222 + - -0.44790483 + - -0.08573072 + - -1.0794868 + - 1.3211484 + - -1.0267495 + - 0.34170142 + - -0.45228508 + - 0.80183166 + - 0.50574994 + - 0.70227575 + - -0.96536624 + - -0.042180635 + - -1.0936062 + - 0.35129213 + - -0.5666409 + - -0.5346388 + - 0.23489185 + - -2.0245805 + - -0.26439658 + - 0.19770235 + - 0.7733333 + - 0.17970699 + - 1.2246317 + - 0.9814428 + - -0.2903474 + - 1.2074271 + - 0.501272 + - 0.24749207 + - -0.21642533 + - 0.2906228 + - -0.49979872 + - 0.26415095 + - 0.501433 + - -0.042981416 + - -0.014059275 + - -0.63753927 + - 0.64429754 + - -1.7964667 + - 0.25490916 + - -0.20513767 + - -0.12913927 + - 0.1964828 + - 0.8092668 + - -0.9696904 + - 0.26530862 + - -0.56467396 + - -0.5872178 + - -0.15478744 + - 0.7404878 + - 0.5340329 + - 0.60305643 + - 1.1866584 + - -0.3774311 + - 0.81875145 + - -0.24098384 + - 0.4060312 + - 0.17040129 + - -0.32154626 + - -0.7550312 + - 0.057789713 + - 0.16519901 + - 0.68956393 + - -0.617511 + - -0.0531684 + - 0.42995512 + - -0.607028 + - 0.040061504 + - -0.86182916 + - -0.7180543 + - 0.7401805 + - -0.42198357 + - 0.26968056 + - -1.5444765 + - 0.16945669 + - 1.217624 + - -0.17086306 + - -0.1421303 + - -0.5469396 + - -0.59001184 + - -0.7784361 + - -0.18679349 + - -0.34887093 + - 0.6221198 + - -0.58943474 + - -0.9040499 + - -0.73156124 + - -0.82022953 + - 0.48355055 + - -1.1594696 + - 1.2104084 + - -1.0509471 + - 0.2527444 + - 1.2852336 + - 0.06373514 + - -0.7195401 + - -0.3677724 + - 0.073669694 + - -0.021352403 + - -0.37145522 + - -0.0073406324 + - 0.3583467 + - -0.8348328 + - -0.59604347 + - -0.40265262 + - -0.02693659 + - -0.31148252 + - -1.1090808 + - -0.030301034 + - 0.13976327 + - -1.2136879 + - 0.035320282 + - 0.3535554 + - -1.5688139 + - 0.17595965 + - -0.6659802 + - -1.9802649 + - -0.35432598 + - -0.29395905 + - -1.2038764 + - -0.41608042 + - 0.36699823 + - -0.05091913 + - -0.8751462 + - -0.9087274 + - 0.4801116 + - -1.4586933 + - 0.78772914 + - -0.27553022 + - 0.08574577 + - 0.22304389 + - 0.001901783 + - -1.0914774 + - -0.0072548836 + - -0.62113667 + - -0.68177116 + - 0.45813638 + - -0.6160619 + - -0.7856389 + - -1.3864756 + - 1.5577792 + - 0.26990858 + - 0.84612846 + - -0.25666517 + - -0.7653059 + - 0.64927864 + - -0.05115719 + - -0.13208972 + - -0.11831948 + - -0.6354545 + - -0.1780964 + - 0.9480155 + - -0.4961671 + - 0.017041028 + - -0.23073995 + - -1.6752698 + - -0.8910312 + - 0.8672527 + - 1.0692105 + - 0.34349495 + - 0.7191764 + - -0.54913974 + - -0.4529444 + - -0.30141753 + - 0.45244497 + - -0.2881953 + - -0.76241916 + - 0.01701805 + - 0.6307543 + - 0.47191522 + - -0.07387722 + - 1.7862662 + - -0.41695195 + - -0.89335895 + - 0.07749042 + - -0.29598492 + - -1.0674714 + - 1.3289616 + - 1.417427 + - 0.20779747 + - -0.10695306 + - 0.15057519 + - 0.37317067 + - 0.05173081 + - -0.68831 + - 0.73900473 + - 0.15694031 + - 0.55931574 + - -0.819237 + - -0.17727043 + - -0.24614926 + - -0.06498201 + - 0.4203459 + - 1.0529735 + - -1.6634095 + - 0.902164 + - 0.22976208 + - -0.5899435 + - -0.7980481 + - -0.0048475564 + - -0.47341445 + - 0.19137529 + - -0.38998026 + - 0.6349899 + - -0.28744707 + - -0.87318814 + - 0.07867032 + - -1.2608685 + - 0.2286636 + - 0.5503384 + - -0.47462064 + - -0.20548862 + - 0.75510746 + - -1.0513641 + - 0.99237406 + - 1.2279786 + - 0.4101724 + - -0.36645734 + - 0.5490203 + - -0.3264212 + - 0.21975031 + - 1.1309171 + - -0.01674299 + - -0.6710228 + - 0.7215167 + - 0.20943573 + - 0.555139 + - 0.71108764 + - -0.16624965 + - 0.5822153 + - 1.149294 + - -0.16024849 + - -0.4955896 + - 0.016667455 + - 1.5621985 + - -0.50284207 + - 1.2017282 + - -1.002291 + - -0.35571888 + - -0.31098118 + - -0.12681499 + - 0.821267 + - 0.7572354 + - 1.1620345 + - 1.6657145 + - -0.80179685 + - -0.3032667 + - 0.3809554 + - -0.7134949 + - -0.74646837 + - 1.1901687 + - -0.9586203 + - 0.30220306 + - 0.1696356 + - 0.8158542 + - -0.8794036 + - 1.1385472 + - -0.45451683 + - 0.6137817 + - -0.47705573 + - 1.1183856 + - 0.014434105 + - -0.49392825 + - -0.039757848 + - 0.28681982 + - -0.71450067 + - 0.8194344 + - -0.0384483 + - 0.6028884 + - 0.75745547 + - -0.33422586 + - 0.86565375 + - -0.71517825 + - 0.7594802 + - -1.3496755 + - 0.18592727 + - -0.39923215 + - 0.3111456 + - 0.16714048 + - 1.2608763 + - -0.42922565 + - 0.24328378 + - 0.83976495 + - -1.3976977 + - 0.3450549 + - -0.6845992 + - -0.28764033 + - 0.46697864 + - 0.681257 + - -0.48346725 + - -0.10998453 + - -0.48293424 + - 0.17039993 + - 0.29646662 + - -0.72920346 + - -0.8712541 + - 0.84642005 + - 0.57345843 + - 0.16414955 + - -0.17650697 + - -0.1654857 + - 0.8680194 + - -0.38802695 + - 0.56318635 + - -0.30581528 + - 1.8246955 + - 0.041201115 + - 0.1730811 + - -1.0117795 + - -0.20816453 + - 1.1007582 + - 0.7987681 + - -1.0011376 + - 0.6699697 + - -0.96139586 + - 0.4520412 + - -0.3227459 + - -0.9787295 + - 1.7064301 + - -0.045380563 + - 0.44003445 + - 0.2881473 + - 0.32060444 + - -0.53330564 + - -0.7458609 + - -1.2799888 + - -0.20681454 + - -0.5788567 + - 0.87884283 + - -0.21842682 + - -0.5110508 + - 0.20337737 + - -0.9529271 + - -0.8920894 + - 0.62817466 + - -0.089243844 + - 1.2506601 + - -0.30859876 + - 0.52464736 + - -1.577458 + - 0.86506957 + - -0.16180237 + - -0.6451876 + - -0.8997455 + - -0.689568 + - -0.5550809 + - 0.54879844 + - -0.20567037 + - 1.0816867 + - 0.53907245 + - 0.24640642 + - -0.00500758 + - 0.0994609 + - -0.082028165 + - 0.6633665 + - -0.06771461 + - -0.8126046 + - -0.08276771 + - 0.122273535 + - -1.4077682 + - 0.65344125 + - 0.49604583 + - 0.48285925 + - -0.61326075 + - 0.059502434 + - 0.4968239 + - -0.3398038 + - -0.85002804 + - -0.8513342 + - 0.10036299 + - 0.64530313 + - 0.62831557 + - 0.78928226 + - -0.85008234 + - -0.766243 + - 0.39012492 + - 0.5910649 + - -0.07126428 + - -0.13246185 + - -0.89592844 + - -0.7658062 + - 1.312253 + - 0.9795044 + - 0.37078738 + - 0.047516555 + - -1.287435 + - 0.17157432 + - -0.16686267 + - -0.10688098 + - -0.11172397 + - 0.20267075 + - -0.6552357 + - 0.2800237 + - -1.8037435 + - 0.1508436 + - 0.24947366 + - 0.67023605 + - 0.67530817 + - -1.033367 + - -0.41353673 + - -0.59349984 + - 0.028830133 + - -1.3523995 + - 0.12497628 + - 0.87255096 + - 0.87087107 + - -0.18449068 + - 0.79595155 + - -0.70120776 + - -0.28769892 + - 0.3902542 + - 0.43118823 + - -0.051839836 + - -0.6050477 + - 0.6396151 + - 0.20942315 + - 0.09991099 + - -0.7391164 + - 0.24729279 + - -0.41480076 + - -0.49494857 + - -0.78128123 + - -0.35835093 + - -0.53758746 + - -0.13555299 + - 0.03073129 + - 0.47043937 + - 0.16701278 + - 1.0457106 + - 0.702593 + - 0.32327354 + - 0.018711254 + - -0.2596096 + - 1.4011946 + - 0.33755624 + - -0.06543507 + - 0.8141793 + - -0.7645023 + - -0.532987 + - -0.1951265 + - 0.19335829 + - -0.61174417 + - 0.077225186 + - -0.6988332 + - -0.3897977 + - 0.37915218 + - 0.59615886 + - -1.9831455 + - -0.3679344 + - -2.023395 + - 0.9688585 + - 0.46854872 + - -0.49805546 + - -0.208342 + - -0.15163574 + - -0.1377097 + - 1.878253 + - -0.18728009 + - 0.38123053 + - -0.13824378 + - -0.2954961 + - 0.42763495 + - -0.24603358 + - 0.1628322 + - 0.842581 + - -0.34907857 + - 0.5294959 + - -0.14574759 + - 0.77288246 + - -0.11421237 + - -0.23344792 + - 1.2086263 + - -1.5588763 + - 0.09942496 + - -1.3481518 + - 0.61870027 + - -1.1478713 + - -0.46976572 + - -0.10636251 + - -1.3085406 + - 0.43214953 + - 0.05524394 + - 0.8000935 + - 0.43331698 + - -0.22318701 + - 0.5597524 + - -8.14653 + - -0.7232454 + - -0.4829777 + - -0.21355215 + - 0.51938415 + - 0.54633397 + - 0.85073006 + - -0.88464475 + - 0.36808544 + - -0.16894472 + - -0.32347974 + - 0.12091856 + - 0.9254223 + - 0.38929504 + - 0.79355645 + - 0.46859992 + - 0.92062116 + - 0.9964036 + - -0.32657716 + - 1.2908907 + - -0.15256053 + - 0.18657881 + - -0.6105832 + - 0.061345264 + - 0.35378 + - -0.792165 + - -0.39395148 + - 1.2144314 + - -0.804891 + - 0.13545582 + - 0.21137342 + - -0.6654251 + - 1.5949675 + - 0.4424338 + - -0.24143839 + - -0.39792645 + - -1.0507987 + - -0.36413527 + - -0.5976953 + - 0.6054765 + - 1.2283214 + - 0.21500719 + - -0.18326366 + - -0.0065856427 + - -0.08184518 + - 0.8007397 + - -0.8180905 + - -1.6634116 + - -1.2854487 + - -0.85243595 + - 0.15832597 + - 0.621912 + - 0.6053898 + - 0.5224174 + - -0.93616796 + - 0.75235915 + - 0.06615754 + - -0.0023372173 + - -0.9213737 + - -0.2548948 + - -0.5085285 + - -0.7929628 + - -0.4497529 + - -0.6456441 + - 0.85625625 + - 0.120143905 + - -0.08442418 + - -0.5129668 + - 0.16538239 + - 1.6077311 + - -0.3928744 + - 0.34487766 + - -0.43184227 + - -0.04122834 + - 0.29670107 + - -0.51762897 + - 0.65680164 + - 0.57277584 + - 0.66948354 + - 0.1398776 + - 0.9790336 + - 0.16417754 + - -0.1196354 + - -0.8737719 + - 0.058854774 + - 0.24216175 + - -0.01265154 + - 0.40640974 + - -0.3477264 + - -0.19956756 + - 0.4118901 + - -0.17779568 + - -0.011484221 + - 0.10376629 + - -0.10691896 + - 0.037767142 + - 1.1812952 + - -0.689515 + - 0.73829395 + - -0.76737964 + - -0.35432225 + - 0.61250883 + - -0.9324362 + - -0.49049675 + - -0.29375598 + - -1.473715 + - -1.0492535 + - -1.2494478 + - -0.5478609 + - -1.4896125 + - 0.4817301 + - -0.79216903 + - -0.24467322 + - -0.55176955 + - -0.083230525 + - -0.47753918 + - 0.18810439 + - 0.47883877 + - -0.2106834 + - -0.83237624 + - -0.10659266 + - 0.043049127 + - -0.098483756 + - 0.5833407 + - -0.8094338 + - -1.1513048 + - 0.19241913 + - 0.50195277 + - -0.8929118 + - -0.03716117 + - -0.44107184 + - 0.5030349 + - 1.4838967 + - -0.876173 + - -0.3227269 + - -0.98616767 + - -0.09206166 + - 0.48493135 + - -0.27781898 + - 0.80406904 + - -0.56890416 + - -0.25707406 + - 0.750425 + - -0.37303916 + - 0.21937534 + - 0.37374908 + - -0.032264948 + - -0.40764412 + - -0.18038614 + - -0.48258096 + - -1.5081406 + - -0.673682 + - 0.8215298 + - 0.25156042 + - -0.82655954 + - 0.3397631 + - -1.2644277 + - 0.8678608 + - -0.27307308 + - -0.34458882 + - -0.4250061 + - -0.4120714 + - 0.8038123 + - 1.1515579 + - -1.2057272 + - 0.3560321 + - 0.0020849928 + - -1.9514849 + - -0.60190904 + - 0.5822501 + - 0.085682176 + - 0.48351315 + - -0.3535313 + - -0.11059061 + - 0.65916437 + - -1.491356 + - -0.14152524 + - -0.21427855 + - 1.0252383 + - -1.4469419 + - 0.56110775 + - 1.7371856 + - 0.095697284 + - -0.5412806 + - -0.13305002 + - 0.7577597 + - -0.27158695 + - 1.0489044 + - 1.0496688 + - 0.23123285 + - -0.15137509 + - 0.17805772 + - 0.09740275 + - -1.6383355 + - -0.60187536 + - -0.9171309 + - -0.93706274 + - 1.0573325 + - 0.702143 + - 0.7050025 + - 0.8193582 + - -2.2506168 + - -0.6361895 + - 0.35078734 + - 0.750742 + - -0.12234919 + - -1.0858344 + - 0.5054128 + - -0.68131065 + - 0.7491392 + - -0.27770603 + - -0.13498452 diff --git a/backends/candle/tests/snapshots/test_gte__gte_single.snap b/backends/candle/tests/snapshots/test_gte__gte_single.snap new file mode 100644 index 00000000..6b03d6be --- /dev/null +++ b/backends/candle/tests/snapshots/test_gte__gte_single.snap @@ -0,0 +1,773 @@ +--- +source: backends/candle/tests/test_gte.rs +assertion_line: 45 +expression: embeddings_single +--- +- - 0.44259304 + - 0.61034733 + - 0.06906447 + - -0.5930369 + - 0.88121295 + - -0.108871564 + - 0.089603834 + - -0.20720902 + - 1.9351525 + - 0.6233138 + - 0.5944155 + - 0.18624856 + - 0.029520601 + - -0.077266246 + - 0.7963188 + - -0.23349458 + - -0.872522 + - 1.692661 + - 0.12729777 + - -1.925918 + - -0.3450846 + - -1.9348283 + - 0.90229857 + - -0.19892806 + - 0.024524756 + - 0.28779918 + - 0.36588693 + - 0.17927459 + - -0.5102686 + - 0.02657038 + - -0.89667296 + - 0.14085183 + - -0.39250132 + - -0.8478426 + - 0.004012823 + - 0.42249972 + - 0.23077106 + - 1.1616367 + - -1.1474063 + - 0.3378647 + - -2.1478326 + - -1.0142069 + - -0.28349385 + - 0.021243036 + - 0.1642712 + - 0.39428475 + - -0.9525956 + - 0.4098242 + - -0.15686427 + - -0.4859926 + - -0.071725324 + - 0.072733976 + - 0.5592042 + - 0.027051464 + - -0.028886497 + - 0.4373799 + - 1.0145298 + - 0.7788421 + - -0.56286293 + - -1.9803355 + - -0.010913014 + - 0.5070983 + - -0.39306438 + - -0.33362845 + - -1.1662877 + - -0.07186741 + - 0.008138694 + - 0.77821803 + - -0.10029823 + - -0.30778 + - 0.3581222 + - -0.44790483 + - -0.08573072 + - -1.0794868 + - 1.3211484 + - -1.0267495 + - 0.34170142 + - -0.45228508 + - 0.80183166 + - 0.50574994 + - 0.70227575 + - -0.96536624 + - -0.042180635 + - -1.0936062 + - 0.35129213 + - -0.5666409 + - -0.5346388 + - 0.23489185 + - -2.0245805 + - -0.26439658 + - 0.19770235 + - 0.7733333 + - 0.17970699 + - 1.2246317 + - 0.9814428 + - -0.2903474 + - 1.2074271 + - 0.501272 + - 0.24749207 + - -0.21642533 + - 0.2906228 + - -0.49979872 + - 0.26415095 + - 0.501433 + - -0.042981416 + - -0.014059275 + - -0.63753927 + - 0.64429754 + - -1.7964667 + - 0.25490916 + - -0.20513767 + - -0.12913927 + - 0.1964828 + - 0.8092668 + - -0.9696904 + - 0.26530862 + - -0.56467396 + - -0.5872178 + - -0.15478744 + - 0.7404878 + - 0.5340329 + - 0.60305643 + - 1.1866584 + - -0.3774311 + - 0.81875145 + - -0.24098384 + - 0.4060312 + - 0.17040129 + - -0.32154626 + - -0.7550312 + - 0.057789713 + - 0.16519901 + - 0.68956393 + - -0.617511 + - -0.0531684 + - 0.42995512 + - -0.607028 + - 0.040061504 + - -0.86182916 + - -0.7180543 + - 0.7401805 + - -0.42198357 + - 0.26968056 + - -1.5444765 + - 0.16945669 + - 1.217624 + - -0.17086306 + - -0.1421303 + - -0.5469396 + - -0.59001184 + - -0.7784361 + - -0.18679349 + - -0.34887093 + - 0.6221198 + - -0.58943474 + - -0.9040499 + - -0.73156124 + - -0.82022953 + - 0.48355055 + - -1.1594696 + - 1.2104084 + - -1.0509471 + - 0.2527444 + - 1.2852336 + - 0.06373514 + - -0.7195401 + - -0.3677724 + - 0.073669694 + - -0.021352403 + - -0.37145522 + - -0.0073406324 + - 0.3583467 + - -0.8348328 + - -0.59604347 + - -0.40265262 + - -0.02693659 + - -0.31148252 + - -1.1090808 + - -0.030301034 + - 0.13976327 + - -1.2136879 + - 0.035320282 + - 0.3535554 + - -1.5688139 + - 0.17595965 + - -0.6659802 + - -1.9802649 + - -0.35432598 + - -0.29395905 + - -1.2038764 + - -0.41608042 + - 0.36699823 + - -0.05091913 + - -0.8751462 + - -0.9087274 + - 0.4801116 + - -1.4586933 + - 0.78772914 + - -0.27553022 + - 0.08574577 + - 0.22304389 + - 0.001901783 + - -1.0914774 + - -0.0072548836 + - -0.62113667 + - -0.68177116 + - 0.45813638 + - -0.6160619 + - -0.7856389 + - -1.3864756 + - 1.5577792 + - 0.26990858 + - 0.84612846 + - -0.25666517 + - -0.7653059 + - 0.64927864 + - -0.05115719 + - -0.13208972 + - -0.11831948 + - -0.6354545 + - -0.1780964 + - 0.9480155 + - -0.4961671 + - 0.017041028 + - -0.23073995 + - -1.6752698 + - -0.8910312 + - 0.8672527 + - 1.0692105 + - 0.34349495 + - 0.7191764 + - -0.54913974 + - -0.4529444 + - -0.30141753 + - 0.45244497 + - -0.2881953 + - -0.76241916 + - 0.01701805 + - 0.6307543 + - 0.47191522 + - -0.07387722 + - 1.7862662 + - -0.41695195 + - -0.89335895 + - 0.07749042 + - -0.29598492 + - -1.0674714 + - 1.3289616 + - 1.417427 + - 0.20779747 + - -0.10695306 + - 0.15057519 + - 0.37317067 + - 0.05173081 + - -0.68831 + - 0.73900473 + - 0.15694031 + - 0.55931574 + - -0.819237 + - -0.17727043 + - -0.24614926 + - -0.06498201 + - 0.4203459 + - 1.0529735 + - -1.6634095 + - 0.902164 + - 0.22976208 + - -0.5899435 + - -0.7980481 + - -0.0048475564 + - -0.47341445 + - 0.19137529 + - -0.38998026 + - 0.6349899 + - -0.28744707 + - -0.87318814 + - 0.07867032 + - -1.2608685 + - 0.2286636 + - 0.5503384 + - -0.47462064 + - -0.20548862 + - 0.75510746 + - -1.0513641 + - 0.99237406 + - 1.2279786 + - 0.4101724 + - -0.36645734 + - 0.5490203 + - -0.3264212 + - 0.21975031 + - 1.1309171 + - -0.01674299 + - -0.6710228 + - 0.7215167 + - 0.20943573 + - 0.555139 + - 0.71108764 + - -0.16624965 + - 0.5822153 + - 1.149294 + - -0.16024849 + - -0.4955896 + - 0.016667455 + - 1.5621985 + - -0.50284207 + - 1.2017282 + - -1.002291 + - -0.35571888 + - -0.31098118 + - -0.12681499 + - 0.821267 + - 0.7572354 + - 1.1620345 + - 1.6657145 + - -0.80179685 + - -0.3032667 + - 0.3809554 + - -0.7134949 + - -0.74646837 + - 1.1901687 + - -0.9586203 + - 0.30220306 + - 0.1696356 + - 0.8158542 + - -0.8794036 + - 1.1385472 + - -0.45451683 + - 0.6137817 + - -0.47705573 + - 1.1183856 + - 0.014434105 + - -0.49392825 + - -0.039757848 + - 0.28681982 + - -0.71450067 + - 0.8194344 + - -0.0384483 + - 0.6028884 + - 0.75745547 + - -0.33422586 + - 0.86565375 + - -0.71517825 + - 0.7594802 + - -1.3496755 + - 0.18592727 + - -0.39923215 + - 0.3111456 + - 0.16714048 + - 1.2608763 + - -0.42922565 + - 0.24328378 + - 0.83976495 + - -1.3976977 + - 0.3450549 + - -0.6845992 + - -0.28764033 + - 0.46697864 + - 0.681257 + - -0.48346725 + - -0.10998453 + - -0.48293424 + - 0.17039993 + - 0.29646662 + - -0.72920346 + - -0.8712541 + - 0.84642005 + - 0.57345843 + - 0.16414955 + - -0.17650697 + - -0.1654857 + - 0.8680194 + - -0.38802695 + - 0.56318635 + - -0.30581528 + - 1.8246955 + - 0.041201115 + - 0.1730811 + - -1.0117795 + - -0.20816453 + - 1.1007582 + - 0.7987681 + - -1.0011376 + - 0.6699697 + - -0.96139586 + - 0.4520412 + - -0.3227459 + - -0.9787295 + - 1.7064301 + - -0.045380563 + - 0.44003445 + - 0.2881473 + - 0.32060444 + - -0.53330564 + - -0.7458609 + - -1.2799888 + - -0.20681454 + - -0.5788567 + - 0.87884283 + - -0.21842682 + - -0.5110508 + - 0.20337737 + - -0.9529271 + - -0.8920894 + - 0.62817466 + - -0.089243844 + - 1.2506601 + - -0.30859876 + - 0.52464736 + - -1.577458 + - 0.86506957 + - -0.16180237 + - -0.6451876 + - -0.8997455 + - -0.689568 + - -0.5550809 + - 0.54879844 + - -0.20567037 + - 1.0816867 + - 0.53907245 + - 0.24640642 + - -0.00500758 + - 0.0994609 + - -0.082028165 + - 0.6633665 + - -0.06771461 + - -0.8126046 + - -0.08276771 + - 0.122273535 + - -1.4077682 + - 0.65344125 + - 0.49604583 + - 0.48285925 + - -0.61326075 + - 0.059502434 + - 0.4968239 + - -0.3398038 + - -0.85002804 + - -0.8513342 + - 0.10036299 + - 0.64530313 + - 0.62831557 + - 0.78928226 + - -0.85008234 + - -0.766243 + - 0.39012492 + - 0.5910649 + - -0.07126428 + - -0.13246185 + - -0.89592844 + - -0.7658062 + - 1.312253 + - 0.9795044 + - 0.37078738 + - 0.047516555 + - -1.287435 + - 0.17157432 + - -0.16686267 + - -0.10688098 + - -0.11172397 + - 0.20267075 + - -0.6552357 + - 0.2800237 + - -1.8037435 + - 0.1508436 + - 0.24947366 + - 0.67023605 + - 0.67530817 + - -1.033367 + - -0.41353673 + - -0.59349984 + - 0.028830133 + - -1.3523995 + - 0.12497628 + - 0.87255096 + - 0.87087107 + - -0.18449068 + - 0.79595155 + - -0.70120776 + - -0.28769892 + - 0.3902542 + - 0.43118823 + - -0.051839836 + - -0.6050477 + - 0.6396151 + - 0.20942315 + - 0.09991099 + - -0.7391164 + - 0.24729279 + - -0.41480076 + - -0.49494857 + - -0.78128123 + - -0.35835093 + - -0.53758746 + - -0.13555299 + - 0.03073129 + - 0.47043937 + - 0.16701278 + - 1.0457106 + - 0.702593 + - 0.32327354 + - 0.018711254 + - -0.2596096 + - 1.4011946 + - 0.33755624 + - -0.06543507 + - 0.8141793 + - -0.7645023 + - -0.532987 + - -0.1951265 + - 0.19335829 + - -0.61174417 + - 0.077225186 + - -0.6988332 + - -0.3897977 + - 0.37915218 + - 0.59615886 + - -1.9831455 + - -0.3679344 + - -2.023395 + - 0.9688585 + - 0.46854872 + - -0.49805546 + - -0.208342 + - -0.15163574 + - -0.1377097 + - 1.878253 + - -0.18728009 + - 0.38123053 + - -0.13824378 + - -0.2954961 + - 0.42763495 + - -0.24603358 + - 0.1628322 + - 0.842581 + - -0.34907857 + - 0.5294959 + - -0.14574759 + - 0.77288246 + - -0.11421237 + - -0.23344792 + - 1.2086263 + - -1.5588763 + - 0.09942496 + - -1.3481518 + - 0.61870027 + - -1.1478713 + - -0.46976572 + - -0.10636251 + - -1.3085406 + - 0.43214953 + - 0.05524394 + - 0.8000935 + - 0.43331698 + - -0.22318701 + - 0.5597524 + - -8.14653 + - -0.7232454 + - -0.4829777 + - -0.21355215 + - 0.51938415 + - 0.54633397 + - 0.85073006 + - -0.88464475 + - 0.36808544 + - -0.16894472 + - -0.32347974 + - 0.12091856 + - 0.9254223 + - 0.38929504 + - 0.79355645 + - 0.46859992 + - 0.92062116 + - 0.9964036 + - -0.32657716 + - 1.2908907 + - -0.15256053 + - 0.18657881 + - -0.6105832 + - 0.061345264 + - 0.35378 + - -0.792165 + - -0.39395148 + - 1.2144314 + - -0.804891 + - 0.13545582 + - 0.21137342 + - -0.6654251 + - 1.5949675 + - 0.4424338 + - -0.24143839 + - -0.39792645 + - -1.0507987 + - -0.36413527 + - -0.5976953 + - 0.6054765 + - 1.2283214 + - 0.21500719 + - -0.18326366 + - -0.0065856427 + - -0.08184518 + - 0.8007397 + - -0.8180905 + - -1.6634116 + - -1.2854487 + - -0.85243595 + - 0.15832597 + - 0.621912 + - 0.6053898 + - 0.5224174 + - -0.93616796 + - 0.75235915 + - 0.06615754 + - -0.0023372173 + - -0.9213737 + - -0.2548948 + - -0.5085285 + - -0.7929628 + - -0.4497529 + - -0.6456441 + - 0.85625625 + - 0.120143905 + - -0.08442418 + - -0.5129668 + - 0.16538239 + - 1.6077311 + - -0.3928744 + - 0.34487766 + - -0.43184227 + - -0.04122834 + - 0.29670107 + - -0.51762897 + - 0.65680164 + - 0.57277584 + - 0.66948354 + - 0.1398776 + - 0.9790336 + - 0.16417754 + - -0.1196354 + - -0.8737719 + - 0.058854774 + - 0.24216175 + - -0.01265154 + - 0.40640974 + - -0.3477264 + - -0.19956756 + - 0.4118901 + - -0.17779568 + - -0.011484221 + - 0.10376629 + - -0.10691896 + - 0.037767142 + - 1.1812952 + - -0.689515 + - 0.73829395 + - -0.76737964 + - -0.35432225 + - 0.61250883 + - -0.9324362 + - -0.49049675 + - -0.29375598 + - -1.473715 + - -1.0492535 + - -1.2494478 + - -0.5478609 + - -1.4896125 + - 0.4817301 + - -0.79216903 + - -0.24467322 + - -0.55176955 + - -0.083230525 + - -0.47753918 + - 0.18810439 + - 0.47883877 + - -0.2106834 + - -0.83237624 + - -0.10659266 + - 0.043049127 + - -0.098483756 + - 0.5833407 + - -0.8094338 + - -1.1513048 + - 0.19241913 + - 0.50195277 + - -0.8929118 + - -0.03716117 + - -0.44107184 + - 0.5030349 + - 1.4838967 + - -0.876173 + - -0.3227269 + - -0.98616767 + - -0.09206166 + - 0.48493135 + - -0.27781898 + - 0.80406904 + - -0.56890416 + - -0.25707406 + - 0.750425 + - -0.37303916 + - 0.21937534 + - 0.37374908 + - -0.032264948 + - -0.40764412 + - -0.18038614 + - -0.48258096 + - -1.5081406 + - -0.673682 + - 0.8215298 + - 0.25156042 + - -0.82655954 + - 0.3397631 + - -1.2644277 + - 0.8678608 + - -0.27307308 + - -0.34458882 + - -0.4250061 + - -0.4120714 + - 0.8038123 + - 1.1515579 + - -1.2057272 + - 0.3560321 + - 0.0020849928 + - -1.9514849 + - -0.60190904 + - 0.5822501 + - 0.085682176 + - 0.48351315 + - -0.3535313 + - -0.11059061 + - 0.65916437 + - -1.491356 + - -0.14152524 + - -0.21427855 + - 1.0252383 + - -1.4469419 + - 0.56110775 + - 1.7371856 + - 0.095697284 + - -0.5412806 + - -0.13305002 + - 0.7577597 + - -0.27158695 + - 1.0489044 + - 1.0496688 + - 0.23123285 + - -0.15137509 + - 0.17805772 + - 0.09740275 + - -1.6383355 + - -0.60187536 + - -0.9171309 + - -0.93706274 + - 1.0573325 + - 0.702143 + - 0.7050025 + - 0.8193582 + - -2.2506168 + - -0.6361895 + - 0.35078734 + - 0.750742 + - -0.12234919 + - -1.0858344 + - 0.5054128 + - -0.68131065 + - 0.7491392 + - -0.27770603 + - -0.13498452 diff --git a/backends/candle/tests/test_gte.rs b/backends/candle/tests/test_gte.rs new file mode 100644 index 00000000..164c5c27 --- /dev/null +++ b/backends/candle/tests/test_gte.rs @@ -0,0 +1,50 @@ +mod common; + +use crate::common::{sort_embeddings, SnapshotEmbeddings}; +use anyhow::Result; +use common::{batch, cosine_matcher, download_artifacts, load_tokenizer}; +use text_embeddings_backend_candle::CandleBackend; +use text_embeddings_backend_core::{Backend, ModelType, Pool}; + +#[test] +fn test_gte() -> Result<()> { + let model_root = download_artifacts("Alibaba-NLP/gte-base-en-v1.5", None)?; + let tokenizer = load_tokenizer(&model_root)?; + + let backend = CandleBackend::new( + &model_root, + "float32".to_string(), + ModelType::Embedding(Pool::Cls), + )?; + + let input_batch = batch( + vec![ + tokenizer.encode("What is Deep Learning?", true).unwrap(), + tokenizer.encode("Deep Learning is...", true).unwrap(), + tokenizer.encode("What is Deep Learning?", true).unwrap(), + ], + [0, 1, 2].to_vec(), + vec![], + ); + + let matcher = cosine_matcher(); + + let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_batch)?); + let embeddings_batch = SnapshotEmbeddings::from(pooled_embeddings); + insta::assert_yaml_snapshot!("gte_batch", embeddings_batch, &matcher); + + let input_single = batch( + vec![tokenizer.encode("What is Deep Learning?", true).unwrap()], + [0].to_vec(), + vec![], + ); + + let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_single)?); + let embeddings_single = SnapshotEmbeddings::from(pooled_embeddings); + + insta::assert_yaml_snapshot!("gte_single", embeddings_single, &matcher); + assert_eq!(embeddings_batch[0], embeddings_single[0]); + assert_eq!(embeddings_batch[2], embeddings_single[0]); + + Ok(()) +}