Skip to content

Commit

Permalink
CUDA: fix 1D im2col, add tests (ggml/993)
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler authored and ggerganov committed Oct 23, 2024
1 parent c19af0a commit 80273a3
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 9 deletions.
1 change: 0 additions & 1 deletion ggml/src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3141,7 +3141,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_ROPE:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_IM2COL:
return op->src[0]->type == GGML_TYPE_F16;
case GGML_OP_POOL_2D:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
Expand Down
6 changes: 3 additions & 3 deletions ggml/src/ggml-cuda/im2col.cu
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int64_t OH = is_2D ? dst->ne[2] : 1;
const int64_t OW = dst->ne[1];

const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
const int64_t batch = src1->ne[3];
const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
const int64_t batch = src1->ne[is_2D ? 3 : 2];
const size_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32

if(dst->type == GGML_TYPE_F16) {
im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
Expand Down
36 changes: 31 additions & 5 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3308,15 +3308,41 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}

test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));
// test cases for 1D im2col
// im2col 1D
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
for (int s0 : {1, 3}) {
for (int p0 : {0, 3}) {
for (int d0 : {1, 3}) {
test_cases.emplace_back(new test_im2col(
GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 2, 2, 1}, {3, 2, 2, 1},
s0, 0, p0, 0, d0, 0, false));
}
}
}

// im2col 2D
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));
for (int s0 : {1, 3}) {
for (int s1 : {1, 3}) {
for (int p0 : {0, 3}) {
for (int p1 : {0, 3}) {
for (int d0 : {1, 3}) {
for (int d1 : {1, 3}) {
test_cases.emplace_back(new test_im2col(
GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 20, 2, 2}, {3, 3, 2, 2},
s0, s1, p0, p1, d0, d1, true));
}
}
}
}
}
}

// test cases for 2D im2col
// extra tests for im2col 2D
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 32}, {3, 3, 1, 32}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 32}, {3, 3, 2, 32}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 1024}, {3, 3, 1, 1024}, 1, 1, 1, 1, 1, 1, true));
Expand Down

0 comments on commit 80273a3

Please sign in to comment.