diff --git a/llama-cpp-2/src/llama_backend.rs b/llama-cpp-2/src/llama_backend.rs index 938356f7..1cc3fa3d 100644 --- a/llama-cpp-2/src/llama_backend.rs +++ b/llama-cpp-2/src/llama_backend.rs @@ -70,6 +70,21 @@ impl LlamaBackend { Ok(LlamaBackend {}) } + /// Was the code built for a GPU backend & is a supported one available. + pub fn supports_gpu_offload(&self) -> bool { + unsafe { llama_cpp_sys_2::llama_supports_gpu_offload() } + } + + /// Does this platform support loading the model via mmap. + pub fn supports_mmap(&self) -> bool { + unsafe { llama_cpp_sys_2::llama_supports_mmap() } + } + + /// Does this platform support locking the model in RAM. + pub fn supports_mlock(&self) -> bool { + unsafe { llama_cpp_sys_2::llama_supports_mlock() } + } + /// Change the output of llama.cpp's logging to be voided instead of pushed to `stderr`. pub fn void_logs(&mut self) { unsafe extern "C" fn void_log( diff --git a/llama-cpp-2/src/model.rs b/llama-cpp-2/src/model.rs index 8b19c4bb..6425dc79 100644 --- a/llama-cpp-2/src/model.rs +++ b/llama-cpp-2/src/model.rs @@ -92,6 +92,15 @@ impl LlamaChatMessage { } } +/// The Rope type that's used within the model. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RopeType { + Norm, + NeoX, + MRope, + Vision, +} + /// How to determine if we should prepend a bos token to tokens #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum AddBos { @@ -446,6 +455,50 @@ impl LlamaModel { unsafe { llama_cpp_sys_2::llama_n_embd(self.model.as_ptr()) } } + /// Returns the total size of all the tensors in the model in bytes. + pub fn size(&self) -> u64 { + unsafe { llama_cpp_sys_2::llama_model_size(self.model.as_ptr()) } + } + + /// Returns the number of parameters in the model. + pub fn n_params(&self) -> u64 { + unsafe { llama_cpp_sys_2::llama_model_n_params(self.model.as_ptr()) } + } + + /// Returns whether the model is a recurrent network (Mamba, RWKV, etc) + pub fn is_recurrent(&self) -> bool { + unsafe { llama_cpp_sys_2::llama_model_is_recurrent(self.model.as_ptr()) } + } + + /// Returns the number of layers within the model. + pub fn n_layer(&self) -> u32 { + // It's never possible for this to panic because while the API interface is defined as an int32_t, + // the field it's accessing is a uint32_t. + u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_layer(self.model.as_ptr()) }).unwrap() + } + + /// Returns the number of attention heads within the model. + pub fn n_head(&self) -> u32 { + // It's never possible for this to panic because while the API interface is defined as an int32_t, + // the field it's accessing is a uint32_t. + u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_head(self.model.as_ptr()) }).unwrap() + } + + /// Returns the rope type of the model. + pub fn rope_type(&self) -> Option { + match unsafe { llama_cpp_sys_2::llama_model_rope_type(self.model.as_ptr()) } { + llama_cpp_sys_2::LLAMA_ROPE_TYPE_NONE => None, + llama_cpp_sys_2::LLAMA_ROPE_TYPE_NORM => Some(RopeType::Norm), + llama_cpp_sys_2::LLAMA_ROPE_TYPE_NEOX => Some(RopeType::NeoX), + llama_cpp_sys_2::LLAMA_ROPE_TYPE_MROPE => Some(RopeType::MRope), + llama_cpp_sys_2::LLAMA_ROPE_TYPE_VISION => Some(RopeType::Vision), + rope_type => { + tracing::error!(rope_type = rope_type, "Unexpected rope type from llama.cpp"); + None + } + } + } + fn get_chat_template_impl( &self, capacity: usize,