Skip to content

Commit

Permalink
[bp] Fix rng for the column sampler. (#10998) (#11004)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Nov 19, 2024
1 parent 5973d60 commit 7b675da
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/common/random.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ class ColumnSampler {
};

inline auto MakeColumnSampler(Context const* ctx) {
std::uint32_t seed = common::GlobalRandomEngine()();
std::uint32_t seed = common::GlobalRandom()();
auto rc = collective::Broadcast(ctx, linalg::MakeVec(&seed, 1), 0);
collective::SafeColl(rc);
auto cs = std::make_shared<common::ColumnSampler>(seed);
Expand Down
10 changes: 2 additions & 8 deletions src/tree/updater_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -867,12 +867,7 @@ class GPUHistMaker : public TreeUpdater {
CHECK_GE(ctx_->Ordinal(), 0) << "Must have at least one device";
info_ = &dmat->Info();

// Synchronise the column sampling seed
uint32_t column_sampling_seed = common::GlobalRandom()();
auto rc = collective::Broadcast(
ctx_, linalg::MakeVec(&column_sampling_seed, sizeof(column_sampling_seed)), 0);
SafeColl(rc);
this->column_sampler_ = std::make_shared<common::ColumnSampler>(column_sampling_seed);
this->column_sampler_ = common::MakeColumnSampler(ctx_);

auto batch_param = BatchParam{param->max_bin, TrainParam::DftSparseThreshold()};
dh::safe_cuda(cudaSetDevice(ctx_->Ordinal()));
Expand Down Expand Up @@ -1012,8 +1007,7 @@ class GPUGlobalApproxMaker : public TreeUpdater {

monitor_.Start(__func__);
CHECK(ctx_->IsCUDA()) << error::InvalidCUDAOrdinal();
uint32_t column_sampling_seed = common::GlobalRandom()();
this->column_sampler_ = std::make_shared<common::ColumnSampler>(column_sampling_seed);
this->column_sampler_ = common::MakeColumnSampler(ctx_);

p_last_fmat_ = p_fmat;
initialised_ = true;
Expand Down
22 changes: 22 additions & 0 deletions tests/python/test_updaters.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,28 @@ def test_exact_sample_by_node_error(self) -> None:
num_boost_round=2,
)

@pytest.mark.parametrize("tree_method", ["approx", "hist"])
def test_colsample_rng(self, tree_method: str) -> None:
"""Test rng has an effect on column sampling."""
X, y, _ = tm.make_regression(128, 16, use_cupy=False)
reg0 = xgb.XGBRegressor(
n_estimators=2,
colsample_bynode=0.5,
random_state=42,
tree_method=tree_method,
)
reg0.fit(X, y)

reg1 = xgb.XGBRegressor(
n_estimators=2,
colsample_bynode=0.5,
random_state=43,
tree_method=tree_method,
)
reg1.fit(X, y)

assert list(reg0.feature_importances_) != list(reg1.feature_importances_)

@given(
exact_parameter_strategy,
hist_parameter_strategy,
Expand Down

0 comments on commit 7b675da

Please sign in to comment.