From 1780f5b9d9f85e6ef2f95dffef9e0303b28feb91 Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin <> Date: Thu, 23 Jan 2025 03:37:59 -0800 Subject: [PATCH] add row-wise processing to PushRowPage --- src/common/quantile.h | 113 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 91 insertions(+), 22 deletions(-) diff --git a/src/common/quantile.h b/src/common/quantile.h index e189b259b159..f877e522ee1a 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -840,47 +840,116 @@ class SketchContainerImpl { template void PushRowPageImpl(Batch const &batch, size_t base_rowid, OptionalWeights weights, size_t nnz, size_t n_features, bool is_dense, IsValid is_valid) { - auto thread_columns_ptr = LoadBalance(batch, nnz, n_features, n_threads_, is_valid); - dmlc::OMPException exc; -#pragma omp parallel num_threads(n_threads_) - { - exc.Run([&]() { - auto tid = static_cast(omp_get_thread_num()); - auto const begin = thread_columns_ptr[tid]; - auto const end = thread_columns_ptr[tid + 1]; - - // do not iterate if no columns are assigned to the thread - if (begin < end && end <= n_features) { - for (size_t ridx = 0; ridx < batch.Size(); ++ridx) { + size_t ridx_block_size = batch.Size() / n_threads_ + (batch.Size() % n_threads_ > 0); + size_t min_ridx_block_size = 1024; + if ((n_features < n_threads_) && (ridx_block_size > min_ridx_block_size)) { + /* Row-wise parallelisation. + */ + std::vector> categories_buff(n_threads_ * n_features); + std::vector sketches_buff(n_threads_ * n_features); + + #pragma omp parallel num_threads(n_threads_) + { + exc.Run([&]() { + auto tid = static_cast(omp_get_thread_num()); + WQSketch* sketches_th = sketches_buff.data() + tid * n_features; + std::set* categories_th = categories_buff.data() + tid * n_features; + + for (size_t ii = 0; ii < n_features; ii++) { + auto n_bins = std::min(static_cast(max_bins_), columns_size_[ii]); + auto eps = 1.0 / (static_cast(n_bins) * WQSketch::kFactor); + sketches_th[ii].Init(columns_size_[ii], eps); + } + + size_t ridx_begin = tid * ridx_block_size; + size_t ridx_end = std::min(ridx_begin + ridx_block_size, batch.Size()); + for (size_t ridx = ridx_begin; ridx < ridx_end; ++ridx) { auto const &line = batch.GetLine(ridx); auto w = weights[ridx + base_rowid]; if (is_dense) { - for (size_t ii = begin; ii < end; ii++) { + for (size_t ii = 0; ii < n_features; ii++) { auto elem = line.GetElement(ii); if (is_valid(elem)) { if (IsCat(feature_types_, ii)) { - categories_[ii].emplace(elem.value); + categories_th[ii].emplace(elem.value); } else { - sketches_[ii].Push(elem.value, w); + sketches_th[ii].Push(elem.value, w); } } } } else { - for (size_t i = 0; i < line.Size(); ++i) { - auto const &elem = line.GetElement(i); - if (is_valid(elem) && elem.column_idx >= begin && elem.column_idx < end) { + for (size_t ii = 0; ii < line.Size(); ++ii) { + auto elem = line.GetElement(ii); + if (is_valid(elem)) { if (IsCat(feature_types_, elem.column_idx)) { - categories_[elem.column_idx].emplace(elem.value); + categories_th[elem.column_idx].emplace(elem.value); } else { - sketches_[elem.column_idx].Push(elem.value, w); + sketches_th[elem.column_idx].Push(elem.value, w); } } } } } - } - }); + #pragma omp barrier + + size_t fidx_block_size = n_features / n_threads_ + (n_features % n_threads_ > 0); + size_t fidx_begin = tid * fidx_block_size; + size_t fidx_end = std::min(fidx_begin + fidx_block_size, n_features); + for (size_t ii = fidx_begin; ii < fidx_end; ++ii) { + for (size_t th = 0; th < n_threads_; ++th) { + if (IsCat(feature_types_, ii)) { + categories_[ii].merge(categories_buff[th * n_features + ii]); + } else { + typename WQSketch::SummaryContainer summary; + sketches_buff[th * n_features + ii].GetSummary(&summary); + sketches_[ii].PushSummary(summary); + } + } + } + }); + } + } else { + auto thread_columns_ptr = LoadBalance(batch, nnz, n_features, n_threads_, is_valid); + #pragma omp parallel num_threads(n_threads_) + { + exc.Run([&]() { + auto tid = static_cast(omp_get_thread_num()); + auto const begin = thread_columns_ptr[tid]; + auto const end = thread_columns_ptr[tid + 1]; + + // do not iterate if no columns are assigned to the thread + if (begin < end && end <= n_features) { + for (size_t ridx = 0; ridx < batch.Size(); ++ridx) { + auto const &line = batch.GetLine(ridx); + auto w = weights[ridx + base_rowid]; + if (is_dense) { + for (size_t ii = begin; ii < end; ii++) { + auto elem = line.GetElement(ii); + if (is_valid(elem)) { + if (IsCat(feature_types_, ii)) { + categories_[ii].emplace(elem.value); + } else { + sketches_[ii].Push(elem.value, w); + } + } + } + } else { + for (size_t i = 0; i < line.Size(); ++i) { + auto const &elem = line.GetElement(i); + if (is_valid(elem) && elem.column_idx >= begin && elem.column_idx < end) { + if (IsCat(feature_types_, elem.column_idx)) { + categories_[elem.column_idx].emplace(elem.value); + } else { + sketches_[elem.column_idx].Push(elem.value, w); + } + } + } + } + } + } + }); + } } exc.Rethrow(); }