Skip to content

Commit

Permalink
Resolve all clippy lints
Browse files Browse the repository at this point in the history
  • Loading branch information
nkoppel committed Jan 21, 2025
1 parent 6942e13 commit 5dacedf
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 67 deletions.
4 changes: 2 additions & 2 deletions llama-cpp-2/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ impl<'model> LlamaContext<'model> {
}
}

/// Gets the max number of logical tokens that can be submitted to decode. Must be greater than or equal to n_ubatch.
/// Gets the max number of logical tokens that can be submitted to decode. Must be greater than or equal to [`Self::n_ubatch`].
#[must_use]
pub fn n_batch(&self) -> u32 {
unsafe { llama_cpp_sys_2::llama_n_batch(self.context.as_ptr()) }
}

/// Gets the max number of physical tokens (hardware level) to decode in batch. Must be less than or equal to n_batch.
/// Gets the max number of physical tokens (hardware level) to decode in batch. Must be less than or equal to [`Self::n_batch`].
#[must_use]
pub fn n_ubatch(&self) -> u32 {
unsafe { llama_cpp_sys_2::llama_n_ubatch(self.context.as_ptr()) }
Expand Down
46 changes: 27 additions & 19 deletions llama-cpp-2/src/context/kv_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::num::{NonZeroU8, TryFromIntError};

/// Errors that can occur when attempting to prepare values for the kv cache
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
#[allow(clippy::module_name_repetitions)]
pub enum KvCacheConversionError {
/// Sequence id conversion to i32 failed
#[error("Provided sequence id is too large for a i32")]
Expand Down Expand Up @@ -33,15 +34,16 @@ impl LlamaContext<'_> {
/// Copy the cache from one sequence to another.
///
/// # Returns
/// A `Result` indicating whether the operation was successful. If the either position exceeds
/// the maximum i32 value, no copy is attempted and an `Err` is returned.
/// A `Result` indicating whether the operation was successful.
///
/// # Parameters
///
/// * `src` - The sequence id to copy the cache from.
/// * `dest` - The sequence id to copy the cache to.
/// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to `p1`.
/// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from `p0`.
///
/// # Errors
/// If either position exceeds [`i32::MAX`].
pub fn copy_kv_cache_seq(
&mut self,
src: i32,
Expand All @@ -51,10 +53,10 @@ impl LlamaContext<'_> {
) -> Result<(), KvCacheConversionError> {
let p0 = p0
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P0TooLarge(e))?;
.map_err(KvCacheConversionError::P0TooLarge)?;
let p1 = p1
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P1TooLarge(e))?;
.map_err(KvCacheConversionError::P1TooLarge)?;
unsafe {
llama_cpp_sys_2::llama_kv_cache_seq_cp(self.context.as_ptr(), src, dest, p0, p1);
}
Expand All @@ -69,10 +71,12 @@ impl LlamaContext<'_> {
/// either position exceeds the maximum i32 value, no removal is attempted and an `Err` is returned.
///
/// # Parameters
///
/// * `src` - The sequence id to clear the cache for. If `None`, matches all sequences
/// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to `p1`.
/// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from `p0`.
///
/// # Errors
/// If the sequence id or either position exceeds [`i32::MAX`].
pub fn clear_kv_cache_seq(
&mut self,
src: Option<u32>,
Expand All @@ -81,13 +85,13 @@ impl LlamaContext<'_> {
) -> Result<bool, KvCacheConversionError> {
let src = src
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::SeqIdTooLarge(e))?;
.map_err(KvCacheConversionError::SeqIdTooLarge)?;
let p0 = p0
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P0TooLarge(e))?;
.map_err(KvCacheConversionError::P0TooLarge)?;
let p1 = p1
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P1TooLarge(e))?;
.map_err(KvCacheConversionError::P1TooLarge)?;
Ok(unsafe { llama_cpp_sys_2::llama_kv_cache_seq_rm(self.context.as_ptr(), src, p0, p1) })
}

Expand Down Expand Up @@ -118,15 +122,17 @@ impl LlamaContext<'_> {
/// - explicitly with [`Self::kv_cache_update`]
///
/// # Returns
/// A `Result` indicating whether the operation was successful. If either position
/// exceeds the maximum i32 value, no update is attempted and an `Err` is returned.
/// A `Result` indicating whether the operation was successful.
///
/// # Parameters
///
/// * `seq_id` - The sequence id to update
/// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
/// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
/// * `delta` - The relative position to add to the tokens
///
/// # Errors
/// If either position exceeds [`i32::MAX`].
pub fn kv_cache_seq_add(
&mut self,
seq_id: i32,
Expand All @@ -136,10 +142,10 @@ impl LlamaContext<'_> {
) -> Result<(), KvCacheConversionError> {
let p0 = p0
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P0TooLarge(e))?;
.map_err(KvCacheConversionError::P0TooLarge)?;
let p1 = p1
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P1TooLarge(e))?;
.map_err(KvCacheConversionError::P1TooLarge)?;
unsafe {
llama_cpp_sys_2::llama_kv_cache_seq_add(self.context.as_ptr(), seq_id, p0, p1, delta);
}
Expand All @@ -152,15 +158,17 @@ impl LlamaContext<'_> {
/// - explicitly with [`Self::kv_cache_update`]
///
/// # Returns
/// A `Result` indicating whether the operation was successful. If either position
/// exceeds the maximum i32 value, no update is attempted and an `Err` is returned.
/// A `Result` indicating whether the operation was successful.
///
/// # Parameters
///
/// * `seq_id` - The sequence id to update
/// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
/// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
/// * `d` - The factor to divide the positions by
///
/// # Errors
/// If either position exceeds [`i32::MAX`].
pub fn kv_cache_seq_div(
&mut self,
seq_id: i32,
Expand All @@ -170,10 +178,10 @@ impl LlamaContext<'_> {
) -> Result<(), KvCacheConversionError> {
let p0 = p0
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P0TooLarge(e))?;
.map_err(KvCacheConversionError::P0TooLarge)?;
let p1 = p1
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P1TooLarge(e))?;
.map_err(KvCacheConversionError::P1TooLarge)?;
let d = c_int::from(d.get());
unsafe { llama_cpp_sys_2::llama_kv_cache_seq_div(self.context.as_ptr(), seq_id, p0, p1, d) }
Ok(())
Expand Down Expand Up @@ -239,7 +247,7 @@ pub struct KVCacheView<'a> {
view: llama_cpp_sys_2::llama_kv_cache_view,
}

impl<'a> KVCacheView<'a> {
impl KVCacheView<'_> {
/// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
pub fn update(&mut self) {
unsafe {
Expand Down Expand Up @@ -314,7 +322,7 @@ impl<'a> KVCacheView<'a> {
}
}

impl<'a> Drop for KVCacheView<'a> {
impl Drop for KVCacheView<'_> {
fn drop(&mut self) {
unsafe {
llama_cpp_sys_2::llama_kv_cache_view_free(&mut self.view);
Expand Down
24 changes: 19 additions & 5 deletions llama-cpp-2/src/llama_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub struct LlamaBatch {
allocated: usize,
/// The logits that are initialized. Used by [`LlamaContext`] to ensure that only initialized logits are accessed.
pub(crate) initialized_logits: Vec<i32>,
#[allow(clippy::doc_markdown)]
/// The llama_cpp batch. always initialize by `llama_cpp_sys_2::llama_batch_init(allocated, <unknown>, <unknown>)`
pub(crate) llama_batch: llama_batch,
}
Expand All @@ -20,7 +21,7 @@ pub enum BatchAddError {
/// There was not enough space in the batch to add the token.
#[error("Insufficient Space of {0}")]
InsufficientSpace(usize),
/// Empty buffer is provided for get_one
/// Empty buffer is provided for [`LlamaBatch::get_one`]
#[error("Empty buffer")]
EmptyBuffer,
}
Expand Down Expand Up @@ -152,22 +153,35 @@ impl LlamaBatch {
}
}

/// llama_batch_get_one
/// Return batch for single sequence of tokens starting at pos_0
/// ``llama_batch_get_one``
/// Return batch for single sequence of tokens
///
/// NOTE: this is a helper function to facilitate transition to the new batch API
///
/// # Errors
/// If the provided token buffer is empty.
///
/// # Panics
/// If the number of tokens in ``tokens`` exceeds [`i32::MAX`].
pub fn get_one(tokens: &[LlamaToken]) -> Result<Self, BatchAddError> {
if tokens.is_empty() {
return Err(BatchAddError::EmptyBuffer);
}
let batch = unsafe {
let ptr = tokens.as_ptr() as *mut i32;
llama_cpp_sys_2::llama_batch_get_one(ptr, tokens.len() as i32)
llama_cpp_sys_2::llama_batch_get_one(
ptr,
tokens
.len()
.try_into()
.expect("number of tokens exceeds i32::MAX"),
)
};
let batch = Self {
allocated: 0,
initialized_logits: vec![(tokens.len() - 1) as i32],
initialized_logits: vec![(tokens.len() - 1)
.try_into()
.expect("number of tokens exceeds i32::MAX + 1")],
llama_batch: batch,
};
Ok(batch)
Expand Down
44 changes: 26 additions & 18 deletions llama-cpp-2/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ pub struct LlamaChatMessage {

impl LlamaChatMessage {
/// Create a new `LlamaChatMessage`
///
/// # Errors
/// If either of ``role`` or ``content`` contain null bytes.
pub fn new(role: String, content: String) -> Result<Self, NewLlamaChatMessageError> {
Ok(Self {
role: CString::new(role)?,
Expand Down Expand Up @@ -152,17 +155,24 @@ impl LlamaModel {
/// Convert single token to bytes.
///
/// # Errors
///
/// See [`TokenToStringError`] for more information.
///
/// # Panics
/// If a [`TokenToStringError::InsufficientBufferSpace`] error returned by
/// [`Self::token_to_bytes_with_size`] contains a positive nonzero value. This should never
/// happen.
pub fn token_to_bytes(
&self,
token: LlamaToken,
special: Special,
) -> Result<Vec<u8>, TokenToStringError> {
match self.token_to_bytes_with_size(token, 8, special, None) {
Err(TokenToStringError::InsufficientBufferSpace(i)) => {
self.token_to_bytes_with_size(token, -i as usize, special, None)
}
Err(TokenToStringError::InsufficientBufferSpace(i)) => self.token_to_bytes_with_size(
token,
(-i).try_into().expect("Error buffer size is positive"),
special,
None,
),
x => x,
}
}
Expand Down Expand Up @@ -232,7 +242,7 @@ impl LlamaModel {
self.vocab_ptr(),
c_string.as_ptr(),
c_int::try_from(c_string.as_bytes().len())?,
buffer.as_mut_ptr() as *mut llama_cpp_sys_2::llama_token,
buffer.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>(),
buffer_capacity,
add_bos,
true,
Expand All @@ -248,7 +258,7 @@ impl LlamaModel {
self.vocab_ptr(),
c_string.as_ptr(),
c_int::try_from(c_string.as_bytes().len())?,
buffer.as_mut_ptr() as *mut llama_cpp_sys_2::llama_token,
buffer.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>(),
-size,
add_bos,
true,
Expand Down Expand Up @@ -323,18 +333,16 @@ impl LlamaModel {
lstrip: Option<NonZeroU16>,
) -> Result<Vec<u8>, TokenToStringError> {
if token == self.token_nl() {
return Ok(String::from("\n").into_bytes());
return Ok(b"\n".to_vec());
}

// unsure what to do with this in the face of the 'special' arg + attr changes
let attrs = self.token_attr(token);
if attrs.contains(LlamaTokenAttr::Control)
&& (token == self.token_bos() || token == self.token_eos())
{
return Ok(Vec::new());
} else if attrs.is_empty()
if attrs.is_empty()
|| attrs
.intersects(LlamaTokenAttr::Unknown | LlamaTokenAttr::Byte | LlamaTokenAttr::Unused)
|| attrs.contains(LlamaTokenAttr::Control)
&& (token == self.token_bos() || token == self.token_eos())
{
return Ok(Vec::new());
}
Expand Down Expand Up @@ -558,12 +566,12 @@ impl LlamaModel {
chat.len(),
add_ass,
buff.as_mut_ptr().cast::<i8>(),
buff.len() as i32,
buff.len().try_into().expect("Buffer size exceeds i32::MAX"),
)
};

if res > buff.len() as i32 {
buff.resize(res as usize, 0);
if res > buff.len().try_into().expect("Buffer size exceeds i32::MAX") {
buff.resize(res.try_into().expect("res is negative"), 0);

let res = unsafe {
llama_cpp_sys_2::llama_chat_apply_template(
Expand All @@ -572,12 +580,12 @@ impl LlamaModel {
chat.len(),
add_ass,
buff.as_mut_ptr().cast::<i8>(),
buff.len() as i32,
buff.len().try_into().expect("Buffer size exceeds i32::MAX"),
)
};
assert_eq!(res, buff.len() as i32);
assert_eq!(Ok(res), buff.len().try_into());
}
buff.truncate(res as usize);
buff.truncate(res.try_into().expect("res is negative"));
Ok(String::from_utf8(buff)?)
}
}
Expand Down
2 changes: 1 addition & 1 deletion llama-cpp-2/src/model/params/kv_overrides.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ pub struct KvOverrideValueIterator<'a> {
current: usize,
}

impl<'a> Iterator for KvOverrideValueIterator<'a> {
impl Iterator for KvOverrideValueIterator<'_> {
type Item = (CString, ParamOverrideValue);

fn next(&mut self) -> Option<Self::Item> {
Expand Down
9 changes: 6 additions & 3 deletions llama-cpp-2/src/sampling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,13 @@ impl Debug for LlamaSampler {
}

// this is needed for the dry sampler to typecheck on android
// ...because what is normally an i8, is an u8
// ...because what is normally an i8, is an u8
#[cfg(target_os = "android")]
type CChar = u8;

#[cfg(not(target_os = "android"))]
type CChar = i8;


impl LlamaSampler {
/// Sample and accept a token from the idx-th output of the last evaluation
#[must_use]
Expand Down Expand Up @@ -129,6 +128,7 @@ impl LlamaSampler {
Self::chain(samplers, false)
}

#[allow(clippy::doc_markdown)]
/// Updates the logits l_i' = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original
/// value, the rest are set to -inf
///
Expand Down Expand Up @@ -272,7 +272,10 @@ impl LlamaSampler {
let sampler = unsafe {
llama_cpp_sys_2::llama_sampler_init_dry(
model.vocab_ptr(),
model.n_ctx_train().try_into().expect("n_ctx_train is greater than two billion"),
model
.n_ctx_train()
.try_into()
.expect("n_ctx_train exceeds i32::MAX"),
multiplier,
base,
allowed_length,
Expand Down
Loading

0 comments on commit 5dacedf

Please sign in to comment.