From a6639d0b16f2b3baa022a021c3ca6fda03846e6a Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Tue, 14 Jan 2025 10:44:51 -0500 Subject: [PATCH] Metal support --- candle-metal-kernels/src/lib.rs | 10 ++++++++++ candle-metal-kernels/src/reduce.metal | 20 ++++++++++++++------ candle-nn/src/ops.rs | 6 ++++-- candle-nn/tests/ops.rs | 2 +- 4 files changed, 29 insertions(+), 9 deletions(-) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 9614d14db0..a21b0c9598 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -738,6 +738,7 @@ pub fn call_last_attn_softmax( mask: &Buffer, mask_offset: usize, input_shape: &[usize], + mask_shape: &[usize], scale: f32, ty: SdpaDType, output: &Buffer, @@ -749,6 +750,14 @@ pub fn call_last_attn_softmax( let ne02 = input_shape[input_shape.len() - 3] as i64; let ne03 = input_shape[input_shape.len() - 4] as i64; + let elem_per_batch = if mask_shape.len() == 2 { + 0 + } else { + let bs = input_shape[0]; + let el: usize = input_shape.iter().product(); + el / bs + }; + let mut nth = 32; // SIMD width let name = if ne00 % 4 == 0 { while nth < ne00 / 4 && nth * ne01 * ne02 * ne03 < 256 { @@ -784,6 +793,7 @@ pub fn call_last_attn_softmax( ne00, ne01, ne02, + elem_per_batch as i64, scale ) ); diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index db37ead5ec..a6f677bde9 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -827,6 +827,7 @@ kernel void attn_soft_max( constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, + constant int64_t & elem_per_batch, constant float & scale, threadgroup float * buf [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], @@ -838,9 +839,12 @@ kernel void attn_soft_max( const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - device const T * psrc0 = (device const T *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr; - device T * pdst = (device T *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + const int64_t src_offset = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + const int64_t b_idx = elem_per_batch > 0 ? src_offset / elem_per_batch : 0; + const int64_t mask_offset = b_idx * (ne00*ne01) + i01*ne00; + device const T * psrc0 = (device const T *) src0 + src_offset; + device const T * pmask = src1 != src0 ? (device const T *) src1 + mask_offset : nullptr; + device T * pdst = (device T *) dst + src_offset; float slope = 1.0f; @@ -916,6 +920,7 @@ kernel void attn_soft_max_4( constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, + constant int64_t & elem_per_batch, constant float & scale, threadgroup float * buf [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], @@ -927,9 +932,12 @@ kernel void attn_soft_max_4( const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - device const T * psrc4 = (device const T *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; - device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr; - device T * pdst4 = (device T *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + const int64_t src_offset = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + const int64_t b_idx = elem_per_batch > 0 ? src_offset / elem_per_batch : 0; + const int64_t mask_offset = b_idx * (ne00*ne01) + i01*ne00; + device const T * psrc0 = (device const T *) src0 + src_offset / 4; + device const T * pmask = src1 != src0 ? (device const T *) src1 + mask_offset / 4 : nullptr; + device T * pdst = (device T *) dst + src_offset / 4; float slope = 1.0f; diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index a89904b3c4..6ee685fef2 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -634,6 +634,7 @@ impl candle::InplaceOp2 for AttnSoftmaxLastDim { mask_s.buffer(), mask_l.start_offset() * mask_s.dtype().size_in_bytes(), a_l.dims(), + mask_l.dims(), self.scale, ty, &a_s.buffer(), @@ -713,7 +714,7 @@ impl candle::InplaceOp2 for AttnSoftmaxLastDim { 0 } else { let bs = dims[0]; - el / bs; + el / bs }; let (nrows_x, ncols_x) = (el / dim_m1, dim_m1); @@ -827,6 +828,7 @@ impl candle::CustomOp2 for AttnSoftmaxLastDim { mask_s.buffer(), mask_l.start_offset() * mask_s.dtype().size_in_bytes(), a_l.dims(), + mask_l.dims(), self.scale, ty, &output, @@ -905,7 +907,7 @@ impl candle::CustomOp2 for AttnSoftmaxLastDim { 0 } else { let bs = dims[0]; - el / bs; + el / bs }; let (nrows_x, ncols_x) = (el / dim_m1, dim_m1); diff --git a/candle-nn/tests/ops.rs b/candle-nn/tests/ops.rs index ce2b701c3d..313fcd49db 100644 --- a/candle-nn/tests/ops.rs +++ b/candle-nn/tests/ops.rs @@ -4,7 +4,7 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; -use candle::{test_device, test_utils::to_vec3_round, DType, Device, IndexOp, Result, Tensor}; +use candle::{test_device, test_utils::to_vec3_round, Device, Result, Tensor}; fn softmax(device: &Device) -> Result<()> { let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];