Skip to content

Commit

Permalink
Add tests for EvaluateSplits (dmlc#59)
Browse files Browse the repository at this point in the history
* minor refactoring

* optimize host-device memory sync

* add test for EvaluateSplits

* linting

---------

Co-authored-by: Dmitry Razdoburdin <>
  • Loading branch information
razdoburdin authored Jul 17, 2024
1 parent ba88551 commit 916d4a4
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 39 deletions.
60 changes: 26 additions & 34 deletions plugin/sycl/tree/hist_updater.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ void HistUpdater<GradientSumT>::EvaluateAndApplySplits(
int *num_leaves,
int depth,
std::vector<ExpandEntry> *temp_qexpand_depth) {
EvaluateSplits(qexpand_depth_wise_, gmat, hist_, *p_tree);
EvaluateSplits(qexpand_depth_wise_, gmat, *p_tree);

std::vector<ExpandEntry> nodes_for_apply_split;
AddSplitsToTree(gmat, p_tree, num_leaves, depth,
Expand Down Expand Up @@ -280,7 +280,7 @@ void HistUpdater<GradientSumT>::ExpandWithLossGuide(

this->InitNewNode(ExpandEntry::kRootNid, gmat, gpair, *p_fmat, *p_tree);

this->EvaluateSplits({node}, gmat, hist_, *p_tree);
this->EvaluateSplits({node}, gmat, *p_tree);
node.split.loss_chg = snode_host_[ExpandEntry::kRootNid].best.loss_chg;

qexpand_loss_guided_->push(node);
Expand Down Expand Up @@ -325,7 +325,7 @@ void HistUpdater<GradientSumT>::ExpandWithLossGuide(
snode_host_[cleft].weight, snode_host_[cright].weight);
interaction_constraints_.Split(nid, featureid, cleft, cright);

this->EvaluateSplits({left_node, right_node}, gmat, hist_, *p_tree);
this->EvaluateSplits({left_node, right_node}, gmat, *p_tree);
left_node.split.loss_chg = snode_host_[cleft].best.loss_chg;
right_node.split.loss_chg = snode_host_[cright].best.loss_chg;

Expand Down Expand Up @@ -472,7 +472,7 @@ void HistUpdater<GradientSumT>::InitSampling(
});
});
} else {
// Use oneDPL uniform for better perf, as far as bernoulli_distribution uses fp64
// Use oneDPL uniform, as far as bernoulli_distribution uses fp64
event = qu_.submit([&](::sycl::handler& cgh) {
auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::read_write>(cgh);
cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(num_rows)),
Expand Down Expand Up @@ -649,45 +649,32 @@ template<typename GradientSumT>
void HistUpdater<GradientSumT>::EvaluateSplits(
const std::vector<ExpandEntry>& nodes_set,
const common::GHistIndexMatrix& gmat,
const common::HistCollection<GradientSumT, MemoryType::on_device>& hist,
const RegTree& tree) {
builder_monitor_.Start("EvaluateSplits");

const size_t n_nodes_in_set = nodes_set.size();

using FeatureSetType = std::shared_ptr<HostDeviceVector<bst_feature_t>>;
std::vector<FeatureSetType> features_sets(n_nodes_in_set);

// Generate feature set for each tree node
size_t total_features = 0;
for (size_t nid_in_set = 0; nid_in_set < n_nodes_in_set; ++nid_in_set) {
const int32_t nid = nodes_set[nid_in_set].nid;
features_sets[nid_in_set] = column_sampler_->GetFeatureSet(tree.GetDepth(nid));
for (size_t idx = 0; idx < features_sets[nid_in_set]->Size(); idx++) {
const auto fid = features_sets[nid_in_set]->ConstHostVector()[idx];
if (interaction_constraints_.Query(nid, fid)) {
total_features++;
}
}
}

split_queries_host_.resize(total_features);
size_t pos = 0;

for (size_t nid_in_set = 0; nid_in_set < n_nodes_in_set; ++nid_in_set) {
const size_t nid = nodes_set[nid_in_set].nid;

for (size_t idx = 0; idx < features_sets[nid_in_set]->Size(); idx++) {
const auto fid = features_sets[nid_in_set]->ConstHostVector()[idx];
const bst_node_t nid = nodes_set[nid_in_set].nid;
FeatureSetType features_set = column_sampler_->GetFeatureSet(tree.GetDepth(nid));
for (size_t idx = 0; idx < features_set->Size(); idx++) {
const size_t fid = features_set->ConstHostVector()[idx];
if (interaction_constraints_.Query(nid, fid)) {
split_queries_host_[pos].nid = nid;
split_queries_host_[pos].fid = fid;
split_queries_host_[pos].hist = hist[nid].DataConst();
split_queries_host_[pos].best = snode_host_[nid].best;
pos++;
auto this_hist = hist_[nid].DataConst();
if (pos < split_queries_host_.size()) {
split_queries_host_[pos] = SplitQuery{nid, fid, this_hist};
} else {
split_queries_host_.push_back({nid, fid, this_hist});
}
++pos;
}
}
}
const size_t total_features = pos;

split_queries_device_.Resize(&qu_, total_features);
auto event = qu_.memcpy(split_queries_device_.Data(), split_queries_host_.data(),
Expand All @@ -702,10 +689,14 @@ void HistUpdater<GradientSumT>::EvaluateSplits(
snode_device_.ResizeNoCopy(&qu_, snode_host_.size());
event = qu_.memcpy(snode_device_.Data(), snode_host_.data(),
snode_host_.size() * sizeof(NodeEntry<GradientSumT>), event);
const NodeEntry<GradientSumT>* snode = snode_device_.DataConst();
const NodeEntry<GradientSumT>* snode = snode_device_.Data();

const float min_child_weight = param_.min_child_weight;

best_splits_device_.ResizeNoCopy(&qu_, total_features);
if (best_splits_host_.size() < total_features) best_splits_host_.resize(total_features);
SplitEntry<GradientSumT>* best_splits = best_splits_device_.Data();

event = qu_.submit([&](::sycl::handler& cgh) {
cgh.depends_on(event);
cgh.parallel_for<>(::sycl::nd_range<2>(::sycl::range<2>(total_features, sub_group_size_),
Expand All @@ -717,17 +708,18 @@ void HistUpdater<GradientSumT>::EvaluateSplits(
int fid = split_queries_device[i].fid;
const GradientPairT* hist_data = split_queries_device[i].hist;

best_splits[i] = snode[nid].best;
EnumerateSplit(sg, cut_ptr, cut_val, hist_data, snode[nid],
&(split_queries_device[i].best), fid, nid, evaluator, min_child_weight);
&(best_splits[i]), fid, nid, evaluator, min_child_weight);
});
});
event = qu_.memcpy(split_queries_host_.data(), split_queries_device_.Data(),
total_features * sizeof(SplitQuery), event);
event = qu_.memcpy(best_splits_host_.data(), best_splits,
total_features * sizeof(SplitEntry<GradientSumT>), event);

qu_.wait();
for (size_t i = 0; i < total_features; i++) {
int nid = split_queries_host_[i].nid;
snode_host_[nid].best.Update(split_queries_host_[i].best);
snode_host_[nid].best.Update(best_splits_host_[i]);
}

builder_monitor_.Stop("EvaluateSplits");
Expand Down
9 changes: 5 additions & 4 deletions plugin/sycl/tree/hist_updater.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,8 @@ class HistUpdater {
friend class DistributedHistRowsAdder<GradientSumT>;

struct SplitQuery {
int nid;
int fid;
SplitEntry<GradientSumT> best;
bst_node_t nid;
size_t fid;
const GradientPairT* hist;
};

Expand All @@ -106,7 +105,6 @@ class HistUpdater {

void EvaluateSplits(const std::vector<ExpandEntry>& nodes_set,
const common::GHistIndexMatrix& gmat,
const common::HistCollection<GradientSumT, MemoryType::on_device>& hist,
const RegTree& tree);

// Enumerate the split values of specific feature
Expand Down Expand Up @@ -222,6 +220,9 @@ class HistUpdater {
std::vector<SplitQuery> split_queries_host_;
USMVector<SplitQuery, MemoryType::on_device> split_queries_device_;

USMVector<SplitEntry<GradientSumT>, MemoryType::on_device> best_splits_device_;
std::vector<SplitEntry<GradientSumT>> best_splits_host_;

TreeEvaluator<GradientSumT> tree_evaluator_;
FeatureInteractionConstraintHost interaction_constraints_;

Expand Down
2 changes: 1 addition & 1 deletion tests/ci_build/conda_env/linux_sycl_test.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name: linux_sycl_test
channels:
- conda-forge
- intel
- https://software.repos.intel.com/python/conda/
dependencies:
- python=3.8
- cmake
Expand Down
92 changes: 92 additions & 0 deletions tests/cpp/plugin/test_sycl_hist_updater.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ class TestHistUpdater : public HistUpdater<GradientSumT> {
HistUpdater<GradientSumT>::InitNewNode(nid, gmat, gpair, fmat, tree);
return HistUpdater<GradientSumT>::snode_host_[nid];
}

auto TestEvaluateSplits(const std::vector<ExpandEntry>& nodes_set,
const common::GHistIndexMatrix& gmat,
const RegTree& tree) {
HistUpdater<GradientSumT>::EvaluateSplits(nodes_set, gmat, tree);
return HistUpdater<GradientSumT>::snode_host_;
}
};

void GenerateRandomGPairs(::sycl::queue* qu, GradientPair* gpair_ptr, size_t num_rows, bool has_neg_hess) {
Expand Down Expand Up @@ -301,6 +308,83 @@ void TestHistUpdaterInitNewNode(const xgboost::tree::TrainParam& param, float sp
EXPECT_NEAR(snode.stats.GetHess(), grad_stat.GetHess(), 1e-6 * grad_stat.GetHess());
}

template <typename GradientSumT>
void TestHistUpdaterEvaluateSplits(const xgboost::tree::TrainParam& param) {
const size_t num_rows = 1u << 8;
const size_t num_columns = 2;
const size_t n_bins = 32;

Context ctx;
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}});

DeviceManager device_manager;
auto qu = device_manager.GetQueue(ctx.Device());
ObjInfo task{ObjInfo::kRegression};

auto p_fmat = RandomDataGenerator{num_rows, num_columns, 0.0f}.GenerateDMatrix();

FeatureInteractionConstraintHost int_constraints;

TestHistUpdater<GradientSumT> updater(&ctx, qu, param, int_constraints, p_fmat.get());
updater.SetHistSynchronizer(new BatchHistSynchronizer<GradientSumT>());
updater.SetHistRowsAdder(new BatchHistRowsAdder<GradientSumT>());

USMVector<GradientPair, MemoryType::on_device> gpair(&qu, num_rows);
auto* gpair_ptr = gpair.Data();
GenerateRandomGPairs(&qu, gpair_ptr, num_rows, false);

DeviceMatrix dmat;
dmat.Init(qu, p_fmat.get());
common::GHistIndexMatrix gmat;
gmat.Init(qu, &ctx, dmat, n_bins);

RegTree tree;
tree.ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
ExpandEntry node(ExpandEntry::kRootNid, tree.GetDepth(ExpandEntry::kRootNid));

auto* row_set_collection = updater.TestInitData(gmat, gpair, *p_fmat, tree);
auto& row_idxs = row_set_collection->Data();
const size_t* row_idxs_ptr = row_idxs.DataConst();
const auto* hist = updater.TestBuildHistogramsLossGuide(node, gmat, &tree, gpair);
const auto snode_init = updater.TestInitNewNode(ExpandEntry::kRootNid, gmat, gpair, *p_fmat, tree);

const auto snode_updated = updater.TestEvaluateSplits({node}, gmat, tree);
auto best_loss_chg = snode_updated[0].best.loss_chg;
auto stats = snode_init.stats;
auto root_gain = snode_init.root_gain;

// Check all splits manually. Save the best one and compare with the ans
TreeEvaluator<GradientSumT> tree_evaluator(qu, param, num_columns);
auto evaluator = tree_evaluator.GetEvaluator();
const uint32_t* cut_ptr = gmat.cut_device.Ptrs().DataConst();
const size_t size = gmat.cut_device.Ptrs().Size();
int n_better_splits = 0;
const auto* hist_ptr = (*hist)[0].DataConst();
std::vector<bst_float> best_loss_chg_des(1, -1);
{
::sycl::buffer<bst_float> best_loss_chg_buff(best_loss_chg_des.data(), 1);
qu.submit([&](::sycl::handler& cgh) {
auto best_loss_chg_acc = best_loss_chg_buff.template get_access<::sycl::access::mode::read_write>(cgh);
cgh.single_task<>([=]() {
for (size_t i = 1; i < size; ++i) {
GradStats<GradientSumT> left(0, 0);
GradStats<GradientSumT> right = stats - left;
for (size_t j = cut_ptr[i-1]; j < cut_ptr[i]; ++j) {
auto loss_change = evaluator.CalcSplitGain(0, i - 1, left, right) - root_gain;
if (loss_change > best_loss_chg_acc[0]) {
best_loss_chg_acc[0] = loss_change;
}
left.Add(hist_ptr[j].GetGrad(), hist_ptr[j].GetHess());
right = stats - left;
}
}
});
}).wait();
}

ASSERT_NEAR(best_loss_chg_des[0], best_loss_chg, 1e-6);
}

TEST(SyclHistUpdater, Sampling) {
xgboost::tree::TrainParam param;
param.UpdateAllowUnknown(Args{{"subsample", "0.7"}});
Expand Down Expand Up @@ -340,4 +424,12 @@ TEST(SyclHistUpdater, InitNewNode) {
TestHistUpdaterInitNewNode<double>(param, 0.5);
}

TEST(SyclHistUpdater, EvaluateSplits) {
xgboost::tree::TrainParam param;
param.UpdateAllowUnknown(Args{{"max_depth", "3"}});

TestHistUpdaterEvaluateSplits<float>(param);
TestHistUpdaterEvaluateSplits<double>(param);
}

} // namespace xgboost::sycl::tree

0 comments on commit 916d4a4

Please sign in to comment.