From c5f14804733d24aa6b850f2450e325a666734766 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Kaunism=C3=A4ki?= Date: Wed, 3 Jul 2024 15:03:27 +0200 Subject: [PATCH] tokenizer max limit on input size (#324) --- core/src/tokenization.rs | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/core/src/tokenization.rs b/core/src/tokenization.rs index c33bfafc..fece3a98 100644 --- a/core/src/tokenization.rs +++ b/core/src/tokenization.rs @@ -7,6 +7,8 @@ use tokenizers::{TruncationDirection, TruncationParams, TruncationStrategy}; use tokio::sync::oneshot; use tracing::{instrument, Span}; +static MAX_CHAR_MULTIPLIER: usize = 250; + /// Validation #[derive(Debug, Clone)] pub struct Tokenization { @@ -215,6 +217,7 @@ fn tokenizer_worker( let _ = response_tx.send(tokenize_input( inputs, add_special_tokens, + max_input_length, None, default_prompt_clone, prompt_name, @@ -269,9 +272,11 @@ fn prepare_pre_prompt( Ok(pre_prompt) } +#[allow(clippy::too_many_arguments)] fn tokenize_input( inputs: EncodingInput, add_special_tokens: bool, + max_input_length: usize, truncate_params: Option, default_prompt: Option, prompt_name: Option, @@ -280,6 +285,14 @@ fn tokenize_input( ) -> Result<(Option, RawEncoding), TextEmbeddingsError> { let pre_prompt = prepare_pre_prompt(default_prompt, prompt_name, prompts)?; + let input_chars = inputs.count_chars(); + let limit = max_input_length * MAX_CHAR_MULTIPLIER; + if input_chars > limit { + return Err(TextEmbeddingsError::Validation(format!( + "`inputs` must have less than {limit} characters. Given: {input_chars}" + ))); + } + let encoding = match inputs { // encode input EncodingInput::Single(s) => { @@ -359,6 +372,7 @@ fn encode_input( let (_, encoding) = tokenize_input( inputs, true, + max_input_length, truncate_params, default_prompt, prompt_name, @@ -404,6 +418,14 @@ impl EncodingInput { EncodingInput::Ids(v) => v.is_empty(), } } + + fn count_chars(&self) -> usize { + match self { + EncodingInput::Single(s) => s.chars().count(), + EncodingInput::Dual(s1, s2) => s1.chars().count() + s2.chars().count(), + EncodingInput::Ids(v) => v.len(), + } + } } impl From for EncodingInput {