Skip to content

Commit

Permalink
Merge pull request #6 from StarlightSearch/dev
Browse files Browse the repository at this point in the history
Add Clip and Local embedding Model
  • Loading branch information
akshayballal95 authored Apr 12, 2024
2 parents b72dc13 + 00d5dcc commit 1e6e5c3
Show file tree
Hide file tree
Showing 23 changed files with 1,060 additions and 701 deletions.
830 changes: 311 additions & 519 deletions Cargo.lock

Large diffs are not rendered by default.

13 changes: 6 additions & 7 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@ name = "embed_anything"
version = "0.1.5"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
name = "embed_anything"
crate-type = ["cdylib"]


[dependencies]
pyo3 = { version = "0.20", features = ["extension-module"] }
Expand All @@ -17,13 +14,15 @@ reqwest = { version = "0.12.2", features = ["json"] }
futures = "0.3.30"
serde = {version = "1.0.196", features = ["rc", "derive"]}
pdf-extract = "0.7.4"
rust-bert = { git = "https://github.com/guillaume-be/rust-bert.git" }
walkdir = "2.4.0"
regex = "1.10.3"
rayon = "1.8.1"
candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.5.0" }
image = "0.25.1"
hf-hub = "0.3.2"
tokenizers = "0.15.2"
candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.5.0" }
candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.5.0", features = ["mkl"] }
candle-transformers = { git = "https://github.com/huggingface/candle.git", version = "0.5.0", features = ["mkl"] }
candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.5.0", features = ["mkl"] }
anyhow = "1.0.81"
intel-mkl-src = "0.8.1"

37 changes: 35 additions & 2 deletions embed_anything.pyi
Original file line number Diff line number Diff line change
@@ -1,14 +1,47 @@
from .embed_anything import *
from embed_anything import EmbedData

def embed_file(file_path: str) -> list[EmbedData]:
def embed_query(query: list[str], embeder: str) -> list[EmbedData]:
"""
Embeds the given query and returns an EmbedData object.
### Arguments:
- `query`: The query to embed.
- `embeder`: The name of the embedding model to use. Choose between "OpenAI" and "AllMiniLmL12V2"
### Returns:
- An EmbedData object.
"""
def embed_file(file_path: str, embeder: str) -> list[EmbedData]:
"""
Embeds the file at the given path and returns a list of EmbedData objects.
### Arguments:
- `file_path`: The path to the file to embed.
- `embeder`: The name of the embedding model to use. Choose between "OpenAI" and "AllMiniLmL12V2"
### Returns:
- A list of EmbedData objects.
"""

def embed_directory(file_path: str, embeder: str) -> list[EmbedData]:
"""
Embeds all the files in the given directory and returns a list of EmbedData objects.
### Arguments:
- `file_path`: The path to the directory containing the files to embed.
- `embeder`: The name of the embedding model to use. Choose between "OpenAI" and "AllMiniLmL12V2"
### Returns:
- A list of EmbedData objects.
"""

class EmbedData:
"""
Represents the data of an embedded file.
### Attributes:
- `embedding`: The embedding of the file.
- `text`: The text for which the embedding is generated for.
"""
embedding: list[float]
text: str
62 changes: 62 additions & 0 deletions examples/clip.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
use candle_core::{Device, Tensor};
use embed_anything::{embed_directory, embed_query};
use std::path::PathBuf;

fn main() {
// let out = embed_file("test_files/TUe_SOP_AI_2.pdf", "Bert").unwrap();
let out = embed_directory(PathBuf::from("test_files"), "Clip").unwrap();
let query_emb_data = embed_query(vec!["Photo of a monkey".to_string()], "Clip").unwrap();
let n_vectors = out.len();
let out_embeddings = Tensor::from_vec(
out.iter()
.map(|embed| embed.embedding.clone())
.collect::<Vec<_>>()
.iter()
.flatten()
.cloned()
.collect::<Vec<f32>>(),
(n_vectors, out[0].embedding.len()),
&Device::Cpu,
)
.unwrap();

let image_paths = out
.iter()
.map(|embed| embed.text.clone().unwrap())
.collect::<Vec<_>>();

let query_embeddings = Tensor::from_vec(
query_emb_data
.iter()
.map(|embed| embed.embedding.clone())
.collect::<Vec<_>>()
.iter()
.flatten()
.cloned()
.collect::<Vec<f32>>(),
(1, query_emb_data[0].embedding.len()),
&Device::Cpu,
)
.unwrap();

let similarities = out_embeddings
.matmul(&query_embeddings.transpose(0, 1).unwrap())
.unwrap()
.detach()
.squeeze(1)
.unwrap()
.to_vec1::<f32>()
.unwrap();
let mut indices: Vec<usize> = (0..similarities.len()).collect();
indices.sort_by(|a, b| similarities[*b].partial_cmp(&similarities[*a]).unwrap());

let top_3_indices = indices[0..3].to_vec();
let top_3_image_paths = top_3_indices
.iter()
.map(|i| image_paths[*i].clone())
.collect::<Vec<String>>();

let similar_image =top_3_image_paths[0].clone();

println!("{:?}", similar_image)
}
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@ classifiers = [
"Programming Language :: Rust",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
"License :: OSI Approved :: GNU General Public License v3 (GPLv3)"

]
dynamic = ["version"]
license = {file = "LICENSE"}

[tool.maturin]
features = ["pyo3/extension-module"]

Expand Down
144 changes: 0 additions & 144 deletions src/embed.rs

This file was deleted.

87 changes: 87 additions & 0 deletions src/embedding_model/bert.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
extern crate intel_mkl_src;

use anyhow::Error as E;
use candle_core::{Device, Tensor};
use tokenizers::{PaddingParams, Tokenizer};
use super::embed::{Embed, EmbedData};
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config, HiddenAct, DTYPE};
use hf_hub::{api::sync::Api, Repo};

pub struct BertEmbeder {
pub model: BertModel,
pub tokenizer: Tokenizer,
}
impl BertEmbeder {
pub fn default() -> anyhow::Result<Self> {
let device = Device::Cpu;
let default_model = "sentence-transformers/all-MiniLM-L12-v2".to_string();
let default_revision = "refs/pr/21".to_string();
let (model_id, _revision) = (default_model, default_revision);
let repo = Repo::model(model_id);
let (config_filename, tokenizer_filename, weights_filename) = {
let api = Api::new()?;
let api = api.repo(repo);
let config = api.get("config.json")?;
let tokenizer = api.get("tokenizer.json")?;
let weights = api.get("model.safetensors")?;

(config, tokenizer, weights)
};
let config = std::fs::read_to_string(config_filename)?;
let mut config: Config = serde_json::from_str(&config)?;
let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;

let pp = PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
..Default::default()
};
tokenizer.with_padding(Some(pp));

let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? };

config.hidden_act = HiddenAct::GeluApproximate;

let model = BertModel::load(vb, &config)?;
Ok(BertEmbeder { model, tokenizer })
}

pub fn tokenize_batch(&self, text_batch: &[String], device: &Device) -> anyhow::Result<Tensor> {
let tokens = self
.tokenizer
.encode_batch(text_batch.to_vec(), true)
.map_err(E::msg)?;
let token_ids = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_ids().to_vec();
Tensor::new(tokens.as_slice(), device)
})
.collect::<candle_core::Result<Vec<_>>>()?;

Ok(Tensor::stack(&token_ids, 0)?)
}
}

impl Embed for BertEmbeder {
async fn embed(&self, text_batch: &[String]) -> Result<Vec<EmbedData>, reqwest::Error> {
let token_ids = self.tokenize_batch(text_batch, &self.model.device).unwrap();
let token_type_ids = token_ids.zeros_like().unwrap();
let embeddings = self.model.forward(&token_ids, &token_type_ids).unwrap();
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3().unwrap();
let embeddings = (embeddings.sum(1).unwrap() / (n_tokens as f64)).unwrap();
let embeddings = normalize_l2(&embeddings).unwrap();
let encodings = embeddings.to_vec2::<f32>().unwrap();
let final_embeddings = encodings
.iter()
.zip(text_batch)
.map(|(data, text)| EmbedData::new(data.to_vec(), Some(text.clone())))
.collect::<Vec<_>>();
Ok(final_embeddings)
}
}

pub fn normalize_l2(v: &Tensor) -> candle_core::Result<Tensor> {
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
}
Loading

0 comments on commit 1e6e5c3

Please sign in to comment.