Skip to content

Commit

Permalink
combine two kernels into one
Browse files Browse the repository at this point in the history
  • Loading branch information
Dmitry Razdoburdin committed Dec 10, 2023
1 parent 9f87e11 commit 62444d7
Showing 1 changed file with 9 additions and 19 deletions.
28 changes: 9 additions & 19 deletions plugin/sycl/predictor/predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,25 +189,6 @@ void DevicePredictInternal(::sycl::queue* qu,
}
auto* miss_buff_ptr = miss_buff.Data();

events[0] = qu->submit([&](::sycl::handler& cgh) {
cgh.depends_on(events);
cgh.parallel_for<>(::sycl::range<1>(num_rows),
[=](::sycl::item<1> pid) {
int row_idx = pid.get_id(0);
auto* fval_buff_row_ptr = fval_buff_ptr + num_features * row_idx;
auto* miss_buff_row_ptr = miss_buff_ptr + num_features * row_idx;

const Entry* begin_ptr = data + row_ptr[row_idx];
const Entry* end_ptr = data + row_ptr[row_idx + 1];
for (const Entry* entry = begin_ptr; entry < end_ptr; entry += 1) {
fval_buff_row_ptr[entry->index] = entry->fvalue;
if constexpr (any_missing) {
miss_buff_row_ptr[entry->index] = 0;
}
}
});
});

auto& out_preds_vec = out_preds->HostVector();
::sycl::buffer<float, 1> out_preds_buf(out_preds_vec.data(), out_preds_vec.size());
events[0] = qu->submit([&](::sycl::handler& cgh) {
Expand All @@ -218,6 +199,15 @@ void DevicePredictInternal(::sycl::queue* qu,
auto* fval_buff_row_ptr = fval_buff_ptr + num_features * row_idx;
auto* miss_buff_row_ptr = miss_buff_ptr + num_features * row_idx;

const Entry* first_entry = data + row_ptr[row_idx];
const Entry* last_entry = data + row_ptr[row_idx + 1];
for (const Entry* entry = first_entry; entry < last_entry; entry += 1) {
fval_buff_row_ptr[entry->index] = entry->fvalue;
if constexpr (any_missing) {
miss_buff_row_ptr[entry->index] = 0;
}
}

if (num_group == 1) {
float sum = 0.0;
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
Expand Down

0 comments on commit 62444d7

Please sign in to comment.