Skip to content

Commit

Permalink
Cuda support for attn softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Jan 11, 2025
1 parent 96279d5 commit 8f434ac
Show file tree
Hide file tree
Showing 2 changed files with 317 additions and 2 deletions.
124 changes: 124 additions & 0 deletions candle-kernels/src/reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@
#define WARP_SIZE 32
const int BLOCK_SIZE = 1024;

static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
}
return x;
}

// TODO: Maybe add some fast_sum_f16_f32 variant that not only accumulate in f32
// but also expect a f32 output so that this can be used for normalization e.g.
// in softmax.
Expand Down Expand Up @@ -218,6 +226,106 @@ __device__ void softmax(const T * x, T * dst, const int ncols) {
}
}

template <typename T>
__device__ void attn_soft_max(const T * x, const T * mask, T * dst, const int ncols, const int nrows_y, const float scale) {
const int tid = threadIdx.x;
const int rowx = blockIdx.x;
const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension

const int block_size = blockDim.x;

const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;

extern __shared__ float smem[];
float * buf_iw = smem; // shared memory buffer for inter-warp communication
// shared memory buffer to cache values between iterations:
T * vals = dst + (int64_t)rowx*ncols;

float max_val = -INFINITY;

#pragma unroll
for (int col0 = 0; col0 < ncols; col0 += block_size) {
const int col = col0 + tid;

if (col >= ncols) {
break;
}

const int64_t ix = (int64_t)rowx*ncols + col;
const int64_t iy = (int64_t)rowy*ncols + col;

const float val = float(x[ix]) * scale + (mask ? float(mask[iy]) : 0.0f);

vals[col] = val;
max_val = max(max_val, val);
}

// find the max value in the block
max_val = warp_reduce_max(max_val);
if (block_size > WARP_SIZE) {
if (warp_id == 0) {
buf_iw[lane_id] = -INFINITY;
}
__syncthreads();

if (lane_id == 0) {
buf_iw[warp_id] = max_val;
}
__syncthreads();

max_val = buf_iw[lane_id];
max_val = warp_reduce_max(max_val);
}

float tmp = 0.0f; // partial sum

#pragma unroll
for (int col0 = 0; col0 < ncols; col0 += block_size) {
const int col = col0 + tid;

if (col >= ncols) {
break;
}

const float val = expf(float(vals[col]) - max_val);
tmp += val;
vals[col] = val;
}

// find the sum of exps in the block
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
__syncthreads();
if (warp_id == 0) {
buf_iw[lane_id] = 0.0f;
}
__syncthreads();

if (lane_id == 0) {
buf_iw[warp_id] = tmp;
}
__syncthreads();

tmp = buf_iw[lane_id];
tmp = warp_reduce_sum(tmp);
}

const float inv_sum = 1.0f / tmp;

#pragma unroll
for (int col0 = 0; col0 < ncols; col0 += block_size) {
const int col = col0 + tid;

if (col >= ncols) {
return;
}

const int64_t idst = (int64_t)rowx*ncols + col;
dst[idst] = float(vals[col]) * inv_sum;
}
}

template <typename T>
__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
Expand Down Expand Up @@ -523,6 +631,18 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
const TYPENAME *src, TYPENAME *dst, \
const int n_cols) { \
softmax<TYPENAME, ACC_TYPENAME>(src, dst, n_cols); \
}

#define ATTN_SOFTMAX_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const TYPENAME * x, \
const TYPENAME * mask, \
TYPENAME * dst, \
const int ncols, \
const int nrows_y, \
const float scale \
) { \
attn_soft_max<TYPENAME>(x, mask, dst, ncols, nrows_y, scale); \
} \

#define RMSNORM_OP(TYPENAME, FN_NAME) \
Expand Down Expand Up @@ -574,6 +694,7 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
#if __CUDA_ARCH__ >= 800
#include "cuda_bf16.h"
SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16)
ATTN_SOFTMAX_OP(__nv_bfloat16, attn_soft_max_bf16)
RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16)
LAYERNORM_OP(__nv_bfloat16, layernorm_bf16)
ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16)
Expand All @@ -591,6 +712,7 @@ FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argm

#if __CUDA_ARCH__ >= 530
SOFTMAX_OP(__half, float, softmax_f16)
ATTN_SOFTMAX_OP(__half, attn_soft_max_f16)
RMSNORM_OP(__half, rmsnorm_f16)
LAYERNORM_OP(__half, layernorm_f16)
ROPE_OP(__half, rope_f16, rope_i_f16, rope_thd_f16)
Expand All @@ -603,6 +725,8 @@ SUM_OP(double, sum_f64)
SUM_OP(uint32_t, sum_u32)
SOFTMAX_OP(float, float, softmax_f32)
SOFTMAX_OP(double, double, softmax_f64)
ATTN_SOFTMAX_OP(float, attn_soft_max_f32)
ATTN_SOFTMAX_OP(double, attn_soft_max_f64)
RMSNORM_OP(float, rmsnorm_f32)
RMSNORM_OP(double, rmsnorm_f64)
LAYERNORM_OP(float, layernorm_f32)
Expand Down
195 changes: 193 additions & 2 deletions candle-nn/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,101 @@ impl candle::InplaceOp2 for AttnSoftmaxLastDim {

Ok(())
}

#[cfg(feature = "cuda")]
fn cuda_fwd(
&self,
a_s: &mut candle::CudaStorage,
a_l: &Layout,
mask_s: &candle::CudaStorage,
mask_l: &Layout,
) -> Result<()> {
use candle::backend::BackendStorage;

use candle::cuda::Map2InPlace;
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
};
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
use candle::{CudaDevice, WithDType};

if !a_l.is_contiguous() {
candle::bail!("Non contiguous xs for attn-softmax-last-dim is not implemented");
}
if !mask_l.is_contiguous() {
candle::bail!("Non contiguous mask for attn-softmax-last-dim is not implemented");
}

if a_l.dims().len() != 4 {
candle::bail!("attn-softmax-last-dim expects xs of rank 2");
}
if mask_l.dims().len() != 2 {
candle::bail!("attn-softmax-last-dim expects mask of rank 2");
}
if mask_l.dim(D::Minus1)? != a_l.dim(D::Minus1)?
|| mask_l.dim(D::Minus2)? != a_l.dim(D::Minus2)?
{
candle::bail!("attn-softmax-last-dim expects last 2 dims to match xs last 2 dims");
}

struct S<'a> {
scale: f32,
a_l: &'a Layout,
}
impl Map2InPlace for S<'_> {
fn f<T: DeviceRepr + WithDType>(
&self,
a_s: &mut CudaSlice<T>,
_a_shape: &Shape,
mask_s: &CudaSlice<T>,
mask_l: &Layout,
dev: &CudaDevice,
) -> Result<()> {
let a = match self.a_l.contiguous_offsets() {
None => candle::bail!("input has to be contiguous"),
Some((o1, o2)) => a_s.slice(o1..o2),
};
let mask = match mask_l.contiguous_offsets() {
None => candle::bail!("mask has to be contiguous"),
Some((o1, o2)) => mask_s.slice(o1..o2),
};

let el = self.a_l.shape().elem_count();
let dims = self.a_l.shape().dims();
let dim_m1 = dims[dims.len() - 1];
let nrows_y = dims[dims.len() - 2];
let (nrows_x, ncols_x) = (el / dim_m1, dim_m1);

const WARP_SIZE: usize = 32;
let mut nth = WARP_SIZE;
while nth < ncols_x && nth < 1024 {
nth *= 2;
}

let cfg = LaunchConfig {
grid_dim: (nth as u32, 1, 1),
block_dim: (nrows_x as u32, 1, 1),
shared_mem_bytes: (WARP_SIZE * std::mem::size_of::<f32>()) as u32,
};
let func =
dev.get_or_load_func(&kernel_name::<T>("attn_soft_max"), kernels::REDUCE)?;
let params = (&a, &mask, &a, ncols_x as i32, nrows_y as i32, self.scale);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;

Ok(())
}
}

let dev = a_s.device().clone();
S {
scale: self.scale,
a_l,
}
.map(&mut a_s.slice, a_l.shape(), &mask_s.slice, mask_l, &dev)?;

Ok(())
}
}

impl candle::CustomOp2 for AttnSoftmaxLastDim {
Expand Down Expand Up @@ -716,6 +811,102 @@ impl candle::CustomOp2 for AttnSoftmaxLastDim {
let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, a_s.dtype());
Ok((newstorage, a_l.shape().clone()))
}

#[cfg(feature = "cuda")]
fn cuda_fwd(
&self,
a_s: &candle::CudaStorage,
a_l: &Layout,
mask_s: &candle::CudaStorage,
mask_l: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::backend::BackendStorage;

use candle::cuda::Map2;
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
};
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
use candle::{CudaDevice, WithDType};

if !a_l.is_contiguous() {
candle::bail!("Non contiguous xs for attn-softmax-last-dim is not implemented");
}
if !mask_l.is_contiguous() {
candle::bail!("Non contiguous mask for attn-softmax-last-dim is not implemented");
}

if a_l.dims().len() != 4 {
candle::bail!("attn-softmax-last-dim expects xs of rank 2");
}
if mask_l.dims().len() != 2 {
candle::bail!("attn-softmax-last-dim expects mask of rank 2");
}
if mask_l.dim(D::Minus1)? != a_l.dim(D::Minus1)?
|| mask_l.dim(D::Minus2)? != a_l.dim(D::Minus2)?
{
candle::bail!("attn-softmax-last-dim expects last 2 dims to match xs last 2 dims");
}

struct S {
scale: f32,
}
impl Map2 for S {
fn f<T: DeviceRepr + WithDType>(
&self,
a_s: &CudaSlice<T>,
a_l: &Layout,
mask_s: &CudaSlice<T>,
mask_l: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>> {
let a = match a_l.contiguous_offsets() {
None => candle::bail!("input has to be contiguous"),
Some((o1, o2)) => a_s.slice(o1..o2),
};
let mask = match mask_l.contiguous_offsets() {
None => candle::bail!("mask has to be contiguous"),
Some((o1, o2)) => mask_s.slice(o1..o2),
};

let el = a_l.shape().elem_count();
let dims = a_l.shape().dims();
let dim_m1 = dims[dims.len() - 1];
let nrows_y = dims[dims.len() - 2];
let (nrows_x, ncols_x) = (el / dim_m1, dim_m1);

const WARP_SIZE: usize = 32;
let mut nth = WARP_SIZE;
while nth < ncols_x && nth < 1024 {
nth *= 2;
}

let cfg = LaunchConfig {
grid_dim: (nth as u32, 1, 1),
block_dim: (nrows_x as u32, 1, 1),
shared_mem_bytes: (WARP_SIZE * std::mem::size_of::<f32>()) as u32,
};
let func =
dev.get_or_load_func(&kernel_name::<T>("attn_soft_max"), kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (&a, &mask, &dst, ncols_x as i32, nrows_y as i32, self.scale);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;

Ok(dst)
}
}

let dev = a_s.device().clone();
let slice = S { scale: self.scale }.map(&a_s.slice, a_l, &mask_s.slice, mask_l, &dev)?;

let dst = candle::cuda_backend::CudaStorage {
slice,
device: dev.clone(),
};
Ok((dst, a_l.shape().clone()))
}
}

/// Softmax with fused broadcast addition of a mask and scale.
Expand All @@ -729,7 +920,7 @@ impl candle::CustomOp2 for AttnSoftmaxLastDim {
///
/// Note: if the last dim of `xs` is a multiple of 4, a vectorized implementation will be used.
pub fn attn_softmax_last_dim(xs: &Tensor, mask: &Tensor, scale: f32) -> Result<Tensor> {
if xs.device().is_metal() {
if xs.device().is_metal() || xs.device().is_cuda() {
xs.apply_op2_no_bwd(mask, &AttnSoftmaxLastDim { scale })
} else {
softmax_last_dim(&(xs.broadcast_add(mask)? * scale as f64)?)
Expand All @@ -738,7 +929,7 @@ pub fn attn_softmax_last_dim(xs: &Tensor, mask: &Tensor, scale: f32) -> Result<T

/// Inplace equivalent of `attn_softmax_last_dim`
pub fn inplace_attn_softmax_last_dim(xs: &mut Tensor, mask: &Tensor, scale: f32) -> Result<()> {
if xs.device().is_metal() {
if xs.device().is_metal() || xs.device().is_cuda() {
xs.inplace_op2(mask, &AttnSoftmaxLastDim { scale })?;
} else {
*xs = softmax_last_dim(&(xs.broadcast_add(mask)? * scale as f64)?)?;
Expand Down

0 comments on commit 8f434ac

Please sign in to comment.