Skip to content

Commit

Permalink
add row-wise processing to PushRowPage
Browse files Browse the repository at this point in the history
  • Loading branch information
Dmitry Razdoburdin committed Jan 23, 2025
1 parent 7c69d92 commit 1780f5b
Showing 1 changed file with 91 additions and 22 deletions.
113 changes: 91 additions & 22 deletions src/common/quantile.h
Original file line number Diff line number Diff line change
Expand Up @@ -840,47 +840,116 @@ class SketchContainerImpl {
template <typename Batch, typename IsValid>
void PushRowPageImpl(Batch const &batch, size_t base_rowid, OptionalWeights weights, size_t nnz,
size_t n_features, bool is_dense, IsValid is_valid) {
auto thread_columns_ptr = LoadBalance(batch, nnz, n_features, n_threads_, is_valid);

dmlc::OMPException exc;
#pragma omp parallel num_threads(n_threads_)
{
exc.Run([&]() {
auto tid = static_cast<uint32_t>(omp_get_thread_num());
auto const begin = thread_columns_ptr[tid];
auto const end = thread_columns_ptr[tid + 1];

// do not iterate if no columns are assigned to the thread
if (begin < end && end <= n_features) {
for (size_t ridx = 0; ridx < batch.Size(); ++ridx) {
size_t ridx_block_size = batch.Size() / n_threads_ + (batch.Size() % n_threads_ > 0);
size_t min_ridx_block_size = 1024;
if ((n_features < n_threads_) && (ridx_block_size > min_ridx_block_size)) {
/* Row-wise parallelisation.
*/
std::vector<std::set<float>> categories_buff(n_threads_ * n_features);
std::vector<WQSketch> sketches_buff(n_threads_ * n_features);

#pragma omp parallel num_threads(n_threads_)
{
exc.Run([&]() {
auto tid = static_cast<uint32_t>(omp_get_thread_num());
WQSketch* sketches_th = sketches_buff.data() + tid * n_features;
std::set<float>* categories_th = categories_buff.data() + tid * n_features;

for (size_t ii = 0; ii < n_features; ii++) {
auto n_bins = std::min(static_cast<bst_idx_t>(max_bins_), columns_size_[ii]);
auto eps = 1.0 / (static_cast<float>(n_bins) * WQSketch::kFactor);
sketches_th[ii].Init(columns_size_[ii], eps);
}

size_t ridx_begin = tid * ridx_block_size;
size_t ridx_end = std::min(ridx_begin + ridx_block_size, batch.Size());
for (size_t ridx = ridx_begin; ridx < ridx_end; ++ridx) {
auto const &line = batch.GetLine(ridx);
auto w = weights[ridx + base_rowid];
if (is_dense) {
for (size_t ii = begin; ii < end; ii++) {
for (size_t ii = 0; ii < n_features; ii++) {
auto elem = line.GetElement(ii);
if (is_valid(elem)) {
if (IsCat(feature_types_, ii)) {
categories_[ii].emplace(elem.value);
categories_th[ii].emplace(elem.value);
} else {
sketches_[ii].Push(elem.value, w);
sketches_th[ii].Push(elem.value, w);
}
}
}
} else {
for (size_t i = 0; i < line.Size(); ++i) {
auto const &elem = line.GetElement(i);
if (is_valid(elem) && elem.column_idx >= begin && elem.column_idx < end) {
for (size_t ii = 0; ii < line.Size(); ++ii) {
auto elem = line.GetElement(ii);
if (is_valid(elem)) {
if (IsCat(feature_types_, elem.column_idx)) {
categories_[elem.column_idx].emplace(elem.value);
categories_th[elem.column_idx].emplace(elem.value);
} else {
sketches_[elem.column_idx].Push(elem.value, w);
sketches_th[elem.column_idx].Push(elem.value, w);
}
}
}
}
}
}
});
#pragma omp barrier

size_t fidx_block_size = n_features / n_threads_ + (n_features % n_threads_ > 0);
size_t fidx_begin = tid * fidx_block_size;
size_t fidx_end = std::min(fidx_begin + fidx_block_size, n_features);
for (size_t ii = fidx_begin; ii < fidx_end; ++ii) {
for (size_t th = 0; th < n_threads_; ++th) {
if (IsCat(feature_types_, ii)) {
categories_[ii].merge(categories_buff[th * n_features + ii]);
} else {
typename WQSketch::SummaryContainer summary;
sketches_buff[th * n_features + ii].GetSummary(&summary);
sketches_[ii].PushSummary(summary);
}
}
}
});
}
} else {
auto thread_columns_ptr = LoadBalance(batch, nnz, n_features, n_threads_, is_valid);
#pragma omp parallel num_threads(n_threads_)
{
exc.Run([&]() {
auto tid = static_cast<uint32_t>(omp_get_thread_num());
auto const begin = thread_columns_ptr[tid];
auto const end = thread_columns_ptr[tid + 1];

// do not iterate if no columns are assigned to the thread
if (begin < end && end <= n_features) {
for (size_t ridx = 0; ridx < batch.Size(); ++ridx) {
auto const &line = batch.GetLine(ridx);
auto w = weights[ridx + base_rowid];
if (is_dense) {
for (size_t ii = begin; ii < end; ii++) {
auto elem = line.GetElement(ii);
if (is_valid(elem)) {
if (IsCat(feature_types_, ii)) {
categories_[ii].emplace(elem.value);
} else {
sketches_[ii].Push(elem.value, w);
}
}
}
} else {
for (size_t i = 0; i < line.Size(); ++i) {
auto const &elem = line.GetElement(i);
if (is_valid(elem) && elem.column_idx >= begin && elem.column_idx < end) {
if (IsCat(feature_types_, elem.column_idx)) {
categories_[elem.column_idx].emplace(elem.value);
} else {
sketches_[elem.column_idx].Push(elem.value, w);
}
}
}
}
}
}
});
}
}
exc.Rethrow();
}
Expand Down

0 comments on commit 1780f5b

Please sign in to comment.