Skip to content

Commit

Permalink
validation fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Dmitry Razdoburdin committed Nov 22, 2024
1 parent 1c13be4 commit 65c93f4
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
22 changes: 13 additions & 9 deletions plugin/sycl/common/linalg_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,18 +78,22 @@ bool Validate(DeviceOrd device, TensorView<T, D> t, Fn&& fn) {

int flag = 0;
{
::sycl::buffer<int> buff(&flag, 1);
size_t size = xgboost::linalg::cend(t) - xgboost::linalg::cbegin(t);
::sycl::buffer<int, 1> flag_buf(&flag, 1);
qu->submit([&](::sycl::handler& cgh) {
auto reduction = ::sycl::reduction(buff, cgh, ::sycl::maximum<>());
cgh.parallel_for<>(::sycl::range<1>(size), reduction,
[=](::sycl::id<1> pid, auto& max) {
const size_t i = pid[0];
auto it = xgboost::linalg::cbegin(t) + i;
max.combine(!const_cast<Fn&&>(fn)(*it));
auto flag_acc = flag_buf.get_access<::sycl::access::mode::write>(cgh);
cgh.parallel_for<>(::sycl::range<1>(t.Size()),
[=](::sycl::id<1> pid) {
const size_t idx = pid[0];
const T& value = call(t, xgboost::linalg::UnravelIndex(idx, t.Shape()));
bool is_valid = const_cast<Fn&&>(fn)(value);
if (!is_valid) {
AtomicRef<int> flag_ref(flag_acc[0]);
flag_ref = 1;
}
});
}).wait_and_throw();
});
}
qu->wait_and_throw();
return (flag == 0);
}

Expand Down
3 changes: 1 addition & 2 deletions src/objective/regression_obj.cu
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ class RegLossObj : public FitInterceptGlmLike {
if (iter == 0) {
ValidateLabel(info);
}

size_t const ndata = preds.Size();
out_gpair->SetDevice(ctx_->Device());
auto device = ctx_->Device();
Expand All @@ -132,7 +131,7 @@ class RegLossObj : public FitInterceptGlmLike {
additional_input_.HostVector().begin()[1] = is_null_weight;

const size_t nthreads = ctx_->Threads();
bool on_device = device.IsCUDA();
bool on_device = !device.IsCPU();
// On CPU we run the transformation each thread processing a contigious block of data
// for better performance.
const size_t n_data_blocks = std::max(static_cast<size_t>(1), (on_device ? ndata : nthreads));
Expand Down

0 comments on commit 65c93f4

Please sign in to comment.