Skip to content

Commit

Permalink
refactor UpdatePredictionCache
Browse files Browse the repository at this point in the history
  • Loading branch information
Dmitry Razdoburdin committed Oct 29, 2024
1 parent f863bee commit 9334d5e
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 28 deletions.
4 changes: 2 additions & 2 deletions include/xgboost/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -608,13 +608,13 @@ auto MakeTensorView(Context const *ctx, Order order, common::Span<T, ext> data,

template <typename T, typename... S>
auto MakeTensorView(Context const *ctx, HostDeviceVector<T> *data, S &&...shape) {
auto span = ctx->IsCUDA() ? data->DeviceSpan() : data->HostSpan();
auto span = ctx->IsCPU() ? data->HostSpan() : data->DeviceSpan();
return MakeTensorView(ctx->Device(), span, std::forward<S>(shape)...);
}

template <typename T, typename... S>
auto MakeTensorView(Context const *ctx, HostDeviceVector<T> const *data, S &&...shape) {
auto span = ctx->IsCUDA() ? data->ConstDeviceSpan() : data->ConstHostSpan();
auto span = ctx->IsCPU() ? data->ConstHostSpan() : data->ConstDeviceSpan();
return MakeTensorView(ctx->Device(), span, std::forward<S>(shape)...);
}

Expand Down
27 changes: 4 additions & 23 deletions plugin/sycl/tree/hist_updater.cc
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ template<typename GradientSumT>
bool HistUpdater<GradientSumT>::UpdatePredictionCache(
const DMatrix* data,
linalg::MatrixView<float> out_preds) {
CHECK(out_preds.Device().IsSycl());
// p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in
// conjunction with Update().
if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_) {
Expand All @@ -372,23 +373,6 @@ bool HistUpdater<GradientSumT>::UpdatePredictionCache(
builder_monitor_.Start("UpdatePredictionCache");
CHECK_GT(out_preds.Size(), 0U);

const size_t stride = out_preds.Stride(0);
const bool is_first_group = (out_pred_ptr == nullptr);
const size_t gid = out_pred_ptr == nullptr ? 0 : &out_preds(0) - out_pred_ptr;
const bool is_last_group = (gid + 1 == stride);

const int buffer_size = out_preds.Size() *stride;
if (buffer_size == 0) return true;

::sycl::event event;
if (is_first_group) {
out_preds_buf_.ResizeNoCopy(qu_, buffer_size);
out_pred_ptr = &out_preds(0);
event = qu_->memcpy(out_preds_buf_.Data(), out_pred_ptr,
buffer_size * sizeof(bst_float), event);
}
auto* out_preds_buf_ptr = out_preds_buf_.Data();

size_t n_nodes = row_set_collection_.Size();
std::vector<::sycl::event> events(n_nodes);
for (size_t node = 0; node < n_nodes; node++) {
Expand All @@ -408,17 +392,14 @@ bool HistUpdater<GradientSumT>::UpdatePredictionCache(
const size_t num_rows = rowset.Size();

events[node] = qu_->submit([&](::sycl::handler& cgh) {
cgh.depends_on(event);
cgh.parallel_for<>(::sycl::range<1>(num_rows), [=](::sycl::item<1> pid) {
out_preds_buf_ptr[rid[pid.get_id(0)]*stride + gid] += leaf_value;
size_t row_id = rid[pid.get_id(0)];
float& val = const_cast<float&>(out_preds(row_id));
val += leaf_value;
});
});
}
}
if (is_last_group) {
qu_->memcpy(out_pred_ptr, out_preds_buf_ptr, buffer_size * sizeof(bst_float), events);
out_pred_ptr = nullptr;
}
qu_->wait();

builder_monitor_.Stop("UpdatePredictionCache");
Expand Down
3 changes: 0 additions & 3 deletions plugin/sycl/tree/hist_updater.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,6 @@ class HistUpdater {
std::unique_ptr<HistSynchronizer<GradientSumT>> hist_synchronizer_;
std::unique_ptr<HistRowsAdder<GradientSumT>> hist_rows_adder_;

USMVector<bst_float, MemoryType::on_device> out_preds_buf_;
bst_float* out_pred_ptr = nullptr;

std::vector<GradientPairT> reduce_buffer_;
::sycl::queue* qu_;
};
Expand Down
1 change: 1 addition & 0 deletions src/objective/objective.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ std::string ObjFunction::GetSyclImplementationName(const std::string& name) {
return name + sycl_postfix;
} else {
// Function hasn't specific sycl implementation
LOG(FATAL) << "`" << name << "` doesn't have sycl implementation yet\n";
return name;
}
}
Expand Down

0 comments on commit 9334d5e

Please sign in to comment.