Skip to content

Commit

Permalink
Merge approx tests. (dmlc#10583)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Jul 16, 2024
1 parent 5a92ffe commit a6a8a55
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 61 deletions.
49 changes: 49 additions & 0 deletions tests/cpp/tree/test_approx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
#include <gtest/gtest.h>

#include "../../../src/tree/common_row_partitioner.h"
#include "../../../src/tree/param.h" // for TrainParam
#include "../collective/test_worker.h" // for TestDistributedGlobal
#include "../helpers.h"
#include "test_column_split.h" // for TestColumnSplit
#include "test_partitioner.h"
#include "xgboost/tree_model.h" // for RegTree

namespace xgboost::tree {
namespace {
Expand Down Expand Up @@ -76,6 +78,53 @@ TEST(Approx, Partitioner) {
}
}

TEST(Approx, InteractionConstraint) {
auto constexpr kRows = 32;
auto constexpr kCols = 16;
auto p_dmat = GenerateCatDMatrix(kRows, kCols, 0.6f, false);
Context ctx;

linalg::Matrix<GradientPair> gpair({kRows}, ctx.Device());
gpair.Data()->Copy(GenerateRandomGradients(kRows));

ObjInfo task{ObjInfo::kRegression};
{
// With constraints
RegTree tree{1, kCols};

std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)};
TrainParam param;
param.UpdateAllowUnknown(
Args{{"interaction_constraints", "[[0, 1]]"}, {"num_feature", std::to_string(kCols)}});
std::vector<HostDeviceVector<bst_node_t>> position(1);
updater->Configure(Args{});
updater->Update(&param, &gpair, p_dmat.get(), position, {&tree});

ASSERT_EQ(tree.NumExtraNodes(), 4);
ASSERT_EQ(tree[0].SplitIndex(), 1);

ASSERT_EQ(tree[tree[0].LeftChild()].SplitIndex(), 0);
ASSERT_EQ(tree[tree[0].RightChild()].SplitIndex(), 0);
}
{
// Without constraints
RegTree tree{1u, kCols};

std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)};
std::vector<HostDeviceVector<bst_node_t>> position(1);
TrainParam param;
param.Init(Args{});
updater->Configure(Args{});
updater->Update(&param, &gpair, p_dmat.get(), position, {&tree});

ASSERT_EQ(tree.NumExtraNodes(), 10);
ASSERT_EQ(tree[0].SplitIndex(), 1);

ASSERT_NE(tree[tree[0].LeftChild()].SplitIndex(), 0);
ASSERT_NE(tree[tree[0].RightChild()].SplitIndex(), 0);
}
}

namespace {
void TestColumnSplitPartitioner(size_t n_samples, size_t base_rowid, std::shared_ptr<DMatrix> Xy,
std::vector<float>* hess, float min_value, float mid_value,
Expand Down
8 changes: 6 additions & 2 deletions tests/cpp/tree/test_column_split.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,13 @@ inline std::shared_ptr<DMatrix> GenerateCatDMatrix(std::size_t rows, std::size_t
for (size_t i = 0; i < ft.size(); ++i) {
ft[i] = (i % 3 == 0) ? FeatureType::kNumerical : FeatureType::kCategorical;
}
return RandomDataGenerator(rows, cols, 0.6f).Seed(3).Type(ft).MaxCategory(17).GenerateDMatrix();
return RandomDataGenerator(rows, cols, sparsity)
.Seed(3)
.Type(ft)
.MaxCategory(17)
.GenerateDMatrix();
} else {
return RandomDataGenerator{rows, cols, 0.6f}.Seed(3).GenerateDMatrix();
return RandomDataGenerator{rows, cols, sparsity}.Seed(3).GenerateDMatrix();
}
}

Expand Down
59 changes: 0 additions & 59 deletions tests/cpp/tree/test_histmaker.cc

This file was deleted.

0 comments on commit a6a8a55

Please sign in to comment.