From c25ed199f44b8b7eac9a857019faf1e1d3fccd89 Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin Date: Tue, 13 Aug 2024 12:25:24 +0200 Subject: [PATCH] Add tests for depth-wise policy (#63) * add tests for depth-wise policy * fix compilation error --------- Co-authored-by: Dmitry Razdoburdin <> --- tests/cpp/plugin/test_sycl_hist_updater.cc | 62 ++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/tests/cpp/plugin/test_sycl_hist_updater.cc b/tests/cpp/plugin/test_sycl_hist_updater.cc index e9f3ecfa2132..31cd2757166d 100644 --- a/tests/cpp/plugin/test_sycl_hist_updater.cc +++ b/tests/cpp/plugin/test_sycl_hist_updater.cc @@ -73,6 +73,13 @@ class TestHistUpdater : public HistUpdater { const USMVector &gpair) { HistUpdater::ExpandWithLossGuide(gmat, p_tree, gpair); } + + auto TestExpandWithDepthWise(const common::GHistIndexMatrix& gmat, + DMatrix *p_fmat, + RegTree* p_tree, + const USMVector &gpair) { + HistUpdater::ExpandWithDepthWise(gmat, p_tree, gpair); + } }; void GenerateRandomGPairs(::sycl::queue* qu, GradientPair* gpair_ptr, size_t num_rows, bool has_neg_hess) { @@ -532,6 +539,53 @@ void TestHistUpdaterExpandWithLossGuide(const xgboost::tree::TrainParam& param) ASSERT_NEAR(ans[2], -0.15, 1e-6); } +template +void TestHistUpdaterExpandWithDepthWise(const xgboost::tree::TrainParam& param) { + const size_t num_rows = 3; + const size_t num_columns = 1; + const size_t n_bins = 16; + + Context ctx; + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + + DeviceManager device_manager; + auto qu = device_manager.GetQueue(ctx.Device()); + + std::vector data = {7, 3, 15}; + auto p_fmat = GetDMatrixFromData(data, num_rows, num_columns); + + DeviceMatrix dmat; + dmat.Init(qu, p_fmat.get()); + common::GHistIndexMatrix gmat; + gmat.Init(qu, &ctx, dmat, n_bins); + + std::vector gpair_host = {{1, 2}, {3, 1}, {1, 1}}; + USMVector gpair(&qu, gpair_host); + + RegTree tree; + FeatureInteractionConstraintHost int_constraints; + TestHistUpdater updater(&ctx, qu, param, int_constraints, p_fmat.get()); + updater.SetHistSynchronizer(new BatchHistSynchronizer()); + updater.SetHistRowsAdder(new BatchHistRowsAdder()); + auto* row_set_collection = updater.TestInitData(gmat, gpair, *p_fmat, tree); + + updater.TestExpandWithDepthWise(gmat, p_fmat.get(), &tree, gpair); + + const auto& nodes = tree.GetNodes(); + std::vector ans(data.size()); + for (size_t data_idx = 0; data_idx < data.size(); ++data_idx) { + size_t node_idx = 0; + while (!nodes[node_idx].IsLeaf()) { + node_idx = data[data_idx] < nodes[node_idx].SplitCond() ? nodes[node_idx].LeftChild() : nodes[node_idx].RightChild(); + } + ans[data_idx] = nodes[node_idx].LeafValue(); + } + + ASSERT_NEAR(ans[0], -0.15, 1e-6); + ASSERT_NEAR(ans[1], -0.45, 1e-6); + ASSERT_NEAR(ans[2], -0.15, 1e-6); +} + TEST(SyclHistUpdater, Sampling) { xgboost::tree::TrainParam param; param.UpdateAllowUnknown(Args{{"subsample", "0.7"}}); @@ -608,4 +662,12 @@ TEST(SyclHistUpdater, ExpandWithLossGuide) { TestHistUpdaterExpandWithLossGuide(param); } +TEST(SyclHistUpdater, ExpandWithDepthWise) { + xgboost::tree::TrainParam param; + param.UpdateAllowUnknown(Args{{"max_depth", "2"}}); + + TestHistUpdaterExpandWithDepthWise(param); + TestHistUpdaterExpandWithDepthWise(param); +} + } // namespace xgboost::sycl::tree