diff --git a/Cargo.lock b/Cargo.lock index c0cca263..cfb7cc68 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -26,6 +26,30 @@ dependencies = [ "pom", ] +[[package]] +name = "ahash" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "891477e0c6a8957309ee5c45a6368af3ae14bb510732d2684ffa19af310920f9" +dependencies = [ + "getrandom", + "once_cell", + "version_check", +] + +[[package]] +name = "ahash" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +dependencies = [ + "cfg-if", + "getrandom", + "once_cell", + "version_check", + "zerocopy", +] + [[package]] name = "aho-corasick" version = "1.1.3" @@ -529,6 +553,29 @@ dependencies = [ "typenum", ] +[[package]] +name = "cssparser" +version = "0.31.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b3df4f93e5fbbe73ec01ec8d3f68bba73107993a5b1e7519273c32db9b0d5be" +dependencies = [ + "cssparser-macros", + "dtoa-short", + "itoa", + "phf 0.11.2", + "smallvec", +] + +[[package]] +name = "cssparser-macros" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13b588ba4ac1a99f7f2964d24b3d896ddc6bf847ee3855dbd4366f058cfcd331" +dependencies = [ + "quote", + "syn 2.0.58", +] + [[package]] name = "darling" version = "0.14.4" @@ -670,6 +717,23 @@ dependencies = [ "syn 2.0.58", ] +[[package]] +name = "derive_more" +version = "0.99.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fb810d30a7c1953f91334de7244731fc3f3c10d7fe163338a35b9f640960321" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "diacritics" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "720807774f004a558d2b6d88e17b27a9c7cedccd541f4851446976631da2a151" + [[package]] name = "digest" version = "0.10.7" @@ -710,6 +774,21 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "dtoa" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcbb2bf8e87535c23f7a8a321e364ce21462d0ff10cb6407820e8e96dfff6653" + +[[package]] +name = "dtoa-short" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbaceec3c6e4211c79e7b1800fb9680527106beb2f9c51904a3210c03a448c74" +dependencies = [ + "dtoa", +] + [[package]] name = "dyn-stack" version = "0.10.0" @@ -720,6 +799,12 @@ dependencies = [ "reborrow", ] +[[package]] +name = "ego-tree" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a68a4904193147e0a8dec3314640e6db742afd5f6e634f428a6af230d9b3591" + [[package]] name = "either" version = "1.10.0" @@ -744,9 +829,11 @@ dependencies = [ "rayon", "regex", "reqwest", + "scraper", "serde", "serde_json", "tempdir", + "text-cleaner", "tokenizers", "tokio", "walkdir", @@ -1006,6 +1093,16 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a06f77d526c1a601b7c4cdd98f54b5eaabffc14d5f2f0296febdc7f357c6d3ba" +[[package]] +name = "futf" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df420e2e84819663797d1ec6544b13c5be84629e7bb00dc960d6917db2987843" +dependencies = [ + "mac", + "new_debug_unreachable", +] + [[package]] name = "futures-channel" version = "0.3.30" @@ -1013,6 +1110,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" dependencies = [ "futures-core", + "futures-sink", ] [[package]] @@ -1021,6 +1119,12 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" +[[package]] +name = "futures-io" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" + [[package]] name = "futures-sink" version = "0.3.30" @@ -1040,9 +1144,22 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" dependencies = [ "futures-core", + "futures-io", + "futures-sink", "futures-task", + "memchr", "pin-project-lite", "pin-utils", + "slab", +] + +[[package]] +name = "fxhash" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c31b6d751ae2c7f11320402d34e41349dd1016f8d5d45e48c4312bc8625af50c" +dependencies = [ + "byteorder", ] [[package]] @@ -1254,6 +1371,15 @@ dependencies = [ "rand_distr", ] +[[package]] +name = "hashbrown" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" +dependencies = [ + "ahash 0.7.8", +] + [[package]] name = "hashbrown" version = "0.12.3" @@ -1301,6 +1427,20 @@ dependencies = [ "ureq", ] +[[package]] +name = "html5ever" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bea68cab48b8459f17cf1c944c67ddc572d272d9f2b274140f223ecb1da4a3b7" +dependencies = [ + "log", + "mac", + "markup5ever", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "http" version = "1.1.0" @@ -1426,6 +1566,17 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" +[[package]] +name = "idna" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "418a0a6fab821475f634efe3ccc45c013f742efe03d853e8d3355d5cb850ecf8" +dependencies = [ + "matches", + "unicode-bidi", + "unicode-normalization", +] + [[package]] name = "idna" version = "0.5.0" @@ -1661,6 +1812,15 @@ version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" +[[package]] +name = "linkify" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28d9967eb7d0bc31c39c6f52e8fce42991c0cd1f7a2078326f0b7a399a584c8d" +dependencies = [ + "memchr", +] + [[package]] name = "linux-raw-sys" version = "0.4.13" @@ -1709,6 +1869,12 @@ dependencies = [ "weezl", ] +[[package]] +name = "mac" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c41e0c4fef86961ac6d6f8a82609f55f31b05e4fce149ac5710e439df7619ba4" + [[package]] name = "macro_rules_attribute" version = "0.2.0" @@ -1752,6 +1918,26 @@ dependencies = [ "pulldown-cmark", ] +[[package]] +name = "markup5ever" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2629bb1404f3d34c2e921f21fd34ba00b206124c81f65c50b43b6aaefeb016" +dependencies = [ + "log", + "phf 0.10.1", + "phf_codegen", + "string_cache", + "string_cache_codegen", + "tendril", +] + +[[package]] +name = "matches" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2532096657941c2fea9c289d370a250971c689d4f143798ff67113ec042024a5" + [[package]] name = "maybe-rayon" version = "0.1.1" @@ -2149,6 +2335,86 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "phf" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fabbf1ead8a5bcbc20f5f8b939ee3f5b0f6f281b6ad3468b84656b658b455259" +dependencies = [ + "phf_shared 0.10.0", +] + +[[package]] +name = "phf" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc" +dependencies = [ + "phf_macros", + "phf_shared 0.11.2", +] + +[[package]] +name = "phf_codegen" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fb1c3a8bc4dd4e5cfce29b44ffc14bedd2ee294559a294e2a4d4c9e9a6a13cd" +dependencies = [ + "phf_generator 0.10.0", + "phf_shared 0.10.0", +] + +[[package]] +name = "phf_generator" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d5285893bb5eb82e6aaf5d59ee909a06a16737a8970984dd7746ba9283498d6" +dependencies = [ + "phf_shared 0.10.0", + "rand 0.8.5", +] + +[[package]] +name = "phf_generator" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48e4cc64c2ad9ebe670cb8fd69dd50ae301650392e81c05f9bfcb2d5bdbc24b0" +dependencies = [ + "phf_shared 0.11.2", + "rand 0.8.5", +] + +[[package]] +name = "phf_macros" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3444646e286606587e49f3bcf1679b8cef1dc2c5ecc29ddacaffc305180d464b" +dependencies = [ + "phf_generator 0.11.2", + "phf_shared 0.11.2", + "proc-macro2", + "quote", + "syn 2.0.58", +] + +[[package]] +name = "phf_shared" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6796ad771acdc0123d2a88dc428b5e38ef24456743ddb1744ed628f9815c096" +dependencies = [ + "siphasher", +] + +[[package]] +name = "phf_shared" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90fcb95eef784c2ac79119d1dd819e162b5da872ce6f3c3abe1e8ca1c082f72b" +dependencies = [ + "siphasher", +] + [[package]] name = "pin-project" version = "1.1.5" @@ -2230,6 +2496,12 @@ version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +[[package]] +name = "precomputed-hash" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" + [[package]] name = "proc-macro-error" version = "1.0.4" @@ -2633,6 +2905,7 @@ dependencies = [ "base64 0.22.0", "bytes", "encoding_rs", + "futures-channel", "futures-core", "futures-util", "h2", @@ -2790,6 +3063,22 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "scraper" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b80b33679ff7a0ea53d37f3b39de77ea0c75b12c5805ac43ec0c33b3051af1b" +dependencies = [ + "ahash 0.8.11", + "cssparser", + "ego-tree", + "getopts", + "html5ever", + "once_cell", + "selectors", + "tendril", +] + [[package]] name = "security-framework" version = "2.10.0" @@ -2813,6 +3102,25 @@ dependencies = [ "libc", ] +[[package]] +name = "selectors" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4eb30575f3638fc8f6815f448d50cb1a2e255b0897985c8c59f4d37b72a07b06" +dependencies = [ + "bitflags 2.5.0", + "cssparser", + "derive_more", + "fxhash", + "log", + "new_debug_unreachable", + "phf 0.10.1", + "phf_codegen", + "precomputed-hash", + "servo_arc", + "smallvec", +] + [[package]] name = "seq-macro" version = "0.3.5" @@ -2892,6 +3200,15 @@ dependencies = [ "yaml-rust", ] +[[package]] +name = "servo_arc" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d036d71a959e00c77a63538b90a6c2390969f9772b096ea837205c6bd0491a44" +dependencies = [ + "stable_deref_trait", +] + [[package]] name = "sha2" version = "0.10.8" @@ -2918,6 +3235,12 @@ dependencies = [ "quote", ] +[[package]] +name = "siphasher" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" + [[package]] name = "slab" version = "0.4.9" @@ -2970,6 +3293,32 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "string_cache" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f91138e76242f575eb1d3b38b4f1362f10d3a43f47d182a5b359af488a02293b" +dependencies = [ + "new_debug_unreachable", + "once_cell", + "parking_lot", + "phf_shared 0.10.0", + "precomputed-hash", + "serde", +] + +[[package]] +name = "string_cache_codegen" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bb30289b722be4ff74a408c3cc27edeaad656e06cb1fe8fa9231fa59c728988" +dependencies = [ + "phf_generator 0.10.0", + "phf_shared 0.10.0", + "proc-macro2", + "quote", +] + [[package]] name = "strsim" version = "0.10.0" @@ -3114,6 +3463,32 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "tendril" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d24a120c5fc464a3458240ee02c299ebcb9d67b5249c8848b09d639dca8d7bb0" +dependencies = [ + "futf", + "mac", + "utf-8", +] + +[[package]] +name = "text-cleaner" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29ba3bc9a57b04ba93202d084e285b6ff4c59ab44581167726d1aa76c30ce37f" +dependencies = [ + "diacritics", + "hashbrown 0.11.2", + "lazy_static", + "linkify", + "regex", + "unicode-normalization", + "validator", +] + [[package]] name = "thiserror" version = "1.0.58" @@ -3496,10 +3871,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" dependencies = [ "form_urlencoded", - "idna", + "idna 0.5.0", "percent-encoding", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf8parse" version = "0.2.1" @@ -3526,6 +3907,21 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "validator" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f07b0a1390e01c0fc35ebb26b28ced33c9a3808f7f9fbe94d3cc01e233bfeed5" +dependencies = [ + "idna 0.2.3", + "lazy_static", + "regex", + "serde", + "serde_derive", + "serde_json", + "url", +] + [[package]] name = "vcpkg" version = "0.2.15" @@ -3895,6 +4291,26 @@ dependencies = [ "synstructure", ] +[[package]] +name = "zerocopy" +version = "0.7.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74d4d3961e53fa4c9a25a8637fc2bfaf2595b3d3ae34875568a5cf64787716be" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.58", +] + [[package]] name = "zerofrom" version = "0.1.3" diff --git a/Cargo.toml b/Cargo.toml index 72922e66..200eaf0b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" [dependencies] serde_json = "1.0.112" -reqwest = { version = "0.12.2", features = ["json"] } +reqwest = { version = "0.12.2", features = ["json", "blocking"] } serde = {version = "1.0.196", features = ["derive"]} pdf-extract = "0.7.4" walkdir = "2.4.0" @@ -24,6 +24,8 @@ pyo3 = { version = "0.21" } intel-mkl-src = {version = "0.8.1", optional = true } markdown-parser = "0.1.2" markdown_to_text = "1.0.0" +scraper = "0.19.0" +text-cleaner = "0.1.0" [dev-dependencies] tempdir = "0.3.7" diff --git a/examples/web_embed.rs b/examples/web_embed.rs new file mode 100644 index 00000000..76a17bf0 --- /dev/null +++ b/examples/web_embed.rs @@ -0,0 +1,44 @@ +use embed_anything::file_processor::website_processor; +use candle_core::Tensor; + +#[tokio::main] +async fn main() { + let url = "https://en.wikipedia.org/wiki/Long_short-term_memory"; + + let website_processor = website_processor::WebsiteProcesor; + let webpage = website_processor.process_website(url).await.unwrap(); + let embeder = embed_anything::embedding_model::bert::BertEmbeder::default(); + let embed_data = webpage.embed_webpage(&embeder).await.unwrap(); + let embeddings: Vec> = embed_data.iter().map(|data| data.embedding.clone()).collect(); + + let embeddings = Tensor::from_vec( + embeddings.iter().flatten().cloned().collect::>(), + (embeddings.len(), embeddings[0].len()), + &candle_core::Device::Cpu, + ).unwrap(); + + let query = vec!["how to use lstm for nlp".to_string()]; + let query_embedding: Vec = embeder.embed(&query, None).await.unwrap().iter().map(|data| data.embedding.clone()).flatten().collect(); + + let query_embedding_tensor = Tensor::from_vec( + query_embedding.clone(), + (1, query_embedding.len()), + &candle_core::Device::Cpu, + ).unwrap(); + + + let similarities = embeddings + .matmul(&query_embedding_tensor.transpose(0, 1).unwrap()) + .unwrap() + .detach() + .squeeze(1) + .unwrap() + .to_vec1::() + .unwrap(); + + let max_similarity_index = similarities.iter().enumerate().max_by(|a, b| a.1.partial_cmp(b.1).unwrap()).unwrap().0; + let data = &embed_data[max_similarity_index]; + + println!("{:?}", data); + +} \ No newline at end of file diff --git a/src/embedding_model/bert.rs b/src/embedding_model/bert.rs index 6232ff35..5b09d9ed 100644 --- a/src/embedding_model/bert.rs +++ b/src/embedding_model/bert.rs @@ -1,10 +1,12 @@ #[cfg(feature = "mkl")] extern crate intel_mkl_src; +use std::collections::HashMap; + use anyhow::Error as E; use candle_core::{Device, Tensor}; use tokenizers::{PaddingParams, Tokenizer}; -use super::embed::{Embed, EmbedData}; +use super::embed::{Embed, EmbedData, TextEmbed}; use candle_nn::VarBuilder; use candle_transformers::models::bert::{BertModel, Config, HiddenAct, DTYPE}; use hf_hub::{api::sync::Api, Repo}; @@ -65,10 +67,8 @@ impl BertEmbeder { Ok(Tensor::stack(&token_ids, 0)?) } -} -impl Embed for BertEmbeder { - async fn embed(&self, text_batch: &[String]) -> Result, reqwest::Error> { + pub async fn embed(&self, text_batch: &[String],metadata:Option>) -> Result, reqwest::Error> { let token_ids = self.tokenize_batch(text_batch, &self.model.device).unwrap(); let token_type_ids = token_ids.zeros_like().unwrap(); let embeddings = self.model.forward(&token_ids, &token_type_ids).unwrap(); @@ -79,12 +79,31 @@ impl Embed for BertEmbeder { let final_embeddings = encodings .iter() .zip(text_batch) - .map(|(data, text)| EmbedData::new(data.to_vec(), Some(text.clone()))) + .map(|(data, text)| EmbedData::new(data.to_vec(), Some(text.clone()), metadata.clone())) .collect::>(); Ok(final_embeddings) } } +impl Embed for BertEmbeder { + fn embed( + &self, + text_batch: &[String],metadata: Option> + ) -> impl std::future::Future, reqwest::Error>> { + self.embed(text_batch, metadata) + } +} + +impl TextEmbed for BertEmbeder { + fn embed( + &self, + text_batch: &[String], + metadata: Option> + ) -> impl std::future::Future, reqwest::Error>> { + self.embed(text_batch, metadata) + } +} + pub fn normalize_l2(v: &Tensor) -> candle_core::Result { v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?) } diff --git a/src/embedding_model/clip.rs b/src/embedding_model/clip.rs index dda16ee4..de18053e 100644 --- a/src/embedding_model/clip.rs +++ b/src/embedding_model/clip.rs @@ -1,6 +1,8 @@ #[cfg(feature = "mkl")] extern crate intel_mkl_src; +use std::collections::HashMap; + use anyhow::Error as E; use candle_core::{DType, Device, Tensor}; @@ -145,6 +147,7 @@ impl EmbedImage for ClipEmbeder { fn embed_image_batch>( &self, image_paths: &[T], + ) -> anyhow::Result> { let config = clip::ClipConfig::vit_base_patch32(); @@ -163,13 +166,14 @@ impl EmbedImage for ClipEmbeder { EmbedData::new( data.to_vec(), Some(path.as_ref().to_str().unwrap().to_string()), + None, ) }) .collect::>(); Ok(embeddings) } - fn embed_image>(&self, image_path: T) -> anyhow::Result { + fn embed_image>(&self, image_path: T, metadata: Option>) -> anyhow::Result { let config = clip::ClipConfig::vit_base_patch32(); let image = self .load_image(&image_path, config.image_size) @@ -182,12 +186,12 @@ impl EmbedImage for ClipEmbeder { .unwrap() .to_vec2::() .unwrap()[0]; - Ok(EmbedData::new(encoding.to_vec(), None)) + Ok(EmbedData::new(encoding.to_vec(), None, metadata.clone())) } } impl Embed for ClipEmbeder { - async fn embed(&self, text_batch: &[String]) -> Result, reqwest::Error> { + async fn embed(&self, text_batch: &[String], metadata: Option>) -> Result, reqwest::Error> { let (input_ids, _vec_seq) = ClipEmbeder::tokenize_sequences( Some(text_batch.to_vec()), &self.tokenizer, @@ -204,7 +208,7 @@ impl Embed for ClipEmbeder { let embeddings = encodings .iter() .zip(text_batch) - .map(|(data, text)| EmbedData::new(data.to_vec(), Some(text.clone()))) + .map(|(data, text)| EmbedData::new(data.to_vec(), Some(text.clone()), metadata.clone() )) .collect::>(); Ok(embeddings) } diff --git a/src/embedding_model/embed.rs b/src/embedding_model/embed.rs index a984bd29..73d1b657 100644 --- a/src/embedding_model/embed.rs +++ b/src/embedding_model/embed.rs @@ -1,6 +1,7 @@ use pyo3::prelude::*; use serde::Deserialize; use std::collections::HashMap; +use std::fmt::Debug; use super::jina::JinaEmbeder; @@ -12,6 +13,8 @@ pub struct EmbedResponse { pub data: Vec, pub usage: HashMap, } + + #[pyclass] #[derive(Deserialize, Debug, Clone)] pub struct EmbedData { @@ -19,19 +22,31 @@ pub struct EmbedData { pub embedding: Vec, #[pyo3(get, set)] pub text: Option, + #[pyo3(get, set)] + pub metadata: Option>, +} + +impl Default for EmbedData { + fn default() -> Self { + Self { + embedding: Vec::new(), + text: None, + metadata: None, + } + } } #[pymethods] impl EmbedData { #[new] - pub fn new(embedding: Vec, text: Option) -> Self { - Self { embedding, text } + pub fn new(embedding: Vec, text: Option, metadata:Option>) -> Self { + Self { embedding, text, metadata } } pub fn __str__(&self) -> String { format!( - "EmbedData(embedding: {:?}, text: {:?})", - self.embedding, self.text + "EmbedData(embedding: {:?}, text: {:?}, metadata: {:?})", + self.embedding, self.text, self.metadata.clone() ) } } @@ -45,12 +60,12 @@ pub enum Embeder { } impl Embeder { - pub async fn embed(&self, text_batch: &[String]) -> Result, reqwest::Error> { + pub async fn embed(&self, text_batch: &[String], metadata: Option>) -> Result, reqwest::Error> { match self { - Embeder::OpenAI(embeder) => embeder.embed(text_batch).await, - Embeder::Jina(embeder) => embeder.embed(text_batch).await, - Embeder::Clip(embeder) => embeder.embed(text_batch).await, - Embeder::Bert(embeder) => embeder.embed(text_batch).await, + Embeder::OpenAI(embeder) => TextEmbed::embed(embeder, text_batch, metadata).await, + Embeder::Jina(embeder) => TextEmbed::embed(embeder, text_batch, metadata).await, + Embeder::Clip(embeder) => Embed::embed(embeder, text_batch, metadata).await, + Embeder::Bert(embeder) => TextEmbed::embed(embeder, text_batch, metadata).await, } } } @@ -60,11 +75,16 @@ pub trait Embed { fn embed( &self, text_batch: &[String], + metadata: Option>, ) -> impl std::future::Future, reqwest::Error>>; } +pub trait TextEmbed { + fn embed(&self, text_batch: &[String], metadata: Option>) -> impl std::future::Future, reqwest::Error>>; +} + pub trait EmbedImage { - fn embed_image>(&self, image_path: T) -> anyhow::Result; + fn embed_image>(&self, image_path: T, metadata: Option>) -> anyhow::Result; fn embed_image_batch>(&self, image_paths:&[T]) -> anyhow::Result>; } \ No newline at end of file diff --git a/src/embedding_model/jina.rs b/src/embedding_model/jina.rs index fa6826a5..bb20bbee 100644 --- a/src/embedding_model/jina.rs +++ b/src/embedding_model/jina.rs @@ -1,4 +1,6 @@ -use super::embed::{Embed, EmbedData}; +use std::collections::HashMap; + +use super::embed::{Embed, EmbedData, TextEmbed}; use anyhow::Error as E; use candle_core::{DType, Device, Tensor}; use candle_nn::{Module, VarBuilder}; @@ -68,10 +70,8 @@ impl JinaEmbeder { .collect::>>()?; Ok(Tensor::stack(&token_ids, 0)?) } -} -impl Embed for JinaEmbeder { - async fn embed(&self, text_batch: &[String]) -> Result, reqwest::Error> { + async fn embed(&self, text_batch: &[String], metadata:Option>) -> Result, reqwest::Error> { let token_ids = self.tokenize_batch(text_batch, &self.model.device).unwrap(); let embeddings = self.model.forward(&token_ids).unwrap(); @@ -84,10 +84,32 @@ impl Embed for JinaEmbeder { let final_embeddings = encodings .iter() .zip(text_batch) - .map(|(data, text)| EmbedData::new(data.to_vec(), Some(text.clone()))) + .map(|(data, text)| EmbedData::new(data.to_vec(), Some(text.clone()), metadata.clone())) .collect::>(); Ok(final_embeddings) } + + +} + +impl Embed for JinaEmbeder { + fn embed( + &self, + text_batch: &[String], + metadata: Option>, + ) -> impl std::future::Future, reqwest::Error>> { + self.embed(text_batch, metadata) + } +} + +impl TextEmbed for JinaEmbeder { + fn embed( + &self, + text_batch: &[String], + metadata: Option>, + ) -> impl std::future::Future, reqwest::Error>> { + self.embed(text_batch, metadata) + } } pub fn normalize_l2(v: &Tensor) -> candle_core::Result { diff --git a/src/embedding_model/openai.rs b/src/embedding_model/openai.rs index 44fd1024..44e302d2 100644 --- a/src/embedding_model/openai.rs +++ b/src/embedding_model/openai.rs @@ -1,10 +1,12 @@ +use std::collections::HashMap; + use reqwest::Client; use serde::Deserialize; use serde_json::json; use crate::embedding_model::embed::{EmbedData, EmbedResponse}; -use super::embed::Embed; +use super::embed::{Embed, TextEmbed}; /// Represents an OpenAIEmbeder struct that contains the URL and API key for making requests to the OpenAI API. #[derive(Deserialize, Debug)] @@ -20,7 +22,32 @@ impl Default for OpenAIEmbeder { } impl Embed for OpenAIEmbeder { - async fn embed(&self, text_batch: &[String]) -> Result, reqwest::Error> { + fn embed(&self, text_batch: &[String], metadata: Option>) -> impl std::future::Future, reqwest::Error>> { + self.embed(text_batch, metadata) + } +} + +impl TextEmbed for OpenAIEmbeder { + fn embed( + &self, + text_batch: &[String], + metadata: Option>, + ) -> impl std::future::Future, reqwest::Error>> { + self.embed(text_batch, metadata) + } +} + +impl OpenAIEmbeder { + pub fn new(api_key: Option) -> Self { + let api_key = api_key.unwrap_or_else(|| std::env::var("OPENAI_API_KEY").unwrap()); + + Self { + url: "https://api.openai.com/v1/embeddings".to_string(), + api_key, + } + } + + async fn embed(&self, text_batch: &[String], metadata: Option>) -> Result, reqwest::Error> { let client = Client::new(); let response = client @@ -41,23 +68,11 @@ impl Embed for OpenAIEmbeder { .data .iter() .zip(text_batch) - .map(move |(data, text)| EmbedData::new(data.embedding.clone(), Some(text.clone()))) + .map(move |(data, text)| EmbedData::new(data.embedding.clone(), Some(text.clone()), metadata.clone())) .collect::>(); Ok(emb_data) } - -} - -impl OpenAIEmbeder { - pub fn new(api_key: Option) -> Self { - let api_key = api_key.unwrap_or_else(|| std::env::var("OPENAI_API_KEY").unwrap()); - - Self { - url: "https://api.openai.com/v1/embeddings".to_string(), - api_key, - } - } } #[cfg(test)] @@ -72,7 +87,7 @@ mod tests { "The quick brown fox jumps over the lazy dog".to_string(), ]; - let embeddings = openai.embed(&text_batch).await.unwrap(); + let embeddings = openai.embed(&text_batch, None).await.unwrap(); assert_eq!(embeddings.len(), 2); } -} \ No newline at end of file +} diff --git a/src/file_embed.rs b/src/file_embed.rs index b1d381f8..380cd8f7 100644 --- a/src/file_embed.rs +++ b/src/file_embed.rs @@ -1,4 +1,4 @@ -use std::fmt::Debug; +use std::{collections::HashMap, fmt::Debug}; use anyhow::Error; @@ -25,13 +25,20 @@ impl FileEmbeder { embeddings: Vec::new(), } } - pub fn split_into_chunks(&mut self, text: &str, chunk_size: usize) { + pub fn split_into_chunks(&mut self, text: &str, chunk_size: usize) -> Option> { + let mut chunk = Vec::new(); + let mut chunks = Vec::new(); + + if text == "" { + return None; + } + let sentences: Vec<&str> = text.split_terminator('.').collect(); for sentence in sentences { if text.len() < chunk_size { - self.chunks.push(text.to_owned()); + chunks.push(text.to_owned()); break; } @@ -44,14 +51,17 @@ impl FileEmbeder { chunk.extend(words); if chunk.len() >= chunk_size { - self.chunks.push(chunk.join(" ")); + chunks.push(chunk.join(" ")); chunk.clear(); } } + self.chunks = chunks; + Some(self.chunks.clone()) + } - pub async fn embed(&mut self, embeder: &Embeder) -> Result<(), reqwest::Error> { - self.embeddings = embeder.embed(&self.chunks).await?; + pub async fn embed(&mut self, embeder: &Embeder, metadata: Option>) -> Result<(), reqwest::Error> { + self.embeddings = embeder.embed(&self.chunks, metadata).await?; Ok(()) } @@ -81,7 +91,7 @@ mod tests { let embeder = Embeder::Bert(BertEmbeder::default()); let mut file_embeder = FileEmbeder::new(file_path.to_string_lossy().to_string()); file_embeder.split_into_chunks(&text, 100); - file_embeder.embed(&embeder).await.unwrap(); + file_embeder.embed(&embeder, None).await.unwrap(); assert_eq!(file_embeder.chunks.len(), 5); assert_eq!(file_embeder.embeddings.len(), 5); } @@ -90,7 +100,7 @@ mod tests { async fn test_image_embeder() { let file_path = PathBuf::from("test_files/clip/cat1.jpg"); let embeder = ClipEmbeder::default(); - let emb_data = embeder.embed_image(&file_path).unwrap(); + let emb_data = embeder.embed_image(&file_path, None).unwrap(); assert_eq!(emb_data.embedding.len(), 512); } } diff --git a/src/file_processor/mod.rs b/src/file_processor/mod.rs index 7a211219..3820d572 100644 --- a/src/file_processor/mod.rs +++ b/src/file_processor/mod.rs @@ -6,4 +6,7 @@ pub mod pdf_processor; pub mod markdown_processor; /// This module contains the file processor for text files. -pub mod txt_processor; \ No newline at end of file +pub mod txt_processor; + +/// This module contains the processor to process web links. +pub mod website_processor; \ No newline at end of file diff --git a/src/file_processor/pdf_processor.rs b/src/file_processor/pdf_processor.rs index 86d28b36..9e94599e 100644 --- a/src/file_processor/pdf_processor.rs +++ b/src/file_processor/pdf_processor.rs @@ -1,5 +1,4 @@ use std:: path::PathBuf; -use pdf_extract::OutputError; use anyhow::Error; /// A struct for processing PDF files. @@ -26,6 +25,7 @@ impl PdfProcessor { #[cfg(test)] mod tests { + use pdf_extract::OutputError; use super::*; use std::fs::File; use tempdir::TempDir; diff --git a/src/file_processor/website_processor.rs b/src/file_processor/website_processor.rs new file mode 100644 index 00000000..e5edad07 --- /dev/null +++ b/src/file_processor/website_processor.rs @@ -0,0 +1,177 @@ +use std::collections::{HashMap, HashSet}; + +use anyhow::{Error, Ok}; +use regex::Regex; +use scraper::Selector; +use serde_json::json; +use text_cleaner::clean::Clean; + +use crate::{ + embedding_model::embed::{EmbedData, TextEmbed}, + file_embed::FileEmbeder, +}; + +#[derive(Debug)] +pub struct WebPage { + pub url: String, + pub title: Option, + pub headers: Option>, + pub paragraphs: Option>, + pub codes: Option>, + pub links: Option>, +} + +impl WebPage { + pub async fn embed_webpage(&self, embeder: &T) -> Result, Error>{ + let mut embed_data = Vec::new(); + let paragraph_embeddings = if let Some(paragraphs) = &self.paragraphs { + self.embed_tag::("p", paragraphs.to_vec(), &embeder).await.unwrap_or(Vec::new()) + } else { + Vec::new() + }; + + let header_embeddings = if let Some(headers) = &self.headers { + self.embed_tag::("h1", headers.to_vec(), &embeder).await.unwrap_or(Vec::new()) + } else { + Vec::new() + }; + + let code_embeddings = if let Some(codes) = &self.codes { + self.embed_tag::("code", codes.to_vec(), &embeder).await.unwrap_or(Vec::new()) + } else { + Vec::new() + }; + + embed_data.extend(paragraph_embeddings); + embed_data.extend(header_embeddings); + embed_data.extend(code_embeddings); + Ok(embed_data) + } + + pub async fn embed_tag(&self,tag: &str, tag_content: Vec, embeder: &T) -> Result, Error> { + let mut embed_data = Vec::new(); + for content in tag_content { + let mut file_embeder = FileEmbeder::new(self.url.to_string()); + + let chunks = match file_embeder.split_into_chunks(&content, 1000) { + Some(chunks) => chunks, + None => continue, + }; + + match chunks.len() { + 0 => continue, + _ => (), + } + + let tag_type = match tag { + "h1" => "header", + "h2" => "subheader", + "h3" => "subsubheader", + "p" => "paragraph", + "code" => "code", + _ => "paragraph", + }; + + let metadata = json!({ + "url": self.url, + "type": tag_type, + "full_text": content, + }); + + let metadata_hashmap: HashMap = + serde_json::from_value(metadata).unwrap(); + + + let embeddings = embeder + .embed(&chunks, Some(metadata_hashmap)) + .await + .unwrap_or(Vec::new()); + for embedding in embeddings { + embed_data.push(embedding); + + } + } + Ok(embed_data) + } +} + +/// A struct for processing websites. +pub struct WebsiteProcesor; + + +impl WebsiteProcesor { + pub fn new() -> Self { + Self {} + } + + pub async fn process_website(&self, website: &str) -> Result { + let response = reqwest::get(website).await?.text().await?; + let document = scraper::Html::parse_document(&response); + let headers = self.get_text_from_tag("h1,h2,h3", &document)?; + let paragraphs = self.get_text_from_tag("p", &document)?; + let codes = self.get_text_from_tag("code", &document)?; + let links = self.extract_links(website, &document)?; + let binding = self.get_text_from_tag("h1", &document)?; + let title = binding.first(); + let web_page = WebPage { + url: website.to_string(), + title: title.map(|s| s.to_string()), + headers: Some(headers), + paragraphs: Some(paragraphs), + codes: Some(codes), + links: Some(links), + }; + + Ok(web_page) + } + + pub fn get_text_from_tag( + &self, + tag: &str, + document: &scraper::Html, + ) -> Result, Error> { + let selector = Selector::parse(tag).map_err(|e| Error::msg(e.to_string()))?; + Ok(document + .select(&selector) + .map(|element| element.text().collect::().trim()) + .collect::>()) + } + + pub fn extract_links( + &self, + website: &str, + document: &scraper::Html, + ) -> Result, Error> { + let mut links = HashSet::new(); + let _ = document + .select(&Selector::parse("a").unwrap()) + .map(|element| { + let link = element.value().attr("href").unwrap_or_default().to_string(); + let regex: Regex = Regex::new( + r"^((https?|ftp|smtp):\/\/)?(www.)?[a-z0-9]+\.[a-z]+(\/[a-zA-Z0-9#]+\/?)*$", + ) + .unwrap(); + // Check if the link is a valid URL using regex. If not append the website URL to the beginning of the link. + if !regex.is_match(&link) { + links.insert(format!("{}{}", website, link)); + } else { + links.insert(link); + } + }); + + Ok(links) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_process_website() { + let website_processor = WebsiteProcesor; + let website = "https://www.scrapingbee.com/blog/web-scraping-rust/"; + let result = website_processor.process_website(website); + assert!(result.await.is_ok()); + } +} diff --git a/src/lib.rs b/src/lib.rs index 424b4c50..82622132 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,7 +11,7 @@ pub mod parser; use std::path::PathBuf; -use embedding_model::embed::{EmbedData, EmbedImage, Embeder}; +use embedding_model::embed::{EmbedData, EmbedImage, Embeder, TextEmbed}; use file_embed::FileEmbeder; use parser::FileParser; use pyo3::{exceptions::PyValueError, prelude::*}; @@ -59,7 +59,7 @@ pub fn embed_query(query: Vec, embeder: &str) -> PyResult }; let runtime = Builder::new_multi_thread().enable_all().build().unwrap(); - let embeddings = runtime.block_on(embedding_model.embed(&query)).unwrap(); + let embeddings = runtime.block_on(embedding_model.embed(&query, None)).unwrap(); Ok(embeddings) } /// Embeds the text from a file using the specified embedding model. @@ -169,12 +169,35 @@ pub fn embed_directory( Ok(embeddings) } +#[pyfunction] +pub fn emb_webpage(url: String, embeder: &str) -> PyResult> { + let website_processor = file_processor::website_processor::WebsiteProcesor::new(); + let runtime = Builder::new_multi_thread().enable_all().build().unwrap(); + let webpage = runtime.block_on(website_processor.process_website(url.as_ref())).unwrap(); + + let embeddings = match embeder { + "OpenAI" => runtime.block_on(webpage.embed_webpage(&embedding_model::openai::OpenAIEmbeder::default())).unwrap(), + "Jina" => runtime.block_on(webpage.embed_webpage(&embedding_model::jina::JinaEmbeder::default())).unwrap(), + "Bert" => runtime.block_on(webpage.embed_webpage(&embedding_model::bert::BertEmbeder::default())).unwrap(), + _ => { + return Err(PyValueError::new_err( + "Invalid embedding model. Choose between OpenAI and AllMiniLmL12V2.", + )) + } + }; + + Ok(embeddings) + +} + #[pymodule] fn embed_anything(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(embed_file, m)?)?; m.add_function(wrap_pyfunction!(embed_directory, m)?)?; m.add_function(wrap_pyfunction!(embed_query, m)?)?; + m.add_function(wrap_pyfunction!(emb_webpage, m)?)?; m.add_class::()?; + Ok(()) } @@ -195,7 +218,7 @@ fn emb_directory( file_embeder.split_into_chunks(&text, 100); let runtime = Builder::new_multi_thread().enable_all().build().unwrap(); runtime - .block_on(file_embeder.embed(&embedding_model)) + .block_on(file_embeder.embed(&embedding_model, None)) .unwrap(); file_embeder.embeddings }) @@ -211,14 +234,14 @@ fn emb_text> (file: T, embedding_model: Embeder) -> Py file_embeder.split_into_chunks(&text, 100); let runtime = Builder::new_multi_thread().enable_all().build().unwrap(); runtime - .block_on(file_embeder.embed(&embedding_model)) + .block_on(file_embeder.embed(&embedding_model, None)) .unwrap(); Ok(file_embeder.embeddings[0].clone()) } fn emb_image, U: EmbedImage>(image_path: T, embedding_model: U) -> PyResult { - let embedding = embedding_model.embed_image(image_path).unwrap(); + let embedding = embedding_model.embed_image(image_path, None).unwrap(); Ok(embedding) } @@ -231,3 +254,4 @@ fn emb_image_directory(directory: PathBuf, embedding_model: T) -> .unwrap(); Ok(embeddings) } + diff --git a/test.py b/test.py index 33b0f784..fe2e3bbd 100644 --- a/test.py +++ b/test.py @@ -4,22 +4,29 @@ from PIL import Image import time -start = time.time() -data= embed_anything.embed_file("test_files/clip/cat1.jpg", embeder= "Clip") +# start = time.time() +# data= embed_anything.embed_file("test_files/clip/cat1.jpg", embeder= "Clip") -embeddings = np.array([data.embedding for data in data]) +# embeddings = np.array([data.embedding for data in data]) -print(data[0]) +# print(data[0]) -query = ["Photo of a dog?"] -query_embedding = np.array(embed_anything.embed_query(query, embeder= "Clip")[0].embedding) +# query = ["Photo of a dog?"] +# query_embedding = np.array(embed_anything.embed_query(query, embeder= "Clip")[0].embedding) -similarities = np.dot(embeddings, query_embedding) +# similarities = np.dot(embeddings, query_embedding) -max_index = np.argmax(similarities) +# max_index = np.argmax(similarities) -# Image.open(data[max_index].text).show() -print(data[max_index].text) -end = time.time() -print("Time taken: ", end-start) \ No newline at end of file +# # Image.open(data[max_index].text).show() +# print(data[max_index].text) +# end = time.time() +# print("Time taken: ", end-start) + + +url = "https://www.akshaymakes.com/blogs/3d_convolution" + +data = embed_anything.emb_webpage(url, embeder= "Bert") + +print(data[0]) \ No newline at end of file