forked from dmlc/xgboost
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add the basics of partition builder implementation and the related test
- Loading branch information
Dmitry Razdoburdin
committed
Jan 26, 2024
1 parent
65d7bf2
commit 5ea7624
Showing
6 changed files
with
250 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
/*! | ||
* Copyright 2017-2023 XGBoost contributors | ||
*/ | ||
#ifndef PLUGIN_SYCL_COMMON_PARTITION_BUILDER_H_ | ||
#define PLUGIN_SYCL_COMMON_PARTITION_BUILDER_H_ | ||
|
||
#pragma GCC diagnostic push | ||
#pragma GCC diagnostic ignored "-Wtautological-constant-compare" | ||
#pragma GCC diagnostic ignored "-W#pragma-messages" | ||
#include <xgboost/data.h> | ||
#pragma GCC diagnostic pop | ||
#include <xgboost/tree_model.h> | ||
|
||
#include <algorithm> | ||
#include <vector> | ||
#include <utility> | ||
|
||
#pragma GCC diagnostic push | ||
#pragma GCC diagnostic ignored "-Wtautological-constant-compare" | ||
#include "../../../src/common/column_matrix.h" | ||
#pragma GCC diagnostic pop | ||
|
||
#include "../data.h" | ||
|
||
#include <CL/sycl.hpp> | ||
|
||
namespace xgboost { | ||
namespace sycl { | ||
namespace common { | ||
|
||
// The builder is required for samples partition to left and rights children for set of nodes | ||
class PartitionBuilder { | ||
public: | ||
static constexpr size_t maxLocalSums = 256; | ||
static constexpr size_t subgroupSize = 16; | ||
|
||
|
||
template<typename Func> | ||
void Init(::sycl::queue* qu, size_t n_nodes, Func funcNTaks) { | ||
qu_ = qu; | ||
nodes_offsets_.resize(n_nodes+1); | ||
result_rows_.resize(2 * n_nodes); | ||
n_nodes_ = n_nodes; | ||
|
||
|
||
nodes_offsets_[0] = 0; | ||
for (size_t i = 1; i < n_nodes+1; ++i) { | ||
nodes_offsets_[i] = nodes_offsets_[i-1] + funcNTaks(i-1); | ||
} | ||
|
||
if (data_.Size() < nodes_offsets_[n_nodes]) { | ||
data_.Resize(qu, nodes_offsets_[n_nodes]); | ||
} | ||
} | ||
|
||
size_t GetSubgroupSize() { | ||
return subgroupSize; | ||
} | ||
|
||
|
||
size_t GetNLeftElems(int nid) const { | ||
return result_rows_[2 * nid]; | ||
} | ||
|
||
|
||
size_t GetNRightElems(int nid) const { | ||
return result_rows_[2 * nid + 1]; | ||
} | ||
|
||
void SetNLeftElems(int nid, size_t val) { | ||
result_rows_[2 * nid] = val; | ||
} | ||
|
||
|
||
void SetNRightElems(int nid, size_t val) { | ||
result_rows_[2 * nid + 1] = val; | ||
} | ||
|
||
xgboost::common::Span<size_t> GetData(int nid) { | ||
return { data_.Data() + nodes_offsets_[nid], nodes_offsets_[nid + 1] - nodes_offsets_[nid] }; | ||
} | ||
|
||
void MergeToArray(size_t nid, | ||
size_t* data_result, | ||
::sycl::event* event) { | ||
size_t n_nodes_total = GetNLeftElems(nid) + GetNRightElems(nid); | ||
if (n_nodes_total > 0) { | ||
const size_t* data = data_.Data() + nodes_offsets_[nid]; | ||
qu_->memcpy(data_result, data, sizeof(size_t) * n_nodes_total, *event); | ||
} | ||
} | ||
|
||
protected: | ||
std::vector<size_t> nodes_offsets_; | ||
std::vector<size_t> result_rows_; | ||
std::vector<::sycl::event> nodes_events_; | ||
size_t n_nodes_; | ||
|
||
USMVector<size_t, MemoryType::on_device> parts_size_; | ||
USMVector<size_t, MemoryType::on_device> data_; | ||
|
||
::sycl::queue* qu_; | ||
}; | ||
|
||
} // namespace common | ||
} // namespace sycl | ||
} // namespace xgboost | ||
|
||
|
||
#endif // PLUGIN_SYCL_COMMON_PARTITION_BUILDER_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
/** | ||
* Copyright 2020-2023 by XGBoost contributors | ||
*/ | ||
#include <gtest/gtest.h> | ||
|
||
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#include "../../../plugin/sycl/common/partition_builder.h" | ||
#include "../../../plugin/sycl/device_manager.h" | ||
#include "../helpers.h" | ||
|
||
namespace xgboost::sycl::common { | ||
|
||
TEST(SyclPartitionBuilder, BasicTest) { | ||
constexpr size_t kNodes = 5; | ||
// Number of rows for each node | ||
std::vector<size_t> rows = { 5, 5, 10, 1, 2 }; | ||
|
||
DeviceManager device_manager; | ||
auto qu = device_manager.GetQueue(DeviceOrd::SyclDefault()); | ||
PartitionBuilder builder; | ||
builder.Init(&qu, kNodes, [&](size_t i) { | ||
return rows[i]; | ||
}); | ||
|
||
// We test here only the basics, thus syntetic partition builder is adopted | ||
// Number of rows to go left for each node. | ||
std::vector<size_t> rows_for_left_node = { 2, 0, 7, 1, 2 }; | ||
|
||
size_t first_row_id = 0; | ||
for(size_t nid = 0; nid < kNodes; ++nid) { | ||
size_t n_rows_nodes = rows[nid]; | ||
|
||
auto rid_buff = builder.GetData(nid); | ||
size_t rid_buff_size = rid_buff.size(); | ||
auto* rid_buff_ptr = rid_buff.data(); | ||
|
||
size_t n_left = rows_for_left_node[nid]; | ||
size_t n_right = rows[nid] - n_left; | ||
|
||
qu.submit([&](::sycl::handler& cgh) { | ||
cgh.parallel_for<>(::sycl::range<1>(n_left), [=](::sycl::id<1> pid) { | ||
int row_id = first_row_id + pid[0]; | ||
rid_buff_ptr[pid[0]] = row_id; | ||
}); | ||
}); | ||
qu.wait(); | ||
first_row_id += n_left; | ||
|
||
// We are storing indexes for the right side in the tail of the array to save some memory | ||
qu.submit([&](::sycl::handler& cgh) { | ||
cgh.parallel_for<>(::sycl::range<1>(n_right), [=](::sycl::id<1> pid) { | ||
int row_id = first_row_id + pid[0]; | ||
rid_buff_ptr[rid_buff_size - pid[0] - 1] = row_id; | ||
}); | ||
}); | ||
qu.wait(); | ||
first_row_id += n_right; | ||
|
||
builder.SetNLeftElems(nid, n_left); | ||
builder.SetNRightElems(nid, n_right); | ||
} | ||
|
||
::sycl::event event; | ||
std::vector<size_t> v(*std::max_element(rows.begin(), rows.end())); | ||
size_t row_id = 0; | ||
for(size_t nid = 0; nid < kNodes; ++nid) { | ||
builder.MergeToArray(nid, v.data(), &event); | ||
qu.wait(); | ||
|
||
// Check that row_id for left side are correct | ||
for(size_t j = 0; j < rows_for_left_node[nid]; ++j) { | ||
ASSERT_EQ(v[j], row_id++); | ||
} | ||
|
||
// Check that row_id for right side are correct | ||
for(size_t j = 0; j < rows[nid] - rows_for_left_node[nid]; ++j) { | ||
ASSERT_EQ(v[rows[nid] - j - 1], row_id++); | ||
} | ||
|
||
// Check that number of left/right rows are correct | ||
size_t n_left = builder.GetNLeftElems(nid); | ||
size_t n_right = builder.GetNRightElems(nid); | ||
ASSERT_EQ(n_left, rows_for_left_node[nid]); | ||
ASSERT_EQ(n_right, (rows[nid] - rows_for_left_node[nid])); | ||
} | ||
} | ||
|
||
} // namespace xgboost::common |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters