Skip to content

Commit

Permalink
SYCL. Refactor gradient calculation with HostDeviceVector (dmlc#10922)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Dmitry Razdoburdin <>
  • Loading branch information
razdoburdin authored Oct 24, 2024
1 parent 9c4f190 commit e8a3ead
Show file tree
Hide file tree
Showing 13 changed files with 176 additions and 432 deletions.
9 changes: 9 additions & 0 deletions include/xgboost/span.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,16 @@ namespace xgboost::common {

#else

#if defined(__SYCL_DEVICE_ONLY__)

// SYCL doesn't support termination
#define SYCL_KERNEL_CHECK(cond)

#define KERNEL_CHECK(cond) SYCL_KERNEL_CHECK(cond)

#else // defined(__SYCL_DEVICE_ONLY__)
#define KERNEL_CHECK(cond) (XGBOOST_EXPECT((cond), true) ? static_cast<void>(0) : std::terminate())
#endif // defined(__SYCL_DEVICE_ONLY__)

#define SPAN_CHECK(cond) KERNEL_CHECK(cond)

Expand Down
34 changes: 17 additions & 17 deletions plugin/sycl/common/hist_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ inline auto GetBlocksParameters(::sycl::queue* qu, size_t size, size_t max_nbloc
// Kernel with buffer using
template<typename FPType, typename BinIdxType, bool isDense>
::sycl::event BuildHistKernel(::sycl::queue* qu,
const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
const HostDeviceVector<GradientPair>& gpair,
const RowSetCollection::Elem& row_indices,
const GHistIndexMatrix& gmat,
GHistRow<FPType, MemoryType::on_device>* hist,
Expand All @@ -128,7 +128,7 @@ ::sycl::event BuildHistKernel(::sycl::queue* qu,
const size_t size = row_indices.Size();
const size_t* rid = row_indices.begin;
const size_t n_columns = isDense ? gmat.nfeatures : gmat.row_stride;
const auto* pgh = gpair_device.DataConst();
const auto* pgh = gpair.ConstDevicePointer();
const BinIdxType* gradient_index = gmat.index.data<BinIdxType>();
const uint32_t* offsets = gmat.cut.cut_ptrs_.ConstDevicePointer();
const size_t nbins = gmat.nbins;
Expand Down Expand Up @@ -199,7 +199,7 @@ ::sycl::event BuildHistKernel(::sycl::queue* qu,
// Kernel with atomic using
template<typename FPType, typename BinIdxType, bool isDense>
::sycl::event BuildHistKernel(::sycl::queue* qu,
const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
const HostDeviceVector<GradientPair>& gpair,
const RowSetCollection::Elem& row_indices,
const GHistIndexMatrix& gmat,
GHistRow<FPType, MemoryType::on_device>* hist,
Expand All @@ -208,7 +208,7 @@ ::sycl::event BuildHistKernel(::sycl::queue* qu,
const size_t* rid = row_indices.begin;
const size_t n_columns = isDense ? gmat.nfeatures : gmat.row_stride;
const GradientPair::ValueT* pgh =
reinterpret_cast<const GradientPair::ValueT*>(gpair_device.DataConst());
reinterpret_cast<const GradientPair::ValueT*>(gpair.ConstDevicePointer());
const BinIdxType* gradient_index = gmat.index.data<BinIdxType>();
const uint32_t* offsets = gmat.cut.cut_ptrs_.ConstDevicePointer();
FPType* hist_data = reinterpret_cast<FPType*>(hist->Data());
Expand Down Expand Up @@ -254,7 +254,7 @@ ::sycl::event BuildHistKernel(::sycl::queue* qu,
template<typename FPType, typename BinIdxType>
::sycl::event BuildHistDispatchKernel(
::sycl::queue* qu,
const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
const HostDeviceVector<GradientPair>& gpair,
const RowSetCollection::Elem& row_indices,
const GHistIndexMatrix& gmat,
GHistRow<FPType, MemoryType::on_device>* hist,
Expand All @@ -273,28 +273,28 @@ ::sycl::event BuildHistDispatchKernel(
use_atomic = use_atomic || force_atomic_use;
if (!use_atomic) {
if (isDense) {
return BuildHistKernel<FPType, BinIdxType, true>(qu, gpair_device, row_indices,
return BuildHistKernel<FPType, BinIdxType, true>(qu, gpair, row_indices,
gmat, hist, hist_buffer,
events_priv);
} else {
return BuildHistKernel<FPType, uint32_t, false>(qu, gpair_device, row_indices,
return BuildHistKernel<FPType, uint32_t, false>(qu, gpair, row_indices,
gmat, hist, hist_buffer,
events_priv);
}
} else {
if (isDense) {
return BuildHistKernel<FPType, BinIdxType, true>(qu, gpair_device, row_indices,
return BuildHistKernel<FPType, BinIdxType, true>(qu, gpair, row_indices,
gmat, hist, events_priv);
} else {
return BuildHistKernel<FPType, uint32_t, false>(qu, gpair_device, row_indices,
return BuildHistKernel<FPType, uint32_t, false>(qu, gpair, row_indices,
gmat, hist, events_priv);
}
}
}

template<typename FPType>
::sycl::event BuildHistKernel(::sycl::queue* qu,
const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
const HostDeviceVector<GradientPair>& gpair,
const RowSetCollection::Elem& row_indices,
const GHistIndexMatrix& gmat, const bool isDense,
GHistRow<FPType, MemoryType::on_device>* hist,
Expand All @@ -304,17 +304,17 @@ ::sycl::event BuildHistKernel(::sycl::queue* qu,
const bool is_dense = isDense;
switch (gmat.index.GetBinTypeSize()) {
case BinTypeSize::kUint8BinsTypeSize:
return BuildHistDispatchKernel<FPType, uint8_t>(qu, gpair_device, row_indices,
return BuildHistDispatchKernel<FPType, uint8_t>(qu, gpair, row_indices,
gmat, hist, is_dense, hist_buffer,
event_priv, force_atomic_use);
break;
case BinTypeSize::kUint16BinsTypeSize:
return BuildHistDispatchKernel<FPType, uint16_t>(qu, gpair_device, row_indices,
return BuildHistDispatchKernel<FPType, uint16_t>(qu, gpair, row_indices,
gmat, hist, is_dense, hist_buffer,
event_priv, force_atomic_use);
break;
case BinTypeSize::kUint32BinsTypeSize:
return BuildHistDispatchKernel<FPType, uint32_t>(qu, gpair_device, row_indices,
return BuildHistDispatchKernel<FPType, uint32_t>(qu, gpair, row_indices,
gmat, hist, is_dense, hist_buffer,
event_priv, force_atomic_use);
break;
Expand All @@ -325,22 +325,22 @@ ::sycl::event BuildHistKernel(::sycl::queue* qu,

template <typename GradientSumT>
::sycl::event GHistBuilder<GradientSumT>::BuildHist(
const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
const HostDeviceVector<GradientPair>& gpair,
const RowSetCollection::Elem& row_indices,
const GHistIndexMatrix &gmat,
GHistRowT<MemoryType::on_device>* hist,
bool isDense,
GHistRowT<MemoryType::on_device>* hist_buffer,
::sycl::event event_priv,
bool force_atomic_use) {
return BuildHistKernel<GradientSumT>(qu_, gpair_device, row_indices, gmat,
return BuildHistKernel<GradientSumT>(qu_, gpair, row_indices, gmat,
isDense, hist, hist_buffer, event_priv,
force_atomic_use);
}

template
::sycl::event GHistBuilder<float>::BuildHist(
const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
const HostDeviceVector<GradientPair>& gpair,
const RowSetCollection::Elem& row_indices,
const GHistIndexMatrix& gmat,
GHistRow<float, MemoryType::on_device>* hist,
Expand All @@ -350,7 +350,7 @@ ::sycl::event GHistBuilder<float>::BuildHist(
bool force_atomic_use);
template
::sycl::event GHistBuilder<double>::BuildHist(
const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
const HostDeviceVector<GradientPair>& gpair,
const RowSetCollection::Elem& row_indices,
const GHistIndexMatrix& gmat,
GHistRow<double, MemoryType::on_device>* hist,
Expand Down
2 changes: 1 addition & 1 deletion plugin/sycl/common/hist_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class GHistBuilder {
GHistBuilder(::sycl::queue* qu, uint32_t nbins) : qu_{qu}, nbins_{nbins} {}

// Construct a histogram via histogram aggregation
::sycl::event BuildHist(const USMVector<GradientPair, MemoryType::on_device>& gpair_device,
::sycl::event BuildHist(const HostDeviceVector<GradientPair>& gpair,
const RowSetCollection::Elem& row_indices,
const GHistIndexMatrix& gmat,
GHistRowT<MemoryType::on_device>* HistCollection,
Expand Down
194 changes: 0 additions & 194 deletions plugin/sycl/common/linalg_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,200 +40,6 @@ ::sycl::event GroupWiseKernel(::sycl::queue* qu, int* flag_ptr,
});
return event;
}

struct Argument {
template <typename T>
operator T&&() const;
};

template <typename Fn, typename Is, typename = void>
struct ArgumentsPassedImpl
: std::false_type {};

template <typename Fn, size_t ...Is>
struct ArgumentsPassedImpl<Fn, std::index_sequence<Is...>,
decltype(std::declval<Fn>()(((void)Is, Argument{})...), void())>
: std::true_type {};

template <typename Fn, size_t N>
struct ArgumentsPassed : ArgumentsPassedImpl<Fn, std::make_index_sequence<N>> {};

template <typename OutputDType, typename InputDType,
size_t BatchSize, size_t MaxNumInputs>
class BatchProcessingHelper {
public:
static constexpr size_t kBatchSize = BatchSize;
using InputType = HostDeviceVector<InputDType>;
using OutputType = HostDeviceVector<OutputDType>;

private:
template <size_t NumInput = 0>
void Host2Buffers(InputDType* in_buffer_ptr, const InputType& input) {
/*
* Some inputs may have less than 1 sample per output symbol.
*/
const size_t sub_sample_rate = ndata_ * sample_rates_[NumInput+1] / input.Size();
const size_t n_samples = batch_size_ * sample_rates_[NumInput+1] / sub_sample_rate;

const InputDType* in_host_ptr = input.HostPointer() +
batch_begin_ * sample_rates_[NumInput+1] / sub_sample_rate;

events_[NumInput] =
qu_->memcpy(in_buffer_ptr, in_host_ptr, n_samples * sizeof(InputDType),
events_[MaxNumInputs - 2]);
}

template <size_t NumInput = 0, class... InputTypes>
void Host2Buffers(InputDType* in_buffer_ptr, const InputType& input,
const InputTypes&... other_inputs) {
// Make copy for the first input in the list
Host2Buffers<NumInput>(in_buffer_ptr, input);
// Recurent call for next inputs
InputDType* next_input = in_buffer_.Data() + in_buff_offsets_[NumInput + 1];
Host2Buffers<NumInput+1>(next_input, other_inputs...);
}

void Buffers2Host(OutputType* output) {
const size_t n_samples = batch_size_ * sample_rates_[0];
OutputDType* out_host_ptr = output->HostPointer() + batch_begin_* sample_rates_[0];
events_[MaxNumInputs - 1] =
qu_->memcpy(out_host_ptr, out_buffer_.DataConst(), n_samples * sizeof(OutputDType),
events_[MaxNumInputs - 2]);
}

void Buffers2Host(InputType* output) {
const size_t n_samples = batch_size_ * sample_rates_[1];
InputDType* out_host_ptr = output->HostPointer() + batch_begin_* sample_rates_[1];
events_[MaxNumInputs - 1] =
qu_->memcpy(out_host_ptr, in_buffer_.DataConst(), n_samples * sizeof(InputDType),
events_[MaxNumInputs - 2]);
}

template <size_t NumInputs = 1, typename Fn, class... InputTypes>
void Call(Fn &&fn, const InputDType* input, const InputTypes*... other_inputs) {
static_assert(NumInputs <= MaxNumInputs,
"To many arguments in the passed function");
/* Passed lambda may have less inputs than MaxNumInputs,
* need to pass only requared number of arguments
*/
// 1 for events, 1 for batch_size, 1 for output
if constexpr (ArgumentsPassed<Fn, NumInputs + 1 + 1 + 1>::value) {
events_[MaxNumInputs - 2] = fn(events_, batch_size_,
out_buffer_.Data(), input, other_inputs...);
} else {
const InputDType* next_input = in_buffer_.DataConst() +
in_buff_offsets_[MaxNumInputs - 1 - NumInputs];
Call<NumInputs+1>(std::forward<Fn>(fn), next_input, input, other_inputs...);
}
}

template <size_t NumInputs = 1, typename Fn, class... InputTypes>
void Call(Fn &&fn, InputDType* io, const InputDType* input, const InputTypes*... other_inputs) {
static_assert(NumInputs <= MaxNumInputs,
"To many arguments in the passed function");
if constexpr (ArgumentsPassed<Fn, NumInputs + 1 + 1>::value) {
events_[MaxNumInputs - 2] = fn(events_, batch_size_,
io, input, other_inputs...);
} else {
const InputDType* next_input = in_buffer_.DataConst() +
in_buff_offsets_[MaxNumInputs - NumInputs];
Call<NumInputs+1>(std::forward<Fn>(fn), io, next_input, input, other_inputs...);
}
}

template <size_t NumInputs = 1, typename Fn>
void Call(Fn &&fn, InputDType* io) {
static_assert(NumInputs <= MaxNumInputs,
"To many arguments in the passed function");
if constexpr (ArgumentsPassed<Fn, NumInputs + 1 + 1>::value) {
events_[MaxNumInputs - 2] = fn(events_, batch_size_, io);
} else {
const InputDType* next_input = in_buffer_.DataConst() +
in_buff_offsets_[MaxNumInputs - 1];
Call<NumInputs+1>(std::forward<Fn>(fn), io, next_input);
}
}

public:
BatchProcessingHelper() = default;

// The first element of sample_rate always corresonds to output sample rate
void InitBuffers(::sycl::queue* qu, const std::vector<int>& sample_rate) {
assert(sample_rate.size() == MaxNumInputs + 1);
sample_rates_ = sample_rate;
qu_ = qu;
events_.resize(MaxNumInputs + 2);
out_buffer_.Resize(qu, kBatchSize * sample_rate.front());

in_buff_offsets_[0] = 0;
for (size_t i = 1; i < MaxNumInputs; ++i) {
in_buff_offsets_[i] = in_buff_offsets_[i - 1] + kBatchSize * sample_rate[i];
}
const size_t in_buff_size = in_buff_offsets_.back() + kBatchSize * sample_rate.back();
in_buffer_.Resize(qu, in_buff_size);
}

/*
* Batch-wise proces on sycl device
* output = fn(inputs)
*/
template <typename Fn, class... InputTypes>
void Calculate(Fn &&fn, OutputType* output, const InputTypes&... inputs) {
ndata_ = output->Size() / sample_rates_.front();
const size_t nBatch = ndata_ / kBatchSize + (ndata_ % kBatchSize > 0);
for (size_t batch = 0; batch < nBatch; ++batch) {
batch_begin_ = batch * kBatchSize;
batch_end_ = (batch == nBatch - 1) ? ndata_ : batch_begin_ + kBatchSize;
batch_size_ = batch_end_ - batch_begin_;

// Iteratively copy all inputs to device buffers
Host2Buffers(in_buffer_.Data(), inputs...);
// Pack buffers and call function
// We shift input pointer to keep the same order of inputs after packing
Call(std::forward<Fn>(fn), in_buffer_.DataConst() + in_buff_offsets_.back());
// Copy results to host
Buffers2Host(output);
}
}

/*
* Batch-wise proces on sycl device
* input = fn(input, other_inputs)
*/
template <typename Fn, class... InputTypes>
void Calculate(Fn &&fn, InputType* input, const InputTypes&... other_inputs) {
ndata_ = input->Size();
const size_t nBatch = ndata_ / kBatchSize + (ndata_ % kBatchSize > 0);
for (size_t batch = 0; batch < nBatch; ++batch) {
batch_begin_ = batch * kBatchSize;
batch_end_ = (batch == nBatch - 1) ? ndata_ : batch_begin_ + kBatchSize;
batch_size_ = batch_end_ - batch_begin_;

// Iteratively copy all inputs to device buffers.
// inputs are pased by const reference
Host2Buffers(in_buffer_.Data(), *(input), other_inputs...);
// Pack buffers and call function
// We shift input pointer to keep the same order of inputs after packing
Call(std::forward<Fn>(fn), in_buffer_.Data());
// Copy results to host
Buffers2Host(input);
}
}

private:
std::array<int, MaxNumInputs> in_buff_offsets_;
std::vector<int> sample_rates_;
size_t ndata_;
size_t batch_begin_;
size_t batch_end_;
// is not equal to kBatchSize for the last batch
size_t batch_size_;
::sycl::queue* qu_;
std::vector<::sycl::event> events_;
USMVector<InputDType, MemoryType::on_device> in_buffer_;
USMVector<OutputDType, MemoryType::on_device> out_buffer_;
};

} // namespace linalg
} // namespace sycl
} // namespace xgboost
Expand Down
Loading

0 comments on commit e8a3ead

Please sign in to comment.