Skip to content

Commit

Permalink
Fixing the UB bug
Browse files Browse the repository at this point in the history
  • Loading branch information
dmitry.razdoburdin committed Feb 21, 2022
1 parent e18e342 commit 93661a1
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions plugin/updater_oneapi/regression_obj_oneapi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class RegLossObjOneAPI : public ObjFunction {
sycl::buffer<bst_float, 1> weights_buf(is_null_weight ? NULL : info.weights_.HostPointer(),
is_null_weight ? 1 : info.weights_.Size());

const size_t n_targets = std::max(info.labels.Shape(1), static_cast<size_t>(1));

sycl::buffer<int, 1> additional_input_buf(1);
{
auto additional_input_acc = additional_input_buf.get_access<sycl::access::mode::write>();
Expand All @@ -92,7 +94,7 @@ class RegLossObjOneAPI : public ObjFunction {
cgh.parallel_for<>(sycl::range<1>(ndata), [=](sycl::id<1> pid) {
int idx = pid[0];
bst_float p = Loss::PredTransform(preds_acc[idx]);
bst_float w = is_null_weight ? 1.0f : weights_acc[idx];
bst_float w = is_null_weight ? 1.0f : weights_acc[idx/n_targets];
bst_float label = labels_acc[idx];
if (label == 1.0f) {
w *= scale_pos_weight;
Expand Down Expand Up @@ -125,7 +127,6 @@ class RegLossObjOneAPI : public ObjFunction {

void PredTransform(HostDeviceVector<float> *io_preds) const override {
size_t const ndata = io_preds->Size();

sycl::buffer<bst_float, 1> io_preds_buf(io_preds->HostPointer(), io_preds->Size());

qu_.submit([&](sycl::handler& cgh) {
Expand Down

0 comments on commit 93661a1

Please sign in to comment.