Skip to content

Commit

Permalink
Add the kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Aug 8, 2024
1 parent 7e530ca commit 87229dc
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 3 deletions.
4 changes: 3 additions & 1 deletion candle-kernels/src/cast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,15 @@ CAST_OP(uint8_t, double, cast_u8_f64)

CAST_OP(int64_t, uint32_t, cast_i64_u32)
CAST_OP(int64_t, uint8_t, cast_i64_u8 )
CAST_OP(int64_t, int32_t, cast_i64_i32 )
CAST_OP(int64_t, int64_t, cast_i64_i64 )
CAST_OP(int64_t, float, cast_i64_f32)
CAST_OP(int64_t, double, cast_i64_f64)

CAST_OP(int32_t, uint32_t, cast_i32_u32)
CAST_OP(int32_t, uint8_t, cast_i32_u8 )
CAST_OP(int32_t, int32_t, cast_i32_i64 )
CAST_OP(int32_t, int64_t, cast_i32_i64 )
CAST_OP(int32_t, int32_t, cast_i32_i32 )
CAST_OP(int32_t, float, cast_i32_f32)
CAST_OP(int32_t, double, cast_i32_f64)

Expand Down
6 changes: 4 additions & 2 deletions candle-metal-kernels/src/binary.metal
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,15 @@ kernel void FN_NAME_STRIDED( \
BINARY(FN, float, float, NAME##_f32, NAME##_f32_strided); \
BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided); \
BINARY(FN, uint32_t, uint32_t, NAME##_u32, NAME##_u32_strided); \
BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided); \
BINARY(FN, int32_t, int32_t, NAME##_i32, NAME##_i32_strided);

#define BINARY_OP_OUT(NAME, FN) \
BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \
BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided); \
BINARY(FN, uint32_t, uint8_t, NAME##_u32, NAME##_u32_strided); \
BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided); \
BINARY(FN, int32_t, uint8_t, NAME##_i32, NAME##_i32_strided);

#define INT64_BINARY_OP(NAME, FN) \
BINARY(FN, int64_t, int64_t, NAME##_i64, NAME##_i64_strided);
Expand Down
16 changes: 16 additions & 0 deletions candle-metal-kernels/src/cast.metal
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ kernel void FN_NAME_STRIDED( \
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
CAST(cast_u32_f16, cast_u32_f16_strided, uint32_t, half)
CAST(cast_u32_i32, cast_u32_i32_strided, uint32_t, int32_t)
#if __METAL_VERSION__ >= 220
CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t)
#endif
Expand All @@ -87,6 +88,7 @@ CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat)
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float)
CAST(cast_u8_f16, cast_u8_f16_strided, uint8_t, half)
CAST(cast_u8_i32, cast_u8_i32_strided, uint8_t, int64_t)
#if __METAL_VERSION__ >= 220
CAST(cast_u8_i64, cast_u8_i64_strided, uint8_t, int64_t)
#endif
Expand All @@ -98,6 +100,7 @@ CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat)
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
CAST(cast_f16_u8, cast_f16_u8_strided, half, uint8_t)
CAST(cast_f16_u32, cast_f16_u32_strided, half, uint32_t)
CAST(cast_f16_i32, cast_f16_i32_strided, half, int64_t)
CAST(cast_f16_i64, cast_f16_i64_strided, half, int64_t)
#if defined(__HAVE_BFLOAT__)
CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
Expand All @@ -107,15 +110,27 @@ CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float)
CAST(cast_i64_u8, cast_i64_u8_strided, int64_t, uint8_t)
CAST(cast_i64_u32, cast_i64_u32_strided, int64_t, uint32_t)
CAST(cast_i64_i32, cast_i64_i32_strided, int64_t, int32_t)
CAST(cast_i64_f16, cast_i64_f16_strided, int64_t, half)
#if defined(__HAVE_BFLOAT__)
CAST_THROUGH(cast_i64_bf16, cast_i64_bf16_strided, int64_t, bfloat, float)
#endif

// i32
CAST(cast_i32_f32, cast_i32_f32_strided, int32_t, float)
CAST(cast_i32_u8, cast_i32_u8_strided, int32_t, uint8_t)
CAST(cast_i32_u32, cast_i32_u32_strided, int32_t, uint32_t)
CAST(cast_i32_i64, cast_i32_i64_strided, int32_t, int64_t)
CAST(cast_i32_f16, cast_i32_f16_strided, int32_t, half)
#if defined(__HAVE_BFLOAT__)
CAST_THROUGH(cast_i64_bf16, cast_i64_bf16_strided, int64_t, bfloat, float)
#endif

// f32
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
CAST(cast_f32_u32, cast_f32_u32_strided, float, uint32_t)
CAST(cast_f32_u8, cast_f32_u8_strided, float, uint8_t)
CAST(cast_f32_i32, cast_f32_i32_strided, float, int32_t)
CAST(cast_f32_i64, cast_f32_i64_strided, float, int64_t)
#if defined(__HAVE_BFLOAT__)
CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
Expand All @@ -124,6 +139,7 @@ CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
// bf16
#if defined(__HAVE_BFLOAT__)
CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t)
CAST(cast_bf16_i32, cast_bf16_i32_strided, bfloat, int32_t)
CAST(cast_bf16_i64, cast_bf16_i64_strided, bfloat, int64_t)
CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float)
CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float)
Expand Down
22 changes: 22 additions & 0 deletions candle-metal-kernels/src/indexing.metal
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,12 @@ INDEX_OP(is_i64_f16, int64_t, half)
INDEX_OP(is_i64_bf16, int64_t, bfloat)
#endif

INDEX_OP(is_i32_f32, int32_t, float)
INDEX_OP(is_i32_f16, int32_t, half)
#if defined(__HAVE_BFLOAT__)
INDEX_OP(is_i32_bf16, int32_t, bfloat)
#endif

INDEX_OP(is_u32_f32, uint32_t, float)
INDEX_OP(is_u32_f16, uint32_t, half)
#if defined(__HAVE_BFLOAT__)
Expand All @@ -213,9 +219,11 @@ GATHER_OP(gather_u32_bf16, uint, bfloat)

SCATTER_ADD_OP(sa_u32_f32, uint32_t, float)
SCATTER_ADD_OP(sa_u8_f32, uint8_t, float)
SCATTER_ADD_OP(sa_i32_f32, int32_t, float)
SCATTER_ADD_OP(sa_i64_f32, int64_t, float)
SCATTER_ADD_OP(sa_u32_f16, uint32_t, half)
SCATTER_ADD_OP(sa_u8_f16, uint8_t, half)
SCATTER_ADD_OP(sa_i32_f16, int32_t, half)
SCATTER_ADD_OP(sa_i64_f16, int64_t, half)
#if defined(__HAVE_BFLOAT__)
SCATTER_ADD_OP(sa_u32_bf16, uint32_t, bfloat)
Expand All @@ -226,16 +234,29 @@ SCATTER_ADD_OP(sa_i64_bf16, int64_t, bfloat)
// i64
INDEX_ADD_OP(ia_i64_f16, int64_t, half)
INDEX_ADD_OP(ia_i64_f32, int64_t, float)
INDEX_ADD_OP(ia_i64_i32, int64_t, int32_t)
INDEX_ADD_OP(ia_i64_i64, int64_t, int64_t)
INDEX_ADD_OP(ia_i64_u32, int64_t, uint32_t)
INDEX_ADD_OP(ia_i64_u8, int64_t, uint8_t)
#if defined(__HAVE_BFLOAT__)
INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat)
#endif

// i64
INDEX_ADD_OP(ia_i32_f16, int32_t, half)
INDEX_ADD_OP(ia_i32_f32, int32_t, float)
INDEX_ADD_OP(ia_i32_i64, int32_t, int64_t)
INDEX_ADD_OP(ia_i32_i32, int32_t, int32_t)
INDEX_ADD_OP(ia_i32_u32, int32_t, uint32_t)
INDEX_ADD_OP(ia_i32_u8, int32_t, uint8_t)
#if defined(__HAVE_BFLOAT__)
INDEX_ADD_OP(ia_i32_bf16, int32_t, bfloat)
#endif

// u32
INDEX_ADD_OP(ia_u32_f16, uint32_t, half)
INDEX_ADD_OP(ia_u32_f32, uint32_t, float)
INDEX_ADD_OP(ia_u32_i32, uint32_t, int32_t)
INDEX_ADD_OP(ia_u32_i64, uint32_t, int64_t)
INDEX_ADD_OP(ia_u32_u32, uint32_t, uint32_t)
INDEX_ADD_OP(ia_u32_u8, uint32_t, uint8_t)
Expand All @@ -246,6 +267,7 @@ INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat)
// u8
INDEX_ADD_OP(ia_u8_f16, uint8_t, half)
INDEX_ADD_OP(ia_u8_f32, uint8_t, float)
INDEX_ADD_OP(ia_u8_i32, uint8_t, int32_t)
INDEX_ADD_OP(ia_u8_i64, uint8_t, int64_t)
INDEX_ADD_OP(ia_u8_u32, uint8_t, uint32_t)
INDEX_ADD_OP(ia_u8_u8, uint8_t, uint8_t)
Expand Down
7 changes: 7 additions & 0 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ pub mod copy2d {
pub const HALF: Kernel = Kernel("copy2d_f16");
pub const BFLOAT: Kernel = Kernel("copy2d_bf16");
pub const I64: Kernel = Kernel("copy2d_i64");
pub const I32: Kernel = Kernel("copy2d_i32");
pub const U32: Kernel = Kernel("copy2d_u32");
pub const U8: Kernel = Kernel("copy2d_u8");
}
Expand All @@ -62,6 +63,7 @@ macro_rules! ops{
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16"));
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16"));
pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64"));
pub const I32: Kernel = Kernel(concat!(stringify!($name), "_i32"));
pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32"));
pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8"));
}
Expand All @@ -72,6 +74,7 @@ macro_rules! ops{
pub const HALF: Kernel = Kernel("copy_f16");
pub const BFLOAT: Kernel = Kernel("copy_bf16");
pub const I64: Kernel = Kernel("copy_i64");
pub const I32: Kernel = Kernel("copy_i32");
pub const U32: Kernel = Kernel("copy_u32");
pub const U8: Kernel = Kernel("copy_u8");
}
Expand All @@ -86,6 +89,7 @@ macro_rules! ops{
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_tiled"));
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_tiled"));
pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_tiled"));
pub const I32: Kernel = Kernel(concat!(stringify!($name), "_i32_tiled"));
pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_tiled"));
pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_tiled"));
}
Expand All @@ -96,6 +100,7 @@ macro_rules! ops{
pub const HALF: Kernel = Kernel("copy_f16_tiled");
pub const BFLOAT: Kernel = Kernel("copy_bf16_tiled");
pub const I64: Kernel = Kernel("copy_i64_tiled");
pub const I32: Kernel = Kernel("copy_i32_tiled");
pub const U32: Kernel = Kernel("copy_u32_tiled");
pub const U8: Kernel = Kernel("copy_u8_tiled");
}
Expand All @@ -110,6 +115,7 @@ macro_rules! ops{
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_strided"));
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided"));
pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_strided"));
pub const I32: Kernel = Kernel(concat!(stringify!($name), "_i32_strided"));
pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_strided"));
pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_strided"));
}
Expand All @@ -120,6 +126,7 @@ macro_rules! ops{
pub const HALF: Kernel = Kernel("copy_f16_strided");
pub const BFLOAT: Kernel = Kernel("copy_bf16_strided");
pub const I64: Kernel = Kernel("copy_i64_strided");
pub const I32: Kernel = Kernel("copy_i32_strided");
pub const U32: Kernel = Kernel("copy_u32_strided");
pub const U8: Kernel = Kernel("copy_u8_strided");
}
Expand Down
6 changes: 6 additions & 0 deletions candle-metal-kernels/src/reduce.metal
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,12 @@ ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX)
ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN)
#endif

REDUCE(x + y, fast_sum_i32_strided, int32_t, 0)
REDUCE(MIN(x, y), fast_min_i32_strided, int32_t, INT_MAX)
REDUCE(MAX(x, y), fast_max_i32_strided, int32_t, INT_MIN)
ARGMIN(fast_argmin_i32_strided, int32_t, INT_MAX)
ARGMAX(fast_argmax_i32_strided, int32_t, INT_MIN)

#if defined(__HAVE_BFLOAT__)
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
REDUCE(x + y, fast_sum_bf16_strided, half, 0)
Expand Down
1 change: 1 addition & 0 deletions candle-metal-kernels/src/sort.metal
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ ARGSORT(float, f32)
ARGSORT(half, f16)
ARGSORT(uint8_t, u8)
ARGSORT(uint32_t, u32)
ARGSORT(int32_t, i32)

#if __METAL_VERSION__ >= 220
ARGSORT(int64_t, i64)
Expand Down
14 changes: 14 additions & 0 deletions candle-metal-kernels/src/ternary.metal
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,25 @@ WHERE_OP(float, int64_t, where_i64_f32)
WHERE_OP(uint8_t, int64_t, where_i64_u8)
WHERE_OP(uint32_t, int64_t, where_i64_u32)
WHERE_OP(int64_t, int64_t, where_i64_i64)
WHERE_OP(int64_t, int32_t, where_i64_i32)
#if defined(__HAVE_BFLOAT__)
WHERE_OP(bfloat, int64_t, where_i64_bf16)
#endif
#endif

WHERE_OP(int64_t, uint8_t, where_u8_i32)
WHERE_OP(int64_t, uint32_t, where_u32_i32)

WHERE_OP(half, int32_t, where_i32_f16)
WHERE_OP(float, int32_t, where_i32_f32)
WHERE_OP(uint8_t, int32_t, where_i32_u8)
WHERE_OP(uint32_t, int32_t, where_i32_u32)
WHERE_OP(int64_t, int32_t, where_i32_i64)
WHERE_OP(int32_t, int32_t, where_i32_i32)
#if defined(__HAVE_BFLOAT__)
WHERE_OP(bfloat, int32_t, where_i32_bf16)
#endif

#if defined(__HAVE_BFLOAT__)
WHERE_OP(bfloat, uint8_t, where_u8_bf16)
WHERE_OP(bfloat, uint32_t, where_u32_bf16)
Expand Down
3 changes: 3 additions & 0 deletions candle-metal-kernels/src/unary.metal
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ UNARY(id, int64_t, copy_i64, copy_i64_strided)
COPY2D(copy2d_i64, int64_t)
#endif

UNARY(id, int32_t, copy_i32, copy_i32_strided)
COPY2D(copy2d_i32, int32_t)

#if defined(__HAVE_BFLOAT__)
BFLOAT_UNARY_OP(cos)
BFLOAT_UNARY_OP(sin)
Expand Down

0 comments on commit 87229dc

Please sign in to comment.