Skip to content

Commit

Permalink
Remove apply_mask_scale
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Jan 14, 2025
1 parent f9d1d50 commit 072c715
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 223 deletions.
49 changes: 2 additions & 47 deletions candle-kernels/src/reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ __device__ void attn_soft_max(const T * x, const T * mask, T * dst, const int nc

const int64_t ix = (int64_t)rowx*ncols + col;

const int64_t b_idx = elem_per_batch > 0 ? ix / elem_per_batch : 1;
const int64_t b_idx = elem_per_batch > 0 ? ix / elem_per_batch : 0;
const int64_t iy = (int64_t)b_idx * (ncols*nrows_y) + rowy*ncols + col;

const float val = float(x[ix]) * scale + (mask ? float(mask[iy]) : 0.0f);
Expand Down Expand Up @@ -328,34 +328,6 @@ __device__ void attn_soft_max(const T * x, const T * mask, T * dst, const int nc
}
}

template <typename T>
__device__ void apply_mask_scale(const T * x, const T * mask, T * dst, const int ncols, const int nrows_y, const int elem_per_batch, const float scale) {
const int tid = threadIdx.x;
const int rowx = blockIdx.x;
const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension

const int block_size = blockDim.x;

float max_val = -INFINITY;

#pragma unroll
for (int col0 = 0; col0 < ncols; col0 += block_size) {
const int col = col0 + tid;

if (col >= ncols) {
break;
}

const int64_t ix = (int64_t)rowx*ncols + col;

const int64_t b_idx = ix / elem_per_batch;
const int64_t iy = (int64_t)b_idx * (ncols*nrows_y) + rowy*ncols + col;

const float val = float(x[ix]) * scale + (mask ? float(mask[iy]) : 0.0f);
dst[ix] = T(val);
}
}

template <typename T>
__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
Expand Down Expand Up @@ -674,20 +646,7 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
const float scale \
) { \
attn_soft_max<TYPENAME>(x, mask, dst, ncols, nrows_y, elem_per_batch, scale); \
}

#define MASK_SCALE_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const TYPENAME * x, \
const TYPENAME * mask, \
TYPENAME * dst, \
const int ncols, \
const int nrows_y, \
const int elem_per_batch, \
const float scale \
) { \
apply_mask_scale<TYPENAME>(x, mask, dst, ncols, nrows_y, elem_per_batch, scale); \
} \
} \

#define RMSNORM_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
Expand Down Expand Up @@ -739,7 +698,6 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
#include "cuda_bf16.h"
SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16)
ATTN_SOFTMAX_OP(__nv_bfloat16, attn_soft_max_bf16)
MASK_SCALE_OP(__nv_bfloat16, mask_scale_bf16)
RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16)
LAYERNORM_OP(__nv_bfloat16, layernorm_bf16)
ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16)
Expand All @@ -758,7 +716,6 @@ FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argm
#if __CUDA_ARCH__ >= 530
SOFTMAX_OP(__half, float, softmax_f16)
ATTN_SOFTMAX_OP(__half, attn_soft_max_f16)
MASK_SCALE_OP(__half, mask_scale_f16)
RMSNORM_OP(__half, rmsnorm_f16)
LAYERNORM_OP(__half, layernorm_f16)
ROPE_OP(__half, rope_f16, rope_i_f16, rope_thd_f16)
Expand All @@ -773,8 +730,6 @@ SOFTMAX_OP(float, float, softmax_f32)
SOFTMAX_OP(double, double, softmax_f64)
ATTN_SOFTMAX_OP(float, attn_soft_max_f32)
ATTN_SOFTMAX_OP(double, attn_soft_max_f64)
MASK_SCALE_OP(float, mask_scale_f32)
MASK_SCALE_OP(double, mask_scale_f64)
RMSNORM_OP(float, rmsnorm_f32)
RMSNORM_OP(double, rmsnorm_f64)
LAYERNORM_OP(float, layernorm_f32)
Expand Down
149 changes: 0 additions & 149 deletions candle-nn/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -980,155 +980,6 @@ pub fn inplace_attn_softmax_last_dim(xs: &mut Tensor, mask: &Tensor, scale: f32)
Ok(())
}

#[allow(dead_code)]
struct ApplyMaskScale {
scale: f32,
}

impl candle::InplaceOp2 for ApplyMaskScale {
fn name(&self) -> &'static str {
"apply-mask-scale-inplace"
}

fn cpu_fwd(
&self,
_a_s: &mut CpuStorage,
_a_l: &Layout,
_mask_s: &CpuStorage,
_mask_l: &Layout,
) -> Result<()> {
candle::bail!("cpu apply-mask-scale-inplace is not implemented");
}

#[cfg(feature = "cuda")]
fn cuda_fwd(
&self,
a_s: &mut candle::CudaStorage,
a_l: &Layout,
mask_s: &candle::CudaStorage,
mask_l: &Layout,
) -> Result<()> {
use candle::backend::BackendStorage;

use candle::cuda::Map2InPlace;
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
};
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
use candle::{CudaDevice, WithDType};

if !a_l.is_contiguous() {
candle::bail!("Non contiguous xs for apply-mask-scale is not implemented");
}
if !mask_l.is_contiguous() {
candle::bail!("Non contiguous mask for apply-mask-scale is not implemented");
}

if a_l.dims().len() != 4 {
candle::bail!("apply-mask-scale expects xs of rank 2");
}
if mask_l.dims().len() != 3 {
candle::bail!("apply-mask-scale expects mask of rank 3");
}
if mask_l.dim(D::Minus1)? != a_l.dim(D::Minus1)?
|| mask_l.dim(D::Minus2)? != a_l.dim(D::Minus2)?
{
candle::bail!("apply-mask-scale expects last 2 dims to match xs last 2 dims");
}
if mask_l.dim(0)? != a_l.dim(0)? {
candle::bail!("apply-mask-scale expects mask bs to match xs bs");
}

struct S<'a> {
scale: f32,
a_l: &'a Layout,
}
impl Map2InPlace for S<'_> {
fn f<T: DeviceRepr + WithDType>(
&self,
a_s: &mut CudaSlice<T>,
_a_shape: &Shape,
mask_s: &CudaSlice<T>,
mask_l: &Layout,
dev: &CudaDevice,
) -> Result<()> {
let a = match self.a_l.contiguous_offsets() {
None => candle::bail!("input has to be contiguous"),
Some((o1, o2)) => a_s.slice(o1..o2),
};
let mask = match mask_l.contiguous_offsets() {
None => candle::bail!("mask has to be contiguous"),
Some((o1, o2)) => mask_s.slice(o1..o2),
};

let el = self.a_l.shape().elem_count();
let dims = self.a_l.shape().dims();
let dim_m1 = dims[dims.len() - 1];
let nrows_y = dims[dims.len() - 2];
let bs = dims[0];
let elem_per_batch = el / bs;

let (nrows_x, ncols_x) = (el / dim_m1, dim_m1);

const WARP_SIZE: usize = 32;
const CUDA_SOFT_MAX_BLOCK_SIZE: usize = 1024;
let mut nth = WARP_SIZE;
while nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE {
nth *= 2;
}

let cfg = LaunchConfig {
grid_dim: (nrows_x as u32, 1, 1),
block_dim: (nth as u32, 1, 1),
shared_mem_bytes: (WARP_SIZE * std::mem::size_of::<f32>()) as u32,
};
let func =
dev.get_or_load_func(&kernel_name::<T>("mask_scale"), kernels::REDUCE)?;
let params = (
&a,
&mask,
&a,
ncols_x as i32,
nrows_y as i32,
elem_per_batch as i32,
self.scale,
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;

Ok(())
}
}

let dev = a_s.device().clone();
S {
scale: self.scale,
a_l,
}
.map(&mut a_s.slice, a_l.shape(), &mask_s.slice, mask_l, &dev)?;

Ok(())
}
}

/// Fused broadcast addition of a mask and scale.
/// Equivalent to:
/// ```ignore
/// (xs.broadcast_add(&mask)? * scale as f64)?)?
/// ```
/// - `xs` must be a rank-4 tensor
/// - The last 2 dimensions of `xs` must match the dimensions of `mask`.
/// - `mask` should be a 3-dimensional tensor (bs, s1, s2).
///
pub fn inplace_apply_mask_scale(xs: &mut Tensor, mask: &Tensor, scale: f32) -> Result<()> {
if xs.device().is_metal() || xs.device().is_cuda() {
xs.inplace_op2(mask, &ApplyMaskScale { scale })?;
} else {
*xs = (xs.broadcast_add(&mask.unsqueeze(1)?)? * scale as f64)?;
}
Ok(())
}

#[derive(Debug, Clone)]
struct RmsNorm {
eps: f32,
Expand Down
27 changes: 0 additions & 27 deletions candle-nn/tests/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,27 +68,6 @@ fn inplace_softmax(device: &Device) -> Result<()> {
Ok(())
}

fn apply_mask_scale(device: &Device) -> Result<()> {
let cpu = &Device::Cpu;
let mut xs = Tensor::zeros((2, 2, 2, 2), DType::F32, cpu)?;
let mask = &[[[1f32, 1.], [1., 1.]], [[2., 2.], [2., 2.]]];
let mask = Tensor::new(mask, cpu)?;
candle_nn::ops::inplace_apply_mask_scale(&mut xs, &mask, 1.0)?;
let truth_0 = xs.i(0)?.to_vec3::<f32>()?;
let truth_1 = xs.i(1)?.to_vec3::<f32>()?;

let mut xs = Tensor::zeros((2, 2, 2, 2), DType::F32, device)?;
let mask = &[[[1f32, 1.], [1., 1.]], [[2., 2.], [2., 2.]]];
let mask = Tensor::new(mask, device)?;
candle_nn::ops::inplace_apply_mask_scale(&mut xs, &mask, 1.0)?;
let xs_0 = xs.i(0)?.to_vec3::<f32>()?;
let xs_1 = xs.i(1)?.to_vec3::<f32>()?;

assert_eq!(xs_0, truth_0);
assert_eq!(xs_1, truth_1);
Ok(())
}

fn rms_norm(device: &Device) -> Result<()> {
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
let tensor = Tensor::new(data, device)?;
Expand Down Expand Up @@ -296,12 +275,6 @@ test_device!(
inplace_softmax_gpu,
inplace_softmax_metal
);
test_device!(
apply_mask_scale,
apply_mask_scale_cpu,
apply_mask_scale_gpu,
apply_mask_scale_metal
);
test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal);
test_device!(rms_norml, rms_norml_cpu, rms_norml_gpu, rms_norml_metal);
test_device!(layer_norm, ln_cpu, ln_gpu, ln_metal);
Expand Down

0 comments on commit 072c715

Please sign in to comment.