Skip to content

Commit

Permalink
refactor device matrix inintialization
Browse files Browse the repository at this point in the history
  • Loading branch information
Dmitry Razdoburdin committed Jan 30, 2024
1 parent 025b286 commit 5fcd136
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 18 deletions.
29 changes: 20 additions & 9 deletions plugin/sycl/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,8 @@ class USMVector {
struct DeviceMatrix {
DMatrix* p_mat; // Pointer to the original matrix on the host
::sycl::queue qu_;
USMVector<size_t> row_ptr;
USMVector<Entry> data;
USMVector<size_t, MemoryType::on_device> row_ptr;
USMVector<Entry, MemoryType::on_device> data;
size_t total_offset;

DeviceMatrix(::sycl::queue qu, DMatrix* dmat) : p_mat(dmat), qu_(qu) {
Expand All @@ -231,6 +231,7 @@ struct DeviceMatrix {
}

row_ptr.Resize(&qu_, num_row + 1);
size_t* rows = row_ptr.Data();
data.Resize(&qu_, num_nonzero);

size_t data_offset = 0;
Expand All @@ -239,18 +240,28 @@ struct DeviceMatrix {
const auto& offset_vec = batch.offset.HostVector();
size_t batch_size = batch.Size();
if (batch_size > 0) {
std::copy(offset_vec.data(), offset_vec.data() + batch_size,
row_ptr.Data() + batch.base_rowid);
auto event = qu.memcpy(row_ptr.Data() + batch.base_rowid, offset_vec.data(), sizeof(size_t) * batch_size);
if (batch.base_rowid > 0) {
for (size_t i = 0; i < batch_size; i++)
row_ptr[i + batch.base_rowid] += batch.base_rowid;
const auto base_rowid = batch.base_rowid;
qu.submit([&](::sycl::handler& cgh) {
cgh.depends_on(event);
cgh.parallel_for<>(::sycl::range<1>(batch_size), [=](::sycl::id<1> pid) {
int row_id = pid[0];
rows[row_id] += base_rowid;
});
});
}
std::copy(data_vec.data(), data_vec.data() + offset_vec[batch_size],
data.Data() + data_offset);
qu.memcpy(data.Data() + data_offset, data_vec.data(), sizeof(Entry) * offset_vec[batch_size]);
data_offset += offset_vec[batch_size];
}
}
row_ptr[num_row] = data_offset;
qu.submit([&](::sycl::handler& cgh) {
cgh.single_task<>([=] {
rows[num_row] = data_offset;
});
});
qu.wait();

total_offset = data_offset;
}

Expand Down
6 changes: 0 additions & 6 deletions plugin/sycl/data/gradient_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,6 @@ void GHistIndexMatrix::Init(::sycl::queue qu,
const bool isDense = p_fmat_device.p_mat->IsDense();
this->isDense_ = isDense;

row_ptr = std::vector<size_t>(p_fmat_device.row_ptr.Begin(), p_fmat_device.row_ptr.End());
row_ptr_device = p_fmat_device.row_ptr;

index.setQueue(qu);

row_stride = 0;
Expand All @@ -151,9 +148,6 @@ void GHistIndexMatrix::Init(::sycl::queue qu,
index.ResizeOffset(n_offsets);
offsets = index.Offset();
qu.memcpy(offsets, cut.Ptrs().data(), sizeof(uint32_t) * n_offsets).wait_and_throw();
// for (size_t i = 0; i < n_offsets; ++i) {
// offsets[i] = cut.Ptrs()[i];
// }
}

if (isDense) {
Expand Down
2 changes: 0 additions & 2 deletions plugin/sycl/data/gradient_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,6 @@ struct Index {
*/
struct GHistIndexMatrix {
/*! \brief row pointer to rows by element position */
std::vector<size_t> row_ptr;
USMVector<size_t> row_ptr_device;
/*! \brief The index data */
Index index;
/*! \brief hit count of each index */
Expand Down
4 changes: 3 additions & 1 deletion plugin/sycl/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,10 @@ void QuantileHistMaker::Update(xgboost::tree::TrainParam const *param,
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree *> &trees) {
if (dmat != p_last_dmat_ || is_gmat_initialized_ == false) {
updater_monitor_.Start("GmatInitialization");
updater_monitor_.Start("DeviceMatrixInitialization");
sycl::DeviceMatrix dmat_device(qu_, dmat);
updater_monitor_.Stop("DeviceMatrixInitialization");
updater_monitor_.Start("GmatInitialization");
gmat_.Init(qu_, ctx_, dmat_device, static_cast<uint32_t>(param_.max_bin));
updater_monitor_.Stop("GmatInitialization");
is_gmat_initialized_ = true;
Expand Down

0 comments on commit 5fcd136

Please sign in to comment.