Skip to content

Commit

Permalink
reuse FindSplitConditons from cpu branch
Browse files Browse the repository at this point in the history
  • Loading branch information
Dmitry Razdoburdin committed Jul 22, 2024
1 parent 916d4a4 commit d1d1f3c
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 38 deletions.
35 changes: 5 additions & 30 deletions plugin/sycl/tree/hist_updater.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include <limits>
#include <vector>

#include "../../src/tree/common_row_partitioner.h"

#include "../common/hist_util.h"
#include "../../src/collective/allreduce.h"

Expand Down Expand Up @@ -786,34 +788,6 @@ void HistUpdater<GradientSumT>::EnumerateSplit(
best.SplitIndex() == total_split_index) p_best->Update(best);
}

template <typename GradientSumT>
void HistUpdater<GradientSumT>::FindSplitConditions(
const std::vector<ExpandEntry>& nodes,
const RegTree& tree,
const common::GHistIndexMatrix& gmat,
std::vector<int32_t>* split_conditions) {
const size_t n_nodes = nodes.size();
split_conditions->resize(n_nodes);

for (size_t i = 0; i < nodes.size(); ++i) {
const int32_t nid = nodes[i].nid;
const bst_uint fid = tree[nid].SplitIndex();
const bst_float split_pt = tree[nid].SplitCond();
const uint32_t lower_bound = gmat.cut.Ptrs()[fid];
const uint32_t upper_bound = gmat.cut.Ptrs()[fid + 1];
int32_t split_cond = -1;
// convert floating-point split_pt into corresponding bin_id
// split_cond = -1 indicates that split_pt is less than all known cut points
CHECK_LT(upper_bound,
static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
for (uint32_t i = lower_bound; i < upper_bound; ++i) {
if (split_pt == gmat.cut.Values()[i]) {
split_cond = static_cast<int32_t>(i);
}
}
(*split_conditions)[i] = split_cond;
}
}
template <typename GradientSumT>
void HistUpdater<GradientSumT>::AddSplitsToRowSet(
const std::vector<ExpandEntry>& nodes,
Expand All @@ -835,11 +809,12 @@ void HistUpdater<GradientSumT>::ApplySplit(
const common::GHistIndexMatrix& gmat,
const common::HistCollection<GradientSumT, MemoryType::on_device>& hist,
RegTree* p_tree) {
using CommonRowPartitioner = xgboost::tree::CommonRowPartitioner;
builder_monitor_.Start("ApplySplit");

const size_t n_nodes = nodes.size();
std::vector<int32_t> split_conditions;
FindSplitConditions(nodes, *p_tree, gmat, &split_conditions);
std::vector<int32_t> split_conditions(n_nodes);
CommonRowPartitioner::FindSplitConditions(nodes, *p_tree, gmat, &split_conditions);

partition_builder_.Init(&qu_, n_nodes, [&](size_t node_in_set) {
const int32_t nid = nodes[node_in_set].nid;
Expand Down
5 changes: 0 additions & 5 deletions plugin/sycl/tree/hist_updater.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,6 @@ class HistUpdater {

void AddSplitsToRowSet(const std::vector<ExpandEntry>& nodes, RegTree* p_tree);


void FindSplitConditions(const std::vector<ExpandEntry>& nodes, const RegTree& tree,
const common::GHistIndexMatrix& gmat,
std::vector<int32_t>* split_conditions);

void InitData(const common::GHistIndexMatrix& gmat,
const USMVector<GradientPair, MemoryType::on_device> &gpair,
const DMatrix& fmat,
Expand Down
7 changes: 4 additions & 3 deletions src/tree/common_row_partitioner.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,10 @@ class CommonRowPartitioner {
}
}

template <typename ExpandEntry>
void FindSplitConditions(const std::vector<ExpandEntry>& nodes, const RegTree& tree,
const GHistIndexMatrix& gmat, std::vector<int32_t>* split_conditions) {
/* Making GHistIndexMatrix_t a templete parameter allows reuse this function for sycl-plugin */
template <typename ExpandEntry, typename GHistIndexMatrix_t>
void static FindSplitConditions(const std::vector<ExpandEntry>& nodes, const RegTree& tree,
const GHistIndexMatrix_t& gmat, std::vector<int32_t>* split_conditions) {
auto const& ptrs = gmat.cut.Ptrs();
auto const& vals = gmat.cut.Values();

Expand Down
8 changes: 8 additions & 0 deletions tests/cpp/plugin/test_sycl_hist_updater.cc
Original file line number Diff line number Diff line change
Expand Up @@ -432,4 +432,12 @@ TEST(SyclHistUpdater, EvaluateSplits) {
TestHistUpdaterEvaluateSplits<double>(param);
}

TEST(SyclHistUpdater, ApplySplit) {
xgboost::tree::TrainParam param;
param.UpdateAllowUnknown(Args{{"max_depth", "3"}});

// TestHistUpdaterApplySplit<float>(param);
// TestHistUpdaterApplySplit<double>(param);
}

} // namespace xgboost::sycl::tree

0 comments on commit d1d1f3c

Please sign in to comment.