Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Dmitry Razdoburdin committed Dec 8, 2023
1 parent 048b552 commit 3875c25
Showing 1 changed file with 8 additions and 16 deletions.
24 changes: 8 additions & 16 deletions plugin/sycl/predictor/predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,24 +177,16 @@ void DevicePredictInternal(::sycl::queue* qu,
std::vector<::sycl::event> events(1);
events[0] = qu->fill(fval_buff_ptr, missing_response, num_features * num_rows);

const size_t max_work_group_size =
qu->get_device().get_info<::sycl::info::device::max_work_group_size>();
const size_t feat_local = num_features < max_work_group_size ? num_features : max_work_group_size;
events[0] = qu->submit([&](::sycl::handler& cgh) {
cgh.depends_on(events);
cgh.parallel_for<>(::sycl::nd_range<2>(::sycl::range<2>(num_rows, feat_local),
::sycl::range<2>( 1, feat_local)),
[=](::sycl::nd_item<2> pid) {
int global_idx = pid.get_global_id(0);
int j = pid.get_global_id(1);
auto* fval_buff_row_ptr = fval_buff_ptr + num_features * global_idx;

// const Entry* begin_ptr = data + row_ptr[global_idx];
// const Entry* end_ptr = data + row_ptr[global_idx + 1];
size_t n_columns = row_ptr[global_idx + 1] - row_ptr[global_idx];
// for (const Entry* entry = begin_ptr; entry < end_ptr; entry += 1) {
for (int column = j; column < n_columns; column += feat_local) {
const Entry* entry = data + row_ptr[global_idx] + column;
cgh.parallel_for<>(::sycl::range<1>(num_rows),
[=](::sycl::item<1> pid) {
int row_idx = pid.get_id(0);
auto* fval_buff_row_ptr = fval_buff_ptr + num_features * row_idx;

const Entry* begin_ptr = data + row_ptr[row_idx];
const Entry* end_ptr = data + row_ptr[row_idx + 1];
for (const Entry* entry = begin_ptr; entry < end_ptr; entry += 1) {
fval_buff_row_ptr[entry->index] = {entry->fvalue, 0};
}
});
Expand Down

0 comments on commit 3875c25

Please sign in to comment.