diff --git a/llama-cpp-2/Cargo.toml b/llama-cpp-2/Cargo.toml index a927efc4..9cd9e1ff 100644 --- a/llama-cpp-2/Cargo.toml +++ b/llama-cpp-2/Cargo.toml @@ -18,7 +18,7 @@ tracing = { workspace = true } encoding_rs = { workspace = true } [features] -default = ["openmp"] +default = ["openmp", "android-shared-stdcxx"] cuda = ["llama-cpp-sys-2/cuda"] metal = ["llama-cpp-sys-2/metal"] dynamic-link = ["llama-cpp-sys-2/dynamic-link"] @@ -26,6 +26,8 @@ vulkan = ["llama-cpp-sys-2/vulkan"] native = ["llama-cpp-sys-2/native"] openmp = ["llama-cpp-sys-2/openmp"] sampler = [] +# Only has an impact on Android. +android-shared-stdcxx = ["llama-cpp-sys-2/shared-stdcxx"] [target.'cfg(all(target_os = "macos", any(target_arch = "aarch64", target_arch = "arm64")))'.dependencies] diff --git a/llama-cpp-2/src/model.rs b/llama-cpp-2/src/model.rs index 85927ec6..3dc02ee9 100644 --- a/llama-cpp-2/src/model.rs +++ b/llama-cpp-2/src/model.rs @@ -1,5 +1,5 @@ //! A safe wrapper around `llama_model`. -use std::ffi::CString; +use std::ffi::{c_char, CString}; use std::num::NonZeroU16; use std::os::raw::c_int; use std::path::Path; @@ -565,7 +565,7 @@ impl LlamaModel { chat.as_ptr(), chat.len(), add_ass, - buff.as_mut_ptr().cast::(), + buff.as_mut_ptr().cast::(), buff.len().try_into().expect("Buffer size exceeds i32::MAX"), ) }; @@ -579,7 +579,7 @@ impl LlamaModel { chat.as_ptr(), chat.len(), add_ass, - buff.as_mut_ptr().cast::(), + buff.as_mut_ptr().cast::(), buff.len().try_into().expect("Buffer size exceeds i32::MAX"), ) }; diff --git a/llama-cpp-2/src/sampling.rs b/llama-cpp-2/src/sampling.rs index b3a2cf4f..d79f351b 100644 --- a/llama-cpp-2/src/sampling.rs +++ b/llama-cpp-2/src/sampling.rs @@ -1,7 +1,7 @@ //! Safe wrapper around `llama_sampler`. use std::borrow::Borrow; -use std::ffi::CString; +use std::ffi::{c_char, CString}; use std::fmt::{Debug, Formatter}; use crate::context::LlamaContext; @@ -20,14 +20,6 @@ impl Debug for LlamaSampler { } } -// this is needed for the dry sampler to typecheck on android -// ...because what is normally an i8, is an u8 -#[cfg(target_os = "android")] -type CChar = u8; - -#[cfg(not(target_os = "android"))] -type CChar = i8; - impl LlamaSampler { /// Sample and accept a token from the idx-th output of the last evaluation #[must_use] @@ -266,7 +258,7 @@ impl LlamaSampler { .into_iter() .map(|s| CString::new(s.as_ref()).expect("A sequence breaker contains null bytes")) .collect(); - let mut seq_breaker_pointers: Vec<*const CChar> = + let mut seq_breaker_pointers: Vec<*const c_char> = seq_breakers.iter().map(|s| s.as_ptr()).collect(); let sampler = unsafe { diff --git a/llama-cpp-sys-2/Cargo.toml b/llama-cpp-sys-2/Cargo.toml index 42b6f026..efcb7099 100644 --- a/llama-cpp-sys-2/Cargo.toml +++ b/llama-cpp-sys-2/Cargo.toml @@ -71,3 +71,5 @@ dynamic-link = [] vulkan = [] native = [] openmp = [] +# Only has an impact on Android. +shared-stdcxx = [] \ No newline at end of file diff --git a/llama-cpp-sys-2/build.rs b/llama-cpp-sys-2/build.rs index 7fff6bba..2d8e9630 100644 --- a/llama-cpp-sys-2/build.rs +++ b/llama-cpp-sys-2/build.rs @@ -229,22 +229,44 @@ fn main() { config.static_crt(static_crt); } - if target.contains("android") && target.contains("aarch64") { + if target.contains("android") { // build flags for android taken from this doc // https://github.com/ggerganov/llama.cpp/blob/master/docs/android.md let android_ndk = env::var("ANDROID_NDK") .expect("Please install Android NDK and ensure that ANDROID_NDK env variable is set"); + + println!("cargo::rerun-if-env-changed=ANDROID_NDK"); + config.define( "CMAKE_TOOLCHAIN_FILE", format!("{android_ndk}/build/cmake/android.toolchain.cmake"), ); - config.define("ANDROID_ABI", "arm64-v8a"); - config.define("ANDROID_PLATFORM", "android-28"); - config.define("CMAKE_SYSTEM_PROCESSOR", "arm64"); - config.define("CMAKE_C_FLAGS", "-march=armv8.7a"); - config.define("CMAKE_CXX_FLAGS", "-march=armv8.7a"); - config.define("GGML_OPENMP", "OFF"); + if env::var("ANDROID_PLATFORM").is_ok() { + println!("cargo::rerun-if-env-changed=ANDROID_PLATFORM"); + } else { + config.define("ANDROID_PLATFORM", "android-28"); + } + if target.contains("aarch64") { + config.cflag("-march=armv8.7a"); + config.cxxflag("-march=armv8.7a"); + } else if target.contains("armv7") { + config.cflag("-march=armv8.7a"); + config.cxxflag("-march=armv8.7a"); + } else if target.contains("x86_64") { + config.cflag("-march=x86-64"); + config.cxxflag("-march=x86-64"); + } else if target.contains("i686") { + config.cflag("-march=i686"); + config.cxxflag("-march=i686"); + } else { + // Rather than guessing just fail. + panic!("Unsupported Android target {target}"); + } config.define("GGML_LLAMAFILE", "OFF"); + if cfg!(feature = "shared-stdcxx") { + println!("cargo:rustc-link-lib=dylib=stdc++"); + println!("cargo:rustc-link-lib=c++_shared"); + } } if cfg!(feature = "vulkan") { @@ -266,8 +288,13 @@ fn main() { config.define("GGML_CUDA", "ON"); } - if cfg!(feature = "openmp") { + // Android doesn't have OpenMP support AFAICT and openmp is a default feature. Do this here + // rather than modifying the defaults in Cargo.toml just in case someone enables the OpenMP feature + // and tries to build for Android anyway. + if cfg!(feature = "openmp") && !target.contains("android") { config.define("GGML_OPENMP", "ON"); + } else { + config.define("GGML_OPENMP", "OFF"); } // General