Skip to content

Commit

Permalink
Implement column sampler in CUDA. (dmlc#9785)
Browse files Browse the repository at this point in the history
- CUDA implementation.
- Extract the broadcasting logic, we will need the context parameter after revamping the collective implementation.
- Some changes to the event loop for fixing a deadlock in CI.
- Move argsort into algorithms.cuh, add support for cuda stream.
  • Loading branch information
trivialfis authored Nov 16, 2023
1 parent 178cfe7 commit fedd967
Show file tree
Hide file tree
Showing 20 changed files with 447 additions and 232 deletions.
25 changes: 19 additions & 6 deletions src/collective/loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,14 @@ void Loop::Process() {
break;
}

auto unlock_notify = [&](bool is_blocking) {
auto unlock_notify = [&](bool is_blocking, bool stop) {
if (!is_blocking) {
return;
std::lock_guard guard{mu_};
stop_ = stop;
} else {
stop_ = stop;
lock.unlock();
}
lock.unlock();
cv_.notify_one();
};

Expand All @@ -145,13 +148,14 @@ void Loop::Process() {
auto rc = this->EmptyQueue(&qcopy);
// Handle error
if (!rc.OK()) {
unlock_notify(is_blocking, true);
std::lock_guard<std::mutex> guard{rc_lock_};
this->rc_ = std::move(rc);
unlock_notify(is_blocking);
return;
}

CHECK(qcopy.empty());
unlock_notify(is_blocking);
unlock_notify(is_blocking, false);
}
}

Expand All @@ -170,12 +174,21 @@ Result Loop::Stop() {
}

[[nodiscard]] Result Loop::Block() {
{
std::lock_guard<std::mutex> guard{rc_lock_};
if (!rc_.OK()) {
return std::move(rc_);
}
}
this->Submit(Op{Op::kBlock});
{
std::unique_lock lock{mu_};
cv_.wait(lock, [this] { return (this->queue_.empty()) || stop_; });
}
return std::move(rc_);
{
std::lock_guard<std::mutex> lock{rc_lock_};
return std::move(rc_);
}
}

Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} {
Expand Down
3 changes: 3 additions & 0 deletions src/collective/loop.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ class Loop {
std::mutex mu_;
std::queue<Op> queue_;
std::chrono::seconds timeout_;

Result rc_;
std::mutex rc_lock_; // lock for transferring error info.

bool stop_{false};
std::exception_ptr curr_exce_{nullptr};
common::Monitor mutable timer_;
Expand Down
91 changes: 81 additions & 10 deletions src/common/algorithm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
#include "xgboost/logging.h" // CHECK
#include "xgboost/span.h" // Span,byte

namespace xgboost {
namespace common {
namespace xgboost::common {
namespace detail {
// Wrapper around cub sort to define is_decending
template <bool IS_DESCENDING, typename KeyT, typename BeginOffsetIteratorT,
Expand Down Expand Up @@ -127,29 +126,31 @@ inline void SegmentedSortKeys(Context const *ctx, Span<V const> group_ptr,
template <bool accending, bool per_seg_index, typename U, typename V, typename IdxT>
void SegmentedArgSort(Context const *ctx, Span<U> values, Span<V> group_ptr,
Span<IdxT> sorted_idx) {
auto cuctx = ctx->CUDACtx();
CHECK_GE(group_ptr.size(), 1ul);
std::size_t n_groups = group_ptr.size() - 1;
std::size_t bytes = 0;
if (per_seg_index) {
SegmentedSequence(ctx, group_ptr, sorted_idx);
} else {
dh::Iota(sorted_idx);
dh::Iota(sorted_idx, cuctx->Stream());
}
dh::TemporaryArray<std::remove_const_t<U>> values_out(values.size());
dh::TemporaryArray<std::remove_const_t<IdxT>> sorted_idx_out(sorted_idx.size());

detail::DeviceSegmentedRadixSortPair<!accending>(
nullptr, bytes, values.data(), values_out.data().get(), sorted_idx.data(),
sorted_idx_out.data().get(), sorted_idx.size(), n_groups, group_ptr.data(),
group_ptr.data() + 1, ctx->CUDACtx()->Stream());
group_ptr.data() + 1, cuctx->Stream());
dh::TemporaryArray<byte> temp_storage(bytes);
detail::DeviceSegmentedRadixSortPair<!accending>(
temp_storage.data().get(), bytes, values.data(), values_out.data().get(), sorted_idx.data(),
sorted_idx_out.data().get(), sorted_idx.size(), n_groups, group_ptr.data(),
group_ptr.data() + 1, ctx->CUDACtx()->Stream());
group_ptr.data() + 1, cuctx->Stream());

dh::safe_cuda(cudaMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(),
sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice));
sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice,
cuctx->Stream()));
}

/**
Expand All @@ -159,11 +160,12 @@ void SegmentedArgSort(Context const *ctx, Span<U> values, Span<V> group_ptr,
template <typename SegIt, typename ValIt>
void SegmentedArgMergeSort(Context const *ctx, SegIt seg_begin, SegIt seg_end, ValIt val_begin,
ValIt val_end, dh::device_vector<std::size_t> *p_sorted_idx) {
auto cuctx = ctx->CUDACtx();
using Tup = thrust::tuple<std::int32_t, float>;
auto &sorted_idx = *p_sorted_idx;
std::size_t n = std::distance(val_begin, val_end);
sorted_idx.resize(n);
dh::Iota(dh::ToSpan(sorted_idx));
dh::Iota(dh::ToSpan(sorted_idx), cuctx->Stream());
dh::device_vector<Tup> keys(sorted_idx.size());
auto key_it = dh::MakeTransformIterator<Tup>(thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(std::size_t i) -> Tup {
Expand All @@ -177,14 +179,83 @@ void SegmentedArgMergeSort(Context const *ctx, SegIt seg_begin, SegIt seg_end, V
return thrust::make_tuple(seg_idx, residue);
});
thrust::copy(ctx->CUDACtx()->CTP(), key_it, key_it + keys.size(), keys.begin());
thrust::stable_sort_by_key(ctx->CUDACtx()->TP(), keys.begin(), keys.end(), sorted_idx.begin(),
thrust::stable_sort_by_key(cuctx->TP(), keys.begin(), keys.end(), sorted_idx.begin(),
[=] XGBOOST_DEVICE(Tup const &l, Tup const &r) {
if (thrust::get<0>(l) != thrust::get<0>(r)) {
return thrust::get<0>(l) < thrust::get<0>(r); // segment index
}
return thrust::get<1>(l) < thrust::get<1>(r); // residue
});
}
} // namespace common
} // namespace xgboost

template <bool accending, typename IdxT, typename U>
void ArgSort(xgboost::Context const *ctx, xgboost::common::Span<U> keys,
xgboost::common::Span<IdxT> sorted_idx) {
std::size_t bytes = 0;
auto cuctx = ctx->CUDACtx();
dh::Iota(sorted_idx, cuctx->Stream());

using KeyT = typename decltype(keys)::value_type;
using ValueT = std::remove_const_t<IdxT>;

dh::TemporaryArray<KeyT> out(keys.size());
cub::DoubleBuffer<KeyT> d_keys(const_cast<KeyT *>(keys.data()), out.data().get());
dh::TemporaryArray<IdxT> sorted_idx_out(sorted_idx.size());
cub::DoubleBuffer<ValueT> d_values(const_cast<ValueT *>(sorted_idx.data()),
sorted_idx_out.data().get());

// track https://github.com/NVIDIA/cub/pull/340 for 64bit length support
using OffsetT = std::conditional_t<!dh::BuildWithCUDACub(), std::ptrdiff_t, int32_t>;
CHECK_LE(sorted_idx.size(), std::numeric_limits<OffsetT>::max());
if (accending) {
void *d_temp_storage = nullptr;
#if THRUST_MAJOR_VERSION >= 2
dh::safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
cuctx->Stream())));
#else
dh::safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
nullptr, false)));
#endif
dh::TemporaryArray<char> storage(bytes);
d_temp_storage = storage.data().get();
#if THRUST_MAJOR_VERSION >= 2
dh::safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
cuctx->Stream())));
#else
dh::safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
nullptr, false)));
#endif
} else {
void *d_temp_storage = nullptr;
#if THRUST_MAJOR_VERSION >= 2
dh::safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
cuctx->Stream())));
#else
dh::safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
nullptr, false)));
#endif
dh::TemporaryArray<char> storage(bytes);
d_temp_storage = storage.data().get();
#if THRUST_MAJOR_VERSION >= 2
dh::safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
cuctx->Stream())));
#else
dh::safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
nullptr, false)));
#endif
}

dh::safe_cuda(cudaMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(),
sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice,
cuctx->Stream()));
}
} // namespace xgboost::common
#endif // XGBOOST_COMMON_ALGORITHM_CUH_
82 changes: 12 additions & 70 deletions src/common/device_helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,8 @@ inline void LaunchN(size_t n, L lambda) {
}

template <typename Container>
void Iota(Container array) {
LaunchN(array.size(), [=] __device__(size_t i) { array[i] = i; });
void Iota(Container array, cudaStream_t stream) {
LaunchN(array.size(), stream, [=] __device__(size_t i) { array[i] = i; });
}

namespace detail {
Expand Down Expand Up @@ -597,6 +597,16 @@ class DoubleBuffer {
T *Other() { return buff.Alternate(); }
};

template <typename T>
xgboost::common::Span<T> LazyResize(xgboost::Context const *ctx,
xgboost::HostDeviceVector<T> *buffer, std::size_t n) {
buffer->SetDevice(ctx->Device());
if (buffer->Size() < n) {
buffer->Resize(n);
}
return buffer->DeviceSpan().subspan(0, n);
}

/**
* \brief Copies device span to std::vector.
*
Expand Down Expand Up @@ -1060,74 +1070,6 @@ void InclusiveSum(InputIteratorT d_in, OutputIteratorT d_out, OffsetT num_items)
InclusiveScan(d_in, d_out, cub::Sum(), num_items);
}

template <bool accending, typename IdxT, typename U>
void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_idx) {
size_t bytes = 0;
Iota(sorted_idx);

using KeyT = typename decltype(keys)::value_type;
using ValueT = std::remove_const_t<IdxT>;

TemporaryArray<KeyT> out(keys.size());
cub::DoubleBuffer<KeyT> d_keys(const_cast<KeyT *>(keys.data()),
out.data().get());
TemporaryArray<IdxT> sorted_idx_out(sorted_idx.size());
cub::DoubleBuffer<ValueT> d_values(const_cast<ValueT *>(sorted_idx.data()),
sorted_idx_out.data().get());

// track https://github.com/NVIDIA/cub/pull/340 for 64bit length support
using OffsetT = std::conditional_t<!BuildWithCUDACub(), std::ptrdiff_t, int32_t>;
CHECK_LE(sorted_idx.size(), std::numeric_limits<OffsetT>::max());
if (accending) {
void *d_temp_storage = nullptr;
#if THRUST_MAJOR_VERSION >= 2
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr)));
#else
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr, false)));
#endif
TemporaryArray<char> storage(bytes);
d_temp_storage = storage.data().get();
#if THRUST_MAJOR_VERSION >= 2
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr)));
#else
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr, false)));
#endif
} else {
void *d_temp_storage = nullptr;
#if THRUST_MAJOR_VERSION >= 2
safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr)));
#else
safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr, false)));
#endif
TemporaryArray<char> storage(bytes);
d_temp_storage = storage.data().get();
#if THRUST_MAJOR_VERSION >= 2
safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr)));
#else
safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr, false)));
#endif
}

safe_cuda(cudaMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(),
sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice));
}

class CUDAStreamView;

class CUDAEvent {
Expand Down
43 changes: 30 additions & 13 deletions src/common/random.cc
Original file line number Diff line number Diff line change
@@ -1,32 +1,50 @@
/*!
* Copyright 2020 by XGBoost Contributors
* \file random.cc
/**
* Copyright 2020-2023, XGBoost Contributors
*/
#include "random.h"

namespace xgboost {
namespace common {
#include <algorithm> // for sort, max, copy
#include <memory> // for shared_ptr

#include "xgboost/host_device_vector.h" // for HostDeviceVector

namespace xgboost::common {
std::shared_ptr<HostDeviceVector<bst_feature_t>> ColumnSampler::ColSample(
std::shared_ptr<HostDeviceVector<bst_feature_t>> p_features, float colsample) {
if (colsample == 1.0f) {
return p_features;
}

int n = std::max(1, static_cast<int>(colsample * p_features->Size()));
auto p_new_features = std::make_shared<HostDeviceVector<bst_feature_t>>();

if (ctx_->IsCUDA()) {
#if defined(XGBOOST_USE_CUDA)
cuda_impl::SampleFeature(ctx_, n, p_features, p_new_features, this->feature_weights_,
&this->weight_buffer_, &this->idx_buffer_, &rng_);
return p_new_features;
#else
AssertGPUSupport();
return nullptr;
#endif // defined(XGBOOST_USE_CUDA)
}

const auto &features = p_features->HostVector();
CHECK_GT(features.size(), 0);

int n = std::max(1, static_cast<int>(colsample * features.size()));
auto p_new_features = std::make_shared<HostDeviceVector<bst_feature_t>>();
auto &new_features = *p_new_features;

if (feature_weights_.size() != 0) {
if (!feature_weights_.Empty()) {
auto const &h_features = p_features->HostVector();
std::vector<float> weights(h_features.size());
auto const &h_feature_weight = feature_weights_.ConstHostVector();
auto &weight = this->weight_buffer_.HostVector();
weight.resize(h_features.size());
for (size_t i = 0; i < h_features.size(); ++i) {
weights[i] = feature_weights_[h_features[i]];
weight[i] = h_feature_weight[h_features[i]];
}
CHECK(ctx_);
new_features.HostVector() =
WeightedSamplingWithoutReplacement(ctx_, p_features->HostVector(), weights, n);
WeightedSamplingWithoutReplacement(ctx_, p_features->HostVector(), weight, n);
} else {
new_features.Resize(features.size());
std::copy(features.begin(), features.end(), new_features.HostVector().begin());
Expand All @@ -36,5 +54,4 @@ std::shared_ptr<HostDeviceVector<bst_feature_t>> ColumnSampler::ColSample(
std::sort(new_features.HostVector().begin(), new_features.HostVector().end());
return p_new_features;
}
} // namespace common
} // namespace xgboost
} // namespace xgboost::common
Loading

0 comments on commit fedd967

Please sign in to comment.