Skip to content

Commit

Permalink
Merge branch 'main' into user_guide_update
Browse files Browse the repository at this point in the history
  • Loading branch information
joshuajerin authored Nov 11, 2024
2 parents a891bb3 + aac3999 commit a79a8a9
Show file tree
Hide file tree
Showing 36 changed files with 711 additions and 428 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/extension_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ jobs:
CO_API_KEY: ${{ secrets.CO_API_KEY }}
PORTKEY_API_KEY: ${{ secrets.PORTKEY_API_KEY }}
PORTKEY_VIRTUAL_KEY_OPENAI: ${{ secrets.PORTKEY_VIRTUAL_KEY_OPENAI }}
VOYAGE_API_KEY: ${{ secrets.VOYAGE_API_KEY }}
run: |
cd ../core && cargo test
- name: Restore cached binaries
Expand All @@ -132,6 +133,7 @@ jobs:
CO_API_KEY: ${{ secrets.CO_API_KEY }}
PORTKEY_API_KEY: ${{ secrets.PORTKEY_API_KEY }}
PORTKEY_VIRTUAL_KEY_OPENAI: ${{ secrets.PORTKEY_VIRTUAL_KEY_OPENAI }}
VOYAGE_API_KEY: ${{ secrets.VOYAGE_API_KEY }}
run: |
echo "\q" | make run
make test-integration
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/extension_upgrade.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ on:
jobs:
test:
name: Upgrade Test
runs-on: ubuntu-latest
runs-on: ubuntu-24.04
services:
vector-serve:
image: quay.io/tembo/vector-serve:latest
Expand Down
6 changes: 6 additions & 0 deletions .github/workflows/pg-image-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ jobs:
shell: bash
run: |
set -xe
sudo apt-get update
sudo apt-get install -y wget
wget https://github.com/freshautomations/stoml/releases/download/v0.7.1/stoml_linux_amd64 &> /dev/null
mv stoml_linux_amd64 stoml
chmod +x stoml
Expand Down Expand Up @@ -94,6 +96,8 @@ jobs:
shell: bash
run: |
set -xe
sudo apt-get update
sudo apt-get install -y wget
wget https://github.com/freshautomations/stoml/releases/download/v0.7.1/stoml_linux_armv7 &> /dev/null
mv stoml_linux_armv7 stoml
chmod +x stoml
Expand Down Expand Up @@ -164,6 +168,8 @@ jobs:
shell: bash
run: |
set -xe
sudo apt-get update
sudo apt-get install -y wget
wget https://github.com/freshautomations/stoml/releases/download/v0.7.1/stoml_linux_amd64 &> /dev/null
mv stoml_linux_amd64 stoml
chmod +x stoml
Expand Down
8 changes: 7 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ export PG_VERSION=17

#### 3.3. Install dependencies

From within the pg_vectorize directory, run the following, which will install `pg_cron`, `pgmq`, and `pgvector`:
From within the pg_vectorize/extension directory, run the following, which will install `pg_cron`, `pgmq`, and `pgvector`:

```bash
make setup
Expand All @@ -72,6 +72,12 @@ make run

Once the above command is run, you will be brought into Postgres via `psql`.

Run the following command inside the `psql` console to enable the extensions:

```sql
create extension vectorize cascade
```

To list out the enabled extensions, run:

```sql
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ Alternatively, `schedule => 'realtime` creates triggers on the source table and
Statements below would will result in new embeddings being generated either immediately (`schedule => 'realtime'`) or within the cron schedule set in the `schedule` parameter.

```sql
INSERT INTO products (product_id, product_name, description)
VALUES (12345, 'pizza', 'dish of Italian origin consisting of a flattened disk of bread');
INSERT INTO products (product_id, product_name, description, product_category, price)
VALUES (12345, 'pizza', 'dish of Italian origin consisting of a flattened disk of bread', 'food', 5.99);

UPDATE products
SET description = 'sling made of fabric, rope, or netting, suspended between two or more points, used for swinging, sleeping, or resting'
Expand Down
2 changes: 1 addition & 1 deletion core/src/transformers/providers/cohere.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ mod integration_tests {
);
assert!(
embeddings.embeddings[0].len() == 384,
"Embeddings should have length 384"
"Embeddings should have dimension 384"
);
}
}
4 changes: 4 additions & 0 deletions core/src/transformers/providers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pub mod ollama;
pub mod openai;
pub mod portkey;
pub mod vector_serve;
pub mod voyage;

use anyhow::Result;
use async_trait::async_trait;
Expand Down Expand Up @@ -66,6 +67,9 @@ pub fn get_provider(
api_key,
virtual_key,
))),
ModelSource::Voyage => Ok(Box::new(providers::voyage::VoyageProvider::new(
url, api_key,
))),
ModelSource::SentenceTransformers => Ok(Box::new(
providers::vector_serve::VectorServeProvider::new(url, api_key),
)),
Expand Down
2 changes: 1 addition & 1 deletion core/src/transformers/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ mod integration_tests {
);
assert!(
embeddings.embeddings[0].len() == 1536,
"Embeddings should have length 1536"
"Embeddings should have dimension 1536"
);
}
}
Expand Down
2 changes: 1 addition & 1 deletion core/src/transformers/providers/portkey.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ mod portkey_integration_tests {
);
assert!(
embeddings.embeddings[0].len() == 1536,
"Embeddings should have length 1536"
"Embeddings should have dimension 1536"
);

let dim = provider.model_dim("text-embedding-ada-002").await.unwrap();
Expand Down
2 changes: 1 addition & 1 deletion core/src/transformers/providers/vector_serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ mod integration_tests {
);
assert!(
embeddings.embeddings[0].len() == 384,
"Embeddings should have length 1536"
"Embeddings should have dimension 1536"
);

let model_dim = provider
Expand Down
144 changes: 144 additions & 0 deletions core/src/transformers/providers/voyage.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
use reqwest::Client;
use serde::{Deserialize, Serialize};

use super::{EmbeddingProvider, GenericEmbeddingRequest, GenericEmbeddingResponse};
use crate::errors::VectorizeError;
use crate::transformers::http_handler::handle_response;
use async_trait::async_trait;
use std::env;

pub const VOYAGE_BASE_URL: &str = "https://api.voyageai.com/v1";

pub struct VoyageProvider {
pub url: String,
pub api_key: String,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct VoyageEmbeddingBody {
pub input: Vec<String>,
pub model: String,
pub input_type: String,
}

impl From<GenericEmbeddingRequest> for VoyageEmbeddingBody {
fn from(request: GenericEmbeddingRequest) -> Self {
VoyageEmbeddingBody {
input: request.input,
model: request.model,
input_type: "document".to_string(),
}
}
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct VoyageEmbeddingResponse {
pub data: Vec<EmbeddingObject>,
}

#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct EmbeddingObject {
pub embedding: Vec<f64>,
}

impl From<VoyageEmbeddingResponse> for GenericEmbeddingResponse {
fn from(response: VoyageEmbeddingResponse) -> Self {
GenericEmbeddingResponse {
embeddings: response.data.iter().map(|x| x.embedding.clone()).collect(),
}
}
}

impl VoyageProvider {
pub fn new(url: Option<String>, api_key: Option<String>) -> Self {
let final_url = match url {
Some(url) => url,
None => VOYAGE_BASE_URL.to_string(),
};
let final_api_key = match api_key {
Some(api_key) => api_key,
None => env::var("VOYAGE_API_KEY").expect("VOYAGE_API_KEY not set"),
};
VoyageProvider {
url: final_url,
api_key: final_api_key,
}
}
}

#[async_trait]
impl EmbeddingProvider for VoyageProvider {
async fn generate_embedding<'a>(
&self,
request: &'a GenericEmbeddingRequest,
) -> Result<GenericEmbeddingResponse, VectorizeError> {
let client = Client::new();

let req_body = VoyageEmbeddingBody::from(request.clone());
let embedding_url = format!("{}/embeddings", self.url);

let response = client
.post(&embedding_url)
.timeout(std::time::Duration::from_secs(120_u64))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&req_body)
.send()
.await?;

let embeddings = handle_response::<VoyageEmbeddingResponse>(response, "embeddings").await?;
Ok(GenericEmbeddingResponse {
embeddings: embeddings
.data
.iter()
.map(|x| x.embedding.clone())
.collect(),
})
}

async fn model_dim(&self, model_name: &str) -> Result<u32, VectorizeError> {
// determine embedding dim by generating an embedding and getting length of array
let req = GenericEmbeddingRequest {
input: vec!["hello world".to_string()],
model: model_name.to_string(),
};
let embedding = self.generate_embedding(&req).await?;
let dim = embedding.embeddings[0].len();
Ok(dim as u32)
}
}

#[cfg(test)]
mod integration_tests {
use super::*;
use std::env;

#[tokio::test]
async fn test_voyage_ai_embedding() {
let api_key = Some(env::var("VOYAGE_API_KEY").expect("VOYAGE_API_KEY must be set"));
let provider = VoyageProvider::new(Some(VOYAGE_BASE_URL.to_string()), api_key);

let request = GenericEmbeddingRequest {
input: vec!["hello world".to_string()],
model: "voyage-3-lite".to_string(),
};

let embeddings = provider.generate_embedding(&request).await.unwrap();
println!("{:?}", embeddings);
assert!(
!embeddings.embeddings.is_empty(),
"Embeddings should not be empty"
);
assert!(
embeddings.embeddings.len() == 1,
"Embeddings should have length 1"
);
assert!(
embeddings.embeddings[0].len() == 512,
"Embeddings should have dimension 512"
);

let dim = provider.model_dim("voyage-3-lite").await.unwrap();
assert_eq!(dim, 512);
}
}
16 changes: 14 additions & 2 deletions core/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,6 @@ pub struct VectorizeMeta {
pub name: String,
pub index_dist_type: IndexDist,
pub transformer: Model,
// search_alg and SimilarityAlg are now deprecated
pub search_alg: SimilarityAlg,
pub params: serde_json::Value,
#[serde(deserialize_with = "from_tsopt")]
pub last_completion: Option<chrono::DateTime<Utc>>,
Expand All @@ -162,6 +160,7 @@ impl Model {
ModelSource::Tembo => self.name.clone(),
ModelSource::Cohere => self.name.clone(),
ModelSource::Portkey => self.name.clone(),
ModelSource::Voyage => self.name.clone(),
}
}
}
Expand Down Expand Up @@ -239,6 +238,7 @@ pub enum ModelSource {
Tembo,
Cohere,
Portkey,
Voyage,
}

impl FromStr for ModelSource {
Expand All @@ -252,6 +252,7 @@ impl FromStr for ModelSource {
"tembo" => Ok(ModelSource::Tembo),
"cohere" => Ok(ModelSource::Cohere),
"portkey" => Ok(ModelSource::Portkey),
"voyage" => Ok(ModelSource::Voyage),
_ => Ok(ModelSource::SentenceTransformers),
}
}
Expand All @@ -266,6 +267,7 @@ impl Display for ModelSource {
ModelSource::Tembo => write!(f, "tembo"),
ModelSource::Cohere => write!(f, "cohere"),
ModelSource::Portkey => write!(f, "portkey"),
ModelSource::Voyage => write!(f, "voyage"),
}
}
}
Expand All @@ -279,6 +281,7 @@ impl From<String> for ModelSource {
"tembo" => ModelSource::Tembo,
"cohere" => ModelSource::Cohere,
"portkey" => ModelSource::Portkey,
"voyage" => ModelSource::Voyage,
// other cases are assumed to be private sentence-transformer compatible model
// and can be hot-loaded
_ => ModelSource::SentenceTransformers,
Expand All @@ -300,6 +303,15 @@ mod model_tests {
assert_eq!(model.api_name(), "text-embedding-ada-002");
}

#[test]
fn test_voyage_parsing() {
let model = Model::new("voyage/voyage-3-lite").unwrap();
assert_eq!(model.source, ModelSource::Voyage);
assert_eq!(model.fullname, "voyage/voyage-3-lite");
assert_eq!(model.name, "voyage-3-lite");
assert_eq!(model.api_name(), "voyage-3-lite");
}

#[test]
fn test_tembo_parsing() {
let model = Model::new("tembo/meta-llama/Meta-Llama-3-8B-Instruct").unwrap();
Expand Down
Loading

0 comments on commit a79a8a9

Please sign in to comment.