Skip to content

Commit

Permalink
Update file_processor/mod.rs, file_processor/pdf_processor.rs, Cargo.…
Browse files Browse the repository at this point in the history
…toml, test.py, examples/web_embed.rs, and embedding_model/jina.rs
  • Loading branch information
akshayballal95 committed Apr 21, 2024
1 parent f6340f0 commit ae429b8
Show file tree
Hide file tree
Showing 14 changed files with 833 additions and 70 deletions.
418 changes: 417 additions & 1 deletion Cargo.lock

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ edition = "2021"

[dependencies]
serde_json = "1.0.112"
reqwest = { version = "0.12.2", features = ["json"] }
reqwest = { version = "0.12.2", features = ["json", "blocking"] }
serde = {version = "1.0.196", features = ["derive"]}
pdf-extract = "0.7.4"
walkdir = "2.4.0"
Expand All @@ -24,6 +24,8 @@ pyo3 = { version = "0.21" }
intel-mkl-src = {version = "0.8.1", optional = true }
markdown-parser = "0.1.2"
markdown_to_text = "1.0.0"
scraper = "0.19.0"
text-cleaner = "0.1.0"

[dev-dependencies]
tempdir = "0.3.7"
Expand Down
44 changes: 44 additions & 0 deletions examples/web_embed.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use embed_anything::file_processor::website_processor;
use candle_core::Tensor;

#[tokio::main]
async fn main() {
let url = "https://en.wikipedia.org/wiki/Long_short-term_memory";

let website_processor = website_processor::WebsiteProcesor;
let webpage = website_processor.process_website(url).await.unwrap();
let embeder = embed_anything::embedding_model::bert::BertEmbeder::default();
let embed_data = webpage.embed_webpage(&embeder).await.unwrap();
let embeddings: Vec<Vec<f32>> = embed_data.iter().map(|data| data.embedding.clone()).collect();

let embeddings = Tensor::from_vec(
embeddings.iter().flatten().cloned().collect::<Vec<f32>>(),
(embeddings.len(), embeddings[0].len()),
&candle_core::Device::Cpu,
).unwrap();

let query = vec!["how to use lstm for nlp".to_string()];
let query_embedding: Vec<f32> = embeder.embed(&query, None).await.unwrap().iter().map(|data| data.embedding.clone()).flatten().collect();

let query_embedding_tensor = Tensor::from_vec(
query_embedding.clone(),
(1, query_embedding.len()),
&candle_core::Device::Cpu,
).unwrap();


let similarities = embeddings
.matmul(&query_embedding_tensor.transpose(0, 1).unwrap())
.unwrap()
.detach()
.squeeze(1)
.unwrap()
.to_vec1::<f32>()
.unwrap();

let max_similarity_index = similarities.iter().enumerate().max_by(|a, b| a.1.partial_cmp(b.1).unwrap()).unwrap().0;
let data = &embed_data[max_similarity_index];

println!("{:?}", data);

}
29 changes: 24 additions & 5 deletions src/embedding_model/bert.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

use std::collections::HashMap;

use anyhow::Error as E;
use candle_core::{Device, Tensor};
use tokenizers::{PaddingParams, Tokenizer};
use super::embed::{Embed, EmbedData};
use super::embed::{Embed, EmbedData, TextEmbed};
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config, HiddenAct, DTYPE};
use hf_hub::{api::sync::Api, Repo};
Expand Down Expand Up @@ -65,10 +67,8 @@ impl BertEmbeder {

Ok(Tensor::stack(&token_ids, 0)?)
}
}

impl Embed for BertEmbeder {
async fn embed(&self, text_batch: &[String]) -> Result<Vec<EmbedData>, reqwest::Error> {
pub async fn embed(&self, text_batch: &[String],metadata:Option<HashMap<String,String>>) -> Result<Vec<EmbedData>, reqwest::Error> {
let token_ids = self.tokenize_batch(text_batch, &self.model.device).unwrap();
let token_type_ids = token_ids.zeros_like().unwrap();
let embeddings = self.model.forward(&token_ids, &token_type_ids).unwrap();
Expand All @@ -79,12 +79,31 @@ impl Embed for BertEmbeder {
let final_embeddings = encodings
.iter()
.zip(text_batch)
.map(|(data, text)| EmbedData::new(data.to_vec(), Some(text.clone())))
.map(|(data, text)| EmbedData::new(data.to_vec(), Some(text.clone()), metadata.clone()))
.collect::<Vec<_>>();
Ok(final_embeddings)
}
}

impl Embed for BertEmbeder {
fn embed(
&self,
text_batch: &[String],metadata: Option<HashMap<String,String>>
) -> impl std::future::Future<Output = Result<Vec<EmbedData>, reqwest::Error>> {
self.embed(text_batch, metadata)
}
}

impl TextEmbed for BertEmbeder {
fn embed(
&self,
text_batch: &[String],
metadata: Option<HashMap<String,String>>
) -> impl std::future::Future<Output = Result<Vec<EmbedData>, reqwest::Error>> {
self.embed(text_batch, metadata)
}
}

pub fn normalize_l2(v: &Tensor) -> candle_core::Result<Tensor> {
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
}
12 changes: 8 additions & 4 deletions src/embedding_model/clip.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

use std::collections::HashMap;

use anyhow::Error as E;

use candle_core::{DType, Device, Tensor};
Expand Down Expand Up @@ -145,6 +147,7 @@ impl EmbedImage for ClipEmbeder {
fn embed_image_batch<T: AsRef<std::path::Path>>(
&self,
image_paths: &[T],

) -> anyhow::Result<Vec<EmbedData>> {
let config = clip::ClipConfig::vit_base_patch32();

Expand All @@ -163,13 +166,14 @@ impl EmbedImage for ClipEmbeder {
EmbedData::new(
data.to_vec(),
Some(path.as_ref().to_str().unwrap().to_string()),
None,
)
})
.collect::<Vec<_>>();
Ok(embeddings)
}

fn embed_image<T: AsRef<std::path::Path>>(&self, image_path: T) -> anyhow::Result<EmbedData> {
fn embed_image<T: AsRef<std::path::Path>>(&self, image_path: T, metadata: Option<HashMap<String, String>>) -> anyhow::Result<EmbedData> {
let config = clip::ClipConfig::vit_base_patch32();
let image = self
.load_image(&image_path, config.image_size)
Expand All @@ -182,12 +186,12 @@ impl EmbedImage for ClipEmbeder {
.unwrap()
.to_vec2::<f32>()
.unwrap()[0];
Ok(EmbedData::new(encoding.to_vec(), None))
Ok(EmbedData::new(encoding.to_vec(), None, metadata.clone()))
}
}

impl Embed for ClipEmbeder {
async fn embed(&self, text_batch: &[String]) -> Result<Vec<EmbedData>, reqwest::Error> {
async fn embed(&self, text_batch: &[String], metadata: Option<HashMap<String, String>>) -> Result<Vec<EmbedData>, reqwest::Error> {
let (input_ids, _vec_seq) = ClipEmbeder::tokenize_sequences(
Some(text_batch.to_vec()),
&self.tokenizer,
Expand All @@ -204,7 +208,7 @@ impl Embed for ClipEmbeder {
let embeddings = encodings
.iter()
.zip(text_batch)
.map(|(data, text)| EmbedData::new(data.to_vec(), Some(text.clone())))
.map(|(data, text)| EmbedData::new(data.to_vec(), Some(text.clone()), metadata.clone() ))
.collect::<Vec<_>>();
Ok(embeddings)
}
Expand Down
40 changes: 30 additions & 10 deletions src/embedding_model/embed.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use pyo3::prelude::*;
use serde::Deserialize;
use std::collections::HashMap;
use std::fmt::Debug;


use super::jina::JinaEmbeder;
Expand All @@ -12,26 +13,40 @@ pub struct EmbedResponse {
pub data: Vec<EmbedData>,
pub usage: HashMap<String, usize>,
}


#[pyclass]
#[derive(Deserialize, Debug, Clone)]
pub struct EmbedData {
#[pyo3(get, set)]
pub embedding: Vec<f32>,
#[pyo3(get, set)]
pub text: Option<String>,
#[pyo3(get, set)]
pub metadata: Option<HashMap<String, String>>,
}

impl Default for EmbedData {
fn default() -> Self {
Self {
embedding: Vec::new(),
text: None,
metadata: None,
}
}
}

#[pymethods]
impl EmbedData {
#[new]
pub fn new(embedding: Vec<f32>, text: Option<String>) -> Self {
Self { embedding, text }
pub fn new(embedding: Vec<f32>, text: Option<String>, metadata:Option<HashMap<String, String>>) -> Self {
Self { embedding, text, metadata }
}

pub fn __str__(&self) -> String {
format!(
"EmbedData(embedding: {:?}, text: {:?})",
self.embedding, self.text
"EmbedData(embedding: {:?}, text: {:?}, metadata: {:?})",
self.embedding, self.text, self.metadata.clone()
)
}
}
Expand All @@ -45,12 +60,12 @@ pub enum Embeder {
}

impl Embeder {
pub async fn embed(&self, text_batch: &[String]) -> Result<Vec<EmbedData>, reqwest::Error> {
pub async fn embed(&self, text_batch: &[String], metadata: Option<HashMap<String, String>>) -> Result<Vec<EmbedData>, reqwest::Error> {
match self {
Embeder::OpenAI(embeder) => embeder.embed(text_batch).await,
Embeder::Jina(embeder) => embeder.embed(text_batch).await,
Embeder::Clip(embeder) => embeder.embed(text_batch).await,
Embeder::Bert(embeder) => embeder.embed(text_batch).await,
Embeder::OpenAI(embeder) => TextEmbed::embed(embeder, text_batch, metadata).await,
Embeder::Jina(embeder) => TextEmbed::embed(embeder, text_batch, metadata).await,
Embeder::Clip(embeder) => Embed::embed(embeder, text_batch, metadata).await,
Embeder::Bert(embeder) => TextEmbed::embed(embeder, text_batch, metadata).await,
}
}
}
Expand All @@ -60,11 +75,16 @@ pub trait Embed {
fn embed(
&self,
text_batch: &[String],
metadata: Option<HashMap<String, String>>,
) -> impl std::future::Future<Output = Result<Vec<EmbedData>, reqwest::Error>>;

}

pub trait TextEmbed {
fn embed(&self, text_batch: &[String], metadata: Option<HashMap<String, String>>) -> impl std::future::Future<Output = Result<Vec<EmbedData>, reqwest::Error>>;
}

pub trait EmbedImage {
fn embed_image<T: AsRef<std::path::Path>>(&self, image_path: T) -> anyhow::Result<EmbedData>;
fn embed_image<T: AsRef<std::path::Path>>(&self, image_path: T, metadata: Option<HashMap<String, String>>) -> anyhow::Result<EmbedData>;
fn embed_image_batch<T: AsRef<std::path::Path>>(&self, image_paths:&[T]) -> anyhow::Result<Vec<EmbedData>>;
}
32 changes: 27 additions & 5 deletions src/embedding_model/jina.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use super::embed::{Embed, EmbedData};
use std::collections::HashMap;

use super::embed::{Embed, EmbedData, TextEmbed};
use anyhow::Error as E;
use candle_core::{DType, Device, Tensor};
use candle_nn::{Module, VarBuilder};
Expand Down Expand Up @@ -68,10 +70,8 @@ impl JinaEmbeder {
.collect::<candle_core::Result<Vec<_>>>()?;
Ok(Tensor::stack(&token_ids, 0)?)
}
}

impl Embed for JinaEmbeder {
async fn embed(&self, text_batch: &[String]) -> Result<Vec<EmbedData>, reqwest::Error> {
async fn embed(&self, text_batch: &[String], metadata:Option<HashMap<String, String>>) -> Result<Vec<EmbedData>, reqwest::Error> {
let token_ids = self.tokenize_batch(text_batch, &self.model.device).unwrap();
let embeddings = self.model.forward(&token_ids).unwrap();

Expand All @@ -84,10 +84,32 @@ impl Embed for JinaEmbeder {
let final_embeddings = encodings
.iter()
.zip(text_batch)
.map(|(data, text)| EmbedData::new(data.to_vec(), Some(text.clone())))
.map(|(data, text)| EmbedData::new(data.to_vec(), Some(text.clone()), metadata.clone()))
.collect::<Vec<_>>();
Ok(final_embeddings)
}


}

impl Embed for JinaEmbeder {
fn embed(
&self,
text_batch: &[String],
metadata: Option<HashMap<String, String>>,
) -> impl std::future::Future<Output = Result<Vec<EmbedData>, reqwest::Error>> {
self.embed(text_batch, metadata)
}
}

impl TextEmbed for JinaEmbeder {
fn embed(
&self,
text_batch: &[String],
metadata: Option<HashMap<String, String>>,
) -> impl std::future::Future<Output = Result<Vec<EmbedData>, reqwest::Error>> {
self.embed(text_batch, metadata)
}
}

pub fn normalize_l2(v: &Tensor) -> candle_core::Result<Tensor> {
Expand Down
Loading

0 comments on commit ae429b8

Please sign in to comment.