Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose model & backend informational methods #666

Merged
merged 2 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions llama-cpp-2/src/llama_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,21 @@ impl LlamaBackend {
Ok(LlamaBackend {})
}

/// Was the code built for a GPU backend & is a supported one available.
pub fn supports_gpu_offload(&self) -> bool {
unsafe { llama_cpp_sys_2::llama_supports_gpu_offload() }
}

/// Does this platform support loading the model via mmap.
pub fn supports_mmap(&self) -> bool {
unsafe { llama_cpp_sys_2::llama_supports_mmap() }
}

/// Does this platform support locking the model in RAM.
pub fn supports_mlock(&self) -> bool {
unsafe { llama_cpp_sys_2::llama_supports_mlock() }
}

/// Change the output of llama.cpp's logging to be voided instead of pushed to `stderr`.
pub fn void_logs(&mut self) {
unsafe extern "C" fn void_log(
Expand Down
53 changes: 53 additions & 0 deletions llama-cpp-2/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,15 @@ impl LlamaChatMessage {
}
}

/// The Rope type that's used within the model.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RopeType {
Norm,
NeoX,
MRope,
Vision,
}

/// How to determine if we should prepend a bos token to tokens
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AddBos {
Expand Down Expand Up @@ -446,6 +455,50 @@ impl LlamaModel {
unsafe { llama_cpp_sys_2::llama_n_embd(self.model.as_ptr()) }
}

/// Returns the total size of all the tensors in the model in bytes.
pub fn size(&self) -> u64 {
unsafe { llama_cpp_sys_2::llama_model_size(self.model.as_ptr()) }
}

/// Returns the number of parameters in the model.
pub fn n_params(&self) -> u64 {
unsafe { llama_cpp_sys_2::llama_model_n_params(self.model.as_ptr()) }
}

/// Returns whether the model is a recurrent network (Mamba, RWKV, etc)
pub fn is_recurrent(&self) -> bool {
unsafe { llama_cpp_sys_2::llama_model_is_recurrent(self.model.as_ptr()) }
}

/// Returns the number of layers within the model.
pub fn n_layer(&self) -> u32 {
// It's never possible for this to panic because while the API interface is defined as an int32_t,
// the field it's accessing is a uint32_t.
u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_layer(self.model.as_ptr()) }).unwrap()
}

/// Returns the number of attention heads within the model.
pub fn n_head(&self) -> u32 {
// It's never possible for this to panic because while the API interface is defined as an int32_t,
// the field it's accessing is a uint32_t.
u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_head(self.model.as_ptr()) }).unwrap()
}

/// Returns the rope type of the model.
pub fn rope_type(&self) -> Option<RopeType> {
match unsafe { llama_cpp_sys_2::llama_model_rope_type(self.model.as_ptr()) } {
llama_cpp_sys_2::LLAMA_ROPE_TYPE_NONE => None,
llama_cpp_sys_2::LLAMA_ROPE_TYPE_NORM => Some(RopeType::Norm),
llama_cpp_sys_2::LLAMA_ROPE_TYPE_NEOX => Some(RopeType::NeoX),
llama_cpp_sys_2::LLAMA_ROPE_TYPE_MROPE => Some(RopeType::MRope),
llama_cpp_sys_2::LLAMA_ROPE_TYPE_VISION => Some(RopeType::Vision),
rope_type => {
tracing::error!(rope_type = rope_type, "Unexpected rope type from llama.cpp");
None
}
}
}

fn get_chat_template_impl(
&self,
capacity: usize,
Expand Down
Loading