From 589b40810adf584b764da28a5ea0f3d4cdee1be0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 3 Feb 2025 16:32:48 +0200 Subject: [PATCH 01/64] ci : dummy commit to trigger CI --- src/whisper.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/whisper.cpp b/src/whisper.cpp index 11077d5b687..069aa6e73b9 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -33,6 +33,8 @@ #include #include +// dummy + #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data #endif From f12559d590073f1ff8868bdd35c35842c271b262 Mon Sep 17 00:00:00 2001 From: issixx <46835150+issixx@users.noreply.github.com> Date: Fri, 17 Jan 2025 21:29:08 +0900 Subject: [PATCH 02/64] ggml-cpu : fix ggml_graph_compute_thread did not terminate on abort. (ggml/1065) some threads kept looping and failed to terminate properly after an abort during CPU execution. Co-authored-by: issi --- ggml/src/ggml-cpu/ggml-cpu.c | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 2966ff7682d..424c9781f44 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1302,7 +1302,7 @@ struct ggml_threadpool { // these are atomic as an annotation for thread-sanitizer atomic_bool stop; // Used for stopping the threadpool altogether atomic_bool pause; // Used for pausing the threadpool or individual threads - atomic_bool abort; // Used for aborting processing of a graph + atomic_int abort; // Used for aborting processing of a graph struct ggml_compute_state * workers; // per thread state int n_threads_max; // number of threads in the pool @@ -13778,14 +13778,14 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { /*.threadpool=*/ tp, }; - for (int node_n = 0; node_n < cgraph->n_nodes && !tp->abort; node_n++) { + for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) { struct ggml_tensor * node = cgraph->nodes[node_n]; ggml_compute_forward(¶ms, node); if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) { - tp->abort = true; + atomic_store_explicit(&tp->abort, node_n + 1, memory_order_relaxed); tp->ec = GGML_STATUS_ABORTED; } @@ -13958,7 +13958,7 @@ static struct ggml_threadpool * ggml_threadpool_new_impl( threadpool->current_chunk = 0; threadpool->stop = false; threadpool->pause = tpp->paused; - threadpool->abort = false; + threadpool->abort = -1; threadpool->workers = NULL; threadpool->n_threads_max = tpp->n_threads; threadpool->n_threads_cur = tpp->n_threads; @@ -14037,7 +14037,7 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl threadpool->cgraph = cgraph; threadpool->cplan = cplan; threadpool->current_chunk = 0; - threadpool->abort = false; + threadpool->abort = -1; threadpool->ec = GGML_STATUS_SUCCESS; } From 8e0143e205c275623d0193a8c0d8d1f1184e7d7f Mon Sep 17 00:00:00 2001 From: William Tambellini Date: Thu, 23 Jan 2025 11:59:08 -0800 Subject: [PATCH 03/64] ggml : add option to not print stack on abort (ggml/1081) * Add option to not print stack on abort Add option/envvar to disable stack printing on abort. Also link some unittests with Threads to fix link errors on ubuntu/g++11. * Update ggml/src/ggml.c --------- Co-authored-by: Diego Devesa --- ggml/src/ggml.c | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index da5b817e156..ae8147dd7d2 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -128,6 +128,10 @@ static void ggml_print_backtrace_symbols(void) { #endif static void ggml_print_backtrace(void) { + const char * GGML_NO_BACKTRACE = getenv("GGML_NO_BACKTRACE"); + if (GGML_NO_BACKTRACE) { + return; + } char attach[32]; snprintf(attach, sizeof(attach), "attach %d", getpid()); int pid = fork(); From 9700cfb0a3ea5a0789b3c55e9e4217a4f34ba2f6 Mon Sep 17 00:00:00 2001 From: Akarshan Biswas Date: Wed, 15 Jan 2025 08:50:17 +0530 Subject: [PATCH 04/64] SYCL: Add gated linear attention kernel (llama/11175) * SYCL: Add Gated Linear attention kernel * glahpp: add a space at the end of file * gla: Put the barrier inside the main logic loop --- ggml/src/ggml-sycl/backend.hpp | 1 + ggml/src/ggml-sycl/ggml-sycl.cpp | 4 ++ ggml/src/ggml-sycl/gla.cpp | 105 +++++++++++++++++++++++++++++++ ggml/src/ggml-sycl/gla.hpp | 8 +++ 4 files changed, 118 insertions(+) create mode 100644 ggml/src/ggml-sycl/gla.cpp create mode 100644 ggml/src/ggml-sycl/gla.hpp diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp index 85748a5b4c1..b1df4e5db17 100644 --- a/ggml/src/ggml-sycl/backend.hpp +++ b/ggml/src/ggml-sycl/backend.hpp @@ -29,5 +29,6 @@ #include "wkv6.hpp" #include "outprod.hpp" #include "element_wise.hpp" +#include "gla.hpp" #endif // GGML_SYCL_BACKEND_HPP diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 037c8093eef..5272ca45423 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4040,6 +4040,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens case GGML_OP_RWKV_WKV6: ggml_sycl_op_rwkv_wkv6(ctx, dst); break; + case GGML_OP_GATED_LINEAR_ATTN: + ggml_sycl_op_gated_linear_attn(ctx, dst); + break; default: return false; } @@ -4507,6 +4510,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_LEAKY_RELU: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_RWKV_WKV6: + case GGML_OP_GATED_LINEAR_ATTN: return true; default: return false; diff --git a/ggml/src/ggml-sycl/gla.cpp b/ggml/src/ggml-sycl/gla.cpp new file mode 100644 index 00000000000..eedb4748643 --- /dev/null +++ b/ggml/src/ggml-sycl/gla.cpp @@ -0,0 +1,105 @@ +#include + +#include "common.hpp" + +template +static void gated_linear_attn_f32_kernel(const dpct::queue_ptr stream, u_int B, u_int T, u_int C, u_int H, float scale, + const float * k, const float * v, const float * r, const float * td, + const float * s, float * dst) { + const u_int head_size = HEAD_SIZE; + const u_int state_size = C * head_size; + const u_int n_seq_tokens = T / B; + sycl::range<1> block_dims((C / H)); + sycl::range<1> grid_dims((B * H)); + stream->submit([&](sycl::handler & cgh) { + /* local memory accessors*/ + auto _k = sycl::local_accessor(sycl::range<1>(head_size), cgh); + auto _r = sycl::local_accessor(sycl::range<1>(head_size), cgh); + auto _td = sycl::local_accessor(sycl::range<1>(head_size), cgh); + + cgh.parallel_for(sycl::nd_range<1>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<1> item) { + u_int tid = item.get_local_id(0); + u_int bid = item.get_group(0); + + u_int batch_i = bid / H; + u_int head_i = bid % H; + + float state[head_size]; + +#pragma unroll + for (u_int i = 0; i < head_size; i++) { + state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid]; + } + + for (u_int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; + t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) { + + item.barrier(sycl::access::fence_space::local_space); //sync threads + _k[tid] = k[t]; + _r[tid] = r[t]; + _td[tid] = td[t]; + item.barrier(sycl::access::fence_space::local_space); //sync threads + + const float _v = v[t]; + float y = 0; + + for (u_int j = 0; j < head_size; j += 4) { + const sycl::float4 & k = (sycl::float4 &) (_k[j]); + const sycl::float4 & r = (sycl::float4 &) (_r[j]); + const sycl::float4 & td = (sycl::float4 &) (_td[j]); + sycl::float4 & s = (sycl::float4 &) (state[j]); + sycl::float4 kv; + + kv.x() = k.x() * _v; + kv.y() = k.y() * _v; + kv.z() = k.z() * _v; + kv.w() = k.w() * _v; + + s.x() = s.x() * td.x() + kv.x(); + s.y() = s.y() * td.y() + kv.y(); + s.z() = s.z() * td.z() + kv.z(); + s.w() = s.w() * td.w() + kv.w(); + + y += r.x() * s.x(); + y += r.y() * s.y(); + y += r.z() * s.z(); + y += r.w() * s.w(); + } + dst[t] = y * scale; + } +#pragma unroll + for (u_int i = 0; i < head_size; i++) { + dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i]; + } + }); + }); +} + +void ggml_sycl_op_gated_linear_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const float * k_d = static_cast(dst->src[0]->data); + const float * v_d = static_cast(dst->src[1]->data); + const float * r_d = static_cast(dst->src[2]->data); + const float * td_d = static_cast(dst->src[3]->data); + const float * s_d = static_cast(dst->src[4]->data); + + const int64_t B = dst->src[4]->ne[1]; + const int64_t T = dst->src[0]->ne[2]; + const int64_t C = dst->ne[0]; + const int64_t H = dst->src[0]->ne[1]; + + dpct::queue_ptr stream = ctx.stream(); + GGML_ASSERT(dst->src[4]->type == GGML_TYPE_F32); + GGML_ASSERT(C % H == 0); + GGML_ASSERT(C / H == 64 || C / H == 128); + + float scale; + memcpy(&scale, dst->op_params, sizeof(float)); + + float * dst_d = (float *) dst->data; + + if (C / H == 64) { + gated_linear_attn_f32_kernel<64>(stream, B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d); + } else { + gated_linear_attn_f32_kernel<128>(stream, B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d); + } +} diff --git a/ggml/src/ggml-sycl/gla.hpp b/ggml/src/ggml-sycl/gla.hpp new file mode 100644 index 00000000000..607cf3a7f30 --- /dev/null +++ b/ggml/src/ggml-sycl/gla.hpp @@ -0,0 +1,8 @@ +#ifndef GGML_SYCL_GLA_HPP +#define GGML_SYCL_GLA_HPP + +#include "common.hpp" + +void ggml_sycl_op_gated_linear_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +#endif // GGML_SYCL_GLA_HPP From 54a2ee648fdb021c19a1f22e3d5016251b30edf0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 15 Jan 2025 12:51:37 +0100 Subject: [PATCH 05/64] RoPE: fix back, CUDA support for back + noncont. (llama/11240) * RoPE: fix back, CUDA support for back + noncont. * fix comments reg. non-cont. RoPE support [no-ci] --- ggml/include/ggml.h | 19 +- ggml/src/ggml-cpu/ggml-cpu.c | 1 + ggml/src/ggml-cpu/ggml-cpu.cpp | 2 - ggml/src/ggml-cuda/ggml-cuda.cu | 10 +- ggml/src/ggml-cuda/rope.cu | 366 ++++++++++++++------------------ ggml/src/ggml-cuda/rope.cuh | 2 + ggml/src/ggml.c | 58 ++--- 7 files changed, 225 insertions(+), 233 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 8f8cb9e1aa1..a9c051cd5d6 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1500,7 +1500,7 @@ extern "C" { // rotary position embedding backward, i.e compute dx from dy // a - dy - GGML_API struct ggml_tensor * ggml_rope_back( + GGML_API struct ggml_tensor * ggml_rope_ext_back( struct ggml_context * ctx, struct ggml_tensor * a, // gradients of ggml_rope result struct ggml_tensor * b, // positions @@ -1515,6 +1515,23 @@ extern "C" { float beta_fast, float beta_slow); + GGML_API struct ggml_tensor * ggml_rope_multi_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + int n_dims, + int sections[4], + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow); + + // clamp // in-place, returns view(a) GGML_API struct ggml_tensor * ggml_clamp( diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 424c9781f44..8bf5f781a59 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -13668,6 +13668,7 @@ struct ggml_cplan ggml_graph_plan( } break; case GGML_OP_SOFT_MAX: case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: { cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks; } break; diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index f11399cc628..5c47ceb7314 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -403,8 +403,6 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float case GGML_OP_MUL_MAT: return src1->type == GGML_TYPE_F32 || src1->type == ggml_get_type_traits_cpu(src0->type)->vec_dot_type; - case GGML_OP_ROPE_BACK: - return op->src[2] == NULL && (op->op_params[2] & 4) == 0; case GGML_OP_IM2COL_BACK: return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32; case GGML_OP_OUT_PROD: diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 1dac397c4b0..9118edc7282 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2141,6 +2141,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_ROPE: ggml_cuda_op_rope(ctx, dst); break; + case GGML_OP_ROPE_BACK: + ggml_cuda_op_rope_back(ctx, dst); + break; case GGML_OP_IM2COL: ggml_cuda_op_im2col(ctx, dst); break; @@ -3025,7 +3028,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_SOFT_MAX: return true; case GGML_OP_ROPE: - return ggml_is_contiguous(op->src[0]); + case GGML_OP_ROPE_BACK: { + const size_t ts = ggml_type_size(op->src[0]->type); + const int64_t ne0_012 = op->src[0]->ne[0] * op->src[0]->ne[1] * op->src[0]->ne[2]; + return op->src[0]->nb[0] == ts && op->src[0]->nb[3] == ne0_012*ts; + } case GGML_OP_IM2COL: case GGML_OP_POOL_2D: case GGML_OP_SUM: @@ -3081,6 +3088,7 @@ static int64_t get_op_batch_size(const ggml_tensor * op) { return op->ne[1]; case GGML_OP_MUL_MAT_ID: case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: return op->ne[2]; default: return ggml_nrows(op); diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index 2c84778d29c..e1912fee1f9 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -16,9 +16,10 @@ static __device__ float rope_yarn_ramp(const float low, const float high, const // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. +template static __device__ void rope_yarn( - float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale, - float * cos_theta, float * sin_theta) { + const float theta_extrap, const float freq_scale, const rope_corr_dims corr_dims, const int64_t i0, const float ext_factor, + float mscale, float & cos_theta, float & sin_theta) { // Get n-d rotational scaling corrected for extrapolation float theta_interp = freq_scale * theta_extrap; float theta = theta_interp; @@ -29,24 +30,28 @@ static __device__ void rope_yarn( // Get n-d magnitude scaling corrected for interpolation mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); } - *cos_theta = cosf(theta) * mscale; - *sin_theta = sinf(theta) * mscale; + cos_theta = cosf(theta) * mscale; + sin_theta = sinf(theta) * mscale; + if (!forward) { + sin_theta *= -1.0f; + } } -template +template static __global__ void rope_norm( - const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, - float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) { + const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, + const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (i0 >= ne0) { return; } - const int row = blockDim.x*blockIdx.x + threadIdx.x; + const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; if (i0 >= n_dims) { - const int i = row*ne0 + i0; + const int i = row_dst*ne0 + i0; dst[i + 0] = x[i + 0]; dst[i + 1] = x[i + 1]; @@ -54,39 +59,43 @@ static __global__ void rope_norm( return; } - const int i = row*ne0 + i0; - const int i2 = row/p_delta_rows; + const int row_x = row_dst % ne1; + const int channel_x = row_dst / ne1; + + const int idst = row_dst*ne0 + i0; + const int ix = channel_x*s2 + row_x*s1 + i0; - const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f); + const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; float cos_theta; float sin_theta; - rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta); - const float x0 = x[i + 0]; - const float x1 = x[i + 1]; + const float x0 = x[ix + 0]; + const float x1 = x[ix + 1]; - dst[i + 0] = x0*cos_theta - x1*sin_theta; - dst[i + 1] = x0*sin_theta + x1*cos_theta; + dst[idst + 0] = x0*cos_theta - x1*sin_theta; + dst[idst + 1] = x0*sin_theta + x1*cos_theta; } -template +template static __global__ void rope_neox( - const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, - float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) { + const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, + const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (i0 >= ne0) { return; } - const int row = blockDim.x*blockIdx.x + threadIdx.x; + const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; if (i0 >= n_dims) { - const int i = row*ne0 + i0; + const int i = row_dst*ne0 + i0; dst[i + 0] = x[i + 0]; dst[i + 1] = x[i + 1]; @@ -94,39 +103,43 @@ static __global__ void rope_neox( return; } - const int i = row*ne0 + i0/2; - const int i2 = row/p_delta_rows; + const int row_x = row_dst % ne1; + const int channel_x = row_dst / ne1; + + const int idst = row_dst*ne0 + i0/2; + const int ix = channel_x*s2 + row_x*s1 + i0/2; - const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f); + const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; float cos_theta; float sin_theta; - rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta); - const float x0 = x[i + 0]; - const float x1 = x[i + n_dims/2]; + const float x0 = x[ix + 0]; + const float x1 = x[ix + n_dims/2]; - dst[i + 0] = x0*cos_theta - x1*sin_theta; - dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta; + dst[idst + 0] = x0*cos_theta - x1*sin_theta; + dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta; } -template +template static __global__ void rope_multi( - const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, - float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) { + const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, + const int n_dims, const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors, const mrope_sections sections) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (i0 >= ne0) { return; } - const int row = blockDim.x*blockIdx.x + threadIdx.x; + const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; if (i0 >= n_dims) { - const int i = row*ne0 + i0; + const int i = row_dst*ne0 + i0; dst[i + 0] = x[i + 0]; dst[i + 1] = x[i + 1]; @@ -134,25 +147,28 @@ static __global__ void rope_multi( return; } - const int i = row*ne0 + i0/2; - const int i2 = row/p_delta_rows; + const int row_x = row_dst % ne1; + const int channel_x = row_dst / ne1; - int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3]; - int sec_w = sections.v[1] + sections.v[0]; - int sector = (i0 / 2) % sect_dims; + const int idst = row_dst*ne0 + i0/2; + const int ix = channel_x*s2 + row_x*s1 + i0/2; + + const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3]; + const int sec_w = sections.v[1] + sections.v[0]; + const int sector = (i0 / 2) % sect_dims; float theta_base = 0.0; if (sector < sections.v[0]) { - theta_base = pos[i2]*powf(theta_scale, i0/2.0f); + theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); } else if (sector >= sections.v[0] && sector < sec_w) { - theta_base = pos[i2 + ne2 * 1]*powf(theta_scale, i0/2.0f); + theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f); } else if (sector >= sec_w && sector < sec_w + sections.v[2]) { - theta_base = pos[i2 + ne2 * 2]*powf(theta_scale, i0/2.0f); + theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f); } else if (sector >= sec_w + sections.v[2]) { - theta_base = pos[i2 + ne2 * 3]*powf(theta_scale, i0/2.0f); + theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f); } const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; @@ -160,42 +176,46 @@ static __global__ void rope_multi( float cos_theta; float sin_theta; - rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta); - const float x0 = x[i + 0]; - const float x1 = x[i + n_dims/2]; + const float x0 = x[ix + 0]; + const float x1 = x[ix + n_dims/2]; - dst[i + 0] = x0*cos_theta - x1*sin_theta; - dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta; + dst[idst + 0] = x0*cos_theta - x1*sin_theta; + dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta; } -template +template static __global__ void rope_vision( - const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, - float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) { + const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, + const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, + const float theta_scale, const float * __restrict__ freq_factors, const mrope_sections sections) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (i0 >= ne0) { return; } - const int row = blockDim.x*blockIdx.x + threadIdx.x; + const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; + + const int row_x = row_dst % ne1; + const int channel_x = row_dst / ne1; - const int i = row*ne0 + i0/2; - const int i2 = row/p_delta_rows; // i2-th tokens + const int idst = row_dst*ne0 + i0/2; + const int ix = channel_x*s2 + row_x*s1 + i0/2; - int sect_dims = sections.v[0] + sections.v[1]; - int sec_w = sections.v[1] + sections.v[0]; - int sector = (i0 / 2) % sect_dims; + const int sect_dims = sections.v[0] + sections.v[1]; + const int sec_w = sections.v[1] + sections.v[0]; + const int sector = (i0 / 2) % sect_dims; float theta_base = 0.0; if (sector < sections.v[0]) { const int p = sector; - theta_base = pos[i2]*powf(theta_scale, p); + theta_base = pos[channel_x]*powf(theta_scale, p); } else if (sector >= sections.v[0] && sector < sec_w) { const int p = sector - sections.v[0]; - theta_base = pos[i2 + ne2]*powf(theta_scale, p); + theta_base = pos[channel_x + ne2]*powf(theta_scale, p); } const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; @@ -203,19 +223,20 @@ static __global__ void rope_vision( float cos_theta; float sin_theta; - rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta); - const float x0 = x[i + 0]; - const float x1 = x[i + n_dims]; + const float x0 = x[ix + 0]; + const float x1 = x[ix + n_dims]; - dst[i + 0] = x0*cos_theta - x1*sin_theta; - dst[i + n_dims] = x0*sin_theta + x1*cos_theta; + dst[idst + 0] = x0*cos_theta - x1*sin_theta; + dst[idst + n_dims] = x0*sin_theta + x1*cos_theta; } -template +template static void rope_norm_cuda( - const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { + const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr, + const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, cudaStream_t stream) { GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); @@ -224,22 +245,21 @@ static void rope_norm_cuda( const float theta_scale = powf(freq_base, -2.0f/n_dims); if (freq_factors == nullptr) { - rope_norm<<>>( - x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, freq_factors - ); + rope_norm<<>>( + x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors); } else { - rope_norm<<>>( - x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, freq_factors - ); + rope_norm<<>>( + x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors); } } -template +template static void rope_neox_cuda( - const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { + const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr, + const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, cudaStream_t stream) { GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); @@ -248,22 +268,21 @@ static void rope_neox_cuda( const float theta_scale = powf(freq_base, -2.0f/n_dims); if (freq_factors == nullptr) { - rope_neox<<>>( - x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, freq_factors - ); + rope_neox<<>>( + x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors); } else { - rope_neox<<>>( - x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, freq_factors - ); + rope_neox<<>>( + x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors); } } -template +template static void rope_multi_cuda( - const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) { + const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, + const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, const mrope_sections sections, cudaStream_t stream) { GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); @@ -272,22 +291,21 @@ static void rope_multi_cuda( const float theta_scale = powf(freq_base, -2.0f/n_dims); if (freq_factors == nullptr) { - rope_multi<<>>( - x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, freq_factors, sections - ); + rope_multi<<>>( + x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, sections); } else { - rope_multi<<>>( - x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, freq_factors, sections - ); + rope_multi<<>>( + x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, sections); } } -template +template static void rope_vision_cuda( - const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) { + const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, + const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, const mrope_sections sections, cudaStream_t stream) { GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); @@ -298,80 +316,18 @@ static void rope_vision_cuda( const float theta_scale = powf(freq_base, -2.0f/n_dims); if (freq_factors == nullptr) { - rope_vision<<>>( - x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, freq_factors, sections - ); + rope_vision<<>>( + x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, sections); } else { - rope_vision<<>>( - x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, freq_factors, sections - ); + rope_vision<<>>( + x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, sections); } } -static void rope_norm_cuda_f16( - const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { - - rope_norm_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); -} - -static void rope_norm_cuda_f32( - const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { - - rope_norm_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); -} - -static void rope_neox_cuda_f16( - const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { - - rope_neox_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); -} - -static void rope_neox_cuda_f32( - const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream -) { - - rope_neox_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); -} - -static void rope_multi_cuda_f16( - const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream -) { - - rope_multi_cuda(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); -} - -static void rope_multi_cuda_f32( - const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream -) { - - rope_multi_cuda(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); -} - -static void rope_vision_cuda_f16( - const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream -) { - - rope_vision_cuda(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); -} - -static void rope_vision_cuda_f32( - const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream -) { - - rope_vision_cuda(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); -} - -void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { +template +void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; const ggml_tensor * src2 = dst->src[2]; @@ -382,7 +338,6 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); - GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == dst->type); @@ -392,6 +347,9 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int64_t ne02 = src0->ne[2]; // num heads const int64_t nr = ggml_nrows(src0); + const size_t s01 = src0->nb[1] / ggml_type_size(src0->type); + const size_t s02 = src0->nb[2] / ggml_type_size(src0->type); + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; @@ -440,59 +398,59 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { // compute if (is_neox) { if (src0->type == GGML_TYPE_F32) { - rope_neox_cuda_f32( - (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, freq_factors, stream - ); + rope_neox_cuda( + (const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } else if (src0->type == GGML_TYPE_F16) { - rope_neox_cuda_f16( - (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, freq_factors, stream - ); + rope_neox_cuda( + (const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } else { GGML_ABORT("fatal error"); } } else if (is_mrope && !is_vision) { if (src0->type == GGML_TYPE_F32) { - rope_multi_cuda_f32( - (const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, freq_factors, sections, stream - ); + rope_multi_cuda( + (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); } else if (src0->type == GGML_TYPE_F16) { - rope_multi_cuda_f16( - (const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, freq_factors, sections, stream - ); + rope_multi_cuda( + (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); } else { GGML_ABORT("fatal error"); } } else if (is_vision) { if (src0->type == GGML_TYPE_F32) { - rope_vision_cuda_f32( - (const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, freq_factors, sections, stream - ); + rope_vision_cuda( + (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); } else if (src0->type == GGML_TYPE_F16) { - rope_vision_cuda_f16( - (const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, freq_factors, sections, stream - ); + rope_vision_cuda( + (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); } else { GGML_ABORT("fatal error"); } } else { if (src0->type == GGML_TYPE_F32) { - rope_norm_cuda_f32( - (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, freq_factors, stream - ); + rope_norm_cuda( + (const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } else if (src0->type == GGML_TYPE_F16) { - rope_norm_cuda_f16( - (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, freq_factors, stream - ); + rope_norm_cuda( + (const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } else { GGML_ABORT("fatal error"); } } } + +void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_rope_impl(ctx, dst); +} + +void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_rope_impl(ctx, dst); +} diff --git a/ggml/src/ggml-cuda/rope.cuh b/ggml/src/ggml-cuda/rope.cuh index 0f787a0b2f7..9139f3b220d 100644 --- a/ggml/src/ggml-cuda/rope.cuh +++ b/ggml/src/ggml-cuda/rope.cuh @@ -3,3 +3,5 @@ #define CUDA_ROPE_BLOCK_SIZE 256 void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index ae8147dd7d2..0a0dcc7e3b1 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3699,7 +3699,7 @@ void ggml_rope_yarn_corr_dims( // ggml_rope_back -struct ggml_tensor * ggml_rope_back( +struct ggml_tensor * ggml_rope_ext_back( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, @@ -3713,29 +3713,32 @@ struct ggml_tensor * ggml_rope_back( float attn_factor, float beta_fast, float beta_slow) { - GGML_ASSERT(ggml_is_vector(b)); - GGML_ASSERT(b->type == GGML_TYPE_I32); - GGML_ASSERT(a->ne[2] == b->ne[0]); - - struct ggml_tensor * result = ggml_dup_tensor(ctx, a); - - int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; - memcpy(params + 5, &freq_base, sizeof(float)); - memcpy(params + 6, &freq_scale, sizeof(float)); - memcpy(params + 7, &ext_factor, sizeof(float)); - memcpy(params + 8, &attn_factor, sizeof(float)); - memcpy(params + 9, &beta_fast, sizeof(float)); - memcpy(params + 10, &beta_slow, sizeof(float)); - ggml_set_op_params(result, params, sizeof(params)); - - result->op = GGML_OP_ROPE_BACK; - result->src[0] = a; - result->src[1] = b; - result->src[2] = c; - + struct ggml_tensor * result = ggml_rope_ext( + ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + result->op = GGML_OP_ROPE_BACK; return result; } +struct ggml_tensor * ggml_rope_multi_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + int n_dims, + int sections[4], + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + struct ggml_tensor * result = ggml_rope_multi( + ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + result->op = GGML_OP_ROPE_BACK; + return result; +} // ggml_clamp struct ggml_tensor * ggml_clamp( @@ -5598,6 +5601,7 @@ static void ggml_compute_backward( //const int n_ctx = ((int32_t *) tensor->op_params)[3]; const int n_ctx_orig = ((const int32_t *) tensor->op_params)[4]; float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + int sections[4] = {0, 0, 0, 0}; memcpy(&freq_base, (const float *) tensor->op_params + 5, sizeof(float)); memcpy(&freq_scale, (const float *) tensor->op_params + 6, sizeof(float)); @@ -5605,10 +5609,14 @@ static void ggml_compute_backward( memcpy(&attn_factor, (const float *) tensor->op_params + 8, sizeof(float)); memcpy(&beta_fast, (const float *) tensor->op_params + 9, sizeof(float)); memcpy(&beta_slow, (const float *) tensor->op_params + 10, sizeof(float)); - - ggml_add_or_set(ctx, cgraph, isrc0, - ggml_rope_back(ctx, grad, src1, src2, n_dims, mode, n_ctx_orig, freq_base, - freq_scale, ext_factor, attn_factor, beta_fast, beta_slow)); + memcpy(§ions, tensor->op_params + 11, sizeof(sections)); + + struct ggml_tensor * rope_back = grad->ne[2] == src1->ne[0] ? + ggml_rope_ext_back(ctx, grad, src1, src2, n_dims, + mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow) : + ggml_rope_multi_back(ctx, grad, src1, src2, n_dims, sections, + mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + ggml_add_or_set(ctx, cgraph, isrc0, rope_back); } GGML_ASSERT((!src2 || !src2_needs_grads) && "gradients for freq factors not implemented"); } break; From 02aa86230abd5cceb8b7b88813643a1c8066c62a Mon Sep 17 00:00:00 2001 From: Junil Kim Date: Wed, 15 Jan 2025 22:17:42 +0900 Subject: [PATCH 06/64] fix: ggml: fix vulkan-shaders-gen build (llama/10448) * fix: ggml: fix vulkan-shaders-gen build The vulkan-shaders-gen target was not being built correctly in case of cross-compilation. Other outputs need to be built for the cross compile target, but vulkan-shaders-gen needs to be built for the host. * refactor: ggml: Improve vulkan-shaders-gen toolchain setup - Add GGML_SHADERS_GEN_TOOLCHAIN CMake option. - Auto-detect host toolchain if not set. * refactor: ggml: Improve vulkan-shaders-gen toolchain setup Use configure_file to generate host_toolchain.cmake from template * fix: ggml: Fix compile error Fix compile error not finding vulkan-shaders-gen * fix: vulkan-shaders-gen build and path handling Fix build issues with vulkan-shaders-gen: - Add target dependency for correct build order - Use CMAKE_HOST_SYSTEM_NAME for executable suffix - Fix MSVC output directory in host toolchain - Normalize path handling for cross-compilation * fix: improve host compiler detection in vulkan shader build Improve host compiler detection for vulkan shader generation: - Add NO_CMAKE_FIND_ROOT_PATH to all compiler searches - Consolidate compiler detection logic - Fix Windows-specific MSVC detection - Ensure correct compiler search in cross-compilation * refactor: Simplify CMake function for detecting host compiler Simplified the CMake function to improve the process of detecting the host compiler. * fix: Remove unnecessary Vulkan library linkage in CMakeLists.txt Since `vulkan-shader-gen.cpp` only requires the `glslc` executable and not the Vulkan headers or libraries, CMakeLists.txt needs to be corrected. (See: ecc93d0558fc3ecb8a5af69d2ece02fae4710ade) * refactor: Rename host_toolchain.cmake.in - Rename host_toolchain.cmake.in to cmake/host-toolchain.cmake.in * refactor: GGML_VULKAN_SHADERS_GEN_TOOLCHAIN Rename the macro GGML_SHADERS_GEN_TOOLCHAIN to GGML_VULKAN_SHADERS_GEN_TOOLCHAIN --- ggml/CMakeLists.txt | 3 + ggml/src/ggml-vulkan/CMakeLists.txt | 66 +++++++++++++++++-- .../ggml-vulkan/cmake/host-toolchain.cmake.in | 15 +++++ .../ggml-vulkan/vulkan-shaders/CMakeLists.txt | 6 +- .../vulkan-shaders/vulkan-shaders-gen.cpp | 2 - 5 files changed, 81 insertions(+), 11 deletions(-) create mode 100644 ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index fe8acc8038b..185079aa47c 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -185,6 +185,9 @@ option(GGML_OPENCL_PROFILING "ggml: use OpenCL profiling (increas option(GGML_OPENCL_EMBED_KERNELS "ggml: embed kernels" ON) option(GGML_OPENCL_USE_ADRENO_KERNELS "ggml: use optimized kernels for Adreno" ON) +# toolchain for vulkan-shaders-gen +set (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN "" CACHE FILEPATH "ggml: toolchain file for vulkan-shaders-gen") + # extra artifacts option(GGML_BUILD_TESTS "ggml: build tests" ${GGML_STANDALONE}) option(GGML_BUILD_EXAMPLES "ggml: build examples" ${GGML_STANDALONE}) diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index c0ddaac827f..d970f7e20b4 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -1,5 +1,20 @@ +cmake_minimum_required(VERSION 3.19) +cmake_policy(SET CMP0114 NEW) + find_package(Vulkan COMPONENTS glslc REQUIRED) +function(detect_host_compiler) + if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows") + find_program(HOST_C_COMPILER NAMES cl gcc clang NO_CMAKE_FIND_ROOT_PATH) + find_program(HOST_CXX_COMPILER NAMES cl g++ clang++ NO_CMAKE_FIND_ROOT_PATH) + else() + find_program(HOST_C_COMPILER NAMES gcc clang NO_CMAKE_FIND_ROOT_PATH) + find_program(HOST_CXX_COMPILER NAMES g++ clang++ NO_CMAKE_FIND_ROOT_PATH) + endif() + set(HOST_C_COMPILER "${HOST_C_COMPILER}" PARENT_SCOPE) + set(HOST_CXX_COMPILER "${HOST_CXX_COMPILER}" PARENT_SCOPE) +endfunction() + if (Vulkan_FOUND) message(STATUS "Vulkan found") @@ -73,19 +88,56 @@ if (Vulkan_FOUND) add_compile_definitions(GGML_VULKAN_RUN_TESTS) endif() - add_subdirectory(vulkan-shaders) - - set (_ggml_vk_genshaders_cmd vulkan-shaders-gen) + if (NOT CMAKE_CROSSCOMPILING) + add_subdirectory(vulkan-shaders) + if (MSVC) + foreach(CONFIG ${CMAKE_CONFIGURATION_TYPES}) + string(TOUPPER ${CONFIG} CONFIG) + set_target_properties(vulkan-shaders-gen PROPERTIES + RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) + endforeach() + endif() + else() + if (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN) + set(HOST_CMAKE_TOOLCHAIN_FILE ${GGML_VULKAN_SHADERS_GEN_TOOLCHAIN}) + else() + detect_host_compiler() + if (NOT HOST_C_COMPILER OR NOT HOST_CXX_COMPILER) + message(FATAL_ERROR "Host compiler not found") + else() + message(STATUS "Host compiler: ${HOST_C_COMPILER} ${HOST_CXX_COMPILER}") + endif() + configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/host-toolchain.cmake.in ${CMAKE_BINARY_DIR}/host-toolchain.cmake @ONLY) + set(HOST_CMAKE_TOOLCHAIN_FILE ${CMAKE_BINARY_DIR}/host-toolchain.cmake) + endif() + message(STATUS "vulkan-shaders-gen toolchain file: ${HOST_CMAKE_TOOLCHAIN_FILE}") + + include(ExternalProject) + # Native build through ExternalProject_Add + ExternalProject_Add( + vulkan-shaders-gen + SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders + CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE} + -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR} + BUILD_COMMAND ${CMAKE_COMMAND} --build . + INSTALL_COMMAND ${CMAKE_COMMAND} --install . + INSTALL_DIR ${CMAKE_BINARY_DIR} + ) + ExternalProject_Add_StepTargets(vulkan-shaders-gen build install) + endif() + set (_ggml_vk_host_suffix $,.exe,>) + set (_ggml_vk_genshaders_cmd ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/vulkan-shaders-gen${_ggml_vk_host_suffix}) set (_ggml_vk_header ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp) set (_ggml_vk_source ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp) set (_ggml_vk_input_dir ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders) set (_ggml_vk_output_dir ${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv) file(GLOB _ggml_vk_shader_deps "${_ggml_vk_input_dir}/*.comp") + set (_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen) - if (NOT CMAKE_CROSSCOMPILING) - set(_ggml_vk_genshaders_cmd "$/${_ggml_vk_genshaders_cmd}") - endif () + if (CMAKE_CROSSCOMPILING) + set(_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen-build vulkan-shaders-gen-install) + endif() add_custom_command( OUTPUT ${_ggml_vk_header} @@ -99,7 +151,7 @@ if (Vulkan_FOUND) --target-cpp ${_ggml_vk_source} --no-clean - DEPENDS ${_ggml_vk_shader_deps} ${_ggml_vk_genshaders_cmd} + DEPENDS ${_ggml_vk_shader_deps} COMMENT "Generate vulkan shaders" ) diff --git a/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in b/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in new file mode 100644 index 00000000000..b6af747a500 --- /dev/null +++ b/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in @@ -0,0 +1,15 @@ +set(CMAKE_BUILD_TYPE Release) +set(CMAKE_C_FLAGS -O2) +set(CMAKE_CXX_FLAGS -O2) +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE NEVER) +set(CMAKE_C_COMPILER @HOST_C_COMPILER@) +set(CMAKE_CXX_COMPILER @HOST_CXX_COMPILER@) +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY @CMAKE_RUNTIME_OUTPUT_DIRECTORY@) + +if("@CMAKE_C_COMPILER_ID@" STREQUAL "MSVC") + foreach(CONFIG IN ITEMS DEBUG RELEASE MINSIZEREL RELWITHDEBINFO) + set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) + endforeach() +endif() diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt b/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt index bd0c74cb1c7..074031087f4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt @@ -1,9 +1,11 @@ find_package (Threads REQUIRED) -find_package(Vulkan COMPONENTS glslc REQUIRED) +find_program(GLSLC_EXECUTABLE glslc) +if(NOT GLSLC_EXECUTABLE) + message(FATAL_ERROR "glslc not found.") +endif() set(TARGET vulkan-shaders-gen) add_executable(${TARGET} vulkan-shaders-gen.cpp) install(TARGETS ${TARGET} RUNTIME) target_compile_features(${TARGET} PRIVATE cxx_std_17) target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads) -target_link_libraries(vulkan-shaders-gen PRIVATE Vulkan::Vulkan) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 7b5044798d7..2438399174d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -30,8 +30,6 @@ #include #endif -#include - #define ASYNCIO_CONCURRENCY 64 std::mutex lock; From 164f13c6a9054c704682dbcfb10c37718863e660 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Wed, 15 Jan 2025 19:50:13 +0000 Subject: [PATCH 07/64] vulkan: scale caching for k quants + misc fixes (llama/11081) * q6_k scale caching * 16 bit unpack * q4_k test (slow) * revert it * q3_k * q2_k * little stuff * try precalculating products of a and q2_k scales * Revert "try precalculating products of a and q2_k scales" This reverts commit 65110b81f23f66331a50c6e889a7c1ab9470a86b. * unpack should be u16, add vim swap to gitignore (about time) * better q4_k scales * q5_k * better q6_k with separate paths for all threads and partial threads in use, plus some more optimizations * q2_k better dequant * q3_k optimizations * q3_k use hmask simd from cpu avx version * make the caches happy * q3_k separate out calculation * q2_k separate out * little stuff * use calc_superblock everywhere * q2_k optimize scale calculation * more barriers --- .../vulkan-shaders/mul_mat_vec_q2_k.comp | 154 ++++++------ .../vulkan-shaders/mul_mat_vec_q3_k.comp | 143 ++++++----- .../vulkan-shaders/mul_mat_vec_q4_k.comp | 169 ++++++------- .../vulkan-shaders/mul_mat_vec_q5_k.comp | 229 +++++++++--------- .../vulkan-shaders/mul_mat_vec_q6_k.comp | 144 ++++++----- 5 files changed, 454 insertions(+), 385 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp index 6a9b9b2d132..8cdc640e80e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -5,6 +5,80 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +shared FLOAT_TYPE sccache1[BLOCK_SIZE/16][16]; +shared FLOAT_TYPE sccache2[BLOCK_SIZE/16][16]; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint v_im, const uint ix, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { + const uint y_idx = i * QUANT_K + y_offset; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + + barrier(); + if (!all_threads) { // when we don't have enough blocks to use all threads + if (i < num_blocks_per_row) { + const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]); + sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF); + sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF); + } + barrier(); + + if (i >= num_blocks_per_row) + continue; + } else { + const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]); + sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF); + sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF); + barrier(); + } + + const uint32_t qs_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); + const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303)); + const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303)); + const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); + const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); + + vec2 d = vec2(data_a[ib0 + i].d); + const FLOAT_TYPE dall = FLOAT_TYPE(d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]); + vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]); + vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]); + vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]); + vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]); + vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]); + vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]); + vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]); + + FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); + FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); + [[unroll]] for (int l = 0; l < 2; ++l) { + sum1 = fma(FLOAT_TYPE(b0[l]), sccache1[ix][ 8*v_im] * qs_u32_0[l ], + fma(FLOAT_TYPE(b16[l]), sccache1[ix][1 + 8*v_im] * qs_u32_0[l+2], + fma(FLOAT_TYPE(b32[l]), sccache1[ix][2 + 8*v_im] * qs_u32_2[l ], + fma(FLOAT_TYPE(b48[l]), sccache1[ix][3 + 8*v_im] * qs_u32_2[l+2], + fma(FLOAT_TYPE(b64[l]), sccache1[ix][4 + 8*v_im] * qs_u32_4[l ], + fma(FLOAT_TYPE(b80[l]), sccache1[ix][5 + 8*v_im] * qs_u32_4[l+2], + fma(FLOAT_TYPE(b96[l]), sccache1[ix][6 + 8*v_im] * qs_u32_6[l ], + fma(FLOAT_TYPE(b112[l]), sccache1[ix][7 + 8*v_im] * qs_u32_6[l+2], sum1)))))))); + sum2 = fma(FLOAT_TYPE(b0[l]), sccache2[ix][ 8*v_im], + fma(FLOAT_TYPE(b16[l]), sccache2[ix][1 + 8*v_im], + fma(FLOAT_TYPE(b32[l]), sccache2[ix][2 + 8*v_im], + fma(FLOAT_TYPE(b48[l]), sccache2[ix][3 + 8*v_im], + fma(FLOAT_TYPE(b64[l]), sccache2[ix][4 + 8*v_im], + fma(FLOAT_TYPE(b80[l]), sccache2[ix][5 + 8*v_im], + fma(FLOAT_TYPE(b96[l]), sccache2[ix][6 + 8*v_im], + fma(FLOAT_TYPE(b112[l]), sccache2[ix][7 + 8*v_im], sum2)))))))); + } + temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n])); + } + } +} + void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uint a_offset, b_offset, d_offset; get_offsets(a_offset, b_offset, d_offset); @@ -14,88 +88,28 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { // 16 threads are used to process each block const uint it_size = gl_WorkGroupSize.x/16; const uint tid = gl_LocalInvocationID.x; - const uint itid = tid%16; // 0...16 - const uint ix = tid/16; - - const uint step = 8; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; - const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const uint v_in = itid - step*v_im; // 0...15 or 0...7 + const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_in = itid - 8*v_im; // 0...7 const uint l0 = 2*v_in; // 0...15 const uint q_offset = 32*v_im + l0; - const uint s_offset = 8*v_im; const uint y_offset = 128*v_im + l0; - FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { temp[j][i] = FLOAT_TYPE(0); } } - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { - const uint y_idx = i * QUANT_K + y_offset; - - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; - vec2 d = vec2(data_a[ib0 + i].d); - const FLOAT_TYPE dall = FLOAT_TYPE(d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); - - uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0]; - uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1]; - - uint32_t s0_lo4_u32 = s0_u32 & 0x0F0F0F0F; - uint32_t s0_hi4_u32 = (s0_u32 >> 4) & 0x0F0F0F0F; - uint32_t s4_lo4_u32 = s4_u32 & 0x0F0F0F0F; - uint32_t s4_hi4_u32 = (s4_u32 >> 4) & 0x0F0F0F0F; - - uvec4 s0_lo4 = uvec4(unpack8(s0_lo4_u32)); - uvec4 s4_lo4 = uvec4(unpack8(s4_lo4_u32)); - uvec4 s0_hi4 = uvec4(unpack8(s0_hi4_u32)); - uvec4 s4_hi4 = uvec4(unpack8(s4_hi4_u32)); - - uint16_t qs0_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 0]; - uint16_t qs16_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]; - uvec2 qs0 = uvec2(unpack8(qs0_u16)); - uvec2 qs16 = uvec2(unpack8(qs16_u16)); - - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { - vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]); - vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]); - vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]); - vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]); - vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]); - vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]); - vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]); - vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]); - - FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); - FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); - [[unroll]] for (int l = 0; l < 2; ++l) { - sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3), - fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3), - fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3), - fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3), - fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l] >> 4) & 3), - fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3), - fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l] >> 6) & 3), - fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1)))))))); - sum2 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_hi4[0]), - fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_hi4[1]), - fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_hi4[2]), - fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_hi4[3]), - fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_hi4[0]), - fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_hi4[1]), - fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_hi4[2]), - fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2)))))))); - } - temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n])); - } - } - } + const uint nbr_par_th = num_blocks_per_row%it_size; + const uint nbr_all_th = num_blocks_per_row - nbr_par_th; + uint i0 = 0; + [[unroll]] for (; i0 < nbr_all_th; i0 += it_size) + calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true); + calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false); reduce_result(temp, d_offset, first_row, num_rows, tid); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp index 96ef50fdda2..3116fad165b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp @@ -5,6 +5,74 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +shared FLOAT_TYPE sccache[BLOCK_SIZE/16][2][8]; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, const uint itid8, const uint v_im, const uint v_im4, const uint v_in, const uint32_t hm_m[4], const uint q_offset, const uint y_offset, const uint s_shift, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { + const uint y_idx = i * QUANT_K + y_offset; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + + if (!all_threads) { // when we don't have enough blocks to use all threads + barrier(); + if (i < num_blocks_per_row) + sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); + barrier(); + + if (i >= num_blocks_per_row) + continue; + } + + const uint32_t hmk = ~(uint32_t(data_a_packed16[ib0 + i].hmask[v_in]) | (uint32_t(data_a_packed16[ib0 + i].hmask[v_in + 8]) << 16)); + const vec4 hmk_0 = vec4(unpack8(((hmk & hm_m[0]) >> ( v_im4)) << 2)); + const vec4 hmk_1 = vec4(unpack8(((hmk & hm_m[1]) >> (1 + v_im4)) << 2)); + const vec4 hmk_2 = vec4(unpack8(((hmk & hm_m[2]) >> (2 + v_im4)) << 2)); + const vec4 hmk_3 = vec4(unpack8(((hmk & hm_m[3]) >> (3 + v_im4)) << 2)); + + // 0, 1, 16, 17 + uint32_t qs_u32 = uint32_t(data_a[ib0 + i].qs[q_offset]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 1]) << 8); + qs_u32 |= (uint32_t(data_a[ib0 + i].qs[q_offset + 16]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 17]) << 8)) << 16; + const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303)); + const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303)); + const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); + const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); + + if (all_threads) { + barrier(); + sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); + barrier(); + } + + const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]); + vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]); + vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]); + vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]); + vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]); + vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]); + vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]); + vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0); + [[unroll]] for (int l = 0; l < 2; ++l) { + sum = fma(FLOAT_TYPE( b0[l]) * sccache[ix][v_im][0], qs_u32_0[l ] - hmk_0[l ], + fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2], + fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], qs_u32_2[l ] - hmk_1[l ], + fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2], + fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], qs_u32_4[l ] - hmk_2[l ], + fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2], + fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], qs_u32_6[l ] - hmk_3[l ], + fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum)))))))); + } + temp[j][n] = fma(d, sum, temp[j][n]); + } + } +} + void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uint a_offset, b_offset, d_offset; get_offsets(a_offset, b_offset, d_offset); @@ -14,76 +82,37 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { // 16 threads are used to process each block const uint it_size = gl_WorkGroupSize.x/16; const uint tid = gl_LocalInvocationID.x; - const uint itid = tid%16; // 0...16 - const uint ix = tid/16; - - const uint step = 8; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; + const uint itid8 = itid%8; - const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const uint v_in = itid - step*v_im; // 0...15 or 0...7 + const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_im4 = v_im*4; + const uint v_in = itid - 8*v_im; // 0...7 - const uint8_t m = uint8_t(1 << (4 * v_im)); + const uint32_t m = 0x01010101 << (4 * v_im); + uint32_t hm_m[4]; + [[unroll]] for (uint j = 0; j < 4; ++j) + hm_m[j] = m << j; const uint l0 = 2*v_in; // 0...15 const uint q_offset = 32*v_im + l0; const uint y_offset = 128*v_im + l0; - FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { temp[j][i] = FLOAT_TYPE(0); } } - const uint s_shift = 4 * v_im; - - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { - const uint y_idx = i * QUANT_K + y_offset; - - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; - const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); - - uint16_t s0_16 = data_a_packed16[ib0 + i].scales[0]; - uint16_t s2_16 = data_a_packed16[ib0 + i].scales[1]; - uint16_t s4_16 = data_a_packed16[ib0 + i].scales[2]; - uint16_t s6_16 = data_a_packed16[ib0 + i].scales[3]; - uint16_t s8_16 = data_a_packed16[ib0 + i].scales[4]; - uint16_t s10_16 = data_a_packed16[ib0 + i].scales[5]; - u8vec2 s0 = unpack8(s0_16); - u8vec2 s2 = unpack8(s2_16); - u8vec2 s4 = unpack8(s4_16); - u8vec2 s6 = unpack8(s6_16); - u8vec2 s8 = unpack8(s8_16); - u8vec2 s10 = unpack8(s10_16); - - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { - - vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]); - vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]); - vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]); - vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]); - vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]); - vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]); - vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]); - vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]); - - FLOAT_TYPE sum = FLOAT_TYPE(0.0); - [[unroll]] for (int l = 0; l < 2; ++l) { - sum = fma(FLOAT_TYPE(b0[l]) * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b32[l]) * FLOAT_TYPE(int8_t(((s2[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b64[l]) * FLOAT_TYPE(int8_t(((s4[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b96[l]) * FLOAT_TYPE(int8_t(((s6[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b16[l]) * FLOAT_TYPE(int8_t(((s0[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b48[l]) * FLOAT_TYPE(int8_t(((s2[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b80[l]) * FLOAT_TYPE(int8_t(((s4[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(b112[l]) * FLOAT_TYPE(int8_t(((s6[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum)))))))); - } - temp[j][n] = fma(d, sum, temp[j][n]); - } - } - } + const uint s_shift = v_im4 + 2*(itid8/4); + + const uint nbr_par_th = num_blocks_per_row%it_size; + const uint nbr_all_th = num_blocks_per_row - nbr_par_th; + uint i0 = 0; + [[unroll]] for (; i0 < nbr_all_th; i0 += it_size) + calc_superblock(a_offset, b_offset, ix, itid8, v_im, v_im4, v_in, hm_m, q_offset, y_offset, s_shift, i0 + ix, num_blocks_per_row, first_row, num_rows, true); + calc_superblock(a_offset, b_offset, ix, itid8, v_im, v_im4, v_in, hm_m, q_offset, y_offset, s_shift, i0 + ix, num_blocks_per_row, first_row, num_rows, false); reduce_result(temp, d_offset, first_row, num_rows, tid); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp index f97eb8744fb..f9cde064887 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -6,6 +6,86 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y1_idx = i * QUANT_K + y_offset; + const uint y2_idx = y1_idx + 128; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + vec2 d = vec2(data_a[ib0 + i].d); + const FLOAT_TYPE dall = FLOAT_TYPE(d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + + const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; + const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; + const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; + + const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32; + const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2; + const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F)); + const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h)); + + const FLOAT_TYPE sc0 = scale_0_4_l_f.x; + const FLOAT_TYPE sc1 = scale_0_4_l_f.y; + const FLOAT_TYPE sc2 = scale_0_4_l_f.z; + const FLOAT_TYPE sc3 = scale_0_4_l_f.w; + const FLOAT_TYPE sc4 = scale8_f.x; + const FLOAT_TYPE sc5 = scale8_f.y; + const FLOAT_TYPE sc6 = scale8_f.z; + const FLOAT_TYPE sc7 = scale8_f.w; + + const uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4]; + const uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16]; + + const uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F; + const uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F; + const uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F; + const uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F; + + const vec4 qs0_lo4 = vec4(unpack8(qs0_u32_lo4)); + const vec4 qs64_lo4 = vec4(unpack8(qs64_u32_lo4)); + const vec4 qs0_hi4 = vec4(unpack8(qs0_u32_hi4)); + const vec4 qs64_hi4 = vec4(unpack8(qs64_u32_hi4)); + + const FLOAT_TYPE q4_0 = qs0_lo4.x; + const FLOAT_TYPE q4_1 = qs0_lo4.y; + const FLOAT_TYPE q4_2 = qs0_lo4.z; + const FLOAT_TYPE q4_3 = qs0_lo4.w; + const FLOAT_TYPE q4_4 = qs0_hi4.x; + const FLOAT_TYPE q4_5 = qs0_hi4.y; + const FLOAT_TYPE q4_6 = qs0_hi4.z; + const FLOAT_TYPE q4_7 = qs0_hi4.w; + const FLOAT_TYPE q4_8 = qs64_lo4.x; + const FLOAT_TYPE q4_9 = qs64_lo4.y; + const FLOAT_TYPE q4_10 = qs64_lo4.z; + const FLOAT_TYPE q4_11 = qs64_lo4.w; + const FLOAT_TYPE q4_12 = qs64_hi4.x; + const FLOAT_TYPE q4_13 = qs64_hi4.y; + const FLOAT_TYPE q4_14 = qs64_hi4.z; + const FLOAT_TYPE q4_15 = qs64_hi4.w; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 by10 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 ]); + vec4 by132 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]); + vec4 by20 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 ]); + vec4 by232 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]); + + const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3))); + const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7))); + const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11))); + const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x), q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15))); + const FLOAT_TYPE smin = + fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7, + fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7, + fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7, + fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7))))))))))))))); + temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); + } + } +} + void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uint a_offset, b_offset, d_offset; get_offsets(a_offset, b_offset, d_offset); @@ -15,13 +95,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { // 16 threads are used to process each block const uint it_size = gl_WorkGroupSize.x/16; const uint tid = gl_LocalInvocationID.x; - const uint itid = tid%16; // 0...16 - const uint ix = tid/16; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; - const uint step = 4; - - const uint il = itid/step; // 0...3 - const uint ir = itid - step*il; // 0...7 or 0...3 + const uint il = itid/4; // 0...3 + const uint ir = itid - 4*il; // 0...3 const uint n = 4; const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 @@ -31,89 +109,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint q_offset = 32*v_im + l0; const uint y_offset = 64*v_im + l0; - FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { temp[j][i] = FLOAT_TYPE(0); } } - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { - const uint y1_idx = i * QUANT_K + y_offset; - const uint y2_idx = y1_idx + 128; - - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; - vec2 d = vec2(data_a[ib0 + i].d); - const FLOAT_TYPE dall = FLOAT_TYPE(d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); - - uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; - uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; - uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; - uvec4 scale0 = uvec4(unpack8(scale0_u32)); - uvec4 scale4 = uvec4(unpack8(scale4_u32)); - uvec4 scale8 = uvec4(unpack8(scale8_u32)); - - const uint32_t sc0 = ( scale0.x & 0x3f); - const uint32_t sc1 = ( scale0.y & 0x3f); - const uint32_t sc2 = ( scale4.x & 0x3f); - const uint32_t sc3 = ( scale4.y & 0x3f); - const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2)); - const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2)); - const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2)); - const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2)); - - uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4]; - uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16]; - - uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F; - uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F; - uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F; - uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F; - - uvec4 qs0_lo4 = uvec4(unpack8(qs0_u32_lo4)); - uvec4 qs64_lo4 = uvec4(unpack8(qs64_u32_lo4)); - uvec4 qs0_hi4 = uvec4(unpack8(qs0_u32_hi4)); - uvec4 qs64_hi4 = uvec4(unpack8(qs64_u32_hi4)); - - const uint32_t q4_0 = qs0_lo4.x; - const uint32_t q4_1 = qs0_lo4.y; - const uint32_t q4_2 = qs0_lo4.z; - const uint32_t q4_3 = qs0_lo4.w; - const uint32_t q4_4 = qs0_hi4.x; - const uint32_t q4_5 = qs0_hi4.y; - const uint32_t q4_6 = qs0_hi4.z; - const uint32_t q4_7 = qs0_hi4.w; - const uint32_t q4_8 = qs64_lo4.x; - const uint32_t q4_9 = qs64_lo4.y; - const uint32_t q4_10 = qs64_lo4.z; - const uint32_t q4_11 = qs64_lo4.w; - const uint32_t q4_12 = qs64_hi4.x; - const uint32_t q4_13 = qs64_hi4.y; - const uint32_t q4_14 = qs64_hi4.z; - const uint32_t q4_15 = qs64_hi4.w; - - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { - vec4 by10 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 ]); - vec4 by132 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]); - vec4 by20 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 ]); - vec4 by232 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]); - - const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3))); - const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7))); - const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11))); - const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x), q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15))); - const FLOAT_TYPE smin = - fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7, - fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7, - fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7, - fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7))))))))))))))); - temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); - } - } - } + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) + calc_superblock(a_offset, b_offset, v_im, q_offset, y_offset, i, num_blocks_per_row, first_row, num_rows); reduce_result(temp, d_offset, first_row, num_rows, tid); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp index 79d7db0e3e6..6c84ef3cde3 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp @@ -6,6 +6,118 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint l0, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y1_idx = i * QUANT_K + y_offset; + const uint y2_idx = y1_idx + 128; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + vec2 d = vec2(data_a[ib0 + i].d); + const FLOAT_TYPE dall = FLOAT_TYPE(d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + + const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; + const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; + const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; + + const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32; + const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2; + const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F)); + const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h)); + + const FLOAT_TYPE sc0 = scale_0_4_l_f.x; + const FLOAT_TYPE sc1 = scale_0_4_l_f.y; + const FLOAT_TYPE sc2 = scale_0_4_l_f.z; + const FLOAT_TYPE sc3 = scale_0_4_l_f.w; + const FLOAT_TYPE sc4 = scale8_f.x; + const FLOAT_TYPE sc5 = scale8_f.y; + const FLOAT_TYPE sc6 = scale8_f.z; + const FLOAT_TYPE sc7 = scale8_f.w; + + const uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); + const uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16); + + uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F; + uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F; + uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F; + uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F; + + const uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8])); + + const uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4; + const uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3; + const uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010); + const uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1; + + qs0_16_u32_lo4 += qs0_16_lo4_offset16; + qs0_16_u32_hi4 += qs0_16_hi4_offset16; + qs64_80_u32_lo4 += qs64_80_lo4_offset16; + qs64_80_u32_hi4 += qs64_80_hi4_offset16; + + const vec4 qs0_16_lo4 = vec4(unpack8(qs0_16_u32_lo4)); + const vec4 qs64_80_lo4 = vec4(unpack8(qs64_80_u32_lo4)); + const vec4 qs0_16_hi4 = vec4(unpack8(qs0_16_u32_hi4)); + const vec4 qs64_80_hi4 = vec4(unpack8(qs64_80_u32_hi4)); + + const FLOAT_TYPE q4_0 = qs0_16_lo4.x; + const FLOAT_TYPE q4_1 = qs0_16_lo4.y; + const FLOAT_TYPE q4_2 = qs0_16_lo4.z; + const FLOAT_TYPE q4_3 = qs0_16_lo4.w; + const FLOAT_TYPE q4_4 = qs0_16_hi4.x; + const FLOAT_TYPE q4_5 = qs0_16_hi4.y; + const FLOAT_TYPE q4_6 = qs0_16_hi4.z; + const FLOAT_TYPE q4_7 = qs0_16_hi4.w; + const FLOAT_TYPE q4_8 = qs64_80_lo4.x; + const FLOAT_TYPE q4_9 = qs64_80_lo4.y; + const FLOAT_TYPE q4_10 = qs64_80_lo4.z; + const FLOAT_TYPE q4_11 = qs64_80_lo4.w; + const FLOAT_TYPE q4_12 = qs64_80_hi4.x; + const FLOAT_TYPE q4_13 = qs64_80_hi4.y; + const FLOAT_TYPE q4_14 = qs64_80_hi4.z; + const FLOAT_TYPE q4_15 = qs64_80_hi4.w; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec2 by10 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 ]); + vec2 by116 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8]); + vec2 by132 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]); + vec2 by148 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]); + vec2 by20 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 ]); + vec2 by216 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8]); + vec2 by232 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]); + vec2 by248 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]); + + const FLOAT_TYPE sx = + fma(FLOAT_TYPE(by10.x), q4_0, + fma(FLOAT_TYPE(by10.y), q4_1, + fma(FLOAT_TYPE(by116.x), q4_2, + FLOAT_TYPE(by116.y) * q4_3))); + const FLOAT_TYPE sy = + fma(FLOAT_TYPE(by132.x), q4_4, + fma(FLOAT_TYPE(by132.y), q4_5, + fma(FLOAT_TYPE(by148.x), q4_6, + FLOAT_TYPE(by148.y) * q4_7))); + const FLOAT_TYPE sz = + fma(FLOAT_TYPE(by20.x), q4_8, + fma(FLOAT_TYPE(by20.y), q4_9, + fma(FLOAT_TYPE(by216.x), q4_10, + FLOAT_TYPE(by216.y) * q4_11))); + const FLOAT_TYPE sw = + fma(FLOAT_TYPE(by232.x), q4_12, + fma(FLOAT_TYPE(by232.y), q4_13, + fma(FLOAT_TYPE(by248.x), q4_14, + FLOAT_TYPE(by248.y) * q4_15))); + const FLOAT_TYPE smin = + fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2, + fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3, + fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6, + (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7))); + temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); + } + } +} + void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uint a_offset, b_offset, d_offset; get_offsets(a_offset, b_offset, d_offset); @@ -15,11 +127,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { // 16 threads are used to process each block const uint it_size = gl_WorkGroupSize.x/16; const uint tid = gl_LocalInvocationID.x; - const uint itid = tid%16; // 0...16 - const uint ix = tid/16; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; const uint il = itid/4; // 0...3 - const uint ir = itid - 4*il; // 0...7 or 0...3 + const uint ir = itid - 4*il; // 0...3 const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 const uint v_in = il % 2; @@ -28,121 +140,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint q_offset = 32*v_im + l0; const uint y_offset = 64*v_im + l0; - FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { temp[j][i] = FLOAT_TYPE(0); } } - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { - const uint y1_idx = i * QUANT_K + y_offset; - const uint y2_idx = y1_idx + 128; - - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; - vec2 d = vec2(data_a[ib0 + i].d); - const FLOAT_TYPE dall = FLOAT_TYPE(d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); - - uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; - uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; - uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; - uvec4 scale0 = uvec4(unpack8(scale0_u32)); - uvec4 scale4 = uvec4(unpack8(scale4_u32)); - uvec4 scale8 = uvec4(unpack8(scale8_u32)); - - const uint32_t sc0 = ( scale0.x & 0x3f); - const uint32_t sc1 = ( scale0.y & 0x3f); - const uint32_t sc2 = ( scale4.x & 0x3f); - const uint32_t sc3 = ( scale4.y & 0x3f); - const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2)); - const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2)); - const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2)); - const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2)); - - uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); - uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16); - - uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F; - uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F; - uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F; - uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F; - - uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8])); - - uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4; - uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3; - uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010) << 0; - uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1; - - qs0_16_u32_lo4 += qs0_16_lo4_offset16; - qs0_16_u32_hi4 += qs0_16_hi4_offset16; - qs64_80_u32_lo4 += qs64_80_lo4_offset16; - qs64_80_u32_hi4 += qs64_80_hi4_offset16; - - uvec4 qs0_16_lo4 = uvec4(unpack8(qs0_16_u32_lo4)); - uvec4 qs64_80_lo4 = uvec4(unpack8(qs64_80_u32_lo4)); - uvec4 qs0_16_hi4 = uvec4(unpack8(qs0_16_u32_hi4)); - uvec4 qs64_80_hi4 = uvec4(unpack8(qs64_80_u32_hi4)); - - const uint32_t q4_0 = qs0_16_lo4.x; - const uint32_t q4_1 = qs0_16_lo4.y; - const uint32_t q4_2 = qs0_16_lo4.z; - const uint32_t q4_3 = qs0_16_lo4.w; - const uint32_t q4_4 = qs0_16_hi4.x; - const uint32_t q4_5 = qs0_16_hi4.y; - const uint32_t q4_6 = qs0_16_hi4.z; - const uint32_t q4_7 = qs0_16_hi4.w; - const uint32_t q4_8 = qs64_80_lo4.x; - const uint32_t q4_9 = qs64_80_lo4.y; - const uint32_t q4_10 = qs64_80_lo4.z; - const uint32_t q4_11 = qs64_80_lo4.w; - const uint32_t q4_12 = qs64_80_hi4.x; - const uint32_t q4_13 = qs64_80_hi4.y; - const uint32_t q4_14 = qs64_80_hi4.z; - const uint32_t q4_15 = qs64_80_hi4.w; - - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { - vec2 by10 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 ]); - vec2 by116 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8]); - vec2 by132 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]); - vec2 by148 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]); - vec2 by20 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 ]); - vec2 by216 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8]); - vec2 by232 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]); - vec2 by248 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]); - - const FLOAT_TYPE sx = - fma(FLOAT_TYPE(by10.x), q4_0, - fma(FLOAT_TYPE(by10.y), q4_1, - fma(FLOAT_TYPE(by116.x), q4_2, - FLOAT_TYPE(by116.y) * q4_3))); - const FLOAT_TYPE sy = - fma(FLOAT_TYPE(by132.x), q4_4, - fma(FLOAT_TYPE(by132.y), q4_5, - fma(FLOAT_TYPE(by148.x), q4_6, - FLOAT_TYPE(by148.y) * q4_7))); - const FLOAT_TYPE sz = - fma(FLOAT_TYPE(by20.x), q4_8, - fma(FLOAT_TYPE(by20.y), q4_9, - fma(FLOAT_TYPE(by216.x), q4_10, - FLOAT_TYPE(by216.y) * q4_11))); - const FLOAT_TYPE sw = - fma(FLOAT_TYPE(by232.x), q4_12, - fma(FLOAT_TYPE(by232.y), q4_13, - fma(FLOAT_TYPE(by248.x), q4_14, - FLOAT_TYPE(by248.y) * q4_15))); - const FLOAT_TYPE smin = - fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2, - fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3, - fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6, - (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7))); - temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); - } - } - } + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) + calc_superblock(a_offset, b_offset, v_im, l0, q_offset, y_offset, i, num_blocks_per_row, first_row, num_rows); reduce_result(temp, d_offset, first_row, num_rows, tid); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp index 041fd27c12b..f05f96b5efb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp @@ -6,7 +6,77 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { +shared FLOAT_TYPE sccache[BLOCK_SIZE/16][16]; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint ix, const uint ql_offset, const uint qh_offset, const uint s_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { + const uint y_idx = i * QUANT_K + y_offset; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + + if (!all_threads) { // when we don't have enough blocks to use all threads + barrier(); + if (i < num_blocks_per_row) + sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]); + barrier(); + + if (i >= num_blocks_per_row) + continue; + } + + const uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16); + const uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16); + + const uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F; + const uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F; + const uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F; + const uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F; + + const uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16); + const uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4; + const uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2; + const uint32_t qh4_u32 = (qh_u32 & 0x30303030); + const uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2; + + const uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32; + const uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32; + const uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32; + const uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32; + + const vec4 q0 = vec4(unpack8(q0_u32)) - 32; + const vec4 q1 = vec4(unpack8(q1_u32)) - 32; + const vec4 q2 = vec4(unpack8(q2_u32)) - 32; + const vec4 q3 = vec4(unpack8(q3_u32)) - 32; + + if (all_threads) { + barrier(); + sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]); + barrier(); + } + + const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 by0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 ]); + vec4 by32 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8]); + vec4 by64 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]); + vec4 by96 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]); + + FLOAT_TYPE sum[4] = {0, 0, 0, 0}; + [[unroll]] for (uint l = 0; l < 4; ++l) { + sum[0] = fma(FLOAT_TYPE(by0[l]), q0[l], sum[0]); + sum[1] = fma(FLOAT_TYPE(by32[l]), q1[l], sum[1]); + sum[2] = fma(FLOAT_TYPE(by64[l]), q2[l], sum[2]); + sum[3] = fma(FLOAT_TYPE(by96[l]), q3[l], sum[3]); + } + temp[j][n] = fma(fma(sum[0], sccache[ix][s_offset], fma(sum[1], sccache[ix][s_offset + 2], fma(sum[2], sccache[ix][s_offset + 4], sum[3] * sccache[ix][s_offset + 6]))), d, temp[j][n]); + } + } +} + +void compute_outputs(const uint first_row, const uint num_rows) { uint a_offset, b_offset, d_offset; get_offsets(a_offset, b_offset, d_offset); @@ -15,13 +85,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { // 16 threads are used to process each block const uint it_size = gl_WorkGroupSize.x/16; const uint tid = gl_LocalInvocationID.x; - const uint itid = tid%16; // 0...16 - const uint ix = tid/16; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; - const uint step = 8; - - const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const uint v_in = itid - step*v_im; // 0...15 or 0...7 + const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_in = itid - 8*v_im; // 0...7 const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28 const uint is = v_in / 4; @@ -31,68 +99,18 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint s_offset = 8*v_im + is; const uint y_offset = 128*v_im + l0; - FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { temp[j][i] = FLOAT_TYPE(0); } } - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { - const uint y_idx = i * QUANT_K + y_offset; - - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; - const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); - - FLOAT_TYPE scales[4]; - scales[0] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]); - scales[1] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]); - scales[2] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]); - scales[3] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]); - - uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16); - uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16); - - uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F; - uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F; - uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F; - uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F; - - uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16); - uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4; - uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2; - uint32_t qh4_u32 = (qh_u32 & 0x30303030) << 0; - uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2; - - uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32; - uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32; - uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32; - uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32; - - uvec4 q0 = uvec4(unpack8(q0_u32)); - uvec4 q1 = uvec4(unpack8(q1_u32)); - uvec4 q2 = uvec4(unpack8(q2_u32)); - uvec4 q3 = uvec4(unpack8(q3_u32)); - - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { - vec4 by0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 ]); - vec4 by32 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8]); - vec4 by64 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]); - vec4 by96 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]); - - FLOAT_TYPE sum = FLOAT_TYPE(0.0); - [[unroll]] for (int l = 0; l < 4; ++l) { - sum = fma(FLOAT_TYPE(by0[l]) * scales[0], FLOAT_TYPE(int8_t(q0[l]) - 32), - fma(FLOAT_TYPE(by32[l]) * scales[1], FLOAT_TYPE(int8_t(q1[l]) - 32), - fma(FLOAT_TYPE(by64[l]) * scales[2], FLOAT_TYPE(int8_t(q2[l]) - 32), - fma(FLOAT_TYPE(by96[l]) * scales[3], FLOAT_TYPE(int8_t(q3[l]) - 32), sum)))); - } - temp[j][n] += sum * d; - } - } - } + const uint nbr_par_th = num_blocks_per_row%it_size; + const uint nbr_all_th = num_blocks_per_row - nbr_par_th; + uint i0 = 0; + [[unroll]] for (; i0 < nbr_all_th; i0 += it_size) + calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true); + calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false); reduce_result(temp, d_offset, first_row, num_rows, tid); } From db6383094c347474c3af1fddc83cc386177cfbd5 Mon Sep 17 00:00:00 2001 From: fj-y-saito <85871716+fj-y-saito@users.noreply.github.com> Date: Thu, 16 Jan 2025 18:11:49 +0900 Subject: [PATCH 08/64] ggml: aarch64: implement SVE kernels for q4_K_q8_K vector dot (llama/11227) * Add SVE support for q4_K_q8_K * Update ggml/src/ggml-cpu/ggml-cpu-quants.c change to use K_SCALE_SIZE Co-authored-by: Georgi Gerganov --------- Co-authored-by: Georgi Gerganov --- ggml/src/ggml-cpu/ggml-cpu-quants.c | 83 ++++++++++++++++++++++++++++- 1 file changed, 82 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c index 8e14722667a..88303ff0e61 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.c +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.c @@ -5573,7 +5573,88 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r uint32_t utmp[4]; -#ifdef __ARM_NEON +#ifdef __ARM_FEATURE_SVE + float sumf = 0; + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); + + memcpy(utmp, x[i].scales, K_SCALE_SIZE); + + uint32x2_t mins8 = { 0 }; + mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0); + mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1); + + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[0] &= kmask1; + + const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8))); + const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), + vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); + sumf -= dmin * vaddvq_s32(prod); + + const uint8_t * scales = (const uint8_t *)utmp; + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const int vector_length = ggml_cpu_get_sve_cnt()*8; + const svuint8_t m4b = svdup_n_u8(0xf); + const svint32_t mzero = svdup_n_s32(0); + svint32_t sumi1 = svdup_n_s32(0); + svint32_t sumi1_1 = svdup_n_s32(0); + svint32_t sumi1_2 = svdup_n_s32(0); + svint32_t sumi2 = svdup_n_s32(0); + svint32_t sumi2_1 = svdup_n_s32(0); + svint32_t sumi2_2 = svdup_n_s32(0); + switch (vector_length) { + case 128: + { + for (int j = 0; j < QK_K/64; ++j) { + svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), m4b)); + svint8_t q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16; + sumi1_1 = svmla_n_s32_x(svptrue_b32(), sumi1_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]); + q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), m4b)); + q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16; + sumi1_2 = svmla_n_s32_x(svptrue_b32(), sumi1_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]); + + q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), 4)); + q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16; + sumi2_1 = svmla_n_s32_x(svptrue_b32(), sumi2_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]); + q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), 4)); + q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16; + sumi2_2 = svmla_n_s32_x(svptrue_b32(), sumi2_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]); + q4 += 32; + } + sumi1 = svadd_s32_x(svptrue_b32(), sumi1_1, sumi1_2); + sumi2 = svadd_s32_x(svptrue_b32(), sumi2_1, sumi2_2); + sumf += d * (svaddv_s32(svptrue_b32(), svadd_s32_x(svptrue_b32(), sumi1, sumi2))); + } break; + case 256: + case 512: + { + for (int j = 0; j < QK_K/64; ++j) { + const svuint8_t q4bits = svld1_u8(svptrue_pat_b8(SV_VL32), q4); q4 += 32; + svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_pat_b8(SV_VL32), q4bits, m4b)); + svint8_t q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32; + sumi1 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]); + + q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q4bits, 4)); + q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32; + sumi2 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]); + } + sumf += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), sumi1, sumi2))); + } break; + default: + assert(false && "Unsupported vector length"); + break; + } + } + *s = sumf; +#elif __ARM_NEON const uint8x16_t m4b = vdupq_n_u8(0xf); const int32x4_t mzero = vdupq_n_s32(0); From de49024e49d5cca0b5bfc44d02f8e6271a4ab585 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 16 Jan 2025 16:43:38 +0100 Subject: [PATCH 09/64] CUDA: backwards pass for misc. ops, add tests (llama/11257) * CUDA: backwards pass for misc. ops, add tests * remove restrict from pointers --- ggml/include/ggml.h | 12 +- ggml/src/ggml-alloc.c | 5 + ggml/src/ggml-cpu/ggml-cpu.c | 133 +++++++++-------- ggml/src/ggml-cpu/ggml-cpu.cpp | 10 ++ ggml/src/ggml-cuda/cross-entropy-loss.cu | 181 +++++++++++++---------- ggml/src/ggml-cuda/getrows.cu | 159 +++++++++++++------- ggml/src/ggml-cuda/getrows.cuh | 3 + ggml/src/ggml-cuda/ggml-cuda.cu | 27 +++- ggml/src/ggml-cuda/norm.cu | 155 +++++++++++++++---- ggml/src/ggml-cuda/norm.cuh | 2 + ggml/src/ggml-cuda/out-prod.cu | 38 +++-- ggml/src/ggml-cuda/rope.cu | 48 +++--- ggml/src/ggml-cuda/softmax.cu | 134 ++++++++++++----- ggml/src/ggml-cuda/softmax.cuh | 2 + ggml/src/ggml-cuda/unary.cu | 36 +++++ ggml/src/ggml-cuda/unary.cuh | 3 + ggml/src/ggml.c | 49 +++--- 17 files changed, 687 insertions(+), 310 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index a9c051cd5d6..1198dc1fd93 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1384,16 +1384,20 @@ extern "C" { float scale, float max_bias); - GGML_API struct ggml_tensor * ggml_soft_max_back( + GGML_API struct ggml_tensor * ggml_soft_max_ext_back( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b); + struct ggml_tensor * b, + float scale, + float max_bias); // in-place, returns view(a) - GGML_API struct ggml_tensor * ggml_soft_max_back_inplace( + GGML_API struct ggml_tensor * ggml_soft_max_ext_back_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b); + struct ggml_tensor * b, + float scale, + float max_bias); // rotary position embedding // if (mode & 1) - skip n_past elements (NOT SUPPORTED) diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index 8dc8226ac49..9a3bf9f2923 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -37,6 +37,7 @@ static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml return true; } +// ops that return true for this function must not use restrict pointers for their backend implementations static bool ggml_op_can_inplace(enum ggml_op op) { switch (op) { case GGML_OP_SCALE: @@ -52,8 +53,12 @@ static bool ggml_op_can_inplace(enum ggml_op op) { case GGML_OP_LOG: case GGML_OP_UNARY: case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + case GGML_OP_SILU_BACK: case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: return true; default: diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 8bf5f781a59..8040e20fe9d 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -6691,20 +6691,20 @@ static void ggml_compute_forward_silu_back_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * grad = dst->src[1]; + const struct ggml_tensor * grad = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; assert(ggml_is_contiguous_1(grad)); - assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(src1)); assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - assert(ggml_are_same_shape(src0, grad)); + assert(ggml_are_same_shape(src1, dst)); + assert(ggml_are_same_shape(src1, grad)); const int ith = params->ith; const int nth = params->nth; - const int nc = src0->ne[0]; - const int nr = ggml_nrows(src0); + const int nc = src1->ne[0]; + const int nr = ggml_nrows(src1); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -6716,7 +6716,7 @@ static void ggml_compute_forward_silu_back_f32( for (int i1 = ir0; i1 < ir1; i1++) { ggml_vec_silu_backward_f32(nc, (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1])), + (float *) ((char *) src1->data + i1*(src1->nb[1])), (float *) ((char *) grad->data + i1*(grad->nb[1]))); #ifndef NDEBUG @@ -6895,7 +6895,7 @@ static void ggml_compute_forward_norm_f32( float eps; memcpy(&eps, dst->op_params, sizeof(float)); - GGML_ASSERT(eps > 0.0f); + GGML_ASSERT(eps >= 0.0f); // TODO: optimize for (int64_t i03 = 0; i03 < ne03; i03++) { @@ -6966,7 +6966,7 @@ static void ggml_compute_forward_rms_norm_f32( float eps; memcpy(&eps, dst->op_params, sizeof(float)); - GGML_ASSERT(eps > 0.0f); + GGML_ASSERT(eps >= 0.0f); // TODO: optimize for (int64_t i03 = 0; i03 < ne03; i03++) { @@ -7018,12 +7018,13 @@ static void ggml_compute_forward_rms_norm_back_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; + const struct ggml_tensor * src0 = dst->src[0]; // gradients from forward pass output + const struct ggml_tensor * src1 = dst->src[1]; // src1 from forward pass GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1)); GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); const int ith = params->ith; const int nth = params->nth; @@ -7042,8 +7043,8 @@ static void ggml_compute_forward_rms_norm_back_f32( const int64_t i12 = i02; const int64_t i13 = i03; - const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - const float * dz = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13); + const float * dz = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + const float * x = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13); ggml_float sum_xx = 0.0; ggml_float sum_xdz = 0.0; @@ -7066,9 +7067,9 @@ static void ggml_compute_forward_rms_norm_back_f32( { // z = rms_norm(x) // - // rms_norm(src0) = + // rms_norm(src1) = // scale( - // src0, + // src1, // div( // 1, // sqrt( @@ -7076,13 +7077,13 @@ static void ggml_compute_forward_rms_norm_back_f32( // scale( // sum( // sqr( - // src0)), + // src1)), // (1.0/N)), // eps)))); // postorder: // ## op args grad - // 00 param src0 grad[#00] + // 00 param src1 grad[#00] // 01 const 1 // 02 sqr (#00) grad[#02] // 03 sum (#02) grad[#03] @@ -7159,6 +7160,7 @@ static void ggml_compute_forward_rms_norm_back_f32( // dx := scale(dx, rrms) float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + // dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps) ggml_vec_cpy_f32 (ne00, dx, x); // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps); ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps); @@ -7750,12 +7752,13 @@ static void ggml_compute_forward_out_prod_f32( const int ith = params->ith; const int nth = params->nth; - GGML_ASSERT(ne0 == ne00); - GGML_ASSERT(ne1 == ne10); - GGML_ASSERT(ne2 == ne02); - GGML_ASSERT(ne02 == ne12); - GGML_ASSERT(ne3 == ne13); - GGML_ASSERT(ne03 == ne13); + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne10); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + GGML_ASSERT(ne2 % ne02 == 0); + GGML_ASSERT(ne3 % ne03 == 0); // we don't support permuted src0 or src1 GGML_ASSERT(nb00 == sizeof(float)); @@ -7797,6 +7800,10 @@ static void ggml_compute_forward_out_prod_f32( const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32); const int64_t blck_1 = 16; + // dps == dst per src0, used for group query attention + const int64_t dps2 = ne2 / ne02; + const int64_t dps3 = ne3 / ne03; + for (int64_t bir = ir0; bir < ir1; bir += blck_1) { const int64_t bir1 = MIN(bir + blck_1, ir1); for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) { @@ -7807,8 +7814,8 @@ static void ggml_compute_forward_out_prod_f32( const int64_t i2 = (ir - i3*ne2*ne1)/ne1; const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1); - const int64_t i02 = i2; - const int64_t i03 = i3; + const int64_t i02 = i2 / dps2; + const int64_t i03 = i3 / dps3; //const int64_t i10 = i1; const int64_t i12 = i2; @@ -8906,9 +8913,9 @@ static void ggml_compute_forward_soft_max( } -// ggml_compute_forward_soft_max_back +// ggml_compute_forward_soft_max_ext_back -static void ggml_compute_forward_soft_max_back_f32( +static void ggml_compute_forward_soft_max_ext_back_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -8921,6 +8928,14 @@ static void ggml_compute_forward_soft_max_back_f32( GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_are_same_shape(src1, dst)); + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); + + GGML_ASSERT(max_bias == 0.0f); + // TODO: handle transposed/permuted matrices const int ith = params->ith; @@ -8969,10 +8984,11 @@ static void ggml_compute_forward_soft_max_back_f32( // linear runtime, no additional memory float dot_y_dy = 0; - ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1); - ggml_vec_cpy_f32 (nc, dx, dy); - ggml_vec_acc1_f32(nc, dx, -dot_y_dy); - ggml_vec_mul_f32 (nc, dx, dx, y); + ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1); + ggml_vec_cpy_f32 (nc, dx, dy); + ggml_vec_acc1_f32 (nc, dx, -dot_y_dy); + ggml_vec_mul_f32 (nc, dx, dx, y); + ggml_vec_scale_f32(nc, dx, scale); #ifndef NDEBUG for (int i = 0; i < nc; ++i) { @@ -8983,7 +8999,7 @@ static void ggml_compute_forward_soft_max_back_f32( } } -static void ggml_compute_forward_soft_max_back( +static void ggml_compute_forward_soft_max_ext_back( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -8992,7 +9008,7 @@ static void ggml_compute_forward_soft_max_back( switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_soft_max_back_f32(params, dst); + ggml_compute_forward_soft_max_ext_back_f32(params, dst); } break; default: { @@ -9985,9 +10001,10 @@ static void ggml_compute_forward_im2col_back_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; + const struct ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output + const struct ggml_tensor * src1 = dst->src[1]; // convolution kernel + GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -10009,11 +10026,11 @@ static void ggml_compute_forward_im2col_back_f32( const int64_t IH = is_2D ? ne1 : 1; const int64_t IW = ne0; - const int64_t KH = is_2D ? ne01 : 1; - const int64_t KW = ne00; + const int64_t KH = is_2D ? ne11 : 1; + const int64_t KW = ne10; - const int64_t OH = is_2D ? ne12 : 1; - const int64_t OW = ne11; + const int64_t OH = is_2D ? ne02 : 1; + const int64_t OW = ne01; int ofs0 = is_2D ? nb3 : nb2; int ofs1 = is_2D ? nb2 : nb1; @@ -10059,9 +10076,9 @@ static void ggml_compute_forward_im2col_back_f32( continue; } - const float * const src_data = (const float *) src1->data + const float * const grad_in = (const float *) src0->data + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW] - grad += src_data[iic*(KH*KW) + ikh*KW + ikw]; + grad += grad_in[iic*(KH*KW) + ikh*KW + ikw]; } } float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW] @@ -12484,22 +12501,22 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - const struct ggml_tensor * opt0 = dst->src[2]; + const struct ggml_tensor * grad = dst->src[0]; // gradient of forward pass output + const struct ggml_tensor * src0f = dst->src[1]; // src0 of forward pass + const struct ggml_tensor * src1f = dst->src[2]; // src1 of forward pass GGML_ASSERT(ggml_is_contiguous(dst)); - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - GGML_ASSERT(ggml_is_contiguous(opt0)); - GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_contiguous(src0f)); + GGML_ASSERT(ggml_is_contiguous(src1f)); + GGML_ASSERT(ggml_is_contiguous(grad)); + GGML_ASSERT(ggml_are_same_shape(src0f, src1f) && ggml_are_same_shape(src0f, dst)); const int64_t ith = params->ith; const int64_t nth = params->nth; // TODO: handle transposed/permuted matrices - const int64_t nc = src0->ne[0]; - const int64_t nr = ggml_nrows(src0); + const int64_t nc = src0f->ne[0]; + const int64_t nr = ggml_nrows(src0f); // rows per thread const int64_t dr = (nr + nth - 1)/nth; @@ -12508,12 +12525,12 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32( const int64_t ir0 = dr*ith; const int64_t ir1 = MIN(ir0 + dr, nr); - const float d_by_nr = ((const float *) opt0->data)[0] / (float) nr; + const float d_by_nr = ((const float *) grad->data)[0] / (float) nr; for (int64_t i1 = ir0; i1 < ir1; i1++) { - float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]); - float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]); - float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]); + float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]); + const float * s0 = (const float *)((const char *) src0f->data + i1*src0f->nb[1]); + const float * s1 = (const float *)((const char *) src1f->data + i1*src1f->nb[1]); #ifndef NDEBUG for (int64_t i = 0; i < nc; ++i) { @@ -12526,11 +12543,11 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32( // soft_max float max = -INFINITY; ggml_vec_max_f32(nc, &max, s0); - ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max); + const ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max); assert(sum > 0.0); ggml_vec_scale_f32(nc, ds0, 1.0/sum); - // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr + // grad(src0f) = (softmax(src0f) - src1f) * grad(cross_entropy_loss(src0f, src1f)) / nr ggml_vec_sub_f32(nc, ds0, ds0, s1); ggml_vec_scale_f32(nc, ds0, d_by_nr); @@ -12827,7 +12844,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_SOFT_MAX_BACK: { - ggml_compute_forward_soft_max_back(params, tensor); + ggml_compute_forward_soft_max_ext_back(params, tensor); } break; case GGML_OP_ROPE: { diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index 5c47ceb7314..35a1c876c86 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -403,6 +403,16 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float case GGML_OP_MUL_MAT: return src1->type == GGML_TYPE_F32 || src1->type == ggml_get_type_traits_cpu(src0->type)->vec_dot_type; + case GGML_OP_SOFT_MAX_BACK: { + if (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type != GGML_TYPE_F32) { + return false; + } + float max_bias = 0.0f; + + memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float)); + + return max_bias == 0.0f; + } case GGML_OP_IM2COL_BACK: return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32; case GGML_OP_OUT_PROD: diff --git a/ggml/src/ggml-cuda/cross-entropy-loss.cu b/ggml/src/ggml-cuda/cross-entropy-loss.cu index ed09406a88b..27599a2b038 100644 --- a/ggml/src/ggml-cuda/cross-entropy-loss.cu +++ b/ggml/src/ggml-cuda/cross-entropy-loss.cu @@ -5,95 +5,89 @@ #include #include -static __global__ void cross_entropy_loss_f32(const float * logits, const float * labels, float * dst, const int nclasses, const int k) { - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - const int i0 = blockDim.x*blockIdx.x + warp_id*WARP_SIZE; - - const int ne_tmp = WARP_SIZE*nclasses; - - extern __shared__ float tmp_all[]; - float * tmp_logits = tmp_all + (2*warp_id + 0)*ne_tmp; - float * tmp_labels = tmp_all + (2*warp_id + 1)*ne_tmp; - - // Each warp first loads ne_tmp logits/labels into shared memory: - for (int i = lane_id; i < ne_tmp; i += WARP_SIZE) { - const int ig = i0*nclasses + i; // ig == i global - - tmp_logits[i] = ig < k*nclasses ? logits[ig] : 0.0f; - tmp_labels[i] = ig < k*nclasses ? labels[ig] : 0.0f; - } +template +static __global__ void cross_entropy_loss_f32( + const float * __restrict__ logits, const float * __restrict__ labels, float * __restrict__ dst, const int nclasses, const int k) { + extern __shared__ float tmp[]; - // Each thread in the warp then calculates the cross entropy loss for a single row. - // TODO: pad in order to avoid shared memory bank conflicts. + logits += int64_t(blockIdx.x)*nclasses; + labels += int64_t(blockIdx.x)*nclasses; // Find maximum for softmax: - float max = -INFINITY; - for (int i = 0; i < nclasses; ++i) { - max = fmaxf(max, tmp_logits[lane_id*nclasses + i]); + float max_logit = -INFINITY; + for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { + const float val = logits[i]; + max_logit = fmaxf(max_logit, val); + + if (use_shared) { + tmp[i] = val; + } } + max_logit = warp_reduce_max(max_logit); // Calculate log(softmax(logits)) which is just logits - max: float sum = 0.0f; - for (int i = 0; i < nclasses; ++i) { - float val = tmp_logits[lane_id*nclasses + i] - max; - sum += expf(val); - tmp_logits[lane_id*nclasses + i] = val; + for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { + const float logit_i = use_shared ? tmp[i] : logits[i]; + sum += expf(logit_i - max_logit); } + sum = warp_reduce_sum(sum); sum = logf(sum); // log(exp(logits - max) / sum) = (logits - max) - log(sum) float loss = 0.0f; - for (int i = 0; i < nclasses; ++i) { - loss += (tmp_logits[lane_id*nclasses + i] - sum) * tmp_labels[lane_id*nclasses + i]; + for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { + const float logit_i = use_shared ? tmp[i] : logits[i]; + loss += (logit_i - max_logit - sum) * labels[i]; } loss = -warp_reduce_sum(loss) / (float)k; - __syncthreads(); - - if (lane_id == 0) { - tmp_all[warp_id] = loss; - } - - __syncthreads(); - - if (warp_id != 0) { - return; - } - - loss = lane_id < CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE/WARP_SIZE ? tmp_all[lane_id] : 0.0f; - loss = warp_reduce_sum(loss); - - if (lane_id != 0) { + if (threadIdx.x != 0) { return; } dst[blockIdx.x] = loss; } -static __global__ void cross_entropy_loss_back_f32(const float * logits, const float * labels, const float * loss, float * dst, const int nclasses) { +template +static __global__ void cross_entropy_loss_back_f32( + const float * __restrict__ grad, const float * __restrict__ logits, const float * __restrict__ labels, + float * __restrict__ dst, const int nclasses) { extern __shared__ float tmp[]; + logits += int64_t(blockIdx.x)*nclasses; + labels += int64_t(blockIdx.x)*nclasses; + dst += int64_t(blockIdx.x)*nclasses; + float maxval = -INFINITY; for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { - const float val = logits[blockIdx.x*nclasses + i]; + const float val = logits[i]; maxval = fmaxf(maxval, val); - tmp[i] = val; + + if (use_shared) { + tmp[i] = val; + } } maxval = warp_reduce_max(maxval); float sum = 0.0f; for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { - const float val = expf(tmp[i] - maxval); + const float val = expf((use_shared ? tmp[i] : logits[i]) - maxval); sum += val; - tmp[i] = val; + + if (use_shared) { + tmp[i] = val; + } else { + dst[i] = val; + } } sum = warp_reduce_sum(sum); const float sm_scale = 1.0f/sum; - const float d_by_nrows = *loss/gridDim.x; + const float d_by_nrows = *grad/gridDim.x; for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { - dst[blockIdx.x*nclasses + i] = (tmp[i]*sm_scale - labels[blockIdx.x*nclasses + i])*d_by_nrows; + const float val = use_shared ? tmp[i] : dst[i]; + dst[i] = (val*sm_scale - labels[i])*d_by_nrows; } } @@ -119,48 +113,77 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * ggml_cuda_pool & pool = ctx.pool(); cudaStream_t stream = ctx.stream(); - const dim3 blocks_dim(CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1); - const dim3 blocks_num((nrows + CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE - 1) / CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1); - const int shmem = 2*CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE*ne00*sizeof(float); + const dim3 blocks_dim(WARP_SIZE, 1, 1); + const dim3 blocks_num(nrows, 1, 1); + const size_t nbytes_shared = ne00*sizeof(float); + + const int id = ggml_cuda_get_device(); + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; ggml_cuda_pool_alloc dst_tmp(pool, blocks_num.x); - cross_entropy_loss_f32<<>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows); + if (nbytes_shared <= smpbo) { +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) + static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; + if (!shared_memory_limit_raised[id]) { + CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo)); + shared_memory_limit_raised[id] = true; + } +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) + cross_entropy_loss_f32<<>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows); + } else { + cross_entropy_loss_f32<<>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows); + } + CUDA_CHECK(cudaGetLastError()); // Combine results from individual blocks: sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream); } void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const ggml_tensor * src1 = dst->src[1]; - const ggml_tensor * opt0 = dst->src[2]; - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(opt0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - GGML_ASSERT(ggml_is_contiguous(opt0)); + const ggml_tensor * grad = dst->src[0]; + const ggml_tensor * src0f = dst->src[1]; + const ggml_tensor * src1f = dst->src[2]; + + GGML_ASSERT(src0f->type == GGML_TYPE_F32); + GGML_ASSERT(src1f->type == GGML_TYPE_F32); + GGML_ASSERT( grad->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_scalar(grad)); + GGML_ASSERT(ggml_is_contiguous(src0f)); + GGML_ASSERT(ggml_is_contiguous(src1f)); GGML_ASSERT(ggml_is_contiguous(dst)); - GGML_ASSERT(ggml_are_same_shape(src0, src1)); - GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_are_same_shape(src0f, src1f)); + GGML_ASSERT(ggml_are_same_shape(src0f, dst)); - const int64_t ne00 = src0->ne[0]; - const int64_t nrows = ggml_nrows(src0); + const int64_t ne00 = src0f->ne[0]; + const int64_t nrows = ggml_nrows(src0f); - const float * src0_d = (const float *) src0->data; - const float * src1_d = (const float *) src1->data; - const float * opt0_d = (const float *) opt0->data; - float * dst_d = (float *) dst->data; + const float * grad_d = (const float *) grad->data; + const float * src0f_d = (const float *) src0f->data; + const float * src1f_d = (const float *) src1f->data; + float * dst_d = (float *) dst->data; cudaStream_t stream = ctx.stream(); const dim3 blocks_dim(WARP_SIZE, 1, 1); const dim3 blocks_num(nrows, 1, 1); - const int shmem = ne00*sizeof(float); - - cross_entropy_loss_back_f32<<>>(src0_d, src1_d, opt0_d, dst_d, ne00); + const size_t nbytes_shared = ne00*sizeof(float); + + const int id = ggml_cuda_get_device(); + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; + + if (nbytes_shared <= smpbo) { +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) + static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; + if (!shared_memory_limit_raised[id]) { + CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo)); + shared_memory_limit_raised[id] = true; + } +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) + cross_entropy_loss_back_f32<<>>(grad_d, src0f_d, src1f_d, dst_d, ne00); + } else { + cross_entropy_loss_back_f32<<>>(grad_d, src0f_d, src1f_d, dst_d, ne00); + } } diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index 4c3703238cb..4cef53a98cf 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -3,15 +3,15 @@ template static __global__ void k_get_rows( - const void * src0, const int32_t * src1, dst_t * dst, - int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/ - /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/ - /*size_t s0,*/ size_t s1, size_t s2, size_t s3, - /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03, - size_t s10, size_t s11, size_t s12/*, size_t s13*/) { + const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst, + const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/ + /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/ + /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3, + /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, + const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2; - const int i10 = blockDim.y*blockIdx.y + threadIdx.y; + const int i10 = blockDim.y*blockIdx.y + threadIdx.y; const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12; const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12; @@ -22,10 +22,10 @@ static __global__ void k_get_rows( const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; - const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03; + const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03; - const int ib = i00/qk; // block index - const int iqs = (i00%qk)/qr; // quant index + const int ib = i00/qk; // block index + const int iqs = (i00%qk)/qr; // quant index const int iybs = i00 - i00%qk; // dst block start index const int y_offset = qr == 1 ? 1 : qk/2; @@ -39,15 +39,15 @@ static __global__ void k_get_rows( template static __global__ void k_get_rows_float( - const src0_t * src0, const int32_t * src1, dst_t * dst, - int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/ - /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/ - /*size_t s0,*/ size_t s1, size_t s2, size_t s3, - /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03, - size_t s10, size_t s11, size_t s12/*, size_t s13*/) { - - const int i00 = blockIdx.x*blockDim.x + threadIdx.x; - const int i10 = blockDim.y*blockIdx.y + threadIdx.y; + const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst, + const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/ + /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/ + /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3, + /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, + const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { + + const int i00 = blockIdx.x*blockDim.x + threadIdx.x; + const int i10 = blockDim.y*blockIdx.y + threadIdx.y; const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12; const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12; @@ -58,14 +58,38 @@ static __global__ void k_get_rows_float( const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; - const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03); + const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03); dst_row[i00] = src0_row[i00]; } +template +static __global__ void k_get_rows_back_float( + const grad_t * __restrict__ grad, const int32_t * __restrict__ rows, dst_t * __restrict__ dst, const int64_t ncols, const int64_t nrows_grad) { + const int col = blockIdx.x*blockDim.x + threadIdx.x; + + if (col >= ncols) { + return; + } + + const int dst_row = blockIdx.y*blockDim.y + threadIdx.y; + + float sum = 0.0f; + + for (int64_t i = 0; i < nrows_grad; ++i) { + if (rows[i] != dst_row) { + continue; + } + sum += grad[i*ncols + col]; + } + + dst[dst_row*ncols + col] = sum; +} + template -static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { +static void get_rows_cuda( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { GGML_TENSOR_BINARY_OP_LOCALS @@ -87,22 +111,25 @@ static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, gg GGML_ASSERT(ne00 % 2 == 0); k_get_rows<<>>( - src0_dd, src1_dd, dst_dd, - ne00, /*ne01, ne02, ne03,*/ - /*ne10, ne11,*/ ne12, /*ne13,*/ - /* s0,*/ s1, s2, s3, - /* nb00,*/ nb01, nb02, nb03, - s10, s11, s12/*, s13*/); + src0_dd, src1_dd, dst_dd, + ne00, /*ne01, ne02, ne03,*/ + /*ne10, ne11,*/ ne12, /*ne13,*/ + /* s0,*/ s1, s2, s3, + /* nb00,*/ nb01, nb02, nb03, + s10, s11, s12/*, s13*/); GGML_UNUSED(dst); } template -static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { +static void get_rows_cuda_float( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { GGML_TENSOR_BINARY_OP_LOCALS + GGML_ASSERT(ne13 == 1); + const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE; const dim3 block_nums(block_num_x, ne10, ne11*ne12); @@ -119,12 +146,12 @@ static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * sr //const size_t s13 = nb13 / ggml_element_size(src1); k_get_rows_float<<>>( - src0_dd, src1_dd, dst_dd, - ne00, /*ne01, ne02, ne03,*/ - /*ne10, ne11,*/ ne12, /*ne13,*/ - /* s0,*/ s1, s2, s3, - /* nb00,*/ nb01, nb02, nb03, - s10, s11, s12/*, s13*/); + src0_dd, src1_dd, dst_dd, + ne00, /*ne01, ne02, ne03,*/ + /*ne10, ne11,*/ ne12, /*ne13,*/ + /* s0,*/ s1, s2, s3, + /* nb00,*/ nb01, nb02, nb03, + s10, s11, s12/*, s13*/); GGML_UNUSED(dst); } @@ -132,42 +159,41 @@ static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * sr void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; - const float * src0_d = (const float *)src0->data; - const float * src1_d = (const float *)src1->data; - float * dst_d = (float *)dst->data; - cudaStream_t stream = ctx.stream(); + const void * src0_d = (const void *) src0->data; + const int32_t * src1_d = (const int32_t *) src1->data; + float * dst_d = (float *) dst->data; + + cudaStream_t stream = ctx.stream(); GGML_ASSERT(src1->type == GGML_TYPE_I32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type)); - GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type)); - - const int32_t * src1_i32 = (const int32_t *) src1_d; + GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type)); switch (src0->type) { case GGML_TYPE_F16: - get_rows_cuda_float(src0, src1, dst, (const half *)src0_d, src1_i32, dst_d, stream); + get_rows_cuda_float(src0, src1, dst, (const half *) src0_d, src1_d, dst_d, stream); break; case GGML_TYPE_F32: - get_rows_cuda_float(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_cuda_float(src0, src1, dst, (const float *) src0_d, src1_d, dst_d, stream); break; case GGML_TYPE_Q4_0: - get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream); break; case GGML_TYPE_Q4_1: - get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream); break; case GGML_TYPE_Q5_0: - get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream); break; case GGML_TYPE_Q5_1: - get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream); break; case GGML_TYPE_Q8_0: - get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream); break; default: // TODO: k-quants @@ -175,3 +201,34 @@ void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { break; } } + +void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output + const ggml_tensor * src1 = dst->src[1]; // src1 in forward pass + + GGML_TENSOR_BINARY_OP_LOCALS + + const float * src0_d = (const float *) src0->data; + const int32_t * src1_d = (const int32_t *) src1->data; + float * dst_d = (float *) dst->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_I32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + GGML_ASSERT(ne02*ne03 == 1); + GGML_ASSERT(ne12*ne13 == 1); + GGML_ASSERT(ne2*ne3 == 1); + + const dim3 block_dims(CUDA_GET_ROWS_BACK_BLOCK_SIZE, 1, 1); + const int block_num_x = (ne00 + CUDA_GET_ROWS_BACK_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BACK_BLOCK_SIZE; + const dim3 block_nums(block_num_x, ne1, 1); + + k_get_rows_back_float<<>>(src0_d, src1_d, dst_d, ne00, ne10); +} diff --git a/ggml/src/ggml-cuda/getrows.cuh b/ggml/src/ggml-cuda/getrows.cuh index bbf1302325c..a1ca643f1c5 100644 --- a/ggml/src/ggml-cuda/getrows.cuh +++ b/ggml/src/ggml-cuda/getrows.cuh @@ -1,5 +1,8 @@ #include "common.cuh" #define CUDA_GET_ROWS_BLOCK_SIZE 256 +#define CUDA_GET_ROWS_BACK_BLOCK_SIZE 256 void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 9118edc7282..7fd1fc85346 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2003,6 +2003,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_GET_ROWS: ggml_cuda_op_get_rows(ctx, dst); break; + case GGML_OP_GET_ROWS_BACK: + ggml_cuda_op_get_rows_back(ctx, dst); + break; case GGML_OP_DUP: ggml_cuda_dup(ctx, dst); break; @@ -2091,9 +2094,15 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_LEAKY_RELU: ggml_cuda_op_leaky_relu(ctx, dst); break; + case GGML_OP_SILU_BACK: + ggml_cuda_op_silu_back(ctx, dst); + break; case GGML_OP_RMS_NORM: ggml_cuda_op_rms_norm(ctx, dst); break; + case GGML_OP_RMS_NORM_BACK: + ggml_cuda_op_rms_norm_back(ctx, dst); + break; case GGML_OP_MUL_MAT: if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) { GGML_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]); @@ -2138,6 +2147,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SOFT_MAX: ggml_cuda_op_soft_max(ctx, dst); break; + case GGML_OP_SOFT_MAX_BACK: + ggml_cuda_op_soft_max_back(ctx, dst); + break; case GGML_OP_ROPE: ggml_cuda_op_rope(ctx, dst); break; @@ -2912,7 +2924,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g } } break; case GGML_OP_OUT_PROD: - return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1; + return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; case GGML_OP_GET_ROWS: { switch (op->src[0]->type) { @@ -2928,6 +2940,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return false; } } break; + case GGML_OP_GET_ROWS_BACK: + { + return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1; + } break; case GGML_OP_CPY: { ggml_type src0_type = op->src[0]->type; @@ -3001,8 +3017,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g } return false; } break; + case GGML_OP_SILU_BACK: + return ggml_is_contiguous(op->src[0]); + break; case GGML_OP_NORM: case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0; break; case GGML_OP_NONE: @@ -3027,6 +3047,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: return true; + case GGML_OP_SOFT_MAX_BACK: { + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float)); + return max_bias == 0.0f; + } case GGML_OP_ROPE: case GGML_OP_ROPE_BACK: { const size_t ts = ggml_type_size(op->src[0]->type); diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index 133e219f0ae..d991ec97281 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -5,20 +5,24 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols, c const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; - float2 mean_var = make_float2(0.f, 0.f); + x += int64_t(row)*ncols; + dst += int64_t(row)*ncols; + + float2 mean_var = make_float2(0.0f, 0.0f); for (int col = tid; col < ncols; col += block_size) { - const float xi = x[row*ncols + col]; + const float xi = x[col]; mean_var.x += xi; mean_var.y += xi * xi; } // sum up partial sums mean_var = warp_reduce_sum(mean_var); - if (block_size > WARP_SIZE) { + if constexpr (block_size > WARP_SIZE) { + static_assert(block_size == 1024, "unexpected block_size"); __shared__ float2 s_sum[32]; - int warp_id = threadIdx.x / WARP_SIZE; - int lane_id = threadIdx.x % WARP_SIZE; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; if (lane_id == 0) { s_sum[warp_id] = mean_var; } @@ -32,7 +36,7 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols, c const float inv_std = rsqrtf(var + eps); for (int col = tid; col < ncols; col += block_size) { - dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std; + dst[col] = (x[col] - mean) * inv_std; } } @@ -40,14 +44,8 @@ template static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) { // blockIdx.x: num_groups idx // threadIdx.x: block_size idx - int start = blockIdx.x * group_size; - int end = start + group_size; - - start += threadIdx.x; - - if (end >= ne_elements) { - end = ne_elements; - } + const int start = blockIdx.x*group_size + threadIdx.x; + const int end = min(blockIdx.x*group_size + group_size, ne_elements); float tmp = 0.0f; // partial sum for thread in warp @@ -56,10 +54,11 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr } tmp = warp_reduce_sum(tmp); - if (block_size > WARP_SIZE) { + if constexpr (block_size > WARP_SIZE) { + static_assert(block_size == 1024, "unexpected block_size"); __shared__ float s_sum[32]; - int warp_id = threadIdx.x / WARP_SIZE; - int lane_id = threadIdx.x % WARP_SIZE; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; if (lane_id == 0) { s_sum[warp_id] = tmp; } @@ -68,11 +67,11 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr tmp = warp_reduce_sum(tmp); } - float mean = tmp / group_size; + const float mean = tmp / group_size; tmp = 0.0f; for (int j = start; j < end; j += block_size) { - float xi = x[j] - mean; + const float xi = x[j] - mean; dst[j] = xi; tmp += xi * xi; } @@ -80,8 +79,8 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr tmp = warp_reduce_sum(tmp); if (block_size > WARP_SIZE) { __shared__ float s_sum[32]; - int warp_id = threadIdx.x / WARP_SIZE; - int lane_id = threadIdx.x % WARP_SIZE; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; if (lane_id == 0) { s_sum[warp_id] = tmp; } @@ -90,8 +89,8 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr tmp = warp_reduce_sum(tmp); } - float variance = tmp / group_size; - float scale = rsqrtf(variance + eps); + const float variance = tmp / group_size; + const float scale = rsqrtf(variance + eps); for (int j = start; j < end; j += block_size) { dst[j] *= scale; } @@ -102,19 +101,23 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; + x += int64_t(row)*ncols; + dst += int64_t(row)*ncols; + float tmp = 0.0f; // partial sum for thread in warp for (int col = tid; col < ncols; col += block_size) { - const float xi = x[row*ncols + col]; + const float xi = x[col]; tmp += xi * xi; } // sum up partial sums tmp = warp_reduce_sum(tmp); - if (block_size > WARP_SIZE) { + if constexpr (block_size > WARP_SIZE) { + static_assert(block_size == 1024, "unexpected block_size"); __shared__ float s_sum[32]; - int warp_id = threadIdx.x / WARP_SIZE; - int lane_id = threadIdx.x % WARP_SIZE; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; if (lane_id == 0) { s_sum[warp_id] = tmp; } @@ -127,12 +130,63 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol const float scale = rsqrtf(mean + eps); for (int col = tid; col < ncols; col += block_size) { - dst[row*ncols + col] = scale * x[row*ncols + col]; + dst[col] = scale * x[col]; + } +} + +template +static __global__ void rms_norm_back_f32( + const float * grad, const float * xf, float * dst, const int ncols, const float eps) { + const int row = blockIdx.x*blockDim.y + threadIdx.y; + const int tid = threadIdx.x; + + grad += int64_t(row)*ncols; + xf += int64_t(row)*ncols; + dst += int64_t(row)*ncols; + + float sum_xx = 0.0f; // sum for squares of x, equivalent to forward pass + float sum_xg = 0.0f; // sum for x * gradient, needed because RMS norm mixes inputs + + for (int col = tid; col < ncols; col += block_size) { + const float xfi = xf[col]; + sum_xx += xfi * xfi; + sum_xg += xfi * grad[col]; + } + + // sum up partial sums + sum_xx = warp_reduce_sum(sum_xx); + sum_xg = warp_reduce_sum(sum_xg); + if constexpr (block_size > WARP_SIZE) { + static_assert(block_size == 1024, "unexpected block_size"); + __shared__ float s_sum_xx[32]; + __shared__ float s_sum_xg[32]; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum_xx[warp_id] = sum_xx; + s_sum_xg[warp_id] = sum_xg; + } + __syncthreads(); + + sum_xx = s_sum_xx[lane_id]; + sum_xx = warp_reduce_sum(sum_xx); + + sum_xg = s_sum_xg[lane_id]; + sum_xg = warp_reduce_sum(sum_xg); + } + + const float mean_eps = sum_xx / ncols + eps; + const float sum_eps = sum_xx + ncols*eps; + + const float scale_grad = rsqrtf(mean_eps); + const float scale_x = -scale_grad * sum_xg/sum_eps; + + for (int col = tid; col < ncols; col += block_size) { + dst[col] = scale_grad*grad[col] + scale_x*xf[col]; } } static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { - GGML_ASSERT(ncols % WARP_SIZE == 0); if (ncols < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); norm_f32<<>>(x, dst, ncols, eps); @@ -142,7 +196,8 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i } } -static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) { +static void group_norm_f32_cuda( + const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) { if (group_size < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); group_norm_f32<<>>(x, dst, group_size, ne_elements, eps); @@ -153,7 +208,6 @@ static void group_norm_f32_cuda(const float * x, float * dst, const int num_grou } static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { - GGML_ASSERT(ncols % WARP_SIZE == 0); if (ncols < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); rms_norm_f32<<>>(x, dst, ncols, eps); @@ -163,6 +217,16 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con } } +static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { + if (ncols < 1024) { + const dim3 block_dims(WARP_SIZE, 1, 1); + rms_norm_back_f32<<>>(grad, xf, dst, ncols, eps); + } else { + const dim3 block_dims(1024, 1, 1); + rms_norm_back_f32<1024><<>>(grad, xf, dst, ncols, eps); + } +} + void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; @@ -179,6 +243,7 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float eps; memcpy(&eps, dst->op_params, sizeof(float)); + GGML_ASSERT(eps >= 0.0f); norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream); } @@ -198,6 +263,7 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) float eps; memcpy(&eps, dst->op_params + 1, sizeof(float)); + GGML_ASSERT(eps >= 0.0f); int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups); group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], eps, group_size, ggml_nelements(src0), stream); @@ -219,6 +285,33 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float eps; memcpy(&eps, dst->op_params, sizeof(float)); + GGML_ASSERT(eps >= 0.0f); rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream); } + +void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * grad = dst->src[0]; // gradients + const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass + + const float * grad_d = (const float *) grad->data; + const float * src0f_d = (const float *) src0f->data; + float * dst_d = (float *) dst->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(grad)); + + GGML_ASSERT( grad->type == GGML_TYPE_F32); + GGML_ASSERT(src0f->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int64_t ne00 = src0f->ne[0]; + const int64_t nrows = ggml_nrows(src0f); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + GGML_ASSERT(eps >= 0.0f); + + rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream); +} diff --git a/ggml/src/ggml-cuda/norm.cuh b/ggml/src/ggml-cuda/norm.cuh index 431a8f74d55..d63d34380b0 100644 --- a/ggml/src/ggml-cuda/norm.cuh +++ b/ggml/src/ggml-cuda/norm.cuh @@ -5,3 +5,5 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/out-prod.cu b/ggml/src/ggml-cuda/out-prod.cu index 619cfdcb589..73e3e2c47f2 100644 --- a/ggml/src/ggml-cuda/out-prod.cu +++ b/ggml/src/ggml-cuda/out-prod.cu @@ -11,16 +11,15 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(dst)); GGML_ASSERT(ne01 == ne11); GGML_ASSERT(ne0 == ne00); GGML_ASSERT(ne1 == ne10); - GGML_ASSERT(ne2 == src0->ne[2]); + GGML_ASSERT(ne2 % src0->ne[2] == 0); + GGML_ASSERT(ne3 % src0->ne[3] == 0); + GGML_ASSERT(ne2 == src1->ne[2]); - GGML_ASSERT(ne3 == src0->ne[3]); GGML_ASSERT(ne3 == src1->ne[3]); const float * src0_d = (const float *) src0->data; @@ -33,8 +32,6 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const float alpha = 1.0f; const float beta = 0.0f; - GGML_ASSERT(ne2 == 1); - GGML_ASSERT(ne3 == 1); CUBLAS_CHECK(cublasSetStream(handle, stream)); const bool src1_T = ggml_is_transposed(src1); @@ -42,10 +39,27 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float); GGML_ASSERT( (src1_T ? nb11 : nb10) == sizeof(float)); - CUBLAS_CHECK( - cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op, - ne0, ne1, ne01, - &alpha, src0_d, ne00, - src1_d, ldb, - &beta, dst_d, ne0)); + // data strides in dimensions 2/3 + const size_t s02 = nb02 / sizeof(float); + const size_t s03 = nb03 / sizeof(float); + const size_t s12 = nb12 / sizeof(float); + const size_t s13 = nb13 / sizeof(float); + const size_t s2 = nb2 / sizeof(float); + const size_t s3 = nb3 / sizeof(float); + + // dps == dst per src0, used for group query attention + const int64_t dps2 = ne2 / ne02; + const int64_t dps3 = ne3 / ne03; + + // TODO batched matrix multiplication + for (int64_t i3 = 0; i3 < ne3; ++i3) { + for (int64_t i2 = 0; i2 < ne2; ++i2) { + CUBLAS_CHECK( + cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op, + ne0, ne1, ne01, + &alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, ne00, + src1_d + i3 *s13 + i2 *s12, ldb, + &beta, dst_d + i3 *s3 + i2 *s2, ne0)); + } + } } diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index e1912fee1f9..18f691b2d31 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -39,9 +39,9 @@ static __device__ void rope_yarn( template static __global__ void rope_norm( - const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, - const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors) { + const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, + const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (i0 >= ne0) { @@ -83,9 +83,9 @@ static __global__ void rope_norm( template static __global__ void rope_neox( - const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, - const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors) { + const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, + const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (i0 >= ne0) { @@ -127,9 +127,9 @@ static __global__ void rope_neox( template static __global__ void rope_multi( - const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, - const int n_dims, const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors, const mrope_sections sections) { + const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, + const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (i0 >= ne0) { @@ -187,9 +187,9 @@ static __global__ void rope_multi( template static __global__ void rope_vision( - const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, - const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, - const float theta_scale, const float * __restrict__ freq_factors, const mrope_sections sections) { + const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, + const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, + const float theta_scale, const float * freq_factors, const mrope_sections sections) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (i0 >= ne0) { @@ -234,9 +234,9 @@ static __global__ void rope_vision( template static void rope_norm_cuda( - const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr, - const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, cudaStream_t stream) { + const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr, + const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); @@ -257,9 +257,9 @@ static void rope_norm_cuda( template static void rope_neox_cuda( - const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr, - const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, cudaStream_t stream) { + const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr, + const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); @@ -280,9 +280,9 @@ static void rope_neox_cuda( template static void rope_multi_cuda( - const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, - const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, const mrope_sections sections, cudaStream_t stream) { + const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, + const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) { GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); @@ -303,9 +303,9 @@ static void rope_multi_cuda( template static void rope_vision_cuda( - const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, - const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, const mrope_sections sections, cudaStream_t stream) { + const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, + const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) { GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index c24abae1f13..9aa4b848938 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -1,5 +1,7 @@ #include "common.cuh" +#include "ggml.h" #include "softmax.cuh" +#include template static __device__ __forceinline__ float t2f32(T val) { @@ -11,14 +13,20 @@ __device__ float __forceinline__ t2f32(half val) { return __half2float(val); } -template -static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { +template +static __global__ void soft_max_f32( + const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, + const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; const int tid = threadIdx.x; const int rowx = blockIdx.x; const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension + x += int64_t(rowx)*ncols; + mask += int64_t(rowy)*ncols * (mask != nullptr); + dst += int64_t(rowx)*ncols; + const int block_size = block_size_template == 0 ? blockDim.x : block_size_template; const int warp_id = threadIdx.x / WARP_SIZE; @@ -29,7 +37,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst extern __shared__ float data_soft_max_f32[]; float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication // shared memory buffer to cache values between iterations: - float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + (int64_t)rowx*ncols; + float * vals = use_shared ? buf_iw + WARP_SIZE : dst; float max_val = -INFINITY; @@ -41,10 +49,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst break; } - const int64_t ix = (int64_t)rowx*ncols + col; - const int64_t iy = (int64_t)rowy*ncols + col; - - const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f); + const float val = x[col]*scale + (mask ? slope*t2f32(mask[col]) : 0.0f); vals[col] = val; max_val = max(max_val, val); @@ -110,8 +115,29 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst return; } - const int64_t idst = (int64_t)rowx*ncols + col; - dst[idst] = vals[col] * inv_sum; + dst[col] = vals[col] * inv_sum; + } +} + +static __global__ void soft_max_back_f32( + const float * grad, const float * dstf, float * dst, const int ncols, const float scale) { + const int tid = threadIdx.x; + const int rowx = blockIdx.x; + + grad += int64_t(rowx)*ncols; + dstf += int64_t(rowx)*ncols; + dst += int64_t(rowx)*ncols; + + float dgf_dot = 0.0f; // dot product of dst from forward pass and gradients + + for (int col = tid; col < ncols; col += WARP_SIZE) { + dgf_dot += dstf[col]*grad[col]; + } + + dgf_dot = warp_reduce_sum(dgf_dot); + + for (int col = tid; col < ncols; col += WARP_SIZE) { + dst[col] = scale * (grad[col] - dgf_dot) * dstf[col]; } } @@ -121,7 +147,7 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); const dim3 block_nums(nrows_x, 1, 1); - const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float); + const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float); static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); const uint32_t n_head = nrows_x/nrows_y; @@ -131,50 +157,68 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); // FIXME: this limit could be raised by ~2-4x on Ampere or newer - if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) { + if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) { switch (ncols_x) { case 32: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 64: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 128: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 256: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 512: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 1024: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 2048: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 4096: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; default: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; } } else { - const size_t shmem_low = WARP_SIZE*sizeof(float); - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + const size_t nbytes_shared_low = WARP_SIZE*sizeof(float); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); } } +static void soft_max_back_f32_cuda( + const float * grad, const float * dstf, float * dst, + const int ncols, const int nrows, const float scale, cudaStream_t stream) { + const dim3 block_dims(WARP_SIZE, 1, 1); + const dim3 block_nums(nrows, 1, 1); + + soft_max_back_f32<<>>(grad, dstf, dst, ncols, scale); +} + void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; - const float * src0_d = (const float *)src0->data; - const void * src1_d = src1 ? (const void *)src1->data : nullptr; + const float * src0_d = (const float *) src0->data; + const void * src1_d = src1 ? (const void *) src1->data : nullptr; + float * dst_d = (float *) dst->data; - float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); @@ -189,18 +233,42 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float scale = 1.0f; float max_bias = 0.0f; - memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); - memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); + memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); if (use_f16) { - const half * src1_dd = (const half *)src1_d; - - soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); } else { - const float * src1_dd = (const float *)src1_d; - - soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); } } + +void ggml_cuda_op_soft_max_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; // grad + const ggml_tensor * src1 = dst->src[1]; // forward pass output + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); + + GGML_ASSERT(max_bias == 0.0f); + + soft_max_back_f32_cuda(src0_d, src1_d, dst_d, ncols, nrows, scale, stream); +} diff --git a/ggml/src/ggml-cuda/softmax.cuh b/ggml/src/ggml-cuda/softmax.cuh index 4ef4ff86c9c..93dfee835f6 100644 --- a/ggml/src/ggml-cuda/softmax.cuh +++ b/ggml/src/ggml-cuda/softmax.cuh @@ -3,3 +3,5 @@ #define CUDA_SOFT_MAX_BLOCK_SIZE 1024 void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_soft_max_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 81fc92202f2..6b21f407d80 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -51,6 +51,19 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) { dst[i] = x[i] / (1.0f + expf(-x[i])); } +static __global__ void silu_back_f32( + const float * grad, const float * xf, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + const float xfi = xf[i]; + const float s = 1.0f / (1.0f + expf(-xfi)); + dst[i] = grad[i] * s * (1.0f + xfi * (1.0f - s)); +} + static __global__ void tanh_f32(const float * x, float * dst, int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -173,6 +186,11 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_ silu_f32<<>>(x, dst, k); } +static void silu_back_f32_cuda(const float * grad, const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SILU_BACK_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE; + silu_back_f32<<>>(grad, x, dst, k); +} + static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE; tanh_f32<<>>(x, dst, k); @@ -284,6 +302,24 @@ void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); } +void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; // input from forward pass + const ggml_tensor * src1 = dst->src[1]; // grads of forward pass output + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + silu_back_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(src0), stream); +} + void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index c91936728ba..e7f62643a2a 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -4,6 +4,7 @@ #define CUDA_STEP_BLOCK_SIZE 256 #define CUDA_GELU_BLOCK_SIZE 256 #define CUDA_SILU_BLOCK_SIZE 256 +#define CUDA_SILU_BACK_BLOCK_SIZE 256 #define CUDA_TANH_BLOCK_SIZE 256 #define CUDA_RELU_BLOCK_SIZE 256 #define CUDA_SIGMOID_BLOCK_SIZE 256 @@ -23,6 +24,8 @@ void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 0a0dcc7e3b1..d83b1b8ad6b 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3454,12 +3454,14 @@ struct ggml_tensor * ggml_soft_max_ext( return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false); } -// ggml_soft_max_back +// ggml_soft_max_ext_back -static struct ggml_tensor * ggml_soft_max_back_impl( +static struct ggml_tensor * ggml_soft_max_ext_back_impl( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, + float scale, + float max_bias, bool inplace) { struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); @@ -3467,21 +3469,28 @@ static struct ggml_tensor * ggml_soft_max_back_impl( result->src[0] = a; result->src[1] = b; + memcpy((float *) result->op_params + 0, &scale, sizeof(float)); + memcpy((float *) result->op_params + 1, &max_bias, sizeof(float)); + return result; } -struct ggml_tensor * ggml_soft_max_back( +struct ggml_tensor * ggml_soft_max_ext_back( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_soft_max_back_impl(ctx, a, b, false); + struct ggml_tensor * b, + float scale, + float max_bias) { + return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, false); } -struct ggml_tensor * ggml_soft_max_back_inplace( +struct ggml_tensor * ggml_soft_max_ext_back_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_soft_max_back_impl(ctx, a, b, true); + struct ggml_tensor * b, + float scale, + float max_bias) { + return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, true); } // ggml_rope @@ -5080,10 +5089,10 @@ struct ggml_tensor * ggml_cross_entropy_loss_back( struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_tensor * c) { - GGML_ASSERT(ggml_are_same_shape(a, b)); - GGML_ASSERT(ggml_is_scalar(c)); + GGML_ASSERT(ggml_is_scalar(a)); + GGML_ASSERT(ggml_are_same_shape(b, c)); - struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + struct ggml_tensor * result = ggml_dup_tensor(ctx, b); result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK; result->src[0] = a; @@ -5262,7 +5271,7 @@ static void ggml_sub_or_set( } static void ggml_compute_backward( - struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, bool * grads_needed) { + struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, const bool * grads_needed) { struct ggml_tensor * tensor = cgraph->nodes[i]; struct ggml_tensor * grad = ggml_graph_get_grad(cgraph, tensor); @@ -5406,7 +5415,7 @@ static void ggml_compute_backward( if (src0_needs_grads) { float eps; memcpy(&eps, tensor->op_params, sizeof(float)); - ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, src0, grad, eps)); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, grad, src0, eps)); } } break; case GGML_OP_MUL_MAT: { @@ -5589,7 +5598,13 @@ static void ggml_compute_backward( } break; case GGML_OP_SOFT_MAX: { if (src0_needs_grads) { - ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_back(ctx, grad, tensor)); + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (const float *) tensor->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) tensor->op_params + 1, sizeof(float)); + + ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_ext_back(ctx, grad, tensor, scale, max_bias)); } GGML_ASSERT((!src1 || !src1_needs_grads) && "backward pass for softmax mask not implemented"); } break; @@ -5630,7 +5645,7 @@ static void ggml_compute_backward( const int32_t d1 = ggml_get_op_params_i32(tensor, 5); const bool is_2D = ggml_get_op_params_i32(tensor, 6) == 1; - ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, src0, grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D)); + ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, grad, src0, src1->ne, s0, s1, p0, p1, d0, d1, is_2D)); } } break; case GGML_OP_POOL_2D: { @@ -5673,7 +5688,7 @@ static void ggml_compute_backward( } break; case GGML_UNARY_OP_SILU: { if (src0_needs_grads) { - ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, src0, grad)); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, grad, src0)); } } break; case GGML_UNARY_OP_EXP: { @@ -5690,7 +5705,7 @@ static void ggml_compute_backward( } break; case GGML_OP_CROSS_ENTROPY_LOSS: { if (src0_needs_grads) { - ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, src0, src1, grad)); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, grad, src0, src1)); } GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented"); } break; From 62e24146200c8d0ae2473389f8d9bbd4efbfafb7 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Thu, 16 Jan 2025 15:16:39 -0600 Subject: [PATCH 10/64] vulkan: optimize coopmat2 q2_k dequant function (llama/11130) --- .../vulkan-shaders/dequant_funcs_cm2.comp | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp index 94b78598ea2..e768b8930c4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp @@ -101,19 +101,25 @@ layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_ block_q2_K block; }; +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2_K_packed16 { + block_q2_K_packed16 block; +}; + float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { + decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl); const f16vec2 d = bl.block.d; const uint idx = coordInBlock[1]; - const uint iqs = idx; - const uint qsi = (iqs / 128) * 32 + (iqs % 32); // 0..31 - const uint scalesi = iqs / 16; // 0..15 - const uint qsshift = ((iqs % 128) / 32) * 2; // 0,2,4,6 + const uint scalesi = (idx & 0xF0) >> 4; // 0..15 + const uint qsshift = (idx & 0x60) >> 4; // 0,2,4,6 + + uint qs = uint32_t(bl16.block.qs[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]); + qs = (qs >> qsshift) & 0x0303; + qs = unpack8(qs)[idx & 1]; - uint32_t qs = bl.block.qs[qsi]; const uint scales = bl.block.scales[scalesi]; - float16_t ret = d.x * float16_t(scales & 0xF) * float16_t((qs >> qsshift) & 3) - d.y * float16_t(scales >> 4); + float16_t ret = d.x * float16_t(scales & 0xF) * float16_t(qs) - d.y * float16_t(scales >> 4); return ret; } From 09f3c6664870dad525448e8da43e92d8c8f6fd20 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Thu, 16 Jan 2025 15:23:49 -0600 Subject: [PATCH 11/64] vulkan: optimize coopmat2 q4_k/q5_k dequant functions. (llama/11206) Do masking on whole dwords, fetch all scales at once. --- .../vulkan-shaders/dequant_funcs_cm2.comp | 80 +++++++++++-------- .../src/ggml-vulkan/vulkan-shaders/types.comp | 10 +++ 2 files changed, 57 insertions(+), 33 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp index e768b8930c4..175e31fa76e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp @@ -163,39 +163,47 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4 block_q4_K_packed16 block; }; +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed128 { + block_q4_K_packed128 block; +}; + float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl); + decodeBufQ4_K_packed128 bl128 = decodeBufQ4_K_packed128(bl); const uint idx = coordInBlock[1]; const uint b = (idx & 0x20) >> 5; // 0,1 const uint is = (idx & 0xE0) >> 5; // 0..7 - const f16vec2 loadd = bl.block.d; + uvec4 v = bl128.block.q4k[0]; + + const f16vec2 loadd = unpackFloat2x16(v.x); uint32_t sc; uint32_t mbyte; - uint32_t scidx0 = (is < 4) ? is : (is + 4); - uint32_t scidx1 = (is < 4) ? is : (is - 4); - uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0; - uint32_t scidxshift1 = (is < 4) ? 0 : 2; - uint32_t mbidx0 = is + 4; - uint32_t mbidx1 = (is < 4) ? is + 4 : is; - uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0; - uint32_t mbidxshift0 = (is < 4) ? 0 : 4; - uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0; - uint32_t mbidxshift1 = (is < 4) ? 0 : 2; + uint32_t scale0 = v.y; + uint32_t scale4 = v.z; + uint32_t scale8 = v.w; - sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1)); - mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + uint32_t sc_lo = scale0; + uint32_t mb_lo = scale4; + uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2); + uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2); + + sc = is < 4 ? sc_lo : sc_hi; + mbyte = is < 4 ? mb_lo : mb_hi; + sc = sc >> (8 * (is & 3)); + mbyte = mbyte >> (8 * (is & 3)); + sc &= 0x3F; + mbyte &= 0x3F; const float16_t d = loadd.x * float16_t(sc); const float16_t m = loadd.y * float16_t(mbyte); uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]); - qs = (qs >> (b * 4)) & 0x0F0F; - qs = unpack8(qs)[idx & 1]; + qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF; float16_t ret = d * float16_t(qs) - m; @@ -210,47 +218,53 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5 block_q5_K_packed16 block; }; +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed128 { + block_q5_K_packed128 block; +}; + float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl); + decodeBufQ5_K_packed128 bl128 = decodeBufQ5_K_packed128(bl); const uint idx = coordInBlock[1]; const uint b = (idx & 0x20) >> 5; // 0,1 const uint is = (idx & 0xE0) >> 5; // 0..7 - const uint32_t hm = 0x0101 << is; + uvec4 v = bl128.block.q5k[0]; - const f16vec2 loadd = bl.block.d; + const f16vec2 loadd = unpackFloat2x16(v.x); uint32_t sc; uint32_t mbyte; - uint32_t scidx0 = (is < 4) ? is : (is + 4); - uint32_t scidx1 = (is < 4) ? is : (is - 4); - uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0; - uint32_t scidxshift1 = (is < 4) ? 0 : 2; - uint32_t mbidx0 = is + 4; - uint32_t mbidx1 = (is < 4) ? is + 4 : is; - uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0; - uint32_t mbidxshift0 = (is < 4) ? 0 : 4; - uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0; - uint32_t mbidxshift1 = (is < 4) ? 0 : 2; + uint32_t scale0 = v.y; + uint32_t scale4 = v.z; + uint32_t scale8 = v.w; - sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1)); - mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + uint32_t sc_lo = scale0; + uint32_t mb_lo = scale4; + uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2); + uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2); + + sc = is < 4 ? sc_lo : sc_hi; + mbyte = is < 4 ? mb_lo : mb_hi; + sc = sc >> (8 * (is & 3)); + mbyte = mbyte >> (8 * (is & 3)); + sc &= 0x3F; + mbyte &= 0x3F; const float16_t d = loadd.x * float16_t(sc); const float16_t m = loadd.y * float16_t(mbyte); uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]); - qh = qh & hm; - qh = unpack8(qh)[idx & 1]; + qh = ((qh >> is) & 0x101) << 4; uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]); qs = (qs >> (b * 4)) & 0x0F0F; - qs = unpack8(qs)[idx & 1]; + qs = unpack8(qs | qh)[idx & 1]; - float16_t ret = d * (float16_t(qs) + (qh != 0 ? float16_t(16) : float16_t(0))) - m; + float16_t ret = d * (float16_t(qs)) - m; return ret; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp index f12e61bbe10..1e35b665283 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp @@ -227,6 +227,11 @@ struct block_q4_K_packed32 uint32_t qs[QUANT_K_Q4_K/2/4]; }; +struct block_q4_K_packed128 +{ + uvec4 q4k[9]; +}; + #if defined(DATA_A_Q4_K) #define QUANT_K QUANT_K_Q4_K #define A_TYPE block_q4_K @@ -252,6 +257,11 @@ struct block_q5_K_packed16 uint16_t qs[QUANT_K_Q5_K/2/2]; }; +struct block_q5_K_packed128 +{ + uvec4 q5k[11]; +}; + #if defined(DATA_A_Q5_K) #define QUANT_K QUANT_K_Q5_K #define A_TYPE block_q5_K From 7183a1eb72bb4415e3b98a83884ac5d69f9eb6bc Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Thu, 16 Jan 2025 15:47:10 -0600 Subject: [PATCH 12/64] vulkan: support copy from f32 to q4_0/q4_1/q5_0/q5_1/q8_0/iq4_nl (llama/11166) * vulkan: support copy from f32 to q4_0/q4_1/q5_0/q5_1/q8_0/iq4_nl Shaders are based on cpy.cu. * vulkan: support copy from q4_0/q4_1/q5_0/q5_1/q8_0/iq4_nl to f32 * ggml: copy q->f32 assumes some contiguity in the destination --- ggml/src/ggml-cpu/ggml-cpu.c | 55 ++++ ggml/src/ggml-vulkan/ggml-vulkan.cpp | 77 +++++- .../vulkan-shaders/copy_from_quant.comp | 51 ++++ .../vulkan-shaders/copy_to_quant.comp | 237 ++++++++++++++++++ .../vulkan-shaders/generic_unary_head.comp | 20 ++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 5 + 6 files changed, 440 insertions(+), 5 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 8040e20fe9d..cd3f815b85f 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -3967,6 +3967,57 @@ static void ggml_compute_forward_dup_bytes( } } +static void ggml_compute_forward_dup_q( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + const enum ggml_type type = src0->type; + ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float; + + size_t qk = ggml_blck_size(type); + const int64_t nr = ggml_nelements(src1) / qk; + + // destination must be contiguous in the first dimension + GGML_ASSERT(nb10 == ggml_type_size(dst->type)); + // must either have first dimension large enough to hold a row, or fully contiguous + GGML_ASSERT((ne10 % qk) == 0 || ggml_is_contiguous(dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + + uint32_t i = ir * qk; + + const int64_t i03 = i/(ne00 * ne01 * ne02); + const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); + const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; + const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00; + const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; + + const int64_t i13 = i/(ne10 * ne11 * ne12); + const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); + const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; + const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; + const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13; + + dequantize_row_q( + (const void *) ((char *) src0->data + x_offset), + (float *) ((char *) dst->data + dst_offset), qk); + } +} + static void ggml_compute_forward_dup( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -3993,6 +4044,10 @@ static void ggml_compute_forward_dup( } break; default: { + if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) { + ggml_compute_forward_dup_q(params, dst); + break; + } GGML_ABORT("fatal error"); } } diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 649146d7b45..8e3e9149575 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -228,6 +228,8 @@ struct vk_device_struct { vk_pipeline pipeline_repeat_f32; vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16; vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16; + vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT]; + vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_norm_f32; vk_pipeline pipeline_group_norm_f32; vk_pipeline pipeline_rms_norm_f32; @@ -1965,6 +1967,20 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_1], "cpy_q4_1_f32", cpy_q4_1_f32_len, cpy_q4_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_0], "cpy_q5_0_f32", cpy_q5_0_f32_len, cpy_q5_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_1], "cpy_q5_1_f32", cpy_q5_1_f32_len, cpy_q5_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q8_0], "cpy_q8_0_f32", cpy_q8_0_f32_len, cpy_q8_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_IQ4_NL], "cpy_iq4_nl_f32", cpy_iq4_nl_f32_len, cpy_iq4_nl_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); ggml_vk_create_pipeline(device, device->pipeline_add_f32_norepeat, "add_f32_norepeat", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); @@ -3689,6 +3705,33 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_cpy_f16_f16; } } + if (src->type == GGML_TYPE_F32) { + switch (to) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: + return ctx->device->pipeline_cpy_f32_quant[to]; + default: + break; + } + } + + if (to == GGML_TYPE_F32) { + switch (src->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: + return ctx->device->pipeline_cpy_quant_f32[src->type]; + default: + break; + } + } std::cerr << "Missing CPY op for types: " << ggml_type_name(src->type) << " " << ggml_type_name(to) << std::endl; GGML_ABORT("fatal error"); @@ -5160,7 +5203,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; std::cerr << "), " << ggml_op_name(op) << ", " << (dryrun ? "dryrun" : "") << ")"); - GGML_ASSERT(op == GGML_OP_GET_ROWS || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT + GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)); // NOLINT GGML_ASSERT(dst->buffer != nullptr); const uint64_t ne00 = src0->ne[0]; @@ -7905,12 +7948,36 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm { ggml_type src0_type = op->src[0]->type; ggml_type src1_type = op->src[1] != nullptr ? op->src[1]->type : src0_type; - if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { - return true; + + if (src0_type == GGML_TYPE_F32) { + switch (src1_type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: + return true; + default: + break; + } } - if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) { - return true; + if (src1_type == GGML_TYPE_F32) { + switch (src0_type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: + return true; + default: + break; + } } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { return true; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp new file mode 100644 index 00000000000..c09bf496bf3 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp @@ -0,0 +1,51 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" +#include "dequant_funcs.comp" + +#if defined(DATA_A_IQ4_NL) +// 16 invocations needed for init_iq4nl_shmem +layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in; +#else +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; +#endif + +void main() { +#if defined(DATA_A_IQ4_NL) + init_iq4nl_shmem(); + if (gl_LocalInvocationIndex.x != 0) { + return; + } +#endif + + const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K; + + if (idx >= p.ne) { + return; + } + + uint dst_idx = get_doffset() + dst_idx(idx); + uint src_idx = src0_idx_quant(idx, QUANT_K); + + const uint a_offset = 0; + const uint ib = src_idx; + const vec2 dm = get_dm(ib, a_offset); + + [[unroll]] for (int j = 0; j < QUANT_K; j += 4) { + vec4 v = dequantize4(ib, j / QUANT_R, a_offset); + v = v * dm.x + vec4(dm.y); + +#if QUANT_R == 2 + data_d[dst_idx + j/2 + 0] = v[0]; + data_d[dst_idx + j/2 + QUANT_K/2 + 0] = v[1]; + data_d[dst_idx + j/2 + 1] = v[2]; + data_d[dst_idx + j/2 + QUANT_K/2 + 1] = v[3]; +#else + data_d[dst_idx + j + 0] = v[0]; + data_d[dst_idx + j + 1] = v[1]; + data_d[dst_idx + j + 2] = v[2]; + data_d[dst_idx + j + 3] = v[3]; +#endif + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp new file mode 100644 index 00000000000..ccf5b980a8a --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp @@ -0,0 +1,237 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +#if defined(DATA_A_IQ4_NL) +// 16 invocations needed for init_iq4nl_shmem +layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in; +#else +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; +#endif + +layout (binding = 0) readonly buffer S {float data_s[];}; +layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];}; + +#if defined(DATA_A_Q4_0) +void quantize(uint dst_idx, uint src_idx) +{ + float amax = 0.0; + float vmax = 0.0; + + [[unroll]] for (int j = 0; j < QUANT_K_Q4_0; ++j) { + const float v = data_s[src_idx + j]; + if (amax < abs(v)) { + amax = abs(v); + vmax = v; + } + } + + const float d = vmax / -8; + const float id = (d != 0.0) ? 1.0/d : 0.0; + + data_q[dst_idx].d = float16_t(d); + + [[unroll]] for (int j = 0; j < QUANT_K_Q4_0/2; ++j) { + const float x0 = data_s[src_idx + 0 + j]*id; + const float x1 = data_s[src_idx + QUANT_K_Q4_0/2 + j]*id; + + const uint xi0 = min(15, int(x0 + 8.5)); + const uint xi1 = min(15, int(x1 + 8.5)); + + data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4)); + } +} +#endif + +#if defined(DATA_A_Q4_1) +void quantize(uint dst_idx, uint src_idx) +{ + float vmin = 1.0/0.0; + float vmax = -vmin; + + [[unroll]] for (int j = 0; j < QUANT_K_Q4_1; ++j) { + const float v = data_s[src_idx + j]; + + if (v < vmin) vmin = v; + if (v > vmax) vmax = v; + } + + const float d = (vmax - vmin) / ((1 << 4) - 1); + const float id = (d != 0.0) ? 1.0/d : 0.0; + + data_q[dst_idx].d = float16_t(d); + data_q[dst_idx].m = float16_t(vmin); + + [[unroll]] for (int j = 0; j < QUANT_K_Q4_1/2; ++j) { + const float x0 = (data_s[src_idx + 0 + j] - vmin)*id; + const float x1 = (data_s[src_idx + QUANT_K_Q4_1/2 + j] - vmin)*id; + + const uint xi0 = min(15, int(x0 + 0.5)); + const uint xi1 = min(15, int(x1 + 0.5)); + + data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4)); + } +} +#endif + +#if defined(DATA_A_Q5_0) +void quantize(uint dst_idx, uint src_idx) +{ + float amax = 0.0; + float vmax = 0.0; + + [[unroll]] for (int j = 0; j < QUANT_K_Q5_0; ++j) { + const float v = data_s[src_idx + j]; + if (amax < abs(v)) { + amax = abs(v); + vmax = v; + } + } + + const float d = vmax / -16; + const float id = (d != 0.0) ? 1.0/d : 0.0; + + data_q[dst_idx].d = float16_t(d); + + uint32_t qh = 0; + [[unroll]] for (int j = 0; j < QUANT_K_Q5_0/2; ++j) { + const float x0 = data_s[src_idx + 0 + j]*id; + const float x1 = data_s[src_idx + QUANT_K_Q5_0/2 + j]*id; + + const uint xi0 = min(31, int(x0 + 16.5)); + const uint xi1 = min(31, int(x1 + 16.5)); + + data_q[dst_idx].qs[j] = uint8_t((xi0 & 0xf) | ((xi1 & 0xf) << 4)); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QUANT_K_Q5_0/2); + } + data_q[dst_idx].qh[0] = uint16_t(qh & 0xFFFF); + data_q[dst_idx].qh[1] = uint16_t(qh >> 16); +} +#endif + +#if defined(DATA_A_Q5_1) +void quantize(uint dst_idx, uint src_idx) +{ + float min = data_s[src_idx + 0]; + float max = min; + + [[unroll]] for (int j = 1; j < QUANT_K_Q5_1; ++j) { + const float v = data_s[src_idx + j]; + min = v < min ? v : min; + max = v > max ? v : max; + } + + const float d = (max - min) / 31; + const float id = (d != 0) ? 1.0/d : 0.0; + + data_q[dst_idx].d = float16_t(d); + data_q[dst_idx].m = float16_t(min); + + uint32_t qh = 0; + [[unroll]] for (int j = 0; j < QUANT_K_Q5_1/2; ++j) { + const float x0 = (data_s[src_idx + 0 + j] - min)*id; + const float x1 = (data_s[src_idx + QUANT_K_Q5_1/2 + j] - min)*id; + + const uint xi0 = uint(x0 + 0.5); + const uint xi1 = uint(x1 + 0.5); + + data_q[dst_idx].qs[j] = uint8_t((xi0 & 0xf) | ((xi1 & 0xf) << 4)); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QUANT_K_Q5_1/2); + } + data_q[dst_idx].qh = qh; +} +#endif + +#if defined(DATA_A_Q8_0) +void quantize(uint dst_idx, uint src_idx) +{ + float amax = 0.0; // absolute max + + [[unroll]] for (int j = 0; j < QUANT_K_Q8_0; j++) { + const float v = data_s[src_idx + j]; + amax = max(amax, abs(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = (d != 0.0) ? 1.0/d : 0.0; + + data_q[dst_idx].d = float16_t(d); + + [[unroll]] for (int j = 0; j < QUANT_K_Q8_0; ++j) { + const float x0 = data_s[src_idx + j]*id; + + data_q[dst_idx].qs[j] = int8_t(round(x0)); + } +} +#endif + +#if defined(DATA_A_IQ4_NL) +uint best_index(float x) { + if (x <= kvalues_iq4nl[0]) return 0; + if (x >= kvalues_iq4nl[15]) return 15; + int ml = 0, mu = 15; + while (mu-ml > 1) { + int mav = (ml+mu)/2; + if (x < kvalues_iq4nl[mav]) mu = mav; else ml = mav; + } + return x - kvalues_iq4nl[mu-1] < kvalues_iq4nl[mu] - x ? mu-1 : mu; +} + +void quantize(uint dst_idx, uint src_idx) +{ + float amax = 0.0; + float vmax = 0.0; + + [[unroll]] for (int j = 0; j < QUANT_K_IQ4_NL; ++j) { + const float v = data_s[src_idx + j]; + if (amax < abs(v)) { + amax = abs(v); + vmax = v; + } + } + + float d = vmax / kvalues_iq4nl[0]; + const float id = (d != 0.0) ? 1.0/d : 0.0; + + float sumqx = 0, sumq2 = 0; + [[unroll]] for (int j = 0; j < QUANT_K_IQ4_NL/2; ++j) { + const float x0 = data_s[src_idx + 0 + j]*id; + const float x1 = data_s[src_idx + QUANT_K_IQ4_NL/2 + j]*id; + const uint xi0 = best_index(x0); + const uint xi1 = best_index(x1); + data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4)); + const float v0 = kvalues_iq4nl[xi0]; + const float v1 = kvalues_iq4nl[xi1]; + const float w0 = data_s[src_idx + 0 + j]*data_s[src_idx + 0 + j]; + const float w1 = data_s[src_idx + QUANT_K_IQ4_NL/2 + j]*data_s[src_idx + QUANT_K_IQ4_NL/2 + j]; + sumqx += w0*v0*data_s[src_idx + j] + w1*v1*data_s[src_idx + QUANT_K_IQ4_NL/2 + j]; + sumq2 += w0*v0*v0 + w1*v1*v1; + } + + data_q[dst_idx].d = float16_t(sumq2 > 0 ? sumqx/sumq2 : d); + +} +#endif + +void main() { +#if defined(DATA_A_IQ4_NL) + init_iq4nl_shmem(); + if (gl_LocalInvocationIndex.x != 0) { + return; + } +#endif + + const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K; + + if (idx >= p.ne) { + return; + } + + uint dst_idx = dst_idx_quant(idx, QUANT_K); + uint src_idx = get_aoffset() + src0_idx(idx); + + quantize(dst_idx, src_idx); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp index 68d1bc9f145..8dc9d360d52 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp @@ -54,3 +54,23 @@ uint dst_idx(uint idx) { const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10; } + +uint src0_idx_quant(uint idx, uint qk) { + const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L); + const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; + const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L); + const uint i02_offset = i02*p.ne01*p.ne00; + const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L); + const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00; + return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + (i00/qk)*p.nb00; +} + +uint dst_idx_quant(uint idx, uint qk) { + const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L); + const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; + const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L); + const uint i12_offset = i12*p.ne11*p.ne10; + const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L); + const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; + return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + (i10/qk)*p.nb10; +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 2438399174d..b7890ef1e74 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -417,6 +417,11 @@ void process_shaders() { string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { + string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + } + string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}}); From fdc21fc87b9ad152d76bede79a98313c4ceb4ac2 Mon Sep 17 00:00:00 2001 From: Radoslav Gerganov Date: Fri, 17 Jan 2025 10:57:09 +0200 Subject: [PATCH 13/64] rpc : early register backend devices (llama/11262) Early register RPC devices and do not propagate RPC specifics in the llama model structures. ref: #10609 --- ggml/include/ggml-backend.h | 2 ++ ggml/src/ggml-backend-impl.h | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index 7221a083092..fc9571c82c9 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -203,6 +203,8 @@ extern "C" { // Backend registry // + GGML_API void ggml_backend_device_register(ggml_backend_dev_t device); + // Backend (reg) enumeration GGML_API size_t ggml_backend_reg_count(void); GGML_API ggml_backend_reg_t ggml_backend_reg_get(size_t index); diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h index 36d72e95f02..d1c2d76d897 100644 --- a/ggml/src/ggml-backend-impl.h +++ b/ggml/src/ggml-backend-impl.h @@ -208,7 +208,6 @@ extern "C" { // Internal backend registry API GGML_API void ggml_backend_register(ggml_backend_reg_t reg); - GGML_API void ggml_backend_device_register(ggml_backend_dev_t device); // Add backend dynamic loading support to the backend From 668306ff2ba2c5f46c4851b370bd15962be3fa91 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sat, 18 Jan 2025 02:26:50 -0600 Subject: [PATCH 14/64] vulkan: fix coopmat2 flash attention for non-contiguous inputs (llama/11281) Add code similar to mul_mm_cm2 to force alignment of strides, to avoid a performance regression. Add noncontiguous FA tests in test-backend-ops. Fixes #11268. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 43 ++++++++++++++++--- .../vulkan-shaders/flash_attn_cm2.comp | 20 +++++++++ 2 files changed, 57 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 8e3e9149575..437e9cdcc35 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -386,10 +386,13 @@ struct vk_flash_attn_push_constants { uint32_t nev3; uint32_t nem1; + uint32_t nb01; uint32_t nb02; uint32_t nb03; + uint32_t nb11; uint32_t nb12; uint32_t nb13; + uint32_t nb21; uint32_t nb22; uint32_t nb23; uint32_t nb31; @@ -4809,7 +4812,14 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } assert(pipelines); - bool aligned = (KV % pipelines[1]->align) == 0; + const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type)); + const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type)); + const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type)); + + bool aligned = (KV % pipelines[1]->align) == 0 && + // the "aligned" shader variant will forcibly align strides, for performance + (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0; + vk_pipeline pipeline = pipelines[aligned]; assert(pipeline); @@ -4845,15 +4855,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx if (ctx->device->uma) { ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset); - ggml_vk_host_get(ctx->device, k->data, d_K, q_buf_offset); - ggml_vk_host_get(ctx->device, v->data, d_V, q_buf_offset); - ggml_vk_host_get(ctx->device, dst->data, d_D, q_buf_offset); + ggml_vk_host_get(ctx->device, k->data, d_K, k_buf_offset); + ggml_vk_host_get(ctx->device, v->data, d_V, v_buf_offset); + ggml_vk_host_get(ctx->device, dst->data, d_D, d_buf_offset); Q_uma = d_Q != nullptr; K_uma = d_K != nullptr; V_uma = d_V != nullptr; D_uma = d_D != nullptr; if (mask) { - ggml_vk_host_get(ctx->device, mask->data, d_M, q_buf_offset); + ggml_vk_host_get(ctx->device, mask->data, d_M, m_buf_offset); M_uma = d_M != nullptr; } } @@ -4891,7 +4901,18 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } } - const vk_flash_attn_push_constants pc = { N, KV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, (uint32_t)neq2, (uint32_t)neq3, (uint32_t)nek2, (uint32_t)nek3, (uint32_t)nev2, (uint32_t)nev3, nem1, (uint32_t)nbq2, (uint32_t)nbq3, (uint32_t)nbk2, (uint32_t)nbk3, (uint32_t)nbv2, (uint32_t)nbv3, nbm1, scale, max_bias, logit_softcap, mask != nullptr, n_head_log2, m0, m1 }; + const vk_flash_attn_push_constants pc = { N, KV, + (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, + (uint32_t)neq2, (uint32_t)neq3, + (uint32_t)nek2, (uint32_t)nek3, + (uint32_t)nev2, (uint32_t)nev3, + nem1, + q_stride, (uint32_t)nbq2, (uint32_t)nbq3, + k_stride, (uint32_t)nbk2, (uint32_t)nbk3, + v_stride, (uint32_t)nbv2, (uint32_t)nbv3, + nbm1, + scale, max_bias, logit_softcap, + mask != nullptr, n_head_log2, m0, m1 }; ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE}, @@ -8668,6 +8689,7 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) { ggml_tensor * src0 = tensor->src[0]; ggml_tensor * src1 = tensor->src[1]; ggml_tensor * src2 = tensor->src[2]; + ggml_tensor * src3 = tensor->src[3]; void * tensor_data = tensor->data; @@ -8730,6 +8752,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) { if (src2 != nullptr) { std::cerr << "src2=" << src2 << " src2->name=" << src2->name << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; } + if (src3 != nullptr) { + std::cerr << "src3=" << src3 << " src3->name=" << src3->name << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl; + } std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; std::cerr << std::endl << "Result:" << std::endl; ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3); @@ -8774,6 +8799,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) { if (src2 != nullptr) { std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; } + if (src3 != nullptr) { + std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl; + } std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; std::cerr << std::endl << "Result:" << std::endl; ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0); @@ -8796,6 +8824,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) { if (src2 != nullptr) { std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; } + if (src3 != nullptr) { + std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl; + } std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; std::cerr << std::endl << "Result:" << std::endl; ggml_vk_print_tensor_area(tensor, tensor_data, first_error[0], first_error[1], first_error[2], first_error[3]); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index c5be8131b30..ca3a59b8faa 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -42,10 +42,13 @@ layout (push_constant) uniform parameter { uint32_t nev3; uint32_t nem1; + uint32_t nb01; uint32_t nb02; uint32_t nb03; + uint32_t nb11; uint32_t nb12; uint32_t nb13; + uint32_t nb21; uint32_t nb22; uint32_t nb23; uint32_t nb31; @@ -146,6 +149,23 @@ void main() { tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D); tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D); + // nb?1 are already divided by the type size and are in units of elements + uint32_t q_stride = p.nb01; + uint32_t k_stride = p.nb11; + uint32_t v_stride = p.nb21; + // hint to the compiler that strides are aligned for the aligned variant of the shader + if (Clamp != gl_CooperativeMatrixClampModeConstantNV) + { + q_stride &= ~7; +#if !defined(BLOCK_SIZE) + k_stride &= ~7; + v_stride &= ~7; +#endif + } + tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1); + tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1); + tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1); + coopmat Q; coopmat Qf16; From 90171055f390cf206199e969abf3917ec9da8c27 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 18 Jan 2025 16:18:15 +0200 Subject: [PATCH 15/64] cmake : add sanitizer flags for llama.cpp (llama/11279) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * cmake : add sanitizer flags for llama.cpp ggml-ci * tests : fix compile warnings ggml-ci * cmake : move sanitizer flags to llama_add_compile_flags ggml-ci * cmake : move llama.cpp compile flags to top level lists ggml-ci * cmake : apply only sanitizer flags at top level ggml-ci * tests : fix gguf context use in same_tensor_data * gguf-test: tensor data comparison * dummy : trigger ggml-ci * unicode : silence gcc warnings ggml-ci * ci : use sanitizer builds only in Debug mode ggml-ci * cmake : add status messages [no ci] --------- Co-authored-by: Johannes Gäßler --- ggml/src/gguf.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp index 655ed600a17..ab13669c567 100644 --- a/ggml/src/gguf.cpp +++ b/ggml/src/gguf.cpp @@ -648,6 +648,10 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par ok = ok && data != nullptr; + if (ok) { + ggml_set_name(data, "GGUF tensor data binary blob"); + } + // read the binary blob with the tensor data ok = ok && gr.read(data->data, ctx->size); From d507b4cebe085da9b99aaebff82b5282d64279a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Scipione?= Date: Sun, 19 Jan 2025 14:33:34 +0100 Subject: [PATCH 16/64] SYCL: Introducing memory host pool (llama/11251) * Implement host pool for matrix_info Creating a new memory pool on the host to store memory location for matrix_info needed to launch gemm_batch from oneMKL/oneMath. Removing complex support in gemm_batch since it is not used in llama.cpp * Remove unnecessary headers and cast * Reorder member variable to avoid warning on initialization * Formatting * Remove unused variable * Address PR review feedback - remove warning --------- Signed-off-by: nscipione --- ggml/src/ggml-sycl/common.hpp | 13 +++ ggml/src/ggml-sycl/dpct/helper.hpp | 135 +++++++++-------------------- ggml/src/ggml-sycl/ggml-sycl.cpp | 92 ++++++++++++++++++-- 3 files changed, 137 insertions(+), 103 deletions(-) diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index e9500f3a168..abad847ca81 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -333,8 +333,12 @@ struct ggml_backend_sycl_context { // pool std::unique_ptr pools[GGML_SYCL_MAX_DEVICES]; + std::unique_ptr host_pools[GGML_SYCL_MAX_DEVICES]; + static std::unique_ptr new_pool_for_device(queue_ptr qptr, int device); + static std::unique_ptr new_pool_for_host(queue_ptr qptr, int device); + ggml_sycl_pool & pool(int device) { if (pools[device] == nullptr) { pools[device] = new_pool_for_device(stream(device,0), device); @@ -345,6 +349,15 @@ struct ggml_backend_sycl_context { ggml_sycl_pool & pool() { return pool(device); } + + ggml_sycl_pool & host_pool(int device) { + if (host_pools[device] == nullptr) { + host_pools[device] = new_pool_for_host(stream(device, 0), device); + } + return *host_pools[device]; + } + + ggml_sycl_pool & host_pool() { return host_pool(device); } }; // common device functions diff --git a/ggml/src/ggml-sycl/dpct/helper.hpp b/ggml/src/ggml-sycl/dpct/helper.hpp index e167948e7a3..c96395be613 100644 --- a/ggml/src/ggml-sycl/dpct/helper.hpp +++ b/ggml/src/ggml-sycl/dpct/helper.hpp @@ -82,6 +82,14 @@ inline std::string get_device_backend_and_type(const sycl::device &device) { return device_type.str(); } +template struct matrix_info_t { + oneapi::mkl::transpose transpose_info[2]; + Ts value_info[2]; + std::int64_t size_info[3]; + std::int64_t ld_info[3]; + std::int64_t groupsize_info; +}; + namespace dpct { typedef sycl::queue *queue_ptr; @@ -1727,26 +1735,13 @@ namespace dpct }; template - inline void gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, int k, - const void *alpha, const void **a, int lda, - const void **b, int ldb, const void *beta, void **c, - int ldc, int batch_size) - { - struct matrix_info_t - { - oneapi::mkl::transpose transpose_info[2]; - Ts value_info[2]; - std::int64_t size_info[3]; - std::int64_t ld_info[3]; - std::int64_t groupsize_info; - }; - + inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, + int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b, + int ldb, const void * beta, void ** c, int ldc, int batch_size, + matrix_info_t * matrix_info) { Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); - matrix_info_t *matrix_info = - (matrix_info_t *)std::malloc(sizeof(matrix_info_t)); matrix_info->transpose_info[0] = a_trans; matrix_info->transpose_info[1] = b_trans; matrix_info->value_info[0] = alpha_value; @@ -1763,23 +1758,18 @@ namespace dpct sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( oneapi::mkl::backend_selector{ q }, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1, - matrix_info->size_info + 2, matrix_info->value_info, reinterpret_cast(a), - matrix_info->ld_info, reinterpret_cast(b), matrix_info->ld_info + 1, - matrix_info->value_info + 1, reinterpret_cast(c), matrix_info->ld_info + 2, 1, - &(matrix_info->groupsize_info)); + matrix_info->size_info + 2, reinterpret_cast(matrix_info->value_info), + reinterpret_cast(a), matrix_info->ld_info, reinterpret_cast(b), + matrix_info->ld_info + 1, reinterpret_cast(matrix_info->value_info + 1), + reinterpret_cast(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); #else sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info, - matrix_info->size_info + 1, matrix_info->size_info + 2, matrix_info->value_info, + matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast(matrix_info->value_info), reinterpret_cast(a), matrix_info->ld_info, reinterpret_cast(b), - matrix_info->ld_info + 1, matrix_info->value_info + 1, reinterpret_cast(c), - matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); + matrix_info->ld_info + 1, reinterpret_cast(matrix_info->value_info + 1), + reinterpret_cast(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); #endif - - q.submit([&](sycl::handler &cgh) - { - cgh.depends_on(e); - cgh.host_task([=] { std::free(matrix_info); }); }); } template @@ -2422,25 +2412,11 @@ namespace dpct /// \param [in] ldc Leading dimension of C. /// \param [in] batch_size Specifies the number of matrix multiply operations to perform. /// \param [in] scaling_type Data type of the scaling factors. - inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, int k, - const void *alpha, const void *a[], - library_data_t a_type, int lda, const void *b[], - library_data_t b_type, int ldb, const void *beta, - void *c[], library_data_t c_type, int ldc, - int batch_size, library_data_t scaling_type) - { - if (scaling_type == library_data_t::real_float && - c_type == library_data_t::complex_float) - { - scaling_type = library_data_t::complex_float; - } - else if (scaling_type == library_data_t::real_double && - c_type == library_data_t::complex_double) - { - scaling_type = library_data_t::complex_double; - } - + inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, + int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda, + const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[], + library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type, + matrix_info_t * matrix_info) { std::uint64_t key = detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); switch (key) @@ -2449,48 +2425,24 @@ namespace dpct library_data_t::real_float, library_data_t::real_float, library_data_t::real_float, library_data_t::real_float): { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, + beta, c, ldc, batch_size, matrix_info); break; } case detail::get_type_combination_id( library_data_t::real_double, library_data_t::real_double, library_data_t::real_double, library_data_t::real_double): { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_float, library_data_t::complex_float, - library_data_t::complex_float, library_data_t::complex_float): - { - detail::gemm_batch_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_double, library_data_t::complex_double, - library_data_t::complex_double, library_data_t::complex_double): - { - detail::gemm_batch_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, + beta, c, ldc, batch_size, matrix_info); break; } case detail::get_type_combination_id( library_data_t::real_half, library_data_t::real_half, library_data_t::real_half, library_data_t::real_half): { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, - a, lda, b, ldb, beta, c, ldc, - batch_size); + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); break; } #ifdef __INTEL_MKL__ @@ -2498,19 +2450,16 @@ namespace dpct library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float): { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); break; } case detail::get_type_combination_id( library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float, library_data_t::real_float): { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, - b, ldb, beta, c, ldc, batch_size); + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); break; } #endif @@ -2522,10 +2471,9 @@ namespace dpct dpct::get_value(reinterpret_cast(alpha), q); float beta_float = dpct::get_value(reinterpret_cast(beta), q); - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, &alpha_float, - a, lda, b, ldb, &beta_float, c, ldc, - batch_size); + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc, batch_size, + matrix_info); break; } case detail::get_type_combination_id( @@ -2533,8 +2481,7 @@ namespace dpct library_data_t::real_float, library_data_t::real_float): { detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); break; } case detail::get_type_combination_id( @@ -2542,8 +2489,7 @@ namespace dpct library_data_t::real_float, library_data_t::real_float): { detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); break; } case detail::get_type_combination_id( @@ -2557,8 +2503,7 @@ namespace dpct sycl::half alpha_half(alpha_value); sycl::half beta_half(beta_value); detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, - batch_size); + q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, batch_size, matrix_info); break; } default: diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 5272ca45423..ed4d8bb8bae 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -1173,6 +1173,85 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool { } }; +struct ggml_sycl_pool_host : public ggml_sycl_pool { + queue_ptr qptr; + int device; + + inline static int counter{ 0 }; + + struct ggml_sycl_buffer { + void * ptr = nullptr; + size_t size = 0; + }; + + // Set arbitrarly to 64 + static constexpr int MAX_POOL_SIZE{ 64 }; + std::vector buffer_pool = std::vector(MAX_POOL_SIZE); + size_t pool_size = 0; + + explicit ggml_sycl_pool_host(queue_ptr qptr_, int device_) : qptr(qptr_), device(device_) {} + + ~ggml_sycl_pool_host() { + for (int i = 0; i < MAX_POOL_SIZE; ++i) { + ggml_sycl_buffer & b = buffer_pool[i]; + if (b.ptr != nullptr) { + SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr))); + b.ptr = nullptr; + pool_size -= b.size; + b.size = 0; + } + } + counter = 0; + } + + void * alloc(size_t size, size_t * actual_size) override { + if (counter == MAX_POOL_SIZE) { + ggml_sycl_buffer b = buffer_pool[0]; + void * ptr = b.ptr; + *actual_size = b.size; + counter = 1; + return ptr; + } + ggml_sycl_buffer & b = buffer_pool[counter]; + + if (b.ptr == nullptr) { + void * ptr; + + SYCL_CHECK(CHECK_TRY_ERROR(ptr = (void *) sycl::malloc_host(size, *qptr))); + if (!ptr) { + GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on host\n", __func__, size); + return nullptr; + } + pool_size += size; + *actual_size = size; + counter = counter + 1; + return ptr; + } else { + ++counter; + b.size = size; + return b.ptr; + } + } + + void free(void * ptr, size_t size) override { + // if the pool is not completed add the pointer to it in place of the first nullptr found. + // Otherwise do nothing, pointers will be freed once the pool is deallocated. + for (int i = 0; i < MAX_POOL_SIZE; ++i) { + ggml_sycl_buffer & b = buffer_pool[i]; + if (b.ptr == nullptr) { + b.ptr = ptr; + b.size = size; + return; + } + } + } +}; + +std::unique_ptr ggml_backend_sycl_context::new_pool_for_host(queue_ptr qptr, int device) { + // return pool for the host to speed up memory management + return std::unique_ptr(new ggml_sycl_pool_host(qptr, device)); +} + std::unique_ptr ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) { // TBD: NO VMM support // if (ggml_sycl_info().devices[device].vmm) { @@ -3363,6 +3442,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, ggml_sycl_pool_alloc ptrs_src(ctx.pool(), 2*ne23); ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23); + ggml_sycl_pool_alloc> matrix_info(ctx.host_pool(), 1); sycl::range<3> block_dims(1, ne12, ne13); /* @@ -3391,14 +3471,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, }); } SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( - *main_stream, oneapi::mkl::transpose::trans, - oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha, - (const void **)(ptrs_src.get() + 0 * ne23), - dpct::library_data_t::real_half, nb01 / nb00, - (const void **)(ptrs_src.get() + 1 * ne23), - dpct::library_data_t::real_half, nb11 / nb10, beta, - (void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, - cu_compute_type))); + *main_stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha, + (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00, + (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta, + (void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get()))); } } catch (sycl::exception const &exc) { From 0dcada42d4af59d66bf7e48c4d9fac842e8285b8 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Mon, 20 Jan 2025 10:38:32 -0600 Subject: [PATCH 17/64] vulkan: fix coopmat2 validation failures (llama/11284) mul mat and flash attention shaders were loading f32 types directly into A/B matrices, which happens to work but is technically invalid usage. For FA, we can load it as an Accumulator matrix and convert and this is not in the inner loop and is cheap enough. For mul mat, it's more efficient to do this conversion in a separate pass and have the input(s) be f16. coopmat2 requires SPIR-V 1.6 (related using to LocalSizeId). LocalSizeId requires maintenance4 be enabled, and SPIR-V 1.6 requires Vulkan 1.3. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 65 +++++++++++-------- .../vulkan-shaders/flash_attn_cm2.comp | 2 +- .../vulkan-shaders/mul_mm_cm2.comp | 35 +++------- .../vulkan-shaders/vulkan-shaders-gen.cpp | 7 +- 4 files changed, 53 insertions(+), 56 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 437e9cdcc35..78c2f5c45ad 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -29,8 +29,6 @@ #include "ggml-vulkan-shaders.hpp" -#define VK_API_VERSION VK_API_VERSION_1_2 - #define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) #define VK_VENDOR_ID_AMD 0x1002 @@ -1614,11 +1612,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ - CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3) - CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3) - CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) @@ -1631,21 +1625,18 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) - - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) #undef CREATE_MM #undef CREATE_MM2 } else @@ -2287,6 +2278,14 @@ static vk_device ggml_vk_get_device(size_t idx) { } #endif + VkPhysicalDeviceMaintenance4Features maint4_features {}; + maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES; + if (maintenance4_support) { + last_struct->pNext = (VkBaseOutStructure *)&maint4_features; + last_struct = (VkBaseOutStructure *)&maint4_features; + device_extensions.push_back("VK_KHR_maintenance4"); + } + vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2); device->fp16 = device->fp16 && vk12_features.shaderFloat16; @@ -2662,7 +2661,14 @@ void ggml_vk_instance_init() { vk_instance_initialized = true; - vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, VK_API_VERSION }; + uint32_t api_version = vk::enumerateInstanceVersion(); + + if (api_version < VK_API_VERSION_1_2) { + std::cerr << "ggml_vulkan: Error: Vulkan 1.2 required." << std::endl; + GGML_ABORT("fatal error"); + } + + vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, api_version }; const std::vector instance_extensions = vk::enumerateInstanceExtensionProperties(); const bool validation_ext = ggml_vk_instance_validation_ext_available(instance_extensions); @@ -2972,7 +2978,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co } } - GGML_ASSERT(src1_type == GGML_TYPE_F32); + GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16)); switch (src0_type) { case GGML_TYPE_Q4_0: @@ -3812,8 +3818,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub src1_uma = d_Qy != nullptr; } - const bool x_non_contig = !ggml_vk_dim01_contiguous(src0); - // Reformat and convert to fp16 if src1 is non-contiguous, or for coopmat2 for better perf + // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf + const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) || + !ggml_vk_dim01_contiguous(src0); const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || !ggml_vk_dim01_contiguous(src1); @@ -4393,8 +4400,11 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& ids_uma = d_ids != nullptr; } - const bool x_non_contig = !ggml_vk_dim01_contiguous(src0); - const bool y_non_contig = !ggml_vk_dim01_contiguous(src1); + // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf + const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) || + !ggml_vk_dim01_contiguous(src0); + const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || + !ggml_vk_dim01_contiguous(src1); const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; @@ -4404,7 +4414,8 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig; if (qx_needs_dequant) { - GGML_ABORT("fatal error"); + // Fall back to dequant + f16 mulmat + mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]); } // Not implemented diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index ca3a59b8faa..3735d0dbbae 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -166,7 +166,7 @@ void main() { tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1); tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1); - coopmat Q; + coopmat Q; coopmat Qf16; uint32_t q_offset = iq2*p.nb02+iq3*p.nb03; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index cbfa5dce1b0..57f9e724513 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -57,17 +57,13 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; #if QUANT_K > 1 #define DECODEFUNCA , dequantFuncA -#define MAT_A_TYPE float16_t #include "dequant_funcs_cm2.comp" #else #define DECODEFUNCA -#define MAT_A_TYPE A_TYPE #endif -#define MAT_B_TYPE B_TYPE - #ifdef MUL_MAT_ID layout (binding = 3) readonly buffer IDS {int data_ids[];}; @@ -236,16 +232,13 @@ void main() { for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { - coopmat mat_a; - coopmat mat_b; + coopmat mat_a; + coopmat mat_b; coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); - coopmat mat_a_ft = coopmat(mat_a); - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose); - coopmat mat_b_ft = coopmat(mat_b); - sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); + sum = coopMatMulAdd(mat_a, mat_b, sum); } } else #endif // !defined(MUL_MAT_ID) @@ -261,10 +254,8 @@ void main() { [[dont_unroll]] for (uint block_k = start_k; block_k < end_k; block_k += BK) { - coopmat mat_a; - coopmat mat_b; - coopmat mat_a_ft; - coopmat mat_b_ft; + coopmat mat_a; + coopmat mat_b; // Clamping is expensive, so detect different code paths for each combination // of A and B needing clamping. @@ -281,16 +272,12 @@ void main() { #else coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose); #endif - mat_a_ft = coopmat(mat_a); - mat_b_ft = coopmat(mat_b); - sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); + sum = coopMatMulAdd(mat_a, mat_b, sum); } else if (unclampedA && !unclampedB) { coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA); coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); - mat_a_ft = coopmat(mat_a); - mat_b_ft = coopmat(mat_b); - sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); + sum = coopMatMulAdd(mat_a, mat_b, sum); } else if (!unclampedA && unclampedB) { coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); #ifdef MUL_MAT_ID @@ -298,16 +285,12 @@ void main() { #else coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose); #endif - mat_a_ft = coopmat(mat_a); - mat_b_ft = coopmat(mat_b); - sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); + sum = coopMatMulAdd(mat_a, mat_b, sum); } else if (!unclampedA && !unclampedB) { coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); - mat_a_ft = coopmat(mat_a); - mat_b_ft = coopmat(mat_b); - sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum); + sum = coopMatMulAdd(mat_a, mat_b, sum); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index b7890ef1e74..8bcb64101b1 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -316,8 +316,11 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool // For aligned matmul loads std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2"; - string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + // don't generate f32 variants for coopmat2 + if (!coopmat2) { + string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + } if (tname != "f16" && tname != "f32") { string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); From 5183a05e560aa46a8893f07a07941ab1b8ec838f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 21 Jan 2025 08:48:13 +0200 Subject: [PATCH 18/64] metal : fix out-of-bounds write (llama/11314) ggml-ci --- ggml/src/ggml-metal/ggml-metal.metal | 31 +++++++++++++++------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 8ba43904d0c..44f04c909bf 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4416,7 +4416,6 @@ void kernel_mul_mv_q2_K_f32_impl( device const half * dh = &x[ib].d; for (int row = 0; row < N_DST; row++) { - float4 acc1 = {0.f, 0.f, 0.f, 0.f}; float4 acc2 = {0.f, 0.f, 0.f, 0.f}; for (int i = 0; i < 8; i += 2) { @@ -4447,7 +4446,7 @@ void kernel_mul_mv_q2_K_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST; ++row) { + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = all_sum; @@ -4613,7 +4612,7 @@ void kernel_mul_mv_q3_K_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; if (tiisg == 0) { - for (int row = 0; row < 2; ++row) { + for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) { dst_f32[first_row + row] = sumf1[row]; } } @@ -4729,7 +4728,7 @@ void kernel_mul_mv_q4_K_f32_impl( device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0; - for (int row = 0; row < N_DST; ++row) { + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = all_sum; @@ -4861,7 +4860,7 @@ void kernel_mul_mv_q5_K_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < 2; ++row) { + for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = tot; @@ -4906,6 +4905,10 @@ void kernel_mul_mv_q6_K_f32_impl( const int row = 2*r0 + sgitg; + if (row >= args.ne0) { + return; + } + const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5061,7 +5064,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST; ++row) { + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = all_sum * 0.25f; @@ -5179,7 +5182,7 @@ void kernel_mul_mv_iq2_xs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST; ++row) { + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = all_sum * 0.25f; @@ -5289,7 +5292,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST; ++row) { + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = all_sum * 0.5f; @@ -5401,7 +5404,7 @@ void kernel_mul_mv_iq3_s_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST; ++row) { + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = all_sum; @@ -5514,7 +5517,7 @@ void kernel_mul_mv_iq2_s_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST; ++row) { + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = all_sum * 0.25f; @@ -5614,7 +5617,7 @@ void kernel_mul_mv_iq1_s_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST; ++row) { + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = all_sum; @@ -5709,7 +5712,7 @@ void kernel_mul_mv_iq1_m_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < N_DST; ++row) { + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = all_sum; @@ -5799,7 +5802,7 @@ void kernel_mul_mv_iq4_nl_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) { + for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = all_sum; @@ -5888,7 +5891,7 @@ void kernel_mul_mv_iq4_xs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < 2; ++row) { + for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = all_sum; From 58640aa456f5a3eb4a74a2a5dce9dfb5044dc2ac Mon Sep 17 00:00:00 2001 From: Radoslav Gerganov Date: Tue, 21 Jan 2025 15:06:41 +0200 Subject: [PATCH 19/64] rpc : better caching of the base buffer pointer (llama/11331) There is no need to use map, just store the base pointer in the buffer context. --- ggml/src/ggml-rpc/ggml-rpc.cpp | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 63da2b86b1b..3d0c465780a 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -181,7 +181,7 @@ struct ggml_backend_rpc_context { struct ggml_backend_rpc_buffer_context { std::shared_ptr sock; - std::unordered_map base_cache; + void * base_ptr; uint64_t remote_ptr; }; @@ -423,16 +423,15 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) { static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) { ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; - if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) { - return ctx->base_cache[buffer]; + if (ctx->base_ptr != nullptr) { + return ctx->base_ptr; } rpc_msg_buffer_get_base_req request = {ctx->remote_ptr}; rpc_msg_buffer_get_base_rsp response; bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response)); GGML_ASSERT(status); - void * base_ptr = reinterpret_cast(response.base_ptr); - ctx->base_cache[buffer] = base_ptr; - return base_ptr; + ctx->base_ptr = reinterpret_cast(response.base_ptr); + return ctx->base_ptr; } static rpc_tensor serialize_tensor(const ggml_tensor * tensor) { @@ -557,7 +556,7 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back if (response.remote_ptr != 0) { ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft, ggml_backend_rpc_buffer_interface, - new ggml_backend_rpc_buffer_context{sock, {}, response.remote_ptr}, + new ggml_backend_rpc_buffer_context{sock, nullptr, response.remote_ptr}, response.remote_size); return buffer; } else { From 373670613961d0426d05626ff5ff8fff619a648d Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Thu, 23 Jan 2025 01:01:17 -0600 Subject: [PATCH 20/64] vulkan: fix diag_mask_inf (llama/11323) With robustbufferaccess disabled, this shader was showing OOB stores. There is a bounds check in the code, but the workgrouop dimensions were reversed vs CUDA and it was running the wrong number of threads. So fix the workgroup dimensions and disable robustness for this pipeline. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 2 +- ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 78c2f5c45ad..eae4f0739a5 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2012,7 +2012,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp b/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp index 4e68742b516..26d8bc22ad7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp @@ -12,7 +12,7 @@ layout (push_constant) uniform parameter #include "types.comp" -layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x = 1, local_size_y = 512, local_size_z = 1) in; layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; From ba523d5e224ba2fae1871787e21f48d8f1b26a17 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Thu, 23 Jan 2025 01:07:50 -0600 Subject: [PATCH 21/64] vulkan: sort shaders for more deterministic binary (llama/11315) Fixes #11306. --- ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 8bcb64101b1..e9c6cb9d45d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -17,13 +17,13 @@ #include #include #include +#include #include #include #ifdef _WIN32 #include #include // For _mkdir on Windows - #include // For std::replace on w64devkit #else #include #include @@ -502,6 +502,7 @@ void write_output_files() { fprintf(hdr, "#include \n\n"); fprintf(src, "#include \"%s\"\n\n", basename(target_hpp).c_str()); + std::sort(shader_fnames.begin(), shader_fnames.end()); for (const auto& pair : shader_fnames) { const std::string& name = pair.first; #ifdef _WIN32 From 16eeb31933cfeb4d094ddad0f4ecad4bac5c2df6 Mon Sep 17 00:00:00 2001 From: amd-dwang Date: Thu, 23 Jan 2025 15:14:28 +0800 Subject: [PATCH 22/64] Vulkan-run-test: fix mmq_wg_denoms (llama/11343) There should be a copy-and-paste error here. *mmq_wg_denoms should be used together with *warptile_mmq, instead of wg_denoms. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 94 ++++++++++++++-------------- 1 file changed, 47 insertions(+), 47 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index eae4f0739a5..c325416d12d 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1673,31 +1673,31 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); if (device->coopmat_acc_f16_support) { - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } else { - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines. @@ -1707,31 +1707,31 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); if (device->coopmat_acc_f16_support) { - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); } else { - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); } } #undef CREATE_MM2 From 30767b4c4e182e90c916417f259222a07cae6e05 Mon Sep 17 00:00:00 2001 From: "Bernhard M. Wiedemann" Date: Fri, 24 Jan 2025 12:21:35 +0100 Subject: [PATCH 23/64] cmake : avoid -march=native when reproducible build is wanted (llama/11366) See https://reproducible-builds.org/ for why this is good and https://reproducible-builds.org/specs/source-date-epoch/ for the definition of this variable. Without this patch, compiling on different machines produced different binaries, which made verification of results difficult. Fixes: #11317 This patch was done while working on reproducible builds for openSUSE. --- ggml/CMakeLists.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 185079aa47c..ff68ddc21d9 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -58,7 +58,8 @@ else() set(GGML_BLAS_VENDOR_DEFAULT "Generic") endif() -if (CMAKE_CROSSCOMPILING) +if (CMAKE_CROSSCOMPILING OR DEFINED ENV{SOURCE_DATE_EPOCH}) + message(STATUS "Setting GGML_NATIVE_DEFAULT to OFF") set(GGML_NATIVE_DEFAULT OFF) else() set(GGML_NATIVE_DEFAULT ON) From c262dc80e2ee4306783cece21f881dfd16c3b039 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Fri, 24 Jan 2025 12:38:31 +0100 Subject: [PATCH 24/64] CPU/CUDA: fix (GQA) mul mat back, add CUDA support (llama/11380) --- ggml/src/ggml-cpu/ggml-cpu.c | 4 +-- ggml/src/ggml-cpu/ggml-cpu.cpp | 3 +- ggml/src/ggml-cuda/binbcast.cu | 54 ++++++++++++++++++--------------- ggml/src/ggml-cuda/ggml-cuda.cu | 2 +- ggml/src/ggml-cuda/out-prod.cu | 7 +++-- ggml/src/ggml.c | 32 +++++++++++-------- 6 files changed, 59 insertions(+), 43 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index cd3f815b85f..e809f05d217 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -7883,7 +7883,7 @@ static void ggml_compute_forward_out_prod_f32( float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); - float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1); } @@ -7892,7 +7892,7 @@ static void ggml_compute_forward_out_prod_f32( float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); - float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); ggml_vec_mad_f32(ne0, d, s0, *s1); } diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index 35a1c876c86..2ccb4b472d6 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -416,7 +416,8 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st case GGML_OP_IM2COL_BACK: return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32; case GGML_OP_OUT_PROD: - return (src0->type == GGML_TYPE_F32 || ggml_is_quantized(src0->type)) && src1->type == GGML_TYPE_F32; + return (src0->type == GGML_TYPE_F32 || (ggml_is_quantized(src0->type) && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3])) && + src1->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; default: return true; } diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index c7b6be4e290..ce4b9cfb51b 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -93,26 +93,31 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s template static __global__ void k_repeat_back( - const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, - const int64_t ne0, const int64_t ne1, const int64_t ne2) { + const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const size_t s00, const size_t s01, const size_t s02, const size_t s03, + const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3) { - const int64_t tid0 = (int64_t) blockIdx.x*blockDim.x + threadIdx.x; - const int64_t tid1 = (int64_t) blockIdx.y*blockDim.y + threadIdx.y; - const int64_t tid2 = (int64_t) blockIdx.z*blockDim.z + threadIdx.z; + const int64_t tid0 = int64_t(blockIdx.x)*blockDim.x + threadIdx.x; + const int64_t tid1 = int64_t(blockIdx.y)*blockDim.y + threadIdx.y; + const int64_t tid23 = int64_t(blockIdx.z)*blockDim.z + threadIdx.z; + const int64_t tid2 = tid23 % ne2; + const int64_t tid3 = tid23 / ne2; if (tid0 >= ne0) { return; } T sum = 0; - for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) { - for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) { - for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) { - sum += src[i2*ne01*ne00 + i1*ne00 + i0]; + for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) { + for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) { + for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) { + for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) { + sum += src[i3*s03 + i2*s02 + i1*s01 + i0*s00]; + } } } } - dst[tid2*ne1*ne0 + tid1*ne0 + tid0] = sum; + dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum; } template @@ -274,12 +279,14 @@ struct bin_bcast_cuda { template static void repeat_back_cuda( - const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, - const int64_t ne0, const int64_t ne1, const int64_t ne2, cudaStream_t stream) { + const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const size_t s00, const size_t s01, const size_t s02, const size_t s03, + const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) { const dim3 block_dims(WARP_SIZE, 1, 1); - const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2); - k_repeat_back<<>>(src, dst, ne00, ne01, ne02, ne0, ne1, ne2); + const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2*ne3); + k_repeat_back<<>> + (src, dst, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3); } template @@ -326,27 +333,26 @@ void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst const ggml_tensor * src0 = dst->src[0]; GGML_ASSERT(src0->type == dst->type); - GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(ggml_is_contiguous(dst)); GGML_ASSERT(ggml_can_repeat(dst, src0)); cudaStream_t stream = ctx.stream(); - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - GGML_ASSERT(src0->ne[3] == 1); + GGML_TENSOR_UNARY_OP_LOCALS; + + GGML_ASSERT(ne2*ne3 <= (1 << 15)); - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - const int64_t ne2 = dst->ne[2]; - GGML_ASSERT(dst->ne[3] == 1); + const size_t ts = ggml_type_size(src0->type); + const size_t s00 = nb00 / ts; + const size_t s01 = nb01 / ts; + const size_t s02 = nb02 / ts; + const size_t s03 = nb03 / ts; switch (dst->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0->data; float * dst_d = (float *) dst->data; - repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne0, ne1, ne2, stream); + repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3, stream); } break; default: { GGML_ASSERT(false); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 7fd1fc85346..e602419bc57 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3002,7 +3002,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; } break; case GGML_OP_REPEAT_BACK: - return op->type == GGML_TYPE_F32 && op->src[0]->ne[3] == 1; + return op->type == GGML_TYPE_F32 && (op->src[0]->ne[2]*op->src[0]->ne[3]) <= (1 << 15); case GGML_OP_CONCAT: { ggml_type src0_type = op->src[0]->type; diff --git a/ggml/src/ggml-cuda/out-prod.cu b/ggml/src/ggml-cuda/out-prod.cu index 73e3e2c47f2..c9b2b699c6a 100644 --- a/ggml/src/ggml-cuda/out-prod.cu +++ b/ggml/src/ggml-cuda/out-prod.cu @@ -34,6 +34,9 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { CUBLAS_CHECK(cublasSetStream(handle, stream)); + const int64_t lda = nb01 / sizeof(float); + const int64_t ldc = nb1 / sizeof(float); + const bool src1_T = ggml_is_transposed(src1); const cublasOperation_t src1_cublas_op = src1_T ? CUBLAS_OP_N : CUBLAS_OP_T; const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float); @@ -57,9 +60,9 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { CUBLAS_CHECK( cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op, ne0, ne1, ne01, - &alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, ne00, + &alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda, src1_d + i3 *s13 + i2 *s12, ldb, - &beta, dst_d + i3 *s3 + i2 *s2, ne0)); + &beta, dst_d + i3 *s3 + i2 *s2, ldc)); } } } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index d83b1b8ad6b..3b486154211 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -5343,7 +5343,7 @@ static void ggml_compute_backward( } break; case GGML_OP_MUL: { if (src0_needs_grads) { - ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, src1, grad)); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, src1)); } if (src1_needs_grads) { struct ggml_tensor * tmp = ggml_mul(ctx, src0, grad); @@ -5435,21 +5435,25 @@ static void ggml_compute_backward( // src1.shape [n,p,qq,rr] if (src0_needs_grads) { - struct ggml_tensor * s1_tg = + GGML_ASSERT(grad->ne[2] == src1->ne[2]); + GGML_ASSERT(grad->ne[3] == src1->ne[3]); + struct ggml_tensor * tmp = ggml_out_prod(ctx, // [n,m,qq,rr] src1, // [n,p,qq,rr] grad); // [m,p,qq,rr] - const int64_t qq = s1_tg->ne[2]; - const int64_t rr = s1_tg->ne[3]; - const int64_t q1 = src0->ne[2]; - const int64_t r1 = src0->ne[3]; - const bool ne2_broadcasted = qq > q1; - const bool ne3_broadcasted = rr > r1; - if (ne2_broadcasted || ne3_broadcasted) { - // sum broadcast repetitions of s1_tg into shape of src0 - s1_tg = ggml_repeat_back(ctx, s1_tg, src0); + if (!ggml_are_same_shape(tmp, src0)) { + GGML_ASSERT(tmp->ne[0] == src0->ne[0]); + GGML_ASSERT(tmp->ne[1] == src0->ne[1]); + GGML_ASSERT(tmp->ne[3] == 1); + + const int64_t nr2 = tmp->ne[2] / src0->ne[2]; + const size_t nb2 = tmp->nb[2] * nr2; + const size_t nb3 = tmp->nb[2]; + + tmp = ggml_view_4d(ctx, tmp, src0->ne[0], src0->ne[1], src0->ne[2], nr2, tmp->nb[1], nb2, nb3, 0); + tmp = ggml_repeat_back(ctx, tmp, src0); } - ggml_add_or_set(ctx, cgraph, isrc0, s1_tg /*= [n,m,q1,r1]*/); + ggml_add_or_set(ctx, cgraph, isrc0, tmp); } if (src1_needs_grads) { ggml_add_or_set(ctx, cgraph, isrc1, @@ -5518,7 +5522,9 @@ static void ggml_compute_backward( if (src0_needs_grads) { GGML_ASSERT(!cgraph->grads[isrc0] || ggml_is_contiguous(cgraph->grads[isrc0])); GGML_ASSERT(ggml_is_contiguous(grad)); - ggml_add_or_set(ctx, cgraph, isrc0, grad); + GGML_ASSERT(ggml_nelements(tensor) == ggml_nelements(src0)); + ggml_add_or_set(ctx, cgraph, isrc0, + ggml_are_same_shape(tensor, src0) ? grad : ggml_reshape(ctx, grad, src0)); } } break; case GGML_OP_RESHAPE: { From 727891d9bf9e51f167e552a35af4dd6bfa2cbc59 Mon Sep 17 00:00:00 2001 From: uvos Date: Fri, 24 Jan 2025 17:50:49 +0100 Subject: [PATCH 25/64] rocBLAS: Avoid fp32->fp16->fp32 conversion on cdna (llama/11356) --- ggml/src/ggml-cuda/ggml-cuda.cu | 60 ++++++++++++++++++++------------- ggml/src/ggml-cuda/mmvq.cu | 3 +- 2 files changed, 39 insertions(+), 24 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index e602419bc57..fb3d9e2d92e 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -1082,7 +1082,9 @@ static void ggml_cuda_op_mul_mat_cublas( const int compute_capability = ggml_cuda_info().devices[id].cc; - if (compute_capability >= GGML_CUDA_CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) { + const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT; + + if (compute_capability >= GGML_CUDA_CC_VOLTA && use_fp16) { // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32 ggml_cuda_pool_alloc src0_as_f16(ctx.pool(id)); if (src0->type != GGML_TYPE_F16) { @@ -1103,28 +1105,38 @@ static void ggml_cuda_op_mul_mat_cublas( to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream); } const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get(); - ggml_cuda_pool_alloc dst_f16(ctx.pool(id), row_diff*src1_ncols); - const half alpha_f16 = 1.0f; - const half beta_f16 = 0.0f; + CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream)); - cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F; - if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) { - cu_compute_type = CUBLAS_COMPUTE_32F; - } + if (compute_capability == GGML_CUDA_CC_CDNA) { + const float alpha = 1.0f; + const float beta = 0.0f; + CUBLAS_CHECK( + cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, + row_diff, src1_ncols, ne10, + &alpha, src0_ptr, CUDA_R_16F, ne00, + src1_ptr, CUDA_R_16F, ne10, + &beta, dst_dd_i, CUDA_R_32F, ldc, + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + } else { + ggml_cuda_pool_alloc dst_f16(ctx.pool(id), row_diff*src1_ncols); - CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream)); - CUBLAS_CHECK( - cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, - row_diff, src1_ncols, ne10, - &alpha_f16, src0_ptr, CUDA_R_16F, ne00, - src1_ptr, CUDA_R_16F, ne10, - &beta_f16, dst_f16.get(), CUDA_R_16F, ldc, - cu_compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + const half alpha_f16 = 1.0f; + const half beta_f16 = 0.0f; - const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); - to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); + CUBLAS_CHECK( + cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, + row_diff, src1_ncols, ne10, + &alpha_f16, src0_ptr, CUDA_R_16F, ne00, + src1_ptr, CUDA_R_16F, ne10, + &beta_f16, dst_dd_i, CUDA_R_16F, ldc, + CUBLAS_COMPUTE_16F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); + to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); + } } else { ggml_cuda_pool_alloc src0_ddq_as_f32(ctx.pool(id)); ggml_cuda_pool_alloc src1_ddq_as_f32(ctx.pool(id)); @@ -1613,10 +1625,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F; cudaDataType_t cu_data_type = CUDA_R_16F; - if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) { - cu_compute_type = CUBLAS_COMPUTE_32F; - } - // dst strides size_t nbd2 = dst->nb[2]; size_t nbd3 = dst->nb[3]; @@ -1645,6 +1653,12 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co beta = &beta_f32; } + if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) { + cu_compute_type = CUBLAS_COMPUTE_32F; + alpha = &alpha_f32; + beta = &beta_f32; + } + GGML_ASSERT(ne12 % ne02 == 0); GGML_ASSERT(ne13 % ne03 == 0); diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index e3b912d875b..4fb466ca006 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -142,7 +142,7 @@ static void mul_mat_vec_q_cuda( int64_t nwarps = 1; int64_t rows_per_cuda_block = 1; - if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_CDNA || ggml_cuda_info().devices[id].cc == GGML_CUDA_CC_RDNA1) { // NVIDIA and AMD older than RDNA2 but not CDNA + if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_RDNA2) { // NVIDIA and AMD older than RDNA2 switch(ncols_y) { case 1: nwarps = 4; @@ -166,6 +166,7 @@ static void mul_mat_vec_q_cuda( break; } } + const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block; const dim3 block_nums(nblocks, 1, 1); const dim3 block_dims(WARP_SIZE, nwarps, 1); From 9e467815d476bd247b86da174475416dc0e62efe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Fri, 24 Jan 2025 21:02:43 +0100 Subject: [PATCH 26/64] CUDA: fix FP16 cuBLAS GEMM (llama/11396) --- ggml/src/ggml-cuda/ggml-cuda.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index fb3d9e2d92e..fbe889a0122 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -1114,8 +1114,8 @@ static void ggml_cuda_op_mul_mat_cublas( CUBLAS_CHECK( cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, row_diff, src1_ncols, ne10, - &alpha, src0_ptr, CUDA_R_16F, ne00, - src1_ptr, CUDA_R_16F, ne10, + &alpha, src0_ptr, CUDA_R_16F, ne00, + src1_ptr, CUDA_R_16F, ne10, &beta, dst_dd_i, CUDA_R_32F, ldc, CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); @@ -1128,9 +1128,9 @@ static void ggml_cuda_op_mul_mat_cublas( CUBLAS_CHECK( cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, row_diff, src1_ncols, ne10, - &alpha_f16, src0_ptr, CUDA_R_16F, ne00, - src1_ptr, CUDA_R_16F, ne10, - &beta_f16, dst_dd_i, CUDA_R_16F, ldc, + &alpha_f16, src0_ptr, CUDA_R_16F, ne00, + src1_ptr, CUDA_R_16F, ne10, + &beta_f16, dst_f16.get(), CUDA_R_16F, ldc, CUBLAS_COMPUTE_16F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); From 0282ad8fd15115c981b66d0839c00daad5b00a1f Mon Sep 17 00:00:00 2001 From: uvos Date: Sat, 25 Jan 2025 00:02:23 +0100 Subject: [PATCH 27/64] hip : Add hipGraph and VMM support to ROCM (llama/11362) * Add hipGraph support * Enable VMM on rocm --- ggml/CMakeLists.txt | 1 + ggml/src/ggml-cuda/common.cuh | 2 +- ggml/src/ggml-cuda/ggml-cuda.cu | 58 +++++++++++++++++++++----------- ggml/src/ggml-cuda/vendors/hip.h | 43 +++++++++++++++++++++++ ggml/src/ggml-hip/CMakeLists.txt | 8 +++++ 5 files changed, 92 insertions(+), 20 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index ff68ddc21d9..123c755ace4 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -154,6 +154,7 @@ option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashA option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT}) option(GGML_HIP "ggml: use HIP" OFF) +option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF) option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF) option(GGML_VULKAN "ggml: use Vulkan" OFF) option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 2c0a5622672..a79fa83c5a2 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -588,7 +588,7 @@ struct ggml_tensor_extra_gpu { }; -#if (CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS) +#if ((CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)) || defined(GGML_HIP_GRAPHS) #define USE_CUDA_GRAPH #endif diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index fbe889a0122..a53a1bbd05d 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -62,7 +62,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); [[noreturn]] void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) { int id = -1; // in case cudaGetDevice fails - cudaGetDevice(&id); + (void)cudaGetDevice(&id); GGML_LOG_ERROR(GGML_CUDA_NAME " error: %s\n", msg); GGML_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func, file, line); @@ -152,7 +152,7 @@ static ggml_cuda_device_info ggml_cuda_init() { for (int id = 0; id < info.device_count; ++id) { int device_vmm = 0; -#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM) +#if !defined(GGML_CUDA_NO_VMM) CUdevice device; CU_CHECK(cuDeviceGet(&device, id)); CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device)); @@ -164,7 +164,7 @@ static ggml_cuda_device_info ggml_cuda_init() { alloc_prop.location.id = id; CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED)); } -#endif // !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM) +#endif // !defined(GGML_CUDA_NO_VMM) info.devices[id].vmm = !!device_vmm; cudaDeviceProp prop; @@ -300,7 +300,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { }; // pool with virtual memory -#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM) +#if !defined(GGML_CUDA_NO_VMM) struct ggml_cuda_pool_vmm : public ggml_cuda_pool { static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB @@ -309,6 +309,9 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { size_t pool_used = 0; size_t pool_size = 0; size_t granularity; +#if defined(GGML_USE_HIP) + std::vector> mappings; +#endif explicit ggml_cuda_pool_vmm(int device) : device(device), @@ -317,7 +320,14 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { ~ggml_cuda_pool_vmm() { if (pool_addr != 0) { +#if defined(GGML_USE_HIP) + // Workaround for https://github.com/ROCm/ROCR-Runtime/issues/285 + for (std::pair & mapping : mappings) { + CU_CHECK(cuMemUnmap(mapping.first, mapping.second)); + } +#else CU_CHECK(cuMemUnmap(pool_addr, pool_size)); +#endif CU_CHECK(cuMemAddressFree(pool_addr, CUDA_POOL_VMM_MAX_SIZE)); } } @@ -350,7 +360,11 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { } // map at the end of the pool - CU_CHECK(cuMemMap(pool_addr + pool_size, reserve_size, 0, handle, 0)); + CUdeviceptr start_ptr = (CUdeviceptr)((char *)(pool_addr) + pool_size); + CU_CHECK(cuMemMap(start_ptr, reserve_size, 0, handle, 0)); +#if defined(GGML_USE_HIP) + mappings.push_back({start_ptr, reserve_size}); +#endif // the memory allocation handle is no longer needed after mapping CU_CHECK(cuMemRelease(handle)); @@ -360,7 +374,7 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { access.location.type = CU_MEM_LOCATION_TYPE_DEVICE; access.location.id = device; access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; - CU_CHECK(cuMemSetAccess(pool_addr + pool_size, reserve_size, &access, 1)); + CU_CHECK(cuMemSetAccess((CUdeviceptr)((char *)(pool_addr) + pool_size), reserve_size, &access, 1)); // add to the pool pool_size += reserve_size; @@ -372,7 +386,7 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { GGML_ASSERT(pool_addr != 0); - void * ptr = (void *) (pool_addr + pool_used); + void * ptr = (void *) ((CUdeviceptr)((char *)(pool_addr) + pool_used)); *actual_size = size; pool_used += size; @@ -391,17 +405,17 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { pool_used -= size; // all deallocations must be in reverse order of the allocations - GGML_ASSERT(ptr == (void *) (pool_addr + pool_used)); + GGML_ASSERT(ptr == (void *) ((char *)(pool_addr) + pool_used)); } }; -#endif // !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM) +#endif // !defined(GGML_CUDA_NO_VMM) std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int device) { -#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM) +#if !defined(GGML_CUDA_NO_VMM) if (ggml_cuda_info().devices[device].vmm) { return std::unique_ptr(new ggml_cuda_pool_vmm(device)); } -#endif // !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM) +#endif // !defined(GGML_CUDA_NO_VMM) return std::unique_ptr(new ggml_cuda_pool_leg(device)); } @@ -547,7 +561,7 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device); if (err != cudaSuccess) { // clear the error - cudaGetLastError(); + (void)cudaGetLastError(); GGML_LOG_ERROR("%s: allocating %.2f MiB on device %d: cudaMalloc failed: %s\n", __func__, size / 1024.0 / 1024.0, buft_ctx->device, cudaGetErrorString(err)); return nullptr; } @@ -962,7 +976,7 @@ static void * ggml_cuda_host_malloc(size_t size) { cudaError_t err = cudaMallocHost((void **) &ptr, size); if (err != cudaSuccess) { // clear the error - cudaGetLastError(); + (void)cudaGetLastError(); GGML_LOG_DEBUG("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__, size / 1024.0 / 1024.0, cudaGetErrorString(err)); return nullptr; @@ -1209,7 +1223,7 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) { CUDA_CHECK(err); } else { // reset the error - cudaGetLastError(); + (void)cudaGetLastError(); } } else { cudaError_t err = cudaDeviceDisablePeerAccess(id_other); @@ -1217,7 +1231,7 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) { CUDA_CHECK(err); } else { // reset the error - cudaGetLastError(); + (void)cudaGetLastError(); } } } @@ -2452,7 +2466,7 @@ static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vecto if (stat == cudaErrorInvalidDeviceFunction) { // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node. // We don't need to update blas nodes, so clear error and move on. - cudaGetLastError(); + (void)cudaGetLastError(); } else { GGML_ASSERT(stat == cudaSuccess); } @@ -2507,14 +2521,20 @@ static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { cudaGraphExecUpdateResultInfo result_info; +#ifdef __HIP_PLATFORM_AMD__ + hipGraphNode_t errorNode; + hipError_t stat = hipGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info); +#else cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info); +#endif if (stat == cudaErrorGraphExecUpdateFailure) { #ifndef NDEBUG GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__); #endif + // The pre-existing graph exec cannot be updated due to violated constraints // so instead clear error and re-instantiate - cudaGetLastError(); + (void)cudaGetLastError(); CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance)); cuda_ctx->cuda_graph->instance = nullptr; CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); @@ -2742,7 +2762,7 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) { cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly); if (err != cudaSuccess) { // clear the error - cudaGetLastError(); + (void)cudaGetLastError(); GGML_LOG_DEBUG("%s: failed to register %.2f MiB of pinned memory: %s\n", __func__, size / 1024.0 / 1024.0, cudaGetErrorString(err)); @@ -2762,7 +2782,7 @@ void ggml_backend_cuda_unregister_host_buffer(void * buffer) { cudaError_t err = cudaHostUnregister(buffer); if (err != cudaSuccess) { // clear the error - cudaGetLastError(); + (void)cudaGetLastError(); } } diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index c905b15d76c..8594093f052 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -19,6 +19,12 @@ #define CUBLAS_TF32_TENSOR_OP_MATH 0 #define CUDA_R_16F HIPBLAS_R_16F #define CUDA_R_32F HIPBLAS_R_32F +#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported +#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended +#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned +#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice +#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite +#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }} #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) #define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6 #define cublasCreate hipblasCreate @@ -74,6 +80,21 @@ #define cudaMemGetInfo hipMemGetInfo #define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize #define cudaSetDevice hipSetDevice +#define cuDeviceGet hipDeviceGet +#define CUdevice hipDevice_t +#define CUdeviceptr hipDeviceptr_t +#define cuMemUnmap hipMemUnmap +#define CUmemAccessDesc hipMemAccessDesc +#define cuMemAddressFree hipMemAddressFree +#define cuMemRelease hipMemRelease +#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t +#define cuMemCreate hipMemCreate +#define cuMemAddressReserve hipMemAddressReserve +#define cuMemMap hipMemMap +#define cuMemSetAccess hipMemSetAccess +#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity +#define CUmemAllocationProp hipMemAllocationProp +#define cuDeviceGetAttribute hipDeviceGetAttribute #define cudaStreamCreateWithFlags hipStreamCreateWithFlags #define cudaStreamDestroy hipStreamDestroy #define cudaStreamFireAndForget hipStreamFireAndForget @@ -81,6 +102,28 @@ #define cudaStreamPerThread hipStreamPerThread #define cudaStreamSynchronize hipStreamSynchronize #define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags) +#define cudaGraphExec_t hipGraphExec_t +#define cudaGraphNode_t hipGraphNode_t +#define cudaKernelNodeParams hipKernelNodeParams +#define cudaKernelNodeParams hipKernelNodeParams +#define cudaGraphExecDestroy hipGraphExecDestroy +#define cudaGraphLaunch hipGraphLaunch +#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure +#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult +#define cudaGraphNodeType hipGraphNodeType +#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel +#define cudaGraphInstantiate hipGraphInstantiate +#define cudaStreamEndCapture hipStreamEndCapture +#define cudaGraphDestroy hipGraphDestroy +#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams +#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction +#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams +#define cudaGraphNodeGetType hipGraphNodeGetType +#define cudaGraphGetNodes hipGraphGetNodes +#define cudaGraphExecUpdate hipGraphExecUpdate +#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed +#define cudaStreamBeginCapture hipStreamBeginCapture +#define cudaGraph_t hipGraph_t #define cudaStream_t hipStream_t #define cudaSuccess hipSuccess #define __trap() do { abort(); __builtin_unreachable(); } while(0) diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index d090ba9bd98..77994a6986b 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -92,6 +92,14 @@ if (GGML_CUDA_NO_PEER_COPY) add_compile_definitions(GGML_CUDA_NO_PEER_COPY) endif() +if (GGML_HIP_GRAPHS) + add_compile_definitions(GGML_HIP_GRAPHS) +endif() + +if (GGML_CUDA_NO_VMM) + add_compile_definitions(GGML_CUDA_NO_VMM) +endif() + if (CXX_IS_HIPCC) set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX) target_link_libraries(ggml-hip PRIVATE hip::device) From a160fa0f3a494e0a3779f687dab4c5067039ee2a Mon Sep 17 00:00:00 2001 From: uvos Date: Sat, 25 Jan 2025 21:01:12 +0100 Subject: [PATCH 28/64] Hip: disable VMM on hip as it seams that it dosent work in some configurations (llama/11420) --- ggml/CMakeLists.txt | 1 + ggml/src/ggml-cuda/common.cuh | 4 ++++ ggml/src/ggml-cuda/ggml-cuda.cu | 14 +++++++------- ggml/src/ggml-hip/CMakeLists.txt | 4 ++-- 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 123c755ace4..bbabb14de38 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -155,6 +155,7 @@ option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp on option(GGML_HIP "ggml: use HIP" OFF) option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF) +option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON) option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF) option(GGML_VULKAN "ggml: use Vulkan" OFF) option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index a79fa83c5a2..bb6120568f2 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -131,6 +131,10 @@ typedef float dfloat; // dequantize float typedef float2 dfloat2; #endif // GGML_CUDA_F16 +#if (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM)) +#define GGML_USE_VMM +#endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM)) + #if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL #define FP16_AVAILABLE #endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index a53a1bbd05d..85178abd282 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -152,7 +152,7 @@ static ggml_cuda_device_info ggml_cuda_init() { for (int id = 0; id < info.device_count; ++id) { int device_vmm = 0; -#if !defined(GGML_CUDA_NO_VMM) +#if defined(GGML_USE_VMM) CUdevice device; CU_CHECK(cuDeviceGet(&device, id)); CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device)); @@ -164,7 +164,7 @@ static ggml_cuda_device_info ggml_cuda_init() { alloc_prop.location.id = id; CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED)); } -#endif // !defined(GGML_CUDA_NO_VMM) +#endif // defined(GGML_USE_VMM) info.devices[id].vmm = !!device_vmm; cudaDeviceProp prop; @@ -300,7 +300,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { }; // pool with virtual memory -#if !defined(GGML_CUDA_NO_VMM) +#if defined(GGML_USE_VMM) struct ggml_cuda_pool_vmm : public ggml_cuda_pool { static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB @@ -408,14 +408,14 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { GGML_ASSERT(ptr == (void *) ((char *)(pool_addr) + pool_used)); } }; -#endif // !defined(GGML_CUDA_NO_VMM) +#endif // defined(GGML_USE_VMM) std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int device) { -#if !defined(GGML_CUDA_NO_VMM) +#if defined(GGML_USE_VMM) if (ggml_cuda_info().devices[device].vmm) { return std::unique_ptr(new ggml_cuda_pool_vmm(device)); } -#endif // !defined(GGML_CUDA_NO_VMM) +#endif // defined(GGML_USE_VMM) return std::unique_ptr(new ggml_cuda_pool_leg(device)); } @@ -3250,7 +3250,7 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t features.push_back({ "FORCE_CUBLAS", "1" }); #endif - #ifdef GGML_CUDA_NO_VMM + #ifndef GGML_USE_VMM features.push_back({ "NO_VMM", "1" }); #endif diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index 77994a6986b..ecc3bc66d44 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -96,8 +96,8 @@ if (GGML_HIP_GRAPHS) add_compile_definitions(GGML_HIP_GRAPHS) endif() -if (GGML_CUDA_NO_VMM) - add_compile_definitions(GGML_CUDA_NO_VMM) +if (GGML_HIP_NO_VMM) + add_compile_definitions(GGML_HIP_NO_VMM) endif() if (CXX_IS_HIPCC) From 7230a6e1c878c2c77704c7a99309facdd9a72508 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sat, 25 Jan 2025 15:29:57 -0600 Subject: [PATCH 29/64] vulkan: compile shaders on-demand (llama/11406) Reduce first-run startup time and memory consumption. Should fix #11339. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 64 ++++++++++++++++++---------- 1 file changed, 41 insertions(+), 23 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index c325416d12d..a9d6b923cbf 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -85,6 +85,10 @@ struct vk_pipeline_struct { uint32_t parameter_count; std::array wg_denoms; uint32_t align; + // set to true to request the pipeline is compiled after the dryrun + bool needed {}; + // set to true when the shader has been compiled + bool compiled {}; }; typedef std::shared_ptr vk_pipeline; @@ -186,8 +190,11 @@ struct vk_device_struct { bool mul_mat_id_m; bool mul_mat_id_s; - vk_matmul_pipeline pipeline_matmul_f32; - vk_matmul_pipeline pipeline_matmul_f32_f16; + // set to true to indicate that some shaders need to be compiled after the dryrun + bool need_compiles {}; + + vk_matmul_pipeline pipeline_matmul_f32 {}; + vk_matmul_pipeline pipeline_matmul_f32_f16 {}; vk_matmul_pipeline2 pipeline_matmul_f16; vk_matmul_pipeline2 pipeline_matmul_f16_f32; vk_pipeline pipeline_matmul_split_k_reduce; @@ -195,7 +202,7 @@ struct vk_device_struct { vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT]; vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT]; - vk_matmul_pipeline pipeline_matmul_id_f32; + vk_matmul_pipeline pipeline_matmul_id_f32 {}; vk_matmul_pipeline2 pipeline_matmul_id_f16; vk_matmul_pipeline2 pipeline_matmul_id_f16_f32; @@ -776,13 +783,6 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin GGML_ASSERT(parameter_count > 0); GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT - pipeline = std::make_shared(); - pipeline->name = name; - pipeline->parameter_count = parameter_count; - pipeline->push_constant_size = push_constant_size; - pipeline->wg_denoms = wg_denoms; - pipeline->align = align; - vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast(spv_data)); pipeline->shader_module = device->device.createShaderModule(shader_module_create_info); @@ -865,6 +865,7 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin } pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value; + pipeline->compiled = true; { std::lock_guard guard(device->mutex); @@ -875,12 +876,6 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin std::lock_guard guard(compile_count_mutex); assert(compile_count > 0); compile_count--; - - // "Progress bar" for shader compiles - static uint32_t total_compile_count = 0; - if ((total_compile_count++ % 10) == 0) { - std::cerr << "."; - } } compile_count_cond.notify_all(); } @@ -906,6 +901,10 @@ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) static void ggml_pipeline_request_descriptor_sets(vk_device& device, vk_pipeline& pipeline, uint32_t n) { VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")"); device->pipeline_descriptor_set_requirements[pipeline->name] += n; + if (!pipeline->compiled) { + pipeline->needed = true; + device->need_compiles = true; + } } static void ggml_pipeline_allocate_descriptor_sets(vk_device& device) { @@ -1388,8 +1387,6 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec static void ggml_vk_load_shaders(vk_device& device) { VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")"); - std::cerr << "ggml_vulkan: Compiling shaders"; - // some shaders have a minimum subgroup size const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u); const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u); @@ -1527,15 +1524,33 @@ static void ggml_vk_load_shaders(vk_device& device) { } } - device->pipeline_matmul_f32 = std::make_shared(); - device->pipeline_matmul_f32_f16 = std::make_shared(); - - device->pipeline_matmul_id_f32 = std::make_shared(); + if (!device->pipeline_matmul_f32) { + device->pipeline_matmul_f32 = std::make_shared(); + } + if (!device->pipeline_matmul_f32_f16) { + device->pipeline_matmul_f32_f16 = std::make_shared(); + } + if (!device->pipeline_matmul_id_f32) { + device->pipeline_matmul_id_f32 = std::make_shared(); + } std::vector> compiles; auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, const std::vector& specialization_constants, uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) { + + if (!pipeline) { + pipeline = std::make_shared(); + pipeline->name = name; + pipeline->parameter_count = parameter_count; + pipeline->push_constant_size = push_constant_size; + pipeline->wg_denoms = wg_denoms; + pipeline->align = align; + } + + if (!pipeline->needed || pipeline->compiled) { + return; + } { // wait until fewer than N compiles are in progress uint32_t N = std::max(1u, std::thread::hardware_concurrency()); @@ -2050,7 +2065,7 @@ static void ggml_vk_load_shaders(vk_device& device) { for (auto &c : compiles) { c.wait(); } - std::cerr << "Done!" << std::endl; + device->need_compiles = false; } static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props); @@ -7656,6 +7671,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg for (int i = 0; i < cgraph->n_nodes; i++) { ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false); } + if (ctx->device->need_compiles) { + ggml_vk_load_shaders(ctx->device); + } ggml_vk_preallocate_buffers(ctx); ggml_pipeline_allocate_descriptor_sets(ctx->device); From d5d831da65c7be75d0840d4d29c30df4c5bcdf98 Mon Sep 17 00:00:00 2001 From: bandoti <141645996+bandoti@users.noreply.github.com> Date: Sun, 26 Jan 2025 12:07:48 -0400 Subject: [PATCH 30/64] cmake: add ggml find package (llama/11369) * Add initial ggml cmake package * Add build numbers to ggml find-package * Expand variables with GGML_ prefix * Guard against adding to cache variable twice * Add git to msys2 workflow * Handle ggml-cpu-* variants * Link ggml/ggml-base libraries to their targets * Replace main-cmake-pkg with simple-cmake-pkg * Interface features require c_std_90 * Fix typo * Removed unnecessary bracket from status message * Update examples/simple-cmake-pkg/README.md Co-authored-by: Georgi Gerganov * Update examples/simple-cmake-pkg/README.md Co-authored-by: Georgi Gerganov --------- Co-authored-by: Georgi Gerganov --- ggml/CMakeLists.txt | 71 +++++++++++++++++++++++++++++++++++++++++ ggml/src/CMakeLists.txt | 11 +++++++ 2 files changed, 82 insertions(+) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index bbabb14de38..7c069e42037 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -267,3 +267,74 @@ if (GGML_STANDALONE) install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml.pc DESTINATION share/pkgconfig) endif() + +# +# Create CMake package +# + +# Generate version info based on git commit. + +find_program(GIT_EXE NAMES git git.exe REQUIRED NO_CMAKE_FIND_ROOT_PATH) +execute_process(COMMAND ${GIT_EXE} rev-list --count HEAD + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + OUTPUT_VARIABLE GGML_BUILD_NUMBER + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +if(GGML_BUILD_NUMBER EQUAL 1) + message(WARNING "GGML build version fixed at 1 likely due to a shallow clone.") +endif() + +execute_process(COMMAND ${GIT_EXE} rev-parse --short HEAD + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + OUTPUT_VARIABLE GGML_BUILD_COMMIT + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +# Capture variables prefixed with GGML_. + +set(variable_set_statements +" +####### Expanded from @GGML_VARIABLES_EXPANED@ by configure_package_config_file() ####### +####### Any changes to this file will be overwritten by the next CMake run ####### + +") + +set(GGML_SHARED_LIB ${BUILD_SHARED_LIBS}) + +get_cmake_property(all_variables VARIABLES) +foreach(variable_name IN LISTS all_variables) + if(variable_name MATCHES "^GGML_") + string(REPLACE ";" "\\;" + variable_value "${${variable_name}}") + + set(variable_set_statements + "${variable_set_statements}set(${variable_name} \"${variable_value}\")\n") + endif() +endforeach() + +set(GGML_VARIABLES_EXPANDED ${variable_set_statements}) + +# Create the CMake package and set install location. + +set(GGML_INSTALL_VERSION 0.0.${GGML_BUILD_NUMBER}) +set(GGML_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} CACHE PATH "Location of header files") +set(GGML_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Location of library files") +set(GGML_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location of binary files") + +configure_package_config_file( + ${CMAKE_CURRENT_SOURCE_DIR}/cmake/ggml-config.cmake.in + ${CMAKE_CURRENT_BINARY_DIR}/ggml-config.cmake + INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ggml + PATH_VARS GGML_INCLUDE_INSTALL_DIR + GGML_LIB_INSTALL_DIR + GGML_BIN_INSTALL_DIR) + +write_basic_package_version_file( + ${CMAKE_CURRENT_BINARY_DIR}/ggml-version.cmake + VERSION ${GGML_INSTALL_VERSION} + COMPATIBILITY SameMajorVersion) + +install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml-config.cmake + ${CMAKE_CURRENT_BINARY_DIR}/ggml-version.cmake + DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ggml) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index ae1cd23376c..8d2b948fb56 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -250,6 +250,17 @@ function(ggml_add_backend_library backend) target_compile_definitions(${backend} PRIVATE GGML_BACKEND_BUILD) target_compile_definitions(${backend} PUBLIC GGML_BACKEND_SHARED) endif() + + if(NOT GGML_AVAILABLE_BACKENDS) + set(GGML_AVAILABLE_BACKENDS "${backend}" + CACHE INTERNAL "List of backends for cmake package") + else() + list(FIND GGML_AVAILABLE_BACKENDS "${backend}" has_backend) + if(has_backend EQUAL -1) + set(GGML_AVAILABLE_BACKENDS "${GGML_AVAILABLE_BACKENDS};${backend}" + CACHE INTERNAL "List of backends for cmake package") + endif() + endif() endfunction() function(ggml_add_backend backend) From 8639c003a93704b9b3511bc4480bee98e50ec7c7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 26 Jan 2025 20:06:16 +0200 Subject: [PATCH 31/64] metal : use residency sets (llama/11427) * metal : use residency sets ggml-ci * metal : restore commandBufferWithUnretainedReferences calls [no ci] * metal : release descriptors ggml-ci * metal : check env GGML_METAL_NO_RESIDENCY ggml-ci * metal : fix build + clean-up ggml-ci --- ggml/src/ggml-metal/ggml-metal.m | 136 +++++++++++++++++++++++++++---- 1 file changed, 119 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index a85502ee089..c9474345da9 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -19,7 +19,10 @@ // max number of MTLCommandBuffer used to submit a graph for processing #define GGML_METAL_MAX_COMMAND_BUFFERS 8 -#define UNUSED(x) (void)(x) +// create residency sets only on macOS >= 15.0 +#if TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000 +#define GGML_METAL_HAS_RESIDENCY_SETS 1 +#endif // globals @@ -39,6 +42,7 @@ bool has_simdgroup_reduction; bool has_simdgroup_mm; + bool has_residency_sets; bool has_bfloat; bool use_bfloat; @@ -48,6 +52,7 @@ /*.mtl_device_ref_count =*/ 0, /*.has_simdgroup_reduction =*/ false, /*.has_simdgroup_mm =*/ false, + /*.has_residency_sets =*/ false, /*.has_bfloat =*/ false, /*.use_bfloat =*/ false, /*.name =*/ "", @@ -65,6 +70,10 @@ ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7]; +#if defined(GGML_METAL_HAS_RESIDENCY_SETS) + ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == NULL; +#endif + ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6]; @@ -483,6 +492,11 @@ @implementation GGMLMetalClass GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]); ctx->queue = [device newCommandQueue]; + if (ctx->queue == nil) { + GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__); + return NULL; + } + ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); id metal_library; @@ -649,6 +663,7 @@ @implementation GGMLMetalClass GGML_LOG_INFO("%s: simdgroup reduction = %s\n", __func__, ctx_dev->has_simdgroup_reduction ? "true" : "false"); GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, ctx_dev->has_simdgroup_mm ? "true" : "false"); + GGML_LOG_INFO("%s: has residency sets = %s\n", __func__, ctx_dev->has_residency_sets ? "true" : "false"); GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false"); GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, ctx_dev->use_bfloat ? "true" : "false"); GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false"); @@ -1035,8 +1050,70 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) { // multiple buffers are used only to avoid the maximum buffer size limitation when using mmap int n_buffers; struct ggml_backend_metal_buffer buffers[GGML_METAL_MAX_BUFFERS]; + + // optional MTLResidencySet + id rset; }; +// rset init +static bool ggml_backend_metal_buffer_rset_init( + struct ggml_backend_metal_buffer_context * ctx, + struct ggml_backend_metal_device_context * ctx_dev, + id device) { + ctx->rset = nil; + + if (!ctx_dev->has_residency_sets) { + return true; + } + +#if defined(GGML_METAL_HAS_RESIDENCY_SETS) + if (@available(macOS 15.0, *)) { + MTLResidencySetDescriptor * desc = [[MTLResidencySetDescriptor alloc] init]; + desc.label = @"ggml_backend_metal"; + desc.initialCapacity = ctx->n_buffers; + + NSError * error; + ctx->rset = [device newResidencySetWithDescriptor:desc error:&error]; + if (error) { + GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + [desc release]; + return false; + } + + [desc release]; + + for (int i = 0; i < ctx->n_buffers; i++) { + [ctx->rset addAllocation:ctx->buffers[i].metal]; + } + + [ctx->rset commit]; + [ctx->rset requestResidency]; + + return true; + } +#else + GGML_UNUSED(ctx_dev); + GGML_UNUSED(device); +#endif + + return true; +} + +// rset free +static void ggml_backend_metal_buffer_rset_free(struct ggml_backend_metal_buffer_context * ctx) { +#if defined(GGML_METAL_HAS_RESIDENCY_SETS) + if (@available(macOS 15.0, *)) { + if (ctx->rset) { + [ctx->rset endResidency]; + [ctx->rset removeAllAllocations]; + [ctx->rset release]; + } + } +#else + GGML_UNUSED(ctx); +#endif +} + // finds the Metal buffer that contains the tensor data on the GPU device // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the // Metal buffer based on the host memory pointer @@ -4176,6 +4253,8 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) for (int i = 0; i < ctx->n_buffers; i++) { [ctx->buffers[i].metal release]; } + + ggml_backend_metal_buffer_rset_free(ctx); ggml_backend_metal_device_rel(buffer->buft->device->context); if (ctx->owned) { @@ -4198,19 +4277,19 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) static void ggml_backend_metal_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { memset((char *)tensor->data + offset, value, size); - UNUSED(buffer); + GGML_UNUSED(buffer); } static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { memcpy((char *)tensor->data + offset, data, size); - UNUSED(buffer); + GGML_UNUSED(buffer); } static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { memcpy(data, (const char *)tensor->data + offset, size); - UNUSED(buffer); + GGML_UNUSED(buffer); } static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) { @@ -4220,7 +4299,7 @@ static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, c } return false; - UNUSED(buffer); + GGML_UNUSED(buffer); } static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { @@ -4246,7 +4325,7 @@ static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_ static const char * ggml_backend_metal_buffer_type_get_name(ggml_backend_buffer_type_t buft) { return "Metal"; - UNUSED(buft); + GGML_UNUSED(buft); } static void ggml_backend_metal_log_allocated_size(id device, size_t size_aligned) { @@ -4270,8 +4349,8 @@ static void ggml_backend_metal_log_allocated_size(id device, size_t s } #endif #endif - UNUSED(device); - UNUSED(size_aligned); + GGML_UNUSED(device); + GGML_UNUSED(size_aligned); } static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { @@ -4284,7 +4363,8 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba size_aligned += (size_page - (size_aligned % size_page)); } - id device = ggml_backend_metal_device_acq(buft->device->context); + struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)buft->device->context; + id device = ggml_backend_metal_device_acq(ctx_dev); ctx->all_data = ggml_metal_host_malloc(size_aligned); ctx->all_size = size_aligned; @@ -4307,7 +4387,14 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) { GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); free(ctx); - ggml_backend_metal_device_rel(buft->device->context); + ggml_backend_metal_device_rel(ctx_dev); + return NULL; + } + + if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) { + GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__); + free(ctx); + ggml_backend_metal_device_rel(ctx_dev); return NULL; } @@ -4318,7 +4405,7 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { return 32; - UNUSED(buft); + GGML_UNUSED(buft); } static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { @@ -4328,13 +4415,13 @@ static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_ty return max_size; - UNUSED(buft); + GGML_UNUSED(buft); } static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) { return true; - UNUSED(buft); + GGML_UNUSED(buft); } ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) { @@ -4357,7 +4444,7 @@ ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) { static const char * ggml_backend_metal_buffer_from_ptr_type_get_name(ggml_backend_buffer_type_t buft) { return "Metal_Mapped"; - UNUSED(buft); + GGML_UNUSED(buft); } static ggml_backend_buffer_type_t ggml_backend_metal_buffer_from_ptr_type(void) { @@ -4400,7 +4487,8 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz size_aligned += (size_page - (size_aligned % size_page)); } - id device = ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main); + struct ggml_backend_metal_device_context * ctx_dev = &g_ggml_ctx_dev_main; + id device = ggml_backend_metal_device_acq(ctx_dev); // the buffer fits into the max buffer size allowed by the device if (size_aligned <= device.maxBufferLength) { @@ -4453,6 +4541,13 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz } } + if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) { + GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__); + free(ctx); + ggml_backend_metal_device_rel(ctx_dev); + return NULL; + } + return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size); } @@ -4461,7 +4556,7 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz static const char * ggml_backend_metal_name(ggml_backend_t backend) { return "Metal"; - UNUSED(backend); + GGML_UNUSED(backend); } static void ggml_backend_metal_free(ggml_backend_t backend) { @@ -4766,6 +4861,13 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back } } + if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) { + GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__); + free(ctx); + ggml_backend_metal_device_rel(ctx_dev); + return NULL; + } + return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size); } @@ -4779,7 +4881,7 @@ static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name || buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name; - UNUSED(dev); + GGML_UNUSED(dev); } static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { From 70c4038842597ee8d7f5c547f5b821f19fe58f52 Mon Sep 17 00:00:00 2001 From: Ihar Hrachyshka Date: Mon, 27 Jan 2025 02:41:59 -0500 Subject: [PATCH 32/64] metal: Handle null returned from MTLCreateSystemDefaultDevice() (llama/11441) This fixes segmentation fault error when running tests when no metal devices are available (for example, when not linked with Core Graphics framework or otherwise). --- ggml/src/ggml-metal/ggml-metal.m | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index c9474345da9..76f8e429178 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -64,7 +64,9 @@ if (ctx->mtl_device == nil) { ctx->mtl_device = MTLCreateSystemDefaultDevice(); + } + if (ctx->mtl_device) { ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7]; ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; @@ -99,8 +101,10 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte ctx->mtl_device_ref_count--; if (ctx->mtl_device_ref_count == 0) { - [ctx->mtl_device release]; - ctx->mtl_device = nil; + if (ctx->mtl_device) { + [ctx->mtl_device release]; + ctx->mtl_device = nil; + } } } From 028511d349d4394fe12b5490cb733253b00d9371 Mon Sep 17 00:00:00 2001 From: Haus1 Date: Mon, 27 Jan 2025 08:58:17 -0500 Subject: [PATCH 33/64] AMD: parse the architecture as supplied by gcnArchName (llama/11244) The value provided by minor doesn't include stepping for AMD, parse the value returned by gcnArchName instead to retrieve an accurate ID. --- ggml/src/ggml-cuda/common.cuh | 20 +++++----- ggml/src/ggml-cuda/ggml-cuda.cu | 67 ++++++++++++++++++++++++++++++++- 2 files changed, 75 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index bb6120568f2..a66322da05a 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -46,20 +46,20 @@ #define GGML_CUDA_CC_VOLTA 700 #define GGML_CUDA_CC_TURING 750 #define GGML_CUDA_CC_AMPERE 800 -#define GGML_CUDA_CC_OFFSET_AMD 1000000 +#define GGML_CUDA_CC_OFFSET_AMD 0x1000000 // GCN/CNDA, wave size is 64 -#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 803) // Tonga, Fiji, Polaris, minimum for fast fp16 -#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 900) // Vega56/64, minimum for fp16 dual issue -#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 906) // MI50/Radeon VII, minimum for dp4a -#define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 908) // MI100, minimum for MFMA, acc registers -#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 910) // MI210, minimum acc register renameing -#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 942) // MI300 +#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16 +#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue +#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a +#define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers +#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing +#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300 // RNDA removes MFMA, dp4a, xnack, acc registers, wave size is 32 -#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 1010) // RX 5000 -#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 1030) // RX 6000, minimum for dp4a -#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 1100) // RX 7000, minimum for WMMA +#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000 +#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a +#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA #define GGML_CUDA_CC_QY1 210 #define GGML_CUDA_CC_QY2 220 diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 85178abd282..402f37e85d4 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -119,6 +119,55 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) #endif } +#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) +static int ggml_cuda_parse_id(char devName[]) { + // A list of possible Target IDs can be found under the rocclr/clr repo in device.cpp + // these values are not stable so this is susceptible to breakage + // https://github.com/ROCm/clr/blob/amd-staging/rocclr/device/device.cpp + int archMajor = 0x0; + int archMinor = 0x0; + int archNum = GGML_CUDA_CC_OFFSET_AMD; + int archLen = strlen(devName); + char archName[archLen + 1]; + + // strip leading 'gfx' while copying into our buffer + if (archLen > 3) { + strcpy(archName, &devName[3]); + archLen -= 3; + } + + // trim trailing :xnack- or :sramecc- statuses + archLen = strcspn(archName, ":"); + archName[archLen] = '\0'; + + // tease out the version information + if (archLen > 8) { + // versions labeled generic use '-' as delimiter + // strip the trailing "-generic" then iterate through what remains + if ((strstr(archName, "-generic"))) { + archName[archLen - 8] = '\0'; + char * pch; + if ((pch = strtok(archName, "-"))) { + archMajor = (int)strtoul(pch, 0, 16); + if ((pch = strtok(NULL, "-"))) { + archMinor = 0x10 * (int)strtoul(pch, 0, 16); + } + } + } + } else if (archLen >= 3) { + // last two digits should be the minor * 0x10 + stepping + archMinor = (int)strtoul(&archName[archLen - 2], 0, 16); + archName[archLen - 2] = '\0'; + + // only the major version remains + archMajor = (int)strtoul(archName, 0, 16); + } + archNum += archMajor * 0x100; + archNum += archMinor; + return archNum; +} +#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) + static ggml_cuda_device_info ggml_cuda_init() { #ifdef __HIP_PLATFORM_AMD__ // Workaround for a rocBLAS bug when using multiple graphics cards: @@ -169,7 +218,6 @@ static ggml_cuda_device_info ggml_cuda_init() { cudaDeviceProp prop; CUDA_CHECK(cudaGetDeviceProperties(&prop, id)); - GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no"); info.default_tensor_split[id] = total_vram; total_vram += prop.totalGlobalMem; @@ -178,10 +226,25 @@ static ggml_cuda_device_info ggml_cuda_init() { info.devices[id].smpb = prop.sharedMemPerBlock; #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) info.devices[id].smpbo = prop.sharedMemPerBlock; - info.devices[id].cc = 100*prop.major + 10*prop.minor + GGML_CUDA_CC_OFFSET_AMD; + + info.devices[id].cc = ggml_cuda_parse_id(prop.gcnArchName); + if ((info.devices[id].cc & 0xff00) == 0x0) { + GGML_LOG_WARN("invalid architecture ID received for device %d %s: %s cc %d.%d\n", + id, prop.name, prop.gcnArchName, prop.major, prop.minor); + + // Fallback to prop.major and prop.minor + if (prop.major > 0) { + info.devices[id].cc = GGML_CUDA_CC_OFFSET_AMD + prop.major * 0x100; + info.devices[id].cc += prop.minor * 0x10; + } + } + GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s\n", + id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff, device_vmm ? "yes" : "no"); #else info.devices[id].smpbo = prop.sharedMemPerBlockOptin; info.devices[id].cc = 100*prop.major + 10*prop.minor; + GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n", + id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no"); #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) } From 22e3df0afacf76ad98cd382cbe88055367b06dd5 Mon Sep 17 00:00:00 2001 From: Akarshan Biswas Date: Tue, 28 Jan 2025 15:26:58 +0530 Subject: [PATCH 34/64] SYCL : SOFTMAX F16 mask support and other fixes (llama/11261) Implemented ggml_sycl_op_soft_max() F16 src1(mask) support for which a pragma deprecation warning was added during #5021. To do this, had to decouple it from ggml_sycl_op_flatten which always considered src1 to be of fp32 type(many OP functions are dependent on it). * SYCL: SOFTMAX F16 mask support and other fixes * test-backend-ops: Add F16 mask test cases --- ggml/src/ggml-sycl/ggml-sycl.cpp | 6 +--- ggml/src/ggml-sycl/softmax.cpp | 56 +++++++++++++++++++------------- ggml/src/ggml-sycl/softmax.hpp | 6 +--- 3 files changed, 35 insertions(+), 33 deletions(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index ed4d8bb8bae..2984ed82e8a 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3878,10 +3878,6 @@ static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_diag_mask_inf); } -static void ggml_sycl_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_soft_max); -} - static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(ggml_is_contiguous(dst->src[0])); // TODO: this restriction is temporary until non-cont support is implemented ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_rope); @@ -4090,7 +4086,7 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens ggml_sycl_diag_mask_inf(ctx, dst); break; case GGML_OP_SOFT_MAX: - ggml_sycl_soft_max(ctx, dst); + ggml_sycl_op_soft_max(ctx, dst); break; case GGML_OP_ROPE: ggml_sycl_rope(ctx, dst); diff --git a/ggml/src/ggml-sycl/softmax.cpp b/ggml/src/ggml-sycl/softmax.cpp index a9b3fce0dc4..563e0655f55 100644 --- a/ggml/src/ggml-sycl/softmax.cpp +++ b/ggml/src/ggml-sycl/softmax.cpp @@ -1,7 +1,7 @@ -#include "norm.hpp" +#include "softmax.hpp" -template -static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par, +template +static void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; @@ -29,7 +29,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const slope = sycl::pow(base, float(exp)); } - float *vals = vals_smem ? buf + std::max(nwarps, WARP_SIZE) : dst + rowx * ncols; + float *vals = vals_smem ? buf + sycl::max(nwarps, WARP_SIZE) : dst + rowx * ncols; float max_val = -INFINITY; for (int col0 = 0; col0 < ncols; col0 += block_size) { @@ -42,7 +42,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const const int ix = rowx*ncols + col; const int iy = rowy*ncols + col; - const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f); + const float val = x[ix]*scale + (mask ? slope*static_cast(mask[iy]) : 0.0f); vals[col] = val; max_val = sycl::max(max_val, val); @@ -65,7 +65,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const item_ct1.barrier(sycl::access::fence_space::local_space); max_val = buf[lane_id]; for (size_t i = 1; i < nreduce; i += 1) { - max_val = std::max(max_val, buf[lane_id + i * WARP_SIZE]); + max_val = sycl::max(max_val, buf[lane_id + i * WARP_SIZE]); } max_val = warp_reduce_max(max_val, item_ct1); } @@ -122,8 +122,8 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const } } -template -static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par, +template +static void soft_max_f32_submitter(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims, const size_t n_local_scratch, queue_ptr stream) { @@ -141,7 +141,8 @@ static void soft_max_f32_submitter(const float * x, const float * mask, float * }); } -static void soft_max_f32_sycl(const float * x, const float * mask, +template +static void soft_max_f32_sycl(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, queue_ptr stream, int device) { @@ -223,22 +224,16 @@ static void soft_max_f32_sycl(const float * x, const float * mask, } } -void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, - const queue_ptr &main_stream) { +void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); -#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support") -#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021") - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + GGML_ASSERT(!dst->src[1] || dst->src[1]->type == GGML_TYPE_F16 || dst->src[1]->type == GGML_TYPE_F32); // src1 contains mask and it is optional - const int64_t ne00 = src0->ne[0]; - const int64_t nrows_x = ggml_nrows(src0); - const int64_t nrows_y = src0->ne[1]; + const int64_t ne00 = dst->src[0]->ne[0]; + const int64_t nrows_x = ggml_nrows(dst->src[0]); + const int64_t nrows_y = dst->src[0]->ne[1]; float scale = 1.0f; float max_bias = 0.0f; @@ -246,6 +241,21 @@ void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *s memcpy(&scale, dst->op_params + 0, sizeof(float)); memcpy(&max_bias, dst->op_params + 1, sizeof(float)); - soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, - nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); + + ggml_sycl_set_device(ctx.device); + dpct::queue_ptr main_stream = ctx.stream(); + + if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F16) { + const sycl::half * src1_dd = static_cast(dst->src[1]->data); + soft_max_f32_sycl(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, + main_stream, ctx.device); + } else if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F32) { + const float * src1_dd = static_cast(dst->src[1]->data); + soft_max_f32_sycl(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); + } else { + /* mask unavailable */ + soft_max_f32_sycl(src0_dd, nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); + } } diff --git a/ggml/src/ggml-sycl/softmax.hpp b/ggml/src/ggml-sycl/softmax.hpp index bdb8f712e32..2cf8582ec92 100644 --- a/ggml/src/ggml-sycl/softmax.hpp +++ b/ggml/src/ggml-sycl/softmax.hpp @@ -15,10 +15,6 @@ #include "common.hpp" -void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, - const queue_ptr &main_stream); +void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, ggml_tensor *dst); #endif // GGML_SYCL_SOFTMAX_HPP From b2cfef655b103fd8257515cd67217e60aab6eb58 Mon Sep 17 00:00:00 2001 From: someone13574 <81528246+someone13574@users.noreply.github.com> Date: Tue, 28 Jan 2025 09:15:34 -0500 Subject: [PATCH 35/64] cmake : don't fail on `GGML_CPU=OFF` (llama/11457) --- ggml/src/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 8d2b948fb56..56670913521 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -308,7 +308,7 @@ if (GGML_CPU_ALL_VARIANTS) # MSVC doesn't support AMX ggml_add_cpu_backend_variant(sapphirerapids AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8) endif() -else () +elseif (GGML_CPU) ggml_add_cpu_backend_variant_impl("") endif() From 115716d10996cc4efe6a5d9e53cebe230cd4c7c0 Mon Sep 17 00:00:00 2001 From: Nikita Sarychev <42014488+sARY77@users.noreply.github.com> Date: Tue, 28 Jan 2025 07:42:20 -0800 Subject: [PATCH 36/64] HIP: Only call rocblas_initialize on rocblas versions with the multiple instantation bug (llama/11080) This disables the workaround on rocblas fixed versions (>=4.0.0) to eliminate the runtime cost and unnecessary VRAM allocation of loading all tensile objects. --- ggml/src/ggml-cuda/ggml-cuda.cu | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 402f37e85d4..de3f9c2ca1e 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -42,6 +42,7 @@ #include #include #include +#include #include #include #include @@ -172,8 +173,25 @@ static ggml_cuda_device_info ggml_cuda_init() { #ifdef __HIP_PLATFORM_AMD__ // Workaround for a rocBLAS bug when using multiple graphics cards: // https://github.com/ROCmSoftwarePlatform/rocBLAS/issues/1346 - rocblas_initialize(); - CUDA_CHECK(cudaDeviceSynchronize()); + { + int major_version = 0; + size_t version_length = 0; + if (rocblas_get_version_string_size(&version_length) == rocblas_status_success) { + std::string version(version_length, '\0'); + if (rocblas_get_version_string(version.data(), version.size()) == rocblas_status_success) { + version.resize(::strlen(version.c_str())); + int parsed_value = 0; + if (std::from_chars(version.c_str(), version.c_str() + version.length(), parsed_value).ec == std::errc()) { + major_version = parsed_value; + } + } + } + if (major_version < 4) { + GGML_LOG_DEBUG(GGML_CUDA_NAME " calling rocblas_initialize as a workaround for a rocBLAS bug\n"); + rocblas_initialize(); + CUDA_CHECK(cudaDeviceSynchronize()); + } + } #endif ggml_cuda_device_info info = {}; From 682a6f5f87565aaca583c3f03e12708dbdb3efae Mon Sep 17 00:00:00 2001 From: uvos Date: Tue, 28 Jan 2025 23:06:32 +0100 Subject: [PATCH 37/64] HIP: Supress transformation warning in softmax.cu loops with bounds not known at compile time can not be unrolled. when ncols_template == 0, the bounds of the loop are not constexpr, thus llvm cant unroll the loops here. --- ggml/src/ggml-cuda/softmax.cu | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index 9aa4b848938..da377200e01 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -13,6 +13,12 @@ __device__ float __forceinline__ t2f32(half val) { return __half2float(val); } +// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled. +// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here. +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpass-failed" +#endif template static __global__ void soft_max_f32( const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, @@ -118,6 +124,9 @@ static __global__ void soft_max_f32( dst[col] = vals[col] * inv_sum; } } +#ifdef __clang__ +#pragma clang diagnostic pop +#endif static __global__ void soft_max_back_f32( const float * grad, const float * dstf, float * dst, const int ncols, const float scale) { From 75e7d0585e863e20e9e06feb6a7390d6507436ff Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Wed, 29 Jan 2025 09:26:50 -0600 Subject: [PATCH 38/64] vulkan: Catch pipeline creation failure and print an error message (llama/11436) * vulkan: Catch pipeline creation failure and print an error message Also, fix some warnings from my on-demand compile change. * vulkan: fix pipeline creation logging --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index a9d6b923cbf..6c7e606507b 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -774,12 +774,12 @@ static uint32_t compile_count = 0; static std::mutex compile_count_mutex; static std::condition_variable compile_count_cond; -static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint, - uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, std::vector specialization_constants, - uint32_t align, bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) { - VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << - ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << - ", " << disable_robustness << ", " << require_full_subgroups << ", " << required_subgroup_size << ")"); +static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, size_t spv_size, const void* spv_data, const std::string entrypoint, + uint32_t parameter_count, std::array wg_denoms, std::vector specialization_constants, + bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) { + VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << pipeline->name << ", " << entrypoint << ", " << parameter_count << + ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << + disable_robustness << ", " << require_full_subgroups << ", " << required_subgroup_size << ")"); GGML_ASSERT(parameter_count > 0); GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT @@ -864,7 +864,13 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin compute_pipeline_create_info.setPNext(&rci); } - pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value; + try { + pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value; + } catch (const vk::SystemError& e) { + std::cerr << "ggml_vulkan: Compute pipeline creation failed for " << pipeline->name << std::endl; + std::cerr << "ggml_vulkan: " << e.what() << std::endl; + throw e; + } pipeline->compiled = true; { @@ -1560,8 +1566,8 @@ static void ggml_vk_load_shaders(vk_device& device) { } compile_count++; } - compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, - parameter_count, push_constant_size, wg_denoms, specialization_constants, align, disable_robustness, require_full_subgroups, required_subgroup_size)); + compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint, + parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); }; #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) From 80fa576254aa136231457236bb00ec085249ee80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9my=20Oudompheng?= Date: Wed, 29 Jan 2025 18:29:39 +0100 Subject: [PATCH 39/64] vulkan: implement initial support for IQ2 and IQ3 quantizations (llama/11360) * vulkan: initial support for IQ3_S * vulkan: initial support for IQ3_XXS * vulkan: initial support for IQ2_XXS * vulkan: initial support for IQ2_XS * vulkan: optimize Q3_K by removing branches * vulkan: implement dequantize variants for coopmat2 * vulkan: initial support for IQ2_S * vulkan: vertically realign code * port failing dequant callbacks from mul_mm * Fix array length mismatches * vulkan: avoid using workgroup size before it is referenced * tests: increase timeout for Vulkan llvmpipe backend --------- Co-authored-by: Jeff Bolz --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 157 +++- .../vulkan-shaders/copy_from_quant.comp | 4 +- .../vulkan-shaders/copy_to_quant.comp | 4 +- .../vulkan-shaders/dequant_funcs.comp | 218 +++++- .../vulkan-shaders/dequant_funcs_cm2.comp | 164 ++++ .../vulkan-shaders/dequant_iq2_s.comp | 44 ++ .../vulkan-shaders/dequant_iq2_xs.comp | 43 + .../vulkan-shaders/dequant_iq2_xxs.comp | 48 ++ .../vulkan-shaders/dequant_iq3_s.comp | 39 + .../vulkan-shaders/dequant_iq3_xxs.comp | 49 ++ .../vulkan-shaders/dequant_iq4_nl.comp | 2 +- .../vulkan-shaders/flash_attn_cm2.comp | 4 +- .../vulkan-shaders/get_rows_quant.comp | 4 +- .../vulkan-shaders/mul_mat_vec.comp | 4 +- .../ggml-vulkan/vulkan-shaders/mul_mm.comp | 122 ++- .../vulkan-shaders/mul_mm_cm2.comp | 4 +- .../src/ggml-vulkan/vulkan-shaders/types.comp | 738 +++++++++++++++++- .../vulkan-shaders/vulkan-shaders-gen.cpp | 5 + 18 files changed, 1614 insertions(+), 39 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 6c7e606507b..9ca3959abf1 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1616,6 +1616,11 @@ static void ggml_vk_load_shaders(vk_device& device) { //CREATE_FA(GGML_TYPE_Q4_K, q4_k) //CREATE_FA(GGML_TYPE_Q5_K, q5_k) //CREATE_FA(GGML_TYPE_Q6_K, q6_k) + //CREATE_FA(GGML_TYPE_IQ2_XXS, iq2_xxs) + //CREATE_FA(GGML_TYPE_IQ2_XS, iq2_xs) + //CREATE_FA(GGML_TYPE_IQ2_S, iq2_s) + //CREATE_FA(GGML_TYPE_IQ3_XXS, iq3_xxs) + //CREATE_FA(GGML_TYPE_IQ3_S, iq3_s) CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl) #undef CREATE_FA @@ -1644,7 +1649,12 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) @@ -1657,7 +1667,12 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) #undef CREATE_MM #undef CREATE_MM2 } else @@ -1705,7 +1720,12 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } else { CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); @@ -1718,7 +1738,12 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines. @@ -1739,7 +1764,12 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); } else { CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); @@ -1752,7 +1782,12 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); } } #undef CREATE_MM2 @@ -1796,7 +1831,12 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines. if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) { @@ -1815,7 +1855,12 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); } #undef CREATE_MM2 #undef CREATE_MM @@ -1851,7 +1896,12 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines. if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) { @@ -1870,7 +1920,12 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); } #undef CREATE_MM } @@ -1901,7 +1956,12 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f32_f32_len, mul_mat_vec_iq2_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f32_f32_len, mul_mat_vec_iq2_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f32_f32_len, mul_mat_vec_iq2_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f32_f32_len, mul_mat_vec_iq3_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); @@ -1915,7 +1975,12 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f16_f32_len, mul_mat_vec_iq2_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f16_f32_len, mul_mat_vec_iq2_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f16_f32_len, mul_mat_vec_iq2_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f16_f32_len, mul_mat_vec_iq3_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true); } ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); @@ -1930,7 +1995,12 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true); // dequant shaders ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); @@ -1944,7 +2014,12 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_k", dequant_q4_k_len, dequant_q4_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_k", dequant_q5_k_len, dequant_q5_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_k", dequant_q6_k_len, dequant_q6_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XXS], "dequant_iq2_xxs", dequant_iq2_xxs_len, dequant_iq2_xxs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XS], "dequant_iq2_xs", dequant_iq2_xs_len, dequant_iq2_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_S], "dequant_iq2_s", dequant_iq2_s_len, dequant_iq2_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_XXS], "dequant_iq3_xxs", dequant_iq3_xxs_len, dequant_iq3_xxs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_S], "dequant_iq3_s", dequant_iq3_s_len, dequant_iq3_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); // get_rows ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); @@ -1954,7 +2029,12 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs", get_rows_iq2_xxs_len, get_rows_iq2_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XS], "get_rows_iq2_xs", get_rows_iq2_xs_len, get_rows_iq2_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_S], "get_rows_iq2_s", get_rows_iq2_s_len, get_rows_iq2_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_XXS], "get_rows_iq3_xxs", get_rows_iq3_xxs_len, get_rows_iq3_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_S], "get_rows_iq3_s", get_rows_iq3_s_len, get_rows_iq3_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); @@ -1963,7 +2043,12 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs_f32", get_rows_iq2_xxs_f32_len, get_rows_iq2_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XS], "get_rows_iq2_xs_f32", get_rows_iq2_xs_f32_len, get_rows_iq2_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_S], "get_rows_iq2_s_f32", get_rows_iq2_s_f32_len, get_rows_iq2_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_XXS], "get_rows_iq3_xxs_f32", get_rows_iq3_xxs_f32_len, get_rows_iq3_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_S], "get_rows_iq3_s_f32", get_rows_iq3_s_f32_len, get_rows_iq3_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); @@ -2890,6 +2975,11 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_NL: break; default: @@ -2938,6 +3028,11 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_NL: break; default: @@ -2969,6 +3064,11 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_NL: break; default: @@ -3012,6 +3112,11 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_NL: break; default: @@ -3038,6 +3143,11 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_NL: break; default: @@ -7907,6 +8017,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_NL: break; default: @@ -7975,6 +8090,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm //case GGML_TYPE_Q4_K: //case GGML_TYPE_Q5_K: //case GGML_TYPE_Q6_K: + //case GGML_TYPE_IQ2_XXS: + //case GGML_TYPE_IQ2_XS: + //case GGML_TYPE_IQ2_S: + //case GGML_TYPE_IQ3_XXS: + //case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_NL: break; default: @@ -7992,6 +8112,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_NL: return true; default: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp index c09bf496bf3..aeae5400dfc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp @@ -12,8 +12,8 @@ layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; #endif void main() { -#if defined(DATA_A_IQ4_NL) - init_iq4nl_shmem(); +#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL) + init_iq_shmem(gl_WorkGroupSize); if (gl_LocalInvocationIndex.x != 0) { return; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp index ccf5b980a8a..d4b068e6186 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp @@ -217,8 +217,8 @@ void quantize(uint dst_idx, uint src_idx) #endif void main() { -#if defined(DATA_A_IQ4_NL) - init_iq4nl_shmem(); +#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL) + init_iq_shmem(gl_WorkGroupSize); if (gl_LocalInvocationIndex.x != 0) { return; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp index 91bb8f8db61..ee68775317b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp @@ -88,6 +88,222 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { } #endif +#if defined(DATA_A_IQ2_XXS) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = (iqs / 8) % 4; + const uint qs = data_a[a_offset + ib].qs[8 * ib32 + ib8]; + // Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale) + const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[4 * ib32 + 2], + data_a_packed16[a_offset + ib].qs[4 * ib32 + 3])); + const float db = 0.25 * (0.5 + (signs >> 28)); + const uint sign7 = bitfieldExtract(signs, 7 * int(ib8), 7); + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq2xxs_grid[qs][(iqs % 8) / 4] >> (8 * (iqs % 4))); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + return db * vec2( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0) + ); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = (iqs / 8) % 4; + const uint qs = data_a[a_offset + ib].qs[8 * ib32 + ib8]; + // Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale) + const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[4 * ib32 + 2], + data_a_packed16[a_offset + ib].qs[4 * ib32 + 3])); + const float db = 0.25 * (0.5 + (signs >> 28)); + const uint sign7 = bitfieldExtract(signs, 7 * int(ib8), 7); + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq2xxs_grid[qs][(iqs % 8) / 4] >> (8 * (iqs % 4))); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + bool sign2 = (sign & 4) != 0; + bool sign3 = (sign & 8) != 0; + return db * vec4( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0), + grid.z * (sign2 ? -1.0 : 1.0), + grid.w * (sign3 ? -1.0 : 1.0) + ); +} +#endif + +#if defined(DATA_A_IQ2_XS) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint scale = (data_a[a_offset + ib].scales[iqs / 32] >> (4 * ((iqs / 16) & 1))) & 0xf; + const uint qs = data_a[a_offset + ib].qs[iqs / 8]; + const float db = 0.25 * (0.5 + scale); + const uint sign7 = qs >> 9; + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq2xs_grid[qs & 511][(iqs % 8) / 4] >> (8 * (iqs % 4))); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + return db * vec2( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0) + ); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint scale = (data_a[a_offset + ib].scales[iqs / 32] >> (4 * ((iqs / 16) & 1))) & 0xf; + const uint qs = data_a[a_offset + ib].qs[iqs / 8]; + const float db = 0.25 * (0.5 + scale); + const uint sign7 = qs >> 9; + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq2xs_grid[qs & 511][(iqs % 8) / 4] >> (8 * (iqs % 4))); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + bool sign2 = (sign & 4) != 0; + bool sign3 = (sign & 8) != 0; + return db * vec4( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0), + grid.z * (sign2 ? -1.0 : 1.0), + grid.w * (sign3 ? -1.0 : 1.0) + ); +} +#endif + +#if defined(DATA_A_IQ2_S) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = iqs / 8; + + const uint scale = (data_a[a_offset + ib].scales[ib32] >> (4 * ((iqs / 16) & 1))) & 0xf; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const uint qh = data_a[a_offset + ib].qh[ib32]; + const uint qhshift = 2 * (ib8 % 4); + const uint sign = data_a[a_offset + ib].qs[QUANT_K / 8 + ib8] >> (iqs % 8); + + const float db = 0.25 * (0.5 + scale); + const u8vec4 grid = unpack8(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(iqs % 8) / 4]); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + return db * vec2( + grid[iqs % 4] * (sign0 ? -1.0 : 1.0), + grid[(iqs % 4) + 1] * (sign1 ? -1.0 : 1.0) + ); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = iqs / 8; + + const uint scale = (data_a[a_offset + ib].scales[ib32] >> (4 * ((iqs / 16) & 1))) & 0xf; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const uint qh = data_a[a_offset + ib].qh[ib32]; + const uint qhshift = 2 * (ib8 % 4); + const uint sign = data_a[a_offset + ib].qs[QUANT_K / 8 + ib8] >> (iqs % 8); + + const float db = 0.25 * (0.5 + scale); + const u8vec4 grid = unpack8(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(iqs % 8) / 4]); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + bool sign2 = (sign & 4) != 0; + bool sign3 = (sign & 8) != 0; + return db * vec4( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0), + grid.z * (sign2 ? -1.0 : 1.0), + grid.w * (sign3 ? -1.0 : 1.0) + ); +} +#endif + +#if defined(DATA_A_IQ3_XXS) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint ib4 = iqs / 4; + const uint ib32 = iqs / 32; + const uint is = QUANT_K / 4 + 4 * ib32; + const uint qs = data_a[a_offset + ib].qs[ib4]; + // Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale) + const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[is / 2], + data_a_packed16[a_offset + ib].qs[is / 2 + 1])); + const float db = 0.5 * (0.5 + (signs >> 28)); + const uint sign7 = bitfieldExtract(signs, 7 * (int(ib4 / 2) % 4), 7); + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq3xxs_grid[qs] >> (8 * (iqs % 4))); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + return db * vec2( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0) + ); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib4 = iqs / 4; + const uint ib32 = iqs / 32; + const uint is = QUANT_K / 4 + 4 * ib32; + const uint qs = data_a[a_offset + ib].qs[ib4]; + const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[is / 2], + data_a_packed16[a_offset + ib].qs[is / 2 + 1])); + const float db = 0.5 * (0.5 + (signs >> 28)); + const uint sign7 = bitfieldExtract(signs, 7 * (int(ib4 / 2) % 4), 7); + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq3xxs_grid[qs]); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + bool sign2 = (sign & 4) != 0; + bool sign3 = (sign & 8) != 0; + return db * vec4( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0), + grid.z * (sign2 ? -1.0 : 1.0), + grid.w * (sign3 ? -1.0 : 1.0) + ); +} +#endif + +#if defined(DATA_A_IQ3_S) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint qs = data_a[a_offset + ib].qs[iqs / 4]; + const uint qh = data_a[a_offset + ib].qh[iqs / 32]; + const uint sign = data_a[a_offset + ib].signs[iqs / 8] >> (iqs % 8); + const uint scale = data_a[a_offset + ib].scales[iqs / 64]; + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + const float db = 1 + 2 * ((scale >> (4 * ((iqs / 32) & 1))) & 0xf); + const uint32_t grid = iq3s_grid[qs | ((qh << (8 - ((iqs / 4) % 8))) & 256)] >> (8 * (iqs % 4)); + return db * vec2( + int(grid & 0xFF) * (sign0 ? -1.0 : 1.0), + int((grid >> 8) & 0xFF) * (sign1 ? -1.0 : 1.0) + ); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib4 = iqs / 4; + const uint ib32 = iqs / 32; + const uint qs = data_a[a_offset + ib].qs[ib4]; + const uint qh = data_a[a_offset + ib].qh[ib32]; + const uint sign = data_a[a_offset + ib].signs[iqs / 8] >> (iqs % 8); + const uint scale = data_a[a_offset + ib].scales[ib32 / 2]; + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + bool sign2 = (sign & 4) != 0; + bool sign3 = (sign & 8) != 0; + const float db = 1 + 2 * ((scale >> (4 * (ib32 & 1))) & 0xf); + const uint32_t grid = iq3s_grid[qs | ((qh << (8 - ib4 % 8)) & 256)] >> (8 * (iqs % 4)); + return db * vec4( + int(grid & 0xFF) * (sign0 ? -1.0 : 1.0), + int((grid >> 8) & 0xFF) * (sign1 ? -1.0 : 1.0), + int((grid >> 16) & 0xFF) * (sign2 ? -1.0 : 1.0), + int((grid >> 24) & 0xFF) * (sign3 ? -1.0 : 1.0) + ); +} +#endif + #if defined(DATA_A_IQ4_NL) vec2 dequantize(uint ib, uint iqs, uint a_offset) { const uint vui = uint(data_a[a_offset + ib].qs[iqs]); @@ -105,7 +321,7 @@ vec2 get_dm(uint ib, uint a_offset) { } #endif -#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL) +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL) vec2 get_dm(uint ib, uint a_offset) { return vec2(float(data_a[a_offset + ib].d), 0); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp index 175e31fa76e..974efd3f9a6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp @@ -301,6 +301,160 @@ float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2 return ret; } +#if defined(DATA_A_IQ2_XXS) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS { + block_iq2_xxs block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS_packed16 { + block_iq2_xxs_packed16 block; +}; + +float16_t dequantFuncIQ2_XXS(const in decodeBufIQ2_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufIQ2_XXS_packed16 bl16 = decodeBufIQ2_XXS_packed16(bl); + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint ib32 = (idx & 0xE0) >> 5; // 0..7 + const uint ib8 = (idx & 0x18) >> 3; // 0..3 + const uint iqs = 8 * ib32 + ib8; + + const uint8_t qs = bl.block.qs[iqs]; + const uint signscale = pack32(u16vec2(bl16.block.qs[4*ib32+2], bl16.block.qs[4*ib32+3])); + + const float16_t dscale = bl.block.d * 0.25hf * (0.5hf + float16_t(signscale >> 28)); + uint sign = bitfieldExtract(signscale, 7 * int(ib8), 7); + sign |= bitCount(sign) << 7; + + const uint8_t g = unpack8(iq2xxs_grid[qs][(idx & 4) >> 2])[idx & 3]; + + float16_t ret = dscale * float16_t(g) * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf); + + return ret; +} +#endif + +#if defined(DATA_A_IQ2_XS) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XS { + block_iq2_xs block; +}; + +float16_t dequantFuncIQ2_XS(const in decodeBufIQ2_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint is = (idx & 0xE0) >> 5; // 0..8 + const uint sshift = (idx & 0x10) >> 2; // 0,4 + const uint iqs = (idx & 0xF8) >> 3; // 0..63 + + const uint16_t qs = bl.block.qs[iqs]; + const float16_t dscale = bl.block.d * 0.25hf * (0.5hf + float16_t((bl.block.scales[is] >> sshift) & 0xF)); + + uint sign = uint(qs >> 9); + sign |= bitCount(sign) << 7; + const uint8_t g = unpack8(iq2xs_grid[qs & 0x1FF][(idx & 4) >> 2])[idx & 3]; + + float16_t ret = dscale * float16_t(g) * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf); + return ret; +} +#endif + +#if defined(DATA_A_IQ2_S) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_S { + block_iq2_s block; +}; + +float16_t dequantFuncIQ2_S(const in decodeBufIQ2_S bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + uint idx = coordInBlock[1]; + uint lsb = idx & 1; + idx /= 2; + + const uint ib8 = (idx % 128) / 4; // 0..31 + const uint ib32 = ib8 / 4; // 0..7 + + const uint scale = (bl.block.scales[ib32] >> (2 * (ib8 & 2))) & 0xf; + const uint qs = bl.block.qs[ib8]; + const uint qh = bl.block.qh[ib32]; + const uint qhshift = 2 * (ib8 % 4); + const uint sign = bl.block.qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4)); + + const float d = float(bl.block.d); + const float db = d * 0.25 * (0.5 + scale); + const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); + const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1]; + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid)); + return float16_t(v[lsb]); +} +#endif + +#if defined(DATA_A_IQ3_XXS) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_XXS { + block_iq3_xxs block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_XXS_packed16 { + block_iq3_xxs_packed16 block; +}; + +float16_t dequantFuncIQ3_XXS(const in decodeBufIQ3_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + uint idx = coordInBlock[1]; + uint lsb = idx & 1; + idx /= 2; + + const uint iqs = (idx % 128) / 2; // 0..63 + const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values + + const float d = float(bl.block.d); + const uint qs = bl.block.qs[iqs]; + const uint signs = pack32(u8vec4( + bl.block.qs[is+0], + bl.block.qs[is+1], + bl.block.qs[is+2], + bl.block.qs[is+3] + )); + const float db = d * 0.5 * (0.5 + (signs >> 28)); + const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7); + const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); + const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); + const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1)); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + return float16_t(v[lsb]); +} +#endif + +#if defined(DATA_A_IQ3_S) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_S { + block_iq3_s block; +}; + +float16_t dequantFuncIQ3_S(const in decodeBufIQ3_S bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + uint idx = coordInBlock[1]; + uint lsb = idx & 1; + idx /= 2; + + const uint iqs = (idx % 128) / 2; // 0..63 + const uint iqh = iqs / 8; + + const float d = float(bl.block.d); + const uint qs = bl.block.qs[iqs]; + const uint qh = bl.block.qh[iqh]; + const int8_t sign = int8_t(bl.block.signs[iqs / 2] >> (2 * (idx % 4))); + const uint scale = bl.block.scales[iqs / 16]; + const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign))); + const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf)); + const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2)); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + + return float16_t(v[lsb]); +} +#endif + + #if defined(DATA_A_IQ4_NL) layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL { block_iq4_nl block; @@ -340,6 +494,16 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor #define dequantFuncA dequantFuncQ5_K #elif defined(DATA_A_Q6_K) #define dequantFuncA dequantFuncQ6_K +#elif defined(DATA_A_IQ2_XXS) +#define dequantFuncA dequantFuncIQ2_XXS +#elif defined(DATA_A_IQ2_XS) +#define dequantFuncA dequantFuncIQ2_XS +#elif defined(DATA_A_IQ2_S) +#define dequantFuncA dequantFuncIQ2_S +#elif defined(DATA_A_IQ3_XXS) +#define dequantFuncA dequantFuncIQ3_XXS +#elif defined(DATA_A_IQ3_S) +#define dequantFuncA dequantFuncIQ3_S #elif defined(DATA_A_IQ4_NL) #define dequantFuncA dequantFuncIQ4_NL #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp new file mode 100644 index 00000000000..48f6b65bc40 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp @@ -0,0 +1,44 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq2_s data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 subblock (32 values with 2 scales) + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint ib32 = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * ib32; + + const float d = float(data_a[ib].d); + const vec2 scale = vec2(data_a[ib].scales[ib32] & 0xf, data_a[ib].scales[ib32] >> 4); + const vec2 db = d * (0.5 + scale) * 0.25; + + uint qh = data_a[ib].qh[ib32]; + [[unroll]] for (uint l = 0; l < 4; ++l) { + uint qs = data_a[ib].qs[4 * ib32 + l]; + const uint8_t sign = data_a[ib].qs[QUANT_K / 8 + 4 * ib32 + l]; + qs |= (qh << (8 - 2 * l)) & 0x300; + const uvec2 grid = iq2s_grid[qs & 511]; + const u8vec4 grid0 = unpack8(grid.x); + const u8vec4 grid1 = unpack8(grid.y); + data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign & 1) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 1] = D_TYPE(db[l/2] * grid0.y * ((sign & 2) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 2] = D_TYPE(db[l/2] * grid0.z * ((sign & 4) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 3] = D_TYPE(db[l/2] * grid0.w * ((sign & 8) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 4] = D_TYPE(db[l/2] * grid1.x * ((sign & 16) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 5] = D_TYPE(db[l/2] * grid1.y * ((sign & 32) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 6] = D_TYPE(db[l/2] * grid1.z * ((sign & 64) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 7] = D_TYPE(db[l/2] * grid1.w * ((sign & 128) != 0 ? -1.0 : 1.0)); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp new file mode 100644 index 00000000000..a08331c40de --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp @@ -0,0 +1,43 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq2_xs data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 subblock (32 values with 2 scales) + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint ib32 = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * ib32; + + const float d = float(data_a[ib].d); + const vec2 scale = vec2(data_a[ib].scales[ib32] & 0xf, data_a[ib].scales[ib32] >> 4); + const vec2 db = d * (0.5 + scale) * 0.25; + + [[unroll]] for (uint l = 0; l < 4; ++l) { + uint16_t qs = data_a[ib].qs[4 * ib32 + l]; + const uint sign7 = qs >> 9; + const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit + const uvec2 grid = iq2xs_grid[qs & 511]; + const u8vec4 grid0 = unpack8(grid.x); + const u8vec4 grid1 = unpack8(grid.y); + data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 1] = D_TYPE(db[l/2] * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 2] = D_TYPE(db[l/2] * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 3] = D_TYPE(db[l/2] * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 4] = D_TYPE(db[l/2] * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 5] = D_TYPE(db[l/2] * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 6] = D_TYPE(db[l/2] * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 7] = D_TYPE(db[l/2] * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0)); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp new file mode 100644 index 00000000000..e370690bcb0 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp @@ -0,0 +1,48 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq2_xxs data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 scale block (32 values) + // Each block is described by 4 lattice indices, 4x7 sign bits and 4 scale bits + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint is = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * is; + + const float d = float(data_a[ib].d); + uint signscale = pack32(u8vec4( + data_a[ib].qs[8*is + 4], + data_a[ib].qs[8*is + 5], + data_a[ib].qs[8*is + 6], + data_a[ib].qs[8*is + 7] + )); + const float db = d * (0.5 + (signscale >> 28)) * 0.25; + + [[unroll]] for (uint l = 0; l < 4; ++l) { + const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7); + const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit + const uvec2 grid = iq2xxs_grid[data_a[ib].qs[8 * is + l]]; + const u8vec4 grid0 = unpack8(grid.x); + const u8vec4 grid1 = unpack8(grid.y); + data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 3] = D_TYPE(db * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 4] = D_TYPE(db * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 5] = D_TYPE(db * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 6] = D_TYPE(db * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 7] = D_TYPE(db * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0)); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp new file mode 100644 index 00000000000..c3f4bca5d95 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp @@ -0,0 +1,39 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq3_s data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 scale nibble. + // Each block contains 4 scale bytes (8 scales) for 256 output values. + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint is = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * is; + + const float d = float(data_a[ib].d); + const float db = d * (1 + 2 * ((data_a[ib].scales[is] >> (4 * (is % 2))) & 0xf)); + + // We must produce 32 values using 4 sign bytes, 1 qh byte, 8 qs bytes. + uint qh = data_a[ib].qh[is]; + [[unroll]] for (uint l = 0; l < 8; ++l) { + uint qs = data_a[ib].qs[8 * is + l]; + uint gidx = qs | ((qh << (8 - l)) & 256); + uint8_t signs = data_a[ib].signs[8 * is + l / 2] >> (4 * (l & 1)); + u8vec4 grid = unpack8(iq3s_grid[gidx]); + data_b[b_idx + 4 * l + 0] = D_TYPE(db * grid.x * ((signs & 1) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 4 * l + 1] = D_TYPE(db * grid.y * ((signs & 2) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 4 * l + 2] = D_TYPE(db * grid.z * ((signs & 4) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 4 * l + 3] = D_TYPE(db * grid.w * ((signs & 8) != 0 ? -1.0 : 1.0)); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp new file mode 100644 index 00000000000..a92b82961af --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp @@ -0,0 +1,49 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq3_xxs data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 scale block (32 values) + // 8 threads handle 1 superblock + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint is = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * is; + const uint s_idx = QUANT_K / 4 + 4 * is; + + const float d = float(data_a[ib].d); + uint signscale = pack32(u8vec4( + data_a[ib].qs[s_idx + 0], + data_a[ib].qs[s_idx + 1], + data_a[ib].qs[s_idx + 2], + data_a[ib].qs[s_idx + 3] + )); + const float db = d * (0.5 + (signscale >> 28)) * 0.5; + + [[unroll]] for (uint l = 0; l < 4; ++l) { + const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7); + // Restore parity bit. + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const u8vec4 grid0 = unpack8(iq3xxs_grid[data_a[ib].qs[8 * is + 2 * l]]); + const u8vec4 grid1 = unpack8(iq3xxs_grid[data_a[ib].qs[8 * is + 2 * l + 1]]); + data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 3] = D_TYPE(db * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 4] = D_TYPE(db * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 5] = D_TYPE(db * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 6] = D_TYPE(db * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 7] = D_TYPE(db * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0)); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp index 8de14fc03f1..46d9ad15eba 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp @@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; void main() { const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; - init_iq4nl_shmem(); + init_iq_shmem(gl_WorkGroupSize); const uint tid = gl_LocalInvocationID.x % 64; const uint il = tid/32; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 3735d0dbbae..043a5302388 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -104,8 +104,8 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele #endif void main() { -#if defined(DATA_A_IQ4_NL) - init_iq4nl_shmem(); +#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL) + init_iq_shmem(gl_WorkGroupSize); #endif const uint32_t N = p.N; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp index 1426fde6597..09dc43d8dc3 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp @@ -12,8 +12,8 @@ void main() { const uint i11 = (gl_GlobalInvocationID.z)/p.ne12; const uint i12 = (gl_GlobalInvocationID.z)%p.ne12; -#if defined(DATA_A_IQ4_NL) - init_iq4nl_shmem(); +#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL) + init_iq_shmem(gl_WorkGroupSize); #endif if (i00 >= p.ne00) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp index 53902858de7..48156e7bab6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp @@ -133,8 +133,8 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { void main() { const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); -#if defined(DATA_A_IQ4_NL) - init_iq4nl_shmem(); +#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL) + init_iq_shmem(gl_WorkGroupSize); #endif // do NUM_ROWS at a time, unless there aren't enough remaining rows diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 48122cbef90..d0559aac8ec 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -95,8 +95,8 @@ shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; #endif void main() { -#if defined(DATA_A_IQ4_NL) - init_iq4nl_shmem(); +#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL) + init_iq_shmem(gl_WorkGroupSize); #endif #ifdef MUL_MAT_ID @@ -343,10 +343,8 @@ void main() { const uint qsshift = halfsplit * 2; // 0,2,4,6 const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128 - const int8_t us = int8_t(is < 4 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+8] >> 0) & 3) << 4) : - is < 8 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+4] >> 2) & 3) << 4) : - is < 12 ? (data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is+0] >> 4) & 3) << 4) : - (data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is-4] >> 6) & 3) << 4)); + const int8_t us = int8_t(((data_a[ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF) + | (((data_a[ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4)); const float dl = float(data_a[ib].d) * float(us - 32); buf_a[buf_idx ] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4))); @@ -439,6 +437,118 @@ void main() { buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32)); buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32)); +#elif defined(DATA_A_IQ2_XXS) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint ib32 = (idx % 128) / 16; // 0..7 + const uint ib8 = (idx / 4) % 4; + + const float d = float(data_a[ib].d); + const uint qs = data_a[ib].qs[8 * ib32 + ib8]; + const uint signs = pack32(u8vec4( + data_a[ib].qs[8*ib32 + 4], + data_a[ib].qs[8*ib32 + 5], + data_a[ib].qs[8*ib32 + 6], + data_a[ib].qs[8*ib32 + 7] + )); + const float db = d * 0.25 * (0.5 + (signs >> 28)); + const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7); + const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); + const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); + const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1)); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_IQ2_XS) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint ib32 = (idx % 128) / 16; // 0..7 + const uint ib8 = (idx / 4) % 4; // 0..3 + + const float d = float(data_a[ib].d); + const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; + const float db = d * 0.25 * (0.5 + scale); + const uint qs = data_a[ib].qs[4 * ib32 + ib8]; + const uint sign7 = qs >> 9; + const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); + const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); + const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1)); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_IQ2_S) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint ib8 = (idx % 128) / 4; // 0..31 + const uint ib32 = ib8 / 4; // 0..7 + + const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; + const uint qs = data_a[ib].qs[ib8]; + const uint qh = data_a[ib].qh[ib32]; + const uint qhshift = 2 * (ib8 % 4); + const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4)); + + const float d = float(data_a[ib].d); + const float db = d * 0.25 * (0.5 + scale); + const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); + const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1]; + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid)); + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_IQ3_XXS) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = (idx % 128) / 2; // 0..63 + const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values + + const float d = float(data_a[ib].d); + const uint qs = data_a[ib].qs[iqs]; + const uint signs = pack32(u8vec4( + data_a[ib].qs[is+0], + data_a[ib].qs[is+1], + data_a[ib].qs[is+2], + data_a[ib].qs[is+3] + )); + const float db = d * 0.5 * (0.5 + (signs >> 28)); + const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7); + const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); + const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); + const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1)); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_IQ3_S) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = (idx % 128) / 2; // 0..63 + const uint iqh = iqs / 8; + + const float d = float(data_a[ib].d); + const uint qs = data_a[ib].qs[iqs]; + const uint qh = data_a[ib].qh[iqh]; + const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (2 * (idx % 4))); + const uint scale = data_a[ib].scales[iqs / 16]; + const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign))); + const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf)); + const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2)); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); #elif defined(DATA_A_IQ4_NL) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index 57f9e724513..27c5d68b3d9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -106,8 +106,8 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem #endif void main() { -#if defined(DATA_A_IQ4_NL) - init_iq4nl_shmem(); +#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL) + init_iq_shmem(gl_WorkGroupSize); #endif #ifdef MUL_MAT_ID diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp index 1e35b665283..9e56a35300b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp @@ -294,6 +294,738 @@ struct block_q6_K_packed16 // IQuants +#define QUANT_K_IQ2_XXS 256 +#define QUANT_R_IQ2_XXS 1 + +struct block_iq2_xxs +{ + float16_t d; + uint8_t qs[QUANT_K_IQ2_XXS/4]; +}; + +struct block_iq2_xxs_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ2_XXS/8]; +}; + +#if defined(DATA_A_IQ2_XXS) + +const uvec2[256] iq2xxs_grid_const = { + uvec2(0x08080808, 0x08080808), uvec2(0x0808082b, 0x08080808), uvec2(0x08081919, 0x08080808), uvec2(0x08082b08, 0x08080808), + uvec2(0x08082b2b, 0x08080808), uvec2(0x08190819, 0x08080808), uvec2(0x08191908, 0x08080808), uvec2(0x082b0808, 0x08080808), + uvec2(0x082b082b, 0x08080808), uvec2(0x082b2b08, 0x08080808), uvec2(0x082b2b2b, 0x08080808), uvec2(0x19080819, 0x08080808), + uvec2(0x19081908, 0x08080808), uvec2(0x19190808, 0x08080808), uvec2(0x19192b08, 0x08080808), uvec2(0x192b0819, 0x08080808), + uvec2(0x192b1908, 0x08080808), uvec2(0x2b080808, 0x08080808), uvec2(0x2b08082b, 0x08080808), uvec2(0x2b082b2b, 0x08080808), + uvec2(0x2b2b082b, 0x08080808), uvec2(0x08080819, 0x08080819), uvec2(0x08081908, 0x08080819), uvec2(0x08190808, 0x08080819), + uvec2(0x08191919, 0x08080819), uvec2(0x19080808, 0x08080819), uvec2(0x2b081908, 0x08080819), uvec2(0x2b192b08, 0x08080819), + uvec2(0x08080808, 0x0808082b), uvec2(0x0808082b, 0x0808082b), uvec2(0x082b082b, 0x0808082b), uvec2(0x2b08082b, 0x0808082b), + uvec2(0x08080819, 0x08081908), uvec2(0x08081908, 0x08081908), uvec2(0x08190808, 0x08081908), uvec2(0x082b0819, 0x08081908), + uvec2(0x082b1908, 0x08081908), uvec2(0x19080808, 0x08081908), uvec2(0x1908082b, 0x08081908), uvec2(0x19082b08, 0x08081908), + uvec2(0x192b0808, 0x08081908), uvec2(0x2b080819, 0x08081908), uvec2(0x2b081908, 0x08081908), uvec2(0x2b190808, 0x08081908), + uvec2(0x2b2b1908, 0x08081908), uvec2(0x08080808, 0x08081919), uvec2(0x0808082b, 0x08081919), uvec2(0x08082b08, 0x08081919), + uvec2(0x082b0808, 0x08081919), uvec2(0x1908192b, 0x08081919), uvec2(0x192b2b19, 0x08081919), uvec2(0x2b080808, 0x08081919), + uvec2(0x2b190819, 0x08081919), uvec2(0x08082b19, 0x0808192b), uvec2(0x08190808, 0x0808192b), uvec2(0x19080808, 0x0808192b), + uvec2(0x2b081908, 0x0808192b), uvec2(0x2b2b1908, 0x0808192b), uvec2(0x08080808, 0x08082b08), uvec2(0x08081919, 0x08082b08), + uvec2(0x08082b08, 0x08082b08), uvec2(0x08191908, 0x08082b08), uvec2(0x082b2b08, 0x08082b08), uvec2(0x19080819, 0x08082b08), + uvec2(0x19081908, 0x08082b08), uvec2(0x19190808, 0x08082b08), uvec2(0x1919082b, 0x08082b08), uvec2(0x2b082b08, 0x08082b08), + uvec2(0x08081908, 0x08082b19), uvec2(0x19080808, 0x08082b19), uvec2(0x0808082b, 0x08082b2b), uvec2(0x08191908, 0x08082b2b), + uvec2(0x08080819, 0x08190808), uvec2(0x08081908, 0x08190808), uvec2(0x08190808, 0x08190808), uvec2(0x082b0819, 0x08190808), + uvec2(0x19080808, 0x08190808), uvec2(0x192b0808, 0x08190808), uvec2(0x2b081908, 0x08190808), uvec2(0x2b190808, 0x08190808), + uvec2(0x2b191919, 0x08190808), uvec2(0x08080808, 0x08190819), uvec2(0x08082b08, 0x08190819), uvec2(0x082b0808, 0x08190819), + uvec2(0x19190808, 0x08190819), uvec2(0x19192b2b, 0x08190819), uvec2(0x2b080808, 0x08190819), uvec2(0x082b1908, 0x0819082b), + uvec2(0x19081919, 0x0819082b), uvec2(0x08080808, 0x08191908), uvec2(0x08082b08, 0x08191908), uvec2(0x082b0808, 0x08191908), + uvec2(0x082b1919, 0x08191908), uvec2(0x19082b19, 0x08191908), uvec2(0x2b080808, 0x08191908), uvec2(0x08192b08, 0x08191919), + uvec2(0x192b082b, 0x08191919), uvec2(0x08080808, 0x0819192b), uvec2(0x0819192b, 0x0819192b), uvec2(0x08080819, 0x08192b08), + uvec2(0x08081908, 0x08192b08), uvec2(0x08190808, 0x08192b08), uvec2(0x19080808, 0x08192b08), uvec2(0x2b080819, 0x08192b08), + uvec2(0x08080808, 0x08192b19), uvec2(0x08081919, 0x08192b19), uvec2(0x2b2b0808, 0x08192b19), uvec2(0x19190819, 0x08192b2b), + uvec2(0x08080808, 0x082b0808), uvec2(0x0808082b, 0x082b0808), uvec2(0x08082b2b, 0x082b0808), uvec2(0x19081908, 0x082b0808), + uvec2(0x192b0819, 0x082b0808), uvec2(0x2b080808, 0x082b0808), uvec2(0x2b08082b, 0x082b0808), uvec2(0x082b2b19, 0x082b0819), + uvec2(0x19082b08, 0x082b0819), uvec2(0x08080808, 0x082b082b), uvec2(0x0808082b, 0x082b082b), uvec2(0x08080819, 0x082b1908), + uvec2(0x08081908, 0x082b1908), uvec2(0x08190808, 0x082b1908), uvec2(0x19080808, 0x082b1908), uvec2(0x1919192b, 0x082b1908), + uvec2(0x08080808, 0x082b1919), uvec2(0x19080819, 0x082b1919), uvec2(0x192b1908, 0x082b1919), uvec2(0x2b190808, 0x082b192b), + uvec2(0x08082b08, 0x082b2b08), uvec2(0x082b0808, 0x082b2b08), uvec2(0x2b191908, 0x082b2b08), uvec2(0x19081908, 0x082b2b2b), + uvec2(0x08080819, 0x19080808), uvec2(0x08081908, 0x19080808), uvec2(0x08190808, 0x19080808), uvec2(0x08192b08, 0x19080808), + uvec2(0x082b0819, 0x19080808), uvec2(0x082b1908, 0x19080808), uvec2(0x19080808, 0x19080808), uvec2(0x19082b08, 0x19080808), + uvec2(0x1919192b, 0x19080808), uvec2(0x192b0808, 0x19080808), uvec2(0x2b080819, 0x19080808), uvec2(0x2b081908, 0x19080808), + uvec2(0x2b190808, 0x19080808), uvec2(0x08080808, 0x19080819), uvec2(0x082b0808, 0x19080819), uvec2(0x192b0819, 0x19080819), + uvec2(0x2b080808, 0x19080819), uvec2(0x2b081919, 0x19080819), uvec2(0x08080819, 0x1908082b), uvec2(0x08190808, 0x1908082b), + uvec2(0x19082b08, 0x1908082b), uvec2(0x1919192b, 0x1908082b), uvec2(0x192b2b08, 0x1908082b), uvec2(0x08080808, 0x19081908), + uvec2(0x08082b08, 0x19081908), uvec2(0x082b0808, 0x19081908), uvec2(0x2b080808, 0x19081908), uvec2(0x2b192b19, 0x19081908), + uvec2(0x0819082b, 0x19081919), uvec2(0x082b1908, 0x19081919), uvec2(0x08080808, 0x1908192b), uvec2(0x08080819, 0x19082b08), + uvec2(0x08081908, 0x19082b08), uvec2(0x08190808, 0x19082b08), uvec2(0x19080808, 0x19082b08), uvec2(0x19081919, 0x19082b08), + uvec2(0x08080808, 0x19082b19), uvec2(0x19192b08, 0x19082b19), uvec2(0x192b0819, 0x19082b19), uvec2(0x2b08082b, 0x19082b19), + uvec2(0x19081919, 0x19082b2b), uvec2(0x2b190808, 0x19082b2b), uvec2(0x08080808, 0x19190808), uvec2(0x08082b08, 0x19190808), + uvec2(0x08190819, 0x19190808), uvec2(0x08192b19, 0x19190808), uvec2(0x082b0808, 0x19190808), uvec2(0x2b080808, 0x19190808), + uvec2(0x2b082b08, 0x19190808), uvec2(0x08081908, 0x19190819), uvec2(0x1908082b, 0x19190819), uvec2(0x2b2b1908, 0x19190819), + uvec2(0x2b190819, 0x1919082b), uvec2(0x2b190808, 0x19191908), uvec2(0x2b19082b, 0x19191908), uvec2(0x08082b2b, 0x19191919), + uvec2(0x08080819, 0x1919192b), uvec2(0x19191908, 0x1919192b), uvec2(0x08080808, 0x19192b08), uvec2(0x08190819, 0x19192b08), + uvec2(0x08192b19, 0x19192b08), uvec2(0x192b1908, 0x19192b08), uvec2(0x19080808, 0x19192b19), uvec2(0x08082b08, 0x19192b2b), + uvec2(0x08081908, 0x192b0808), uvec2(0x08190808, 0x192b0808), uvec2(0x19080808, 0x192b0808), uvec2(0x192b2b08, 0x192b0808), + uvec2(0x08080808, 0x192b0819), uvec2(0x19191919, 0x192b0819), uvec2(0x08192b08, 0x192b082b), uvec2(0x192b0808, 0x192b082b), + uvec2(0x08080808, 0x192b1908), uvec2(0x08081919, 0x192b1908), uvec2(0x08190808, 0x192b1919), uvec2(0x0819082b, 0x192b1919), + uvec2(0x2b081908, 0x192b1919), uvec2(0x1908082b, 0x192b2b08), uvec2(0x08080808, 0x2b080808), uvec2(0x0808082b, 0x2b080808), + uvec2(0x08082b2b, 0x2b080808), uvec2(0x19080819, 0x2b080808), uvec2(0x2b08082b, 0x2b080808), uvec2(0x08081908, 0x2b080819), + uvec2(0x08192b08, 0x2b080819), uvec2(0x19080808, 0x2b080819), uvec2(0x08190819, 0x2b08082b), uvec2(0x08080819, 0x2b081908), + uvec2(0x08081908, 0x2b081908), uvec2(0x08190808, 0x2b081908), uvec2(0x08191919, 0x2b081908), uvec2(0x19080808, 0x2b081908), + uvec2(0x192b0808, 0x2b081908), uvec2(0x08080808, 0x2b081919), uvec2(0x1908192b, 0x2b081919), uvec2(0x2b191908, 0x2b081919), + uvec2(0x08082b19, 0x2b08192b), uvec2(0x19080808, 0x2b08192b), uvec2(0x192b0808, 0x2b08192b), uvec2(0x0808082b, 0x2b082b08), + uvec2(0x08081908, 0x2b082b19), uvec2(0x08190819, 0x2b082b2b), uvec2(0x08081908, 0x2b190808), uvec2(0x08190808, 0x2b190808), + uvec2(0x082b1908, 0x2b190808), uvec2(0x19080808, 0x2b190808), uvec2(0x2b2b0819, 0x2b190808), uvec2(0x0819192b, 0x2b190819), + uvec2(0x2b080808, 0x2b190819), uvec2(0x19081919, 0x2b19082b), uvec2(0x08080808, 0x2b191908), uvec2(0x082b082b, 0x2b191908), + uvec2(0x19081908, 0x2b191908), uvec2(0x19190819, 0x2b191919), uvec2(0x2b080819, 0x2b192b08), uvec2(0x082b0808, 0x2b192b19), + uvec2(0x0808082b, 0x2b2b0808), uvec2(0x19190808, 0x2b2b0808), uvec2(0x2b081919, 0x2b2b0808), uvec2(0x08082b19, 0x2b2b0819), + uvec2(0x08080808, 0x2b2b082b), uvec2(0x08192b08, 0x2b2b1908), uvec2(0x19190808, 0x2b2b2b08), uvec2(0x08081908, 0x2b2b2b19) +}; + +shared uvec2 iq2xxs_grid[256]; + +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + for (uint i = gl_LocalInvocationIndex.x; i < iq2xxs_grid.length(); i += wgsize.x) { + iq2xxs_grid[i] = iq2xxs_grid_const[i]; + } + barrier(); +} + +#define QUANT_K QUANT_K_IQ2_XXS +#define QUANT_R QUANT_R_IQ2_XXS +#define A_TYPE block_iq2_xxs +#define A_TYPE_PACKED16 block_iq2_xxs_packed16 +#endif + +#define QUANT_K_IQ2_XS 256 +#define QUANT_R_IQ2_XS 1 + +struct block_iq2_xs +{ + float16_t d; + uint16_t qs[QUANT_K_IQ2_XS/8]; + uint8_t scales[QUANT_K_IQ2_XS/32]; +}; + +struct block_iq2_xs_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ2_XS/8]; + uint16_t scales[QUANT_K_IQ2_XS/64]; +}; + +#if defined(DATA_A_IQ2_XS) + +const uvec2 iq2xs_grid_const[512] = { + uvec2(0x08080808, 0x08080808), uvec2(0x0808082b, 0x08080808), uvec2(0x08081919, 0x08080808), uvec2(0x08082b08, 0x08080808), + uvec2(0x08082b2b, 0x08080808), uvec2(0x08190819, 0x08080808), uvec2(0x08191908, 0x08080808), uvec2(0x0819192b, 0x08080808), + uvec2(0x08192b19, 0x08080808), uvec2(0x082b0808, 0x08080808), uvec2(0x082b082b, 0x08080808), uvec2(0x082b1919, 0x08080808), + uvec2(0x082b2b08, 0x08080808), uvec2(0x19080819, 0x08080808), uvec2(0x19081908, 0x08080808), uvec2(0x1908192b, 0x08080808), + uvec2(0x19082b19, 0x08080808), uvec2(0x19190808, 0x08080808), uvec2(0x1919082b, 0x08080808), uvec2(0x19191919, 0x08080808), + uvec2(0x19192b08, 0x08080808), uvec2(0x192b0819, 0x08080808), uvec2(0x192b1908, 0x08080808), uvec2(0x2b080808, 0x08080808), + uvec2(0x2b08082b, 0x08080808), uvec2(0x2b081919, 0x08080808), uvec2(0x2b082b08, 0x08080808), uvec2(0x2b190819, 0x08080808), + uvec2(0x2b191908, 0x08080808), uvec2(0x2b192b19, 0x08080808), uvec2(0x2b2b0808, 0x08080808), uvec2(0x08080819, 0x08080819), + uvec2(0x08081908, 0x08080819), uvec2(0x0808192b, 0x08080819), uvec2(0x08082b19, 0x08080819), uvec2(0x08190808, 0x08080819), + uvec2(0x0819082b, 0x08080819), uvec2(0x08191919, 0x08080819), uvec2(0x08192b08, 0x08080819), uvec2(0x08192b2b, 0x08080819), + uvec2(0x082b0819, 0x08080819), uvec2(0x082b1908, 0x08080819), uvec2(0x19080808, 0x08080819), uvec2(0x1908082b, 0x08080819), + uvec2(0x19081919, 0x08080819), uvec2(0x19082b08, 0x08080819), uvec2(0x19190819, 0x08080819), uvec2(0x19191908, 0x08080819), + uvec2(0x192b0808, 0x08080819), uvec2(0x192b2b08, 0x08080819), uvec2(0x2b080819, 0x08080819), uvec2(0x2b081908, 0x08080819), + uvec2(0x2b190808, 0x08080819), uvec2(0x08080808, 0x0808082b), uvec2(0x0808082b, 0x0808082b), uvec2(0x08081919, 0x0808082b), + uvec2(0x08082b08, 0x0808082b), uvec2(0x08190819, 0x0808082b), uvec2(0x08191908, 0x0808082b), uvec2(0x082b0808, 0x0808082b), + uvec2(0x19080819, 0x0808082b), uvec2(0x19081908, 0x0808082b), uvec2(0x19190808, 0x0808082b), uvec2(0x19191919, 0x0808082b), + uvec2(0x2b080808, 0x0808082b), uvec2(0x2b082b2b, 0x0808082b), uvec2(0x08080819, 0x08081908), uvec2(0x08081908, 0x08081908), + uvec2(0x0808192b, 0x08081908), uvec2(0x08082b19, 0x08081908), uvec2(0x08190808, 0x08081908), uvec2(0x0819082b, 0x08081908), + uvec2(0x08191919, 0x08081908), uvec2(0x08192b08, 0x08081908), uvec2(0x082b0819, 0x08081908), uvec2(0x082b1908, 0x08081908), + uvec2(0x19080808, 0x08081908), uvec2(0x1908082b, 0x08081908), uvec2(0x19081919, 0x08081908), uvec2(0x19082b08, 0x08081908), + uvec2(0x19190819, 0x08081908), uvec2(0x19191908, 0x08081908), uvec2(0x1919192b, 0x08081908), uvec2(0x192b0808, 0x08081908), + uvec2(0x2b080819, 0x08081908), uvec2(0x2b081908, 0x08081908), uvec2(0x2b190808, 0x08081908), uvec2(0x08080808, 0x08081919), + uvec2(0x0808082b, 0x08081919), uvec2(0x08081919, 0x08081919), uvec2(0x08082b08, 0x08081919), uvec2(0x08190819, 0x08081919), + uvec2(0x08191908, 0x08081919), uvec2(0x082b0808, 0x08081919), uvec2(0x19080819, 0x08081919), uvec2(0x19081908, 0x08081919), + uvec2(0x19190808, 0x08081919), uvec2(0x192b0819, 0x08081919), uvec2(0x2b080808, 0x08081919), uvec2(0x08080819, 0x0808192b), + uvec2(0x08081908, 0x0808192b), uvec2(0x08190808, 0x0808192b), uvec2(0x082b192b, 0x0808192b), uvec2(0x19080808, 0x0808192b), + uvec2(0x1908082b, 0x0808192b), uvec2(0x2b081908, 0x0808192b), uvec2(0x08080808, 0x08082b08), uvec2(0x0808082b, 0x08082b08), + uvec2(0x08081919, 0x08082b08), uvec2(0x08082b08, 0x08082b08), uvec2(0x08082b2b, 0x08082b08), uvec2(0x08190819, 0x08082b08), + uvec2(0x08191908, 0x08082b08), uvec2(0x082b0808, 0x08082b08), uvec2(0x082b1919, 0x08082b08), uvec2(0x19080819, 0x08082b08), + uvec2(0x19081908, 0x08082b08), uvec2(0x19190808, 0x08082b08), uvec2(0x19192b08, 0x08082b08), uvec2(0x2b080808, 0x08082b08), + uvec2(0x2b2b0808, 0x08082b08), uvec2(0x2b2b2b2b, 0x08082b08), uvec2(0x08080819, 0x08082b19), uvec2(0x08081908, 0x08082b19), + uvec2(0x08190808, 0x08082b19), uvec2(0x19080808, 0x08082b19), uvec2(0x2b080819, 0x08082b19), uvec2(0x2b082b19, 0x08082b19), + uvec2(0x08080808, 0x08082b2b), uvec2(0x082b0808, 0x08082b2b), uvec2(0x082b2b08, 0x08082b2b), uvec2(0x2b19192b, 0x08082b2b), + uvec2(0x2b2b0808, 0x08082b2b), uvec2(0x08080819, 0x08190808), uvec2(0x08081908, 0x08190808), uvec2(0x0808192b, 0x08190808), + uvec2(0x08082b19, 0x08190808), uvec2(0x08190808, 0x08190808), uvec2(0x0819082b, 0x08190808), uvec2(0x08191919, 0x08190808), + uvec2(0x08192b08, 0x08190808), uvec2(0x082b0819, 0x08190808), uvec2(0x082b1908, 0x08190808), uvec2(0x19080808, 0x08190808), + uvec2(0x1908082b, 0x08190808), uvec2(0x19081919, 0x08190808), uvec2(0x19082b08, 0x08190808), uvec2(0x19190819, 0x08190808), + uvec2(0x19191908, 0x08190808), uvec2(0x192b0808, 0x08190808), uvec2(0x192b2b2b, 0x08190808), uvec2(0x2b080819, 0x08190808), + uvec2(0x2b081908, 0x08190808), uvec2(0x2b190808, 0x08190808), uvec2(0x08080808, 0x08190819), uvec2(0x0808082b, 0x08190819), + uvec2(0x08081919, 0x08190819), uvec2(0x08082b08, 0x08190819), uvec2(0x08190819, 0x08190819), uvec2(0x08191908, 0x08190819), + uvec2(0x082b0808, 0x08190819), uvec2(0x19080819, 0x08190819), uvec2(0x19081908, 0x08190819), uvec2(0x19190808, 0x08190819), + uvec2(0x2b080808, 0x08190819), uvec2(0x2b191908, 0x08190819), uvec2(0x2b19192b, 0x08190819), uvec2(0x08080819, 0x0819082b), + uvec2(0x08081908, 0x0819082b), uvec2(0x0808192b, 0x0819082b), uvec2(0x08190808, 0x0819082b), uvec2(0x19080808, 0x0819082b), + uvec2(0x192b0808, 0x0819082b), uvec2(0x08080808, 0x08191908), uvec2(0x0808082b, 0x08191908), uvec2(0x08081919, 0x08191908), + uvec2(0x08082b08, 0x08191908), uvec2(0x08190819, 0x08191908), uvec2(0x08191908, 0x08191908), uvec2(0x082b0808, 0x08191908), + uvec2(0x19080819, 0x08191908), uvec2(0x19081908, 0x08191908), uvec2(0x19082b19, 0x08191908), uvec2(0x19190808, 0x08191908), + uvec2(0x192b1908, 0x08191908), uvec2(0x2b080808, 0x08191908), uvec2(0x08080819, 0x08191919), uvec2(0x08081908, 0x08191919), + uvec2(0x08190808, 0x08191919), uvec2(0x19080808, 0x08191919), uvec2(0x08080808, 0x0819192b), uvec2(0x08191908, 0x0819192b), + uvec2(0x19082b19, 0x0819192b), uvec2(0x08080819, 0x08192b08), uvec2(0x08081908, 0x08192b08), uvec2(0x08190808, 0x08192b08), + uvec2(0x0819082b, 0x08192b08), uvec2(0x19080808, 0x08192b08), uvec2(0x19191908, 0x08192b08), uvec2(0x2b08192b, 0x08192b08), + uvec2(0x08080808, 0x08192b19), uvec2(0x08081919, 0x08192b19), uvec2(0x192b192b, 0x08192b19), uvec2(0x19190819, 0x08192b2b), + uvec2(0x2b2b2b19, 0x08192b2b), uvec2(0x08080808, 0x082b0808), uvec2(0x0808082b, 0x082b0808), uvec2(0x08081919, 0x082b0808), + uvec2(0x08082b08, 0x082b0808), uvec2(0x08082b2b, 0x082b0808), uvec2(0x08190819, 0x082b0808), uvec2(0x08191908, 0x082b0808), + uvec2(0x082b0808, 0x082b0808), uvec2(0x19080819, 0x082b0808), uvec2(0x19081908, 0x082b0808), uvec2(0x19190808, 0x082b0808), + uvec2(0x2b080808, 0x082b0808), uvec2(0x2b2b0808, 0x082b0808), uvec2(0x08080819, 0x082b0819), uvec2(0x08081908, 0x082b0819), + uvec2(0x08190808, 0x082b0819), uvec2(0x19080808, 0x082b0819), uvec2(0x19082b08, 0x082b0819), uvec2(0x192b1919, 0x082b0819), + uvec2(0x08080808, 0x082b082b), uvec2(0x082b082b, 0x082b082b), uvec2(0x2b080808, 0x082b082b), uvec2(0x2b2b2b08, 0x082b082b), + uvec2(0x08080819, 0x082b1908), uvec2(0x08081908, 0x082b1908), uvec2(0x08190808, 0x082b1908), uvec2(0x082b2b19, 0x082b1908), + uvec2(0x19080808, 0x082b1908), uvec2(0x08080808, 0x082b1919), uvec2(0x19080819, 0x082b1919), uvec2(0x1919082b, 0x082b1919), + uvec2(0x2b192b19, 0x082b1919), uvec2(0x08080819, 0x082b192b), uvec2(0x08192b2b, 0x082b192b), uvec2(0x2b2b192b, 0x082b192b), + uvec2(0x08080808, 0x082b2b08), uvec2(0x08082b08, 0x082b2b08), uvec2(0x08082b2b, 0x082b2b08), uvec2(0x082b0808, 0x082b2b08), + uvec2(0x19191919, 0x082b2b08), uvec2(0x2b082b08, 0x082b2b08), uvec2(0x2b2b082b, 0x082b2b08), uvec2(0x192b2b08, 0x082b2b19), + uvec2(0x2b190808, 0x082b2b19), uvec2(0x08082b08, 0x082b2b2b), uvec2(0x082b0808, 0x082b2b2b), uvec2(0x2b08082b, 0x082b2b2b), + uvec2(0x2b082b08, 0x082b2b2b), uvec2(0x2b082b2b, 0x082b2b2b), uvec2(0x08080819, 0x19080808), uvec2(0x08081908, 0x19080808), + uvec2(0x0808192b, 0x19080808), uvec2(0x08082b19, 0x19080808), uvec2(0x08190808, 0x19080808), uvec2(0x0819082b, 0x19080808), + uvec2(0x08191919, 0x19080808), uvec2(0x08192b08, 0x19080808), uvec2(0x082b0819, 0x19080808), uvec2(0x082b1908, 0x19080808), + uvec2(0x19080808, 0x19080808), uvec2(0x1908082b, 0x19080808), uvec2(0x19081919, 0x19080808), uvec2(0x19082b08, 0x19080808), + uvec2(0x19082b2b, 0x19080808), uvec2(0x19190819, 0x19080808), uvec2(0x19191908, 0x19080808), uvec2(0x192b0808, 0x19080808), + uvec2(0x192b1919, 0x19080808), uvec2(0x2b080819, 0x19080808), uvec2(0x2b081908, 0x19080808), uvec2(0x2b190808, 0x19080808), + uvec2(0x08080808, 0x19080819), uvec2(0x0808082b, 0x19080819), uvec2(0x08081919, 0x19080819), uvec2(0x08082b08, 0x19080819), + uvec2(0x08190819, 0x19080819), uvec2(0x08191908, 0x19080819), uvec2(0x082b0808, 0x19080819), uvec2(0x19080819, 0x19080819), + uvec2(0x19081908, 0x19080819), uvec2(0x19190808, 0x19080819), uvec2(0x2b080808, 0x19080819), uvec2(0x2b081919, 0x19080819), + uvec2(0x2b2b082b, 0x19080819), uvec2(0x08080819, 0x1908082b), uvec2(0x08081908, 0x1908082b), uvec2(0x08190808, 0x1908082b), + uvec2(0x0819082b, 0x1908082b), uvec2(0x082b2b19, 0x1908082b), uvec2(0x19080808, 0x1908082b), uvec2(0x08080808, 0x19081908), + uvec2(0x0808082b, 0x19081908), uvec2(0x08081919, 0x19081908), uvec2(0x08082b08, 0x19081908), uvec2(0x08190819, 0x19081908), + uvec2(0x08191908, 0x19081908), uvec2(0x08192b19, 0x19081908), uvec2(0x082b0808, 0x19081908), uvec2(0x19080819, 0x19081908), + uvec2(0x19081908, 0x19081908), uvec2(0x19190808, 0x19081908), uvec2(0x2b080808, 0x19081908), uvec2(0x2b191908, 0x19081908), + uvec2(0x08080819, 0x19081919), uvec2(0x08081908, 0x19081919), uvec2(0x08190808, 0x19081919), uvec2(0x082b1908, 0x19081919), + uvec2(0x19080808, 0x19081919), uvec2(0x2b192b2b, 0x19081919), uvec2(0x08080808, 0x1908192b), uvec2(0x08082b2b, 0x1908192b), + uvec2(0x19081908, 0x1908192b), uvec2(0x19190808, 0x1908192b), uvec2(0x08080819, 0x19082b08), uvec2(0x08081908, 0x19082b08), + uvec2(0x08190808, 0x19082b08), uvec2(0x19080808, 0x19082b08), uvec2(0x19081919, 0x19082b08), uvec2(0x19191908, 0x19082b08), + uvec2(0x192b082b, 0x19082b08), uvec2(0x08080808, 0x19082b19), uvec2(0x08190819, 0x19082b19), uvec2(0x19081908, 0x19082b19), + uvec2(0x19190808, 0x19082b19), uvec2(0x192b2b19, 0x19082b19), uvec2(0x08081908, 0x19082b2b), uvec2(0x08080808, 0x19190808), + uvec2(0x0808082b, 0x19190808), uvec2(0x08081919, 0x19190808), uvec2(0x08082b08, 0x19190808), uvec2(0x08190819, 0x19190808), + uvec2(0x08191908, 0x19190808), uvec2(0x082b0808, 0x19190808), uvec2(0x082b2b08, 0x19190808), uvec2(0x19080819, 0x19190808), + uvec2(0x19081908, 0x19190808), uvec2(0x19190808, 0x19190808), uvec2(0x2b080808, 0x19190808), uvec2(0x08080819, 0x19190819), + uvec2(0x08081908, 0x19190819), uvec2(0x08190808, 0x19190819), uvec2(0x08191919, 0x19190819), uvec2(0x19080808, 0x19190819), + uvec2(0x1908082b, 0x19190819), uvec2(0x08080808, 0x1919082b), uvec2(0x19081908, 0x1919082b), uvec2(0x2b2b2b2b, 0x1919082b), + uvec2(0x08080819, 0x19191908), uvec2(0x08081908, 0x19191908), uvec2(0x08190808, 0x19191908), uvec2(0x082b0819, 0x19191908), + uvec2(0x19080808, 0x19191908), uvec2(0x192b0808, 0x19191908), uvec2(0x2b080819, 0x19191908), uvec2(0x2b2b0819, 0x19191908), + uvec2(0x08080808, 0x19191919), uvec2(0x08082b08, 0x19191919), uvec2(0x2b080808, 0x19191919), uvec2(0x2b082b08, 0x19191919), + uvec2(0x082b0819, 0x1919192b), uvec2(0x192b2b08, 0x1919192b), uvec2(0x2b2b0819, 0x1919192b), uvec2(0x08080808, 0x19192b08), + uvec2(0x08191908, 0x19192b08), uvec2(0x19080819, 0x19192b08), uvec2(0x19190808, 0x19192b08), uvec2(0x2b192b19, 0x19192b08), + uvec2(0x08192b2b, 0x19192b19), uvec2(0x19080808, 0x19192b19), uvec2(0x1908082b, 0x19192b19), uvec2(0x2b081919, 0x19192b2b), + uvec2(0x08080819, 0x192b0808), uvec2(0x08081908, 0x192b0808), uvec2(0x08190808, 0x192b0808), uvec2(0x19080808, 0x192b0808), + uvec2(0x19191908, 0x192b0808), uvec2(0x192b082b, 0x192b0808), uvec2(0x2b08192b, 0x192b0808), uvec2(0x2b2b2b19, 0x192b0808), + uvec2(0x08080808, 0x192b0819), uvec2(0x082b1908, 0x192b082b), uvec2(0x19082b2b, 0x192b082b), uvec2(0x2b19082b, 0x192b082b), + uvec2(0x08080808, 0x192b1908), uvec2(0x0819192b, 0x192b1908), uvec2(0x08190808, 0x192b1919), uvec2(0x19080808, 0x192b1919), + uvec2(0x19081919, 0x192b1919), uvec2(0x2b2b1908, 0x192b1919), uvec2(0x08080819, 0x192b2b08), uvec2(0x192b2b2b, 0x192b2b08), + uvec2(0x082b1919, 0x192b2b19), uvec2(0x0808192b, 0x192b2b2b), uvec2(0x19191908, 0x192b2b2b), uvec2(0x192b082b, 0x192b2b2b), + uvec2(0x08080808, 0x2b080808), uvec2(0x0808082b, 0x2b080808), uvec2(0x08081919, 0x2b080808), uvec2(0x08082b08, 0x2b080808), + uvec2(0x08190819, 0x2b080808), uvec2(0x08191908, 0x2b080808), uvec2(0x082b0808, 0x2b080808), uvec2(0x082b2b2b, 0x2b080808), + uvec2(0x19080819, 0x2b080808), uvec2(0x19081908, 0x2b080808), uvec2(0x19190808, 0x2b080808), uvec2(0x2b080808, 0x2b080808), + uvec2(0x2b08082b, 0x2b080808), uvec2(0x2b2b2b08, 0x2b080808), uvec2(0x2b2b2b2b, 0x2b080808), uvec2(0x08080819, 0x2b080819), + uvec2(0x08081908, 0x2b080819), uvec2(0x0808192b, 0x2b080819), uvec2(0x08190808, 0x2b080819), uvec2(0x19080808, 0x2b080819), + uvec2(0x19190819, 0x2b080819), uvec2(0x19192b19, 0x2b080819), uvec2(0x08080808, 0x2b08082b), uvec2(0x082b0808, 0x2b08082b), + uvec2(0x2b080808, 0x2b08082b), uvec2(0x2b08082b, 0x2b08082b), uvec2(0x2b2b0808, 0x2b08082b), uvec2(0x2b2b2b08, 0x2b08082b), + uvec2(0x08080819, 0x2b081908), uvec2(0x08081908, 0x2b081908), uvec2(0x08190808, 0x2b081908), uvec2(0x0819082b, 0x2b081908), + uvec2(0x08191919, 0x2b081908), uvec2(0x19080808, 0x2b081908), uvec2(0x192b0808, 0x2b081908), uvec2(0x2b082b19, 0x2b081908), + uvec2(0x08080808, 0x2b081919), uvec2(0x19081908, 0x2b081919), uvec2(0x2b2b1919, 0x2b081919), uvec2(0x08192b08, 0x2b08192b), + uvec2(0x192b2b2b, 0x2b08192b), uvec2(0x08080808, 0x2b082b08), uvec2(0x08082b08, 0x2b082b08), uvec2(0x082b1919, 0x2b082b08), + uvec2(0x19192b2b, 0x2b082b08), uvec2(0x2b080808, 0x2b082b08), uvec2(0x2b08082b, 0x2b082b08), uvec2(0x2b2b2b08, 0x2b082b08), + uvec2(0x0808192b, 0x2b082b19), uvec2(0x082b082b, 0x2b082b2b), uvec2(0x2b080808, 0x2b082b2b), uvec2(0x2b082b08, 0x2b082b2b), + uvec2(0x2b19192b, 0x2b082b2b), uvec2(0x2b2b2b08, 0x2b082b2b), uvec2(0x08080819, 0x2b190808), uvec2(0x08081908, 0x2b190808), + uvec2(0x08190808, 0x2b190808), uvec2(0x19080808, 0x2b190808), uvec2(0x1919192b, 0x2b190808), uvec2(0x2b081908, 0x2b190808), + uvec2(0x08080808, 0x2b190819), uvec2(0x082b082b, 0x2b190819), uvec2(0x192b1908, 0x2b190819), uvec2(0x1919192b, 0x2b19082b), + uvec2(0x2b082b19, 0x2b19082b), uvec2(0x08080808, 0x2b191908), uvec2(0x08081919, 0x2b191908), uvec2(0x19081908, 0x2b191908), + uvec2(0x19190808, 0x2b191908), uvec2(0x19192b08, 0x2b191908), uvec2(0x082b2b19, 0x2b191919), uvec2(0x2b190808, 0x2b191919), + uvec2(0x2b19082b, 0x2b191919), uvec2(0x19080819, 0x2b19192b), uvec2(0x19190819, 0x2b192b08), uvec2(0x2b2b192b, 0x2b192b08), + uvec2(0x19082b19, 0x2b192b19), uvec2(0x08191919, 0x2b192b2b), uvec2(0x192b0808, 0x2b192b2b), uvec2(0x08080808, 0x2b2b0808), + uvec2(0x0808082b, 0x2b2b0808), uvec2(0x08082b08, 0x2b2b0808), uvec2(0x08082b2b, 0x2b2b0808), uvec2(0x082b0808, 0x2b2b0808), + uvec2(0x082b2b2b, 0x2b2b0808), uvec2(0x2b2b0808, 0x2b2b0808), uvec2(0x19190819, 0x2b2b0819), uvec2(0x19192b19, 0x2b2b0819), + uvec2(0x2b2b192b, 0x2b2b0819), uvec2(0x08080808, 0x2b2b082b), uvec2(0x0808082b, 0x2b2b082b), uvec2(0x08082b08, 0x2b2b082b), + uvec2(0x082b2b2b, 0x2b2b082b), uvec2(0x2b080808, 0x2b2b082b), uvec2(0x2b2b0808, 0x2b2b082b), uvec2(0x19080808, 0x2b2b1908), + uvec2(0x2b191919, 0x2b2b1908), uvec2(0x192b1919, 0x2b2b192b), uvec2(0x2b192b08, 0x2b2b192b), uvec2(0x08082b2b, 0x2b2b2b08), + uvec2(0x082b0808, 0x2b2b2b08), uvec2(0x082b082b, 0x2b2b2b08), uvec2(0x082b2b08, 0x2b2b2b08), uvec2(0x2b2b0808, 0x2b2b2b08), + uvec2(0x2b2b2b08, 0x2b2b2b08), uvec2(0x08081908, 0x2b2b2b19), uvec2(0x2b081908, 0x2b2b2b19), uvec2(0x2b08192b, 0x2b2b2b19), + uvec2(0x082b2b08, 0x2b2b2b2b), uvec2(0x082b2b2b, 0x2b2b2b2b), uvec2(0x2b190819, 0x2b2b2b2b), uvec2(0x2b2b2b2b, 0x2b2b2b2b), +}; + +shared uvec2 iq2xs_grid[512]; + +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + for (uint i = gl_LocalInvocationIndex.x; i < iq2xs_grid.length(); i += wgsize.x) { + iq2xs_grid[i] = iq2xs_grid_const[i]; + } + barrier(); +} + +#define QUANT_K QUANT_K_IQ2_XS +#define QUANT_R QUANT_R_IQ2_XS +#define A_TYPE block_iq2_xs +#define A_TYPE_PACKED16 block_iq2_xs_packed16 +#endif + +#define QUANT_K_IQ2_S 256 +#define QUANT_R_IQ2_S 1 + +struct block_iq2_s +{ + float16_t d; + uint8_t qs[QUANT_K_IQ2_S/4]; + uint8_t qh[QUANT_K_IQ2_S/32]; + uint8_t scales[QUANT_K_IQ2_S/32]; +}; + +#if defined(DATA_A_IQ2_S) + +const uvec2 iq2s_grid_const[1024] = { + uvec2(0x08080808, 0x08080808), uvec2(0x0808082b, 0x08080808), uvec2(0x08081919, 0x08080808), uvec2(0x08082b08, 0x08080808), + uvec2(0x08082b2b, 0x08080808), uvec2(0x08190819, 0x08080808), uvec2(0x08191908, 0x08080808), uvec2(0x0819192b, 0x08080808), + uvec2(0x08192b19, 0x08080808), uvec2(0x082b0808, 0x08080808), uvec2(0x082b082b, 0x08080808), uvec2(0x082b1919, 0x08080808), + uvec2(0x082b2b08, 0x08080808), uvec2(0x19080819, 0x08080808), uvec2(0x19081908, 0x08080808), uvec2(0x1908192b, 0x08080808), + uvec2(0x19082b19, 0x08080808), uvec2(0x19190808, 0x08080808), uvec2(0x1919082b, 0x08080808), uvec2(0x19191919, 0x08080808), + uvec2(0x19192b08, 0x08080808), uvec2(0x192b0819, 0x08080808), uvec2(0x192b1908, 0x08080808), uvec2(0x192b192b, 0x08080808), + uvec2(0x192b2b19, 0x08080808), uvec2(0x2b080808, 0x08080808), uvec2(0x2b08082b, 0x08080808), uvec2(0x2b081919, 0x08080808), + uvec2(0x2b082b08, 0x08080808), uvec2(0x2b190819, 0x08080808), uvec2(0x2b191908, 0x08080808), uvec2(0x2b2b0808, 0x08080808), + uvec2(0x2b2b1919, 0x08080808), uvec2(0x2b2b2b2b, 0x08080808), uvec2(0x08080819, 0x08080819), uvec2(0x08081908, 0x08080819), + uvec2(0x0808192b, 0x08080819), uvec2(0x08082b19, 0x08080819), uvec2(0x08190808, 0x08080819), uvec2(0x0819082b, 0x08080819), + uvec2(0x08191919, 0x08080819), uvec2(0x08192b08, 0x08080819), uvec2(0x082b0819, 0x08080819), uvec2(0x082b1908, 0x08080819), + uvec2(0x19080808, 0x08080819), uvec2(0x1908082b, 0x08080819), uvec2(0x19081919, 0x08080819), uvec2(0x19082b08, 0x08080819), + uvec2(0x19190819, 0x08080819), uvec2(0x19191908, 0x08080819), uvec2(0x1919192b, 0x08080819), uvec2(0x19192b19, 0x08080819), + uvec2(0x192b0808, 0x08080819), uvec2(0x192b1919, 0x08080819), uvec2(0x192b2b08, 0x08080819), uvec2(0x2b080819, 0x08080819), + uvec2(0x2b081908, 0x08080819), uvec2(0x2b190808, 0x08080819), uvec2(0x2b19082b, 0x08080819), uvec2(0x2b191919, 0x08080819), + uvec2(0x2b2b0819, 0x08080819), uvec2(0x2b2b1908, 0x08080819), uvec2(0x08080808, 0x0808082b), uvec2(0x0808082b, 0x0808082b), + uvec2(0x08081919, 0x0808082b), uvec2(0x08082b08, 0x0808082b), uvec2(0x08190819, 0x0808082b), uvec2(0x08191908, 0x0808082b), + uvec2(0x082b0808, 0x0808082b), uvec2(0x082b2b2b, 0x0808082b), uvec2(0x19080819, 0x0808082b), uvec2(0x19081908, 0x0808082b), + uvec2(0x1908192b, 0x0808082b), uvec2(0x19082b19, 0x0808082b), uvec2(0x19190808, 0x0808082b), uvec2(0x19191919, 0x0808082b), + uvec2(0x2b080808, 0x0808082b), uvec2(0x2b081919, 0x0808082b), uvec2(0x2b082b2b, 0x0808082b), uvec2(0x2b191908, 0x0808082b), + uvec2(0x2b2b082b, 0x0808082b), uvec2(0x08080819, 0x08081908), uvec2(0x08081908, 0x08081908), uvec2(0x0808192b, 0x08081908), + uvec2(0x08082b19, 0x08081908), uvec2(0x08190808, 0x08081908), uvec2(0x0819082b, 0x08081908), uvec2(0x08191919, 0x08081908), + uvec2(0x08192b08, 0x08081908), uvec2(0x082b0819, 0x08081908), uvec2(0x082b1908, 0x08081908), uvec2(0x082b192b, 0x08081908), + uvec2(0x082b2b19, 0x08081908), uvec2(0x19080808, 0x08081908), uvec2(0x1908082b, 0x08081908), uvec2(0x19081919, 0x08081908), + uvec2(0x19082b08, 0x08081908), uvec2(0x19082b2b, 0x08081908), uvec2(0x19190819, 0x08081908), uvec2(0x19191908, 0x08081908), + uvec2(0x1919192b, 0x08081908), uvec2(0x19192b19, 0x08081908), uvec2(0x192b0808, 0x08081908), uvec2(0x192b082b, 0x08081908), + uvec2(0x192b1919, 0x08081908), uvec2(0x2b080819, 0x08081908), uvec2(0x2b081908, 0x08081908), uvec2(0x2b08192b, 0x08081908), + uvec2(0x2b082b19, 0x08081908), uvec2(0x2b190808, 0x08081908), uvec2(0x2b191919, 0x08081908), uvec2(0x2b192b08, 0x08081908), + uvec2(0x2b2b0819, 0x08081908), uvec2(0x2b2b1908, 0x08081908), uvec2(0x08080808, 0x08081919), uvec2(0x0808082b, 0x08081919), + uvec2(0x08081919, 0x08081919), uvec2(0x08082b08, 0x08081919), uvec2(0x08082b2b, 0x08081919), uvec2(0x08190819, 0x08081919), + uvec2(0x08191908, 0x08081919), uvec2(0x0819192b, 0x08081919), uvec2(0x08192b19, 0x08081919), uvec2(0x082b0808, 0x08081919), + uvec2(0x082b1919, 0x08081919), uvec2(0x082b2b08, 0x08081919), uvec2(0x19080819, 0x08081919), uvec2(0x19081908, 0x08081919), + uvec2(0x1908192b, 0x08081919), uvec2(0x19082b19, 0x08081919), uvec2(0x19190808, 0x08081919), uvec2(0x1919082b, 0x08081919), + uvec2(0x19191919, 0x08081919), uvec2(0x19192b08, 0x08081919), uvec2(0x192b0819, 0x08081919), uvec2(0x192b1908, 0x08081919), + uvec2(0x2b080808, 0x08081919), uvec2(0x2b08082b, 0x08081919), uvec2(0x2b081919, 0x08081919), uvec2(0x2b082b08, 0x08081919), + uvec2(0x2b190819, 0x08081919), uvec2(0x2b191908, 0x08081919), uvec2(0x2b2b0808, 0x08081919), uvec2(0x08080819, 0x0808192b), + uvec2(0x08081908, 0x0808192b), uvec2(0x0808192b, 0x0808192b), uvec2(0x08082b19, 0x0808192b), uvec2(0x08190808, 0x0808192b), + uvec2(0x08191919, 0x0808192b), uvec2(0x19080808, 0x0808192b), uvec2(0x19081919, 0x0808192b), uvec2(0x19082b08, 0x0808192b), + uvec2(0x19190819, 0x0808192b), uvec2(0x19191908, 0x0808192b), uvec2(0x192b0808, 0x0808192b), uvec2(0x2b080819, 0x0808192b), + uvec2(0x2b081908, 0x0808192b), uvec2(0x2b190808, 0x0808192b), uvec2(0x08080808, 0x08082b08), uvec2(0x0808082b, 0x08082b08), + uvec2(0x08081919, 0x08082b08), uvec2(0x08082b08, 0x08082b08), uvec2(0x08190819, 0x08082b08), uvec2(0x08191908, 0x08082b08), + uvec2(0x0819192b, 0x08082b08), uvec2(0x08192b19, 0x08082b08), uvec2(0x082b0808, 0x08082b08), uvec2(0x082b1919, 0x08082b08), + uvec2(0x082b2b2b, 0x08082b08), uvec2(0x19080819, 0x08082b08), uvec2(0x19081908, 0x08082b08), uvec2(0x1908192b, 0x08082b08), + uvec2(0x19082b19, 0x08082b08), uvec2(0x19190808, 0x08082b08), uvec2(0x1919082b, 0x08082b08), uvec2(0x19191919, 0x08082b08), + uvec2(0x19192b08, 0x08082b08), uvec2(0x192b0819, 0x08082b08), uvec2(0x192b1908, 0x08082b08), uvec2(0x2b080808, 0x08082b08), + uvec2(0x2b081919, 0x08082b08), uvec2(0x2b191908, 0x08082b08), uvec2(0x2b2b2b2b, 0x08082b08), uvec2(0x08080819, 0x08082b19), + uvec2(0x08081908, 0x08082b19), uvec2(0x08190808, 0x08082b19), uvec2(0x0819082b, 0x08082b19), uvec2(0x08191919, 0x08082b19), + uvec2(0x08192b08, 0x08082b19), uvec2(0x082b0819, 0x08082b19), uvec2(0x19080808, 0x08082b19), uvec2(0x19081919, 0x08082b19), + uvec2(0x19082b08, 0x08082b19), uvec2(0x19190819, 0x08082b19), uvec2(0x19191908, 0x08082b19), uvec2(0x192b0808, 0x08082b19), + uvec2(0x2b080819, 0x08082b19), uvec2(0x2b190808, 0x08082b19), uvec2(0x08080808, 0x08082b2b), uvec2(0x08190819, 0x08082b2b), + uvec2(0x08191908, 0x08082b2b), uvec2(0x082b082b, 0x08082b2b), uvec2(0x082b2b08, 0x08082b2b), uvec2(0x082b2b2b, 0x08082b2b), + uvec2(0x19190808, 0x08082b2b), uvec2(0x2b192b19, 0x08082b2b), uvec2(0x08080819, 0x08190808), uvec2(0x08081908, 0x08190808), + uvec2(0x0808192b, 0x08190808), uvec2(0x08082b19, 0x08190808), uvec2(0x08190808, 0x08190808), uvec2(0x0819082b, 0x08190808), + uvec2(0x08191919, 0x08190808), uvec2(0x08192b08, 0x08190808), uvec2(0x082b0819, 0x08190808), uvec2(0x082b1908, 0x08190808), + uvec2(0x082b192b, 0x08190808), uvec2(0x19080808, 0x08190808), uvec2(0x1908082b, 0x08190808), uvec2(0x19081919, 0x08190808), + uvec2(0x19082b08, 0x08190808), uvec2(0x19190819, 0x08190808), uvec2(0x19191908, 0x08190808), uvec2(0x1919192b, 0x08190808), + uvec2(0x19192b19, 0x08190808), uvec2(0x192b0808, 0x08190808), uvec2(0x192b082b, 0x08190808), uvec2(0x192b1919, 0x08190808), + uvec2(0x192b2b08, 0x08190808), uvec2(0x2b080819, 0x08190808), uvec2(0x2b081908, 0x08190808), uvec2(0x2b08192b, 0x08190808), + uvec2(0x2b190808, 0x08190808), uvec2(0x2b191919, 0x08190808), uvec2(0x2b192b08, 0x08190808), uvec2(0x2b2b0819, 0x08190808), + uvec2(0x2b2b1908, 0x08190808), uvec2(0x08080808, 0x08190819), uvec2(0x0808082b, 0x08190819), uvec2(0x08081919, 0x08190819), + uvec2(0x08082b08, 0x08190819), uvec2(0x08082b2b, 0x08190819), uvec2(0x08190819, 0x08190819), uvec2(0x08191908, 0x08190819), + uvec2(0x0819192b, 0x08190819), uvec2(0x08192b19, 0x08190819), uvec2(0x082b0808, 0x08190819), uvec2(0x082b082b, 0x08190819), + uvec2(0x082b1919, 0x08190819), uvec2(0x082b2b08, 0x08190819), uvec2(0x19080819, 0x08190819), uvec2(0x19081908, 0x08190819), + uvec2(0x1908192b, 0x08190819), uvec2(0x19082b19, 0x08190819), uvec2(0x19190808, 0x08190819), uvec2(0x1919082b, 0x08190819), + uvec2(0x19191919, 0x08190819), uvec2(0x19192b08, 0x08190819), uvec2(0x192b0819, 0x08190819), uvec2(0x192b1908, 0x08190819), + uvec2(0x2b080808, 0x08190819), uvec2(0x2b08082b, 0x08190819), uvec2(0x2b081919, 0x08190819), uvec2(0x2b082b08, 0x08190819), + uvec2(0x2b190819, 0x08190819), uvec2(0x2b191908, 0x08190819), uvec2(0x08080819, 0x0819082b), uvec2(0x08081908, 0x0819082b), + uvec2(0x08082b19, 0x0819082b), uvec2(0x08190808, 0x0819082b), uvec2(0x08191919, 0x0819082b), uvec2(0x082b0819, 0x0819082b), + uvec2(0x082b1908, 0x0819082b), uvec2(0x19080808, 0x0819082b), uvec2(0x19081919, 0x0819082b), uvec2(0x19190819, 0x0819082b), + uvec2(0x19191908, 0x0819082b), uvec2(0x2b080819, 0x0819082b), uvec2(0x2b081908, 0x0819082b), uvec2(0x2b190808, 0x0819082b), + uvec2(0x08080808, 0x08191908), uvec2(0x0808082b, 0x08191908), uvec2(0x08081919, 0x08191908), uvec2(0x08082b08, 0x08191908), + uvec2(0x08190819, 0x08191908), uvec2(0x08191908, 0x08191908), uvec2(0x0819192b, 0x08191908), uvec2(0x08192b19, 0x08191908), + uvec2(0x082b0808, 0x08191908), uvec2(0x082b1919, 0x08191908), uvec2(0x082b2b08, 0x08191908), uvec2(0x19080819, 0x08191908), + uvec2(0x19081908, 0x08191908), uvec2(0x1908192b, 0x08191908), uvec2(0x19082b19, 0x08191908), uvec2(0x19190808, 0x08191908), + uvec2(0x1919082b, 0x08191908), uvec2(0x19191919, 0x08191908), uvec2(0x19192b08, 0x08191908), uvec2(0x192b0819, 0x08191908), + uvec2(0x192b1908, 0x08191908), uvec2(0x2b080808, 0x08191908), uvec2(0x2b08082b, 0x08191908), uvec2(0x2b081919, 0x08191908), + uvec2(0x2b082b08, 0x08191908), uvec2(0x2b190819, 0x08191908), uvec2(0x2b191908, 0x08191908), uvec2(0x2b2b0808, 0x08191908), + uvec2(0x08080819, 0x08191919), uvec2(0x08081908, 0x08191919), uvec2(0x0808192b, 0x08191919), uvec2(0x08082b19, 0x08191919), + uvec2(0x08190808, 0x08191919), uvec2(0x0819082b, 0x08191919), uvec2(0x08191919, 0x08191919), uvec2(0x08192b08, 0x08191919), + uvec2(0x082b0819, 0x08191919), uvec2(0x082b1908, 0x08191919), uvec2(0x19080808, 0x08191919), uvec2(0x1908082b, 0x08191919), + uvec2(0x19081919, 0x08191919), uvec2(0x19082b08, 0x08191919), uvec2(0x19190819, 0x08191919), uvec2(0x19191908, 0x08191919), + uvec2(0x192b0808, 0x08191919), uvec2(0x2b080819, 0x08191919), uvec2(0x2b081908, 0x08191919), uvec2(0x2b190808, 0x08191919), + uvec2(0x08080808, 0x0819192b), uvec2(0x08081919, 0x0819192b), uvec2(0x08082b08, 0x0819192b), uvec2(0x08190819, 0x0819192b), + uvec2(0x08191908, 0x0819192b), uvec2(0x082b0808, 0x0819192b), uvec2(0x19080819, 0x0819192b), uvec2(0x19081908, 0x0819192b), + uvec2(0x19190808, 0x0819192b), uvec2(0x2b080808, 0x0819192b), uvec2(0x2b2b2b2b, 0x0819192b), uvec2(0x08080819, 0x08192b08), + uvec2(0x08081908, 0x08192b08), uvec2(0x0808192b, 0x08192b08), uvec2(0x08082b19, 0x08192b08), uvec2(0x08190808, 0x08192b08), + uvec2(0x08191919, 0x08192b08), uvec2(0x08192b08, 0x08192b08), uvec2(0x082b0819, 0x08192b08), uvec2(0x19080808, 0x08192b08), + uvec2(0x1908082b, 0x08192b08), uvec2(0x19081919, 0x08192b08), uvec2(0x19082b08, 0x08192b08), uvec2(0x19190819, 0x08192b08), + uvec2(0x19191908, 0x08192b08), uvec2(0x192b0808, 0x08192b08), uvec2(0x2b080819, 0x08192b08), uvec2(0x2b081908, 0x08192b08), + uvec2(0x08080808, 0x08192b19), uvec2(0x0808082b, 0x08192b19), uvec2(0x08081919, 0x08192b19), uvec2(0x08082b08, 0x08192b19), + uvec2(0x08190819, 0x08192b19), uvec2(0x08191908, 0x08192b19), uvec2(0x082b0808, 0x08192b19), uvec2(0x19080819, 0x08192b19), + uvec2(0x19081908, 0x08192b19), uvec2(0x19190808, 0x08192b19), uvec2(0x192b2b19, 0x08192b19), uvec2(0x2b2b082b, 0x08192b19), + uvec2(0x08081908, 0x08192b2b), uvec2(0x08190808, 0x08192b2b), uvec2(0x19080808, 0x08192b2b), uvec2(0x1919192b, 0x08192b2b), + uvec2(0x08080808, 0x082b0808), uvec2(0x0808082b, 0x082b0808), uvec2(0x08081919, 0x082b0808), uvec2(0x08082b08, 0x082b0808), + uvec2(0x08190819, 0x082b0808), uvec2(0x08191908, 0x082b0808), uvec2(0x0819192b, 0x082b0808), uvec2(0x08192b19, 0x082b0808), + uvec2(0x082b0808, 0x082b0808), uvec2(0x082b1919, 0x082b0808), uvec2(0x082b2b2b, 0x082b0808), uvec2(0x19080819, 0x082b0808), + uvec2(0x19081908, 0x082b0808), uvec2(0x19190808, 0x082b0808), uvec2(0x1919082b, 0x082b0808), uvec2(0x19191919, 0x082b0808), + uvec2(0x192b1908, 0x082b0808), uvec2(0x2b080808, 0x082b0808), uvec2(0x2b082b2b, 0x082b0808), uvec2(0x2b191908, 0x082b0808), + uvec2(0x2b2b2b2b, 0x082b0808), uvec2(0x08080819, 0x082b0819), uvec2(0x08081908, 0x082b0819), uvec2(0x08190808, 0x082b0819), + uvec2(0x0819082b, 0x082b0819), uvec2(0x08191919, 0x082b0819), uvec2(0x082b0819, 0x082b0819), uvec2(0x19080808, 0x082b0819), + uvec2(0x1908082b, 0x082b0819), uvec2(0x19081919, 0x082b0819), uvec2(0x19190819, 0x082b0819), uvec2(0x19191908, 0x082b0819), + uvec2(0x192b0808, 0x082b0819), uvec2(0x2b080819, 0x082b0819), uvec2(0x2b081908, 0x082b0819), uvec2(0x2b190808, 0x082b0819), + uvec2(0x08080808, 0x082b082b), uvec2(0x08082b2b, 0x082b082b), uvec2(0x082b082b, 0x082b082b), uvec2(0x082b2b08, 0x082b082b), + uvec2(0x082b2b2b, 0x082b082b), uvec2(0x19081908, 0x082b082b), uvec2(0x19190808, 0x082b082b), uvec2(0x2b082b08, 0x082b082b), + uvec2(0x2b082b2b, 0x082b082b), uvec2(0x2b2b2b08, 0x082b082b), uvec2(0x08080819, 0x082b1908), uvec2(0x08081908, 0x082b1908), + uvec2(0x0808192b, 0x082b1908), uvec2(0x08082b19, 0x082b1908), uvec2(0x08190808, 0x082b1908), uvec2(0x08191919, 0x082b1908), + uvec2(0x08192b08, 0x082b1908), uvec2(0x082b0819, 0x082b1908), uvec2(0x082b1908, 0x082b1908), uvec2(0x19080808, 0x082b1908), + uvec2(0x1908082b, 0x082b1908), uvec2(0x19081919, 0x082b1908), uvec2(0x19082b08, 0x082b1908), uvec2(0x19190819, 0x082b1908), + uvec2(0x19191908, 0x082b1908), uvec2(0x192b0808, 0x082b1908), uvec2(0x2b080819, 0x082b1908), uvec2(0x2b081908, 0x082b1908), + uvec2(0x2b190808, 0x082b1908), uvec2(0x08080808, 0x082b1919), uvec2(0x08081919, 0x082b1919), uvec2(0x08082b08, 0x082b1919), + uvec2(0x08190819, 0x082b1919), uvec2(0x08191908, 0x082b1919), uvec2(0x082b0808, 0x082b1919), uvec2(0x19080819, 0x082b1919), + uvec2(0x19081908, 0x082b1919), uvec2(0x19190808, 0x082b1919), uvec2(0x192b192b, 0x082b1919), uvec2(0x2b080808, 0x082b1919), + uvec2(0x08080819, 0x082b192b), uvec2(0x08081908, 0x082b192b), uvec2(0x08190808, 0x082b192b), uvec2(0x19080808, 0x082b192b), + uvec2(0x19192b19, 0x082b192b), uvec2(0x08080808, 0x082b2b08), uvec2(0x08081919, 0x082b2b08), uvec2(0x08190819, 0x082b2b08), + uvec2(0x08191908, 0x082b2b08), uvec2(0x19080819, 0x082b2b08), uvec2(0x19081908, 0x082b2b08), uvec2(0x19190808, 0x082b2b08), + uvec2(0x2b082b2b, 0x082b2b08), uvec2(0x2b2b2b2b, 0x082b2b08), uvec2(0x08080819, 0x082b2b19), uvec2(0x08081908, 0x082b2b19), + uvec2(0x08190808, 0x082b2b19), uvec2(0x2b191919, 0x082b2b19), uvec2(0x08082b2b, 0x082b2b2b), uvec2(0x082b082b, 0x082b2b2b), + uvec2(0x192b1908, 0x082b2b2b), uvec2(0x2b082b08, 0x082b2b2b), uvec2(0x2b082b2b, 0x082b2b2b), uvec2(0x08080819, 0x19080808), + uvec2(0x08081908, 0x19080808), uvec2(0x0808192b, 0x19080808), uvec2(0x08082b19, 0x19080808), uvec2(0x08190808, 0x19080808), + uvec2(0x0819082b, 0x19080808), uvec2(0x08191919, 0x19080808), uvec2(0x08192b08, 0x19080808), uvec2(0x08192b2b, 0x19080808), + uvec2(0x082b0819, 0x19080808), uvec2(0x082b1908, 0x19080808), uvec2(0x082b192b, 0x19080808), uvec2(0x19080808, 0x19080808), + uvec2(0x1908082b, 0x19080808), uvec2(0x19081919, 0x19080808), uvec2(0x19082b08, 0x19080808), uvec2(0x19082b2b, 0x19080808), + uvec2(0x19190819, 0x19080808), uvec2(0x19191908, 0x19080808), uvec2(0x1919192b, 0x19080808), uvec2(0x19192b19, 0x19080808), + uvec2(0x192b0808, 0x19080808), uvec2(0x192b082b, 0x19080808), uvec2(0x192b1919, 0x19080808), uvec2(0x2b080819, 0x19080808), + uvec2(0x2b081908, 0x19080808), uvec2(0x2b190808, 0x19080808), uvec2(0x2b191919, 0x19080808), uvec2(0x2b192b08, 0x19080808), + uvec2(0x2b2b0819, 0x19080808), uvec2(0x2b2b1908, 0x19080808), uvec2(0x08080808, 0x19080819), uvec2(0x0808082b, 0x19080819), + uvec2(0x08081919, 0x19080819), uvec2(0x08082b08, 0x19080819), uvec2(0x08190819, 0x19080819), uvec2(0x08191908, 0x19080819), + uvec2(0x0819192b, 0x19080819), uvec2(0x08192b19, 0x19080819), uvec2(0x082b0808, 0x19080819), uvec2(0x082b082b, 0x19080819), + uvec2(0x082b1919, 0x19080819), uvec2(0x19080819, 0x19080819), uvec2(0x19081908, 0x19080819), uvec2(0x1908192b, 0x19080819), + uvec2(0x19082b19, 0x19080819), uvec2(0x19190808, 0x19080819), uvec2(0x1919082b, 0x19080819), uvec2(0x19191919, 0x19080819), + uvec2(0x19192b08, 0x19080819), uvec2(0x192b0819, 0x19080819), uvec2(0x192b1908, 0x19080819), uvec2(0x2b080808, 0x19080819), + uvec2(0x2b08082b, 0x19080819), uvec2(0x2b081919, 0x19080819), uvec2(0x2b082b08, 0x19080819), uvec2(0x2b190819, 0x19080819), + uvec2(0x2b191908, 0x19080819), uvec2(0x2b2b0808, 0x19080819), uvec2(0x08080819, 0x1908082b), uvec2(0x08081908, 0x1908082b), + uvec2(0x08190808, 0x1908082b), uvec2(0x0819082b, 0x1908082b), uvec2(0x08191919, 0x1908082b), uvec2(0x08192b08, 0x1908082b), + uvec2(0x082b1908, 0x1908082b), uvec2(0x19080808, 0x1908082b), uvec2(0x19081919, 0x1908082b), uvec2(0x19082b08, 0x1908082b), + uvec2(0x19190819, 0x1908082b), uvec2(0x19191908, 0x1908082b), uvec2(0x192b0808, 0x1908082b), uvec2(0x2b080819, 0x1908082b), + uvec2(0x2b081908, 0x1908082b), uvec2(0x08080808, 0x19081908), uvec2(0x0808082b, 0x19081908), uvec2(0x08081919, 0x19081908), + uvec2(0x08082b08, 0x19081908), uvec2(0x08082b2b, 0x19081908), uvec2(0x08190819, 0x19081908), uvec2(0x08191908, 0x19081908), + uvec2(0x0819192b, 0x19081908), uvec2(0x08192b19, 0x19081908), uvec2(0x082b0808, 0x19081908), uvec2(0x082b082b, 0x19081908), + uvec2(0x082b1919, 0x19081908), uvec2(0x082b2b08, 0x19081908), uvec2(0x19080819, 0x19081908), uvec2(0x19081908, 0x19081908), + uvec2(0x1908192b, 0x19081908), uvec2(0x19082b19, 0x19081908), uvec2(0x19190808, 0x19081908), uvec2(0x1919082b, 0x19081908), + uvec2(0x19191919, 0x19081908), uvec2(0x19192b08, 0x19081908), uvec2(0x192b0819, 0x19081908), uvec2(0x192b1908, 0x19081908), + uvec2(0x2b080808, 0x19081908), uvec2(0x2b08082b, 0x19081908), uvec2(0x2b081919, 0x19081908), uvec2(0x2b082b08, 0x19081908), + uvec2(0x2b190819, 0x19081908), uvec2(0x2b191908, 0x19081908), uvec2(0x2b2b0808, 0x19081908), uvec2(0x08080819, 0x19081919), + uvec2(0x08081908, 0x19081919), uvec2(0x0808192b, 0x19081919), uvec2(0x08082b19, 0x19081919), uvec2(0x08190808, 0x19081919), + uvec2(0x0819082b, 0x19081919), uvec2(0x08191919, 0x19081919), uvec2(0x08192b08, 0x19081919), uvec2(0x082b0819, 0x19081919), + uvec2(0x082b1908, 0x19081919), uvec2(0x19080808, 0x19081919), uvec2(0x1908082b, 0x19081919), uvec2(0x19081919, 0x19081919), + uvec2(0x19082b08, 0x19081919), uvec2(0x19190819, 0x19081919), uvec2(0x19191908, 0x19081919), uvec2(0x192b0808, 0x19081919), + uvec2(0x192b2b2b, 0x19081919), uvec2(0x2b080819, 0x19081919), uvec2(0x2b081908, 0x19081919), uvec2(0x2b190808, 0x19081919), + uvec2(0x08080808, 0x1908192b), uvec2(0x0808082b, 0x1908192b), uvec2(0x08081919, 0x1908192b), uvec2(0x08082b08, 0x1908192b), + uvec2(0x08190819, 0x1908192b), uvec2(0x08191908, 0x1908192b), uvec2(0x082b0808, 0x1908192b), uvec2(0x19080819, 0x1908192b), + uvec2(0x19081908, 0x1908192b), uvec2(0x19190808, 0x1908192b), uvec2(0x2b080808, 0x1908192b), uvec2(0x2b2b1919, 0x1908192b), + uvec2(0x08080819, 0x19082b08), uvec2(0x08081908, 0x19082b08), uvec2(0x08082b19, 0x19082b08), uvec2(0x08190808, 0x19082b08), + uvec2(0x0819082b, 0x19082b08), uvec2(0x08191919, 0x19082b08), uvec2(0x08192b08, 0x19082b08), uvec2(0x082b0819, 0x19082b08), + uvec2(0x082b1908, 0x19082b08), uvec2(0x19080808, 0x19082b08), uvec2(0x1908082b, 0x19082b08), uvec2(0x19081919, 0x19082b08), + uvec2(0x19082b08, 0x19082b08), uvec2(0x19190819, 0x19082b08), uvec2(0x19191908, 0x19082b08), uvec2(0x192b0808, 0x19082b08), + uvec2(0x2b081908, 0x19082b08), uvec2(0x2b190808, 0x19082b08), uvec2(0x08080808, 0x19082b19), uvec2(0x0808082b, 0x19082b19), + uvec2(0x08081919, 0x19082b19), uvec2(0x08082b08, 0x19082b19), uvec2(0x08190819, 0x19082b19), uvec2(0x08191908, 0x19082b19), + uvec2(0x082b0808, 0x19082b19), uvec2(0x19080819, 0x19082b19), uvec2(0x19081908, 0x19082b19), uvec2(0x19190808, 0x19082b19), + uvec2(0x2b080808, 0x19082b19), uvec2(0x2b19192b, 0x19082b19), uvec2(0x08080819, 0x19082b2b), uvec2(0x08081908, 0x19082b2b), + uvec2(0x08190808, 0x19082b2b), uvec2(0x19080808, 0x19082b2b), uvec2(0x08080808, 0x19190808), uvec2(0x0808082b, 0x19190808), + uvec2(0x08081919, 0x19190808), uvec2(0x08082b08, 0x19190808), uvec2(0x08190819, 0x19190808), uvec2(0x08191908, 0x19190808), + uvec2(0x0819192b, 0x19190808), uvec2(0x08192b19, 0x19190808), uvec2(0x082b0808, 0x19190808), uvec2(0x082b082b, 0x19190808), + uvec2(0x082b1919, 0x19190808), uvec2(0x082b2b08, 0x19190808), uvec2(0x19080819, 0x19190808), uvec2(0x19081908, 0x19190808), + uvec2(0x1908192b, 0x19190808), uvec2(0x19082b19, 0x19190808), uvec2(0x19190808, 0x19190808), uvec2(0x1919082b, 0x19190808), + uvec2(0x19191919, 0x19190808), uvec2(0x19192b08, 0x19190808), uvec2(0x192b0819, 0x19190808), uvec2(0x192b1908, 0x19190808), + uvec2(0x2b080808, 0x19190808), uvec2(0x2b08082b, 0x19190808), uvec2(0x2b081919, 0x19190808), uvec2(0x2b082b08, 0x19190808), + uvec2(0x2b190819, 0x19190808), uvec2(0x2b191908, 0x19190808), uvec2(0x08080819, 0x19190819), uvec2(0x08081908, 0x19190819), + uvec2(0x0808192b, 0x19190819), uvec2(0x08082b19, 0x19190819), uvec2(0x08190808, 0x19190819), uvec2(0x0819082b, 0x19190819), + uvec2(0x08191919, 0x19190819), uvec2(0x08192b08, 0x19190819), uvec2(0x082b0819, 0x19190819), uvec2(0x082b1908, 0x19190819), + uvec2(0x19080808, 0x19190819), uvec2(0x1908082b, 0x19190819), uvec2(0x19081919, 0x19190819), uvec2(0x19082b08, 0x19190819), + uvec2(0x19190819, 0x19190819), uvec2(0x19191908, 0x19190819), uvec2(0x192b0808, 0x19190819), uvec2(0x2b080819, 0x19190819), + uvec2(0x2b081908, 0x19190819), uvec2(0x2b190808, 0x19190819), uvec2(0x08080808, 0x1919082b), uvec2(0x08081919, 0x1919082b), + uvec2(0x08082b08, 0x1919082b), uvec2(0x08190819, 0x1919082b), uvec2(0x08191908, 0x1919082b), uvec2(0x082b0808, 0x1919082b), + uvec2(0x19080819, 0x1919082b), uvec2(0x19081908, 0x1919082b), uvec2(0x19190808, 0x1919082b), uvec2(0x192b2b19, 0x1919082b), + uvec2(0x2b080808, 0x1919082b), uvec2(0x08080819, 0x19191908), uvec2(0x08081908, 0x19191908), uvec2(0x0808192b, 0x19191908), + uvec2(0x08082b19, 0x19191908), uvec2(0x08190808, 0x19191908), uvec2(0x0819082b, 0x19191908), uvec2(0x08191919, 0x19191908), + uvec2(0x08192b08, 0x19191908), uvec2(0x082b0819, 0x19191908), uvec2(0x082b1908, 0x19191908), uvec2(0x19080808, 0x19191908), + uvec2(0x1908082b, 0x19191908), uvec2(0x19081919, 0x19191908), uvec2(0x19082b08, 0x19191908), uvec2(0x19190819, 0x19191908), + uvec2(0x19191908, 0x19191908), uvec2(0x192b0808, 0x19191908), uvec2(0x2b080819, 0x19191908), uvec2(0x2b081908, 0x19191908), + uvec2(0x2b190808, 0x19191908), uvec2(0x08080808, 0x19191919), uvec2(0x0808082b, 0x19191919), uvec2(0x08081919, 0x19191919), + uvec2(0x08082b08, 0x19191919), uvec2(0x08190819, 0x19191919), uvec2(0x08191908, 0x19191919), uvec2(0x082b0808, 0x19191919), + uvec2(0x19080819, 0x19191919), uvec2(0x19081908, 0x19191919), uvec2(0x19190808, 0x19191919), uvec2(0x2b080808, 0x19191919), + uvec2(0x08080819, 0x1919192b), uvec2(0x08081908, 0x1919192b), uvec2(0x08190808, 0x1919192b), uvec2(0x082b192b, 0x1919192b), + uvec2(0x19080808, 0x1919192b), uvec2(0x08080808, 0x19192b08), uvec2(0x0808082b, 0x19192b08), uvec2(0x08081919, 0x19192b08), + uvec2(0x08082b08, 0x19192b08), uvec2(0x08190819, 0x19192b08), uvec2(0x08191908, 0x19192b08), uvec2(0x082b0808, 0x19192b08), + uvec2(0x19080819, 0x19192b08), uvec2(0x19081908, 0x19192b08), uvec2(0x19190808, 0x19192b08), uvec2(0x19192b2b, 0x19192b08), + uvec2(0x2b080808, 0x19192b08), uvec2(0x08080819, 0x19192b19), uvec2(0x08081908, 0x19192b19), uvec2(0x08190808, 0x19192b19), + uvec2(0x19080808, 0x19192b19), uvec2(0x08080808, 0x19192b2b), uvec2(0x08192b19, 0x19192b2b), uvec2(0x2b081919, 0x19192b2b), + uvec2(0x2b2b2b08, 0x19192b2b), uvec2(0x08080819, 0x192b0808), uvec2(0x08081908, 0x192b0808), uvec2(0x0808192b, 0x192b0808), + uvec2(0x08190808, 0x192b0808), uvec2(0x0819082b, 0x192b0808), uvec2(0x08191919, 0x192b0808), uvec2(0x08192b08, 0x192b0808), + uvec2(0x082b0819, 0x192b0808), uvec2(0x082b1908, 0x192b0808), uvec2(0x19080808, 0x192b0808), uvec2(0x19081919, 0x192b0808), + uvec2(0x19082b08, 0x192b0808), uvec2(0x19190819, 0x192b0808), uvec2(0x19191908, 0x192b0808), uvec2(0x192b0808, 0x192b0808), + uvec2(0x2b081908, 0x192b0808), uvec2(0x2b190808, 0x192b0808), uvec2(0x08080808, 0x192b0819), uvec2(0x0808082b, 0x192b0819), + uvec2(0x08081919, 0x192b0819), uvec2(0x08082b08, 0x192b0819), uvec2(0x08190819, 0x192b0819), uvec2(0x08191908, 0x192b0819), + uvec2(0x082b0808, 0x192b0819), uvec2(0x19080819, 0x192b0819), uvec2(0x19081908, 0x192b0819), uvec2(0x19190808, 0x192b0819), + uvec2(0x2b080808, 0x192b0819), uvec2(0x2b192b19, 0x192b0819), uvec2(0x08081908, 0x192b082b), uvec2(0x08190808, 0x192b082b), + uvec2(0x19080808, 0x192b082b), uvec2(0x1919192b, 0x192b082b), uvec2(0x2b2b0819, 0x192b082b), uvec2(0x08080808, 0x192b1908), + uvec2(0x08081919, 0x192b1908), uvec2(0x08082b08, 0x192b1908), uvec2(0x08190819, 0x192b1908), uvec2(0x08191908, 0x192b1908), + uvec2(0x082b0808, 0x192b1908), uvec2(0x19080819, 0x192b1908), uvec2(0x19081908, 0x192b1908), uvec2(0x19190808, 0x192b1908), + uvec2(0x2b080808, 0x192b1908), uvec2(0x08080819, 0x192b1919), uvec2(0x08081908, 0x192b1919), uvec2(0x08190808, 0x192b1919), + uvec2(0x19080808, 0x192b1919), uvec2(0x19082b2b, 0x192b1919), uvec2(0x192b2b08, 0x192b1919), uvec2(0x2b19082b, 0x192b1919), + uvec2(0x08080808, 0x192b192b), uvec2(0x2b191908, 0x192b192b), uvec2(0x08080819, 0x192b2b08), uvec2(0x08081908, 0x192b2b08), + uvec2(0x08190808, 0x192b2b08), uvec2(0x192b1919, 0x192b2b08), uvec2(0x2b192b08, 0x192b2b08), uvec2(0x08080808, 0x192b2b19), + uvec2(0x082b2b2b, 0x192b2b19), uvec2(0x1908082b, 0x192b2b2b), uvec2(0x2b2b0819, 0x192b2b2b), uvec2(0x08080808, 0x2b080808), + uvec2(0x0808082b, 0x2b080808), uvec2(0x08081919, 0x2b080808), uvec2(0x08082b08, 0x2b080808), uvec2(0x08190819, 0x2b080808), + uvec2(0x08191908, 0x2b080808), uvec2(0x08192b19, 0x2b080808), uvec2(0x082b0808, 0x2b080808), uvec2(0x082b1919, 0x2b080808), + uvec2(0x19080819, 0x2b080808), uvec2(0x19081908, 0x2b080808), uvec2(0x19190808, 0x2b080808), uvec2(0x1919082b, 0x2b080808), + uvec2(0x19191919, 0x2b080808), uvec2(0x19192b08, 0x2b080808), uvec2(0x192b0819, 0x2b080808), uvec2(0x2b080808, 0x2b080808), + uvec2(0x2b081919, 0x2b080808), uvec2(0x2b190819, 0x2b080808), uvec2(0x2b191908, 0x2b080808), uvec2(0x08080819, 0x2b080819), + uvec2(0x08081908, 0x2b080819), uvec2(0x08082b19, 0x2b080819), uvec2(0x08190808, 0x2b080819), uvec2(0x0819082b, 0x2b080819), + uvec2(0x08191919, 0x2b080819), uvec2(0x08192b08, 0x2b080819), uvec2(0x082b0819, 0x2b080819), uvec2(0x082b1908, 0x2b080819), + uvec2(0x19080808, 0x2b080819), uvec2(0x1908082b, 0x2b080819), uvec2(0x19081919, 0x2b080819), uvec2(0x19082b08, 0x2b080819), + uvec2(0x19190819, 0x2b080819), uvec2(0x19191908, 0x2b080819), uvec2(0x2b080819, 0x2b080819), uvec2(0x2b081908, 0x2b080819), + uvec2(0x2b190808, 0x2b080819), uvec2(0x2b2b2b19, 0x2b080819), uvec2(0x08080808, 0x2b08082b), uvec2(0x08081919, 0x2b08082b), + uvec2(0x08082b2b, 0x2b08082b), uvec2(0x08190819, 0x2b08082b), uvec2(0x08191908, 0x2b08082b), uvec2(0x19080819, 0x2b08082b), + uvec2(0x19081908, 0x2b08082b), uvec2(0x19190808, 0x2b08082b), uvec2(0x08080819, 0x2b081908), uvec2(0x08081908, 0x2b081908), + uvec2(0x0808192b, 0x2b081908), uvec2(0x08082b19, 0x2b081908), uvec2(0x08190808, 0x2b081908), uvec2(0x0819082b, 0x2b081908), + uvec2(0x08191919, 0x2b081908), uvec2(0x08192b08, 0x2b081908), uvec2(0x082b0819, 0x2b081908), uvec2(0x19080808, 0x2b081908), + uvec2(0x1908082b, 0x2b081908), uvec2(0x19081919, 0x2b081908), uvec2(0x19082b08, 0x2b081908), uvec2(0x19190819, 0x2b081908), + uvec2(0x19191908, 0x2b081908), uvec2(0x192b0808, 0x2b081908), uvec2(0x2b080819, 0x2b081908), uvec2(0x2b081908, 0x2b081908), + uvec2(0x2b190808, 0x2b081908), uvec2(0x08080808, 0x2b081919), uvec2(0x0808082b, 0x2b081919), uvec2(0x08081919, 0x2b081919), + uvec2(0x08082b08, 0x2b081919), uvec2(0x08190819, 0x2b081919), uvec2(0x08191908, 0x2b081919), uvec2(0x082b0808, 0x2b081919), + uvec2(0x19080819, 0x2b081919), uvec2(0x19081908, 0x2b081919), uvec2(0x19190808, 0x2b081919), uvec2(0x2b080808, 0x2b081919), + uvec2(0x2b082b2b, 0x2b081919), uvec2(0x08080819, 0x2b08192b), uvec2(0x08081908, 0x2b08192b), uvec2(0x08190808, 0x2b08192b), + uvec2(0x082b2b19, 0x2b08192b), uvec2(0x19080808, 0x2b08192b), uvec2(0x08080808, 0x2b082b08), uvec2(0x08081919, 0x2b082b08), + uvec2(0x08190819, 0x2b082b08), uvec2(0x08191908, 0x2b082b08), uvec2(0x19080819, 0x2b082b08), uvec2(0x19081908, 0x2b082b08), + uvec2(0x19190808, 0x2b082b08), uvec2(0x2b2b082b, 0x2b082b08), uvec2(0x08080819, 0x2b082b19), uvec2(0x08081908, 0x2b082b19), + uvec2(0x19080808, 0x2b082b19), uvec2(0x192b1919, 0x2b082b19), uvec2(0x082b082b, 0x2b082b2b), uvec2(0x19192b08, 0x2b082b2b), + uvec2(0x19192b2b, 0x2b082b2b), uvec2(0x2b08082b, 0x2b082b2b), uvec2(0x2b2b082b, 0x2b082b2b), uvec2(0x08080819, 0x2b190808), + uvec2(0x08081908, 0x2b190808), uvec2(0x08082b19, 0x2b190808), uvec2(0x08190808, 0x2b190808), uvec2(0x0819082b, 0x2b190808), + uvec2(0x08191919, 0x2b190808), uvec2(0x08192b08, 0x2b190808), uvec2(0x082b1908, 0x2b190808), uvec2(0x19080808, 0x2b190808), + uvec2(0x1908082b, 0x2b190808), uvec2(0x19081919, 0x2b190808), uvec2(0x19082b08, 0x2b190808), uvec2(0x19190819, 0x2b190808), + uvec2(0x19191908, 0x2b190808), uvec2(0x192b0808, 0x2b190808), uvec2(0x2b080819, 0x2b190808), uvec2(0x2b081908, 0x2b190808), + uvec2(0x2b190808, 0x2b190808), uvec2(0x08080808, 0x2b190819), uvec2(0x08081919, 0x2b190819), uvec2(0x08190819, 0x2b190819), + uvec2(0x08191908, 0x2b190819), uvec2(0x19080819, 0x2b190819), uvec2(0x19081908, 0x2b190819), uvec2(0x19190808, 0x2b190819), + uvec2(0x19192b2b, 0x2b190819), uvec2(0x08080819, 0x2b19082b), uvec2(0x08081908, 0x2b19082b), uvec2(0x08190808, 0x2b19082b), + uvec2(0x19080808, 0x2b19082b), uvec2(0x2b2b192b, 0x2b19082b), uvec2(0x08080808, 0x2b191908), uvec2(0x0808082b, 0x2b191908), + uvec2(0x08081919, 0x2b191908), uvec2(0x08082b08, 0x2b191908), uvec2(0x08190819, 0x2b191908), uvec2(0x08191908, 0x2b191908), + uvec2(0x082b0808, 0x2b191908), uvec2(0x19080819, 0x2b191908), uvec2(0x19081908, 0x2b191908), uvec2(0x19190808, 0x2b191908), + uvec2(0x2b080808, 0x2b191908), uvec2(0x2b19192b, 0x2b191908), uvec2(0x08080819, 0x2b191919), uvec2(0x08081908, 0x2b191919), + uvec2(0x08190808, 0x2b191919), uvec2(0x19080808, 0x2b191919), uvec2(0x2b192b08, 0x2b191919), uvec2(0x2b2b0819, 0x2b191919), + uvec2(0x08080808, 0x2b19192b), uvec2(0x1908192b, 0x2b19192b), uvec2(0x192b1908, 0x2b19192b), uvec2(0x08080819, 0x2b192b08), + uvec2(0x08081908, 0x2b192b08), uvec2(0x08190808, 0x2b192b08), uvec2(0x082b192b, 0x2b192b08), uvec2(0x19080808, 0x2b192b08), + uvec2(0x2b2b2b19, 0x2b192b08), uvec2(0x08080808, 0x2b192b19), uvec2(0x19082b19, 0x2b192b19), uvec2(0x1919082b, 0x2b192b19), + uvec2(0x2b190808, 0x2b192b2b), uvec2(0x08080808, 0x2b2b0808), uvec2(0x08081919, 0x2b2b0808), uvec2(0x08082b2b, 0x2b2b0808), + uvec2(0x08191908, 0x2b2b0808), uvec2(0x082b082b, 0x2b2b0808), uvec2(0x082b2b2b, 0x2b2b0808), uvec2(0x19080819, 0x2b2b0808), + uvec2(0x19081908, 0x2b2b0808), uvec2(0x19190808, 0x2b2b0808), uvec2(0x2b2b082b, 0x2b2b0808), uvec2(0x2b2b2b2b, 0x2b2b0808), + uvec2(0x19080808, 0x2b2b0819), uvec2(0x192b1919, 0x2b2b0819), uvec2(0x0808082b, 0x2b2b082b), uvec2(0x08082b2b, 0x2b2b082b), + uvec2(0x082b082b, 0x2b2b082b), uvec2(0x082b2b08, 0x2b2b082b), uvec2(0x082b2b2b, 0x2b2b082b), uvec2(0x2b08082b, 0x2b2b082b), + uvec2(0x2b082b08, 0x2b2b082b), uvec2(0x2b082b2b, 0x2b2b082b), uvec2(0x2b2b2b08, 0x2b2b082b), uvec2(0x08080819, 0x2b2b1908), + uvec2(0x08081908, 0x2b2b1908), uvec2(0x08190808, 0x2b2b1908), uvec2(0x19080808, 0x2b2b1908), uvec2(0x2b082b19, 0x2b2b1908), + uvec2(0x2b2b1908, 0x2b2b1908), uvec2(0x08080808, 0x2b2b1919), uvec2(0x08192b19, 0x2b2b1919), uvec2(0x19190819, 0x2b2b192b), + uvec2(0x08082b2b, 0x2b2b2b08), uvec2(0x082b2b08, 0x2b2b2b08), uvec2(0x2b2b082b, 0x2b2b2b08), uvec2(0x19191908, 0x2b2b2b19), + uvec2(0x2b08192b, 0x2b2b2b19), uvec2(0x08082b08, 0x2b2b2b2b), uvec2(0x08082b2b, 0x2b2b2b2b), uvec2(0x082b0808, 0x2b2b2b2b), + uvec2(0x082b082b, 0x2b2b2b2b), uvec2(0x082b2b08, 0x2b2b2b2b), uvec2(0x2b082b08, 0x2b2b2b2b), uvec2(0x2b2b2b2b, 0x2b2b2b2b) +}; + +shared uvec2 iq2s_grid[1024]; + +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + for (uint i = gl_LocalInvocationIndex.x; i < iq2s_grid.length(); i += wgsize.x) { + iq2s_grid[i] = iq2s_grid_const[i]; + } + barrier(); +} + +#define QUANT_K QUANT_K_IQ2_S +#define QUANT_R QUANT_R_IQ2_S +#define A_TYPE block_iq2_s +#endif + +#define QUANT_K_IQ3_XXS 256 +#define QUANT_R_IQ3_XXS 1 + +struct block_iq3_xxs +{ + float16_t d; + uint8_t qs[QUANT_K_IQ3_XXS/4 + QUANT_K_IQ3_XXS/8]; +}; + +struct block_iq3_xxs_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ3_XXS/8 + QUANT_K_IQ3_XXS/16]; +}; + +#if defined(DATA_A_IQ3_XXS) + +const uint32_t iq3xxs_grid_const[256] = { + 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414, + 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14, + 0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404, + 0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e, + 0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c, + 0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c, + 0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34, + 0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c, + 0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c, + 0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04, + 0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c, + 0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414, + 0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434, + 0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c, + 0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e, + 0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24, + 0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24, + 0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c, + 0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c, + 0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14, + 0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414, + 0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e, + 0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404, + 0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c, + 0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c, + 0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14, + 0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c, + 0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c, + 0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14, + 0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14, + 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c, + 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04, +}; + +shared uint32_t iq3xxs_grid[256]; + +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + for (uint i = gl_LocalInvocationIndex.x; i < iq3xxs_grid.length(); i += wgsize.x) { + iq3xxs_grid[i] = iq3xxs_grid_const[i]; + } + barrier(); +} + +#define QUANT_K QUANT_K_IQ3_XXS +#define QUANT_R QUANT_R_IQ3_XXS +#define A_TYPE block_iq3_xxs +#define A_TYPE_PACKED16 block_iq3_xxs_packed16 +#endif + +#define QUANT_K_IQ3_S 256 +#define QUANT_R_IQ3_S 1 + +struct block_iq3_s +{ + float16_t d; + uint8_t qs[QUANT_K_IQ3_S/4]; + uint8_t qh[QUANT_K_IQ3_S/32]; + uint8_t signs[QUANT_K_IQ3_S/8]; + uint8_t scales[QUANT_K_IQ3_S/64]; +}; + +struct block_iq3_s_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ3_S/4/2]; + uint16_t qh[QUANT_K_IQ3_S/32/2]; + uint16_t signs[QUANT_K_IQ3_S/8/2]; + uint16_t scales[QUANT_K_IQ3_S/64/2]; +}; + +#if defined(DATA_A_IQ3_S) + +const uint32_t iq3s_grid_const[512] = { + 0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305, + 0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905, + 0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09, + 0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b, + 0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b, + 0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d, + 0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03, + 0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505, + 0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03, + 0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901, + 0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d, + 0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303, + 0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501, + 0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105, + 0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505, + 0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101, + 0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707, + 0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b, + 0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01, + 0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f, + 0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305, + 0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103, + 0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509, + 0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503, + 0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b, + 0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f, + 0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f, + 0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f, + 0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109, + 0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f, + 0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509, + 0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501, + 0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303, + 0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f, + 0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907, + 0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703, + 0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03, + 0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01, + 0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01, + 0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903, + 0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505, + 0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b, + 0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107, + 0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509, + 0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303, + 0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103, + 0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05, + 0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b, + 0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f, + 0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701, + 0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909, + 0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305, + 0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d, + 0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b, + 0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d, + 0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307, + 0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09, + 0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309, + 0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709, + 0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f, + 0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303, + 0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503, + 0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b, + 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101, +}; + +shared uint32_t iq3s_grid[512]; + +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + for (uint i = gl_LocalInvocationIndex.x; i < iq3s_grid.length(); i += wgsize.x) { + iq3s_grid[i] = iq3s_grid_const[i]; + } + barrier(); +} + +#define QUANT_K QUANT_K_IQ3_S +#define QUANT_R QUANT_R_IQ3_S +#define A_TYPE block_iq3_s +#define A_TYPE_PACKED16 block_iq3_s_packed16 +#endif + #define QUANT_K_IQ4_NL 32 #define QUANT_R_IQ4_NL 2 @@ -318,11 +1050,11 @@ const int8_t kvalues_iq4nl_const[16] = { shared FLOAT_TYPE kvalues_iq4nl[16]; -void init_iq4nl_shmem() +void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync - if (gl_LocalInvocationIndex.x < 16) { - kvalues_iq4nl[gl_LocalInvocationIndex.x] = FLOAT_TYPE(kvalues_iq4nl_const[gl_LocalInvocationIndex.x]); + for (uint i = gl_LocalInvocationIndex.x; i < kvalues_iq4nl.length(); i += wgsize.x) { + kvalues_iq4nl[i] = FLOAT_TYPE(kvalues_iq4nl_const[i]); } barrier(); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index e9c6cb9d45d..93ddbfadc5f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -55,6 +55,11 @@ const std::vector type_names = { "q4_k", "q5_k", "q6_k", + "iq2_xxs", + "iq2_xs", + "iq2_s", + "iq3_xxs", + "iq3_s", "iq4_nl" }; From f41fdad200a8e82c20b34fbbef45aac67e88ec5a Mon Sep 17 00:00:00 2001 From: uvos Date: Wed, 29 Jan 2025 17:46:23 +0100 Subject: [PATCH 40/64] CUDA/HIP: add warp_size to cuda_device_info --- ggml/src/ggml-cuda/common.cuh | 1 + ggml/src/ggml-cuda/ggml-cuda.cu | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index a66322da05a..eec227dce3a 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -520,6 +520,7 @@ struct ggml_cuda_device_info { bool vmm; // virtual memory support size_t vmm_granularity; // granularity of virtual memory size_t total_vram; + int warp_size; // Number of threads in a dispatch }; cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {}; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index de3f9c2ca1e..ecf06fec408 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -242,6 +242,7 @@ static ggml_cuda_device_info ggml_cuda_init() { info.devices[id].nsm = prop.multiProcessorCount; info.devices[id].smpb = prop.sharedMemPerBlock; + info.devices[id].warp_size = prop.warpSize; #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) info.devices[id].smpbo = prop.sharedMemPerBlock; @@ -256,8 +257,9 @@ static ggml_cuda_device_info ggml_cuda_init() { info.devices[id].cc += prop.minor * 0x10; } } - GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s\n", - id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff, device_vmm ? "yes" : "no"); + GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d\n", + id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff, + device_vmm ? "yes" : "no", prop.warpSize); #else info.devices[id].smpbo = prop.sharedMemPerBlockOptin; info.devices[id].cc = 100*prop.major + 10*prop.minor; From fc2e44490d48efdfd2f65cf5e53ff12866e88ef4 Mon Sep 17 00:00:00 2001 From: uvos Date: Wed, 29 Jan 2025 19:12:42 +0100 Subject: [PATCH 41/64] HIP: Prepare reduction operators for wave 64 --- ggml/src/ggml-cuda/common.cuh | 59 +++++++++++++++------------------ ggml/src/ggml-cuda/ggml-cuda.cu | 4 +-- 2 files changed, 28 insertions(+), 35 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index eec227dce3a..8d8d3932e0e 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -190,53 +190,46 @@ static __device__ void no_device_code( #define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.") #endif // __CUDA_ARCH__ +template static __device__ __forceinline__ int warp_reduce_sum(int x) { #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE return __reduce_add_sync(0xffffffff, x); #else #pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { - x += __shfl_xor_sync(0xffffffff, x, offset, 32); + for (int offset = width/2; offset > 0; offset >>= 1) { + x += __shfl_xor_sync(0xffffffff, x, offset, width); } return x; #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE } +template static __device__ __forceinline__ float warp_reduce_sum(float x) { #pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { - x += __shfl_xor_sync(0xffffffff, x, offset, 32); + for (int offset = width/2; offset > 0; offset >>= 1) { + x += __shfl_xor_sync(0xffffffff, x, offset, width); } return x; } +template static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { #pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { - a.x += __shfl_xor_sync(0xffffffff, a.x, offset, 32); - a.y += __shfl_xor_sync(0xffffffff, a.y, offset, 32); + for (int offset = width/2; offset > 0; offset >>= 1) { + a.x += __shfl_xor_sync(0xffffffff, a.x, offset, width); + a.y += __shfl_xor_sync(0xffffffff, a.y, offset, width); } return a; } +template static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #ifdef FP16_AVAILABLE - -#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) -#pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { - const half2 a_other = __shfl_xor_sync(0xffffffff, a, offset, 32); - reinterpret_cast(a.x) += __low2half(a_other); - reinterpret_cast(a.y) += __high2half(a_other); - } - return a; -#else #pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { - a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, 32)); + for (int offset = width/2; offset > 0; offset >>= 1) { + a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, width)); } return a; -#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) #else NO_DEVICE_CODE; @@ -244,10 +237,11 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #endif // FP16_AVAILABLE } +template static __device__ __forceinline__ float warp_reduce_max(float x) { #pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { - x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, 32)); + for (int offset = width/2; offset > 0; offset >>= 1) { + x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, width)); } return x; } @@ -269,35 +263,34 @@ static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b } static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) { -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) - -#if CUDART_VERSION >= CUDART_HMAX +#if defined(GGML_USE_HIP) && HIP_VERSION >= 50700000 + return half2(__hmax(a.x, b.x), __hmax(a.y, b.y)); +#elif !defined(GGML_USE_HIP) && CUDART_VERSION >= CUDART_HMAX return __hmax2(a, b); -#else +#elif !defined(GGML_USE_HIP) half2 ret; reinterpret_cast(ret.x) = __float2half(fmaxf( __low2float(a), __low2float(b))); reinterpret_cast(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b))); return ret; -#endif // CUDART_VERSION >= CUDART_HMAX - #else GGML_UNUSED(a); GGML_UNUSED(b); NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +#endif } +template static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000) #pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { - x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, 32)); + for (int offset = width/2; offset > 0; offset >>= 1) { + x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, width)); } return x; #else GGML_UNUSED(x); NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000) } #if CUDART_VERSION < CUDART_HMASK diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index ecf06fec408..383131c7789 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -240,8 +240,8 @@ static ggml_cuda_device_info ggml_cuda_init() { info.default_tensor_split[id] = total_vram; total_vram += prop.totalGlobalMem; - info.devices[id].nsm = prop.multiProcessorCount; - info.devices[id].smpb = prop.sharedMemPerBlock; + info.devices[id].nsm = prop.multiProcessorCount; + info.devices[id].smpb = prop.sharedMemPerBlock; info.devices[id].warp_size = prop.warpSize; #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) info.devices[id].smpbo = prop.sharedMemPerBlock; From 43c744ce8ba3bb169c7db74f1d87189dce38ae64 Mon Sep 17 00:00:00 2001 From: uvos Date: Wed, 29 Jan 2025 19:36:00 +0100 Subject: [PATCH 42/64] HIP: require at least HIP 5.5 --- ggml/src/ggml-hip/CMakeLists.txt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index ecc3bc66d44..7a877bdc11a 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -40,6 +40,10 @@ find_package(hip REQUIRED) find_package(hipblas REQUIRED) find_package(rocblas REQUIRED) +if (${hip_VERSION} VERSION_LESS 5.5) + message(FATAL_ERROR "At least ROCM/HIP V5.5 is required") +endif() + message(STATUS "HIP and hipBLAS found") file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh") From 85451e3612c976f4fa52c3e84816c3118a2a4928 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Fri, 31 Jan 2025 17:12:40 +0000 Subject: [PATCH 43/64] `ci`: use sccache on windows instead of ccache (llama/11545) * Use sccache on ci for windows * Detect sccache in cmake --- ggml/src/CMakeLists.txt | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 56670913521..0002ac18a33 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -93,12 +93,18 @@ endif() if (GGML_CCACHE) find_program(GGML_CCACHE_FOUND ccache) + find_program(GGML_SCCACHE_FOUND sccache) - if (GGML_CCACHE_FOUND) + if (GGML_CCACHE_FOUND OR GGML_SCCACHE_FOUND) + if(GGML_CCACHE_FOUND) + set(GGML_CCACHE_VARIANT ccache) + else() + set(GGML_CCACHE_VARIANT sccache) + endif() # TODO: should not be set globally - set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ccache) + set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${GGML_CCACHE_VARIANT}") set(ENV{CCACHE_SLOPPINESS} time_macros) - message(STATUS "ccache found, compilation results will be cached. Disable with GGML_CCACHE=OFF.") + message(STATUS "${GGML_CCACHE_VARIANT} found, compilation results will be cached. Disable with GGML_CCACHE=OFF.") else() message(STATUS "Warning: ccache not found - consider installing it for faster compilation or disable this warning with GGML_CCACHE=OFF") endif () From f8a831779e8f12253056b1f50f988f67f62f3b6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 2 Feb 2025 19:31:09 +0100 Subject: [PATCH 44/64] CUDA: use mma PTX instructions for FlashAttention (llama/11583) * CUDA: use mma PTX instructions for FlashAttention * __shfl_sync workaround for movmatrix * add __shfl_sync to HIP Co-authored-by: Diego Devesa --- ggml/include/ggml.h | 2 +- ggml/src/ggml-cuda/CMakeLists.txt | 2 +- ggml/src/ggml-cuda/common.cuh | 6 +- ggml/src/ggml-cuda/fattn-common.cuh | 179 ++++- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 637 +++++++++++++++++ ggml/src/ggml-cuda/fattn-tile-f16.cu | 24 +- ggml/src/ggml-cuda/fattn-tile-f32.cu | 19 +- ggml/src/ggml-cuda/fattn-vec-f16.cuh | 9 +- ggml/src/ggml-cuda/fattn-vec-f32.cuh | 8 +- ggml/src/ggml-cuda/fattn-wmma-f16.cu | 648 ++++++++++++++++++ ggml/src/ggml-cuda/fattn-wmma-f16.cuh | 542 +-------------- ggml/src/ggml-cuda/fattn.cu | 174 ++--- ggml/src/ggml-cuda/mma.cuh | 335 +++++++-- ggml/src/ggml-cuda/mmq.cu | 2 +- ggml/src/ggml-cuda/mmq.cuh | 349 +++++----- .../fattn-mma-f16-instance-cpb16.cu | 10 + .../fattn-mma-f16-instance-cpb32.cu | 10 + .../fattn-mma-f16-instance-cpb64.cu | 10 + .../fattn-mma-f16-instance-cpb8.cu | 10 + .../fattn-wmma-f16-instance-kqfloat-cpb16.cu | 10 - .../fattn-wmma-f16-instance-kqfloat-cpb32.cu | 9 - .../fattn-wmma-f16-instance-kqhalf-cpb16.cu | 10 - .../fattn-wmma-f16-instance-kqhalf-cpb32.cu | 10 - .../fattn-wmma-f16-instance-kqhalf-cpb8.cu | 8 - .../template-instances/generate_cu_files.py | 24 +- ggml/src/ggml-cuda/vendors/hip.h | 1 + ggml/src/ggml-hip/CMakeLists.txt | 2 +- ggml/src/ggml-musa/CMakeLists.txt | 2 +- 28 files changed, 2056 insertions(+), 996 deletions(-) create mode 100644 ggml/src/ggml-cuda/fattn-mma-f16.cuh create mode 100644 ggml/src/ggml-cuda/fattn-wmma-f16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb32.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb64.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb8.cu delete mode 100644 ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu delete mode 100644 ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu delete mode 100644 ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu delete mode 100644 ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu delete mode 100644 ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 1198dc1fd93..5bd8d9c8b50 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1775,7 +1775,7 @@ extern "C" { struct ggml_tensor * a, int k); -#define GGML_KQ_MASK_PAD 32 +#define GGML_KQ_MASK_PAD 64 // q: [n_embd, n_batch, n_head, 1] // k: [n_embd, n_kv, n_head_kv, 1] diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index 14761650f8a..119fd39b8e4 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -28,7 +28,7 @@ if (CUDAToolkit_FOUND) list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h") file(GLOB GGML_SOURCES_CUDA "*.cu") - file(GLOB SRCS "template-instances/fattn-wmma*.cu") + file(GLOB SRCS "template-instances/fattn-mma*.cu") list(APPEND GGML_SOURCES_CUDA ${SRCS}) file(GLOB SRCS "template-instances/mmq*.cu") list(APPEND GGML_SOURCES_CUDA ${SRCS}) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 8d8d3932e0e..88be8fc8a6e 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -148,7 +148,7 @@ typedef float2 dfloat2; #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING -#define INT8_MMA_AVAILABLE +#define NEW_MMA_AVAILABLE #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING #if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1) @@ -159,11 +159,13 @@ static constexpr bool fast_fp16_available(const int cc) { return cc >= GGML_CUDA_CC_PASCAL && cc != 610; } +// Any FP16 tensor cores are available. static constexpr bool fp16_mma_available(const int cc) { return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA; } -static constexpr bool int8_mma_available(const int cc) { +// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later. +static constexpr bool new_mma_available(const int cc) { return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING; } diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index ee9752da6a8..cfd7c0f4475 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -516,6 +516,104 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) { nullptr; } +template // D == head size +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +__launch_bounds__(D, 1) +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +static __global__ void flash_attn_stream_k_fixup( + float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) { + const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols); + + const int iter_k = ne11 / KQ_stride; + const int iter_j = (ne01 + (ncols - 1)) / ncols; + + const int bidx0 = blockIdx.x; + + const int kbc0 = (bidx0 + 0)*iter_k*iter_j*ne02 / gridDim.x; + const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*ne02 / gridDim.x; + + const bool did_not_have_any_data = kbc0 == kbc0_stop; + const bool wrote_beginning_of_tile = kbc0 % iter_k == 0; + const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0; + if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) { + return; + } + + const int channel = kbc0 / (iter_k*iter_j); + const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k; + + dst += jt*ncols*ne02*D + channel*D; + + // Load the partial result that needs a fixup: + float dst_val[ncols] = {0.0f}; + float max_val[ncols] = {0.0f}; + float rowsum[ncols] = {0.0f}; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (jt*ncols + j >= ne01) { + break; + } + dst_val[j] = dst[j*ne02*D + threadIdx.x]; + + const float2 tmp = dst_fixup[bidx0*ncols + j]; + max_val[j] = tmp.x; + rowsum[j] = tmp.y; + } + + // Iterate over previous blocks and compute the combined results. + // All CUDA blocks that get here must have a previous block that needs a fixup. + int bidx = bidx0 - 1; + int kbc_stop = kbc0; + while(true) { + const int kbc = bidx*iter_k*iter_j*ne02 / gridDim.x; + if (kbc == kbc_stop) { // Did not have any data. + bidx--; + kbc_stop = kbc; + continue; + } + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (jt*ncols + j >= ne01) { + break; + } + const float dst_add = dst_fixup_data[bidx*ncols*D + j*D + threadIdx.x]; + + const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + j]; + + // Scale the current and new value accumulators depending on the max. values. + const float max_val_new = fmaxf(max_val[j], tmp.x); + + const float diff_val = max_val[j] - max_val_new; + const float diff_add = tmp.x - max_val_new; + + const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f; + const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f; + + dst_val[j] = scale_val*dst_val[j] + scale_add*dst_add; + rowsum[j] = scale_val*rowsum[j] + scale_add*tmp.y; + + max_val[j] = max_val_new; + } + + // If this block started in a previous tile we are done and don't need to combine additional partial results. + if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) { + break; + } + bidx--; + kbc_stop = kbc; + } + + // Write back final result: +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (jt*ncols + j >= ne01) { + return; + } + dst[j*ne02*D + threadIdx.x] = dst_val[j] / rowsum[j]; + } +} + template // D == head size #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) @@ -581,10 +679,11 @@ static void on_no_fattn_vec_case(const int D) { } } -template +// parallel_blocks == 0 is stream-k decomposition +template void launch_fattn( ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, - const int nwarps, const int cols_per_block, const bool need_f16_K, const bool need_f16_V + const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V ) { const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; @@ -603,20 +702,23 @@ void launch_fattn( GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); + GGML_ASSERT(Q->ne[3] == 1); + ggml_cuda_pool & pool = ctx.pool(); cudaStream_t main_stream = ctx.stream(); + const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; ggml_cuda_pool_alloc K_f16(pool); ggml_cuda_pool_alloc V_f16(pool); ggml_cuda_pool_alloc dst_tmp(pool); ggml_cuda_pool_alloc dst_tmp_meta(pool); - char * K_data = (char *) K->data; + const char * K_data = (const char *) K->data; size_t nb11 = K->nb[1]; size_t nb12 = K->nb[2]; size_t nb13 = K->nb[3]; - char * V_data = (char *) V->data; + const char * V_data = (const char *) V->data; size_t nb21 = V->nb[1]; size_t nb22 = V->nb[2]; size_t nb23 = V->nb[3]; @@ -649,39 +751,60 @@ void launch_fattn( nb23 = nb23*bs*sizeof(half)/ts; } - if (parallel_blocks > 1) { - dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); - dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); - } + const int ntiles_x = ((Q->ne[1] + cols_per_block - 1) / cols_per_block); + const int ntiles_total = ntiles_x*Q->ne[2]*Q->ne[3]; const dim3 block_dim(WARP_SIZE, nwarps, 1); - const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]); - const int shmem = 0; + dim3 blocks_num; + if (parallel_blocks == 0) { + // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup. + const int tiles_nwaves = (ntiles_total - nsm - 1) / nsm; + const bool tiles_inefficient = 3*nsm < 2*tiles_nwaves*ntiles_total; + const bool short_context = K->ne[1] < 4096; + + const int nblocks_stream_k = 2*nsm; + + blocks_num.x = short_context && !tiles_inefficient ? ntiles_total : nblocks_stream_k; + blocks_num.y = 1; + blocks_num.z = 1; + + dst_tmp_meta.alloc(blocks_num.x*cols_per_block * (2*2 + D) * sizeof(float)); + } else { + blocks_num.x = parallel_blocks*ntiles_x; + blocks_num.y = Q->ne[2]; + blocks_num.z = Q->ne[3]; + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + } + float scale = 1.0f; float max_bias = 0.0f; float logit_softcap = 0.0f; - memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float)); - memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); - memcpy(&logit_softcap, (float *) KQV->op_params + 2, sizeof(float)); + memcpy(&scale, (const float *) KQV->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); if (logit_softcap != 0.0f) { scale /= logit_softcap; } const uint32_t n_head = Q->ne[2]; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head)))); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - fattn_kernel<<>>( + fattn_kernel<<>>( (const char *) Q->data, K_data, V_data, mask ? ((const char *) mask->data) : nullptr, - (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + (parallel_blocks) > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], @@ -693,16 +816,22 @@ void launch_fattn( ); CUDA_CHECK(cudaGetLastError()); - if ((parallel_blocks) == 1) { - return; - } + if constexpr (parallel_blocks == 0) { + if (blocks_num.x % ntiles_total != 0) { // Fixup is only needed if the SMs work on fractional tiles. + const dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine = blocks_num; - const dim3 block_dim_combine(D, 1, 1); - const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); - const int shmem_combine = 0; + flash_attn_stream_k_fixup + <<>> + ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]); + } + } else if constexpr (parallel_blocks > 1) { + const dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); - flash_attn_combine_results - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + flash_attn_combine_results + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + } CUDA_CHECK(cudaGetLastError()); } diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh new file mode 100644 index 00000000000..05bc91a3b8d --- /dev/null +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -0,0 +1,637 @@ +#include "common.cuh" +#include "mma.cuh" +#include "fattn-common.cuh" + +template +static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( + const float2 * const __restrict__ Q_f2, + const half2 * const __restrict__ K_h2, + const half2 * const __restrict__ V_h2, + const half * const __restrict__ maskh, + float2 * const __restrict__ dstk, + float2 * const __restrict__ dstk_fixup, + const float scale, + const float slope, + const float logit_softcap, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int nb21, + const int nb22, + const int nb23, + const int ne0, + const int ne1, + const int ne2, + const int ne3, + const int jt, + const int kb0_start, + const int kb0_stop) { +#ifdef NEW_MMA_AVAILABLE + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + typedef mma_A_I16K8 mma_A; + typedef mma_B_J8K8 mma_B; + typedef mma_C_I16J8 mma_C_KQ; + typedef mma_C_I16J8 mma_C_VKQ; + + static_assert(nwarps*mma_B::J % ncols == 0, "bad nwarps"); + constexpr int np = nwarps*mma_B::J / ncols; // Number of parallel CUDA warps per Q column. + + static_assert(D % nwarps == 0, "bad D"); + static_assert(KQ_stride % nwarps == 0, "bad KQ_stride"); + + constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. + extern __shared__ half2 tile_KV[]; // Temporary shared buffer for loading K/V data with KQ_stride*D logical elements. + + const int stride_Q = nb01 / sizeof(float2); + const int stride_KV = nb11 / sizeof(half2); + const int stride_mask = nb31 / sizeof(half); + + mma_B Q_B[D/(2*mma_B::K)]; + mma_C_VKQ VKQ_C[D/mma_C_VKQ::I]; + + float2 KQ_rowsum = {0.0f, 0.0f}; + float2 KQ_max = {-FLT_MAX/2.0f, -FLT_MAX/2.0f}; + float2 KQ_max_scale = {0.0f, 0.0f}; + + // Temporarily load Q data into tile_KV, will be loaded into registers afterwards. + // The loading is done with decreasing granularity for D for better memory bandwidth. + const half2 scale_h2 = make_half2(scale, scale); +#pragma unroll + for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); + const int k0_stop = D/2 - (D/2) % (1*stride_k); + const int stride_j = WARP_SIZE / stride_k; + + if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) { + break; + } + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps*stride_j) { + const int j = j0 + threadIdx.y*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + + if (jt*ncols + j < ne01) { +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + const float2 tmp = Q_f2[(jt*ncols + j)*stride_Q + k]; + tile_KV[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y); + } + } else { +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + tile_KV[j*D2_padded + k] = make_half2(0.0f, 0.0f); + } + } + } + } + + __syncthreads(); + + { + const int j0 = (threadIdx.y / np) * mma_B::J; + +#pragma unroll + for (int k0 = 0; k0 < D/2; k0 += mma_B::K) { + Q_B[k0/mma_B::K].load_ldmatrix(tile_KV + j0*D2_padded + k0, D2_padded); + } + } + + __syncthreads(); + + // Iterate over ne11 == previous tokens: + for (int kb0 = kb0_start; kb0 < kb0_stop; ++kb0) { + const int k_VKQ_0 = kb0*KQ_stride; + mma_C_KQ KQ_C[KQ_stride/(np*mma_C_KQ::I)]; + + // Load K data into tile with decreasing granularity for D for better memory bandwidth: + static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds"); +#pragma unroll + for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); + const int k0_stop = D/2 - (D/2) % (1*stride_k); + const int stride_i = WARP_SIZE / stride_k; + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < KQ_stride; i_KQ_0 += nwarps*stride_i) { + const int i_KQ = i_KQ_0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + +#pragma unroll + for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += stride_k) { + const int k_KQ = k_KQ_0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + tile_KV[i_KQ*D2_padded + k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV + k_KQ]; + } + } + } + + __syncthreads(); + + // Calculate tile of KQ: +#pragma unroll + for (int i_KQ_00 = 0; i_KQ_00 < KQ_stride; i_KQ_00 += np*mma_A::I) { + const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*mma_A::I; +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += mma_A::K) { + mma_A K_A; + K_A.load_ldmatrix(tile_KV + i_KQ_0*D2_padded + k_KQ_0, D2_padded); + KQ_C[i_KQ_00/(np*mma_A::I)].mma(K_A, Q_B[k_KQ_0/mma_A::K]); + } + } + + __syncthreads(); + + if (use_logit_softcap) { + static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size"); +#pragma unroll + for (int i = 0; i < KQ_stride/(np*mma_C_KQ::I); ++i) { +#pragma unroll + for (int l = 0; l < mma_C_KQ::ne; ++l) { + KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]); + } + } + } + + if (maskh) { + static_assert(KQ_stride % (np *mma_C_KQ::I) == 0, "bad loop size"); + static_assert(ncols % (nwarps/np*mma_C_KQ::J) == 0, "bad loop size"); +#pragma unroll + for (int i00 = 0; i00 < KQ_stride; i00 += np*mma_C_KQ::I) { + const int i0 = i00 + (threadIdx.y % np)*mma_C_KQ::I; +#pragma unroll + for (int l = 0; l < mma_C_KQ::ne; ++l) { + const int i = i0 + mma_C_KQ::get_i(l); + const int j = (threadIdx.y / np)*mma_C_KQ::J + mma_C_KQ::get_j(l); + + KQ_C[i00/(np*mma_C_KQ::I)].x[l] += slope*__half2float(maskh[j*stride_mask + k_VKQ_0 + i]); + } + } + } + + // Calculate softmax for each KQ column using the current max. value. + // The divisor is stored in KQ_rowsum and will be applied at the end. + float2 KQ_max_new = KQ_max; + static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) { +#pragma unroll + for (int l0 = 0; l0 < mma_C_KQ::ne; l0 += 2) { + KQ_max_new.x = fmaxf(KQ_max_new.x, KQ_C[k].x[l0 + 0]); + KQ_max_new.y = fmaxf(KQ_max_new.y, KQ_C[k].x[l0 + 1]); + } + } + + // Values per KQ column are spread across 8 threads, does not need full warp reduce: +#pragma unroll + for (int offset = 16; offset > 2; offset >>= 1) { + KQ_max_new.x = fmaxf(KQ_max_new.x, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.x, offset, WARP_SIZE)); + KQ_max_new.y = fmaxf(KQ_max_new.y, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.y, offset, WARP_SIZE)); + } + + { + const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y); + KQ_max_scale = make_float2(expf(diff.x), expf(diff.y)); + if (diff.x <= SOFTMAX_FTZ_THRESHOLD) { + KQ_max_scale.x = 0.0f; + } + if (diff.y <= SOFTMAX_FTZ_THRESHOLD) { + KQ_max_scale.y = 0.0f; + } + KQ_max = KQ_max_new; + } + + float2 KQ_rowsum_add = make_float2(0.0f, 0.0f); + static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) { +#pragma unroll + for (int l = 0; l < mma_C_KQ::ne; ++l) { + const float KQ_max_l = l % 2 == 0 ? KQ_max.x : KQ_max.y; + const float diff = KQ_C[k].x[l] - KQ_max_l; + KQ_C[k].x[l] = expf(diff); + if (diff <= SOFTMAX_FTZ_THRESHOLD) { + KQ_C[k].x[l] = 0.0f; + } + + if (l % 2 == 0) { + KQ_rowsum_add.x += KQ_C[k].x[l]; + } else { + KQ_rowsum_add.y += KQ_C[k].x[l]; + } + } + } + + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum.x = KQ_max_scale.x*KQ_rowsum.x + KQ_rowsum_add.x; + KQ_rowsum.y = KQ_max_scale.y*KQ_rowsum.y + KQ_rowsum_add.y; + + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale.x, KQ_max_scale.y); +#pragma unroll + for (int i = 0; i < D/mma_C_VKQ::I; ++i) { +#pragma unroll + for (int l = 0; l < mma_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } + + // Convert KQ C tiles into B tiles for VKQ calculation: + mma_B B[KQ_stride/(np*2*mma_B::K)]; + static_assert(KQ_stride % (np*2*mma_B::K) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < KQ_stride/(np*2*mma_B::K); ++k) { + B[k] = KQ_C[k].to_mma_B(); + } + + // Load V data into tile with decreasing granularity for D for better memory bandwidth: + static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds"); +#pragma unroll + for (int stride_i : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int i0_start = stride_i == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_i); + const int i0_stop = D/2 - (D/2) % (1*stride_i); + const int stride_k = WARP_SIZE / stride_i; + +#pragma unroll + for (int k_V_0 = 0; k_V_0 < KQ_stride; k_V_0 += nwarps*stride_k) { + const int k_V = k_V_0 + threadIdx.y*stride_k + (stride_i == WARP_SIZE ? 0 : threadIdx.x / stride_i); + +#pragma unroll + for (int i_V_0 = i0_start; i_V_0 < i0_stop; i_V_0 += stride_i) { + const int i_V = i_V_0 + (stride_i == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_i); + + tile_KV[k_V*D2_padded + i_V] = V_h2[(k_VKQ_0 + k_V)*stride_KV + i_V]; + } + } + } + + __syncthreads(); + + // Calculate VKQ tile: +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += mma_C_VKQ::I) { + static_assert((KQ_stride/2) % (np*mma_A::K) == 0, "bad loop size"); +#pragma unroll + for (int k00 = 0; k00 < KQ_stride/2; k00 += np*mma_A::K) { + const int k0 = k00 + (threadIdx.y % np)*mma_A::K; + + mma_A A; + A.load_ldmatrix_trans(tile_KV + 2*k0*D2_padded + i_VKQ_0/2, D2_padded); + VKQ_C[i_VKQ_0/mma_C_VKQ::I].mma(A, B[k00/(np*mma_A::K)]); + } + } + + __syncthreads(); + } + + // Finally, sum up partial KQ rowsums. + // The partial sums are spread across 8 threads each, does not need full reduce. +#pragma unroll + for (int offset = 16; offset > 2; offset >>= 1) { + KQ_rowsum.x += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum.x, offset, WARP_SIZE); + KQ_rowsum.y += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum.y, offset, WARP_SIZE); + } + + // Write VKQ accumulators to shared memory in column-major format. + // It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. + // Also for np > 1 the combination is done via these values in shared memory. + const int j_cwd = threadIdx.y*mma_B::J + mma_B::get_j(-1); // j combine write data +#pragma unroll + for (int k0 = 0; k0 < D/2; k0 += mma_B::K) { + const mma_B B = VKQ_C[k0/mma_B::K].to_mma_B(); // Conversion of C to B matrix puts it in column-major format. + +#pragma unroll + for (int l = 0; l < mma_B::ne; ++l) { + const int k = k0 + mma_B::get_k(l); + + tile_KV[j_cwd*D2_padded + k] = B.x[l]; + } + } + + const int j_cwmo = (threadIdx.x % (2*mma_C_VKQ::J)) / mma_C_VKQ::J; // j combine write meta offset + const int j_cwm = threadIdx.y*(2*mma_C_VKQ::J) + 2*mma_C_VKQ::get_j(-1) + j_cwmo; // j combine write meta + const float2 KQ_cmr = make_float2(((const float *) &KQ_max)[j_cwmo], ((const float *) &KQ_rowsum)[j_cwmo]); // KQ combine max rowsum + + if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*mma_C_VKQ::J) { + // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. + ((float2 *) tile_KV)[j_cwm*(D2_padded/2) + D/4] = KQ_cmr; + } + + __syncthreads(); + + static_assert(np == 1 || np == 2 || np == 4, "bad np"); + if (np == 1) { + // No combination is needed, the meta data can be directly written from registers to VRAM. + if (needs_fixup && threadIdx.x < mma_B::J) { + float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; + dstk_fixup_meta[j_cwm] = KQ_cmr; + } + if (is_fixup && threadIdx.x < mma_B::J) { + float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; + dstk_fixup_meta[j_cwm] = KQ_cmr; + } + } else if (threadIdx.y % np == 0) { + // Combine the meta data for parallel warps via shared memory. + // Warps with threadIdx.y % np != 0 must NOT return early. + // All threads must return simultaneously to avoid race conditions with work on the next tile. + + float * meta_j = (float *) tile_KV + (threadIdx.y*mma_B::J + threadIdx.x)*D2_padded + D/2; + + float KQ_cm = -FLT_MAX/2; // KQ combine max per parallel warp. + if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) { + KQ_cm = meta_j[0]; + } + + float KQ_cmn = KQ_cm; // KQ combine max new, max between all parallel warps. +#pragma unroll + for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) { + KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE)); + } + + const float KQ_cms = expf(KQ_cm - KQ_cmn); // KQ combine max scale per warp. + float KQ_crs = 0.0f; // KQ combine rowsum, scaled sum of all parallel warps. + if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) { + KQ_crs = KQ_cms*meta_j[1]; + } +#pragma unroll + for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) { + KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); + } + + // Write back combined meta data: + if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) { + meta_j[0] = KQ_cmn; // Combined max. KQ values. + meta_j[1] = KQ_crs; // Combined KQ rowsums. + meta_j[2] = KQ_cms; // KQ max scales per parallel warp. + } + if (needs_fixup && threadIdx.x < mma_B::J) { + float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; + dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); + } + if (is_fixup && threadIdx.x < mma_B::J) { + float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; + dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); + } + } + + if (np > 1) { + __syncthreads(); + } + + if (np == 1 || threadIdx.y % np == 0) { + // The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums. + // The values after that are for the partial results of the individual blocks. + float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(D/2)); + +#pragma unroll + for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); + const int k0_stop = D/2 - (D/2) % (1*stride_k); + const int stride_j = WARP_SIZE / stride_k; + + if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) { + break; + } + +#pragma unroll + for (int j0_dst = 0; j0_dst < ncols; j0_dst += (nwarps/np)*stride_j) { + const int j_dst = j0_dst + (threadIdx.y/np)*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + const int j_tile_KV = (j_dst/mma_B::J)*(np*mma_B::J) + j_dst % mma_B::J; + + if (!is_fixup && jt*ncols + j_dst >= ne01) { + continue; + } + const float * meta_j = (const float *) tile_KV + j_tile_KV*D2_padded + D/2; +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + float2 dstk_val = make_float2(0.0f, 0.0f); +#pragma unroll + for (int ip = 0; ip < np; ++ip) { + const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*mma_B::J*D2_padded + 2]; + const float2 dstk_val_add = __half22float2(tile_KV[(j_tile_KV + ip*mma_B::J)*D2_padded + k]); + dstk_val.x += dstk_val_add.x*KQ_crs; + dstk_val.y += dstk_val_add.y*KQ_crs; + } + + if (!needs_fixup && !is_fixup) { + const float KQ_rowsum_j = meta_j[1]; + dstk_val.x /= KQ_rowsum_j; + dstk_val.y /= KQ_rowsum_j; + } + + if (is_fixup) { + dstk_fixup_data[j_dst*(D/2) + k] = dstk_val; + } else { + dstk[(jt*ncols + j_dst)*ne02*(D/2) + k] = dstk_val; + } + } + } + } + } + + if (np > 1) { + __syncthreads(); + } +#else + NO_DEVICE_CODE; +#endif // NEW_MMA_AVAILABLE +} + +template +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +__launch_bounds__(nwarps*WARP_SIZE, 2) +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +static __global__ void flash_attn_ext_f16( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const float logit_softcap, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int nb21, + const int nb22, + const int nb23, + const int ne0, + const int ne1, + const int ne2, + const int ne3) { + // Skip unused kernel variants for faster compilation: + if (use_logit_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + + static_assert(FATTN_KQ_STRIDE % KQ_stride == 0, "bad KQ_stride"); + + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + + const int iter_k = ne11 / KQ_stride; + const int iter_j = (ne01 + (ncols - 1)) / ncols; + + // kbc == k block continuous, current index in continuous ijk space. + int kbc = (blockIdx.x + 0)*iter_k*iter_j*ne02 / gridDim.x; + const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*ne02 / gridDim.x; + + // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined. + // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup). + // In the most general case >2 seams can fall into the same tile. + + // kb0 == k start index when in the output tile. + int kb0_start = kbc % iter_k; + int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc); + while (kbc < kbc_stop && kb0_stop == iter_k) { + const int channel = kbc / (iter_k*iter_j); + const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. + + const float2 * Q_f2 = (const float2 *) (Q + nb02* channel); + const half2 * K_h2 = (const half2 *) (K + nb12*(channel / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb12*(channel / gqa_ratio)); // K and V have same shape + const half * maskh = mask ? (const half *) mask + (nb31/sizeof(half))*jt*ncols : nullptr; + float2 * dstk = ((float2 *) dst) + channel*(D/2); + + const float slope = get_alibi_slope(max_bias, channel, n_head_log2, m0, m1); + + constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. + if (kb0_start == 0) { + constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap, + ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3, + jt, kb0_start, kb0_stop); + } else { + constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap, + ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3, + jt, kb0_start, kb0_stop); + } + + kbc += iter_k; + kbc -= kbc % iter_k; + + kb0_start = 0; + kb0_stop = min(iter_k, kbc_stop - kbc); + } + + if (kbc >= kbc_stop) { + return; + } + + const int channel = kbc / (iter_k*iter_j); + const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. + + const float2 * Q_f2 = (const float2 *) (Q + nb02* channel); + const half2 * K_h2 = (const half2 *) (K + nb12*(channel / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb12*(channel / gqa_ratio)); // K and V have same shape + const half * maskh = mask ? (const half *) mask + (nb31/sizeof(half))*jt*ncols : nullptr; + float2 * dstk = ((float2 *) dst) + channel*(D/2); + + const float slope = get_alibi_slope(max_bias, channel, n_head_log2, m0, m1); + + constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. + constexpr bool needs_fixup = false; + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap, + ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3, + jt, kb0_start, kb0_stop); +} + +template +void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + typedef mma_A_I16K8 mma_A; + typedef mma_B_J8K8 mma_B; + + static_assert(D % mma_B::K == 0, "bad D"); + static_assert(cols_per_block % mma_B::J == 0, "bad cols_per_block"); + + const ggml_tensor * KQV = dst; + + constexpr int KQ_stride = D <= 128 ? 64 : 32; + constexpr int nwarps = (KQ_stride == 32 && cols_per_block <= 16) ? + cols_per_block/mma_B::J * KQ_stride/mma_A::I : (cols_per_block <= 8 ? 4 : 8); + constexpr size_t nbytes_shared = std::max(KQ_stride, nwarps*mma_B::J) * (D + 8) * sizeof(half); + + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + fattn_kernel_t fattn_kernel; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + fattn_kernel = flash_attn_ext_f16; + } else { + constexpr bool use_logit_softcap = true; + fattn_kernel = flash_attn_ext_f16; + } + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); +} + +#define DECL_FATTN_MMA_F16_CASE(D, cols_per_block) \ + template void ggml_cuda_flash_attn_ext_mma_f16_case \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + +extern DECL_FATTN_MMA_F16_CASE( 64, 8); +extern DECL_FATTN_MMA_F16_CASE( 80, 8); +extern DECL_FATTN_MMA_F16_CASE( 96, 8); +extern DECL_FATTN_MMA_F16_CASE(112, 8); +extern DECL_FATTN_MMA_F16_CASE(128, 8); +extern DECL_FATTN_MMA_F16_CASE(256, 8); + +extern DECL_FATTN_MMA_F16_CASE( 64, 16); +extern DECL_FATTN_MMA_F16_CASE( 80, 16); +extern DECL_FATTN_MMA_F16_CASE( 96, 16); +extern DECL_FATTN_MMA_F16_CASE(112, 16); +extern DECL_FATTN_MMA_F16_CASE(128, 16); +extern DECL_FATTN_MMA_F16_CASE(256, 16); + +extern DECL_FATTN_MMA_F16_CASE( 64, 32); +extern DECL_FATTN_MMA_F16_CASE( 80, 32); +extern DECL_FATTN_MMA_F16_CASE( 96, 32); +extern DECL_FATTN_MMA_F16_CASE(112, 32); +extern DECL_FATTN_MMA_F16_CASE(128, 32); +extern DECL_FATTN_MMA_F16_CASE(256, 32); + +extern DECL_FATTN_MMA_F16_CASE( 64, 64); +extern DECL_FATTN_MMA_F16_CASE( 80, 64); +extern DECL_FATTN_MMA_F16_CASE( 96, 64); +extern DECL_FATTN_MMA_F16_CASE(112, 64); +extern DECL_FATTN_MMA_F16_CASE(128, 64); +extern DECL_FATTN_MMA_F16_CASE(256, 64); diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index 4d314dacb19..d4edbad07f2 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -45,7 +45,17 @@ static __global__ void flash_attn_tile_ext_f16( const int ne2, const int ne3) { #ifdef FP16_AVAILABLE + +#ifndef FLASH_ATTN_AVAILABLE + NO_DEVICE_CODE; + return; +#endif // FLASH_ATTN_AVAILABLE + // Skip unused kernel variants for faster compilation: +#ifdef FP16_MMA_AVAILABLE + NO_DEVICE_CODE; + return; +#endif // FP16_MMA_AVAILABLE if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; return; @@ -288,16 +298,18 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * const ggml_tensor * Q = dst->src[0]; switch (Q->ne[0]) { case 64: { - constexpr int D = 64; - constexpr int nwarps = 8; + constexpr int D = 64; + constexpr int nwarps = 8; + constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); } break; case 128: { - constexpr int D = 128; - constexpr int nwarps = 8; + constexpr int D = 128; + constexpr int nwarps = 8; + constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); } break; default: { GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index bb33604470f..0d274f33255 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -48,7 +48,12 @@ static __global__ void flash_attn_tile_ext_f32( NO_DEVICE_CODE; return; #endif // FLASH_ATTN_AVAILABLE + // Skip unused kernel variants for faster compilation: +#ifdef FP16_MMA_AVAILABLE + NO_DEVICE_CODE; + return; +#endif // FP16_MMA_AVAILABLE if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; return; @@ -287,16 +292,18 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * const ggml_tensor * Q = dst->src[0]; switch (Q->ne[0]) { case 64: { - constexpr int D = 64; - constexpr int nwarps = 8; + constexpr int D = 64; + constexpr int nwarps = 8; + constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); } break; case 128: { - constexpr int D = 128; - constexpr int nwarps = 8; + constexpr int D = 128; + constexpr int nwarps = 8; + constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); } break; default: { GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index 34a2992c769..d9ac4424606 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -42,6 +42,12 @@ static __global__ void flash_attn_vec_ext_f16( const int ne2, const int ne3) { #ifdef FP16_AVAILABLE + +#ifndef FLASH_ATTN_AVAILABLE + NO_DEVICE_CODE; + return; +#endif // FLASH_ATTN_AVAILABLE + // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; @@ -303,7 +309,8 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; constexpr bool need_f16_K = D != 128; constexpr bool need_f16_V = D != 128 && D != 64; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); + constexpr size_t nbytes_shared = 0; + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V); } template diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index a28fc8b7fc8..6ef8f9dcc27 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -41,6 +41,11 @@ static __global__ void flash_attn_vec_ext_f32( const int ne1, const int ne2, const int ne3) { +#ifndef FLASH_ATTN_AVAILABLE + NO_DEVICE_CODE; + return; +#endif // FLASH_ATTN_AVAILABLE + // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; @@ -284,7 +289,8 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32; constexpr bool need_f16_K = D != 128; constexpr bool need_f16_V = D != 128 && D != 64; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); + constexpr size_t nbytes_shared = 0; + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V); } template diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu new file mode 100644 index 00000000000..1054ff95d22 --- /dev/null +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -0,0 +1,648 @@ +// Old and deprecated WMMA FlashAttention implementation. +// It is still needed for Volta since the memory layout of NVIDIA tensor cores changed with Turing. +// Long-term the WMMA code should be replaced with a dedicated Volta implementation. + +#include "common.cuh" +#include "fattn-common.cuh" +#include "fattn-wmma-f16.cuh" + +#ifdef FP16_MMA_AVAILABLE +#include +#endif // FP16_MMA_AVAILABLE + +// D == head size, VKQ_stride == num VKQ rows calculated in parallel: +template +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +__launch_bounds__(nwarps*WARP_SIZE, 1) +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +static __global__ void flash_attn_ext_f16( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const float logit_softcap, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int nb21, + const int nb22, + const int nb23, + const int ne0, + const int ne1, + const int ne2, + const int ne3) { +#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + // Skip unused kernel variants for faster compilation: + if (use_logit_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. + const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. + + static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); + static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); + constexpr int frag_m = ncols == 8 ? 32 : 16; + constexpr int frag_n = ncols == 8 ? 8 : 16; + static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); + typedef nvcuda::wmma::fragment frag_a_K; + typedef nvcuda::wmma::fragment frag_a_V; + typedef nvcuda::wmma::fragment frag_b; + typedef nvcuda::wmma::fragment frag_c_KQ; + typedef nvcuda::wmma::fragment frag_c_VKQ; + + constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. + constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. + static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps."); + + // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts: + constexpr int D_padded = D + 8; + constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; + constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); + + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0); + const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); + const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0; + const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2); + + const int stride_Q = nb01 / sizeof(float); + const int stride_KV = nb11 / sizeof(half); + + const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); + const half slopeh = __float2half(slopef); + const half2 slope2 = make_half2(slopef, slopef); + + const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap); + + frag_b Q_b[D/16][ncols/frag_n]; + + // A single buffer for temporarily holding tiles of KQ and VKQ parts: + constexpr int mem_KQ = ncols*kqs_padded*kqar; + constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded; + __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts]; + float * KQ_f = (float *) KQ; + half2 * KQ2 = (half2 *) KQ; + + float KQ_rowsum_f[ncols/nwarps] = {0.0f}; + float KQ_max_f[ncols/nwarps]; + float KQ_max_scale_f[ncols/nwarps] = {0.0f}; + +#pragma unroll + for (int j = 0; j < ncols/nwarps; ++j) { + KQ_max_f[j] = -FLT_MAX/2.0f; + } + + half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}}; + half2 KQ_max_h2[ncols/nwarps]; + half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}}; + +#pragma unroll + for (int j = 0; j < ncols/nwarps; ++j) { + KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF); + } + + __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. + half2 * VKQ2 = (half2 *) VKQ; +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D/2 && i >= D/2) { + break; + } + VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f); + } + } + + // Convert Q to half and apply scale, temporarily store in KQ: +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; +#pragma unroll + for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D && i >= D) { + break; + } + KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; + } + } + + __syncthreads(); + + // Load Q into tensor core fragments/registers since it will be used frequently: +#pragma unroll + for (int i0 = 0; i0 < D; i0 += 16) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); + } + } + + __syncthreads(); + + // Iterate over ne11 == previous tokens: + for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) { + // Calculate tile of KQ: +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { + frag_c_KQ KQ_c[ncols/frag_n]; +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); + } +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { + frag_a_K K_a; + nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); + } + } +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major); + } + } + + __syncthreads(); + + // Calculate softmax for each KQ column using the current max. value. + // The divisor is stored in KQ_rowsum and will be applied at the end. +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (std::is_same::value) { + float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k]; + + if (use_logit_softcap) { + KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]); + } + } + + float KQ_max_new = KQ_max_f[j0/nwarps]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; + KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]); + } + KQ_max_new = warp_reduce_max(KQ_max_new); + + const float diff = KQ_max_f[j0/nwarps] - KQ_max_new; + KQ_max_scale_f[j0/nwarps] = expf(diff); + if (diff <= SOFTMAX_FTZ_THRESHOLD) { + KQ_max_scale_f[j0/nwarps] = 0.0f; + } + KQ_max_f[j0/nwarps] = KQ_max_new; + + float KQ_rowsum_add = 0.0f; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps]; + KQ_f_tmp[k0/WARP_SIZE] = expf(diff); + if (diff <= SOFTMAX_FTZ_THRESHOLD) { + KQ_f_tmp[k0/WARP_SIZE] = 0.0f; + } + KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE]; + KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE]; + } + KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); + + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add; + } else { + half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k]; + + if (use_logit_softcap) { + // There is no dedicated tangens hyperbolicus function for half2. + KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f)); + KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f)) + /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f)); + + KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2; + } + } + + half2 KQ_max_new = KQ_max_h2[j0/nwarps]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); + } + KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); + const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new; + KQ_max_scale_h2[j0/nwarps] = h2exp(diff); + const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask; + KQ_max_h2[j0/nwarps] = KQ_max_new; + + half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps]; + KQ2_tmp[k0/WARP_SIZE] = h2exp(diff); + const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask; + KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE]; + KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE]; + } + KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); + + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add; + } + } + + __syncthreads(); + + frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n]; +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { + const int k = k0 + (threadIdx.y % VKQ_ratio)*16; + nvcuda::wmma::load_matrix_sync( + KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], + KQ + j0*(kqar*kqs_padded) + k, + kqar*kqs_padded); + } + } + + frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n]; +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f); + } + +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { + const int k = k0 + (threadIdx.y % VKQ_ratio)*16; + + frag_a_V v_a; + nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); + } + } + } + + __syncthreads(); + + const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded); +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + nvcuda::wmma::store_matrix_sync( + KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), + VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], + D_padded, nvcuda::wmma::mem_col_major); + } + } + + __syncthreads(); + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + half2 VKQ_scale; + if (std::is_same::value) { + VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]); + } else { + VKQ_scale = KQ_max_scale_h2[j0/nwarps]; + } + +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D/2 && i >= D/2) { + break; + } + + half2 VKQ_add = make_half2(0.0f, 0.0f); +#pragma unroll + for (int l = 0; l < VKQ_ratio; ++l) { + VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i]; + } + VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add; + } + } + + __syncthreads(); + } + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j_VKQ = j0 + threadIdx.y; + if (ic0 + j_VKQ >= ne01) { + return; + } + const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; + + float KQ_rowsum_j; + if (std::is_same::value) { + KQ_rowsum_j = KQ_rowsum_f[j0/nwarps]; + } else { + KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]); + } + +#pragma unroll + for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D && i >= D) { + break; + } + float dst_val = VKQ[j_VKQ*D_padded + i]; + if (parallel_blocks == 1) { + dst_val /= KQ_rowsum_j; + } + dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val; + } + + if (parallel_blocks == 1 || threadIdx.x != 0) { + continue; + } + + float2 dst_meta_val; + if (std::is_same::value) { + dst_meta_val.x = KQ_max_f[j0/nwarps]; + } else { + dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]); + } + dst_meta_val.y = KQ_rowsum_j; + dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val; + } +#else + NO_DEVICE_CODE; +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +} + +constexpr int get_max_power_of_2(int x) { + return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1; +} + +static_assert(get_max_power_of_2(1) == 1, "Test failed."); +static_assert(get_max_power_of_2(2) == 2, "Test failed."); +static_assert(get_max_power_of_2(4) == 4, "Test failed."); +static_assert(get_max_power_of_2(6) == 2, "Test failed."); + +// Number of VKQ rows calculated in parallel: +constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) { + return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m; +} + +static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed."); +static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed."); +static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed."); +static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed."); +static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed."); +static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); + +template +void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + + constexpr int nwarps = 4; + + constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16; + const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; + const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; + + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + if (4*blocks_num_pb1 < 2*nsm) { + constexpr int parallel_blocks = 4; + fattn_kernel_t fattn_kernel; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } else { + constexpr bool use_logit_softcap = true; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } + launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, true, true); + return; + } + if (2*blocks_num_pb1 < 2*nsm) { + constexpr int parallel_blocks = 2; + fattn_kernel_t fattn_kernel; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } else { + constexpr bool use_logit_softcap = true; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } + launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, true, true); + return; + } + constexpr int parallel_blocks = 1; + fattn_kernel_t fattn_kernel; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } else { + constexpr bool use_logit_softcap = true; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } + launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, true, true); +} + +void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + + const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); + + if (prec != GGML_PREC_DEFAULT) { + if (Q->ne[1] <= 32 || Q->ne[0] > 128) { + constexpr int cols_per_block = 16; + switch (Q->ne[0]) { + case 64: + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst); + break; + case 80: + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst); + break; + case 96: + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst); + break; + case 112: + ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst); + break; + case 128: + ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); + break; + case 256: + ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst); + break; + default: + GGML_ABORT("fatal error"); + break; + } + } else { + constexpr int cols_per_block = 32; + switch (Q->ne[0]) { + case 64: + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst); + break; + case 80: + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst); + break; + case 96: + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst); + break; + case 112: + ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst); + break; + case 128: + ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); + break; + // case 256: + // ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); + // break; + default: + GGML_ABORT("fatal error"); + break; + } + } + return; + } + + if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) { + constexpr int cols_per_block = 8; + switch (Q->ne[0]) { + case 64: + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); + break; + case 96: + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); + break; + case 128: + ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); + break; + case 256: + ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); + break; + default: + GGML_ABORT("fatal error"); + break; + } + return; + } + + if (Q->ne[1] <= 32) { + constexpr int cols_per_block = 16; + switch (Q->ne[0]) { + case 64: + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); + break; + case 80: + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst); + break; + case 96: + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); + break; + case 112: + ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst); + break; + case 128: + ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); + break; + case 256: + ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); + break; + default: + GGML_ABORT("fatal error"); + break; + } + return; + } + + constexpr int cols_per_block = 32; + switch (Q->ne[0]) { + case 64: + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); + break; + case 80: + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst); + break; + case 96: + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); + break; + case 112: + ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst); + break; + case 128: + ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); + break; + case 256: + ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); + break; + default: + GGML_ABORT("fatal error"); + break; + } +} diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh index 860d0e6dc2f..beeea95eb1d 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -1,543 +1,3 @@ #include "common.cuh" -#include "fattn-common.cuh" -#ifdef FP16_MMA_AVAILABLE -#include -#endif // FP16_MMA_AVAILABLE - -// D == head size, VKQ_stride == num VKQ rows calculated in parallel: -template -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) -__launch_bounds__(nwarps*WARP_SIZE, 1) -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) -static __global__ void flash_attn_ext_f16( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, - const float scale, - const float max_bias, - const float m0, - const float m1, - const uint32_t n_head_log2, - const float logit_softcap, - const int ne00, - const int ne01, - const int ne02, - const int ne03, - const int ne10, - const int ne11, - const int ne12, - const int ne13, - const int ne31, - const int nb31, - const int nb01, - const int nb02, - const int nb03, - const int nb11, - const int nb12, - const int nb13, - const int nb21, - const int nb22, - const int nb23, - const int ne0, - const int ne1, - const int ne2, - const int ne3) { -#ifdef FP16_MMA_AVAILABLE - // Skip unused kernel variants for faster compilation: - if (use_logit_softcap && !(D == 128 || D == 256)) { - NO_DEVICE_CODE; - return; - } - - //In this kernel Q, K, V are matrices while i, j, k are matrix indices. - - const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. - const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. - - static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); - static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); - constexpr int frag_m = ncols == 8 ? 32 : 16; - constexpr int frag_n = ncols == 8 ? 8 : 16; - static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); - typedef nvcuda::wmma::fragment frag_a_K; - typedef nvcuda::wmma::fragment frag_a_V; - typedef nvcuda::wmma::fragment frag_b; - typedef nvcuda::wmma::fragment frag_c_KQ; - typedef nvcuda::wmma::fragment frag_c_VKQ; - - constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. - constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. - static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps."); - - // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts: - constexpr int D_padded = D + 8; - constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; - constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); - - const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0); - const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); - const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0; - const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2); - - const int stride_Q = nb01 / sizeof(float); - const int stride_KV = nb11 / sizeof(half); - - const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); - const half slopeh = __float2half(slopef); - const half2 slope2 = make_half2(slopef, slopef); - - const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap); - - frag_b Q_b[D/16][ncols/frag_n]; - - // A single buffer for temporarily holding tiles of KQ and VKQ parts: - constexpr int mem_KQ = ncols*kqs_padded*kqar; - constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded; - __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts]; - float * KQ_f = (float *) KQ; - half2 * KQ2 = (half2 *) KQ; - - float KQ_rowsum_f[ncols/nwarps] = {0.0f}; - float KQ_max_f[ncols/nwarps]; - float KQ_max_scale_f[ncols/nwarps] = {0.0f}; - -#pragma unroll - for (int j = 0; j < ncols/nwarps; ++j) { - KQ_max_f[j] = -FLT_MAX/2.0f; - } - - half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}}; - half2 KQ_max_h2[ncols/nwarps]; - half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}}; - -#pragma unroll - for (int j = 0; j < ncols/nwarps; ++j) { - KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF); - } - - __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. - half2 * VKQ2 = (half2 *) VKQ; -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D/2 && i >= D/2) { - break; - } - VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f); - } - } - - // Convert Q to half and apply scale, temporarily store in KQ: -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; -#pragma unroll - for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D && i >= D) { - break; - } - KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; - } - } - - __syncthreads(); - - // Load Q into tensor core fragments/registers since it will be used frequently: -#pragma unroll - for (int i0 = 0; i0 < D; i0 += 16) { -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += frag_n) { - nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); - } - } - - __syncthreads(); - - // Iterate over ne11 == previous tokens: - for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) { - // Calculate tile of KQ: -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { - frag_c_KQ KQ_c[ncols/frag_n]; -#pragma unroll - for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); - } -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { - frag_a_K K_a; - nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); -#pragma unroll - for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); - } - } -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += frag_n) { - nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major); - } - } - - __syncthreads(); - - // Calculate softmax for each KQ column using the current max. value. - // The divisor is stored in KQ_rowsum and will be applied at the end. -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - if (std::is_same::value) { - float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE]; -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; - - KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k]; - - if (use_logit_softcap) { - KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]); - } - } - - float KQ_max_new = KQ_max_f[j0/nwarps]; -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; - - KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; - KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]); - } - KQ_max_new = warp_reduce_max(KQ_max_new); - - const float diff = KQ_max_f[j0/nwarps] - KQ_max_new; - KQ_max_scale_f[j0/nwarps] = expf(diff); - if (diff <= SOFTMAX_FTZ_THRESHOLD) { - KQ_max_scale_f[j0/nwarps] = 0.0f; - } - KQ_max_f[j0/nwarps] = KQ_max_new; - - float KQ_rowsum_add = 0.0f; -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; - - const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps]; - KQ_f_tmp[k0/WARP_SIZE] = expf(diff); - if (diff <= SOFTMAX_FTZ_THRESHOLD) { - KQ_f_tmp[k0/WARP_SIZE] = 0.0f; - } - KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE]; - KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE]; - } - KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); - - // Scale previous KQ_rowsum to account for a potential increase in KQ_max: - KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add; - } else { - half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)]; -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; - - KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k]; - - if (use_logit_softcap) { - // There is no dedicated tangens hyperbolicus function for half2. - KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f)); - KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f)) - /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f)); - - KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2; - } - } - - half2 KQ_max_new = KQ_max_h2[j0/nwarps]; -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; - - KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); - KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); - } - KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); - const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new; - KQ_max_scale_h2[j0/nwarps] = h2exp(diff); - const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); - *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask; - KQ_max_h2[j0/nwarps] = KQ_max_new; - - half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; - - const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps]; - KQ2_tmp[k0/WARP_SIZE] = h2exp(diff); - const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); - *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask; - KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE]; - KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE]; - } - KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); - - // Scale previous KQ_rowsum to account for a potential increase in KQ_max: - KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add; - } - } - - __syncthreads(); - - frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n]; -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += frag_n) { -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { - const int k = k0 + (threadIdx.y % VKQ_ratio)*16; - nvcuda::wmma::load_matrix_sync( - KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], - KQ + j0*(kqar*kqs_padded) + k, - kqar*kqs_padded); - } - } - - frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n]; -#pragma unroll - for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { -#pragma unroll - for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f); - } - -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { - const int k = k0 + (threadIdx.y % VKQ_ratio)*16; - - frag_a_V v_a; - nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); -#pragma unroll - for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); - } - } - } - - __syncthreads(); - - const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded); -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) { -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += frag_n) { - nvcuda::wmma::store_matrix_sync( - KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), - VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], - D_padded, nvcuda::wmma::mem_col_major); - } - } - - __syncthreads(); - -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - half2 VKQ_scale; - if (std::is_same::value) { - VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]); - } else { - VKQ_scale = KQ_max_scale_h2[j0/nwarps]; - } - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D/2 && i >= D/2) { - break; - } - - half2 VKQ_add = make_half2(0.0f, 0.0f); -#pragma unroll - for (int l = 0; l < VKQ_ratio; ++l) { - VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i]; - } - VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add; - } - } - - __syncthreads(); - } - -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j_VKQ = j0 + threadIdx.y; - if (ic0 + j_VKQ >= ne01) { - return; - } - const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; - - float KQ_rowsum_j; - if (std::is_same::value) { - KQ_rowsum_j = KQ_rowsum_f[j0/nwarps]; - } else { - KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]); - } - -#pragma unroll - for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D && i >= D) { - break; - } - float dst_val = VKQ[j_VKQ*D_padded + i]; - if (parallel_blocks == 1) { - dst_val /= KQ_rowsum_j; - } - dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val; - } - - if (parallel_blocks == 1 || threadIdx.x != 0) { - continue; - } - - float2 dst_meta_val; - if (std::is_same::value) { - dst_meta_val.x = KQ_max_f[j0/nwarps]; - } else { - dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]); - } - dst_meta_val.y = KQ_rowsum_j; - dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val; - } -#else - NO_DEVICE_CODE; -#endif // FP16_MMA_AVAILABLE -} - -constexpr int get_max_power_of_2(int x) { - return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1; -} - -static_assert(get_max_power_of_2(1) == 1, "Test failed."); -static_assert(get_max_power_of_2(2) == 2, "Test failed."); -static_assert(get_max_power_of_2(4) == 4, "Test failed."); -static_assert(get_max_power_of_2(6) == 2, "Test failed."); - -// Number of VKQ rows calculated in parallel: -constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) { - return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m; -} - -static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed."); -static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed."); -static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed."); -static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed."); -static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed."); -static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed."); -static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed."); -static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed."); -static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); - -template -void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * KQV = dst; - const ggml_tensor * Q = dst->src[0]; - - constexpr int nwarps = 4; - - constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16; - const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; - const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; - - float logit_softcap; - memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - - if (4*blocks_num_pb1 < 2*nsm) { - constexpr int parallel_blocks = 4; - fattn_kernel_t fattn_kernel; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - fattn_kernel = flash_attn_ext_f16< - D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; - } else { - constexpr bool use_logit_softcap = true; - fattn_kernel = flash_attn_ext_f16< - D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; - } - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); - return; - } - if (2*blocks_num_pb1 < 2*nsm) { - constexpr int parallel_blocks = 2; - fattn_kernel_t fattn_kernel; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - fattn_kernel = flash_attn_ext_f16< - D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; - } else { - constexpr bool use_logit_softcap = true; - fattn_kernel = flash_attn_ext_f16< - D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; - } - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); - return; - } - constexpr int parallel_blocks = 1; - fattn_kernel_t fattn_kernel; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - fattn_kernel = flash_attn_ext_f16< - D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; - } else { - constexpr bool use_logit_softcap = true; - fattn_kernel = flash_attn_ext_f16< - D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; - } - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); -} - -#define DECL_FATTN_WMMA_F16_CASE(D, cols_per_block, KQ_acc_t) \ - template void ggml_cuda_flash_attn_ext_wmma_f16_case \ - (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ - -extern DECL_FATTN_WMMA_F16_CASE( 64, 16, float); -extern DECL_FATTN_WMMA_F16_CASE( 80, 16, float); -extern DECL_FATTN_WMMA_F16_CASE( 96, 16, float); -extern DECL_FATTN_WMMA_F16_CASE(112, 16, float); -extern DECL_FATTN_WMMA_F16_CASE(128, 16, float); -extern DECL_FATTN_WMMA_F16_CASE(256, 16, float); - -extern DECL_FATTN_WMMA_F16_CASE( 64, 32, float); -extern DECL_FATTN_WMMA_F16_CASE( 80, 32, float); -extern DECL_FATTN_WMMA_F16_CASE( 96, 32, float); -extern DECL_FATTN_WMMA_F16_CASE(112, 32, float); -extern DECL_FATTN_WMMA_F16_CASE(128, 32, float); -// extern DECL_FATTN_WMMA_F16_CASE(256, 16, float); - -extern DECL_FATTN_WMMA_F16_CASE( 64, 8, half); -extern DECL_FATTN_WMMA_F16_CASE( 96, 8, half); -extern DECL_FATTN_WMMA_F16_CASE(128, 8, half); -extern DECL_FATTN_WMMA_F16_CASE(256, 8, half); - -extern DECL_FATTN_WMMA_F16_CASE( 64, 16, half); -extern DECL_FATTN_WMMA_F16_CASE( 80, 16, half); -extern DECL_FATTN_WMMA_F16_CASE( 96, 16, half); -extern DECL_FATTN_WMMA_F16_CASE(112, 16, half); -extern DECL_FATTN_WMMA_F16_CASE(128, 16, half); -extern DECL_FATTN_WMMA_F16_CASE(256, 16, half); - -extern DECL_FATTN_WMMA_F16_CASE( 64, 32, half); -extern DECL_FATTN_WMMA_F16_CASE( 80, 32, half); -extern DECL_FATTN_WMMA_F16_CASE( 96, 32, half); -extern DECL_FATTN_WMMA_F16_CASE(112, 32, half); -extern DECL_FATTN_WMMA_F16_CASE(128, 32, half); -extern DECL_FATTN_WMMA_F16_CASE(256, 16, half); +void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 0b26b0f8e05..b1e66d47083 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -1,5 +1,6 @@ #include "common.cuh" #include "fattn-common.cuh" +#include "fattn-mma-f16.cuh" #include "fattn-tile-f16.cuh" #include "fattn-tile-f32.cuh" #include "fattn-vec-f16.cuh" @@ -7,144 +8,56 @@ #include "fattn-wmma-f16.cuh" #include "fattn.cuh" -#include +template +static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; -static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * KQV = dst; - const ggml_tensor * Q = dst->src[0]; - - const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); - - if (prec != GGML_PREC_DEFAULT) { - if (Q->ne[1] <= 32 || Q->ne[0] > 128) { - constexpr int cols_per_block = 16; - switch (Q->ne[0]) { - case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst); - break; - case 80: - ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst); - break; - case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst); - break; - case 112: - ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst); - break; - case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); - break; - case 256: - ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst); - break; - default: - GGML_ABORT("fatal error"); - break; - } - } else { - constexpr int cols_per_block = 32; - switch (Q->ne[0]) { - case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst); - break; - case 80: - ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst); - break; - case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst); - break; - case 112: - ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst); - break; - case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); - break; - // case 256: - // ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); - // break; - default: - GGML_ABORT("fatal error"); - break; - } - } - return; - } - - if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) { - constexpr int cols_per_block = 8; - switch (Q->ne[0]) { - case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); - break; - case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); - break; - case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); - break; - case 256: - ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); - break; - default: - GGML_ABORT("fatal error"); - break; - } - return; - } - - if (Q->ne[1] <= 32) { - constexpr int cols_per_block = 16; - switch (Q->ne[0]) { - case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); - break; - case 80: - ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst); - break; - case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); - break; - case 112: - ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst); - break; - case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); - break; - case 256: - ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); - break; - default: - GGML_ABORT("fatal error"); - break; - } - return; - } - - constexpr int cols_per_block = 32; switch (Q->ne[0]) { case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_case< 64, cols_per_block>(ctx, dst); break; case 80: - ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_case< 80, cols_per_block>(ctx, dst); break; case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_case< 96, cols_per_block>(ctx, dst); break; case 112: - ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_case<112, cols_per_block>(ctx, dst); break; case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_case<128, cols_per_block>(ctx, dst); break; case 256: - ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_case<256, cols_per_block>(ctx, dst); break; default: GGML_ABORT("fatal error"); break; } } + +static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + + if (Q->ne[1] <= 8) { + ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst); + return; + } + + if (Q->ne[1] <= 16) { + ggml_cuda_flash_attn_ext_mma_f16_switch_hs<16>(ctx, dst); + return; + } + + if (Q->ne[1] <= 32) { + ggml_cuda_flash_attn_ext_mma_f16_switch_hs<32>(ctx, dst); + return; + } + + ggml_cuda_flash_attn_ext_mma_f16_switch_hs<64>(ctx, dst); +} + #define FATTN_VEC_F16_CASE(D, type_K, type_V) \ if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \ ggml_cuda_flash_attn_ext_vec_f16_case(ctx, dst); \ @@ -322,11 +235,19 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst return; } - if (!fp16_mma_available(cc)) { - if (Q->ne[1] <= 8) { - ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); + if (!new_mma_available(cc)) { + if (prec == GGML_PREC_DEFAULT) { + if (Q->ne[1] <= 8) { + ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_tile_f16(ctx, dst); + } } else { - ggml_cuda_flash_attn_ext_tile_f16(ctx, dst); + if (Q->ne[1] <= 8) { + ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_tile_f32(ctx, dst); + } } return; } @@ -341,5 +262,10 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } } - ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); + // The MMA implementation needs Turing or newer, use the old WMMA code for Volta: + if (cc == GGML_CUDA_CC_VOLTA) { + ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); + } + + ggml_cuda_flash_attn_ext_mma_f16(ctx, dst); } diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 7d11540afd2..9788a1389a3 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -1,11 +1,67 @@ +// This file contains primitives that expose the tensor core PTX instructions for CUDA code. +// The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout. +// The documentation for the PTX instructions can be found under: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction +// +// Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C. +// A is a row-major matrix with shape I x K. +// B is a column-major matrix with shape K x J. +// C is a column-major matrix with shape I x J. +// Note that along their lowest dimension I, J, and K are measured in physical 32 bit elements instead of logical elements. +// The functions get_i, get_j, and get_k can be used to get the physical 32 bit index of the lth element of a thread within a tile. +// All matrix tiles have ne physical 32 bit elements per warp. +// +// As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes. + #include "common.cuh" -struct mma_int_A_I16K4 { + +#if CUDART_VERSION >= 11800 + +static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { + int ret = 0; + +#ifdef NEW_MMA_AVAILABLE + asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" + : "+r"(ret) : "r"(x)); +#else + NO_DEVICE_CODE; +#endif // defined(NEW_MMA_AVAILABLE) + return ret; +} + +#else + +static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { + // Imagine transposing row-major matrix to column-major matrix. + const int src_i_low = 2 * (threadIdx.x % 4); + const int src_i_high = src_i_low + 1; + const int src_j = threadIdx.x / 4; + + const int src_laneid_low = src_i_low * 4 + src_j / 2; + const int src_laneid_high = src_i_high * 4 + src_j / 2; + + const int shift_low = ((src_j + 0) % 2) * 16; + const int shift_high = ((src_j + 1) % 2) * 16; + + const int ret_low = (__shfl_sync(0xFFFFFFFF, x, src_laneid_low, WARP_SIZE) >> shift_low) & 0x0000FFFF; + const int ret_high = (__shfl_sync(0xFFFFFFFF, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000; + + return ret_low | ret_high; +} + +#endif // CUDART_VERSION >= 11800 + + +template +struct mma_A_I16K4 { + static_assert(sizeof(T) == 4, "bad type size"); + static constexpr int I = 16; static constexpr int K = 4; static constexpr int ne = 2; - int x[ne] = {0}; + T x[ne]; static __device__ __forceinline__ int get_i(const int l) { const int ret = (l%2) * (I/2) + threadIdx.x / K; @@ -21,27 +77,35 @@ struct mma_int_A_I16K4 { return ret; } - __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) { -#if defined(INT8_MMA_AVAILABLE) - const int * xs = xs0 + (threadIdx.x%I)*stride; - asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" - : "+r"(x[0]), "+r"(x[1]) - : "l"(xs)); -#else + __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) { #pragma unroll for (int l = 0; l < ne; ++l) { x[l] = xs0[get_i(l)*stride + get_k(l)]; } -#endif // defined(INT8_MMA_AVAILABLE) + } + + __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) { +#ifdef NEW_MMA_AVAILABLE + int * xi = (int *) x; + const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride; + asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" + : "+r"(xi[0]), "+r"(xi[1]) + : "l"(xs)); +#else + load_generic(xs0, stride); +#endif // NEW_MMA_AVAILABLE } }; -struct mma_int_A_I16K8 { +template +struct mma_A_I16K8 { + static_assert(sizeof(T) == 4, "bad type size"); + static constexpr int I = 16; static constexpr int K = 8; static constexpr int ne = 4; - int x[ne] = {0}; + T x[ne]; static __device__ __forceinline__ int get_i(const int l) { const int ret = (l%2) * (I/2) + threadIdx.x / (K/2); @@ -57,31 +121,62 @@ struct mma_int_A_I16K8 { return ret; } - __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) { -#if defined(INT8_MMA_AVAILABLE) - const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2); - asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" - : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3]) - : "l"(xs)); -#else + __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) { #pragma unroll for (int l = 0; l < ne; ++l) { x[l] = xs0[get_i(l)*stride + get_k(l)]; } -#endif // defined(INT8_MMA_AVAILABLE) } - __device__ __forceinline__ void load_low(const int * __restrict__ xs0, const int & stride) { - ((mma_int_A_I16K4 *) x)[0].load(xs0, stride); + __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) { +#ifdef NEW_MMA_AVAILABLE + int * xi = (int * ) x; + const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2); + asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" + : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3]) + : "l"(xs)); +#else + GGML_UNUSED(xs0); + GGML_UNUSED(stride); + NO_DEVICE_CODE; +#endif // NEW_MMA_AVAILABLE + } + + __device__ __forceinline__ void load_ldmatrix_trans(const T * __restrict__ xs0, const int & stride) { +#ifdef NEW_MMA_AVAILABLE + int * xi = (int * ) x; + const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2); + asm("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];" + : "+r"(xi[0]), "+r"(xi[2]), "+r"(xi[1]), "+r"(xi[3]) + : "l"(xs)); +#else + GGML_UNUSED(xs0); + GGML_UNUSED(stride); + NO_DEVICE_CODE; +#endif // NEW_MMA_AVAILABLE + } + + __device__ __forceinline__ void transpose() { + int * xi = (int *) x; + xi[0] = ggml_cuda_movmatrix(xi[0]); + + const int tmp = ggml_cuda_movmatrix(xi[1]); + xi[1] = ggml_cuda_movmatrix(xi[2]); + xi[2] = tmp; + + xi[3] = ggml_cuda_movmatrix(xi[3]); } }; -struct mma_int_B_J8K4 { +template +struct mma_B_J8K4 { + static_assert(sizeof(T) == 4, "bad type size"); + static constexpr int J = 8; static constexpr int K = 4; static constexpr int ne = 1; - int x[ne] = {0}; + T x[ne]; static __device__ __forceinline__ int get_j(const int /* l */) { const int ret = threadIdx.x / K; @@ -97,27 +192,34 @@ struct mma_int_B_J8K4 { return ret; } - __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) { -#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster - const int * xs = xs0 + (threadIdx.x%J)*stride; - asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];" - : "+r"(x[0]) - : "l"(xs)); -#else + __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) { #pragma unroll for (int l = 0; l < ne; ++l) { x[l] = xs0[get_j(l)*stride + get_k(l)]; } -#endif // defined(INT8_MMA_AVAILABLE) + } + + __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) { +#ifdef NEW_MMA_AVAILABLE + int * xi = (int *) x; + const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride; + asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];" + : "+r"(xi[0]) : "l"(xs)); +#else + load_generic(xs0, stride); +#endif // NEW_MMA_AVAILABLE } }; -struct mma_int_B_J8K8 { +template +struct mma_B_J8K8 { + static_assert(sizeof(T) == 4, "bad type size"); + static constexpr int J = 8; static constexpr int K = 8; static constexpr int ne = 2; - int x[ne] = {0}; + T x[ne]; static __device__ __forceinline__ int get_j(const int /* l */) { const int ret = threadIdx.x / (K/2); @@ -133,22 +235,31 @@ struct mma_int_B_J8K8 { return ret; } - __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) { -#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster - const int * xs = xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K; - asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" - : "+r"(x[0]), "+r"(x[1]) - : "l"(xs)); -#else + __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) { #pragma unroll for (int l = 0; l < ne; ++l) { x[l] = xs0[get_j(l)*stride + get_k(l)]; } -#endif // defined(INT8_MMA_AVAILABLE) + } + + __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) { +#ifdef NEW_MMA_AVAILABLE + int * xi = (int *) x; + const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K; + asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" + : "+r"(xi[0]), "+r"(xi[1]) + : "l"(xs)); +#else + load_generic(xs0, stride); +#endif // NEW_MMA_AVAILABLE } }; -struct mma_int_C_I16J8 { +template +struct mma_C_I16J8 {}; + +template <> +struct mma_C_I16J8 { static constexpr int I = 16; static constexpr int J = 8; static constexpr int ne = 4; @@ -169,8 +280,8 @@ struct mma_int_C_I16J8 { return ret; } - __device__ __forceinline__ void mma_K4(const mma_int_A_I16K4 & mma_A, const mma_int_B_J8K4 & mma_B) { -#ifdef INT8_MMA_AVAILABLE + __device__ __forceinline__ void mma(const mma_A_I16K4 & mma_A, const mma_B_J8K4 & mma_B) { +#ifdef NEW_MMA_AVAILABLE #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3]) @@ -188,11 +299,11 @@ struct mma_int_C_I16J8 { GGML_UNUSED(mma_A); GGML_UNUSED(mma_B); NO_DEVICE_CODE; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } - __device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) { -#ifdef INT8_MMA_AVAILABLE + __device__ __forceinline__ void mma(const mma_A_I16K8 & mma_A, const mma_B_J8K8 & mma_B) { +#ifdef NEW_MMA_AVAILABLE #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3]) @@ -216,6 +327,132 @@ struct mma_int_C_I16J8 { GGML_UNUSED(mma_A); GGML_UNUSED(mma_B); NO_DEVICE_CODE; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE + } +}; + +template <> +struct mma_C_I16J8 { + static constexpr int I = 16; + static constexpr int J = 4; + static constexpr int ne = 2; + + half2 x[ne] = {{0.0f, 0.0f}, {0.0f, 0.0f}}; + + static __device__ __forceinline__ int get_i(const int l) { + const int ret = l * (I/2) + threadIdx.x / J; + GGML_CUDA_ASSUME(ret >= 0); + GGML_CUDA_ASSUME(ret < I); + return ret; + } + + static __device__ __forceinline__ int get_j(const int /* l */) { + const int ret = threadIdx.x % J; + GGML_CUDA_ASSUME(ret >= 0); + GGML_CUDA_ASSUME(ret < J); + return ret; + } + + __device__ __forceinline__ void mma(const mma_A_I16K8 & mma_A, const mma_B_J8K8 & mma_B) { +#ifdef NEW_MMA_AVAILABLE + int * Axi = (int *) mma_A.x; + int * Bxi = (int *) mma_B.x; + int * xi = (int *) x; +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};" + : "+r"(xi[0]), "+r"(xi[1]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1])); +#else + // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead: + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(xi[0]), "+r"(xi[1]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0])); + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(xi[0]), "+r"(xi[1]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1])); +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#else + GGML_UNUSED(mma_A); + GGML_UNUSED(mma_B); + NO_DEVICE_CODE; +#endif // NEW_MMA_AVAILABLE + } + + __device__ __forceinline__ mma_B_J8K8 to_mma_B() { + mma_B_J8K8 mma_B; + + int * xi = (int *) x; + int * Bxi = (int *) mma_B.x; + Bxi[0] = ggml_cuda_movmatrix(xi[0]); + Bxi[1] = ggml_cuda_movmatrix(xi[1]); + + return mma_B; + } +}; + +template <> +struct mma_C_I16J8 { + static constexpr int I = 16; + static constexpr int J = 8; + static constexpr int ne = 4; + + float x[ne] = {0.0f, 0.0f, 0.0f, 0.0f}; + + static __device__ __forceinline__ int get_i(const int l) { + const int ret = (l/2) * (I/2) + threadIdx.x / (J/2); + GGML_CUDA_ASSUME(ret >= 0); + GGML_CUDA_ASSUME(ret < I); + return ret; + } + + static __device__ __forceinline__ int get_j(const int l) { + const int ret = 2 * (threadIdx.x % (J/2)) + l%2; + GGML_CUDA_ASSUME(ret >= 0); + GGML_CUDA_ASSUME(ret < J); + return ret; + } + + __device__ __forceinline__ void mma(const mma_A_I16K8 & mma_A, const mma_B_J8K8 & mma_B) { +#ifdef NEW_MMA_AVAILABLE + int * Axi = (int *) mma_A.x; + int * Bxi = (int *) mma_B.x; + int * xi = (int *) x; +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1])); +#else + // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead: + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0])); + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1])); +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#else + GGML_UNUSED(mma_A); + GGML_UNUSED(mma_B); + NO_DEVICE_CODE; +#endif // NEW_MMA_AVAILABLE + } + + __device__ __forceinline__ mma_B_J8K8 to_mma_B() { + mma_B_J8K8 mma_B; + mma_B.x[0] = make_half2(x[0], x[1]); + mma_B.x[1] = make_half2(x[2], x[3]); + + int * Bxi = (int *) mma_B.x; + Bxi[0] = ggml_cuda_movmatrix(Bxi[0]); + Bxi[1] = ggml_cuda_movmatrix(Bxi[1]); + + return mma_B; + } + + __device__ __forceinline__ void load_generic(const float * __restrict__ xs0, const int & stride) { +#pragma unroll + for (int l = 0; l < ne; ++l) { + x[l] = xs0[get_j(l)*stride + get_i(l)]; + } } }; diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 270251df4f1..83cb78cbdd4 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -132,7 +132,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { return false; } - if (int8_mma_available(cc)) { + if (new_mma_available(cc)) { return true; } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 3cd508a1d4a..c05c8477812 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -87,7 +87,7 @@ struct tile_x_sizes { }; static constexpr int get_mmq_x_max_host(const int cc) { - return int8_mma_available(cc) ? 128 : + return new_mma_available(cc) ? 128 : #ifdef GGML_CUDA_FORCE_MMQ cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? 128 : 64; #else @@ -96,9 +96,9 @@ static constexpr int get_mmq_x_max_host(const int cc) { } static constexpr __device__ int get_mmq_x_max_device() { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE return 128; -#else // INT8_MMA_AVAILABLE +#else // NEW_MMA_AVAILABLE #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) return 128; @@ -116,7 +116,7 @@ static constexpr __device__ int get_mmq_x_max_device() { #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } static constexpr int get_mmq_y_host(const int cc) { @@ -209,10 +209,10 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1) static int mmq_get_granularity_host(const int mmq_x, const int cc) { - return int8_mma_available(cc) && mmq_x >= 48 ? 16 : 8; + return new_mma_available(cc) && mmq_x >= 48 ? 16 : 8; } -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { return mmq_x >= 48 ? 16 : 8; } @@ -220,21 +220,21 @@ static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */) { return 8; } -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE // ------------------------------------------------------------ template static __device__ __forceinline__ void load_tiles_q4_0( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + 2*WARP_SIZE); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kbx = threadIdx.x / QI4_0; const int kqsx = threadIdx.x % QI4_0; @@ -250,12 +250,12 @@ template static __device__ __forceinlin const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx; const int qs0 = get_int_b2(bxi->qs, kqsx); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808); x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808); #else x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI4_0; @@ -271,11 +271,11 @@ template static __device__ __forceinlin const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } @@ -322,14 +322,14 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( template static __device__ __forceinline__ void load_tiles_q4_1( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kbx = threadIdx.x / QI4_1; const int kqsx = threadIdx.x % QI4_1; @@ -345,12 +345,12 @@ template static __device__ __forceinlin const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx; const int qs0 = get_int_b4(bxi->qs, kqsx); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F; #else x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI4_1; @@ -366,11 +366,11 @@ template static __device__ __forceinlin const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; #else x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + kbxd] = bxi->dm; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } @@ -417,14 +417,14 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( template static __device__ __forceinline__ void load_tiles_q5_0( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kbx = threadIdx.x / QI5_0; const int kqsx = threadIdx.x % QI5_0; @@ -456,13 +456,13 @@ template static __device__ __forceinlin qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; #else x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0; x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI5_0; @@ -478,25 +478,25 @@ template static __device__ __forceinlin const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } template static __device__ __forceinline__ void load_tiles_q5_1( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kbx = threadIdx.x / QI5_1; const int kqsx = threadIdx.x % QI5_1; @@ -526,13 +526,13 @@ template static __device__ __forceinlin qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; #else x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0; x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI5_1; @@ -548,25 +548,25 @@ template static __device__ __forceinlin const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; #else x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } template static __device__ __forceinline__ void load_tiles_q8_0( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_tile + 2*WARP_SIZE); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kbx = threadIdx.x / QI8_0; const int kqsx = threadIdx.x % QI8_0; @@ -581,13 +581,13 @@ template static __device__ __forceinlin const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx); x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx); #else x_qs[i*(2*WARP_SIZE + 1) + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx); x_qs[i*(2*WARP_SIZE + 1) + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int blocks_per_tile_x_row = 2*WARP_SIZE / QI8_0; @@ -603,11 +603,11 @@ template static __device__ __forceinlin const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } @@ -645,9 +645,9 @@ template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { - typedef mma_int_A_I16K8 mma_A; - typedef mma_int_B_J8K8 mma_B; - typedef mma_int_C_I16J8 mma_C; + typedef mma_A_I16K8 mma_A; + typedef mma_B_J8K8 mma_B; + typedef mma_C_I16J8 mma_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; @@ -672,7 +672,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { const int k0 = k00 + k01; - A[n][k01/QI8_0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); + A[n][k01/QI8_0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); } #pragma unroll @@ -695,7 +695,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( mma_B B; float dB[mma_C::ne/2]; - B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { @@ -711,7 +711,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( #pragma unroll for (int n = 0; n < ntx; ++n) { mma_C C; - C.mma_K8(A[n][k01/QI8_0], B); + C.mma(A[n][k01/QI8_0], B); #pragma unroll for (int l = 0; l < mma_C::ne; ++l) { @@ -756,9 +756,9 @@ template static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { - typedef mma_int_A_I16K8 mma_A; - typedef mma_int_B_J8K8 mma_B; - typedef mma_int_C_I16J8 mma_C; + typedef mma_A_I16K8 mma_A; + typedef mma_B_J8K8 mma_B; + typedef mma_C_I16J8 mma_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; @@ -782,7 +782,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { const int k0 = k00 + k01; - A[n][k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); + A[n][k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); } #pragma unroll @@ -805,7 +805,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( mma_B B; float2 dsB[mma_C::ne/2]; - B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { @@ -817,7 +817,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( #pragma unroll for (int n = 0; n < ntx; ++n) { mma_C C; - C.mma_K8(A[n][k01/QI8_1], B); + C.mma(A[n][k01/QI8_1], B); #pragma unroll for (int l = 0; l < mma_C::ne; ++l) { @@ -864,12 +864,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE - typedef mma_int_A_I16K4 mma_A; - typedef mma_int_A_I16K8 mma_A_K8; - typedef mma_int_B_J8K4 mma_B; - typedef mma_int_C_I16J8 mma_C; + typedef mma_A_I16K4 mma_A; + typedef mma_A_I16K8 mma_A_K8; + typedef mma_B_J8K4 mma_B; + typedef mma_C_I16J8 mma_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; @@ -893,7 +893,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { const int k0 = k00 + k01; - ((mma_A_K8 *) A[n])[k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); + ((mma_A_K8 *) A[n])[k01/8].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); } #pragma unroll @@ -916,8 +916,9 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( mma_B B[2]; float dB[mma_C::ne/2]; - B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); - B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K); + // Here load_generic is faster than load_ldmatrix. + B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); + B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K); #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { @@ -929,8 +930,8 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( #pragma unroll for (int n = 0; n < ntx; ++n) { mma_C C[2]; - C[0].mma_K4(A[n][k01/4 + 0], B[0]); - C[1].mma_K4(A[n][k01/4 + 1], B[1]); + C[0].mma(A[n][k01/4 + 0], B[0]); + C[1].mma(A[n][k01/4 + 1], B[1]); #pragma unroll for (int l = 0; l < mma_C::ne; ++l) { @@ -942,20 +943,20 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( #else GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); NO_DEVICE_CODE; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q2_K( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kqsx = threadIdx.x % QI2_K; @@ -977,11 +978,11 @@ template static __device__ __forceinlin const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k; #else x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int sc_m = bxi->scales[kqsx]; @@ -992,11 +993,11 @@ template static __device__ __forceinlin const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4)); #endif // FAST_FP16_AVAILABLE -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik; #else x_dm[i*(WARP_SIZE + 1) + kqsx] = x_dm_ik; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } @@ -1051,12 +1052,12 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE - typedef mma_int_A_I16K4 mma_A; - typedef mma_int_A_I16K8 mma_A_K8; - typedef mma_int_B_J8K4 mma_B; - typedef mma_int_C_I16J8 mma_C; + typedef mma_A_I16K4 mma_A; + typedef mma_A_I16K8 mma_A_K8; + typedef mma_B_J8K4 mma_B; + typedef mma_C_I16J8 mma_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; @@ -1081,7 +1082,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { const int k0 = k00 + k01; - ((mma_A_K8 *) A[n])[k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); + ((mma_A_K8 *) A[n])[k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); } } @@ -1118,24 +1119,25 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { mma_B B[2]; - B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); - B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K); + // Here load_generic is faster than load_ldmatrix. + B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); + B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K); mma_C Cm[2]; if (k01 >= WARP_SIZE * 3/4) { mma_A A1; A1.x[0] = 0x01010101; A1.x[1] = 0x01010101; - Cm[0].mma_K4(A1, B[0]); - Cm[1].mma_K4(A1, B[1]); + Cm[0].mma(A1, B[0]); + Cm[1].mma(A1, B[1]); } #pragma unroll for (int n = 0; n < ntx; ++n) { mma_C Cd[2]; - Cd[0].mma_K4(A[n][k01/4 + 0], B[0]); - Cd[1].mma_K4(A[n][k01/4 + 1], B[1]); + Cd[0].mma(A[n][k01/4 + 0], B[0]); + Cd[1].mma(A[n][k01/4 + 1], B[1]); #pragma unroll for (int l = 0; l < mma_C::ne; ++l) { @@ -1172,13 +1174,13 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( #else GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); NO_DEVICE_CODE; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q3_K( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE*2); #else @@ -1186,7 +1188,7 @@ template static __device__ __forceinlin int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); int * x_sc = (int *) (x_df + txs.dm); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kqsx = threadIdx.x % QI3_K; @@ -1212,11 +1214,11 @@ template static __device__ __forceinlin const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k; #else x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } @@ -1242,7 +1244,7 @@ template static __device__ __forceinlin const int sc = __vsubss4(sc_low | sc_high, 0x20202020); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE const int8_t * sc8 = (const int8_t *) ≻ const float d = bxi->d; @@ -1252,10 +1254,10 @@ template static __device__ __forceinlin } #else x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } -#ifndef INT8_MMA_AVAILABLE +#ifndef NEW_MMA_AVAILABLE #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) { int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y; @@ -1268,7 +1270,7 @@ template static __device__ __forceinlin x_df[i] = bxi->d; } -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } template @@ -1317,7 +1319,7 @@ static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, co template static __device__ __forceinline__ void load_tiles_q4_K( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); #else @@ -1325,7 +1327,7 @@ template static __device__ __forceinlin int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); int * x_sc = (int *) (x_dm + txs.dm); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { @@ -1338,15 +1340,15 @@ template static __device__ __forceinlin const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; const int qs0 = get_int_b4(bxi->qs, threadIdx.x); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F; #else x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) { @@ -1407,7 +1409,7 @@ template static __device__ __forceinlin x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8; } -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } template @@ -1446,7 +1448,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( template static __device__ __forceinline__ void load_tiles_q5_K( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2); #else @@ -1454,7 +1456,7 @@ template static __device__ __forceinlin int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); int * x_sc = (int *) (x_dm + txs.dm); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { @@ -1478,16 +1480,16 @@ template static __device__ __forceinlin const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0; const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + QI5_K/4; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1; #else x_qs[i*(2*WARP_SIZE + 1) + kq0] = ql0 | qh0; x_qs[i*(2*WARP_SIZE + 1) + kq1] = ql1 | qh1; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) { @@ -1548,7 +1550,7 @@ template static __device__ __forceinlin x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8; } -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } template @@ -1587,7 +1589,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( template static __device__ __forceinline__ void load_tiles_q6_K( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE*2); int * x_sc = (int *) (x_df + WARP_SIZE/QI6_K); @@ -1596,7 +1598,7 @@ template static __device__ __forceinlin int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); int * x_sc = (int *) (x_df + txs.dm); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { @@ -1619,13 +1621,13 @@ template static __device__ __forceinlin const int kq0 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + 0; const int kq1 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + QI6_K/2; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020); x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020); #else x_qs[i*(2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); x_qs[i*(2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256 @@ -1641,11 +1643,11 @@ template static __device__ __forceinlin const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q6_K + kbxd] = bxi->d; #else x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + kbxd] = bxi->d; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } #pragma unroll @@ -1658,11 +1660,11 @@ template static __device__ __forceinlin const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8)); #else x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8)); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } @@ -1702,11 +1704,11 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE - typedef mma_int_A_I16K4 mma_A; - typedef mma_int_B_J8K4 mma_B; - typedef mma_int_C_I16J8 mma_C; + typedef mma_A_I16K4 mma_A; + typedef mma_B_J8K4 mma_B; + typedef mma_C_I16J8 mma_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; @@ -1732,8 +1734,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { const int k0 = k00 + k01; - A[n][k01/4 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K); - A[n][k01/4 + 1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K); + A[n][k01/4 + 0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K); + A[n][k01/4 + 1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K); } #pragma unroll @@ -1771,8 +1773,9 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( mma_B B[2]; float dB[mma_C::ne/2]; - B[0].load(y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K); - B[1].load(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K); + // Here load_generic is faster than load_ldmatrix. + B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K); + B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K); #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { @@ -1784,8 +1787,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( #pragma unroll for (int n = 0; n < ntx; ++n) { mma_C C[2]; - C[0].mma_K4(A[n][k01/4 + 0], B[0]); - C[1].mma_K4(A[n][k01/4 + 1], B[1]); + C[0].mma(A[n][k01/4 + 0], B[0]); + C[1].mma(A[n][k01/4 + 1], B[1]); #pragma unroll for (int l = 0; l < mma_C::ne; ++l) { @@ -1805,20 +1808,20 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( #else GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); NO_DEVICE_CODE; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_iq4_nl( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kbx = threadIdx.x / QI4_NL; const int kqsx = threadIdx.x % QI4_NL; @@ -1836,13 +1839,13 @@ template static __device__ __forceinlin const int aux_q4 = get_int_b2(bxi->qs, kqsx); const int2 v = get_int_from_table_16(aux_q4); const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; #else x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x; x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI4_NL; @@ -1858,25 +1861,25 @@ template static __device__ __forceinlin const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d); #else x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = __half2float(bxi->d); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } template static __device__ __forceinline__ void load_tiles_iq2_xxs( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kqsx = threadIdx.x % (QI2_XXS/2); @@ -1905,36 +1908,36 @@ template static __device__ __forceinlin const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000); const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1; #else x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid0; x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid1; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int ls = aux32 >> 28; const float d = bxi->d; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4; #else x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/4; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } template static __device__ __forceinline__ void load_tiles_iq2_xs( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE*2); #else constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kqsx = threadIdx.x % (QI2_XS/2); @@ -1959,38 +1962,38 @@ template static __device__ __forceinlin const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]); const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; #else x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int ls = bxi->scales[kqsx]; const float d = bxi->d; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; #else x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } template static __device__ __forceinline__ void load_tiles_iq2_s( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kqsx = threadIdx.x % (QI2_S/2); @@ -2022,38 +2025,38 @@ template static __device__ __forceinlin const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0); const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; #else x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int ls = bxi->scales[kqsx]; const float d = bxi->d; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; #else x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } template static __device__ __forceinline__ void load_tiles_iq3_xxs( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kqsx = threadIdx.x % (QI3_XXS/2); @@ -2080,36 +2083,36 @@ template static __device__ __forceinlin const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]); const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h; #else x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int ls = aux32 >> 28; const float d = bxi->d; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2; #else x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/2; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } template static __device__ __forceinline__ void load_tiles_iq3_s( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kqsx = threadIdx.x % (QI3_S/2); @@ -2143,36 +2146,36 @@ template static __device__ __forceinlin const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h; #else x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid_l; x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid_h; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F); const float d = bxi->d; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d; #else x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = ls*d; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } template static __device__ __forceinline__ void load_tiles_iq1_s( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; half2 * x_ds = (half2 *) (x_qs + WARP_SIZE*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); int * x_qs = (int *) x_tile; half2 * x_ds = (half2 *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kqsx = threadIdx.x % QI1_S; @@ -2198,37 +2201,37 @@ template static __device__ __forceinlin const int grid0 = (grid >> 0) & 0x0F0F0F0F; const int grid1 = (grid >> 4) & 0x0F0F0F0F; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1; #else x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid0; x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid1; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1); const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta); #else x_ds[i*(WARP_SIZE/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } template static __device__ __forceinline__ void load_tiles_iq4_xs( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kbx = 0; // threadIdx.x / QI4_XS const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS @@ -2246,13 +2249,13 @@ template static __device__ __forceinlin const int aux_q4 = get_int_b4(bxi->qs, kqsx); const int2 v = get_int_from_table_16(aux_q4); const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; #else x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x; x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } #pragma unroll @@ -2270,11 +2273,11 @@ template static __device__ __forceinlin const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F) | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32); #else x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * (ls - 32); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } @@ -2307,16 +2310,16 @@ template static __device__ __forceinline__ void mmq_write_back_mma( const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) { - typedef mma_int_C_I16J8 mma_C; + typedef mma_C_I16J8 mma_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. const int i0 = (threadIdx.y / ntx) * (ntx*mma_C::I); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y"); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { @@ -2505,13 +2508,13 @@ static __device__ void mul_mat_q_process_tile( int * tile_y = (int *) data_mul_mat_q; int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_mma; constexpr mmq_write_back_t write_back = mmq_write_back_mma; #else constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_dp4a; constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE constexpr int blocks_per_iter = MMQ_ITER_K / qk; @@ -2643,7 +2646,7 @@ static __global__ void mul_mat_q( const int jt = kbc / (blocks_per_ne00*nty); const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00; - constexpr bool fixup = true; // Last index writes it data to fixup buffer to avoid data races with other blocks. + constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. mul_mat_q_process_tile (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0, it, jt, kb0_start, kb0_stop); @@ -2749,7 +2752,7 @@ template static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) { const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y); const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type); - const int shmem_x = int8_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); + const int shmem_x = new_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); const int shmem_y = mmq_x*sizeof(block_q8_1_mmq); return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int)); } diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb16.cu new file mode 100644 index 00000000000..f09bdeff79a --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb16.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 16); +DECL_FATTN_MMA_F16_CASE(80, 16); +DECL_FATTN_MMA_F16_CASE(96, 16); +DECL_FATTN_MMA_F16_CASE(112, 16); +DECL_FATTN_MMA_F16_CASE(128, 16); +DECL_FATTN_MMA_F16_CASE(256, 16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb32.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb32.cu new file mode 100644 index 00000000000..221108873a0 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb32.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 32); +DECL_FATTN_MMA_F16_CASE(80, 32); +DECL_FATTN_MMA_F16_CASE(96, 32); +DECL_FATTN_MMA_F16_CASE(112, 32); +DECL_FATTN_MMA_F16_CASE(128, 32); +DECL_FATTN_MMA_F16_CASE(256, 32); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb64.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb64.cu new file mode 100644 index 00000000000..d24b085758d --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb64.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 64); +DECL_FATTN_MMA_F16_CASE(80, 64); +DECL_FATTN_MMA_F16_CASE(96, 64); +DECL_FATTN_MMA_F16_CASE(112, 64); +DECL_FATTN_MMA_F16_CASE(128, 64); +DECL_FATTN_MMA_F16_CASE(256, 64); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb8.cu new file mode 100644 index 00000000000..bdf86c0eaba --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb8.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 8); +DECL_FATTN_MMA_F16_CASE(80, 8); +DECL_FATTN_MMA_F16_CASE(96, 8); +DECL_FATTN_MMA_F16_CASE(112, 8); +DECL_FATTN_MMA_F16_CASE(128, 8); +DECL_FATTN_MMA_F16_CASE(256, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu deleted file mode 100644 index 2d94e65c28c..00000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +++ /dev/null @@ -1,10 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-wmma-f16.cuh" - -DECL_FATTN_WMMA_F16_CASE(64, 16, float); -DECL_FATTN_WMMA_F16_CASE(80, 16, float); -DECL_FATTN_WMMA_F16_CASE(96, 16, float); -DECL_FATTN_WMMA_F16_CASE(112, 16, float); -DECL_FATTN_WMMA_F16_CASE(128, 16, float); -DECL_FATTN_WMMA_F16_CASE(256, 16, float); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu deleted file mode 100644 index c3d9df3c443..00000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +++ /dev/null @@ -1,9 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-wmma-f16.cuh" - -DECL_FATTN_WMMA_F16_CASE(64, 32, float); -DECL_FATTN_WMMA_F16_CASE(80, 32, float); -DECL_FATTN_WMMA_F16_CASE(96, 32, float); -DECL_FATTN_WMMA_F16_CASE(112, 32, float); -DECL_FATTN_WMMA_F16_CASE(128, 32, float); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu deleted file mode 100644 index bb680e401f7..00000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +++ /dev/null @@ -1,10 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-wmma-f16.cuh" - -DECL_FATTN_WMMA_F16_CASE(64, 16, half); -DECL_FATTN_WMMA_F16_CASE(80, 16, half); -DECL_FATTN_WMMA_F16_CASE(96, 16, half); -DECL_FATTN_WMMA_F16_CASE(112, 16, half); -DECL_FATTN_WMMA_F16_CASE(128, 16, half); -DECL_FATTN_WMMA_F16_CASE(256, 16, half); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu deleted file mode 100644 index 073f71b1f3e..00000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +++ /dev/null @@ -1,10 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-wmma-f16.cuh" - -DECL_FATTN_WMMA_F16_CASE(64, 32, half); -DECL_FATTN_WMMA_F16_CASE(80, 32, half); -DECL_FATTN_WMMA_F16_CASE(96, 32, half); -DECL_FATTN_WMMA_F16_CASE(112, 32, half); -DECL_FATTN_WMMA_F16_CASE(128, 32, half); -DECL_FATTN_WMMA_F16_CASE(256, 32, half); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu deleted file mode 100644 index d30710c5fa2..00000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +++ /dev/null @@ -1,8 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-wmma-f16.cuh" - -DECL_FATTN_WMMA_F16_CASE(64, 8, half); -DECL_FATTN_WMMA_F16_CASE(96, 8, half); -DECL_FATTN_WMMA_F16_CASE(128, 8, half); -DECL_FATTN_WMMA_F16_CASE(256, 8, half); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index d7874e6eaf8..a2628f16e57 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -12,13 +12,13 @@ DECL_FATTN_VEC_F{vkq_size}_CASE({head_size}, {type_k}, {type_v}); """ -SOURCE_FATTN_WMMA_START = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. +SOURCE_FATTN_MMA_START = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. -#include "../fattn-wmma-f16.cuh" +#include "../fattn-mma-f16.cuh" """ -SOURCE_FATTN_WMMA_CASE = "DECL_FATTN_WMMA_F16_CASE({head_size}, {cols_per_block}, {kq_acc_t});\n" +SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {cols_per_block});\n" TYPES_MMQ = [ "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", @@ -57,20 +57,12 @@ def get_head_sizes(type_k, type_v): with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f: f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v)) -for kq_acc_t in ["half", "float"]: - for cols_per_block in [8, 16, 32]: - if kq_acc_t == "float" and cols_per_block == 8: - continue +for cols_per_block in [8, 16, 32, 64]: + with open(f"fattn-mma-f16-instance-cpb{cols_per_block}.cu", "w") as f: + f.write(SOURCE_FATTN_MMA_START) - with open(f"fattn-wmma-f16-instance-kq{kq_acc_t}-cpb{cols_per_block}.cu", "w") as f: - f.write(SOURCE_FATTN_WMMA_START) - - for head_size in [64, 80, 96, 112, 128, 256]: - if cols_per_block == 8 and head_size % 32 != 0: # wmma fragment is 8x32 - continue - if kq_acc_t == "float" and cols_per_block == 32 and head_size == 256: # register spilling, bad performance - continue - f.write(SOURCE_FATTN_WMMA_CASE.format(kq_acc_t=kq_acc_t, cols_per_block=cols_per_block, head_size=head_size)) + for head_size in [64, 80, 96, 112, 128, 256]: + f.write(SOURCE_FATTN_MMA_CASE.format(cols_per_block=cols_per_block, head_size=head_size)) for type in TYPES_MMQ: with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f: diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 8594093f052..129478ed785 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -25,6 +25,7 @@ #define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice #define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite #define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }} +#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width) #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) #define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6 #define cublasCreate hipblasCreate diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index 7a877bdc11a..eb03e10fa48 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -50,7 +50,7 @@ file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh") list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h") file(GLOB GGML_SOURCES_ROCM "../ggml-cuda/*.cu") -file(GLOB SRCS "../ggml-cuda/template-instances/fattn-wmma*.cu") +file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu") list(APPEND GGML_SOURCES_ROCM ${SRCS}) file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu") list(APPEND GGML_SOURCES_ROCM ${SRCS}) diff --git a/ggml/src/ggml-musa/CMakeLists.txt b/ggml/src/ggml-musa/CMakeLists.txt index 415b2b2e09c..2f555416e62 100644 --- a/ggml/src/ggml-musa/CMakeLists.txt +++ b/ggml/src/ggml-musa/CMakeLists.txt @@ -29,7 +29,7 @@ if (MUSAToolkit_FOUND) list(APPEND GGML_HEADERS_MUSA "../../include/ggml-cuda.h") file(GLOB GGML_SOURCES_MUSA "../ggml-cuda/*.cu") - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-wmma*.cu") + file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu") list(APPEND GGML_SOURCES_MUSA ${SRCS}) file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu") list(APPEND GGML_SOURCES_MUSA ${SRCS}) From c49ee07ff413c64630db58ab08fd51e0378336f5 Mon Sep 17 00:00:00 2001 From: uvos Date: Sun, 2 Feb 2025 22:08:05 +0100 Subject: [PATCH 45/64] HIP: add GGML_CUDA_CC_IS_* for amd familys as increasing cc archtectures for amd gpus are not supersets of eatch other (llama/11601) This fixes a bug where RDNA1 gpus other than gfx1010 where not handled correctly --- ggml/src/ggml-cuda/common.cuh | 7 +++++++ ggml/src/ggml-cuda/ggml-cuda.cu | 4 ++-- ggml/src/ggml-cuda/mmq.cu | 2 +- ggml/src/ggml-cuda/mmq.cuh | 2 +- 4 files changed, 11 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 88be8fc8a6e..232163c1c6f 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -61,6 +61,13 @@ #define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a #define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA +#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1) +#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2) +#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3) +#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3) +#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA) +#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1) + #define GGML_CUDA_CC_QY1 210 #define GGML_CUDA_CC_QY2 220 diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 383131c7789..bda10aec118 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -1205,7 +1205,7 @@ static void ggml_cuda_op_mul_mat_cublas( CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream)); - if (compute_capability == GGML_CUDA_CC_CDNA) { + if (GGML_CUDA_CC_IS_CDNA(compute_capability)) { const float alpha = 1.0f; const float beta = 0.0f; CUBLAS_CHECK( @@ -1750,7 +1750,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co beta = &beta_f32; } - if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) { + if (GGML_CUDA_CC_IS_CDNA(ggml_cuda_info().devices[ctx.device].cc)) { cu_compute_type = CUBLAS_COMPUTE_32F; alpha = &alpha_f32; beta = &beta_f32; diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 83cb78cbdd4..45212f66c00 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -148,5 +148,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { return cc < GGML_CUDA_CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; } - return (cc < GGML_CUDA_CC_RDNA3 && cc != GGML_CUDA_CC_CDNA && cc != GGML_CUDA_CC_VEGA20) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; + return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc) && !GGML_CUDA_CC_IS_GCN(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index c05c8477812..7a2c4d85b79 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -120,7 +120,7 @@ static constexpr __device__ int get_mmq_x_max_device() { } static constexpr int get_mmq_y_host(const int cc) { - return cc >= GGML_CUDA_CC_OFFSET_AMD ? (cc == GGML_CUDA_CC_RDNA1 ? 64 : 128) : (cc >= GGML_CUDA_CC_VOLTA ? 128 : 64); + return cc >= GGML_CUDA_CC_OFFSET_AMD ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) : (cc >= GGML_CUDA_CC_VOLTA ? 128 : 64); } static constexpr __device__ int get_mmq_y_device() { From 9906792ec34b801963743b93f244178119223413 Mon Sep 17 00:00:00 2001 From: uvos Date: Sun, 2 Feb 2025 22:40:09 +0100 Subject: [PATCH 46/64] CUDA/HIP: add support for selectable warp size to mmv (llama/11519) CUDA/HIP: add support for selectable warp size to mmv --- ggml/src/ggml-cuda/common.cuh | 8 +++++++ ggml/src/ggml-cuda/mmv.cu | 38 ++++++++++++++++++++------------ ggml/src/ggml-cuda/vendors/hip.h | 2 ++ 3 files changed, 34 insertions(+), 14 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 232163c1c6f..174916bc970 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -176,6 +176,14 @@ static constexpr bool new_mma_available(const int cc) { return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING; } +static constexpr __device__ int ggml_cuda_get_physical_warp_size() { +#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) + return __AMDGCN_WAVEFRONT_SIZE; +#else + return 32; +#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) +} + [[noreturn]] static __device__ void no_device_code( const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) { diff --git a/ggml/src/ggml-cuda/mmv.cu b/ggml/src/ggml-cuda/mmv.cu index ac45f2d17f1..5a9ddd9580a 100644 --- a/ggml/src/ggml-cuda/mmv.cu +++ b/ggml/src/ggml-cuda/mmv.cu @@ -5,9 +5,10 @@ template static __global__ void mul_mat_vec( const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row, const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) { - const int64_t row = blockIdx.x; - const int64_t channel = blockIdx.z; - const int tid = threadIdx.x; + const int64_t row = blockIdx.x; + const int64_t channel = blockIdx.z; + const int tid = threadIdx.x; + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); x += (channel/channel_ratio)*stride_channel_x + row*stride_row; y += channel *stride_channel_y; @@ -18,8 +19,8 @@ static __global__ void mul_mat_vec( extern __shared__ char data_mmv[]; float * buf_iw = (float *) data_mmv; - if (block_size > WARP_SIZE) { - if (tid < WARP_SIZE) { + if (block_size > warp_size) { + if (tid < warp_size) { buf_iw[tid] = 0.0f; } __syncthreads(); @@ -67,16 +68,16 @@ static __global__ void mul_mat_vec( static_assert(std::is_same::value, "unsupported type"); } - sumf = warp_reduce_sum(sumf); + sumf = warp_reduce_sum(sumf); - if (block_size > WARP_SIZE) { - buf_iw[tid/WARP_SIZE] = sumf; + if (block_size > warp_size) { + buf_iw[tid/warp_size] = sumf; __syncthreads(); - if (tid >= WARP_SIZE) { + if (tid >= warp_size) { return; } sumf = buf_iw[tid]; - sumf = warp_reduce_sum(sumf); + sumf = warp_reduce_sum(sumf); } if (tid != 0) { @@ -96,10 +97,19 @@ static void launch_mul_mat_vec_cuda( GGML_ASSERT(stride_row % 2 == 0); GGML_ASSERT(nchannels_y % nchannels_x == 0); const int64_t channel_ratio = nchannels_y / nchannels_x; + int device; + int warp_size; - int64_t block_size_best = WARP_SIZE; - int64_t niter_best = (ncols + 2*WARP_SIZE - 1) / (2*WARP_SIZE); - for (int64_t block_size = 2*WARP_SIZE; block_size <= 256; block_size += WARP_SIZE) { + CUDA_CHECK(cudaGetDevice(&device)); + warp_size = ggml_cuda_info().devices[device].warp_size; + + int64_t block_size_best = warp_size; + int64_t niter_best = (ncols + 2*warp_size - 1) / (2*warp_size); + int64_t max_block_size = 256; + if(ggml_cuda_info().devices[device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info().devices[device].cc < GGML_CUDA_CC_RDNA1) { + max_block_size = 128; + } + for (int64_t block_size = 2*warp_size; block_size <= max_block_size; block_size += warp_size) { const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size); if (niter < niter_best) { niter_best = niter; @@ -107,7 +117,7 @@ static void launch_mul_mat_vec_cuda( } } - const int smem = WARP_SIZE*sizeof(float); + const int smem = warp_size*sizeof(float); const dim3 block_nums(nrows, 1, nchannels_y); const dim3 block_dims(block_size_best, 1, 1); switch (block_size_best) { diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 129478ed785..81964611c60 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -1,5 +1,6 @@ #pragma once +#define HIP_ENABLE_WARP_SYNC_BUILTINS 1 #include #include #include @@ -8,6 +9,7 @@ // for rocblas_initialize() #include "rocblas/rocblas.h" #endif // __HIP_PLATFORM_AMD__ + #define CUBLAS_COMPUTE_16F HIPBLAS_R_16F #define CUBLAS_COMPUTE_32F HIPBLAS_R_32F #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F From fad2806352c58beca08667e30b2830c9ba192932 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 2 Feb 2025 23:48:29 +0100 Subject: [PATCH 47/64] HIP: fix flash_attn_stream_k_fixup warning (llama/11604) --- ggml/src/ggml-cuda/fattn-common.cuh | 10 ++++++++++ ggml/src/ggml-cuda/softmax.cu | 4 ++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index cfd7c0f4475..d40ee2da418 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -516,6 +516,12 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) { nullptr; } +// The HIP compiler for some reason complains that it can't unroll a loop because of the jt*ncols + j >= ne01 conditional. +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpass-failed" +#endif // __clang__ + template // D == head size #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) @@ -614,6 +620,10 @@ static __global__ void flash_attn_stream_k_fixup( } } +#ifdef __clang__ +#pragma clang diagnostic pop +#endif // __clang__ + template // D == head size #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index da377200e01..aac6e099988 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -18,7 +18,7 @@ __device__ float __forceinline__ t2f32(half val) { #ifdef __clang__ #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wpass-failed" -#endif +#endif // __clang__ template static __global__ void soft_max_f32( const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, @@ -126,7 +126,7 @@ static __global__ void soft_max_f32( } #ifdef __clang__ #pragma clang diagnostic pop -#endif +#endif // __clang__ static __global__ void soft_max_back_f32( const float * grad, const float * dstf, float * dst, const int ncols, const float scale) { From dbeb7916b8489bbe615907ad0a6d8f9eaf15f58e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 3 Feb 2025 13:25:56 +0100 Subject: [PATCH 48/64] CUDA: fix Volta FlashAttention logic (llama/11615) --- ggml/src/ggml-cuda/fattn-wmma-f16.cu | 2 +- ggml/src/ggml-cuda/fattn.cu | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 1054ff95d22..45702ad651f 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -561,7 +561,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); break; // case 256: - // ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); + // ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst); // break; default: GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index b1e66d47083..b0cf152f52c 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -235,7 +235,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst return; } - if (!new_mma_available(cc)) { + if (!fp16_mma_available(cc)) { if (prec == GGML_PREC_DEFAULT) { if (Q->ne[1] <= 8) { ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); @@ -265,6 +265,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst // The MMA implementation needs Turing or newer, use the old WMMA code for Volta: if (cc == GGML_CUDA_CC_VOLTA) { ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); + return; } ggml_cuda_flash_attn_ext_mma_f16(ctx, dst); From 344b98a44f5494d470e55f7a0a46b459bb24e2c3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 3 Feb 2025 16:05:27 +0200 Subject: [PATCH 49/64] scripts : fix sync paths --- scripts/sync-ggml-am.sh | 54 ++++++++++++++++++++--------------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/scripts/sync-ggml-am.sh b/scripts/sync-ggml-am.sh index f5defa75920..ed3ecd89e88 100755 --- a/scripts/sync-ggml-am.sh +++ b/scripts/sync-ggml-am.sh @@ -138,33 +138,33 @@ if [ -f $SRC_WHISPER/ggml-src.patch ]; then # ggml/scripts/gen-authors.sh -> scripts/gen-authors.sh cat ggml-src.patch | sed -E \ - -e 's/(^[[:space:]]|[ab]\/)CMakeLists.txt/\1ggml\/CMakeLists.txt/g' \ - -e 's/(^[[:space:]]|[ab]\/)src\/CMakeLists.txt/\1ggml\/src\/CMakeLists.txt/g' \ - -e 's/(^[[:space:]]|[ab]\/)cmake\/FindSIMD.cmake/\1ggml\/cmake\/FindSIMD.cmake/g' \ - -e 's/([[:space:]]|[ab]\/)src\/ggml(.*)\.c/\1ggml\/src\/ggml\2.c/g' \ - -e 's/([[:space:]]|[ab]\/)src\/ggml(.*)\.cpp/\1ggml\/src\/ggml\2.cpp/g' \ - -e 's/([[:space:]]|[ab]\/)src\/ggml(.*)\.h/\1ggml\/src\/ggml\2.h/g' \ - -e 's/([[:space:]]|[ab]\/)src\/gguf(.*)\.cpp/\1ggml\/src\/gguf\2.cpp/g' \ - -e 's/([[:space:]]|[ab]\/)src\/ggml-blas\//\1ggml\/src\/ggml-blas\//g' \ - -e 's/([[:space:]]|[ab]\/)src\/ggml-cann\//\1ggml\/src\/ggml-cann\//g' \ - -e 's/([[:space:]]|[ab]\/)src\/ggml-cpu\//\1ggml\/src\/ggml-cpu\//g' \ - -e 's/([[:space:]]|[ab]\/)src\/ggml-cuda\//\1ggml\/src\/ggml-cuda\//g' \ - -e 's/([[:space:]]|[ab]\/)src\/ggml-hip\//\1ggml\/src\/ggml-hip\//g' \ - -e 's/([[:space:]]|[ab]\/)src\/ggml-kompute\//\1ggml\/src\/ggml-kompute\//g' \ - -e 's/([[:space:]]|[ab]\/)src\/ggml-metal\//\1ggml\/src\/ggml-metal\//g' \ - -e 's/([[:space:]]|[ab]\/)src\/ggml-musa\//\1ggml\/src\/ggml-musa\//g' \ - -e 's/([[:space:]]|[ab]\/)src\/ggml-opencl\//\1ggml\/src\/ggml-opencl\//g' \ - -e 's/([[:space:]]|[ab]\/)src\/ggml-rpc\//\1ggml\/src\/ggml-rpc\//g' \ - -e 's/([[:space:]]|[ab]\/)src\/ggml-sycl\//\1ggml\/src\/ggml-sycl\//g' \ - -e 's/([[:space:]]|[ab]\/)src\/ggml-vulkan\//\1ggml\/src\/ggml-vulkan\//g' \ - -e 's/(^[[:space:]]|[ab]\/)include\/ggml(.*)\.h/\1ggml\/include\/ggml\2.h/g' \ - -e 's/(^[[:space:]]|[ab]\/)include\/gguf(.*)\.h/\1ggml\/include\/gguf\2.h/g' \ - -e 's/(^[[:space:]]|[ab]\/)examples\/common\.h/\1examples\/common.h/g' \ - -e 's/(^[[:space:]]|[ab]\/)examples\/common\.cpp/\1examples\/common.cpp/g' \ - -e 's/(^[[:space:]]|[ab]\/)examples\/common-ggml\.h/\1examples\/common-ggml.h/g' \ - -e 's/(^[[:space:]]|[ab]\/)examples\/common-ggml\.cpp/\1examples\/common-ggml.cpp/g' \ - -e 's/(^[[:space:]]|[ab]\/)LICENSE/\1LICENSE/g' \ - -e 's/(^[[:space:]]|[ab]\/)scripts\/gen-authors\.sh/\1scripts\/gen-authors.sh/g' \ + -e 's/(^[[:space:]]| [ab]\/)CMakeLists.txt/\1ggml\/CMakeLists.txt/g' \ + -e 's/(^[[:space:]]| [ab]\/)src\/CMakeLists.txt/\1ggml\/src\/CMakeLists.txt/g' \ + -e 's/(^[[:space:]]| [ab]\/)cmake\/FindSIMD.cmake/\1ggml\/cmake\/FindSIMD.cmake/g' \ + -e 's/([[:space:]]| [ab]\/)src\/ggml(.*)\.c/\1ggml\/src\/ggml\2.c/g' \ + -e 's/([[:space:]]| [ab]\/)src\/ggml(.*)\.cpp/\1ggml\/src\/ggml\2.cpp/g' \ + -e 's/([[:space:]]| [ab]\/)src\/ggml(.*)\.h/\1ggml\/src\/ggml\2.h/g' \ + -e 's/([[:space:]]| [ab]\/)src\/gguf(.*)\.cpp/\1ggml\/src\/gguf\2.cpp/g' \ + -e 's/([[:space:]]| [ab]\/)src\/ggml-blas\//\1ggml\/src\/ggml-blas\//g' \ + -e 's/([[:space:]]| [ab]\/)src\/ggml-cann\//\1ggml\/src\/ggml-cann\//g' \ + -e 's/([[:space:]]| [ab]\/)src\/ggml-cpu\//\1ggml\/src\/ggml-cpu\//g' \ + -e 's/([[:space:]]| [ab]\/)src\/ggml-cuda\//\1ggml\/src\/ggml-cuda\//g' \ + -e 's/([[:space:]]| [ab]\/)src\/ggml-hip\//\1ggml\/src\/ggml-hip\//g' \ + -e 's/([[:space:]]| [ab]\/)src\/ggml-kompute\//\1ggml\/src\/ggml-kompute\//g' \ + -e 's/([[:space:]]| [ab]\/)src\/ggml-metal\//\1ggml\/src\/ggml-metal\//g' \ + -e 's/([[:space:]]| [ab]\/)src\/ggml-musa\//\1ggml\/src\/ggml-musa\//g' \ + -e 's/([[:space:]]| [ab]\/)src\/ggml-opencl\//\1ggml\/src\/ggml-opencl\//g' \ + -e 's/([[:space:]]| [ab]\/)src\/ggml-rpc\//\1ggml\/src\/ggml-rpc\//g' \ + -e 's/([[:space:]]| [ab]\/)src\/ggml-sycl\//\1ggml\/src\/ggml-sycl\//g' \ + -e 's/([[:space:]]| [ab]\/)src\/ggml-vulkan\//\1ggml\/src\/ggml-vulkan\//g' \ + -e 's/(^[[:space:]]| [ab]\/)include\/ggml(.*)\.h/\1ggml\/include\/ggml\2.h/g' \ + -e 's/(^[[:space:]]| [ab]\/)include\/gguf(.*)\.h/\1ggml\/include\/gguf\2.h/g' \ + -e 's/(^[[:space:]]| [ab]\/)examples\/common\.h/\1examples\/common.h/g' \ + -e 's/(^[[:space:]]| [ab]\/)examples\/common\.cpp/\1examples\/common.cpp/g' \ + -e 's/(^[[:space:]]| [ab]\/)examples\/common-ggml\.h/\1examples\/common-ggml.h/g' \ + -e 's/(^[[:space:]]| [ab]\/)examples\/common-ggml\.cpp/\1examples\/common-ggml.cpp/g' \ + -e 's/(^[[:space:]]| [ab]\/)LICENSE/\1LICENSE/g' \ + -e 's/(^[[:space:]]| [ab]\/)scripts\/gen-authors\.sh/\1scripts\/gen-authors.sh/g' \ > ggml-src.patch.tmp mv ggml-src.patch.tmp ggml-src.patch From edc5d9267c71e3018fc90139ece5acec0afb2ec6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 3 Feb 2025 16:05:34 +0200 Subject: [PATCH 50/64] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 1a990ac68c3..34f1cbf69ee 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -f33c42adaf5d7fc24093f3975c162d905d8e111a +498e0ecd2c4f9379439fd413805af10e8e9ff349 From b8ab126343a7142848a6e57bb5263093624d8399 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 3 Feb 2025 16:24:38 +0200 Subject: [PATCH 51/64] cmake : sync cmake scripts --- ggml/cmake/BuildTypes.cmake | 54 ++++++++++++ ggml/cmake/GitVars.cmake | 22 +++++ ggml/cmake/ggml-config.cmake.in | 147 ++++++++++++++++++++++++++++++++ scripts/sync-ggml.sh | 2 +- 4 files changed, 224 insertions(+), 1 deletion(-) create mode 100644 ggml/cmake/BuildTypes.cmake create mode 100644 ggml/cmake/GitVars.cmake create mode 100644 ggml/cmake/ggml-config.cmake.in diff --git a/ggml/cmake/BuildTypes.cmake b/ggml/cmake/BuildTypes.cmake new file mode 100644 index 00000000000..a9c7b6c91ec --- /dev/null +++ b/ggml/cmake/BuildTypes.cmake @@ -0,0 +1,54 @@ +# Add new build types + +# ReleaseGG - Release with enabled asserts + +SET(CMAKE_CXX_FLAGS_RELEASEGG + "-O3" + CACHE STRING "Flags used by the c++ compiler during release builds with enabled asserts." + FORCE ) +SET(CMAKE_C_FLAGS_RELEASEGG + "-O3" + CACHE STRING "Flags used by the compiler during release builds with enabled asserts." + FORCE ) +SET(CMAKE_EXE_LINKER_FLAGS_RELEASEGG + "" + CACHE STRING "Flags used for linking binaries during release builds with enabled asserts." + FORCE ) +SET(CMAKE_SHARED_LINKER_FLAGS_RELEASEGG + "" + CACHE STRING "Flags used by the shared libraries linker during release builds with enabled asserts." + FORCE ) +MARK_AS_ADVANCED( + CMAKE_CXX_FLAGS_RELEASEGG + CMAKE_C_FLAGS_RELEASEGG + CMAKE_EXE_LINKER_FLAGS_RELEASEGG + CMAKE_SHARED_LINKER_FLAGS_RELEASEGG ) + +# RelWithDebInfoGG - RelWithDebInfo with enabled asserts + +SET(CMAKE_CXX_FLAGS_RELWITHDEBINFOGG + "-O2 -g" + CACHE STRING "Flags used by the c++ compiler during release builds with debug symbols and enabled asserts." + FORCE ) +SET(CMAKE_C_FLAGS_RELWITHDEBINFOGG + "-O2 -g" + CACHE STRING "Flags used by the compiler during release builds with debug symbols and enabled asserts." + FORCE ) +SET(CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFOGG + "" + CACHE STRING "Flags used for linking binaries during release builds with debug symbols and enabled asserts." + FORCE ) +SET(CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFOGG + "" + CACHE STRING "Flags used by the shared libraries linker during release builds with debug symbols and enabled asserts." + FORCE ) +MARK_AS_ADVANCED( + CMAKE_CXX_FLAGS_RELWITHDEBINFOGG + CMAKE_C_FLAGS_RELWITHDEBINFOGG + CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFOGG + CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFOGG ) + +if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE) + set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo" "ReleaseGG" "RelWithDebInfoGG") +endif() diff --git a/ggml/cmake/GitVars.cmake b/ggml/cmake/GitVars.cmake new file mode 100644 index 00000000000..1a4c24ebf6a --- /dev/null +++ b/ggml/cmake/GitVars.cmake @@ -0,0 +1,22 @@ +find_package(Git) + +# the commit's SHA1 +execute_process(COMMAND + "${GIT_EXECUTABLE}" describe --match=NeVeRmAtCh --always --abbrev=8 + WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}" + OUTPUT_VARIABLE GIT_SHA1 + ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) + +# the date of the commit +execute_process(COMMAND + "${GIT_EXECUTABLE}" log -1 --format=%ad --date=local + WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}" + OUTPUT_VARIABLE GIT_DATE + ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) + +# the subject of the commit +execute_process(COMMAND + "${GIT_EXECUTABLE}" log -1 --format=%s + WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}" + OUTPUT_VARIABLE GIT_COMMIT_SUBJECT + ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) diff --git a/ggml/cmake/ggml-config.cmake.in b/ggml/cmake/ggml-config.cmake.in new file mode 100644 index 00000000000..bf39f9c007b --- /dev/null +++ b/ggml/cmake/ggml-config.cmake.in @@ -0,0 +1,147 @@ + +@GGML_VARIABLES_EXPANDED@ + +@PACKAGE_INIT@ + +set_and_check(GGML_INCLUDE_DIR "@PACKAGE_GGML_INCLUDE_INSTALL_DIR@") +set_and_check(GGML_LIB_DIR "@PACKAGE_GGML_LIB_INSTALL_DIR@") +set_and_check(GGML_BIN_DIR "@PACKAGE_GGML_BIN_INSTALL_DIR@") + +find_package(Threads REQUIRED) + +find_library(GGML_LIBRARY ggml + REQUIRED + HINTS ${GGML_LIB_DIR} + NO_CMAKE_FIND_ROOT_PATH) + +add_library(ggml::ggml UNKNOWN IMPORTED) +set_target_properties(ggml::ggml + PROPERTIES + IMPORTED_LOCATION "${GGML_LIBRARY}") + +find_library(GGML_BASE_LIBRARY ggml-base + REQUIRED + HINTS ${GGML_LIB_DIR} + NO_CMAKE_FIND_ROOT_PATH) + +add_library(ggml::ggml-base UNKNOWN IMPORTED) +set_target_properties(ggml::ggml-base + PROPERTIES + IMPORTED_LOCATION "${GGML_BASE_LIBRARY}") + +if (NOT GGML_SHARED_LIB) + if (APPLE AND GGML_ACCELERATE) + find_library(ACCELERATE_FRAMEWORK Accelerate REQUIRED) + list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${ACCELERATE_FRAMEWORK}) + endif() + + if (GGML_OPENMP) + find_package(OpenMP REQUIRED) + list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES OpenMP::OpenMP_C OpenMP::OpenMP_CXX) + endif() + + if (GGML_CPU_HBM) + find_library(memkind memkind REQUIRED) + list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES memkind) + endif() + + if (GGML_BLAS) + find_package(BLAS REQUIRED) + list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${BLAS_LIBRARIES}) + list(APPEND GGML_CPU_INTERFACE_LINK_OPTIONS ${BLAS_LINKER_FLAGS}) + endif() + + if (GGML_CUDA) + find_package(CUDAToolkit REQUIRED) + endif() + + if (GGML_METAL) + find_library(FOUNDATION_LIBRARY Foundation REQUIRED) + find_library(METAL_FRAMEWORK Metal REQUIRED) + find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) + + list(APPEND GGML_METAL_INTERFACE_LINK_LIBRARIES + ${FOUNDATION_LIBRARY} ${METAL_FRAMEWORK} ${METALKIT_FRAMEWORK}) + endif() + + if (GGML_VULKAN) + find_package(Vulkan REQUIRED) + list(APPEND GGML_VULKAN_INTERFACE_LINK_LIBRARIES Vulkan::Vulkan) + endif() + + if (GGML_HIP) + find_package(hip REQUIRED) + find_package(hipblas REQUIRED) + find_package(rocblas REQUIRED) + list(APPEND GGML_HIP_INTERFACE_LINK_LIBRARIES hip::host roc::rocblas roc::hipblas) + endif() + + if (GGML_SYCL) + find_package(DNNL) + if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL") + list(APPEND GGML_SYCL_INTERFACE_LINK_LIBRARIES DNNL::dnnl) + endif() + if (WIN32) + find_package(IntelSYCL REQUIRED) + find_package(MKL REQUIRED) + list(APPEND GGML_SYCL_INTERFACE_LINK_LIBRARIES IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL) + endif() + endif() +endif() + +set(_ggml_all_targets "") +foreach(_ggml_backend ${GGML_AVAILABLE_BACKENDS}) + string(REPLACE "-" "_" _ggml_backend_pfx "${_ggml_backend}") + string(TOUPPER "${_ggml_backend_pfx}" _ggml_backend_pfx) + + find_library(${_ggml_backend_pfx}_LIBRARY ${_ggml_backend} + REQUIRED + HINTS ${GGML_LIB_DIR} + NO_CMAKE_FIND_ROOT_PATH) + + message(STATUS "Found ${${_ggml_backend_pfx}_LIBRARY}") + + add_library(ggml::${_ggml_backend} UNKNOWN IMPORTED) + set_target_properties(ggml::${_ggml_backend} + PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${GGML_INCLUDE_DIR}" + IMPORTED_LINK_INTERFACE_LANGUAGES "CXX" + IMPORTED_LOCATION "${${_ggml_backend_pfx}_LIBRARY}" + INTERFACE_COMPILE_FEATURES c_std_90 + POSITION_INDEPENDENT_CODE ON) + + string(REGEX MATCH "^ggml-cpu" is_cpu_variant "${_ggml_backend}") + if(is_cpu_variant) + list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES "ggml::ggml" "ggml::ggml-base") + set_target_properties(ggml::${_ggml_backend} + PROPERTIES + INTERFACE_LINK_LIBRARIES "${GGML_CPU_INTERFACE_LINK_LIBRARIES}") + + if(GGML_CPU_INTERFACE_LINK_OPTIONS) + set_target_properties(ggml::${_ggml_backend} + PROPERTIES + INTERFACE_LINK_OPTIONS "${GGML_CPU_INTERFACE_LINK_OPTIONS}") + endif() + + else() + list(APPEND ${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES "ggml::ggml" "ggml::ggml-base") + set_target_properties(ggml::${_ggml_backend} + PROPERTIES + INTERFACE_LINK_LIBRARIES "${${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES}") + + if(${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS) + set_target_properties(ggml::${_ggml_backend} + PROPERTIES + INTERFACE_LINK_OPTIONS "${${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS}") + endif() + endif() + + list(APPEND _ggml_all_targets ggml::${_ggml_backend}) +endforeach() + +add_library(ggml::all INTERFACE IMPORTED) +set_target_properties(ggml::all + PROPERTIES + INTERFACE_LINK_LIBRARIES "${_ggml_all_targets}") + +check_required_components(ggml) diff --git a/scripts/sync-ggml.sh b/scripts/sync-ggml.sh index 30b592d2a56..077af2c9653 100755 --- a/scripts/sync-ggml.sh +++ b/scripts/sync-ggml.sh @@ -2,7 +2,7 @@ cp -rpv ../ggml/CMakeLists.txt ./ggml/CMakeLists.txt cp -rpv ../ggml/src/CMakeLists.txt ./ggml/src/CMakeLists.txt -cp -rpv ../ggml/cmake/FindSIMD.cmake ./ggml/cmake/FindSIMD.cmake +cp -rpv ../ggml/cmake/* ./ggml/cmake/ cp -rpv ../ggml/src/ggml*.c ./ggml/src/ cp -rpv ../ggml/src/ggml*.cpp ./ggml/src/ From 234460987e04b0727092ec571e8b1918d672e3c4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 3 Feb 2025 19:50:24 +0200 Subject: [PATCH 52/64] ci : use ubuntu-22.04 instead of ubuntu-latest --- .github/workflows/bindings-go.yml | 4 +-- .github/workflows/bindings-ruby.yml | 4 +-- .github/workflows/build.yml | 48 ++++++++++++++--------------- .github/workflows/docker.yml | 2 +- .github/workflows/examples.yml | 4 +-- 5 files changed, 31 insertions(+), 31 deletions(-) diff --git a/.github/workflows/bindings-go.yml b/.github/workflows/bindings-go.yml index 9808cf0f573..ff420f2b636 100644 --- a/.github/workflows/bindings-go.yml +++ b/.github/workflows/bindings-go.yml @@ -10,8 +10,8 @@ on: - whisper.h jobs: - ubuntu-latest: - runs-on: ubuntu-latest + ubuntu-22: + runs-on: ubuntu-22.04 steps: - uses: actions/setup-go@v5 with: diff --git a/.github/workflows/bindings-ruby.yml b/.github/workflows/bindings-ruby.yml index 0b427429da9..94ccf835e44 100644 --- a/.github/workflows/bindings-ruby.yml +++ b/.github/workflows/bindings-ruby.yml @@ -42,8 +42,8 @@ on: - examples/dr_wav.h jobs: - ubuntu-latest: - runs-on: ubuntu-latest + ubuntu-22: + runs-on: ubuntu-22.04 defaults: run: working-directory: bindings/ruby diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 2e33e2a87de..2b880d34b3d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -16,8 +16,8 @@ env: VCPKG_BINARY_SOURCES: "clear;x-gha,readwrite" jobs: - ubuntu-latest: - runs-on: ubuntu-latest + ubuntu-22: + runs-on: ubuntu-22.04 strategy: fail-fast: false @@ -42,8 +42,8 @@ jobs: cmake -B build cmake --build build --config Release -j $(nproc)' - ubuntu-latest-arm64: - runs-on: ubuntu-latest + ubuntu-22-arm64: + runs-on: ubuntu-22.04 strategy: fail-fast: false @@ -68,8 +68,8 @@ jobs: cmake -B build -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8-a cmake --build build --config Release -j $(nproc)' - ubuntu-latest-arm-v7: - runs-on: ubuntu-latest + ubuntu-22-arm-v7: + runs-on: ubuntu-22.04 strategy: fail-fast: false @@ -129,8 +129,8 @@ jobs: # cmake -B build # cmake --build build --config Release - ubuntu-latest-gcc: - runs-on: ubuntu-latest + ubuntu-22-gcc: + runs-on: ubuntu-22.04 strategy: fail-fast: false @@ -157,8 +157,8 @@ jobs: make ctest -L gh --output-on-failure' - ubuntu-latest-gcc-arm64: - runs-on: ubuntu-latest + ubuntu-22-gcc-arm64: + runs-on: ubuntu-22.04 strategy: fail-fast: false @@ -185,8 +185,8 @@ jobs: make ctest -L gh --output-on-failure' - ubuntu-latest-gcc-arm-v7: - runs-on: ubuntu-latest + ubuntu-22-gcc-arm-v7: + runs-on: ubuntu-22.04 strategy: fail-fast: false @@ -213,8 +213,8 @@ jobs: make ctest -L gh --output-on-failure' - ubuntu-latest-clang: - runs-on: ubuntu-latest + ubuntu-22-clang: + runs-on: ubuntu-22.04 strategy: fail-fast: false @@ -244,8 +244,8 @@ jobs: make ctest -L gh --output-on-failure' - ubuntu-latest-gcc-sanitized: - runs-on: ubuntu-latest + ubuntu-22-gcc-sanitized: + runs-on: ubuntu-22.04 strategy: fail-fast: false @@ -584,7 +584,7 @@ jobs: 7z x sdl2.zip echo "SDL2_DIR=${{ github.workspace }}\SDL2-${{ matrix.sdl2_ver }}\cmake" | Out-File -FilePath $env:GITHUB_ENV -Append echo "${{ github.workspace }}\SDL2-${{ matrix.sdl2_ver }}\cmake" > SDL2_PATH.txt - + - name: Configure CMake shell: cmd run: | @@ -594,16 +594,16 @@ jobs: -DCMAKE_CUDA_ARCHITECTURES=all ^ -DWHISPER_SDL2=${{ matrix.sdl2 }} ^ -DSDL2_DIR="%SDL2_DIR%" - + - name: Build Project shell: cmd run: | cd ./build - cmake --build . --config ${{ matrix.build }} + cmake --build . --config ${{ matrix.build }} - name: Copy CUDA DLLs run: | - Get-ChildItem "${{ steps.cuda-toolkit.outputs.CUDA_PATH }}/bin/" -Filter "*.dll" | + Get-ChildItem "${{ steps.cuda-toolkit.outputs.CUDA_PATH }}/bin/" -Filter "*.dll" | Copy-Item -Destination "build/bin/${{ matrix.build }}" - name: Copy SDL2.dll @@ -617,7 +617,7 @@ jobs: path: build/bin/${{ matrix.build }} emscripten: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 strategy: matrix: @@ -684,7 +684,7 @@ jobs: run: xcodebuild -project examples/whisper.swiftui/whisper.swiftui.xcodeproj -scheme WhisperCppDemo -configuration ${{ matrix.build }} -sdk iphoneos CODE_SIGNING_REQUIRED=NO CODE_SIGN_IDENTITY= -destination 'generic/platform=iOS' build android: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - name: Clone @@ -714,7 +714,7 @@ jobs: # TODO: disable because of following fail: https://github.com/ggerganov/whisper.cpp/actions/runs/11019444420/job/30627193602 # android_java: -# runs-on: ubuntu-latest +# runs-on: ubuntu-22.04 # # steps: # - name: Clone @@ -783,7 +783,7 @@ jobs: # PGP_PASSPHRASE: ${{ secrets.GPG_PASSPHRASE }} quantize: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - name: Clone diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index ad4282fd09c..30887e3248e 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -11,7 +11,7 @@ jobs: name: Push Docker image to Docker Hub if: github.event.pull_request.draft == false - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 env: COMMIT_SHA: ${{ github.sha }} strategy: diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index 808dd18c0b7..efc37d2bd91 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -10,8 +10,8 @@ on: - whisper.h jobs: - addon_node-ubuntu-latest: - runs-on: ubuntu-latest + addon_node-ubuntu-22: + runs-on: ubuntu-22.04 strategy: matrix: node-version: [ 16.x, 18.x ] From e0f4cef867c83ea068be3f8fd185f73db62b56e3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 3 Feb 2025 20:12:37 +0200 Subject: [PATCH 53/64] ci : install git --- .github/workflows/build.yml | 22 +++++++++++----------- .github/workflows/examples.yml | 2 +- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 2b880d34b3d..2b3f8fef53f 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -38,7 +38,7 @@ jobs: -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' set -e apt update - apt install -y build-essential libsdl2-dev cmake + apt install -y build-essential libsdl2-dev cmake git cmake -B build cmake --build build --config Release -j $(nproc)' @@ -64,7 +64,7 @@ jobs: -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' set -e apt update - apt install -y build-essential libsdl2-dev cmake + apt install -y build-essential libsdl2-dev cmake git cmake -B build -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8-a cmake --build build --config Release -j $(nproc)' @@ -90,7 +90,7 @@ jobs: -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' set -e apt update - apt install -y build-essential libsdl2-dev cmake + apt install -y build-essential libsdl2-dev cmake git cmake -B build -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv7-a+fp cmake --build build --config Release -j $(nproc)' @@ -152,7 +152,7 @@ jobs: -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' set -e apt update - apt install -y build-essential cmake libsdl2-dev + apt install -y build-essential cmake libsdl2-dev git cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} make ctest -L gh --output-on-failure' @@ -180,7 +180,7 @@ jobs: -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' set -e apt update - apt install -y build-essential cmake libsdl2-dev + apt install -y build-essential cmake libsdl2-dev git cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8-a make ctest -L gh --output-on-failure' @@ -208,7 +208,7 @@ jobs: -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' set -e apt update - apt install -y build-essential cmake libsdl2-dev + apt install -y build-essential cmake libsdl2-dev git cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv7-a+fp make ctest -L gh --output-on-failure' @@ -239,7 +239,7 @@ jobs: -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' set -e apt update - apt install -y clang build-essential cmake libsdl2-dev + apt install -y clang build-essential cmake libsdl2-dev git cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang make ctest -L gh --output-on-failure' @@ -267,7 +267,7 @@ jobs: -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' set -e apt update - apt install -y build-essential cmake + apt install -y build-essential cmake git cmake . -DCMAKE_BUILD_TYPE=Debug -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON make ctest -L gh --output-on-failure' @@ -302,12 +302,12 @@ jobs: shell: bash run: | sudo apt update - sudo apt install intel-oneapi-compiler-dpcpp-cpp + sudo apt install intel-oneapi-compiler-dpcpp-cpp git - name: install oneAPI MKL library shell: bash run: | - sudo apt install intel-oneapi-mkl-devel + sudo apt install intel-oneapi-mkl-devel git - name: Clone id: checkout @@ -352,7 +352,7 @@ jobs: shell: bash run: | sudo apt update - sudo apt install intel-oneapi-compiler-dpcpp-cpp + sudo apt install intel-oneapi-compiler-dpcpp-cpp git - name: install oneAPI MKL library shell: bash diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index efc37d2bd91..74ef8e0faae 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -22,7 +22,7 @@ jobs: - name: Dependencies run: | sudo apt-get update - sudo apt-get install build-essential + sudo apt-get install build-essential git sudo apt-get install cmake sudo apt-get install libsdl2-dev From 90e3c5fc4070154b4d3c16b2f1548f8ae54ceaa8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 3 Feb 2025 21:17:33 +0200 Subject: [PATCH 54/64] ci : more git --- .devops/cublas.Dockerfile | 2 +- .devops/main-cuda.Dockerfile | 4 ++-- .devops/main.Dockerfile | 4 ++-- .github/workflows/build.yml | 1 + 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/.devops/cublas.Dockerfile b/.devops/cublas.Dockerfile index e7e51273a0f..1fae25635d7 100644 --- a/.devops/cublas.Dockerfile +++ b/.devops/cublas.Dockerfile @@ -12,7 +12,7 @@ FROM ${BASE_CUDA_DEV_CONTAINER} as build ARG CUDA_DOCKER_ARCH=all RUN apt-get update && \ - apt-get install -y build-essential git cmake libsdl2-dev wget + apt-get install -y build-essential git cmake libsdl2-dev wget git WORKDIR /app diff --git a/.devops/main-cuda.Dockerfile b/.devops/main-cuda.Dockerfile index a806fa09bde..75a395c70f2 100644 --- a/.devops/main-cuda.Dockerfile +++ b/.devops/main-cuda.Dockerfile @@ -17,7 +17,7 @@ ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH} ENV GGML_CUDA=1 RUN apt-get update && \ - apt-get install -y build-essential libsdl2-dev wget cmake \ + apt-get install -y build-essential libsdl2-dev wget cmake git \ && rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* # Ref: https://stackoverflow.com/a/53464012 @@ -33,7 +33,7 @@ ENV LD_LIBRARY_PATH /usr/local/cuda-${CUDA_MAIN_VERSION}/compat:$LD_LIBRARY_PATH WORKDIR /app RUN apt-get update && \ - apt-get install -y curl ffmpeg wget cmake \ + apt-get install -y curl ffmpeg wget cmake git \ && rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* COPY --from=build /app /app diff --git a/.devops/main.Dockerfile b/.devops/main.Dockerfile index 066d4c83796..e8424126057 100644 --- a/.devops/main.Dockerfile +++ b/.devops/main.Dockerfile @@ -2,7 +2,7 @@ FROM ubuntu:22.04 AS build WORKDIR /app RUN apt-get update && \ - apt-get install -y build-essential wget cmake \ + apt-get install -y build-essential wget cmake git \ && rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* COPY .. . @@ -12,7 +12,7 @@ FROM ubuntu:22.04 AS runtime WORKDIR /app RUN apt-get update && \ - apt-get install -y curl ffmpeg libsdl2-dev wget cmake \ + apt-get install -y curl ffmpeg libsdl2-dev wget cmake git \ && rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* COPY --from=build /app /app diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 2b3f8fef53f..38cd795ec59 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -393,6 +393,7 @@ jobs: msystem: ${{matrix.sys}} install: >- base-devel + git mingw-w64-${{matrix.env}}-toolchain mingw-w64-${{matrix.env}}-cmake mingw-w64-${{matrix.env}}-SDL2 From cff8868b5f4adb8b9db4ed949d8661110242c82a Mon Sep 17 00:00:00 2001 From: mgrachten Date: Mon, 3 Feb 2025 21:36:32 +0100 Subject: [PATCH 55/64] coreml : always convert to "neuralnetwork" (#2770) --- models/convert-whisper-to-coreml.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/convert-whisper-to-coreml.py b/models/convert-whisper-to-coreml.py index f66683e9b21..441efdd2d6f 100644 --- a/models/convert-whisper-to-coreml.py +++ b/models/convert-whisper-to-coreml.py @@ -245,7 +245,7 @@ def convert_encoder(hparams, model, quantize=False): model = ct.convert( traced_model, - convert_to=None if quantize else "mlprogram", # convert will fail if weights are quantized, not sure why + convert_to="neuralnetwork", inputs=[ct.TensorType(name="logmel_data", shape=input_shape)], outputs=[ct.TensorType(name="output")], compute_units=ct.ComputeUnit.ALL @@ -268,7 +268,7 @@ def convert_decoder(hparams, model, quantize=False): model = ct.convert( traced_model, - convert_to=None if quantize else "mlprogram", # convert will fail if weights are quantized, not sure why + convert_to="neuralnetwork", inputs=[ ct.TensorType(name="token_data", shape=tokens_shape, dtype=int), ct.TensorType(name="audio_data", shape=audio_shape) From 3f91832352b2aca890102dc7ebc182f7fa095151 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 3 Feb 2025 22:42:26 +0200 Subject: [PATCH 56/64] talk-llama : sync llama.cpp --- examples/talk-llama/llama-arch.cpp | 9 +- examples/talk-llama/llama-arch.h | 4 +- examples/talk-llama/llama-chat.cpp | 13 +- examples/talk-llama/llama-chat.h | 1 + examples/talk-llama/llama-grammar.cpp | 92 ++++- examples/talk-llama/llama-grammar.h | 23 +- examples/talk-llama/llama-mmap.cpp | 1 + examples/talk-llama/llama-model-loader.cpp | 76 +++- examples/talk-llama/llama-model-loader.h | 7 +- examples/talk-llama/llama-model.cpp | 81 ++++- examples/talk-llama/llama-model.h | 2 - examples/talk-llama/llama-quant.cpp | 3 +- examples/talk-llama/llama-sampling.cpp | 51 ++- examples/talk-llama/llama-vocab.cpp | 21 +- examples/talk-llama/llama.cpp | 393 +++++++++++---------- examples/talk-llama/llama.h | 30 +- examples/talk-llama/unicode.cpp | 5 +- 17 files changed, 581 insertions(+), 231 deletions(-) diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index d7d277e7297..97a1e7e5e01 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -179,6 +179,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, { LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" }, + { LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, "tokenizer.chat_template.%s" }, { LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" }, { LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" }, { LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" }, @@ -1023,6 +1024,9 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_OUTPUT, "output" }, { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, @@ -1443,10 +1447,11 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, }; -LLM_KV::LLM_KV(llm_arch arch) : arch(arch) {} +LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} std::string LLM_KV::operator()(llm_kv kv) const { - return ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch)); + return suffix ? ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch), suffix) + : ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch)); } std::string LLM_TN_IMPL::str() const { diff --git a/examples/talk-llama/llama-arch.h b/examples/talk-llama/llama-arch.h index 34984479045..122fdcebe0a 100644 --- a/examples/talk-llama/llama-arch.h +++ b/examples/talk-llama/llama-arch.h @@ -177,6 +177,7 @@ enum llm_kv { LLM_KV_TOKENIZER_HF_JSON, LLM_KV_TOKENIZER_RWKV, LLM_KV_TOKENIZER_CHAT_TEMPLATE, + LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, LLM_KV_TOKENIZER_FIM_PRE_ID, LLM_KV_TOKENIZER_FIM_SUF_ID, LLM_KV_TOKENIZER_FIM_MID_ID, @@ -335,9 +336,10 @@ enum llm_tensor_layer { }; struct LLM_KV { - LLM_KV(llm_arch arch); + LLM_KV(llm_arch arch, const char * suffix = nullptr); llm_arch arch; + const char * suffix; std::string operator()(llm_kv kv) const; }; diff --git a/examples/talk-llama/llama-chat.cpp b/examples/talk-llama/llama-chat.cpp index 1347ec1560b..028a6479484 100644 --- a/examples/talk-llama/llama-chat.cpp +++ b/examples/talk-llama/llama-chat.cpp @@ -51,6 +51,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 }, { "chatglm3", LLM_CHAT_TEMPLATE_CHATGML_3 }, { "chatglm4", LLM_CHAT_TEMPLATE_CHATGML_4 }, + { "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE }, { "minicpm", LLM_CHAT_TEMPLATE_MINICPM }, { "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 }, { "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD }, @@ -115,7 +116,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>")) { return LLM_CHAT_TEMPLATE_PHI_3; } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) { - return LLM_CHAT_TEMPLATE_FALCON_3; + return tmpl_contains("") ? LLM_CHAT_TEMPLATE_FALCON_3 : LLM_CHAT_TEMPLATE_GLMEDGE; } else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) { return LLM_CHAT_TEMPLATE_ZEPHYR; } else if (tmpl_contains("bos_token + message['role']")) { @@ -152,7 +153,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_MINICPM; } else if (tmpl_contains("'Assistant: ' + message['content'] + eos_token")) { return LLM_CHAT_TEMPLATE_DEEPSEEK_2; - } else if (tmpl_contains(LU8("'<|Assistant|>' + message['content'] + '<|end▁of▁sentence|>'"))) { + } else if (tmpl_contains(LU8("<|Assistant|>")) && tmpl_contains(LU8("<|User|>")) && tmpl_contains(LU8("<|end▁of▁sentence|>"))) { return LLM_CHAT_TEMPLATE_DEEPSEEK_3; } else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) { // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb @@ -440,6 +441,14 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "<|assistant|>"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) { + for (auto message : chat) { + std::string role(message->role); + ss << "<|" << role << "|>" << "\n" << message->content; + } + if (add_ass) { + ss << "<|assistant|>"; + } } else if (tmpl == LLM_CHAT_TEMPLATE_MINICPM) { // MiniCPM-3B-OpenHermes-2.5-v2-GGUF for (auto message : chat) { diff --git a/examples/talk-llama/llama-chat.h b/examples/talk-llama/llama-chat.h index 3a4d07ce3de..2f6a0e3e282 100644 --- a/examples/talk-llama/llama-chat.h +++ b/examples/talk-llama/llama-chat.h @@ -31,6 +31,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_LLAMA_3, LLM_CHAT_TEMPLATE_CHATGML_3, LLM_CHAT_TEMPLATE_CHATGML_4, + LLM_CHAT_TEMPLATE_GLMEDGE, LLM_CHAT_TEMPLATE_MINICPM, LLM_CHAT_TEMPLATE_EXAONE_3, LLM_CHAT_TEMPLATE_RWKV_WORLD, diff --git a/examples/talk-llama/llama-grammar.cpp b/examples/talk-llama/llama-grammar.cpp index bebe4e9a320..9b518d1ac64 100644 --- a/examples/talk-llama/llama-grammar.cpp +++ b/examples/talk-llama/llama-grammar.cpp @@ -560,7 +560,7 @@ bool llama_grammar_parser::parse(const char * src) { } } } catch (const std::exception & err) { - fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); + fprintf(stderr, "%s: error parsing grammar: %s\n\n%s\n", __func__, err.what(), src); rules.clear(); return false; } @@ -960,10 +960,28 @@ struct llama_grammar * llama_grammar_init_impl( // Important: vec_rules has to be moved here, not copied, because stacks contains // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // then the pointers would be invalidated when the local vec_rules goes out of scope. - return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, }; + return new llama_grammar { + vocab, + std::move(vec_rules), + std::move(stacks), + /* .partial_utf8 = */ {}, + /* .lazy =*/ false, + /* .awaiting_trigger = */ false, + /* .trigger_buffer = */ "", + /* .trigger_tokens = */ {}, + /* .trigger_words = */ {}, + }; } -struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) { +struct llama_grammar * llama_grammar_init_impl( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + bool lazy, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens) { llama_grammar_parser parser; // if there is a grammar, parse it @@ -1035,10 +1053,31 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, } } while (true); + std::vector vec_trigger_tokens; + std::vector vec_trigger_words; + for (size_t i = 0; i < num_trigger_tokens; i++) { + GGML_ASSERT(trigger_tokens != nullptr); + vec_trigger_tokens.push_back(trigger_tokens[i]); + } + for (size_t i = 0; i < num_trigger_words; i++) { + GGML_ASSERT(trigger_words != nullptr); + vec_trigger_words.push_back(trigger_words[i]); + } + // Important: vec_rules has to be moved here, not copied, because stacks contains // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // then the pointers would be invalidated when the local vec_rules goes out of scope. - return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, }; + return new llama_grammar { + vocab, + std::move(vec_rules), + std::move(stacks), + /* .partial_utf8 = */ {}, + /* .lazy = */ lazy, + /* .awaiting_trigger = */ lazy, + /* .trigger_buffer = */ "", + std::move(vec_trigger_tokens), + std::move(vec_trigger_words), + }; } void llama_grammar_free_impl(struct llama_grammar * grammar) { @@ -1055,6 +1094,11 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra grammar.rules, grammar.stacks, grammar.partial_utf8, + grammar.lazy, + grammar.awaiting_trigger, + grammar.trigger_buffer, + grammar.trigger_tokens, + grammar.trigger_words, }; // redirect elements in stacks to point to new rules @@ -1076,6 +1120,10 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) { GGML_ASSERT(grammar.vocab != nullptr); + if (grammar.awaiting_trigger) { + return; + } + bool allow_eog = false; for (const auto & stack : grammar.stacks) { if (stack.empty()) { @@ -1115,6 +1163,34 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) { GGML_ASSERT(grammar.vocab != nullptr); + const auto & piece = grammar.vocab->token_to_piece(token); + + if (grammar.awaiting_trigger) { + if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) { + grammar.awaiting_trigger = false; + grammar.trigger_buffer.clear(); + llama_grammar_accept_str(grammar, piece); + LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str()); + return; + } else { + // TODO: consider a smarter incremental substring search algorithm (store last position to search from). + grammar.trigger_buffer += piece; + for (const auto & word : grammar.trigger_words) { + auto pos = grammar.trigger_buffer.find(word); + if (pos != std::string::npos) { + grammar.awaiting_trigger = false; + auto constrained_str = grammar.trigger_buffer.substr(pos); + grammar.trigger_buffer.clear(); + llama_grammar_accept_str(grammar, constrained_str); + LLAMA_LOG_DEBUG("Grammar triggered on word `%s`", word.c_str()); + return; + } + } + LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`) (buffer: `%s`)\n", token, piece.c_str(), grammar.trigger_buffer.c_str()); + return; + } + } + if (grammar.vocab->is_eog(token)) { for (const auto & stack : grammar.stacks) { if (stack.empty()) { @@ -1124,8 +1200,10 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token GGML_ABORT("fatal error"); } - const std::string & piece = grammar.vocab->token_to_piece(token); + llama_grammar_accept_str(grammar, piece); +} +void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string & piece) { // Note terminating 0 in decoded string const auto decoded = decode_utf8(piece, grammar.partial_utf8); const auto & code_points = decoded.first; @@ -1135,5 +1213,7 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token } grammar.partial_utf8 = decoded.second; - GGML_ASSERT(!grammar.stacks.empty()); + if (grammar.stacks.empty()) { + throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece); + } } diff --git a/examples/talk-llama/llama-grammar.h b/examples/talk-llama/llama-grammar.h index f8b40c65157..252d54d4c9f 100644 --- a/examples/talk-llama/llama-grammar.h +++ b/examples/talk-llama/llama-grammar.h @@ -114,6 +114,15 @@ struct llama_grammar { // buffer for partially generated UTF-8 sequence from accepted tokens llama_partial_utf8 partial_utf8; + + // lazy grammars wait for trigger words or tokens before constraining the sampling. + // we still ahve trigger_tokens for non-lazy grammars to force printing of special trigger tokens. + // (useful e.g. for tool_choice=required) + bool lazy = false; + bool awaiting_trigger = false; // Initialized to true for lazy grammars only + std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found. + std::vector trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special). + std::vector trigger_words; }; // @@ -127,7 +136,15 @@ struct llama_grammar * llama_grammar_init_impl( size_t n_rules, size_t start_rule_index); -struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root); +struct llama_grammar * llama_grammar_init_impl( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + bool lazy, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens); void llama_grammar_free_impl(struct llama_grammar * grammar); @@ -141,3 +158,7 @@ void llama_grammar_apply_impl( void llama_grammar_accept_impl( struct llama_grammar & grammar, llama_token token); + +void llama_grammar_accept_str( + struct llama_grammar & grammar, + const std::string & piece); diff --git a/examples/talk-llama/llama-mmap.cpp b/examples/talk-llama/llama-mmap.cpp index 57c6e4f510f..b716630a8c2 100644 --- a/examples/talk-llama/llama-mmap.cpp +++ b/examples/talk-llama/llama-mmap.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #ifdef __has_include #if __has_include() diff --git a/examples/talk-llama/llama-model-loader.cpp b/examples/talk-llama/llama-model-loader.cpp index 53175f0e069..05d58ad90eb 100644 --- a/examples/talk-llama/llama-model-loader.cpp +++ b/examples/talk-llama/llama-model-loader.cpp @@ -64,6 +64,33 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { } } +// return a list of splits for a given path +// for example, given "-00002-of-00004.gguf", returns list of all 4 splits +static std::vector llama_get_list_splits(const std::string & path, const int idx, const int n_split) { + std::vector paths; + std::string split_prefix; + std::vector buf(llama_path_max(), 0); + + { + int ret = llama_split_prefix(buf.data(), buf.size(), path.c_str(), idx, n_split); + if (!ret) { + throw std::runtime_error(format("invalid split file name: %s", path.c_str())); + } + split_prefix = std::string(buf.data(), ret); + } + + if (split_prefix.empty()) { + throw std::runtime_error(format("invalid split file: %s", path.c_str())); + } + + for (int idx = 0; idx < n_split; ++idx) { + int ret = llama_split_path(buf.data(), buf.size(), split_prefix.c_str(), idx, n_split); + paths.push_back(std::string(buf.data(), ret)); + } + + return paths; +} + namespace GGUFMeta { template struct GKV_Base_Type { @@ -413,7 +440,12 @@ namespace GGUFMeta { template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); -llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p) { +llama_model_loader::llama_model_loader( + const std::string & fname, + std::vector & splits, + bool use_mmap, + bool check_tensors, + const struct llama_model_kv_override * param_overrides_p) { int trace = 0; if (getenv("LLAMA_TRACE")) { trace = atoi(getenv("LLAMA_TRACE")); @@ -425,6 +457,7 @@ llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, } } + // Load the main GGUF struct ggml_context * ctx = NULL; struct gguf_init_params params = { /*.no_alloc = */ true, @@ -460,35 +493,54 @@ llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, // Load additional GGML contexts if (n_split > 1) { + // make sure the main file is loaded first uint16_t idx = 0; - get_key(llm_kv(LLM_KV_SPLIT_NO), idx); + const std::string kv_split_no = llm_kv(LLM_KV_SPLIT_NO); + get_key(kv_split_no, idx); if (idx != 0) { - throw std::runtime_error(format("illegal split file: %d, model must be loaded with the first split", idx)); + throw std::runtime_error(format("illegal split file idx: %d (file: %s), model must be loaded with the first split", idx, fname.c_str())); + } + + // generate list of splits if needed + if (splits.empty()) { + splits = llama_get_list_splits(fname, idx, n_split); } - std::vector split_prefix(llama_path_max(), 0); - if (!llama_split_prefix(split_prefix.data(), split_prefix.size(), fname.c_str(), idx, n_split)) { - throw std::runtime_error(format("invalid split file: %s", fname.c_str())); + // in case user give a custom list of splits, check if it matches the expected number + if (n_split != (uint16_t)splits.size()) { + throw std::runtime_error(format("invalid split count, given: %zu splits, but expected %d", splits.size(), n_split)); } if (trace > 0) { LLAMA_LOG_INFO("%s: loading additional %d GGUFs\n", __func__, n_split); } - std::vector split_path(llama_path_max(), 0); + // load other splits for (idx = 1; idx < n_split; idx++) { - llama_split_path(split_path.data(), split_path.size(), split_prefix.data(), idx, n_split); + const char * fname_split = splits[idx].c_str(); struct gguf_init_params split_params = { /*.no_alloc = */ true, /*.ctx = */ &ctx, }; - gguf_context_ptr ctx_gguf { gguf_init_from_file(split_path.data(), split_params) }; + gguf_context_ptr ctx_gguf { gguf_init_from_file(fname_split, split_params) }; if (!ctx_gguf) { - throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, split_path.data())); + throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, fname_split)); + } + + // check idx + { + const int kid = gguf_find_key(ctx_gguf.get(), kv_split_no.c_str()); + if (kid < 0) { + throw std::runtime_error(format("missing key %s in GGUF split %s", kv_split_no.c_str(), fname_split)); + } + int idx_gguf = gguf_get_val_u16(ctx_gguf.get(), kid); + if (idx_gguf != idx) { + throw std::runtime_error(format("invalid split file idx: %d (file: %s), expected %d", idx_gguf, fname_split, idx)); + } } - files.emplace_back(new llama_file(split_path.data(), "rb")); + files.emplace_back(new llama_file(fname_split, "rb")); contexts.emplace_back(ctx); // Save tensors data offset info of the shard. @@ -767,7 +819,7 @@ void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps for (const auto & file : files) { auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU)); auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa"); - std::unique_ptr mapping(new llama_mmap(file.get(), prefetch ? -1 : 0, is_numa_fn())); + std::unique_ptr mapping = std::make_unique(file.get(), prefetch ? -1 : 0, is_numa_fn()); mmaps_used.emplace_back(mapping->size(), 0); if (mlock_mmaps) { std::unique_ptr mlock_mmap(new llama_mlock()); diff --git a/examples/talk-llama/llama-model-loader.h b/examples/talk-llama/llama-model-loader.h index b63d158d982..fe35404b268 100644 --- a/examples/talk-llama/llama-model-loader.h +++ b/examples/talk-llama/llama-model-loader.h @@ -90,7 +90,12 @@ struct llama_model_loader { size_t size_data = 0; std::vector> mmaps_used; - llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p); + llama_model_loader( + const std::string & fname, + std::vector & splits, // optional, only need if the split does not follow naming scheme + bool use_mmap, + bool check_tensors, + const struct llama_model_kv_override * param_overrides_p); template typename std::enable_if::value, bool>::type diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index f90f5e74607..0487c978b5e 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -1093,8 +1093,20 @@ void llama_model::load_hparams(llama_model_loader & ml) { { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 28: type = LLM_TYPE_6B; break; - case 40: type = LLM_TYPE_9B; break; + case 28: { + if (hparams.n_head(0) == 16) { + type = LLM_TYPE_1_5B; + } else { + type = LLM_TYPE_6B; + } + } break; + case 40: { + if (hparams.n_head(0) == 24) { + type = LLM_TYPE_4B; + } else { + type = LLM_TYPE_9B; + } + } break; default: type = LLM_TYPE_UNKNOWN; } } break; @@ -1303,10 +1315,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, (int)n_layer + 1); auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev { if (il < i_gpu_start || (il - i_gpu_start) >= act_gpu_layers) { + LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s\n", il, ggml_backend_dev_name(cpu_dev)); return {cpu_dev, &pimpl->cpu_buft_list}; } const int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + n_devices(), float(il - i_gpu_start)/act_gpu_layers) - splits.begin(); auto * dev = devices.at(layer_gpu); + LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s\n", il, ggml_backend_dev_name(dev)); return {dev, &pimpl->gpu_buft_list.at(dev)}; }; @@ -2203,6 +2217,50 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); } } break; + case LLM_ARCH_PHIMOE: + { + const int64_t n_embd_head = n_embd / n_head; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), { n_vocab }, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), { n_embd }, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, llama_model_loader::TENSOR_NOT_REQUIRED); + if (layer.wqkv == nullptr) { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); + + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + } + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), { n_embd }, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + } break; case LLM_ARCH_PLAMO: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -3022,9 +3080,17 @@ bool llama_model::load_tensors(llama_model_loader & ml) { auto & layer = layers[i]; layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + if (layer.wqkv == nullptr) { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + } layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); @@ -3717,7 +3783,6 @@ struct llama_model_params llama_model_default_params() { /*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER, /*.main_gpu =*/ 0, /*.tensor_split =*/ nullptr, - /*.rpc_servers =*/ nullptr, /*.progress_callback =*/ nullptr, /*.progress_callback_user_data =*/ nullptr, /*.kv_overrides =*/ nullptr, @@ -3912,8 +3977,10 @@ uint64_t llama_model_size(const struct llama_model * model) { return model->size(); } -const char * llama_model_chat_template(const struct llama_model * model) { - const auto & it = model->gguf_kv.find(LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE)); +const char * llama_model_chat_template(const struct llama_model * model, const char * name) { + const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N) + : LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE); + const auto & it = model->gguf_kv.find(key); if (it == model->gguf_kv.end()) { return nullptr; } diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index 4cc8abb753a..a7c30444786 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -323,8 +323,6 @@ struct llama_model { // gguf metadata std::unordered_map gguf_kv; - std::vector rpc_servers; - // list of devices used in this model std::vector devices; diff --git a/examples/talk-llama/llama-quant.cpp b/examples/talk-llama/llama-quant.cpp index d4947a780c1..fb7982655a3 100644 --- a/examples/talk-llama/llama-quant.cpp +++ b/examples/talk-llama/llama-quant.cpp @@ -526,7 +526,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: kv_overrides = v->data(); } - llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, kv_overrides); + std::vector splits = {}; + llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides); ml.init_mappings(false); // no prefetching llama_model model(llama_model_default_params()); diff --git a/examples/talk-llama/llama-sampling.cpp b/examples/talk-llama/llama-sampling.cpp index b3a12386e8a..26974f53965 100644 --- a/examples/talk-llama/llama-sampling.cpp +++ b/examples/talk-llama/llama-sampling.cpp @@ -1433,13 +1433,30 @@ static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token } } +// Fwd declare to break reset --> init_impl --> llama_sampler_grammar_i --> reset cycle. +static struct llama_sampler * llama_sampler_init_grammar_impl( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + bool lazy, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens); + static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { auto * ctx = (llama_sampler_grammar *) smpl->ctx; if (!ctx->grammar) { return; } - auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str()); + std::vector trigger_words; + for (auto & word : ctx->grammar->trigger_words) { + trigger_words.push_back(word.c_str()); + } + auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(), + ctx->grammar->lazy, trigger_words.data(), trigger_words.size(), + ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size()); llama_grammar_free_impl(ctx->grammar); ctx->grammar = grammar_new; @@ -1448,7 +1465,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_grammar *) smpl->ctx; - auto * result = llama_sampler_init_grammar(ctx->vocab, nullptr, nullptr); + auto * result = llama_sampler_init_grammar_impl(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0); // copy the state { @@ -1484,7 +1501,15 @@ static struct llama_sampler_i llama_sampler_grammar_i = { /* .free = */ llama_sampler_grammar_free, }; -struct llama_sampler * llama_sampler_init_grammar(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) { +static struct llama_sampler * llama_sampler_init_grammar_impl( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + bool lazy, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens) { auto * ctx = new llama_sampler_grammar; if (grammar_str != nullptr && grammar_str[0] != '\0') { @@ -1492,7 +1517,7 @@ struct llama_sampler * llama_sampler_init_grammar(const struct llama_vocab * voc /* .vocab = */ vocab, /* .grammar_str = */ grammar_str, /* .grammar_root = */ grammar_root, - /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root), + /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens), }; } else { *ctx = { @@ -1509,6 +1534,24 @@ struct llama_sampler * llama_sampler_init_grammar(const struct llama_vocab * voc }; } +struct llama_sampler * llama_sampler_init_grammar( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root) { + return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ false, nullptr, 0, nullptr, 0); +} + +struct llama_sampler * llama_sampler_init_grammar_lazy( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens) { + return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens); +} + // penalties struct llama_sampler_penalties { diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index 96b74e93a51..ad9ffe66aa7 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -439,7 +439,7 @@ struct llm_tokenizer_bpe_session { "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. " "Are you sure this is what you want?\n", __FUNCTION__); } - if (vocab.get_add_bos() && output.size() >= 2 && *(output.end()-2) == vocab.token_eos()) { + if (vocab.get_add_eos() && output.size() >= 2 && *(output.end()-2) == vocab.token_eos()) { LLAMA_LOG_WARN( "%s: Added a EOS token to the prompt as specified by the model but the prompt " "also ends with a EOS token. So now the final prompt ends with 2 EOS tokens. " @@ -1245,8 +1245,13 @@ struct llama_vocab::impl { std::vector cache_special_tokens; std::vector cache_token_to_piece; // llama_token_to_piece(special = true); - - std::map, int> bpe_ranks; + struct pair_hash { + size_t operator()(const std::pair & p) const { + return std::hash{}(p.first) ^ //create some hash for pair + (std::hash{}(p.second) << 1); + } + }; + std::unordered_map, int, pair_hash> bpe_ranks; // set of all tokens that cause "end of generation" std::set special_eog_ids; @@ -1356,8 +1361,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { // read vocab size from metadata uint32_t n_tokens = 0; - if (!ml.get_key(LLM_KV_VOCAB_SIZE, n_tokens, false)) { - LLAMA_LOG_WARN("%s: there is no vocab_size in metadata\n", __func__); + if (ml.get_key(LLM_KV_VOCAB_SIZE, n_tokens, false)) { + LLAMA_LOG_WARN("%s: adding %u dummy tokens\n", __func__, n_tokens); + id_to_token.resize(n_tokens); } return; @@ -1522,7 +1528,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { pre_type = LLAMA_VOCAB_PRE_TYPE_COMMAND_R; clean_spaces = false; } else if ( - tokenizer_pre == "qwen2") { + tokenizer_pre == "qwen2" || + tokenizer_pre == "deepseek-r1-qwen") { pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2; clean_spaces = false; } else if ( @@ -1685,7 +1692,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); linefeed_id = ids[0]; } else { - const std::vector ids = tokenize("\xC4\x8A", false); // U+010A + const std::vector ids = tokenize("\n", false); //GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); if (ids.empty()) { diff --git a/examples/talk-llama/llama.cpp b/examples/talk-llama/llama.cpp index daf1b7c97cd..5760017e0d9 100644 --- a/examples/talk-llama/llama.cpp +++ b/examples/talk-llama/llama.cpp @@ -31,7 +31,7 @@ #endif // Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback -static int llama_model_load(const std::string & fname, llama_model & model, llama_model_params & params) { +static int llama_model_load(const std::string & fname, std::vector & splits, llama_model & model, llama_model_params & params) { // loading time will be recalculated after the first eval, so // we take page faults deferred by mmap() into consideration model.t_load_us = 0; @@ -40,7 +40,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam model.t_start_us = tm.t_start_us; try { - llama_model_loader ml(fname, params.use_mmap, params.check_tensors, params.kv_overrides); + llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides); ml.print_info(); @@ -4642,7 +4642,7 @@ struct llm_build_context { 0); cb(v_states, "v_states", il); - q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE + q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this q_pe = ggml_rope_ext( ctx0, q_pe, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -4651,7 +4651,7 @@ struct llm_build_context { cb(q_pe, "q_pe", il); // shared RoPE key - k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE + k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this k_pe = ggml_rope_ext( ctx0, k_pe, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -6496,7 +6496,7 @@ struct llm_build_context { 0); cb(v_states, "v_states", il); - q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE + q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this q_pe = ggml_rope_ext( ctx0, q_pe, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -6505,7 +6505,7 @@ struct llm_build_context { cb(q_pe, "q_pe", il); // shared RoPE key - k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE + k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this k_pe = ggml_rope_ext( ctx0, k_pe, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -7215,17 +7215,30 @@ struct llm_build_context { struct ggml_tensor * Qcur = nullptr; struct ggml_tensor * Kcur = nullptr; struct ggml_tensor * Vcur = nullptr; - - cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - + if (model.type == LLM_TYPE_1_5B || model.type == LLM_TYPE_4B || model.type == LLM_TYPE_9B) { + Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + } + Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + } + Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + } + } else { + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + if (model.layers[il].bqkv) { + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + } + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + } cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); @@ -7700,17 +7713,13 @@ struct llm_build_context { 1 ); + struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att)); ggml_build_forward_expand( gf, ggml_cpy( ctx0, - wkv_states, - ggml_view_1d( - ctx0, - kv_self.v_l[il], - hparams.n_embd_v_s() * n_seqs, - hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il]) - ) + ggml_view_1d(ctx0, last_norm_att, n_embd * n_seqs, 0), + ggml_view_1d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self.k_l[il])) ) ); @@ -8432,74 +8441,33 @@ static enum ggml_status llama_graph_compute( return status; } -// decode a batch of tokens by evaluating the transformer -// in case of unsuccessful decoding (error or warning), -// the kv_cache state will be returned to its original state -// (for non-recurrent models) or cleaned (for recurrent models) -// -// - lctx: llama context -// - batch: batch to evaluate -// -// return 0 on success -// return positive int on warning -// return negative int on error -// -static int llama_decode_impl( - llama_context & lctx, - llama_batch inp_batch) { - - lctx.is_encoding = false; - - if (inp_batch.n_tokens == 0) { - LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); - return -1; - } - - // temporary allocate memory for the input batch if needed - llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1); - - const llama_batch & batch = batch_allocr.batch; - const uint32_t n_tokens_all = batch.n_tokens; - +static int llama_prepare_sbatch( + llama_context & lctx, + const llama_batch & batch, + uint32_t & n_outputs) { const auto & model = lctx.model; - const auto & vocab = model.vocab; const auto & hparams = model.hparams; const auto & cparams = lctx.cparams; - GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT + const uint32_t n_tokens_all = batch.n_tokens; + const int64_t n_embd = hparams.n_embd; + + // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens + const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; + GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT if (batch.token) { for (uint32_t i = 0; i < n_tokens_all; ++i) { - if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) { + if (batch.token[i] < 0 || uint32_t(batch.token[i]) >= model.vocab.n_tokens()) { LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]); return -1; } } } - GGML_ASSERT(n_tokens_all <= cparams.n_batch); - GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens"); - if (lctx.t_compute_start_us == 0) { - lctx.t_compute_start_us = ggml_time_us(); - } lctx.n_queued_tokens += n_tokens_all; - - auto & kv_self = lctx.kv_self; - llama_kv_slot_restorer kv_slot_restorer(kv_self); - - const int64_t n_embd = hparams.n_embd; - const int64_t n_vocab = vocab.n_tokens(); - - uint32_t n_outputs = 0; - uint32_t n_outputs_prev = 0; - - const auto n_ubatch = cparams.n_ubatch; - - // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens - const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; - lctx.embd_seq.clear(); // count outputs @@ -8515,7 +8483,7 @@ static int llama_decode_impl( } lctx.sbatch.from_batch(batch, n_embd, - /* simple_split */ !kv_self.recurrent, + /* simple_split */ !lctx.kv_self.recurrent, /* logits_all */ n_outputs == n_tokens_all); // reserve output buffer @@ -8524,70 +8492,148 @@ static int llama_decode_impl( return -2; }; - while (lctx.sbatch.n_tokens > 0) { - llama_ubatch ubatch; - if (kv_self.recurrent) { - if (embd_pooled) { - // Pooled embeddings cannot be split across ubatches (yet) - ubatch = lctx.sbatch.split_seq(n_ubatch); - } else { - // recurrent model architectures are easier to implement - // with equal-length sequences - ubatch = lctx.sbatch.split_equal(n_ubatch); - } + return 0; +} + +static int llama_prepare_ubatch( + llama_context & lctx, + llama_kv_slot_restorer & kv_slot_restorer, + llama_ubatch & ubatch, + const uint32_t n_outputs, + const uint32_t n_tokens_all) { + GGML_ASSERT(lctx.sbatch.n_tokens > 0); + + auto & kv_self = lctx.kv_self; + const auto & cparams = lctx.cparams; + const auto & hparams = lctx.model.hparams; + + // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens + const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; + + if (lctx.kv_self.recurrent) { + if (embd_pooled) { + // Pooled embeddings cannot be split across ubatches (yet) + ubatch = lctx.sbatch.split_seq(cparams.n_ubatch); } else { - ubatch = lctx.sbatch.split_simple(n_ubatch); + // recurrent model architectures are easier to implement + // with equal-length sequences + ubatch = lctx.sbatch.split_equal(cparams.n_ubatch); } - const uint32_t n_tokens = ubatch.n_tokens; + } else { + ubatch = lctx.sbatch.split_simple(cparams.n_ubatch); + } - // count the outputs in this u_batch - { - int32_t n_outputs_new = 0; + // count the outputs in this u_batch + { + int32_t n_outputs_new = 0; - if (n_outputs == n_tokens_all) { - n_outputs_new = n_tokens; - } else { - GGML_ASSERT(ubatch.output); - for (uint32_t i = 0; i < n_tokens; i++) { - n_outputs_new += (int32_t) (ubatch.output[i] != 0); - } + if (n_outputs == n_tokens_all) { + n_outputs_new = ubatch.n_tokens; + } else { + GGML_ASSERT(ubatch.output); + for (uint32_t i = 0; i < ubatch.n_tokens; i++) { + n_outputs_new += int32_t(ubatch.output[i] != 0); } + } + + // needs to happen before the graph is built + lctx.n_outputs = n_outputs_new; + } + + // non-causal masks do not use the KV cache + if (hparams.causal_attn) { + llama_kv_cache_update(&lctx); - // needs to happen before the graph is built - lctx.n_outputs = n_outputs_new; + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) { + kv_self.head = 0; } - int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; - ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch; + const auto slot = llama_kv_cache_find_slot(kv_self, ubatch); + if (!slot) { + return 1; + } + kv_slot_restorer.save(slot); - GGML_ASSERT(n_threads > 0); + if (!kv_self.recurrent) { + // a heuristic, to avoid attending the full cache if it is not yet utilized + // after enough generations, the benefit from this heuristic disappears + // if we start defragmenting the cache, the benefit from this will be more important + const uint32_t pad = llama_kv_cache_get_padding(cparams); + kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad))); + //kv_self.n = llama_kv_cache_cell_max(kv_self); + } + } - // non-causal masks do not use the KV cache - if (hparams.causal_attn) { - llama_kv_cache_update(&lctx); + return 0; +} - // if we have enough unused cells before the current head -> - // better to start searching from the beginning of the cache, hoping to fill it - if (kv_self.head > kv_self.used + 2*n_tokens) { - kv_self.head = 0; - } +// decode a batch of tokens by evaluating the transformer +// in case of unsuccessful decoding (error or warning), +// the kv_cache state will be returned to its original state +// (for non-recurrent models) or cleaned (for recurrent models) +// +// - lctx: llama context +// - inp_batch: batch to evaluate +// +// return 0 on success +// return positive int on warning +// return negative int on error +// +static int llama_decode_impl( + llama_context & lctx, + llama_batch inp_batch) { - const auto slot = llama_kv_cache_find_slot(kv_self, ubatch); - if (!slot) { - return 1; - } - kv_slot_restorer.save(slot); + lctx.is_encoding = false; + + if (inp_batch.n_tokens == 0) { + LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); + return -1; + } + + // temporarily allocate memory for the input batch if needed + llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1); + const llama_batch & batch = batch_allocr.batch; + + const auto & model = lctx.model; + const auto & vocab = model.vocab; + const auto & hparams = model.hparams; + const auto & cparams = lctx.cparams; - if (!kv_self.recurrent) { - // a heuristic, to avoid attending the full cache if it is not yet utilized - // after enough generations, the benefit from this heuristic disappears - // if we start defragmenting the cache, the benefit from this will be more important - const uint32_t pad = llama_kv_cache_get_padding(cparams); - kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad))); - //kv_self.n = llama_kv_cache_cell_max(kv_self); + if (lctx.t_compute_start_us == 0) { + lctx.t_compute_start_us = ggml_time_us(); + } + auto & kv_self = lctx.kv_self; + llama_kv_slot_restorer kv_slot_restorer(kv_self); + + const int64_t n_embd = hparams.n_embd; + const int64_t n_vocab = vocab.n_tokens(); + + uint32_t n_outputs = 0; + uint32_t n_outputs_prev = 0; + + { + const int ret = llama_prepare_sbatch(lctx, batch, n_outputs); + if (ret != 0) { + return ret; + } + } + + while (lctx.sbatch.n_tokens > 0) { + llama_ubatch ubatch; + { + const int ret = llama_prepare_ubatch(lctx, kv_slot_restorer, ubatch, n_outputs, batch.n_tokens); + if (ret != 0) { + return ret; } } + const int n_threads = ubatch.n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; + ggml_threadpool_t threadpool = ubatch.n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch; + + GGML_ASSERT(n_threads > 0); + //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); ggml_backend_sched_reset(lctx.sched.get()); @@ -8640,7 +8686,7 @@ static int llama_decode_impl( // update the kv ring buffer { - kv_self.head += n_tokens; + kv_self.head += ubatch.n_tokens; // Ensure kv cache head points to a valid index. if (kv_self.head >= kv_self.size) { @@ -9374,14 +9420,9 @@ int64_t llama_time_us(void) { return ggml_time_us(); } -struct llama_model * llama_load_model_from_file( - const char * path_model, - struct llama_model_params params) { - return llama_model_load_from_file(path_model, params); -} - -struct llama_model * llama_model_load_from_file( - const char * path_model, +static struct llama_model * llama_model_load_from_file_impl( + const std::string & path_model, + std::vector & splits, struct llama_model_params params) { ggml_time_init(); @@ -9404,53 +9445,13 @@ struct llama_model * llama_model_load_from_file( }; } - if (params.rpc_servers != nullptr && params.rpc_servers[0] != '\0') { - // split the servers set them into model->rpc_servers - std::string servers(params.rpc_servers); - size_t pos = 0; - while ((pos = servers.find(',')) != std::string::npos) { - std::string server = servers.substr(0, pos); - model->rpc_servers.push_back(server); - servers.erase(0, pos + 1); - } - model->rpc_servers.push_back(servers); - } - - // add RPC devices - if (!model->rpc_servers.empty()) { - ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC"); - if (!rpc_reg) { - LLAMA_LOG_ERROR("%s: failed to find RPC backend\n", __func__); - llama_model_free(model); - return nullptr; - } - - typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint); - ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device"); - if (!ggml_backend_rpc_add_device_fn) { - LLAMA_LOG_ERROR("%s: failed to find RPC device add function\n", __func__); - llama_model_free(model); - return nullptr; - } - - for (const std::string & server : model->rpc_servers) { - ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str()); - if (dev) { - model->devices.push_back(dev); - } else { - LLAMA_LOG_ERROR("%s: failed to add RPC device for server '%s'\n", __func__, server.c_str()); - llama_model_free(model); - return nullptr; - } - } - } - // create list of devices to use with this model if (params.devices) { for (ggml_backend_dev_t * dev = params.devices; *dev; ++dev) { model->devices.push_back(*dev); } } else { + std::vector rpc_servers; // use all available devices for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { ggml_backend_dev_t dev = ggml_backend_dev_get(i); @@ -9461,10 +9462,19 @@ struct llama_model * llama_model_load_from_file( break; case GGML_BACKEND_DEVICE_TYPE_GPU: - model->devices.push_back(dev); + ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev); + if (ggml_backend_reg_name(reg) == std::string("RPC")) { + rpc_servers.push_back(dev); + } else { + model->devices.push_back(dev); + } break; } } + // add RPC servers at the front of the list + if (!rpc_servers.empty()) { + model->devices.insert(model->devices.begin(), rpc_servers.begin(), rpc_servers.end()); + } } // if using single GPU mode, remove all except the main GPU @@ -9485,7 +9495,7 @@ struct llama_model * llama_model_load_from_file( LLAMA_LOG_INFO("%s: using device %s (%s) - %zu MiB free\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), free/1024/1024); } - const int status = llama_model_load(path_model, *model, params); + const int status = llama_model_load(path_model, splits, *model, params); GGML_ASSERT(status <= 0); if (status < 0) { if (status == -1) { @@ -9501,6 +9511,35 @@ struct llama_model * llama_model_load_from_file( return model; } +// deprecated +struct llama_model * llama_load_model_from_file( + const char * path_model, + struct llama_model_params params) { + return llama_model_load_from_file(path_model, params); +} + +struct llama_model * llama_model_load_from_file( + const char * path_model, + struct llama_model_params params) { + std::vector splits = {}; + return llama_model_load_from_file_impl(path_model, splits, params); +} + +struct llama_model * llama_model_load_from_splits( + const char ** paths, + size_t n_paths, + struct llama_model_params params) { + std::vector splits; + if (n_paths == 0) { + LLAMA_LOG_ERROR("%s: list of splits is empty\n", __func__); + return nullptr; + } + for (size_t i = 0; i < n_paths; ++i) { + splits.push_back(paths[i]); + } + return llama_model_load_from_file_impl(splits.front(), splits, params); +} + struct llama_context * llama_init_from_model( struct llama_model * model, struct llama_context_params params) { diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h index a184884c77a..61907ed404d 100644 --- a/examples/talk-llama/llama.h +++ b/examples/talk-llama/llama.h @@ -288,9 +288,6 @@ extern "C" { // proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices() const float * tensor_split; - // comma separated list of RPC servers to use for offloading - const char * rpc_servers; - // Called with a progress value between 0.0 and 1.0. Pass NULL to disable. // If the provided progress_callback returns true, model loading continues. // If it returns false, model loading is immediately aborted. @@ -418,10 +415,20 @@ extern "C" { struct llama_model_params params), "use llama_model_load_from_file instead"); + // Load the model from a file + // If the file is split into multiple parts, the file name must follow this pattern: -%05d-of-%05d.gguf + // If the split file name does not follow this pattern, use llama_model_load_from_splits LLAMA_API struct llama_model * llama_model_load_from_file( const char * path_model, struct llama_model_params params); + // Load the model from multiple splits (support custom naming scheme) + // The paths must be in the correct order + LLAMA_API struct llama_model * llama_model_load_from_splits( + const char ** paths, + size_t n_paths, + struct llama_model_params params); + DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model), "use llama_model_free instead"); @@ -503,7 +510,8 @@ extern "C" { LLAMA_API uint64_t llama_model_size(const struct llama_model * model); // Get the default chat template. Returns nullptr if not available - LLAMA_API const char * llama_model_chat_template(const struct llama_model * model); + // If name is NULL, returns the default chat template + LLAMA_API const char * llama_model_chat_template(const struct llama_model * model, const char * name); // Returns the total number of parameters in the model LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model); @@ -951,7 +959,7 @@ extern "C" { LLAMA_API llama_token llama_vocab_fim_rep(const struct llama_vocab * vocab); LLAMA_API llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab); - DEPRECATED(LLAMA_API const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token), "use llama_vocabable_get_text instead"); + DEPRECATED(LLAMA_API const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_text instead"); DEPRECATED(LLAMA_API float llama_token_get_score(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_score instead"); DEPRECATED(LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_attr instead"); DEPRECATED(LLAMA_API bool llama_token_is_eog(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_is_eog instead"); @@ -1191,6 +1199,18 @@ extern "C" { const char * grammar_str, const char * grammar_root); + /// @details Lazy grammar sampler, introduced in https://github.com/ggerganov/llama.cpp/pull/9639 + /// @param trigger_words A list of words that will trigger the grammar sampler. This may be updated to a loose regex syntax (w/ ^) in a near future. + /// @param trigger_tokens A list of tokens that will trigger the grammar sampler. + LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens); + /// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first. LLAMA_API struct llama_sampler * llama_sampler_init_penalties( int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size) diff --git a/examples/talk-llama/unicode.cpp b/examples/talk-llama/unicode.cpp index 7aca6544bc7..89180da4152 100644 --- a/examples/talk-llama/unicode.cpp +++ b/examples/talk-llama/unicode.cpp @@ -7,18 +7,17 @@ #include #include +#include #include #include +#include #include #include #include #include #include -#include #include #include -#include -#include size_t unicode_len_utf8(char src) { const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; From cadfc50eabb46829a0d5ac7ffcb3778ad60a1257 Mon Sep 17 00:00:00 2001 From: billyct Date: Tue, 4 Feb 2025 04:49:06 +0800 Subject: [PATCH 57/64] node : add max_len params in node addon (#2760) --- examples/addon.node/addon.cpp | 2 ++ examples/addon.node/index.js | 1 + 2 files changed, 3 insertions(+) diff --git a/examples/addon.node/addon.cpp b/examples/addon.node/addon.cpp index 4ada6ca5084..656bfe3f1dc 100644 --- a/examples/addon.node/addon.cpp +++ b/examples/addon.node/addon.cpp @@ -330,6 +330,7 @@ Napi::Value whisper(const Napi::CallbackInfo& info) { bool no_timestamps = whisper_params.Get("no_timestamps").As(); int32_t audio_ctx = whisper_params.Get("audio_ctx").As(); bool comma_in_time = whisper_params.Get("comma_in_time").As(); + int32_t max_len = whisper_params.Get("max_len").As(); Napi::Value pcmf32Value = whisper_params.Get("pcmf32"); std::vector pcmf32_vec; @@ -352,6 +353,7 @@ Napi::Value whisper(const Napi::CallbackInfo& info) { params.audio_ctx = audio_ctx; params.pcmf32 = pcmf32_vec; params.comma_in_time = comma_in_time; + params.max_len = max_len; Napi::Function callback = info[1].As(); Worker* worker = new Worker(callback, params); diff --git a/examples/addon.node/index.js b/examples/addon.node/index.js index 643ee756452..65fa31f8784 100644 --- a/examples/addon.node/index.js +++ b/examples/addon.node/index.js @@ -18,6 +18,7 @@ const whisperParams = { translate: true, no_timestamps: false, audio_ctx: 0, + max_len: 0, }; const arguments = process.argv.slice(2); From eb9e5032c4f91feb6323642ae700a56f505b3980 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 4 Feb 2025 09:30:08 +0200 Subject: [PATCH 58/64] ci : add stalebot --- close-issue.yml | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 close-issue.yml diff --git a/close-issue.yml b/close-issue.yml new file mode 100644 index 00000000000..276a217d450 --- /dev/null +++ b/close-issue.yml @@ -0,0 +1,28 @@ +name: Close inactive issues +on: + schedule: + - cron: "42 0 * * *" + +# Fine-grant permission +# https://docs.github.com/en/actions/security-for-github-actions/security-guides/automatic-token-authentication#modifying-the-permissions-for-the-github_token +permissions: + issues: write + +jobs: + close-issues: + runs-on: ubuntu-latest + permissions: + issues: write + pull-requests: write + steps: + - uses: actions/stale@v5 + with: + exempt-issue-labels: "refactor,help wanted,good first issue,research,bug,roadmap" + days-before-issue-stale: 30 + days-before-issue-close: 14 + stale-issue-label: "stale" + close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale." + days-before-pr-stale: -1 + days-before-pr-close: -1 + operations-per-run: 10000 + repo-token: ${{ secrets.GITHUB_TOKEN }} From 898c0cb9d11bf25376d02d943b1887aff1a2342f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 4 Feb 2025 10:50:10 +0200 Subject: [PATCH 59/64] readme : add maintenance roadmap --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 62afad47b3a..14609866fc6 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,9 @@ [![Conan Center](https://shields.io/conan/v/whisper-cpp)](https://conan.io/center/whisper-cpp) [![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/) +> [!NOTE] +> New maintenance roadmap: https://github.com/ggerganov/whisper.cpp/discussions/2788 + Stable: [v1.7.4](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.7.4) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126) High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model: From 16245b35e4e8c511eb4760d3475760fd3eaf425e Mon Sep 17 00:00:00 2001 From: Christian Kastner Date: Tue, 4 Feb 2025 00:17:15 +0100 Subject: [PATCH 60/64] cmake: Add ability to pass in GGML_BUILD_NUMBER (ggml/1096) This makes git as a dependency optional, and is useful in the case where ggml is built not from git, but from a tarball, or a distribution source package. This conditional also affects GGML_BUILD_COMMIT. Nothing seems to be using it, though, so there doesn't seem much value factor it out, or even require it. --- ggml/CMakeLists.txt | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 7c069e42037..75b5ea3b439 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -274,22 +274,25 @@ endif() # Generate version info based on git commit. -find_program(GIT_EXE NAMES git git.exe REQUIRED NO_CMAKE_FIND_ROOT_PATH) -execute_process(COMMAND ${GIT_EXE} rev-list --count HEAD - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} - OUTPUT_VARIABLE GGML_BUILD_NUMBER - OUTPUT_STRIP_TRAILING_WHITESPACE -) - -if(GGML_BUILD_NUMBER EQUAL 1) - message(WARNING "GGML build version fixed at 1 likely due to a shallow clone.") +if(NOT DEFINED GGML_BUILD_NUMBER) + find_program(GIT_EXE NAMES git git.exe REQUIRED NO_CMAKE_FIND_ROOT_PATH) + execute_process(COMMAND ${GIT_EXE} rev-list --count HEAD + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + OUTPUT_VARIABLE GGML_BUILD_NUMBER + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + + if(GGML_BUILD_NUMBER EQUAL 1) + message(WARNING "GGML build version fixed at 1 likely due to a shallow clone.") + endif() + + execute_process(COMMAND ${GIT_EXE} rev-parse --short HEAD + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + OUTPUT_VARIABLE GGML_BUILD_COMMIT + OUTPUT_STRIP_TRAILING_WHITESPACE + ) endif() -execute_process(COMMAND ${GIT_EXE} rev-parse --short HEAD - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} - OUTPUT_VARIABLE GGML_BUILD_COMMIT - OUTPUT_STRIP_TRAILING_WHITESPACE -) # Capture variables prefixed with GGML_. From dbcc669e1ae46d737dfab06ffb907676b3776583 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 4 Feb 2025 13:03:09 +0200 Subject: [PATCH 61/64] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 34f1cbf69ee..26a105f64f6 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -498e0ecd2c4f9379439fd413805af10e8e9ff349 +694244a6e40dc255f6bb4376fb17431c06633e6c From 33ea03f1315b25eacdcc16bf4b8ab33c90137bc7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 4 Feb 2025 13:03:40 +0200 Subject: [PATCH 62/64] authors : update --- AUTHORS | 211 +++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 210 insertions(+), 1 deletion(-) diff --git a/AUTHORS b/AUTHORS index 33e6c9649c7..f523e0a7224 100644 --- a/AUTHORS +++ b/AUTHORS @@ -1,34 +1,51 @@ -# date: Tue Apr 9 20:27:03 EEST 2024 +# date: Tue Feb 4 13:03:35 EET 2025 # this file is auto-generated by scripts/gen-authors.sh 0/0 0cc4m 0xsourcecode <134374803+0xsourcecode@users.noreply.github.com> +65a <10104049+65a@users.noreply.github.com> +AIWintermuteAI <32562299+AIWintermuteAI@users.noreply.github.com> AT Aarni Koskela Aaron Pham <29749331+aarnphm@users.noreply.github.com> Aaron Taylor Abhilash Majumder <30946547+abhilash1910@users.noreply.github.com> Abitofevrything <54505189+abitofevrything@users.noreply.github.com> +Adam Jones +Adrien Gallouët +Adrien Gallouët AfryMask Ahmad Bilal +Ahmad Tameem <113388789+Tameem-10xE@users.noreply.github.com> AidanBeltonS <87009434+AidanBeltonS@users.noreply.github.com> +AidanBeltonS +Akarshan Biswas +Akarshan Biswas Akash Mahajan Akash Mahajan Al Hoang <3811822-hoanga@users.noreply.gitlab.com> Alan +Albert Jin +Alberto Cabrera Pérez +Alberto Cabrera Pérez Aleksander Andrzejewski <18704749+aleksanderandrzejewski@users.noreply.github.com> Alex Azarov Alex Bacart <13940752+alex-bacart@users.noreply.github.com> Alex Evgrashin +Alex O'Connell <35843486+acon96@users.noreply.github.com> Alexandr Graschenkov Alexandru Mariuti Alexey Kharlamov Alfredo Montesinos Ali Alameh +Alter <0x7c48@gmail.com> Ananta Bastola +Andreas Kieslinger <47689530+aendk@users.noreply.github.com> +Andreas Lubbe Andreu Huguet Andrew Huynh +Andrew Minh Nguyen <40281306+amqdn@users.noreply.github.com> Andrew S Andy Maloney Anton Kostin @@ -40,8 +57,11 @@ AustinMroz Avik Sengupta Bader-eddine Ouaich <49657842+baderouaich@users.noreply.github.com> Baffin Lee +Ben Ashbaugh Ben Nortier Benjamin Heiniger +Bernhard M. Wiedemann +Binozo <70137898+Binozo@users.noreply.github.com> Bo-Yi Wu Boris Bliznioukov Borislav Stanimirov @@ -49,47 +69,86 @@ Brad Murray <59848399+bradmurray-dt@users.noreply.github.com> Brian Murray CRD716 Canis Lupus +Carlos Zoido Carolinabanana <140120812+Carolinabanana@users.noreply.github.com> +CarterLi999 <664681047@qq.com> ChangSeok Oh +Changyeon Kim Chaoqun <27287694+OpenWaygate@users.noreply.github.com> +Charles Xu <63788048+chaxu01@users.noreply.github.com> +Charles Xu +Chen Xi +Chen Xi +Chenguang Li <87689256+noemotiovon@users.noreply.github.com> Chia-Hsiang Cheng <88014292+garychia@users.noreply.github.com> Chidi Williams +Chris Elrod Christian <12550267+iceychris@users.noreply.github.com> +Christian Kastner Clifford Heath +Clint Herron Colin +Conrad Kramer +Corey Earwood +CrispStrobe <154636388+CrispStrobe@users.noreply.github.com> +DAN™ DGdev91 Damian Czaja +Dan Johansson <164997844+eddnjjn@users.noreply.github.com> +Dan Johansson Daniel Bevenius +Daniel Valdivia <18384552+dvaldivia@users.noreply.github.com> +Daniel Ziegenberg +Daniele <57776841+daniandtheweb@users.noreply.github.com> +Dave +Dave Airlie +Dave Airlie +Daven Sanassy David David Thorpe +DavidKorczynski Davidson Francis Dener Stassun +Dibakar Gope Didzis Gosko +Diego Devesa Digipom Dimo +Djip007 <3705339+Djip007@users.noreply.github.com> +Djip007 Dody Suria Wijaya +Dou Xinpeng <15529241576@163.com> +Dou Xinpeng <81913537+Dou-Git@users.noreply.github.com> Dr. Tom Murphy VII Ph.D <499244+tom7@users.noreply.github.com> Duncan McConnell Egor Egorov Elkana Bardugo Emmanuel Schmidbauer Engininja2 <139037756+Engininja2@users.noreply.github.com> +Eric Curtin Eric Swanson Eric Tendian +Eric Zhang <34133756+EZForever@users.noreply.github.com> Erik Scholz Evan Jones Evan Martin Eve <139727413+netrunnereve@users.noreply.github.com> Evgeny Kuznetsov F1L1P <78918286+F1L1Pv2@users.noreply.github.com> +Faisal Zaghloul Fangjun Kuang Felix Finn Voorhees +FirstTimeEZ <179362031+FirstTimeEZ@users.noreply.github.com> FlippFuzz <41221030+FlippFuzz@users.noreply.github.com> +Frankie Robertson Gang Chen Gavin Cai George Hindle Georgi Gerganov +Gilad S <7817232+giladgd@users.noreply.github.com> +Gilad S +Gilad S. <7817232+giladgd@users.noreply.github.com> GitAritron <103900385+GitAritron@users.noreply.github.com> GiviMAD Gleicon Moraes @@ -98,41 +157,66 @@ Guillaume Wenzek HY. Kelvin Lee <34256578+hykelvinlee42@users.noreply.github.com> Halalaluyafail3 <55773281+Halalaluyafail3@users.noreply.github.com> Hang +Haus1 Herman Semenov +HimariO +Hong Bo PENG Hrishikesh Barman +Hugo Ian Bicking Ian Bull +Ihar Hrachyshka Ikko Ashimine +Ikko Eltociear Ashimine InconsolableCellist <23345188+InconsolableCellist@users.noreply.github.com> Ismatulla Mansurov <47342870+sapoepsilon@users.noreply.github.com> +Ivan +Ivan Filipov <159561759+vanaka11@users.noreply.github.com> Ivan Gorin +Ivo von Putzer Reibegg JJ <103335846+computerscienceiscool@users.noreply.github.com> Jack Mousseau JacobLinCool Jakub Ráček Jared Van Bortel Jay Binks +Jayant +Jeff Bolz +Jeroen Mostert Jhen-Jie Hong Jhen-Jie Hong JidongZhang-THU <1119708529@qq.com> Jo Liss +Joe Todd Johan Johannes Gäßler John Balis +JohnnyB Jonathan Soo Jonno <1160532+razodactyl@users.noreply.github.com> Joonas Pihlajamaa Jose <34888496+Jerry-Master@users.noreply.github.com> Josh Bleecher Snyder +Josscii Judd Jumper775 <78500318+jumpers775@users.noreply.github.com> +Jun Hee Yoo +Junil Kim +Justina Cho Justine Tunney +Justine Tunney +KITAITI Makoto KP Kaiser Kamilake +Karol Kontny <82021046+kkontny@users.noreply.github.com> +Karthick Kartik Saranathan <278928+Kartiku@users.noreply.github.com> Kasumi <90275229+kasumi-1@users.noreply.github.com> Kawrakow <48489457+ikawrakow@users.noreply.github.com> +Kendrick Taylor Kevin Brothaler +Kevin Gibbons +Konosuke Sakai Konstantin Zhuravlyov Kreijstal Kylin <56434533+KyL0N@users.noreply.github.com> @@ -147,56 +231,110 @@ Luis Herrera Lukas Rist M. A. Ali <73258591+MightyStud@users.noreply.github.com> M. Eren Akbiyik +Ma Mingfei Maciek +Mahesh Madhav <67384846+heshpdx@users.noreply.github.com> Marcin Mielniczuk +Mark Karpelès +Mark Zhuang +Markus Tavenrath +Martin Delille Martin Warnaar +Masaya, Kato <62578291+msy-kato@users.noreply.github.com> Matheus de Sousa <23645013+keyehzy@users.noreply.github.com> +Mathieu Baudier Mathijs de Bruin Matija Pevec +Matt Stephenson +Max Krasnyansky +Max Krasnyansky Maximiliano Levi <8160966+maxilevi@users.noreply.github.com> Meng, Hengyu +Mengqing Cao Michael Podvitskiy Michael Rienstra Mikhail Grigorev Mohammadreza Hendiani Mohit Agarwal +Molly Sophia Murilo Santana +NETZkultur GmbH +Natsu Neil Chudleigh +Neo Zhang <14088817+arthw@users.noreply.github.com> Neo Zhang Jianyu Neuman Vong +Nicholai Tukanov Nicholas Albion +Nico Bosshard +Nicolò Scipione Niels Mayer +Nikita Sarychev <42014488+sARY77@users.noreply.github.com> +Nikolaj Olsson Okabintaro <103938900+Okabintaro@users.noreply.github.com> Oleg Sidorov Oleg Sidorov +Olivier Chafik Ondrej Kokes Ouadie EL FAROUKI +PAB Paul Tsochantaris +Pedro Probst +Peng +Peter Philipp Zabel Philippe Normand +Philippe Normand +Plamen Minev +Prashant Vithule <119530321+Vithulep@users.noreply.github.com> Przemysław Pawełczyk Qianhe Chen <54462604+chenqianhe@users.noreply.github.com> +R0CKSTAR +R0CKSTAR +Radoslav Gerganov Radosław Gryta +Rahul Vadhyar <107788610+RahulVadhyar@users.noreply.github.com> +Raiya Araki <83504221+rai62@users.noreply.github.com> Reinforce-II Reinis Muiznieks RelatedTitle +Rémy Oudompheng RhinoDevel Rich Jones +Robert Ormandi <52251610+ormandi@users.noreply.github.com> Robin Roddur Dasgupta Roland Rabien +Romain Biessy +Ronsor Rotem Dan Ryan Hitchman Ryan Metcalfe <107415876+RyanMetcalfeInt8@users.noreply.github.com> RyanChang +SRHMorris <69468379+SRHMorris@users.noreply.github.com> +SXX +Sacha Arbonel +Salman Faroz +Salvatore Mesoraca Sam <49637763+Onlyartist9@users.noreply.github.com> Sam Pullara +Samuel Durante <44513615+samueldurantes@users.noreply.github.com> Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> +Sandro Hanea <40202887+sandrohanea@users.noreply.github.com> +Sergio López Sergio López +Shanshan Shen <467638484@qq.com> +Shijie <821898965@qq.com> +Shupei Fan Siddharth Ramakrishnan +Sigbjørn Skjæret Simon Moisselin Sindre Sorhus Slava Primenko +Srihari-mcw <96763064+Srihari-mcw@users.noreply.github.com> +Stavros Panakakis <53979866+Stavrospanakakis@users.noreply.github.com> +Stefan Sydow +Stefan Sydow Syahmi Azhar Syed Jafri Sơn Phan Trung @@ -205,37 +343,63 @@ Takeshi Inoue Tamotsu Takahashi Taras Glek Tauseef Mohiuddin <35351464+tauseefmohammed2@users.noreply.github.com> +Thamster Thijs Raymakers Thomas Fitzsimmons Tiago Fassoni Tienshiao Ma +Tim Miller Timothy Cronin <40186632+4imothy@users.noreply.github.com> Tobrun Todd +Toliver Tong Li <31761981+litongjava@users.noreply.github.com> +Tony Wasserka <4840017+neobrain@users.noreply.github.com> Topping1 <78745143+Topping1@users.noreply.github.com> Travis Cline UEXTM.com <84163508+uextm@users.noreply.github.com> +UsernamesLame <156965854+UsernamesLame@users.noreply.github.com> Vadim Peretokin Valentin Gosu <1454649+valenting@users.noreply.github.com> +Vin Misra Vulcan <93451215+trholding@users.noreply.github.com> WhiteOlivierus <36532695+WhiteOlivierus@users.noreply.github.com> +William Tambellini +William Tambellini +Wilson Silva Xiang (Kevin) Li Xiao-Yong Jin XiaotaoChen +Xingchen Song(宋星辰) +Xinpeng Dou <81913537+Dou-Git@users.noreply.github.com> +Xuan Son Nguyen Yajing Tang Yang Shen Yunès +Yuri Khrustalev +Yusuf Redžić <48274562+redzic@users.noreply.github.com> ZaBlazzingZephyrus <119159668+blazingzephyr@users.noreply.github.com> +Zhenwei Jin <109658203+kylo5aby@users.noreply.github.com> +Zhiyuan Li +Zhiyuan Li Zigfrid Zvezdin Zollner <24618122+Zolliner@users.noreply.github.com> +a3sh <38979186+A3shTnT@users.noreply.github.com> +ag2s20150909 <19373730+ag2s20150909@users.noreply.github.com> +agray3 ai-at-home <149282006+ai-at-home@users.noreply.github.com> +aldorof alonfaraj +amd-dwang +amritahs-ibm andypayne ardfork <134447697+ardfork@users.noreply.github.com> +arizhih <40765267+arizhih@users.noreply.github.com> automaticcat +bandoti <141645996+bandoti@users.noreply.github.com> be-next bert hubert +billyct bmwl bobqianic <129547291+bobqianic@users.noreply.github.com> bocytko @@ -248,7 +412,9 @@ byte-6174 <88070277+byte-6174@users.noreply.github.com> cdosoftei clach04 compilade <113953597+compilade@users.noreply.github.com> +compilade conradg +crummyh ddpasa <112642920+ddpasa@users.noreply.github.com> denersc dscripka @@ -256,28 +422,55 @@ duthils ecneladis faker fitzsim +fj-y-saito <85871716+fj-y-saito@users.noreply.github.com> fraxy-v <65565042+fraxy-v@users.noreply.github.com> genevera (she/her) geniusnut +gilbertgong +gn64 +goldwaving <77494627+goldwaving@users.noreply.github.com> greeshmay +haopeng <657407891@qq.com> +hipudding +hsinhoyeh hydai iamthad +issixx <46835150+issixx@users.noreply.github.com> james wolf +jdomke <28772296+jdomke@users.noreply.github.com> +jettoblack +jiez <373447296@qq.com> joecryptotoo <80373433+joecryptotoo@users.noreply.github.com> jorismertz <35079666+jorismertz@users.noreply.github.com> +junchao-loongson <68935141+junchao-loongson@users.noreply.github.com> junkfood <69683722+JunkFood02@users.noreply.github.com> jwijffels +k.h.lai kamranjon katsu560 kennethge <57784063+kenneth-ge@users.noreply.github.com> keyehzy +kunnis +l3utterfly leejet +leo-pony +lhez litong <31761981+litongjava@users.noreply.github.com> +liuwei-git <14815172+liuwei-git@users.noreply.github.com> lnyan +luoyu-intel m.bell +mahorozte <41834471+mahorozte@users.noreply.github.com> +mashizora <30516315+mashizora@users.noreply.github.com> +matt23654 +matteo +mgrachten mkiol +mky_coder <47767389+mkycoder@users.noreply.github.com> novag <7754358+novag@users.noreply.github.com> pajowu +pengxin99 +petterreinholdtsen polarmoon <90010972+polarmoon@users.noreply.github.com> rlapray sandrohanea <40202887+sandrohanea@users.noreply.github.com> @@ -287,15 +480,31 @@ shikokuchuo <53399081+shikokuchuo@users.noreply.github.com> slaren slashlib snadampal <87143774+snadampal@users.noreply.github.com> +someone13574 <81528246+someone13574@users.noreply.github.com> st-gr <38470677+st-gr@users.noreply.github.com> +stduhpf +stormofice <58337328+stormofice@users.noreply.github.com> texmex76 <40733439+texmex76@users.noreply.github.com> thefinaldegree +thewh1teagle <61390950+thewh1teagle@users.noreply.github.com> +toboil-features <160222185+toboil-features@users.noreply.github.com> trixirt ulatekh undef +uvos +uvos +valVk venkr vicalloy +wangshuai09 <391746016@qq.com> +woachk <24752637+woachk@users.noreply.github.com> +xctan xdrudis +yuri@FreeBSD +zhangjixiong +zhentaoyu zhouwg <6889919+zhouwg@users.noreply.github.com> +zhouwg +谢乃闻 布客飞龙 <562826179@qq.com> Артём Земляк From 46d07b9c8506246bea4f2939470acfbbc63ce675 Mon Sep 17 00:00:00 2001 From: midnight Date: Wed, 5 Feb 2025 04:41:10 -0800 Subject: [PATCH 63/64] cmake : fix compile assumptions for power9/etc (#2777) * Add small comment re: VSX to readme Co-authored-by: midnight --- README.md | 16 +++++++++++++++- ggml/src/ggml-cpu/CMakeLists.txt | 18 +++++++----------- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 14609866fc6..9748969c30f 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisp - Plain C/C++ implementation without dependencies - Apple Silicon first-class citizen - optimized via ARM NEON, Accelerate framework, Metal and [Core ML](#core-ml-support) - AVX intrinsics support for x86 architectures -- VSX intrinsics support for POWER architectures +- [VSX intrinsics support for POWER architectures](#power-vsx-intrinsics) - Mixed F16 / F32 precision - [Integer quantization support](#quantization) - Zero memory allocations at runtime @@ -139,6 +139,20 @@ make -j large-v3-turbo | medium | 1.5 GiB | ~2.1 GB | | large | 2.9 GiB | ~3.9 GB | +## POWER VSX Intrinsics + +`whisper.cpp` supports POWER architectures and includes code which +significantly speeds operation on Linux running on POWER9/10, making it +capable of faster-than-realtime transcription on underclocked Raptor +Talos II. Ensure you have a BLAS package installed, and replace the +standard cmake setup with: + +```bash +# build with GGML_BLAS defined +cmake -B build -DGGML_BLAS=1 +cmake --build build --config Release +./build/bin/whisper-cli [ .. etc .. ] + ## Quantization `whisper.cpp` supports integer quantization of the Whisper `ggml` models. diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 6b3641c4263..26533e512ae 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -279,19 +279,15 @@ function(ggml_add_cpu_backend_variant_impl tag_name) endif() elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64") message(STATUS "PowerPC detected") - execute_process(COMMAND bash -c "grep POWER10 /proc/cpuinfo | head -n 1" OUTPUT_VARIABLE POWER10_M) - string(FIND "${POWER10_M}" "POWER10" substring_index) - if (NOT DEFINED substring_index OR "${substring_index}" STREQUAL "") - set(substring_index -1) - endif() - - if (${substring_index} GREATER_EQUAL 0) - list(APPEND ARCH_FLAGS -mcpu=power10) + execute_process(COMMAND bash -c "grep POWER /proc/cpuinfo | head -n 1" OUTPUT_VARIABLE POWER_M) + if (${POWER_M} MATCHES "POWER10") + list(APPEND ARCH_FLAGS -mcpu=power10) + elseif (${POWER_M} MATCHES "POWER9") + list(APPEND ARCH_FLAGS -mcpu=power9) elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le") - list(APPEND ARCH_FLAGS -mcpu=powerpc64le) + list(APPEND ARCH_FLAGS -mcpu=powerpc64le -mtune=native) else() - list(APPEND ARCH_FLAGS -mcpu=native -mtune=native) - # TODO: Add targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be) + list(APPEND ARCH_FLAGS -mcpu=powerpc64 -mtune=native) endif() elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64") message(STATUS "loongarch64 detected") From d682e150908e10caa4c15883c633d7902d385237 Mon Sep 17 00:00:00 2001 From: Judd Date: Thu, 6 Feb 2025 15:37:21 +0800 Subject: [PATCH 64/64] Fixes for Windows (#2790) Fixes for Windows: * MSVC default to utf-8 without BOM. * Console output code page changed to utf-8. --------- Co-authored-by: Judd --- cmake/build-info.cmake | 2 ++ examples/cli/cli.cpp | 12 ++++++++++++ 2 files changed, 14 insertions(+) diff --git a/cmake/build-info.cmake b/cmake/build-info.cmake index ea3dc55c834..b293c9b5d83 100644 --- a/cmake/build-info.cmake +++ b/cmake/build-info.cmake @@ -42,6 +42,8 @@ endif() if(MSVC) set(BUILD_COMPILER "${CMAKE_C_COMPILER_ID} ${CMAKE_C_COMPILER_VERSION}") set(BUILD_TARGET ${CMAKE_VS_PLATFORM_NAME}) + add_compile_options("$<$:/utf-8>") + add_compile_options("$<$:/utf-8>") else() execute_process( COMMAND sh -c "$@ --version | head -1" _ ${CMAKE_C_COMPILER} diff --git a/examples/cli/cli.cpp b/examples/cli/cli.cpp index 57fbf5bff57..36ce26827ce 100644 --- a/examples/cli/cli.cpp +++ b/examples/cli/cli.cpp @@ -12,6 +12,11 @@ #include #include +#if defined(_WIN32) +#define NOMINMAX +#include +#endif + #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data #endif @@ -916,6 +921,13 @@ static bool output_lrc(struct whisper_context * ctx, const char * fname, const w static void cb_log_disable(enum ggml_log_level , const char * , void * ) { } int main(int argc, char ** argv) { +#if defined(_WIN32) + // Set the console output code page to UTF-8, while command line arguments + // are still encoded in the system's code page. In this way, we can print + // non-ASCII characters to the console, and access files with non-ASCII paths. + SetConsoleOutputCP(CP_UTF8); +#endif + whisper_params params; // If the only argument starts with "@", read arguments line-by-line