-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #6 from StarlightSearch/dev
Add Clip and Local embedding Model
- Loading branch information
Showing
23 changed files
with
1,060 additions
and
701 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,47 @@ | ||
from .embed_anything import * | ||
from embed_anything import EmbedData | ||
|
||
def embed_file(file_path: str) -> list[EmbedData]: | ||
def embed_query(query: list[str], embeder: str) -> list[EmbedData]: | ||
""" | ||
Embeds the given query and returns an EmbedData object. | ||
### Arguments: | ||
- `query`: The query to embed. | ||
- `embeder`: The name of the embedding model to use. Choose between "OpenAI" and "AllMiniLmL12V2" | ||
### Returns: | ||
- An EmbedData object. | ||
""" | ||
def embed_file(file_path: str, embeder: str) -> list[EmbedData]: | ||
""" | ||
Embeds the file at the given path and returns a list of EmbedData objects. | ||
### Arguments: | ||
- `file_path`: The path to the file to embed. | ||
- `embeder`: The name of the embedding model to use. Choose between "OpenAI" and "AllMiniLmL12V2" | ||
### Returns: | ||
- A list of EmbedData objects. | ||
""" | ||
|
||
def embed_directory(file_path: str, embeder: str) -> list[EmbedData]: | ||
""" | ||
Embeds all the files in the given directory and returns a list of EmbedData objects. | ||
### Arguments: | ||
- `file_path`: The path to the directory containing the files to embed. | ||
- `embeder`: The name of the embedding model to use. Choose between "OpenAI" and "AllMiniLmL12V2" | ||
### Returns: | ||
- A list of EmbedData objects. | ||
""" | ||
|
||
class EmbedData: | ||
""" | ||
Represents the data of an embedded file. | ||
### Attributes: | ||
- `embedding`: The embedding of the file. | ||
- `text`: The text for which the embedding is generated for. | ||
""" | ||
embedding: list[float] | ||
text: str |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
use candle_core::{Device, Tensor}; | ||
use embed_anything::{embed_directory, embed_query}; | ||
use std::path::PathBuf; | ||
|
||
fn main() { | ||
// let out = embed_file("test_files/TUe_SOP_AI_2.pdf", "Bert").unwrap(); | ||
let out = embed_directory(PathBuf::from("test_files"), "Clip").unwrap(); | ||
let query_emb_data = embed_query(vec!["Photo of a monkey".to_string()], "Clip").unwrap(); | ||
let n_vectors = out.len(); | ||
let out_embeddings = Tensor::from_vec( | ||
out.iter() | ||
.map(|embed| embed.embedding.clone()) | ||
.collect::<Vec<_>>() | ||
.iter() | ||
.flatten() | ||
.cloned() | ||
.collect::<Vec<f32>>(), | ||
(n_vectors, out[0].embedding.len()), | ||
&Device::Cpu, | ||
) | ||
.unwrap(); | ||
|
||
let image_paths = out | ||
.iter() | ||
.map(|embed| embed.text.clone().unwrap()) | ||
.collect::<Vec<_>>(); | ||
|
||
let query_embeddings = Tensor::from_vec( | ||
query_emb_data | ||
.iter() | ||
.map(|embed| embed.embedding.clone()) | ||
.collect::<Vec<_>>() | ||
.iter() | ||
.flatten() | ||
.cloned() | ||
.collect::<Vec<f32>>(), | ||
(1, query_emb_data[0].embedding.len()), | ||
&Device::Cpu, | ||
) | ||
.unwrap(); | ||
|
||
let similarities = out_embeddings | ||
.matmul(&query_embeddings.transpose(0, 1).unwrap()) | ||
.unwrap() | ||
.detach() | ||
.squeeze(1) | ||
.unwrap() | ||
.to_vec1::<f32>() | ||
.unwrap(); | ||
let mut indices: Vec<usize> = (0..similarities.len()).collect(); | ||
indices.sort_by(|a, b| similarities[*b].partial_cmp(&similarities[*a]).unwrap()); | ||
|
||
let top_3_indices = indices[0..3].to_vec(); | ||
let top_3_image_paths = top_3_indices | ||
.iter() | ||
.map(|i| image_paths[*i].clone()) | ||
.collect::<Vec<String>>(); | ||
|
||
let similar_image =top_3_image_paths[0].clone(); | ||
|
||
println!("{:?}", similar_image) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
extern crate intel_mkl_src; | ||
|
||
use anyhow::Error as E; | ||
use candle_core::{Device, Tensor}; | ||
use tokenizers::{PaddingParams, Tokenizer}; | ||
use super::embed::{Embed, EmbedData}; | ||
use candle_nn::VarBuilder; | ||
use candle_transformers::models::bert::{BertModel, Config, HiddenAct, DTYPE}; | ||
use hf_hub::{api::sync::Api, Repo}; | ||
|
||
pub struct BertEmbeder { | ||
pub model: BertModel, | ||
pub tokenizer: Tokenizer, | ||
} | ||
impl BertEmbeder { | ||
pub fn default() -> anyhow::Result<Self> { | ||
let device = Device::Cpu; | ||
let default_model = "sentence-transformers/all-MiniLM-L12-v2".to_string(); | ||
let default_revision = "refs/pr/21".to_string(); | ||
let (model_id, _revision) = (default_model, default_revision); | ||
let repo = Repo::model(model_id); | ||
let (config_filename, tokenizer_filename, weights_filename) = { | ||
let api = Api::new()?; | ||
let api = api.repo(repo); | ||
let config = api.get("config.json")?; | ||
let tokenizer = api.get("tokenizer.json")?; | ||
let weights = api.get("model.safetensors")?; | ||
|
||
(config, tokenizer, weights) | ||
}; | ||
let config = std::fs::read_to_string(config_filename)?; | ||
let mut config: Config = serde_json::from_str(&config)?; | ||
let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; | ||
|
||
let pp = PaddingParams { | ||
strategy: tokenizers::PaddingStrategy::BatchLongest, | ||
..Default::default() | ||
}; | ||
tokenizer.with_padding(Some(pp)); | ||
|
||
let vb = | ||
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }; | ||
|
||
config.hidden_act = HiddenAct::GeluApproximate; | ||
|
||
let model = BertModel::load(vb, &config)?; | ||
Ok(BertEmbeder { model, tokenizer }) | ||
} | ||
|
||
pub fn tokenize_batch(&self, text_batch: &[String], device: &Device) -> anyhow::Result<Tensor> { | ||
let tokens = self | ||
.tokenizer | ||
.encode_batch(text_batch.to_vec(), true) | ||
.map_err(E::msg)?; | ||
let token_ids = tokens | ||
.iter() | ||
.map(|tokens| { | ||
let tokens = tokens.get_ids().to_vec(); | ||
Tensor::new(tokens.as_slice(), device) | ||
}) | ||
.collect::<candle_core::Result<Vec<_>>>()?; | ||
|
||
Ok(Tensor::stack(&token_ids, 0)?) | ||
} | ||
} | ||
|
||
impl Embed for BertEmbeder { | ||
async fn embed(&self, text_batch: &[String]) -> Result<Vec<EmbedData>, reqwest::Error> { | ||
let token_ids = self.tokenize_batch(text_batch, &self.model.device).unwrap(); | ||
let token_type_ids = token_ids.zeros_like().unwrap(); | ||
let embeddings = self.model.forward(&token_ids, &token_type_ids).unwrap(); | ||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3().unwrap(); | ||
let embeddings = (embeddings.sum(1).unwrap() / (n_tokens as f64)).unwrap(); | ||
let embeddings = normalize_l2(&embeddings).unwrap(); | ||
let encodings = embeddings.to_vec2::<f32>().unwrap(); | ||
let final_embeddings = encodings | ||
.iter() | ||
.zip(text_batch) | ||
.map(|(data, text)| EmbedData::new(data.to_vec(), Some(text.clone()))) | ||
.collect::<Vec<_>>(); | ||
Ok(final_embeddings) | ||
} | ||
} | ||
|
||
pub fn normalize_l2(v: &Tensor) -> candle_core::Result<Tensor> { | ||
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?) | ||
} |
Oops, something went wrong.