Skip to content

Commit

Permalink
Merge pull request #65 from utilityai/8-metal-on-mac
Browse files Browse the repository at this point in the history
attempt to add metal on mac
  • Loading branch information
MarcusDunn authored Feb 25, 2024
2 parents 2e05e66 + b6e0bf7 commit e9e80e2
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 19 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions llama-cpp-2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ anyhow = "1.0.80"
name = "grammar_bias"
harness = false

[[bench]]
name = "generate"
harness = false

[features]
cublas = ["llama-cpp-sys-2/cublas"]

Expand Down
61 changes: 61 additions & 0 deletions llama-cpp-2/benches/generate.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
use anyhow::Context;
use criterion::{Criterion, criterion_group, criterion_main};
use pprof::criterion::{Output, PProfProfiler};
use llama_cpp_2::context::params::LlamaContextParams;
use llama_cpp_2::llama_backend::LlamaBackend;
use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::model::{AddBos, LlamaModel};
use llama_cpp_2::model::params::LlamaModelParams;
use llama_cpp_2::token::data_array::LlamaTokenDataArray;

fn generate(c: &mut Criterion) {
let api = hf_hub::api::sync::ApiBuilder::new()
.with_progress(true)
.build()
.unwrap();
let file = api
.model("TheBloke/Llama-2-7B-Chat-GGUF".to_string())
.get("llama-2-7b-chat.Q4_K_M.gguf")
.unwrap();
let backend = LlamaBackend::init().unwrap();
let model_params = LlamaModelParams::default();
let model = LlamaModel::load_from_file(&backend, &file, &model_params).unwrap();
let mut ctx = model
.new_context(&backend, LlamaContextParams::default())
.unwrap();

c.bench_function("generate 50 tokens", |b| {
b.iter(|| {
let tokens_list = model.str_to_token("Hello, my name is", AddBos::Always).unwrap();
let mut n_ctx = tokens_list.len() as i32;
let mut batch = LlamaBatch::new(512, 1);
let last_index: i32 = (tokens_list.len() - 1) as i32;
for (i, token) in (0_i32..).zip(tokens_list.into_iter()) {
let is_last = i == last_index;
batch.add(token, i, &[0], is_last).unwrap();
}
ctx.decode(&mut batch).unwrap();

for _ in 0..50 {
let candidates = ctx.candidates_ith(batch.n_tokens() - 1);
let candidates_p = LlamaTokenDataArray::from_iter(candidates, false);
let new_token_id = ctx.sample_token_greedy(candidates_p);
if new_token_id == model.token_eos() {
break;
}
batch.clear();
batch.add(new_token_id, n_ctx, &[0], true).unwrap();
n_ctx += 1;
ctx.decode(&mut batch).unwrap();
}
ctx.clear_kv_cache_seq(0, None, None)
});
});
}

criterion_group!(
name = benches;
config = Criterion::default().with_profiler(PProfProfiler::new(100, Output::Flamegraph(None)));
targets = generate
);
criterion_main!(benches);
1 change: 1 addition & 0 deletions llama-cpp-sys-2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,4 @@ cc = { workspace = true }

[features]
cublas = []

99 changes: 82 additions & 17 deletions llama-cpp-sys-2/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,23 @@ fn main() {

let cublas_enabled = env::var("CARGO_FEATURE_CUBLAS").is_ok();

let mut ggml_cuda = if cublas_enabled { Some(cc::Build::new()) } else { None };

if !Path::new("llama.cpp/ggml.c").exists() {
panic!("llama.cpp seems to not be populated, try running `git submodule update --init --recursive` to init.")
}

let mut ggml = cc::Build::new();
let mut ggml_cuda = if cublas_enabled {
Some(cc::Build::new())
} else {
None
};
let mut llama_cpp = cc::Build::new();

ggml.cpp(false);
llama_cpp.cpp(true);

// https://github.com/ggerganov/llama.cpp/blob/a836c8f534ab789b02da149fbdaf7735500bff74/Makefile#L364-L368
if let Some(ggml_cuda) = &mut ggml_cuda {
for lib in ["cuda", "cublas", "cudart", "cublasLt"] {
for lib in [
"cuda", "cublas", "culibos", "cudart", "cublasLt", "pthread", "dl", "rt",
] {
println!("cargo:rustc-link-lib={}", lib);
}
if !ggml_cuda.get_compiler().is_like_msvc() {
Expand All @@ -39,23 +38,27 @@ fn main() {
ggml_cuda
.flag_if_supported("-mfp16-format=ieee")
.flag_if_supported("-mno-unaligned-access");
ggml.flag_if_supported("-mfp16-format=ieee")
.flag_if_supported("-mno-unaligned-access");
llama_cpp
.flag_if_supported("-mfp16-format=ieee")
.flag_if_supported("-mno-unaligned-access");
ggml_cuda
.flag_if_supported("-mfp16-format=ieee")
ggml.flag_if_supported("-mfp16-format=ieee")
.flag_if_supported("-mno-unaligned-access");
}

ggml_cuda
.cuda(true)
.flag("-arch=all")
.file("llama.cpp/ggml-cuda.cu");
.file("llama.cpp/ggml-cuda.cu")
.include("llama.cpp");

if ggml_cuda.get_compiler().is_like_msvc() {
ggml_cuda.std("c++14");
} else {
ggml_cuda.std("c++17");
ggml_cuda
.flag("-std=c++11")
.std("c++11");
}

ggml.define("GGML_USE_CUBLAS", None);
Expand All @@ -65,22 +68,36 @@ fn main() {

// https://github.com/ggerganov/llama.cpp/blob/191221178f51b6e81122c5bda0fd79620e547d07/Makefile#L133-L141
if cfg!(target_os = "macos") {
assert!(!cublas_enabled, "CUBLAS is not supported on macOS");

println!("cargo:rustc-link-lib=framework=Metal");
println!("cargo:rustc-link-lib=framework=Foundation");
println!("cargo:rustc-link-lib=framework=MetalPerformanceShaders");
println!("cargo:rustc-link-lib=framework=MetalKit");

llama_cpp.define("_DARWIN_C_SOURCE", None);

// https://github.com/ggerganov/llama.cpp/blob/3c0d25c4756742ebf15ad44700fabc0700c638bd/Makefile#L340-L343
llama_cpp.define("GGML_USE_METAL", None);
llama_cpp.define("GGML_USE_ACCELERATE", None);
llama_cpp.define("ACCELERATE_NEW_LAPACK", None);
llama_cpp.define("ACCELERATE_LAPACK_ILP64", None);
println!("cargo:rustc-link-arg=framework=Accelerate");

metal_hack(&mut ggml);
ggml.include("./llama.cpp/ggml-metal.h");
}

if cfg!(target_os = "dragonfly") {
llama_cpp.define("__BSD_VISIBLE", None);
}

if let Some(ggml_cuda) = ggml_cuda {
println!("compiling ggml-cuda");
ggml_cuda.compile("ggml-cuda");
}

if cfg!(target_os = "linux") {
ggml.define("_GNU_SOURCE", None);
}

ggml.std("c17")
ggml.std("c11")
.include("./llama.cpp")
.file("llama.cpp/ggml.c")
.file("llama.cpp/ggml-alloc.c")
.file("llama.cpp/ggml-backend.c")
Expand All @@ -89,14 +106,23 @@ fn main() {

llama_cpp
.define("_XOPEN_SOURCE", Some("600"))
.std("c++17")
.include("llama.cpp")
.std("c++11")
.file("llama.cpp/llama.cpp");

if let Some(ggml_cuda) = ggml_cuda {
println!("compiling ggml-cuda");
ggml_cuda.compile("ggml-cuda");
println!("compiled ggml-cuda");
}

println!("compiling ggml");
ggml.compile("ggml");
println!("compiled ggml");

println!("compiling llama");
llama_cpp.compile("llama");
println!("compiled llama");

let header = "llama.cpp/llama.h";

Expand All @@ -116,3 +142,42 @@ fn main() {
.write_to_file(out_path.join("bindings.rs"))
.expect("failed to write bindings to file");
}

// courtesy of https://github.com/rustformers/llm
fn metal_hack(build: &mut cc::Build) {
const GGML_METAL_METAL_PATH: &str = "llama.cpp/ggml-metal.metal";
const GGML_METAL_PATH: &str = "llama.cpp/ggml-metal.m";

let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR is not defined"));

let ggml_metal_path = {
let ggml_metal_metal = std::fs::read_to_string(GGML_METAL_METAL_PATH)
.expect("Could not read ggml-metal.metal")
.replace('\\', "\\\\")
.replace('\n', "\\n")
.replace('\r', "\\r")
.replace('\"', "\\\"");

let ggml_metal =
std::fs::read_to_string(GGML_METAL_PATH).expect("Could not read ggml-metal.m");

let needle = r#"NSString * src = [NSString stringWithContentsOfFile:sourcePath encoding:NSUTF8StringEncoding error:&error];"#;
if !ggml_metal.contains(needle) {
panic!("ggml-metal.m does not contain the needle to be replaced; the patching logic needs to be reinvestigated. Contact a `llama-cpp-sys-2` developer!");
}

// Replace the runtime read of the file with a compile-time string
let ggml_metal = ggml_metal.replace(
needle,
&format!(r#"NSString * src = @"{ggml_metal_metal}";"#),
);

let patched_ggml_metal_path = out_dir.join("ggml-metal.m");
std::fs::write(&patched_ggml_metal_path, ggml_metal)
.expect("Could not write temporary patched ggml-metal.m");

patched_ggml_metal_path
};

build.file(ggml_metal_path);
}

0 comments on commit e9e80e2

Please sign in to comment.