Skip to content

Commit

Permalink
Add metal support
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Feb 6, 2025
1 parent 1baec7c commit 25cbfca
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 10 deletions.
2 changes: 1 addition & 1 deletion candle-core/src/quantized/metal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ impl From<GgmlDType> for candle_metal_kernels::GgmlDType {
GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K,
GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K,
GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K,
GgmlDType::Iq4Xs => candle_metal_kernels::GgmlDType::Q8_0,
GgmlDType::Iq4Xs => candle_metal_kernels::GgmlDType::Iq4Xs,
GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16,
GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32,
GgmlDType::BF16 => candle_metal_kernels::GgmlDType::F16,
Expand Down
36 changes: 35 additions & 1 deletion candle-core/tests/quantized_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use candle_core::{
test_utils::to_vec2_round,
DType, Device, IndexOp, Module, Result, Tensor, Var,
};
use quantized::{k_quants, GgmlType};
use quantized::{iq_quants, k_quants, GgmlType};
use rand::prelude::*;

const GGML_TEST_SIZE: usize = 32 * 128;
Expand Down Expand Up @@ -1117,6 +1117,7 @@ fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result<f32> {
GgmlDType::Q5_0 => 0.001353,
GgmlDType::Q5_1 => 0.00149,
GgmlDType::Q8_0 => 0.000092,
GgmlDType::Iq4Xs => 0.001903,

// Not from the ggml repo.
GgmlDType::Q8K => 0.00065,
Expand Down Expand Up @@ -1286,6 +1287,13 @@ quantized_matmul!(
quantized_matmul_q3k_metal,
GgmlDType::Q3K
);
quantized_matmul!(
quantized_matmul_iq4xs_bis,
quantized_matmul_iq4xs_cpu,
quantized_matmul_iq4xs_cuda,
quantized_matmul_iq4xs_metal,
GgmlDType::Q4K
);
quantized_matmul!(
quantized_matmul_q4k_bis,
quantized_matmul_q4k_cpu,
Expand Down Expand Up @@ -1394,6 +1402,32 @@ fn quantized_matmul_q4k() -> Result<()> {
Ok(())
}

#[test]
fn quantized_matmul_iq4xs() -> Result<()> {
use iq_quants::BlockIQ4xs;

let cpu = &Device::Cpu;
let (m, k, n) = (11, 512, 21);
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?;
assert_eq!(mm.dims(), [m, n]);
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);

let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Iq4Xs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?;

assert_eq!(mm.dims(), [m, n]);
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.442, 1.509, -0.293, 1.631]);

ggml_matmul_error_test::<BlockIQ4xs>()?;

Ok(())
}

#[test]
fn quantized_matmul_q5k() -> Result<()> {
use k_quants::BlockQ5K;
Expand Down
29 changes: 21 additions & 8 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2447,6 +2447,7 @@ pub enum GgmlDType {
F16,
F32,
BF16,
Iq4Xs
}

#[allow(clippy::too_many_arguments)]
Expand Down Expand Up @@ -2486,7 +2487,7 @@ pub fn call_quantized_matmul_mv_t(
let r2: u32 = (ne12 / ne02) as u32;
let r3: u32 = (ne13 / ne03) as u32;

let (nth0, nth1, align) = match dtype {
let (nth0, nth1, align, mem_size_bytes) = match dtype {
GgmlDType::Q4_0
| GgmlDType::Q4_1
| GgmlDType::Q5_0
Expand All @@ -2496,46 +2497,52 @@ pub fn call_quantized_matmul_mv_t(
let nth0 = 8;
let nth1 = 8;
let align = 8;
(nth0, nth1, align)
(nth0, nth1, align, None)
}
GgmlDType::Q2K => {
// Fixing a bug in Metal for GGML
// https://github.com/ggerganov/llama.cpp/blob/b8109bc0139f15a5b321909f47510b89dca47ffc/ggml-metal.m#L1576
let nth0 = 2;
let nth1 = 32;
let align = 4;
(nth0, nth1, align)
(nth0, nth1, align, None)
}
GgmlDType::Q4K => {
let nth0 = 4;
let nth1 = 8;
let align = 4;
(nth0, nth1, align)
(nth0, nth1, align, None)
}
GgmlDType::Q3K | GgmlDType::Q5K => {
let nth0 = 2;
let nth1 = 32;
let align = 4;
(nth0, nth1, align)
(nth0, nth1, align, None)
}
GgmlDType::Q6K => {
let nth0 = 2;
let nth1 = 32;
let align = 2;
(nth0, nth1, align)
(nth0, nth1, align, None)
}
GgmlDType::F16 | GgmlDType::BF16 | GgmlDType::Q8K => {
// Original implem uses rows
let nth0 = 32;
let nth1 = 1;
let align = 8;
(nth0, nth1, align)
(nth0, nth1, align, None)
}
GgmlDType::F32 => {
let nth0 = 32;
let nth1 = 1;
let align = 8;
(nth0, nth1, align)
(nth0, nth1, align, None)
}
GgmlDType::Iq4Xs => {
let nth0 = 4;
let nth1 = 16;
let align = 4;
(nth0, nth1, align, Some(32*std::mem::size_of::<f32>()))
}
};
let thread_groups_count = MTLSize {
Expand Down Expand Up @@ -2564,13 +2571,18 @@ pub fn call_quantized_matmul_mv_t(
GgmlDType::F16 => "kernel_mul_mv_f16_f32",
GgmlDType::BF16 => "kernel_mul_mv_bf16_f32",
GgmlDType::F32 => "kernel_mul_mv_f32_f32",
GgmlDType::Iq4Xs => "kernel_mul_mm_iq4_xs_f32",
};

let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);

if let Some(mem_size_bytes) = mem_size_bytes {
encoder.set_threadgroup_memory_length(0, mem_size_bytes as u64);
}

set_params!(
encoder,
(
Expand Down Expand Up @@ -2672,6 +2684,7 @@ pub fn call_quantized_matmul_mm_t(
GgmlDType::F16 => "kernel_mul_mm_f16_f32",
GgmlDType::BF16 => "kernel_mul_mm_bf16_f32",
GgmlDType::F32 => "kernel_mul_mm_f32_f32",
GgmlDType::Iq4Xs => "kernel_mul_mm_iq4_xs_f32",
};

let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?;
Expand Down

0 comments on commit 25cbfca

Please sign in to comment.