diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 27b911ef2..6cd3b15a6 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -1974,188 +1974,6 @@ impl candle::CustomOp2 for MulAndAct { } } -impl candle::InplaceOp2 for MulAndAct { - fn name(&self) -> &'static str { - "mul-and-act" - } - - fn cpu_fwd( - &self, - _a_s: &mut CpuStorage, - _a_l: &Layout, - _mask_s: &CpuStorage, - _mask_l: &Layout, - ) -> Result<()> { - candle::bail!("cpu mul-and-act is not implemented"); - } - - #[cfg(feature = "metal")] - fn metal_fwd( - &self, - a_s: &mut candle::MetalStorage, - a_l: &Layout, - b_s: &candle::MetalStorage, - b_l: &Layout, - ) -> Result<(candle::MetalStorage, Shape)> { - use candle::backend::BackendStorage; - use candle_metal_kernels::BufferOffset; - let device = a_s.device(); - let command_buffer = device.command_buffer()?; - let kernels = device.kernels(); - - let elem_count = a_l.shape().elem_count(); - if a_l.shape() != b_l.shape() { - candle::bail!( - "a and b shapes must match: {:?} vs {:?}", - a_l.dims(), - b_l.dims() - ); - } - if a_s.dtype() != b_s.dtype() { - candle::bail!( - "a and b dtypes must match: {:?} vs {:?}", - a_s.dtype(), - b_s.dtype() - ); - } - - let output = device.new_buffer(elem_count, a_s.dtype(), "mul-and-act")?; - if a_l.is_contiguous() && b_l.is_contiguous() { - let name = match (a_s.dtype(), self.act) { - (DType::F32, Activation::Gelu) => "mul_act_f32_gelu", - (DType::F32, Activation::Relu) => "mul_act_f32_relu", - (DType::F32, Activation::Silu) => "mul_act_f32_silu", - (DType::F16, Activation::Gelu) => "mul_act_f16_gelu", - (DType::F16, Activation::Relu) => "mul_act_f16_relu", - (DType::F16, Activation::Silu) => "mul_act_f16_silu", - (DType::BF16, Activation::Gelu) => "mul_act_bf16_gelu", - (DType::BF16, Activation::Relu) => "mul_act_bf16_relu", - (DType::BF16, Activation::Silu) => "mul_act_bf16_silu", - (dtype, act) => candle::bail!("Expected dtype one of f32/f16/bf16 ({dtype:?}), activation one of gelu/relu/silu ({act:?}"), - }; - candle_metal_kernels::call_mul_and_act_contiguous( - device.metal_device(), - &command_buffer, - kernels, - name, - elem_count, - BufferOffset { - buffer: a_s.buffer(), - offset_in_bytes: a_l.start_offset() * a_s.dtype().size_in_bytes(), - }, - BufferOffset { - buffer: b_s.buffer(), - offset_in_bytes: b_l.start_offset() * b_s.dtype().size_in_bytes(), - }, - &output, - ) - .map_err(candle::Error::wrap)?; - } else { - let name = match (a_s.dtype(), self.act) { - (DType::F32, Activation::Gelu) => "mul_act_f32_strided_gelu", - (DType::F32, Activation::Relu) => "mul_act_f32_strided_relu", - (DType::F32, Activation::Silu) => "mul_act_f32_strided_silu", - (DType::F16, Activation::Gelu) => "mul_act_f16_strided_gelu", - (DType::F16, Activation::Relu) => "mul_act_f16_strided_relu", - (DType::F16, Activation::Silu) => "mul_act_f16_strided_silu", - (DType::BF16, Activation::Gelu) => "mul_act_bf16_strided_gelu", - (DType::BF16, Activation::Relu) => "mul_act_bf16_strided_relu", - (DType::BF16, Activation::Silu) => "mul_act_bf16_strided_silu", - (dtype, act) => candle::bail!("Expected dtype one of f32/f16/bf16 ({dtype:?}), activation one of gelu/relu/silu ({act:?}"), - }; - candle_metal_kernels::call_mul_and_act_strided( - device.metal_device(), - &command_buffer, - kernels, - name, - a_l.dims(), - BufferOffset { - buffer: a_s.buffer(), - offset_in_bytes: a_l.start_offset() * a_s.dtype().size_in_bytes(), - }, - a_l.stride(), - BufferOffset { - buffer: b_s.buffer(), - offset_in_bytes: b_l.start_offset() * b_s.dtype().size_in_bytes(), - }, - b_l.stride(), - &output, - ) - .map_err(candle::Error::wrap)?; - } - - let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, a_s.dtype()); - Ok((newstorage, a_l.shape().clone())) - } - - #[cfg(feature = "cuda")] - fn cuda_fwd( - &self, - a_s: &mut candle::CudaStorage, - a_l: &Layout, - b_s: &candle::CudaStorage, - b_l: &Layout, - ) -> Result<()> { - use candle::cuda::{Map2InPlace, SlicePtrOrNull}; - use candle::cuda_backend::cudarc::driver::{ - CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, - }; - use candle::cuda_backend::{kernel_name, kernels, Map2, WrapErr}; - use candle::{CudaDevice, WithDType}; - - struct S { - act: Activation, - lhs_l: Layout, - } - impl Map2InPlace for S { - fn f( - &self, - lhs: &mut CudaSlice, - shape: &Shape, - rhs: &CudaSlice, - rhs_l: &Layout, - dev: &CudaDevice, - ) -> Result<()> { - let dims = shape.dims(); - let lhs_l = &self.lhs_l; - let elem_count = shape.elem_count(); - let cfg = LaunchConfig::for_num_elems(elem_count as u32); - let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() { - SlicePtrOrNull::Null - } else { - SlicePtrOrNull::Ptr( - dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat()) - .w()?, - ) - }; - let lhs = &lhs.slice(lhs_l.start_offset()..); - let rhs = &rhs.slice(rhs_l.start_offset()..); - let name = match self.act { - Activation::Gelu => "mul_act_gelu", - Activation::Silu => "mul_act_silu", - Activation::Relu => "mul_act_relu", - act => candle::bail!("Expected activation one of gelu/relu/silu ({act:?}"), - }; - let func = dev.get_or_load_func(&kernel_name::(name), kernels::MUL_AND_ACT)?; - // SAFETY: Set later by running the kernel. - let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, lhs); - // SAFETY: ffi - unsafe { func.launch(cfg, params) }.w()?; - Ok(()) - } - } - - use candle::backend::BackendStorage; - let dev = a_s.device().clone(); - S { - act: self.act, - lhs_l: a_l.clone(), - } - .map(&mut a_s.slice, a_l.shape(), &b_s.slice, b_l, &dev)?; - Ok(()) - } -} - /// Elementwise multiply and activation. The following activations are supported: /// - `gelu` /// - `silu` @@ -2170,14 +1988,3 @@ pub fn mul_and_act(a: &Tensor, b: &Tensor, act: Activation) -> Result { a.apply_op2(b, MulAndAct { act }) } } - -/// Inplace elementwise multiply and activation, the counterpart to `mul_and_act`. -pub fn inplace_mul_and_act(a: &mut Tensor, b: &Tensor, act: Activation) -> Result<()> { - if a.device().is_cpu() || b.device().is_cpu() { - *a = (a.apply(&act)? * b)?; - } else { - a.inplace_op2(b, &MulAndAct { act })?; - } - - Ok(()) -}