Skip to content

Commit

Permalink
tokenizer max limit on input size (huggingface#324)
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikKaum authored Jul 3, 2024
1 parent ca68c08 commit c5f1480
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions core/src/tokenization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use tokenizers::{TruncationDirection, TruncationParams, TruncationStrategy};
use tokio::sync::oneshot;
use tracing::{instrument, Span};

static MAX_CHAR_MULTIPLIER: usize = 250;

/// Validation
#[derive(Debug, Clone)]
pub struct Tokenization {
Expand Down Expand Up @@ -215,6 +217,7 @@ fn tokenizer_worker(
let _ = response_tx.send(tokenize_input(
inputs,
add_special_tokens,
max_input_length,
None,
default_prompt_clone,
prompt_name,
Expand Down Expand Up @@ -269,9 +272,11 @@ fn prepare_pre_prompt(
Ok(pre_prompt)
}

#[allow(clippy::too_many_arguments)]
fn tokenize_input(
inputs: EncodingInput,
add_special_tokens: bool,
max_input_length: usize,
truncate_params: Option<TruncationParams>,
default_prompt: Option<String>,
prompt_name: Option<String>,
Expand All @@ -280,6 +285,14 @@ fn tokenize_input(
) -> Result<(Option<String>, RawEncoding), TextEmbeddingsError> {
let pre_prompt = prepare_pre_prompt(default_prompt, prompt_name, prompts)?;

let input_chars = inputs.count_chars();
let limit = max_input_length * MAX_CHAR_MULTIPLIER;
if input_chars > limit {
return Err(TextEmbeddingsError::Validation(format!(
"`inputs` must have less than {limit} characters. Given: {input_chars}"
)));
}

let encoding = match inputs {
// encode input
EncodingInput::Single(s) => {
Expand Down Expand Up @@ -359,6 +372,7 @@ fn encode_input(
let (_, encoding) = tokenize_input(
inputs,
true,
max_input_length,
truncate_params,
default_prompt,
prompt_name,
Expand Down Expand Up @@ -404,6 +418,14 @@ impl EncodingInput {
EncodingInput::Ids(v) => v.is_empty(),
}
}

fn count_chars(&self) -> usize {
match self {
EncodingInput::Single(s) => s.chars().count(),
EncodingInput::Dual(s1, s2) => s1.chars().count() + s2.chars().count(),
EncodingInput::Ids(v) => v.len(),
}
}
}

impl From<String> for EncodingInput {
Expand Down

0 comments on commit c5f1480

Please sign in to comment.