Skip to content

Commit

Permalink
metal : add POOL2D and fix IM2COL (#9943)
Browse files Browse the repository at this point in the history
* add pool_2d

Signed-off-by: Junhee Yoo <[email protected]>

* fix im2col and add unittest for N>=1024

Signed-off-by: Junhee Yoo <[email protected]>

* add tests for N % 1024 != 0

Signed-off-by: Junhee Yoo <[email protected]>

* remove trailing whitespaces

Signed-off-by: Junhee Yoo <[email protected]>

* apply suggestions

Signed-off-by: Junhee Yoo <[email protected]>

* apply more optimization

- original IM2COL kernel + _ext with MIN()

Signed-off-by: Junhee Yoo <[email protected]>

* apply review: change kernel name of pool_2d

Signed-off-by: Junhee Yoo <[email protected]>

* apply review

Signed-off-by: Junhee Yoo <[email protected]>

* fix more formatting and enhance readability

Signed-off-by: Junhee Yoo <[email protected]>

---------

Signed-off-by: Junhee Yoo <[email protected]>
  • Loading branch information
junhee-yoo authored Oct 23, 2024
1 parent 873279b commit 4c9388f
Show file tree
Hide file tree
Showing 3 changed files with 299 additions and 19 deletions.
130 changes: 111 additions & 19 deletions ggml/src/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
GGML_METAL_KERNEL_TYPE_PAD_F32,
GGML_METAL_KERNEL_TYPE_ARANGE_F32,
Expand Down Expand Up @@ -272,6 +274,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_SIN,
GGML_METAL_KERNEL_TYPE_COS,
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,

GGML_METAL_KERNEL_TYPE_COUNT
};
Expand Down Expand Up @@ -685,6 +689,8 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
Expand Down Expand Up @@ -716,6 +722,8 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
}

[metal_library release];
Expand Down Expand Up @@ -844,8 +852,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_IM2COL:
return op->src[0]->type == GGML_TYPE_F16;
case GGML_OP_POOL_1D:
case GGML_OP_POOL_2D:
return false;
case GGML_OP_POOL_2D:
case GGML_OP_UPSCALE:
case GGML_OP_PAD:
case GGML_OP_ARANGE:
Expand Down Expand Up @@ -2545,6 +2553,8 @@ static void ggml_metal_encode_node(
} break;
case GGML_OP_IM2COL:
{
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
Expand Down Expand Up @@ -2574,30 +2584,54 @@ static void ggml_metal_encode_node(
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;

id<MTLComputePipelineState> pipeline = nil;
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline;

const bool is_gt_mttpt = ((size_t)(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup;

switch (dst->type) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
case GGML_TYPE_F32: {
pipeline = (is_gt_mttpt ?
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32].pipeline
:
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline);
} break;
case GGML_TYPE_F16: {
pipeline = (is_gt_mttpt ?
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16].pipeline
:
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline);
} break;
default: GGML_ABORT("fatal error");
};

[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
[encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
[encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
[encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
[encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
[encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
[encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
[encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
[encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
[encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];

[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ofs0 length:sizeof(int32_t) atIndex:2];
[encoder setBytes:&ofs1 length:sizeof(int32_t) atIndex:3];
[encoder setBytes:&IW length:sizeof(int32_t) atIndex:4];
[encoder setBytes:&IH length:sizeof(int32_t) atIndex:5];
[encoder setBytes:&CHW length:sizeof(int32_t) atIndex:6];
[encoder setBytes:&s0 length:sizeof(int32_t) atIndex:7];
[encoder setBytes:&s1 length:sizeof(int32_t) atIndex:8];
[encoder setBytes:&p0 length:sizeof(int32_t) atIndex:9];
[encoder setBytes:&p1 length:sizeof(int32_t) atIndex:10];
[encoder setBytes:&d0 length:sizeof(int32_t) atIndex:11];
[encoder setBytes:&d1 length:sizeof(int32_t) atIndex:12];

if (is_gt_mttpt) {
[encoder setBytes:&N length:sizeof(int32_t) atIndex:13];
[encoder setBytes:&KH length:sizeof(int32_t) atIndex:14];
[encoder setBytes:&KW length:sizeof(int32_t) atIndex:15];

const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N);

const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0);

[encoder dispatchThreadgroups:MTLSizeMake(quotient * CHW, OH, OW) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
} else {
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
}
} break;
case GGML_OP_UPSCALE:
{
Expand Down Expand Up @@ -3001,6 +3035,64 @@ static void ggml_metal_encode_node(

[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_POOL_2D:
{
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0t == GGML_TYPE_F32 && src0t == dstt);

const int32_t * opts = dst->op_params;
enum ggml_op_pool op = opts[0];

id<MTLComputePipelineState> pipeline = nil;
switch (src0t) {
case GGML_TYPE_F32: {
switch(op) {
case GGML_OP_POOL_AVG:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32].pipeline; break;
case GGML_OP_POOL_MAX:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32].pipeline; break;
default: GGML_ASSERT(false && "not implemented");
}
} break;
default: GGML_ASSERT(false && "not implemented");
}

const int32_t k0 = opts[1];
const int32_t k1 = opts[2];
const int32_t s0 = opts[3];
const int32_t s1 = opts[4];
const int32_t p0 = opts[5];
const int32_t p1 = opts[6];

const int64_t IH = src0->ne[1];
const int64_t IW = src0->ne[0];

const int64_t N = dst->ne[3];
const int64_t OC = dst->ne[2];
const int64_t OH = dst->ne[1];
const int64_t OW = dst->ne[0];

const int64_t parallel_elements = N * OC * OH * OW;
const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;

[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&k0 length:sizeof(int32_t) atIndex:2];
[encoder setBytes:&k1 length:sizeof(int32_t) atIndex:3];
[encoder setBytes:&s0 length:sizeof(int32_t) atIndex:4];
[encoder setBytes:&s1 length:sizeof(int32_t) atIndex:5];
[encoder setBytes:&p0 length:sizeof(int32_t) atIndex:6];
[encoder setBytes:&p1 length:sizeof(int32_t) atIndex:7];
[encoder setBytes:&IH length:sizeof(int64_t) atIndex:8];
[encoder setBytes:&IW length:sizeof(int64_t) atIndex:9];
[encoder setBytes:&OH length:sizeof(int64_t) atIndex:10];
[encoder setBytes:&OW length:sizeof(int64_t) atIndex:11];
[encoder setBytes:&parallel_elements length:sizeof(int64_t) atIndex:12];

[encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
} break;
default:
{
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
Expand Down
178 changes: 178 additions & 0 deletions ggml/src/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -1933,6 +1933,85 @@ kernel void kernel_im2col(
template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;

typedef void (im2col_ext_t)(
device const float * x,
device char * dst,
constant int32_t & ofs0,
constant int32_t & ofs1,
constant int32_t & IW,
constant int32_t & IH,
constant int32_t & CHW,
constant int32_t & s0,
constant int32_t & s1,
constant int32_t & p0,
constant int32_t & p1,
constant int32_t & d0,
constant int32_t & d1,
constant int32_t & N,
constant int32_t & KH,
constant int32_t & KW,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);

template <typename T>
kernel void kernel_im2col_ext(
device const float * x,
device char * dst,
constant int32_t & ofs0,
constant int32_t & ofs1,
constant int32_t & IW,
constant int32_t & IH,
constant int32_t & CHW,
constant int32_t & s0,
constant int32_t & s1,
constant int32_t & p0,
constant int32_t & p1,
constant int32_t & d0,
constant int32_t & d1,
constant int32_t & N,
constant int32_t & KH,
constant int32_t & KW,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
const int32_t KHW = KH * KW; // KHW == ntg[1] * ntg[2], KW == ntg[2]

const int32_t d = tgpig[0] / CHW;
const int32_t chw = tgpig[0] % CHW;
const int32_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
const int32_t HW = tgpig[0] % KHW;

const int32_t tpitg_0 = (d * ntg[0]) + tpitg[0];
if (tpitg_0 >= N) {
return;
}

const int32_t tpitg_1 = HW / KW;
const int32_t tpitg_2 = HW % KW;

const int32_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0;
const int32_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1;

const int32_t offset_dst =
(tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
(tgpig_0 * KHW + tpitg_1 * KW + tpitg_2);

device T * pdst = (device T *) (dst);

if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
pdst[offset_dst] = 0.0f;
} else {
const int32_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
pdst[offset_dst] = x[offset_src + iih * IW + iiw];
}
}

template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;

kernel void kernel_upscale_f32(
device const char * src0,
device char * dst,
Expand Down Expand Up @@ -6372,3 +6451,102 @@ template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t
template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;

kernel void kernel_pool_2d_max_f32(
device const float * src0,
device float * dst,
constant int32_t & k0,
constant int32_t & k1,
constant int32_t & s0,
constant int32_t & s1,
constant int32_t & p0,
constant int32_t & p1,
constant int64_t & IH,
constant int64_t & IW,
constant int64_t & OH,
constant int64_t & OW,
constant int64_t & parallel_elements,
uint gid[[thread_position_in_grid]]) {

if (gid >= parallel_elements) {
return;
}

const int idx = gid;
const int I_HW = IH * IW;
const int O_HW = OH * OW;
const int nc = idx / O_HW;
const int cur_oh = idx % O_HW / OW;
const int cur_ow = idx % O_HW % OW;

device const float * i_ptr = src0 + nc * I_HW;
device float * o_ptr = dst + nc * O_HW;

const int start_h = cur_oh * s1 - p1;
const int bh = MAX(0, start_h);
const int eh = MIN(IH, start_h + k1);
const int start_w = cur_ow * s0 - p0;
const int bw = MAX(0, start_w);
const int ew = MIN(IW, start_w + k0);

float res = -INFINITY;

for (int i = bh; i < eh; i += 1) {
for (int j = bw; j < ew; j += 1) {
res = MAX(res, i_ptr[i * IW + j]);
}
}

o_ptr[cur_oh * OW + cur_ow] = res;
}

kernel void kernel_pool_2d_avg_f32(
device const float * src0,
device float * dst,
constant int32_t & k0,
constant int32_t & k1,
constant int32_t & s0,
constant int32_t & s1,
constant int32_t & p0,
constant int32_t & p1,
constant int64_t & IH,
constant int64_t & IW,
constant int64_t & OH,
constant int64_t & OW,
constant int64_t & parallel_elements,
uint gid[[thread_position_in_grid]]) {

if (gid >= parallel_elements) {
return;
}

const int idx = gid;
const int I_HW = IH * IW;
const int O_HW = OH * OW;
const int nc = idx / O_HW;
const int cur_oh = idx % O_HW / OW;
const int cur_ow = idx % O_HW % OW;

device const float * i_ptr = src0 + nc * I_HW;
device float * o_ptr = dst + nc * O_HW;

const int start_h = cur_oh * s1 - p1;
const int bh = MAX(0, start_h);
const int eh = MIN(IH, start_h + k1);
const int start_w = cur_ow * s0 - p0;
const int bw = MAX(0, start_w);
const int ew = MIN(IW, start_w + k0);
// const float scale = 1. / ((eh - bh) * (ew - bw));
const float scale = 1. / (k0 * k1);

float res = 0;

for (int i = bh; i < eh; i += 1) {
for (int j = bw; j < ew; j += 1) {
float cur = i_ptr[i * IW + j];
res += cur * scale;
}
}

o_ptr[cur_oh * OW + cur_ow] = res;
}
10 changes: 10 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3316,6 +3316,16 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));

// test cases for 2D im2col
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 32}, {3, 3, 1, 32}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 32}, {3, 3, 2, 32}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 1024}, {3, 3, 1, 1024}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 1024}, {3, 3, 2, 1024}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2048}, {3, 3, 1, 2048}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2048}, {3, 3, 2, 2048}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2560}, {3, 3, 1, 2560}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2560}, {3, 3, 2, 2560}, 1, 1, 1, 1, 1, 1, true));

// sycl backend will limit task global_range < MAX_INT
// test cases for 2D im2col with large input W and H (occurs in stable-diffusion)
// however these cases need to alloc more memory which may fail in some devices (Intel Arc770, etc.)
Expand Down

0 comments on commit 4c9388f

Please sign in to comment.