Skip to content

Commit

Permalink
Merge pull request #12 from akshayballal95/main
Browse files Browse the repository at this point in the history
New Features
  • Loading branch information
akshayballal95 authored Apr 21, 2024
2 parents 7def47e + ae429b8 commit 5acf1b9
Show file tree
Hide file tree
Showing 27 changed files with 1,678 additions and 167 deletions.
646 changes: 631 additions & 15 deletions Cargo.lock

Large diffs are not rendered by default.

11 changes: 9 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ edition = "2021"

[dependencies]
serde_json = "1.0.112"
reqwest = { version = "0.12.2", features = ["json"] }
reqwest = { version = "0.12.2", features = ["json", "blocking"] }
serde = {version = "1.0.196", features = ["derive"]}
pdf-extract = "0.7.4"
walkdir = "2.4.0"
Expand All @@ -19,9 +19,16 @@ candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.5.
candle-transformers = { git = "https://github.com/huggingface/candle.git", version = "0.5.0" }
candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.5.0" }
anyhow = "1.0.81"
tokio = {version = "1.37.0", features=["rt-multi-thread"]}
tokio = {version = "1.37.0", features=["rt-multi-thread", "macros"]}
pyo3 = { version = "0.21" }
intel-mkl-src = {version = "0.8.1", optional = true }
markdown-parser = "0.1.2"
markdown_to_text = "1.0.0"
scraper = "0.19.0"
text-cleaner = "0.1.0"

[dev-dependencies]
tempdir = "0.3.7"

[features]
mkl = ["dep:intel-mkl-src", "candle-nn/mkl", "candle-transformers/mkl", "candle-core/mkl"]
2 changes: 1 addition & 1 deletion embed_anything.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def embed_file(file_path: str, embeder: str) -> list[EmbedData]:
- A list of EmbedData objects.
"""

def embed_directory(file_path: str, embeder: str) -> list[EmbedData]:
def embed_directory(file_path: str, embeder: str, extensions: list[str]=None) -> list[EmbedData]:
"""
Embeds all the files in the given directory and returns a list of EmbedData objects.
Expand Down
2 changes: 1 addition & 1 deletion examples/bert.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import embed_anything
import time
start_time = time.time()
data = embed_anything.embed_directory("test_files", embeder= "Bert")
data = embed_anything.embed_directory("test_files", embeder= "Bert", extensions=["pdf"])
print(data[0])
end_time = time.time()
print("Time taken: ", end_time-start_time)
4 changes: 1 addition & 3 deletions examples/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@ use embed_anything::embed_directory;
use std::{path::PathBuf, time::Instant};

fn main() {
// let out = embed_file("test_files/TUe_SOP_AI_2.pdf", "Bert").unwrap();

let now = Instant::now();
let out = embed_directory(PathBuf::from("test_files"), "Bert").unwrap();
let out = embed_directory(PathBuf::from("test_files"), "Bert", Some(vec!["md".to_string()])).unwrap();
println!("{:?}", out);
let elapsed_time = now.elapsed();
println!("Elapsed Time: {}", elapsed_time.as_secs_f32());
Expand Down
2 changes: 1 addition & 1 deletion examples/clip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ fn main() {
// let out = embed_file("test_files/TUe_SOP_AI_2.pdf", "Bert").unwrap();

let now = Instant::now();
let out = embed_directory(PathBuf::from("test_files"), "Clip").unwrap();
let out = embed_directory(PathBuf::from("test_files"), "Clip", None).unwrap();
let query_emb_data = embed_query(vec!["Photo of a monkey".to_string()], "Clip").unwrap();
let n_vectors = out.len();
let out_embeddings = Tensor::from_vec(
Expand Down
44 changes: 44 additions & 0 deletions examples/web_embed.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use embed_anything::file_processor::website_processor;
use candle_core::Tensor;

#[tokio::main]
async fn main() {
let url = "https://en.wikipedia.org/wiki/Long_short-term_memory";

let website_processor = website_processor::WebsiteProcesor;
let webpage = website_processor.process_website(url).await.unwrap();
let embeder = embed_anything::embedding_model::bert::BertEmbeder::default();
let embed_data = webpage.embed_webpage(&embeder).await.unwrap();
let embeddings: Vec<Vec<f32>> = embed_data.iter().map(|data| data.embedding.clone()).collect();

let embeddings = Tensor::from_vec(
embeddings.iter().flatten().cloned().collect::<Vec<f32>>(),
(embeddings.len(), embeddings[0].len()),
&candle_core::Device::Cpu,
).unwrap();

let query = vec!["how to use lstm for nlp".to_string()];
let query_embedding: Vec<f32> = embeder.embed(&query, None).await.unwrap().iter().map(|data| data.embedding.clone()).flatten().collect();

let query_embedding_tensor = Tensor::from_vec(
query_embedding.clone(),
(1, query_embedding.len()),
&candle_core::Device::Cpu,
).unwrap();


let similarities = embeddings
.matmul(&query_embedding_tensor.transpose(0, 1).unwrap())
.unwrap()
.detach()
.squeeze(1)
.unwrap()
.to_vec1::<f32>()
.unwrap();

let max_similarity_index = similarities.iter().enumerate().max_by(|a, b| a.1.partial_cmp(b.1).unwrap()).unwrap().0;
let data = &embed_data[max_similarity_index];

println!("{:?}", data);

}
29 changes: 24 additions & 5 deletions src/embedding_model/bert.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

use std::collections::HashMap;

use anyhow::Error as E;
use candle_core::{Device, Tensor};
use tokenizers::{PaddingParams, Tokenizer};
use super::embed::{Embed, EmbedData};
use super::embed::{Embed, EmbedData, TextEmbed};
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config, HiddenAct, DTYPE};
use hf_hub::{api::sync::Api, Repo};
Expand Down Expand Up @@ -65,10 +67,8 @@ impl BertEmbeder {

Ok(Tensor::stack(&token_ids, 0)?)
}
}

impl Embed for BertEmbeder {
async fn embed(&self, text_batch: &[String]) -> Result<Vec<EmbedData>, reqwest::Error> {
pub async fn embed(&self, text_batch: &[String],metadata:Option<HashMap<String,String>>) -> Result<Vec<EmbedData>, reqwest::Error> {
let token_ids = self.tokenize_batch(text_batch, &self.model.device).unwrap();
let token_type_ids = token_ids.zeros_like().unwrap();
let embeddings = self.model.forward(&token_ids, &token_type_ids).unwrap();
Expand All @@ -79,12 +79,31 @@ impl Embed for BertEmbeder {
let final_embeddings = encodings
.iter()
.zip(text_batch)
.map(|(data, text)| EmbedData::new(data.to_vec(), Some(text.clone())))
.map(|(data, text)| EmbedData::new(data.to_vec(), Some(text.clone()), metadata.clone()))
.collect::<Vec<_>>();
Ok(final_embeddings)
}
}

impl Embed for BertEmbeder {
fn embed(
&self,
text_batch: &[String],metadata: Option<HashMap<String,String>>
) -> impl std::future::Future<Output = Result<Vec<EmbedData>, reqwest::Error>> {
self.embed(text_batch, metadata)
}
}

impl TextEmbed for BertEmbeder {
fn embed(
&self,
text_batch: &[String],
metadata: Option<HashMap<String,String>>
) -> impl std::future::Future<Output = Result<Vec<EmbedData>, reqwest::Error>> {
self.embed(text_batch, metadata)
}
}

pub fn normalize_l2(v: &Tensor) -> candle_core::Result<Tensor> {
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
}
126 changes: 104 additions & 22 deletions src/embedding_model/clip.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

use anyhow::Error as E;
use std::collections::HashMap;

use anyhow::Error as E;

use candle_core::{DType, Device, Tensor};
use candle_transformers::models::clip;
Expand Down Expand Up @@ -38,8 +39,8 @@ impl Default for ClipEmbeder {
}

impl ClipEmbeder {
pub fn new(model: clip::ClipModel, tokenizer: Tokenizer) -> Self {
ClipEmbeder { model, tokenizer }
pub fn new(model: clip::ClipModel, tokenizer: Tokenizer) -> Result<Self, E> {
Ok(ClipEmbeder { model, tokenizer })
}

pub fn get_tokenizer(tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
Expand Down Expand Up @@ -100,17 +101,21 @@ impl ClipEmbeder {
Ok((input_ids, vec_seq))
}

fn load_image<T: AsRef<std::path::Path>>(&self, path: T, image_size: usize) -> anyhow::Result<Tensor> {
fn load_image<T: AsRef<std::path::Path>>(
&self,
path: T,
image_size: usize,
) -> anyhow::Result<Tensor> {
let img = image::io::Reader::open(path)?.decode()?;
let (height, width) = (image_size, image_size);
let img = img.resize_to_fill(
width as u32,
height as u32,
image::imageops::FilterType::Triangle,
);

let img = img.to_rgb8();

let img = img.into_raw();
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
.permute((2, 0, 1))?
Expand All @@ -119,48 +124,74 @@ impl ClipEmbeder {
// .unsqueeze(0)?;
Ok(img)
}

fn load_images<T: AsRef<std::path::Path>>(
&self,
&self,
paths: &[T],
image_size: usize,
) -> anyhow::Result<Tensor> {
let mut images = vec![];

for path in paths {
let tensor = self.load_image(path, image_size)?;
images.push(tensor);
}

let images = Tensor::stack(&images, 0)?;

Ok(images)
}




}

impl EmbedImage for ClipEmbeder{
fn embed_image_batch<T: AsRef<std::path::Path>>(&self, image_paths:&[T]) -> anyhow::Result<Vec<EmbedData>> {
impl EmbedImage for ClipEmbeder {
fn embed_image_batch<T: AsRef<std::path::Path>>(
&self,
image_paths: &[T],

) -> anyhow::Result<Vec<EmbedData>> {
let config = clip::ClipConfig::vit_base_patch32();

let images = self.load_images(image_paths, config.image_size).unwrap();
let encodings = self.model.get_image_features(&images).unwrap().to_vec2::<f32>().unwrap();
let encodings = self
.model
.get_image_features(&images)
.unwrap()
.to_vec2::<f32>()
.unwrap();

let embeddings = encodings
.iter()
.zip(image_paths)
.map(|(data, path)| EmbedData::new(data.to_vec(), Some(path.as_ref().to_str().unwrap().to_string())))
.map(|(data, path)| {
EmbedData::new(
data.to_vec(),
Some(path.as_ref().to_str().unwrap().to_string()),
None,
)
})
.collect::<Vec<_>>();
Ok(embeddings)
}

fn embed_image<T: AsRef<std::path::Path>>(&self, image_path: T, metadata: Option<HashMap<String, String>>) -> anyhow::Result<EmbedData> {
let config = clip::ClipConfig::vit_base_patch32();
let image = self
.load_image(&image_path, config.image_size)
.unwrap()
.unsqueeze(0)
.unwrap();
let encoding = &self
.model
.get_image_features(&image)
.unwrap()
.to_vec2::<f32>()
.unwrap()[0];
Ok(EmbedData::new(encoding.to_vec(), None, metadata.clone()))
}
}

impl Embed for ClipEmbeder {
async fn embed(&self, text_batch: &[String]) -> Result<Vec<EmbedData>, reqwest::Error> {
async fn embed(&self, text_batch: &[String], metadata: Option<HashMap<String, String>>) -> Result<Vec<EmbedData>, reqwest::Error> {
let (input_ids, _vec_seq) = ClipEmbeder::tokenize_sequences(
Some(text_batch.to_vec()),
&self.tokenizer,
Expand All @@ -172,13 +203,64 @@ impl Embed for ClipEmbeder {
.model
.get_text_features(&input_ids)
.unwrap()
.to_vec2::<f32>().unwrap();
.to_vec2::<f32>()
.unwrap();
let embeddings = encodings
.iter()
.zip(text_batch)
.map(|(data, text)| EmbedData::new(data.to_vec(), Some(text.clone())))
.map(|(data, text)| EmbedData::new(data.to_vec(), Some(text.clone()), metadata.clone() ))
.collect::<Vec<_>>();
Ok(embeddings)
}
}

#[cfg(test)]
mod tests {
use super::*;
// Initializes a new ClipEmbeder with default values.
#[test]
fn test_default_initialization() {
let _clip_embeder = ClipEmbeder::default();

}

// Tests the tokenize_sequences method.
#[test]
fn test_tokenize_sequences() {
let clip_embeder = ClipEmbeder::default();
let sequences = Some(vec![
"Hey there how are you?".to_string(),
"EmbedAnything is the best!".to_string(),
]);
let (input_ids, vec_seq) = ClipEmbeder::tokenize_sequences(sequences, &clip_embeder.tokenizer, &Device::Cpu).unwrap();
assert_eq!(vec_seq, vec![
"Hey there how are you?".to_string(),
"EmbedAnything is the best!".to_string(),
]);
assert_eq!(input_ids.shape().clone().into_dims(), &[2, 8]);
}

// Tests the load_image method.
#[test]
fn test_load_image() {
let clip_embeder = ClipEmbeder::default();
let image = clip_embeder.load_image("test_files/clip/cat1.jpg", 224).unwrap();
assert_eq!(image.shape().clone().into_dims(), &[3, 224, 224]);
}

// Tests the load_images method.
#[test]
fn test_load_images() {
let clip_embeder = ClipEmbeder::default();
let images = clip_embeder.load_images(&["test_files/clip/cat1.jpg", "test_files/clip/cat2.jpeg"], 224).unwrap();
assert_eq!(images.shape().clone().into_dims(), &[2, 3, 224, 224]);
}

// Tests the embed_image_batch method.
#[test]
fn test_embed_image_batch() {
let clip_embeder = ClipEmbeder::default();
let embeddings = clip_embeder.embed_image_batch(&["test_files/clip/cat1.jpg", "test_files/clip/cat2.jpeg"]).unwrap();
assert_eq!(embeddings.len(), 2);
}
}
Loading

0 comments on commit 5acf1b9

Please sign in to comment.