Skip to content

Commit

Permalink
add todo comment; make sub_group_size dynamical
Browse files Browse the repository at this point in the history
  • Loading branch information
Dmitry Razdoburdin committed Mar 8, 2024
1 parent 5847b1a commit ca823e6
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
20 changes: 12 additions & 8 deletions plugin/sycl/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1034,12 +1034,12 @@ void QuantileHistMaker::Builder<GradientSumT>::EvaluateSplits(
const NodeEntry<GradientSumT>* snode = snode_.DataConst();

const float min_child_weight = param_.min_child_weight;
const size_t local_size = 16;

event = qu_.submit([&](::sycl::handler& cgh) {
cgh.depends_on(event);
cgh.parallel_for<>(::sycl::nd_range<2>(::sycl::range<2>(total_features, local_size),
::sycl::range<2>(1, local_size)),
[=](::sycl::nd_item<2> pid) [[intel::reqd_sub_group_size(local_size)]] {
cgh.parallel_for<>(::sycl::nd_range<2>(::sycl::range<2>(total_features, sub_group_size_),
::sycl::range<2>(1, sub_group_size_)),
[=](::sycl::nd_item<2> pid) {
int i = pid.get_global_id(0);
auto sg = pid.get_sub_group();
int nid = split_queries_device[i].nid;
Expand Down Expand Up @@ -1084,10 +1084,14 @@ void QuantileHistMaker::Builder<GradientSumT>::EnumerateSplit(

GradStats<GradientSumT> sum(0, 0);

int32_t local_size = sg.get_local_range().size();
int32_t sub_group_size = sg.get_local_range().size();
const size_t local_id = sg.get_local_id()[0];

for (int32_t i = ibegin + local_id; i < iend; i += local_size) {
/* TODO(razdoburdin)
* Currently the first additions are fast and the last are slow.
* Maybe calculating of reduce overgroup in seprate kernel and reusing it here can be faster
*/
for (int32_t i = ibegin + local_id; i < iend; i += sub_group_size) {
sum += GradStats<GradientSumT>(
::sycl::inclusive_scan_over_group(sg, hist_data[i].GetGrad(), std::plus<>()),
::sycl::inclusive_scan_over_group(sg, hist_data[i].GetHess(), std::plus<>()));
Expand All @@ -1101,9 +1105,9 @@ void QuantileHistMaker::Builder<GradientSumT>::EnumerateSplit(
}
}

const bool last_iter = i + local_size >= iend;
const bool last_iter = i + sub_group_size >= iend;
if (!last_iter) {
size_t end = i - local_id + local_size;
size_t end = i - local_id + sub_group_size;
if (end > iend) end = iend;
for (size_t j = i + 1; j < end; ++j) {
sum += GradStats<GradientSumT>(hist_data[j].GetGrad(), hist_data[j].GetHess());
Expand Down
3 changes: 3 additions & 0 deletions plugin/sycl/tree/updater_quantile_hist.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ class QuantileHistMaker: public TreeUpdater {
snode_(&qu, 1u << (param.max_depth + 1), NodeEntry<GradientSumT>(param)) {
builder_monitor_.Init("SYCL::Quantile::Builder");
kernel_monitor_.Init("SYCL::Quantile::Kernels");
const auto sub_group_sizes = qu_.get_device().get_info<::sycl::info::device::sub_group_sizes>();
sub_group_size_ = sub_group_sizes.back();
}
// update one tree, growing
void Update(Context const * ctx,
Expand Down Expand Up @@ -385,6 +387,7 @@ class QuantileHistMaker: public TreeUpdater {
}
}
// --data fields--
size_t sub_group_size_;
const xgboost::tree::TrainParam& param_;
// number of omp thread used during training
int nthread_;
Expand Down

0 comments on commit ca823e6

Please sign in to comment.