diff --git a/crates/llm-local/src/llama.rs b/crates/llm-local/src/llama.rs index 1fd66a680..06166e3e2 100644 --- a/crates/llm-local/src/llama.rs +++ b/crates/llm-local/src/llama.rs @@ -162,7 +162,6 @@ impl CachedInferencingModel for LlamaModels { } // Decode the token and add it to the output. if let Some(t) = tokenizer.next_token(next_token)? { - print!("{}", t); output_text.push_str(&t); } } @@ -183,7 +182,7 @@ impl CachedInferencingModel for LlamaModels { } /// Loads a list of SafeTensors file paths from a given model directory and -/// path to the model index JSON file. +/// path to the model index JSON file relative to the model folder. fn load_safetensors(model_dir: &Path, json_file: &str) -> Result> { let json_file = model_dir.join(json_file); let json_file = std::fs::File::open(json_file)?; diff --git a/crates/llm-local/src/token_output_stream.rs b/crates/llm-local/src/token_output_stream.rs index 5d77fbc07..6f082cdd5 100644 --- a/crates/llm-local/src/token_output_stream.rs +++ b/crates/llm-local/src/token_output_stream.rs @@ -1,9 +1,9 @@ -/// Implementation for TokenOutputStream Code is borrowed from -/// https://github.com/huggingface/candle/blob/main/candle-examples/src/token_output_stream.rs -/// (Commit SHA 4fd00b890036ef67391a9cc03f896247d0a75711) -/// /// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a /// streaming way rather than having to wait for the full decoding. +/// Implementation for TokenOutputStream Code is borrowed from +/// +/// Borrowed from https://github.com/huggingface/candle/blob/main/candle-examples/src/token_output_stream.rs +/// (Commit SHA 4fd00b890036ef67391a9cc03f896247d0a75711) pub struct TokenOutputStream { tokenizer: tokenizers::Tokenizer, tokens: Vec, @@ -21,16 +21,10 @@ impl TokenOutputStream { } } - fn decode(&self, tokens: &[u32]) -> anyhow::Result { - match self.tokenizer.decode(tokens, true) { - Ok(str) => Ok(str), - Err(err) => anyhow::bail!("cannot decode: {err}"), - } - } - /// Processes the next token in the sequence, decodes the current token stream, /// and returns any newly decoded text. - /// https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68 + /// + /// Based on the following code: pub fn next_token(&mut self, token: u32) -> anyhow::Result> { let prev_text = if self.tokens.is_empty() { String::new() @@ -50,6 +44,12 @@ impl TokenOutputStream { } } + /// Decodes the remaining tokens and returns any new text found. + /// + /// This function decodes tokens from `self.prev_index` to the end and + /// compares it with the previously decoded portion (from `self.prev_index` + /// to `self.current_index`). If new text is found, it returns the + /// additional part as `Some(String)`. Otherwise, returns `None`. pub fn decode_rest(&self) -> anyhow::Result> { let prev_text = if self.tokens.is_empty() { String::new() @@ -65,4 +65,11 @@ impl TokenOutputStream { Ok(None) } } + + fn decode(&self, tokens: &[u32]) -> anyhow::Result { + match self.tokenizer.decode(tokens, true) { + Ok(str) => Ok(str), + Err(err) => anyhow::bail!("cannot decode: {err}"), + } + } }