Skip to content

Commit

Permalink
Remove inplace
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Feb 15, 2025
1 parent 382118a commit 63324fb
Showing 1 changed file with 0 additions and 193 deletions.
193 changes: 0 additions & 193 deletions candle-nn/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1974,188 +1974,6 @@ impl candle::CustomOp2 for MulAndAct {
}
}

impl candle::InplaceOp2 for MulAndAct {
fn name(&self) -> &'static str {
"mul-and-act"
}

fn cpu_fwd(
&self,
_a_s: &mut CpuStorage,
_a_l: &Layout,
_mask_s: &CpuStorage,
_mask_l: &Layout,
) -> Result<()> {
candle::bail!("cpu mul-and-act is not implemented");
}

#[cfg(feature = "metal")]
fn metal_fwd(
&self,
a_s: &mut candle::MetalStorage,
a_l: &Layout,
b_s: &candle::MetalStorage,
b_l: &Layout,
) -> Result<(candle::MetalStorage, Shape)> {
use candle::backend::BackendStorage;
use candle_metal_kernels::BufferOffset;
let device = a_s.device();
let command_buffer = device.command_buffer()?;
let kernels = device.kernels();

let elem_count = a_l.shape().elem_count();
if a_l.shape() != b_l.shape() {
candle::bail!(
"a and b shapes must match: {:?} vs {:?}",
a_l.dims(),
b_l.dims()
);
}
if a_s.dtype() != b_s.dtype() {
candle::bail!(
"a and b dtypes must match: {:?} vs {:?}",
a_s.dtype(),
b_s.dtype()
);
}

let output = device.new_buffer(elem_count, a_s.dtype(), "mul-and-act")?;
if a_l.is_contiguous() && b_l.is_contiguous() {
let name = match (a_s.dtype(), self.act) {
(DType::F32, Activation::Gelu) => "mul_act_f32_gelu",
(DType::F32, Activation::Relu) => "mul_act_f32_relu",
(DType::F32, Activation::Silu) => "mul_act_f32_silu",
(DType::F16, Activation::Gelu) => "mul_act_f16_gelu",
(DType::F16, Activation::Relu) => "mul_act_f16_relu",
(DType::F16, Activation::Silu) => "mul_act_f16_silu",
(DType::BF16, Activation::Gelu) => "mul_act_bf16_gelu",
(DType::BF16, Activation::Relu) => "mul_act_bf16_relu",
(DType::BF16, Activation::Silu) => "mul_act_bf16_silu",
(dtype, act) => candle::bail!("Expected dtype one of f32/f16/bf16 ({dtype:?}), activation one of gelu/relu/silu ({act:?}"),
};
candle_metal_kernels::call_mul_and_act_contiguous(
device.metal_device(),
&command_buffer,
kernels,
name,
elem_count,
BufferOffset {
buffer: a_s.buffer(),
offset_in_bytes: a_l.start_offset() * a_s.dtype().size_in_bytes(),
},
BufferOffset {
buffer: b_s.buffer(),
offset_in_bytes: b_l.start_offset() * b_s.dtype().size_in_bytes(),
},
&output,
)
.map_err(candle::Error::wrap)?;
} else {
let name = match (a_s.dtype(), self.act) {
(DType::F32, Activation::Gelu) => "mul_act_f32_strided_gelu",
(DType::F32, Activation::Relu) => "mul_act_f32_strided_relu",
(DType::F32, Activation::Silu) => "mul_act_f32_strided_silu",
(DType::F16, Activation::Gelu) => "mul_act_f16_strided_gelu",
(DType::F16, Activation::Relu) => "mul_act_f16_strided_relu",
(DType::F16, Activation::Silu) => "mul_act_f16_strided_silu",
(DType::BF16, Activation::Gelu) => "mul_act_bf16_strided_gelu",
(DType::BF16, Activation::Relu) => "mul_act_bf16_strided_relu",
(DType::BF16, Activation::Silu) => "mul_act_bf16_strided_silu",
(dtype, act) => candle::bail!("Expected dtype one of f32/f16/bf16 ({dtype:?}), activation one of gelu/relu/silu ({act:?}"),
};
candle_metal_kernels::call_mul_and_act_strided(
device.metal_device(),
&command_buffer,
kernels,
name,
a_l.dims(),
BufferOffset {
buffer: a_s.buffer(),
offset_in_bytes: a_l.start_offset() * a_s.dtype().size_in_bytes(),
},
a_l.stride(),
BufferOffset {
buffer: b_s.buffer(),
offset_in_bytes: b_l.start_offset() * b_s.dtype().size_in_bytes(),
},
b_l.stride(),
&output,
)
.map_err(candle::Error::wrap)?;
}

let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, a_s.dtype());
Ok((newstorage, a_l.shape().clone()))
}

#[cfg(feature = "cuda")]
fn cuda_fwd(
&self,
a_s: &mut candle::CudaStorage,
a_l: &Layout,
b_s: &candle::CudaStorage,
b_l: &Layout,
) -> Result<()> {
use candle::cuda::{Map2InPlace, SlicePtrOrNull};
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
};
use candle::cuda_backend::{kernel_name, kernels, Map2, WrapErr};
use candle::{CudaDevice, WithDType};

struct S {
act: Activation,
lhs_l: Layout,
}
impl Map2InPlace for S {
fn f<T: DeviceRepr + WithDType>(
&self,
lhs: &mut CudaSlice<T>,
shape: &Shape,
rhs: &CudaSlice<T>,
rhs_l: &Layout,
dev: &CudaDevice,
) -> Result<()> {
let dims = shape.dims();
let lhs_l = &self.lhs_l;
let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
SlicePtrOrNull::Null
} else {
SlicePtrOrNull::Ptr(
dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())
.w()?,
)
};
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
let name = match self.act {
Activation::Gelu => "mul_act_gelu",
Activation::Silu => "mul_act_silu",
Activation::Relu => "mul_act_relu",
act => candle::bail!("Expected activation one of gelu/relu/silu ({act:?}"),
};
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::MUL_AND_ACT)?;
// SAFETY: Set later by running the kernel.
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, lhs);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }.w()?;
Ok(())
}
}

use candle::backend::BackendStorage;
let dev = a_s.device().clone();
S {
act: self.act,
lhs_l: a_l.clone(),
}
.map(&mut a_s.slice, a_l.shape(), &b_s.slice, b_l, &dev)?;
Ok(())
}
}

/// Elementwise multiply and activation. The following activations are supported:
/// - `gelu`
/// - `silu`
Expand All @@ -2170,14 +1988,3 @@ pub fn mul_and_act(a: &Tensor, b: &Tensor, act: Activation) -> Result<Tensor> {
a.apply_op2(b, MulAndAct { act })
}
}

/// Inplace elementwise multiply and activation, the counterpart to `mul_and_act`.
pub fn inplace_mul_and_act(a: &mut Tensor, b: &Tensor, act: Activation) -> Result<()> {
if a.device().is_cpu() || b.device().is_cpu() {
*a = (a.apply(&act)? * b)?;
} else {
a.inplace_op2(b, &MulAndAct { act })?;
}

Ok(())
}

0 comments on commit 63324fb

Please sign in to comment.