diff --git a/plugin/sycl/tree/hist_updater.cc b/plugin/sycl/tree/hist_updater.cc index 7f91772845a4..18c1f02a3023 100644 --- a/plugin/sycl/tree/hist_updater.cc +++ b/plugin/sycl/tree/hist_updater.cc @@ -79,6 +79,162 @@ void HistUpdater::BuildLocalHistograms( builder_monitor_.Stop("BuildLocalHistograms"); } +template +void HistUpdater::BuildNodeStats( + const common::GHistIndexMatrix &gmat, + RegTree *p_tree, + const USMVector &gpair) { + builder_monitor_.Start("BuildNodeStats"); + for (auto const& entry : qexpand_depth_wise_) { + int nid = entry.nid; + this->InitNewNode(nid, gmat, gpair, *p_tree); + // add constraints + if (!(*p_tree)[nid].IsLeftChild() && !(*p_tree)[nid].IsRoot()) { + // it's a right child + auto parent_id = (*p_tree)[nid].Parent(); + auto left_sibling_id = (*p_tree)[parent_id].LeftChild(); + auto parent_split_feature_id = snode_host_[parent_id].best.SplitIndex(); + tree_evaluator_.AddSplit( + parent_id, left_sibling_id, nid, parent_split_feature_id, + snode_host_[left_sibling_id].weight, snode_host_[nid].weight); + interaction_constraints_.Split(parent_id, parent_split_feature_id, + left_sibling_id, nid); + } + } + builder_monitor_.Stop("BuildNodeStats"); +} + +template +void HistUpdater::AddSplitsToTree( + const common::GHistIndexMatrix &gmat, + RegTree *p_tree, + int *num_leaves, + int depth, + std::vector* nodes_for_apply_split, + std::vector* temp_qexpand_depth) { + builder_monitor_.Start("AddSplitsToTree"); + auto evaluator = tree_evaluator_.GetEvaluator(); + for (auto const& entry : qexpand_depth_wise_) { + const auto lr = param_.learning_rate; + int nid = entry.nid; + + if (snode_host_[nid].best.loss_chg < kRtEps || + (param_.max_depth > 0 && depth == param_.max_depth) || + (param_.max_leaves > 0 && (*num_leaves) == param_.max_leaves)) { + (*p_tree)[nid].SetLeaf(snode_host_[nid].weight * lr); + } else { + nodes_for_apply_split->push_back(entry); + + NodeEntry& e = snode_host_[nid]; + bst_float left_leaf_weight = + evaluator.CalcWeight(nid, GradStats{e.best.left_sum}) * lr; + bst_float right_leaf_weight = + evaluator.CalcWeight(nid, GradStats{e.best.right_sum}) * lr; + p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, + e.best.DefaultLeft(), e.weight, left_leaf_weight, + right_leaf_weight, e.best.loss_chg, e.stats.GetHess(), + e.best.left_sum.GetHess(), e.best.right_sum.GetHess()); + + int left_id = (*p_tree)[nid].LeftChild(); + int right_id = (*p_tree)[nid].RightChild(); + temp_qexpand_depth->push_back(ExpandEntry(left_id, p_tree->GetDepth(left_id))); + temp_qexpand_depth->push_back(ExpandEntry(right_id, p_tree->GetDepth(right_id))); + // - 1 parent + 2 new children + (*num_leaves)++; + } + } + builder_monitor_.Stop("AddSplitsToTree"); +} + + +template +void HistUpdater::EvaluateAndApplySplits( + const common::GHistIndexMatrix &gmat, + RegTree *p_tree, + int *num_leaves, + int depth, + std::vector *temp_qexpand_depth) { + EvaluateSplits(qexpand_depth_wise_, gmat, *p_tree); + + std::vector nodes_for_apply_split; + AddSplitsToTree(gmat, p_tree, num_leaves, depth, + &nodes_for_apply_split, temp_qexpand_depth); + ApplySplit(nodes_for_apply_split, gmat, p_tree); +} + +// Split nodes to 2 sets depending on amount of rows in each node +// Histograms for small nodes will be built explicitly +// Histograms for big nodes will be built by 'Subtraction Trick' +// Exception: in distributed setting, we always build the histogram for the left child node +// and use 'Subtraction Trick' to built the histogram for the right child node. +// This ensures that the workers operate on the same set of tree nodes. +template +void HistUpdater::SplitSiblings( + const std::vector &nodes, + std::vector *small_siblings, + std::vector *big_siblings, + RegTree *p_tree) { + builder_monitor_.Start("SplitSiblings"); + for (auto const& entry : nodes) { + int nid = entry.nid; + RegTree::Node &node = (*p_tree)[nid]; + if (node.IsRoot()) { + small_siblings->push_back(entry); + } else { + const int32_t left_id = (*p_tree)[node.Parent()].LeftChild(); + const int32_t right_id = (*p_tree)[node.Parent()].RightChild(); + + if (nid == left_id && row_set_collection_[left_id ].Size() < + row_set_collection_[right_id].Size()) { + small_siblings->push_back(entry); + } else if (nid == right_id && row_set_collection_[right_id].Size() <= + row_set_collection_[left_id ].Size()) { + small_siblings->push_back(entry); + } else { + big_siblings->push_back(entry); + } + } + } + builder_monitor_.Stop("SplitSiblings"); +} + +template +void HistUpdater::ExpandWithDepthWise( + const common::GHistIndexMatrix &gmat, + RegTree *p_tree, + const USMVector &gpair) { + int num_leaves = 0; + + // in depth_wise growing, we feed loss_chg with 0.0 since it is not used anyway + qexpand_depth_wise_.emplace_back(ExpandEntry::kRootNid, + p_tree->GetDepth(ExpandEntry::kRootNid)); + ++num_leaves; + for (int depth = 0; depth < param_.max_depth + 1; depth++) { + std::vector sync_ids; + std::vector temp_qexpand_depth; + SplitSiblings(qexpand_depth_wise_, &nodes_for_explicit_hist_build_, + &nodes_for_subtraction_trick_, p_tree); + hist_rows_adder_->AddHistRows(this, &sync_ids, p_tree); + BuildLocalHistograms(gmat, p_tree, gpair); + hist_synchronizer_->SyncHistograms(this, sync_ids, p_tree); + BuildNodeStats(gmat, p_tree, gpair); + + EvaluateAndApplySplits(gmat, p_tree, &num_leaves, depth, + &temp_qexpand_depth); + + // clean up + qexpand_depth_wise_.clear(); + nodes_for_subtraction_trick_.clear(); + nodes_for_explicit_hist_build_.clear(); + if (temp_qexpand_depth.empty()) { + break; + } else { + qexpand_depth_wise_ = temp_qexpand_depth; + temp_qexpand_depth.clear(); + } + } +} + template void HistUpdater::ExpandWithLossGuide( const common::GHistIndexMatrix& gmat, @@ -326,7 +482,7 @@ void HistUpdater::InitData( if (param_.grow_policy == xgboost::tree::TrainParam::kLossGuide) { qexpand_loss_guided_.reset(new ExpandQueue(LossGuide)); } else { - LOG(WARNING) << "Depth-wise building is not yet implemented"; + qexpand_depth_wise_.clear(); } } builder_monitor_.Stop("InitData"); diff --git a/plugin/sycl/tree/hist_updater.h b/plugin/sycl/tree/hist_updater.h index 75b318f1e713..6d7d8407059c 100644 --- a/plugin/sycl/tree/hist_updater.h +++ b/plugin/sycl/tree/hist_updater.h @@ -129,6 +129,36 @@ class HistUpdater { const USMVector &gpair, const RegTree& tree); + // Split nodes to 2 sets depending on amount of rows in each node + // Histograms for small nodes will be built explicitly + // Histograms for big nodes will be built by 'Subtraction Trick' + void SplitSiblings(const std::vector& nodes, + std::vector* small_siblings, + std::vector* big_siblings, + RegTree *p_tree); + + void BuildNodeStats(const common::GHistIndexMatrix &gmat, + RegTree *p_tree, + const USMVector &gpair); + + void EvaluateAndApplySplits(const common::GHistIndexMatrix &gmat, + RegTree *p_tree, + int *num_leaves, + int depth, + std::vector *temp_qexpand_depth); + + void AddSplitsToTree( + const common::GHistIndexMatrix &gmat, + RegTree *p_tree, + int *num_leaves, + int depth, + std::vector* nodes_for_apply_split, + std::vector* temp_qexpand_depth); + + void ExpandWithDepthWise(const common::GHistIndexMatrix &gmat, + RegTree *p_tree, + const USMVector &gpair); + void BuildLocalHistograms(const common::GHistIndexMatrix &gmat, RegTree *p_tree, const USMVector &gpair); @@ -180,6 +210,7 @@ class HistUpdater { std::function>; std::unique_ptr qexpand_loss_guided_; + std::vector qexpand_depth_wise_; enum DataLayout { kDenseDataZeroBased, kDenseDataOneBased, kSparseData }; DataLayout data_layout_; diff --git a/tests/cpp/plugin/test_sycl_hist_updater.cc b/tests/cpp/plugin/test_sycl_hist_updater.cc index 9b5526c14bcf..7789b44381dd 100644 --- a/tests/cpp/plugin/test_sycl_hist_updater.cc +++ b/tests/cpp/plugin/test_sycl_hist_updater.cc @@ -75,6 +75,13 @@ class TestHistUpdater : public HistUpdater { const USMVector &gpair) { HistUpdater::ExpandWithLossGuide(gmat, p_tree, gpair); } + + auto TestExpandWithDepthWise(const common::GHistIndexMatrix& gmat, + DMatrix *p_fmat, + RegTree* p_tree, + const USMVector &gpair) { + HistUpdater::ExpandWithDepthWise(gmat, p_tree, gpair); + } }; void GenerateRandomGPairs(::sycl::queue* qu, GradientPair* gpair_ptr, size_t num_rows, bool has_neg_hess) { @@ -544,6 +551,55 @@ void TestHistUpdaterExpandWithLossGuide(const xgboost::tree::TrainParam& param) } +template +void TestHistUpdaterExpandWithDepthWise(const xgboost::tree::TrainParam& param) { + const size_t num_rows = 3; + const size_t num_columns = 1; + const size_t n_bins = 16; + + Context ctx; + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + + DeviceManager device_manager; + auto qu = device_manager.GetQueue(ctx.Device()); + + std::vector data = {7, 3, 15}; + auto p_fmat = GetDMatrixFromData(data, num_rows, num_columns); + + DeviceMatrix dmat; + dmat.Init(qu, p_fmat.get()); + common::GHistIndexMatrix gmat; + gmat.Init(qu, &ctx, dmat, n_bins); + + std::vector gpair_host = {{1, 2}, {3, 1}, {1, 1}}; + USMVector gpair(&qu, gpair_host); + + RegTree tree; + FeatureInteractionConstraintHost int_constraints; + ObjInfo task{ObjInfo::kRegression}; + std::unique_ptr pruner{TreeUpdater::Create("prune", &ctx, &task)}; + TestHistUpdater updater(&ctx, qu, param, std::move(pruner), int_constraints, p_fmat.get()); + updater.SetHistSynchronizer(new BatchHistSynchronizer()); + updater.SetHistRowsAdder(new BatchHistRowsAdder()); + auto* row_set_collection = updater.TestInitData(gmat, gpair, *p_fmat, tree); + + updater.TestExpandWithDepthWise(gmat, p_fmat.get(), &tree, gpair); + + const auto& nodes = tree.GetNodes(); + std::vector ans(data.size()); + for (size_t data_idx = 0; data_idx < data.size(); ++data_idx) { + size_t node_idx = 0; + while (!nodes[node_idx].IsLeaf()) { + node_idx = data[data_idx] < nodes[node_idx].SplitCond() ? nodes[node_idx].LeftChild() : nodes[node_idx].RightChild(); + } + ans[data_idx] = nodes[node_idx].LeafValue(); + } + + ASSERT_NEAR(ans[0], -0.15, 1e-6); + ASSERT_NEAR(ans[1], -0.45, 1e-6); + ASSERT_NEAR(ans[2], -0.15, 1e-6); +} + TEST(SyclHistUpdater, Sampling) { xgboost::tree::TrainParam param; param.UpdateAllowUnknown(Args{{"subsample", "0.7"}}); @@ -620,4 +676,12 @@ TEST(SyclHistUpdater, ExpandWithLossGuide) { TestHistUpdaterExpandWithLossGuide(param); } +TEST(SyclHistUpdater, ExpandWithDepthWise) { + xgboost::tree::TrainParam param; + param.UpdateAllowUnknown(Args{{"max_depth", "2"}}); + + TestHistUpdaterExpandWithDepthWise(param); + TestHistUpdaterExpandWithDepthWise(param); +} + } // namespace xgboost::sycl::tree