Skip to content

Commit

Permalink
Update embedding_model::embed.rs and file_embed.rs
Browse files Browse the repository at this point in the history
  • Loading branch information
akshayballal95 committed Apr 20, 2024
1 parent 26bfa46 commit f6340f0
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 23 deletions.
2 changes: 1 addition & 1 deletion src/embedding_model/embed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pub struct EmbedResponse {
pub usage: HashMap<String, usize>,
}
#[pyclass]
#[derive(Deserialize, Debug)]
#[derive(Deserialize, Debug, Clone)]
pub struct EmbedData {
#[pyo3(get, set)]
pub embedding: Vec<f32>,
Expand Down
2 changes: 1 addition & 1 deletion src/file_embed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
use super::file_processor::pdf_processor::PdfProcessor;
use std::path::PathBuf;

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct FileEmbeder {
pub file: String,
pub chunks: Vec<String>,
Expand Down
50 changes: 30 additions & 20 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,26 +89,19 @@ pub fn embed_query(query: Vec<String>, embeder: &str) -> PyResult<Vec<EmbedData>
/// This will output the embeddings of the file using the OpenAI embedding model.
#[pyfunction]
pub fn embed_file(file_name: &str, embeder: &str) -> PyResult<Vec<EmbedData>> {
let embedding_model = match embeder {
"OpenAI" => Embeder::OpenAI(embedding_model::openai::OpenAIEmbeder::default()),
"Jina" => Embeder::Jina(embedding_model::jina::JinaEmbeder::default()),
"Clip" => Embeder::Clip(embedding_model::clip::ClipEmbeder::default()),
"Bert" => Embeder::Bert(embedding_model::bert::BertEmbeder::default()),
let embeddings = match embeder {
"OpenAI" => emb_text(file_name, Embeder::OpenAI(embedding_model::openai::OpenAIEmbeder::default()))?,
"Jina" => emb_text(file_name, Embeder::Jina(embedding_model::jina::JinaEmbeder::default()))?,
"Bert" => emb_text(file_name, Embeder::Bert(embedding_model::bert::BertEmbeder::default()))?,
"Clip" => emb_image(file_name, embedding_model::clip::ClipEmbeder::default())?,
_ => {
return Err(PyValueError::new_err(
"Invalid embedding model. Choose between OpenAI and AllMiniLmL12V2.",
"Invalid embedding model. Choose between OpenAI and Bert for text files and Clip for image files.",
))
}
};

let mut file_embeder = FileEmbeder::new(file_name.to_string());
let text = file_embeder.extract_text().unwrap();
file_embeder.split_into_chunks(&text, 100);
let runtime = Builder::new_multi_thread().enable_all().build().unwrap();
runtime
.block_on(file_embeder.embed(&embedding_model))
.unwrap();
Ok(file_embeder.embeddings)
Ok(vec![embeddings])
}

/// Embeds the text from files in a directory using the specified embedding model.
Expand Down Expand Up @@ -146,25 +139,25 @@ pub fn embed_directory(
extensions: Option<Vec<String>>,
) -> PyResult<Vec<EmbedData>> {
let embeddings = match embeder {
"OpenAI" => emb(
"OpenAI" => emb_directory(
directory,
Embeder::OpenAI(embedding_model::openai::OpenAIEmbeder::default()),
extensions,
)
.unwrap(),
"Jina" => emb(
"Jina" => emb_directory(
directory,
Embeder::Jina(embedding_model::jina::JinaEmbeder::default()),
extensions,
)
.unwrap(),
"Bert" => emb(
"Bert" => emb_directory(
directory,
Embeder::Bert(embedding_model::bert::BertEmbeder::default()),
extensions,
)
.unwrap(),
"Clip" => emb_image(directory, embedding_model::clip::ClipEmbeder::default()).unwrap(),
"Clip" => emb_image_directory(directory, embedding_model::clip::ClipEmbeder::default())?,

_ => {
return Err(PyValueError::new_err(
Expand All @@ -185,7 +178,7 @@ fn embed_anything(m: &Bound<'_, PyModule>) -> PyResult<()> {
Ok(())
}

fn emb(
fn emb_directory(
directory: PathBuf,
embedding_model: Embeder,
extensions: Option<Vec<String>>,
Expand All @@ -212,7 +205,24 @@ fn emb(
Ok(embeddings)
}

fn emb_image<T: EmbedImage>(directory: PathBuf, embedding_model: T) -> PyResult<Vec<EmbedData>> {
fn emb_text<T: AsRef<std::path::Path>> (file: T, embedding_model: Embeder) -> PyResult<EmbedData> {
let mut file_embeder = FileEmbeder::new(file.as_ref().to_str().unwrap().to_string());
let text = file_embeder.extract_text().unwrap();
file_embeder.split_into_chunks(&text, 100);
let runtime = Builder::new_multi_thread().enable_all().build().unwrap();
runtime
.block_on(file_embeder.embed(&embedding_model))
.unwrap();
Ok(file_embeder.embeddings[0].clone())

}

fn emb_image<T: AsRef<std::path::Path>, U: EmbedImage>(image_path: T, embedding_model: U) -> PyResult<EmbedData> {
let embedding = embedding_model.embed_image(image_path).unwrap();
Ok(embedding)
}

fn emb_image_directory<T: EmbedImage>(directory: PathBuf, embedding_model: T) -> PyResult<Vec<EmbedData>> {
let mut file_parser = FileParser::new();
file_parser.get_image_paths(&directory).unwrap();

Expand Down
3 changes: 2 additions & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import time

start = time.time()
data= embed_anything.embed_directory("test_files", embeder= "Clip")
data= embed_anything.embed_file("test_files/clip/cat1.jpg", embeder= "Clip")


embeddings = np.array([data.embedding for data in data])

Expand Down

0 comments on commit f6340f0

Please sign in to comment.