diff --git a/plugin/sycl/tree/hist_updater.cc b/plugin/sycl/tree/hist_updater.cc index 0f22bc9a3834..426c8debc59d 100644 --- a/plugin/sycl/tree/hist_updater.cc +++ b/plugin/sycl/tree/hist_updater.cc @@ -9,6 +9,8 @@ #include #include +#include + #include "hist_updater.h" #include "../common/hist_util.h" #include "../../../src/common/threading_utils.h" // MemStackAllocator @@ -80,20 +82,14 @@ void HistUpdater::BuildLocalHistograms( const USMVector &gpair_device) { builder_monitor_.Start("BuildLocalHistograms"); const size_t n_nodes = nodes_for_explicit_hist_build_.size(); - for (auto& event : hist_build_events_) { - event = ::sycl::event(); - } + ::sycl::event event; for (size_t i = 0; i < n_nodes; i++) { const int32_t nid = nodes_for_explicit_hist_build_[i].nid; - const size_t event_idx = i % kNumParallelBuffers; - auto& event = hist_build_events_[event_idx]; if (row_set_collection_[nid].Size() > 0) { - auto& hist_buff = hist_buffers_[event_idx]; - event = BuildHist(gpair_device, row_set_collection_[nid], gmat, &(hist_[nid]), - &(hist_buff.GetDeviceBuffer()), event); + &(hist_buffer_.GetDeviceBuffer()), event); } else { common::InitHist(qu_, &(hist_[nid]), hist_[nid].Size(), &event); } @@ -350,7 +346,7 @@ void HistUpdater::Update( tree_evaluator_.Reset(qu_, param_, p_fmat->Info().num_col_); interaction_constraints_.Reset(); - this->InitData(ctx, gmat, gpair_h, gpair_device, *p_fmat, *p_tree); + this->InitData(ctx, gmat, gpair_device, *p_fmat, *p_tree); if (param_.grow_policy == xgboost::tree::TrainParam::kLossGuide) { ExpandWithLossGuide(gmat, p_fmat, p_tree, gpair_h, gpair_device); } else { @@ -433,82 +429,47 @@ bool HistUpdater::UpdatePredictionCache( template void HistUpdater::InitSampling( - const std::vector& gpair, - const USMVector &gpair_device, - const DMatrix& fmat, - USMVector* row_indices_device) { - const auto& info = fmat.Info(); - auto& rnd = xgboost::common::GlobalRandom(); -#if XGBOOST_CUSTOMIZE_GLOBAL_PRNG - std::bernoulli_distribution coin_flip(param_.subsample); - size_t j = 0; - - std::vector row_indices(row_indices_device->Size()); - qu_.memcpy(row_indices.data(), row_indices_device->DataConst(), - row_indices.size() * sizeof(size_t)).wait(); - for (size_t i = 0; i < info.num_row_; ++i) { - if (gpair[i].GetHess() >= 0.0f && coin_flip(rnd)) { - row_indices[j++] = i; - } - } - qu_.memcpy(row_indices_device->Data(), row_indices.data(), - row_indices.size() * sizeof(size_t)).wait(); - /* resize row_indices to reduce memory */ - row_indices_device->Resize(qu_, j); -#else - const size_t nthread = this->nthread_; - std::vector row_offsets(nthread, 0); - /* usage of mt19937_64 give 2x speed up for subsampling */ - std::vector rnds(nthread); - /* create engine for each thread */ - for (std::mt19937& r : rnds) { - r = rnd; - } - - std::vector row_indices(row_indices_device->Size()); - qu_.memcpy(row_indices.data(), row_indices_device->DataConst(), - row_indices.size() * sizeof(size_t)).wait(); - const size_t discard_size = info.num_row_ / nthread; - #pragma omp parallel num_threads(nthread) - { - const size_t tid = omp_get_thread_num(); - const size_t ibegin = tid * discard_size; - const size_t iend = (tid == (nthread - 1)) ? - info.num_row_ : ibegin + discard_size; - std::bernoulli_distribution coin_flip(param_.subsample); - - rnds[tid].discard(2*discard_size * tid); - for (size_t i = ibegin; i < iend; ++i) { - if (gpair[i].GetHess() >= 0.0f && coin_flip(rnds[tid])) { - row_indices[ibegin + row_offsets[tid]++] = i; - } - } - } - - /* discard global engine */ - rnd = rnds[nthread - 1]; - size_t prefix_sum = row_offsets[0]; - for (size_t i = 1; i < nthread; ++i) { - const size_t ibegin = i * discard_size; + const USMVector &gpair, + USMVector* row_indices) { + const size_t num_rows = row_indices->Size(); + auto* row_idx = row_indices->Data(); + const auto* gpair_ptr = gpair.DataConst(); + uint64_t num_samples = 0; + const auto subsample = param_.subsample; + ::sycl::event event; - for (size_t k = 0; k < row_offsets[i]; ++k) { - row_indices[prefix_sum + k] = row_indices[ibegin + k]; - } - prefix_sum += row_offsets[i]; + { + ::sycl::buffer flag_buf(&num_samples, 1); + uint64_t seed = seed_; + seed_ += num_rows; + event = qu_.submit([&](::sycl::handler& cgh) { + auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::read_write>(cgh); + cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(num_rows)), + [=](::sycl::item<1> pid) { + uint64_t i = pid.get_id(0); + + // Create minstd_rand engine + oneapi::dpl::minstd_rand engine(seed, i); + oneapi::dpl::bernoulli_distribution coin_flip(subsample); + + auto rnd = coin_flip(engine); + if (gpair_ptr[i].GetHess() >= 0.0f && rnd) { + AtomicRef num_samples_ref(flag_buf_acc[0]); + row_idx[num_samples_ref++] = i; + } + }); + }); } - qu_.memcpy(row_indices_device->Data(), row_indices.data(), - row_indices.size() * sizeof(size_t)).wait(); - row_indices_device->Resize(&qu_, prefix_sum); -#endif // XGBOOST_CUSTOMIZE_GLOBAL_PRNG + row_indices->Resize(&qu_, num_samples, 0, &event); + qu_.wait(); } template void HistUpdater::InitData( Context const * ctx, const common::GHistIndexMatrix& gmat, - const std::vector& gpair, - const USMVector &gpair_device, + const USMVector &gpair, const DMatrix& fmat, const RegTree& tree) { CHECK((param_.max_depth > 0 || param_.max_leaves > 0)) @@ -528,18 +489,16 @@ void HistUpdater::InitData( uint32_t nbins = gmat.cut.Ptrs().back(); hist_.Init(qu_, nbins); hist_local_worker_.Init(qu_, nbins); - for (auto& buffer : hist_buffers_) { - buffer.Init(qu_, nbins); - size_t buffer_size = 2048; - const size_t min_block_size = 128; - if (buffer_size > info.num_row_ / min_block_size + 1) { - buffer_size = info.num_row_ / min_block_size + 1; - } - buffer.Reset(buffer_size); + + hist_buffer_.Init(qu_, nbins); + size_t buffer_size = 2048; + const size_t min_block_size = 128; + if (buffer_size > info.num_row_ / min_block_size + 1) { + buffer_size = info.num_row_ / min_block_size + 1; } + hist_buffer_.Reset(buffer_size); // initialize histogram builder - this->nthread_ = omp_get_num_threads(); hist_builder_ = common::GHistBuilder(qu_, nbins); USMVector* row_indices = &(row_set_collection_.Data()); @@ -551,55 +510,47 @@ void HistUpdater::InitData( CHECK_EQ(param_.sampling_method, xgboost::tree::TrainParam::kUniform) << "Only uniform sampling is supported, " << "gradient-based sampling is only support by GPU Hist."; - InitSampling(gpair, gpair_device, fmat, row_indices); + InitSampling(gpair, row_indices); } else { - xgboost::common::MemStackAllocator buff(this->nthread_); - bool* p_buff = buff.data(); - std::fill(p_buff, p_buff + this->nthread_, false); - - const size_t block_size = info.num_row_ / this->nthread_ + !!(info.num_row_ % this->nthread_); - - #pragma omp parallel num_threads(this->nthread_) + int has_neg_hess = 0; + const GradientPair* gpair_ptr = gpair.DataConst(); + ::sycl::event event; { - const size_t tid = omp_get_thread_num(); - const size_t ibegin = tid * block_size; - const size_t iend = std::min(static_cast(ibegin + block_size), - static_cast(info.num_row_)); - - for (size_t i = ibegin; i < iend; ++i) { - if (gpair[i].GetHess() < 0.0f) { - p_buff[tid] = true; - break; - } - } - } - - bool has_neg_hess = false; - for (int32_t tid = 0; tid < this->nthread_; ++tid) { - if (p_buff[tid]) { - has_neg_hess = true; - } - } - - if (has_neg_hess) { - size_t j = 0; - std::vector row_indices_buff(row_indices->Size()); - for (size_t i = 0; i < info.num_row_; ++i) { - if (gpair[i].GetHess() >= 0.0f) { - row_indices_buff[j++] = i; - } - } - qu_.memcpy(p_row_indices, row_indices_buff.data(), j * sizeof(size_t)).wait(); - row_indices->Resize(&qu_, j); - } else { - qu_.submit([&](::sycl::handler& cgh) { + ::sycl::buffer flag_buf(&has_neg_hess, 1); + event = qu_.submit([&](::sycl::handler& cgh) { + auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::read_write>(cgh); cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(info.num_row_)), - [p_row_indices](::sycl::item<1> pid) { + [=](::sycl::item<1> pid) { const size_t idx = pid.get_id(0); p_row_indices[idx] = idx; + if (gpair_ptr[idx].GetHess() < 0.0f) { + AtomicRef has_neg_hess_ref(flag_buf_acc[0]); + has_neg_hess_ref.fetch_max(1); + } }); - }).wait_and_throw(); + }); + } + + if (has_neg_hess) { + size_t max_idx = 0; + { + ::sycl::buffer flag_buf(&max_idx, 1); + event = qu_.submit([&](::sycl::handler& cgh) { + cgh.depends_on(event); + auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::read_write>(cgh); + cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(info.num_row_)), + [=](::sycl::item<1> pid) { + const size_t idx = pid.get_id(0); + if (gpair_ptr[idx].GetHess() >= 0.0f) { + AtomicRef max_idx_ref(flag_buf_acc[0]); + p_row_indices[max_idx_ref++] = idx; + } + }); + }); + } + row_indices->Resize(&qu_, max_idx, 0, &event); } + qu_.wait_and_throw(); } } diff --git a/plugin/sycl/tree/hist_updater.h b/plugin/sycl/tree/hist_updater.h index 45b3bd2d8c23..c509782e6dbb 100644 --- a/plugin/sycl/tree/hist_updater.h +++ b/plugin/sycl/tree/hist_updater.h @@ -106,8 +106,7 @@ class HistUpdater { // initialize temp data structure void InitData(Context const * ctx, const common::GHistIndexMatrix& gmat, - const std::vector& gpair, - const USMVector &gpair_device, + const USMVector &gpair, const DMatrix& fmat, const RegTree& tree); @@ -125,9 +124,8 @@ class HistUpdater { const GradientPairT* hist; }; - void InitSampling(const std::vector& gpair, - const USMVector &gpair_device, - const DMatrix& fmat, USMVector* row_indices); + void InitSampling(const USMVector &gpair, + USMVector* row_indices); void EvaluateSplits(const std::vector& nodes_set, const common::GHistIndexMatrix& gmat, @@ -243,8 +241,6 @@ class HistUpdater { // --data fields-- size_t sub_group_size_; const xgboost::tree::TrainParam& param_; - // number of omp thread used during training - int nthread_; xgboost::common::ColumnSampler column_sampler_; // the internal row sets common::RowSetCollection row_set_collection_; @@ -265,6 +261,8 @@ class HistUpdater { std::unique_ptr pruner_; FeatureInteractionConstraintHost interaction_constraints_; + uint64_t seed_ = 0; + common::PartitionBuilder partition_builder_; // back pointers to tree and data matrix @@ -288,9 +286,11 @@ class HistUpdater { xgboost::common::Monitor builder_monitor_; xgboost::common::Monitor kernel_monitor_; - constexpr static size_t kNumParallelBuffers = 1; - std::array, kNumParallelBuffers> hist_buffers_; - std::array<::sycl::event, kNumParallelBuffers> hist_build_events_; + + common::ParallelGHistBuilder hist_buffer_; + + std::vector rnds; + std::vector<::sycl::event> merge_to_array_events_; std::unique_ptr> hist_synchronizer_; std::unique_ptr> hist_rows_adder_; diff --git a/tests/ci_build/conda_env/linux_sycl_test.yml b/tests/ci_build/conda_env/linux_sycl_test.yml index bb14c1e77ebb..6326ce33cb52 100644 --- a/tests/ci_build/conda_env/linux_sycl_test.yml +++ b/tests/ci_build/conda_env/linux_sycl_test.yml @@ -18,3 +18,4 @@ dependencies: - pytest-timeout - pytest-cov - dpcpp_linux-64 +- onedpl-devel \ No newline at end of file