Skip to content

Commit

Permalink
Mul-and-act fused op (#72)
Browse files Browse the repository at this point in the history
* Mul-and-act

* Oops

* Add cuda kernel

* Add inplace op for cuda/cpu

* Remove inplace
  • Loading branch information
EricLBuehler authored Feb 15, 2025
1 parent fb62f6f commit 4fea87a
Show file tree
Hide file tree
Showing 6 changed files with 748 additions and 0 deletions.
1 change: 1 addition & 0 deletions candle-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pub const FUSED_RMS_NORM: &str = include_str!(concat!(env!("OUT_DIR"), "/fused_r
pub const FUSED_ROPE: &str = include_str!(concat!(env!("OUT_DIR"), "/fused_rope.ptx"));
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx"));
pub const KVCONCAT: &str = include_str!(concat!(env!("OUT_DIR"), "/kvconcat.ptx"));
pub const MUL_AND_ACT: &str = include_str!(concat!(env!("OUT_DIR"), "/mul_and_act.ptx"));
pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx"));
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
pub const SORT: &str = include_str!(concat!(env!("OUT_DIR"), "/sort.ptx"));
Expand Down
107 changes: 107 additions & 0 deletions candle-kernels/src/mul_and_act.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
#include "cuda_utils.cuh"

template<typename T>
__device__ __forceinline__ T gelu(T x) {
T x_sq = x * x;
T x_cube = x_sq * x;
T alpha = x + static_cast<T>(0.044715) * x_cube;
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + tanhg(static_cast<T>(M_2_SQRTPI * M_SQRT1_2) * alpha));
}

template<typename T>
__device__ __forceinline__ T relu(T x) {
T zero = 0.;
return maxg(x, zero);
}

template<typename T>
__device__ __forceinline__ T silu(T x) {
return x / (static_cast<T>(1) + expg(-x));
}


#define MUL_ACT_OP_OUT(TYPENAME, OUT_TYPENAME, FN_NAME, ACT) \
extern "C" __global__ void FN_NAME( \
const size_t numel, \
const size_t num_dims, \
const size_t *dims_and_strides, \
const TYPENAME *lhs, \
const TYPENAME *rhs, \
OUT_TYPENAME *out \
) { \
const size_t *dims = dims_and_strides; \
const size_t *lhs_strides = dims_and_strides + 1 * num_dims; \
const size_t *rhs_strides = dims_and_strides + 2 * num_dims; \
bool lhs_cont = dims_and_strides == nullptr || is_contiguous(num_dims, dims, lhs_strides); \
bool rhs_cont = dims_and_strides == nullptr || is_contiguous(num_dims, dims, rhs_strides); \
if (lhs_cont && rhs_cont) { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
TYPENAME x = lhs[i]; \
TYPENAME y = rhs[i]; \
out[i] = TYPENAME(ACT(float(x)) * float(y)); \
} \
} else if (lhs_cont) { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
unsigned int tmp_i = i; \
unsigned int rhs_i = 0; \
for (int d = num_dims - 1; d >= 0; d--) { \
unsigned int i_dim = tmp_i % dims[d]; \
rhs_i += i_dim * rhs_strides[d]; \
tmp_i /= dims[d]; \
} \
TYPENAME x = lhs[i]; \
TYPENAME y = rhs[rhs_i]; \
out[i] = TYPENAME(ACT(float(x)) * float(y)); \
} \
} else if (rhs_cont) { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
unsigned int tmp_i = i; \
unsigned int lhs_i = 0; \
for (int d = num_dims - 1; d >= 0; d--) { \
unsigned int i_dim = tmp_i % dims[d]; \
lhs_i += i_dim * lhs_strides[d]; \
tmp_i /= dims[d]; \
} \
TYPENAME x = lhs[lhs_i]; \
TYPENAME y = rhs[i]; \
out[i] = TYPENAME(ACT(float(x)) * float(y)); \
} \
} else { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
unsigned int tmp_i = i; \
unsigned int lhs_i = 0; \
unsigned int rhs_i = 0; \
for (int d = num_dims - 1; d >= 0; d--) { \
unsigned int i_dim = tmp_i % dims[d]; \
lhs_i += i_dim * lhs_strides[d]; \
rhs_i += i_dim * rhs_strides[d]; \
tmp_i /= dims[d]; \
} \
TYPENAME x = lhs[lhs_i]; \
TYPENAME y = rhs[rhs_i]; \
out[i] = TYPENAME(ACT(float(x)) * float(y)); \
} \
} \
} \


#define MUL_ACT_OP(TYPENAME, FN_NAME, ACT) \
MUL_ACT_OP_OUT(TYPENAME, TYPENAME, FN_NAME, ACT)

#if __CUDA_ARCH__ >= 800
#include "cuda_bf16.h"

MUL_ACT_OP(__nv_bfloat16, mul_act_gelu_bf16, gelu)
MUL_ACT_OP(__nv_bfloat16, mul_act_relu_bf16, relu)
MUL_ACT_OP(__nv_bfloat16, mul_act_silu_bf16, silu)
#endif

#if __CUDA_ARCH__ >= 530
MUL_ACT_OP(__half, mul_act_gelu_f16, gelu)
MUL_ACT_OP(__half, mul_act_relu_f16, relu)
MUL_ACT_OP(__half, mul_act_silu_f16, silu)
#endif

MUL_ACT_OP(float, mul_act_gelu_f32, gelu)
MUL_ACT_OP(float, mul_act_relu_f32, relu)
MUL_ACT_OP(float, mul_act_silu_f32, silu)
75 changes: 75 additions & 0 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ const SORT: &str = include_str!("sort.metal");
const TERNARY: &str = include_str!("ternary.metal");
const UNARY: &str = include_str!("unary.metal");
const SDPA: &str = include_str!("scaled_dot_product_attention.metal");
const MUL_AND_ACT: &str = include_str!("mul_and_act.metal");

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Source {
Expand All @@ -48,6 +49,7 @@ pub enum Source {
Ternary,
Unary,
Sdpa,
MulAndAct,
}

pub mod copy2d {
Expand Down Expand Up @@ -239,6 +241,7 @@ impl Kernels {
Source::Ternary => TERNARY,
Source::Unary => UNARY,
Source::Sdpa => SDPA,
Source::MulAndAct => MUL_AND_ACT,
Source::Mfa => panic!("Invalid lib"),
}
}
Expand Down Expand Up @@ -3315,5 +3318,77 @@ pub fn call_const_fill(
Ok(())
}

#[allow(clippy::too_many_arguments)]
pub fn call_mul_and_act_contiguous(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
length: usize,
left: BufferOffset,
right: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::MulAndAct, name)?;

let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);

set_params!(encoder, (length, &left, &right, output));

let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);

encoder.use_resource(left.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
Ok(())
}

#[allow(clippy::too_many_arguments)]
pub fn call_mul_and_act_strided(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
left_input: BufferOffset,
left_strides: &[usize],
right_input: BufferOffset,
right_strides: &[usize],
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::MulAndAct, name)?;

let num_dims: usize = shape.len();
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
let width: usize = shape.iter().product();
let length: usize = shape.iter().product();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);

encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
length,
num_dims,
shape,
left_strides,
right_strides,
&left_input,
&right_input,
output
)
);
encoder.use_resource(left_input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(right_input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);

Ok(())
}

#[cfg(test)]
mod tests;
Loading

0 comments on commit 4fea87a

Please sign in to comment.