Skip to content

Commit

Permalink
Update llama.cpp to latest version
Browse files Browse the repository at this point in the history
  • Loading branch information
nkoppel committed Jan 21, 2025
1 parent 904fbda commit 6942e13
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 65 deletions.
4 changes: 2 additions & 2 deletions llama-cpp-2/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ impl<'model> LlamaContext<'model> {
scale: f32,
) -> Result<(), LlamaLoraAdapterSetError> {
let err_code = unsafe {
llama_cpp_sys_2::llama_lora_adapter_set(
llama_cpp_sys_2::llama_set_adapter_lora(
self.context.as_ptr(),
adapter.lora_adapter.as_ptr(),
scale,
Expand All @@ -342,7 +342,7 @@ impl<'model> LlamaContext<'model> {
adapter: &mut LlamaLoraAdapter,
) -> Result<(), LlamaLoraAdapterRemoveError> {
let err_code = unsafe {
llama_cpp_sys_2::llama_lora_adapter_remove(
llama_cpp_sys_2::llama_rm_adapter_lora(
self.context.as_ptr(),
adapter.lora_adapter.as_ptr(),
)
Expand Down
31 changes: 17 additions & 14 deletions llama-cpp-2/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub struct LlamaModel {
#[repr(transparent)]
#[allow(clippy::module_name_repetitions)]
pub struct LlamaLoraAdapter {
pub(crate) lora_adapter: NonNull<llama_cpp_sys_2::llama_lora_adapter>,
pub(crate) lora_adapter: NonNull<llama_cpp_sys_2::llama_adapter_lora>,
}

/// A Safe wrapper around `llama_chat_message`
Expand Down Expand Up @@ -74,6 +74,10 @@ unsafe impl Send for LlamaModel {}
unsafe impl Sync for LlamaModel {}

impl LlamaModel {
pub(crate) fn vocab_ptr(&self) -> *const llama_cpp_sys_2::llama_vocab {
unsafe { llama_cpp_sys_2::llama_model_get_vocab(self.model.as_ptr()) }
}

/// get the number of tokens the model was trained on
///
/// # Panics
Expand All @@ -99,28 +103,28 @@ impl LlamaModel {
/// Get the beginning of stream token.
#[must_use]
pub fn token_bos(&self) -> LlamaToken {
let token = unsafe { llama_cpp_sys_2::llama_token_bos(self.model.as_ptr()) };
let token = unsafe { llama_cpp_sys_2::llama_token_bos(self.vocab_ptr()) };
LlamaToken(token)
}

/// Get the end of stream token.
#[must_use]
pub fn token_eos(&self) -> LlamaToken {
let token = unsafe { llama_cpp_sys_2::llama_token_eos(self.model.as_ptr()) };
let token = unsafe { llama_cpp_sys_2::llama_token_eos(self.vocab_ptr()) };
LlamaToken(token)
}

/// Get the newline token.
#[must_use]
pub fn token_nl(&self) -> LlamaToken {
let token = unsafe { llama_cpp_sys_2::llama_token_nl(self.model.as_ptr()) };
let token = unsafe { llama_cpp_sys_2::llama_token_nl(self.vocab_ptr()) };
LlamaToken(token)
}

/// Check if a token represents the end of generation (end of turn, end of sequence, etc.)
#[must_use]
pub fn is_eog_token(&self, token: LlamaToken) -> bool {
unsafe { llama_cpp_sys_2::llama_token_is_eog(self.model.as_ptr(), token.0) }
unsafe { llama_cpp_sys_2::llama_token_is_eog(self.vocab_ptr(), token.0) }
}

/// Get the decoder start token.
Expand Down Expand Up @@ -225,7 +229,7 @@ impl LlamaModel {

let size = unsafe {
llama_cpp_sys_2::llama_tokenize(
self.model.as_ptr(),
self.vocab_ptr(),
c_string.as_ptr(),
c_int::try_from(c_string.as_bytes().len())?,
buffer.as_mut_ptr() as *mut llama_cpp_sys_2::llama_token,
Expand All @@ -241,7 +245,7 @@ impl LlamaModel {
buffer.reserve_exact(usize::try_from(-size).expect("usize's are larger "));
unsafe {
llama_cpp_sys_2::llama_tokenize(
self.model.as_ptr(),
self.vocab_ptr(),
c_string.as_ptr(),
c_int::try_from(c_string.as_bytes().len())?,
buffer.as_mut_ptr() as *mut llama_cpp_sys_2::llama_token,
Expand All @@ -268,7 +272,7 @@ impl LlamaModel {
/// If the token type is not known to this library.
#[must_use]
pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttrs {
let token_type = unsafe { llama_cpp_sys_2::llama_token_get_attr(self.model.as_ptr(), id) };
let token_type = unsafe { llama_cpp_sys_2::llama_token_get_attr(self.vocab_ptr(), id) };
LlamaTokenAttrs::try_from(token_type).expect("token type is valid")
}

Expand Down Expand Up @@ -347,7 +351,7 @@ impl LlamaModel {
let lstrip = lstrip.map_or(0, |it| i32::from(it.get()));
let size = unsafe {
llama_cpp_sys_2::llama_token_to_piece(
self.model.as_ptr(),
self.vocab_ptr(),
token.0,
buf,
len,
Expand All @@ -374,7 +378,7 @@ impl LlamaModel {
/// without issue.
#[must_use]
pub fn n_vocab(&self) -> i32 {
unsafe { llama_cpp_sys_2::llama_n_vocab(self.model.as_ptr()) }
unsafe { llama_cpp_sys_2::llama_n_vocab(self.vocab_ptr()) }
}

/// The type of vocab the model was trained on.
Expand All @@ -384,7 +388,8 @@ impl LlamaModel {
/// If llama-cpp emits a vocab type that is not known to this library.
#[must_use]
pub fn vocab_type(&self) -> VocabType {
let vocab_type = unsafe { llama_cpp_sys_2::llama_vocab_type(self.model.as_ptr()) };
// llama_cpp_sys_2::llama_model_get_vocab
let vocab_type = unsafe { llama_cpp_sys_2::llama_vocab_type(self.vocab_ptr()) };
VocabType::try_from(vocab_type).expect("invalid vocab type")
}

Expand Down Expand Up @@ -479,7 +484,7 @@ impl LlamaModel {

let cstr = CString::new(path)?;
let adapter =
unsafe { llama_cpp_sys_2::llama_lora_adapter_init(self.model.as_ptr(), cstr.as_ptr()) };
unsafe { llama_cpp_sys_2::llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr()) };

let adapter = NonNull::new(adapter).ok_or(LlamaLoraAdapterInitError::NullResult)?;

Expand Down Expand Up @@ -548,7 +553,6 @@ impl LlamaModel {

let res = unsafe {
llama_cpp_sys_2::llama_chat_apply_template(
self.model.as_ptr(),
tmpl_ptr,
chat.as_ptr(),
chat.len(),
Expand All @@ -563,7 +567,6 @@ impl LlamaModel {

let res = unsafe {
llama_cpp_sys_2::llama_chat_apply_template(
self.model.as_ptr(),
tmpl_ptr,
chat.as_ptr(),
chat.len(),
Expand Down
52 changes: 4 additions & 48 deletions llama-cpp-2/src/sampling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ impl LlamaSampler {

let sampler = unsafe {
llama_cpp_sys_2::llama_sampler_init_grammar(
model.model.as_ptr(),
model.vocab_ptr(),
grammar_str.as_ptr(),
grammar_root.as_ptr(),
)
Expand All @@ -264,14 +264,15 @@ impl LlamaSampler {
) -> Self {
let seq_breakers: Vec<CString> = seq_breakers
.into_iter()
.map(|s| CString::new(s.as_ref()).unwrap())
.map(|s| CString::new(s.as_ref()).expect("A sequence breaker contains null bytes"))
.collect();
let mut seq_breaker_pointers: Vec<*const CChar> =
seq_breakers.iter().map(|s| s.as_ptr()).collect();

let sampler = unsafe {
llama_cpp_sys_2::llama_sampler_init_dry(
model.model.as_ptr(),
model.vocab_ptr(),
model.n_ctx_train().try_into().expect("n_ctx_train is greater than two billion"),
multiplier,
base,
allowed_length,
Expand All @@ -286,74 +287,29 @@ impl LlamaSampler {
/// Penalizes tokens for being present in the context.
///
/// Parameters:
/// - ``n_vocab``: [`LlamaModel::n_vocab`]
/// - ``special_eos)id``: [`LlamaModel::token_eos`]
/// - ``linefeed_id``: [`LlamaModel::token_nl`]
/// - ``penalty_last_n``: last n tokens to penalize (0 = disable penalty, -1 = context size)
/// - ``penalty_repeat``: 1.0 = disabled
/// - ``penalty_freq``: 0.0 = disabled
/// - ``penalty_present``: 0.0 = disabled
/// - ``penalize_nl``: consider newlines as a repeatable token
/// - ``ignore_eos``: ignore the end-of-sequence token
#[allow(clippy::too_many_arguments)]
#[must_use]
pub fn penalties(
n_vocab: i32,
special_eos_id: i32,
linefeed_id: i32,
penalty_last_n: i32,
penalty_repeat: f32,
penalty_freq: f32,
penalty_present: f32,
penalize_nl: bool,
ignore_eos: bool,
) -> Self {
let sampler = unsafe {
llama_cpp_sys_2::llama_sampler_init_penalties(
n_vocab,
special_eos_id,
linefeed_id,
penalty_last_n,
penalty_repeat,
penalty_freq,
penalty_present,
penalize_nl,
ignore_eos,
)
};
Self { sampler }
}

/// Same as [`Self::penalties`], but with `n_vocab`, `special_eos_id`, and `linefeed_id`
/// initialized from `model`, `penalize_nl = false`, and `ignore_eos = true`.
///
/// Parameters:
/// - ``model``: The model's tokenizer to use to initialize the sampler.
/// - ``penalty_last_n``: last n tokens to penalize (0 = disable penalty, -1 = context size)
/// - ``penalty_repeat``: 1.0 = disabled
/// - ``penalty_freq``: 0.0 = disabled
/// - ``penalty_present``: 0.0 = disabled
#[must_use]
pub fn penalties_simple(
model: &LlamaModel,
penalty_last_n: i32,
penalty_repeat: f32,
penalty_freq: f32,
penalty_present: f32,
) -> Self {
Self::penalties(
model.n_vocab(),
model.token_eos().0,
model.token_nl().0,
penalty_last_n,
penalty_repeat,
penalty_freq,
penalty_present,
false,
true,
)
}

/// Mirostat 1.0 algorithm described in the paper <https://arxiv.org/abs/2007.14966>. Uses tokens instead of words.
///
/// # Parameters:
Expand Down
2 changes: 1 addition & 1 deletion llama-cpp-sys-2/llama.cpp

0 comments on commit 6942e13

Please sign in to comment.