Skip to content

Commit

Permalink
feat: add default prompts
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene committed Jun 27, 2024
1 parent 7c9b7cb commit ab7100a
Show file tree
Hide file tree
Showing 7 changed files with 217 additions and 22 deletions.
7 changes: 7 additions & 0 deletions core/src/download.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,10 @@ pub async fn download_st_config(api: &ApiRepo) -> Result<PathBuf, ApiError> {

Err(err)
}

#[instrument(skip_all)]
pub async fn download_new_st_config(api: &ApiRepo) -> Result<PathBuf, ApiError> {
tracing::info!("Downloading `config_sentence_transformers.json`");
let pool_config_path = api.get("config_sentence_transformers.json").await?;
Ok(pool_config_path)
}
15 changes: 12 additions & 3 deletions core/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@ impl Infer {
&self,
inputs: I,
add_special_tokens: bool,
prompt_name: Option<String>,
) -> Result<RawEncoding, TextEmbeddingsError> {
self.tokenization
.tokenize(inputs.into(), add_special_tokens)
.tokenize(inputs.into(), add_special_tokens, prompt_name)
.await
.map_err(|err| {
let counter = metrics::counter!("te_request_failure", "err" => "tokenization");
Expand Down Expand Up @@ -119,6 +120,7 @@ impl Infer {
inputs: I,
truncate: bool,
truncation_direction: TruncationDirection,
prompt_name: Option<String>,
permit: OwnedSemaphorePermit,
) -> Result<AllEmbeddingsInferResponse, TextEmbeddingsError> {
let start_time = Instant::now();
Expand All @@ -138,6 +140,7 @@ impl Infer {
inputs,
truncate,
truncation_direction,
prompt_name,
false,
&start_time,
permit,
Expand Down Expand Up @@ -172,6 +175,7 @@ impl Infer {
inputs: I,
truncate: bool,
truncation_direction: TruncationDirection,
prompt_name: Option<String>,
permit: OwnedSemaphorePermit,
) -> Result<PooledEmbeddingsInferResponse, TextEmbeddingsError> {
let start_time = Instant::now();
Expand All @@ -191,6 +195,7 @@ impl Infer {
inputs,
truncate,
truncation_direction,
prompt_name,
true,
&start_time,
permit,
Expand Down Expand Up @@ -225,6 +230,7 @@ impl Infer {
inputs: I,
truncate: bool,
truncation_direction: TruncationDirection,
prompt_name: Option<String>,
normalize: bool,
permit: OwnedSemaphorePermit,
) -> Result<PooledEmbeddingsInferResponse, TextEmbeddingsError> {
Expand All @@ -245,6 +251,7 @@ impl Infer {
inputs,
truncate,
truncation_direction,
prompt_name,
true,
&start_time,
permit,
Expand Down Expand Up @@ -290,11 +297,13 @@ impl Infer {
Ok(response)
}

#[allow(clippy::too_many_arguments)]
async fn embed<I: Into<EncodingInput> + std::fmt::Debug>(
&self,
inputs: I,
truncate: bool,
truncation_direction: TruncationDirection,
prompt_name: Option<String>,
pooling: bool,
start_time: &Instant,
_permit: OwnedSemaphorePermit,
Expand All @@ -315,7 +324,7 @@ impl Infer {
// Tokenization
let encoding = self
.tokenization
.encode(inputs.into(), truncate, truncation_direction)
.encode(inputs.into(), truncate, truncation_direction, prompt_name)
.await
.map_err(|err| {
let counter = metrics::counter!("te_request_failure", "err" => "tokenization");
Expand Down Expand Up @@ -381,7 +390,7 @@ impl Infer {
// Tokenization
let encoding = self
.tokenization
.encode(inputs.into(), truncate, truncation_direction)
.encode(inputs.into(), truncate, truncation_direction, None)
.await
.map_err(|err| {
let counter = metrics::counter!("te_request_failure", "err" => "tokenization");
Expand Down
99 changes: 90 additions & 9 deletions core/src/tokenization.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/// Payload tokenization logic
use crate::TextEmbeddingsError;
use std::collections::HashMap;
use tokenizers::tokenizer::Tokenizer;
pub use tokenizers::Encoding as RawEncoding;
use tokenizers::{TruncationDirection, TruncationParams, TruncationStrategy};
Expand All @@ -19,6 +20,8 @@ impl Tokenization {
tokenizer: Tokenizer,
max_input_length: usize,
position_offset: usize,
default_prompt_name: Option<String>,
prompts: Option<HashMap<String, String>>,
) -> Self {
tracing::info!("Starting {workers} tokenization workers");

Expand All @@ -29,12 +32,16 @@ impl Tokenization {
for _ in 0..workers {
let tokenizer_clone = tokenizer.clone();
let receiver_clone = receiver.clone();
let default_prompt_name_clone = default_prompt_name.clone();
let prompts_clone = prompts.clone();
// Spawn worker
std::thread::spawn(move || {
tokenizer_worker(
tokenizer_clone,
max_input_length,
position_offset,
default_prompt_name_clone,
prompts_clone,
receiver_clone,
)
});
Expand All @@ -49,6 +56,7 @@ impl Tokenization {
inputs: EncodingInput,
truncate: bool,
truncation_direction: TruncationDirection,
prompt_name: Option<String>,
) -> Result<ValidEncoding, TextEmbeddingsError> {
// Check if inputs is empty
if inputs.is_empty() {
Expand All @@ -66,6 +74,7 @@ impl Tokenization {
inputs,
truncate,
truncation_direction,
prompt_name,
response_sender,
Span::current(),
))
Expand All @@ -82,6 +91,7 @@ impl Tokenization {
&self,
inputs: EncodingInput,
add_special_tokens: bool,
prompt_name: Option<String>,
) -> Result<RawEncoding, TextEmbeddingsError> {
// Check if inputs is empty
if inputs.is_empty() {
Expand All @@ -98,6 +108,7 @@ impl Tokenization {
.send(TokenizerRequest::Tokenize(
inputs,
add_special_tokens,
prompt_name,
response_sender,
Span::current(),
))
Expand Down Expand Up @@ -147,6 +158,8 @@ fn tokenizer_worker(
mut tokenizer: Tokenizer,
max_input_length: usize,
position_offset: usize,
default_prompt_name: Option<String>,
prompts: Option<HashMap<String, String>>,
receiver: async_channel::Receiver<TokenizerRequest>,
) {
// Loop over requests
Expand All @@ -156,10 +169,13 @@ fn tokenizer_worker(
inputs,
truncate,
truncation_direction,
prompt_name,
response_tx,
parent_span,
) => {
parent_span.in_scope(|| {
let prompt_name = prompt_name.or(default_prompt_name.clone());

if !response_tx.is_closed() {
// It's possible that the user dropped its request resulting in a send error.
// We just discard the error
Expand All @@ -169,12 +185,22 @@ fn tokenizer_worker(
truncation_direction,
max_input_length,
position_offset,
prompt_name,
prompts.as_ref(),
&mut tokenizer,
));
}
})
}
TokenizerRequest::Tokenize(inputs, add_special_tokens, response_tx, parent_span) => {
TokenizerRequest::Tokenize(
inputs,
add_special_tokens,
prompt_name,
response_tx,
parent_span,
) => {
let prompt_name = prompt_name.or(default_prompt_name.clone());

parent_span.in_scope(|| {
if !response_tx.is_closed() {
// It's possible that the user dropped its request resulting in a send error.
Expand All @@ -183,6 +209,8 @@ fn tokenizer_worker(
inputs,
add_special_tokens,
None,
prompt_name,
prompts.as_ref(),
&mut tokenizer,
));
}
Expand Down Expand Up @@ -216,36 +244,80 @@ fn tokenize_input(
inputs: EncodingInput,
add_special_tokens: bool,
truncate_params: Option<TruncationParams>,
prompt_name: Option<String>,
prompts: Option<&HashMap<String, String>>,
tokenizer: &mut Tokenizer,
) -> Result<RawEncoding, TextEmbeddingsError> {
let pre_prompt = if let Some(prompt_name) = prompt_name.as_ref() {
match prompts {
None => {
return Err(TextEmbeddingsError::Validation(format!("`default-prompt-name` is set to `{prompt_name}` but no prompts were found in the Sentence Transformers configuration")));
}
Some(prompts) if !prompts.contains_key(prompt_name) => {
return Err(TextEmbeddingsError::Validation(format!("`default-prompt-name` is set to `{prompt_name}` but it was not found in the Sentence Transformers prompts. Available prompts: {:?}", prompts.keys())));
}
Some(prompts) => prompts.get(prompt_name).cloned(),
}
} else {
None
};

let encoding = match inputs {
// encode input
EncodingInput::Single(s) => tokenizer
.with_truncation(truncate_params)?
.encode::<String>(s, add_special_tokens)?,
EncodingInput::Single(s) => {
let s = if let Some(mut pre_prompt) = pre_prompt {
pre_prompt.push_str(&s);
pre_prompt
} else {
s
};

tokenizer
.with_truncation(truncate_params)?
.encode::<String>(s, add_special_tokens)?
}
EncodingInput::Dual(s1, s2) => {
if pre_prompt.is_some() {
return Err(TextEmbeddingsError::Validation(
"`prompt_name` cannot be set with dual inputs".to_string(),
));
}

tokenizer
.with_truncation(truncate_params)?
.encode::<(String, String)>((s1, s2), add_special_tokens)?
}
// input is encoded -> convert to tokenizers Encoding
EncodingInput::Ids(ids) => {
let text = tokenizer.decode(&ids, false)?;
tokenizer
.with_truncation(truncate_params)?
.encode::<String>(text, false)?
if let Some(mut pre_prompt) = pre_prompt {
let text = tokenizer.decode(&ids, true)?;
pre_prompt.push_str(&text);

tokenizer
.with_truncation(truncate_params)?
.encode::<String>(pre_prompt, false)?
} else {
let text = tokenizer.decode(&ids, false)?;

tokenizer
.with_truncation(truncate_params)?
.encode::<String>(text, false)?
}
}
};
Ok(encoding)
}

/// Get input length and optionally truncate it
#[allow(clippy::too_many_arguments)]
fn encode_input(
inputs: EncodingInput,
truncate: bool,
truncation_direction: TruncationDirection,
max_input_length: usize,
position_offset: usize,
prompt_name: Option<String>,
prompts: Option<&HashMap<String, String>>,
tokenizer: &mut Tokenizer,
) -> Result<ValidEncoding, TextEmbeddingsError> {
// Default truncation params
Expand All @@ -256,7 +328,14 @@ fn encode_input(
stride: 0,
});

let encoding = tokenize_input(inputs, true, truncate_params, tokenizer)?;
let encoding = tokenize_input(
inputs,
true,
truncate_params,
prompt_name,
prompts,
tokenizer,
)?;
let seq_len = encoding.len();

if seq_len > max_input_length {
Expand Down Expand Up @@ -315,12 +394,14 @@ enum TokenizerRequest {
EncodingInput,
bool,
TruncationDirection,
Option<String>,
oneshot::Sender<Result<ValidEncoding, TextEmbeddingsError>>,
Span,
),
Tokenize(
EncodingInput,
bool,
Option<String>,
oneshot::Sender<Result<RawEncoding, TextEmbeddingsError>>,
Span,
),
Expand Down
Loading

0 comments on commit ab7100a

Please sign in to comment.