From ca823e66ead618434b6a5fbd230ac2f2629ba35e Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin <> Date: Fri, 8 Mar 2024 06:00:12 -0800 Subject: [PATCH] add todo comment; make sub_group_size dynamical --- plugin/sycl/tree/updater_quantile_hist.cc | 20 ++++++++++++-------- plugin/sycl/tree/updater_quantile_hist.h | 3 +++ 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/plugin/sycl/tree/updater_quantile_hist.cc b/plugin/sycl/tree/updater_quantile_hist.cc index b6e38947d561..16db0d835183 100644 --- a/plugin/sycl/tree/updater_quantile_hist.cc +++ b/plugin/sycl/tree/updater_quantile_hist.cc @@ -1034,12 +1034,12 @@ void QuantileHistMaker::Builder::EvaluateSplits( const NodeEntry* snode = snode_.DataConst(); const float min_child_weight = param_.min_child_weight; - const size_t local_size = 16; + event = qu_.submit([&](::sycl::handler& cgh) { cgh.depends_on(event); - cgh.parallel_for<>(::sycl::nd_range<2>(::sycl::range<2>(total_features, local_size), - ::sycl::range<2>(1, local_size)), - [=](::sycl::nd_item<2> pid) [[intel::reqd_sub_group_size(local_size)]] { + cgh.parallel_for<>(::sycl::nd_range<2>(::sycl::range<2>(total_features, sub_group_size_), + ::sycl::range<2>(1, sub_group_size_)), + [=](::sycl::nd_item<2> pid) { int i = pid.get_global_id(0); auto sg = pid.get_sub_group(); int nid = split_queries_device[i].nid; @@ -1084,10 +1084,14 @@ void QuantileHistMaker::Builder::EnumerateSplit( GradStats sum(0, 0); - int32_t local_size = sg.get_local_range().size(); + int32_t sub_group_size = sg.get_local_range().size(); const size_t local_id = sg.get_local_id()[0]; - for (int32_t i = ibegin + local_id; i < iend; i += local_size) { + /* TODO(razdoburdin) + * Currently the first additions are fast and the last are slow. + * Maybe calculating of reduce overgroup in seprate kernel and reusing it here can be faster + */ + for (int32_t i = ibegin + local_id; i < iend; i += sub_group_size) { sum += GradStats( ::sycl::inclusive_scan_over_group(sg, hist_data[i].GetGrad(), std::plus<>()), ::sycl::inclusive_scan_over_group(sg, hist_data[i].GetHess(), std::plus<>())); @@ -1101,9 +1105,9 @@ void QuantileHistMaker::Builder::EnumerateSplit( } } - const bool last_iter = i + local_size >= iend; + const bool last_iter = i + sub_group_size >= iend; if (!last_iter) { - size_t end = i - local_id + local_size; + size_t end = i - local_id + sub_group_size; if (end > iend) end = iend; for (size_t j = i + 1; j < end; ++j) { sum += GradStats(hist_data[j].GetGrad(), hist_data[j].GetHess()); diff --git a/plugin/sycl/tree/updater_quantile_hist.h b/plugin/sycl/tree/updater_quantile_hist.h index a94267243ab5..f4a30a373862 100644 --- a/plugin/sycl/tree/updater_quantile_hist.h +++ b/plugin/sycl/tree/updater_quantile_hist.h @@ -212,6 +212,8 @@ class QuantileHistMaker: public TreeUpdater { snode_(&qu, 1u << (param.max_depth + 1), NodeEntry(param)) { builder_monitor_.Init("SYCL::Quantile::Builder"); kernel_monitor_.Init("SYCL::Quantile::Kernels"); + const auto sub_group_sizes = qu_.get_device().get_info<::sycl::info::device::sub_group_sizes>(); + sub_group_size_ = sub_group_sizes.back(); } // update one tree, growing void Update(Context const * ctx, @@ -385,6 +387,7 @@ class QuantileHistMaker: public TreeUpdater { } } // --data fields-- + size_t sub_group_size_; const xgboost::tree::TrainParam& param_; // number of omp thread used during training int nthread_;