diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs index 511a5b6ae2..d4bf2afeda 100644 --- a/candle-core/src/quantized/metal.rs +++ b/candle-core/src/quantized/metal.rs @@ -391,7 +391,7 @@ impl From for candle_metal_kernels::GgmlDType { GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K, GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K, GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K, - GgmlDType::Iq4Xs => candle_metal_kernels::GgmlDType::Q8_0, + GgmlDType::Iq4Xs => candle_metal_kernels::GgmlDType::Iq4Xs, GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16, GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32, GgmlDType::BF16 => candle_metal_kernels::GgmlDType::F16, diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index 593c6b9bd6..bce0840f84 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -5,7 +5,7 @@ use candle_core::{ test_utils::to_vec2_round, DType, Device, IndexOp, Module, Result, Tensor, Var, }; -use quantized::{k_quants, GgmlType}; +use quantized::{iq_quants, k_quants, GgmlType}; use rand::prelude::*; const GGML_TEST_SIZE: usize = 32 * 128; @@ -1117,6 +1117,7 @@ fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result { GgmlDType::Q5_0 => 0.001353, GgmlDType::Q5_1 => 0.00149, GgmlDType::Q8_0 => 0.000092, + GgmlDType::Iq4Xs => 0.001903, // Not from the ggml repo. GgmlDType::Q8K => 0.00065, @@ -1286,6 +1287,13 @@ quantized_matmul!( quantized_matmul_q3k_metal, GgmlDType::Q3K ); +quantized_matmul!( + quantized_matmul_iq4xs_bis, + quantized_matmul_iq4xs_cpu, + quantized_matmul_iq4xs_cuda, + quantized_matmul_iq4xs_metal, + GgmlDType::Q4K +); quantized_matmul!( quantized_matmul_q4k_bis, quantized_matmul_q4k_cpu, @@ -1394,6 +1402,32 @@ fn quantized_matmul_q4k() -> Result<()> { Ok(()) } +#[test] +fn quantized_matmul_iq4xs() -> Result<()> { + use iq_quants::BlockIQ4xs; + + let cpu = &Device::Cpu; + let (m, k, n) = (11, 512, 21); + let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?; + assert_eq!(mm.dims(), [m, n]); + let dst = mm.flatten_all()?.to_vec1::()?; + let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); + assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); + + let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Iq4Xs)?; + let rhs = quantized::QMatMul::from_qtensor(rhs)?; + let mm = rhs.forward(&lhs)?; + + assert_eq!(mm.dims(), [m, n]); + let dst = mm.flatten_all()?.to_vec1::()?; + let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); + assert_eq!(dst, [1.442, 1.509, -0.293, 1.631]); + + ggml_matmul_error_test::()?; + + Ok(()) +} + #[test] fn quantized_matmul_q5k() -> Result<()> { use k_quants::BlockQ5K; diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 706e07c1e4..c6710be672 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -2447,6 +2447,7 @@ pub enum GgmlDType { F16, F32, BF16, + Iq4Xs } #[allow(clippy::too_many_arguments)] @@ -2486,7 +2487,7 @@ pub fn call_quantized_matmul_mv_t( let r2: u32 = (ne12 / ne02) as u32; let r3: u32 = (ne13 / ne03) as u32; - let (nth0, nth1, align) = match dtype { + let (nth0, nth1, align, mem_size_bytes) = match dtype { GgmlDType::Q4_0 | GgmlDType::Q4_1 | GgmlDType::Q5_0 @@ -2496,7 +2497,7 @@ pub fn call_quantized_matmul_mv_t( let nth0 = 8; let nth1 = 8; let align = 8; - (nth0, nth1, align) + (nth0, nth1, align, None) } GgmlDType::Q2K => { // Fixing a bug in Metal for GGML @@ -2504,38 +2505,44 @@ pub fn call_quantized_matmul_mv_t( let nth0 = 2; let nth1 = 32; let align = 4; - (nth0, nth1, align) + (nth0, nth1, align, None) } GgmlDType::Q4K => { let nth0 = 4; let nth1 = 8; let align = 4; - (nth0, nth1, align) + (nth0, nth1, align, None) } GgmlDType::Q3K | GgmlDType::Q5K => { let nth0 = 2; let nth1 = 32; let align = 4; - (nth0, nth1, align) + (nth0, nth1, align, None) } GgmlDType::Q6K => { let nth0 = 2; let nth1 = 32; let align = 2; - (nth0, nth1, align) + (nth0, nth1, align, None) } GgmlDType::F16 | GgmlDType::BF16 | GgmlDType::Q8K => { // Original implem uses rows let nth0 = 32; let nth1 = 1; let align = 8; - (nth0, nth1, align) + (nth0, nth1, align, None) } GgmlDType::F32 => { let nth0 = 32; let nth1 = 1; let align = 8; - (nth0, nth1, align) + (nth0, nth1, align, None) + } + GgmlDType::Iq4Xs => { + let nth0 = 4; + let nth1 = 16; + let align = 4; + (nth0, nth1, align, Some(32*std::mem::size_of::())) } }; let thread_groups_count = MTLSize { @@ -2564,6 +2571,7 @@ pub fn call_quantized_matmul_mv_t( GgmlDType::F16 => "kernel_mul_mv_f16_f32", GgmlDType::BF16 => "kernel_mul_mv_bf16_f32", GgmlDType::F32 => "kernel_mul_mv_f32_f32", + GgmlDType::Iq4Xs => "kernel_mul_mm_iq4_xs_f32", }; let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?; @@ -2571,6 +2579,10 @@ pub fn call_quantized_matmul_mv_t( let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); + if let Some(mem_size_bytes) = mem_size_bytes { + encoder.set_threadgroup_memory_length(0, mem_size_bytes as u64); + } + set_params!( encoder, ( @@ -2672,6 +2684,7 @@ pub fn call_quantized_matmul_mm_t( GgmlDType::F16 => "kernel_mul_mm_f16_f32", GgmlDType::BF16 => "kernel_mul_mm_bf16_f32", GgmlDType::F32 => "kernel_mul_mm_f32_f32", + GgmlDType::Iq4Xs => "kernel_mul_mm_iq4_xs_f32", }; let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?;