Skip to content

Commit

Permalink
refactor pimpl
Browse files Browse the repository at this point in the history
  • Loading branch information
Dmitry Razdoburdin committed Apr 19, 2024
1 parent a54282f commit d0ec3d7
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 21 deletions.
47 changes: 28 additions & 19 deletions plugin/sycl/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@ void QuantileHistMaker::Configure(const Args& args) {
pruner_->Configure(args);
param_.UpdateAllowUnknown(args);
hist_maker_param_.UpdateAllowUnknown(args);

bool has_fp64_support = qu_.get_device().has(::sycl::aspect::fp64);
if (hist_maker_param_.single_precision_histogram || !has_fp64_support) {
if (hist_maker_param_.single_precision_histogram) {
LOG(WARNING) << "Target device doesn't support fp64, using single_precision_histogram=True";
}
hist_precision_ = HistPrecision::fp32;
} else {
hist_precision_ = HistPrecision::fp64;
}
}

template<typename GradientSumT>
Expand Down Expand Up @@ -90,20 +100,16 @@ void QuantileHistMaker::Update(xgboost::tree::TrainParam const *param,
param_.learning_rate = lr / trees.size();
int_constraint_.Configure(param_, dmat->Info().num_col_);
// build tree
bool has_double_support = qu_.get_device().has(::sycl::aspect::fp64);
if (hist_maker_param_.single_precision_histogram || !has_double_support) {
if (!hist_maker_param_.single_precision_histogram) {
LOG(WARNING) << "Target device doesn't support fp64, using single_precision_histogram=True";
}
if (!pimpl_single) {
SetPimpl(&pimpl_single, dmat);
if (hist_precision_ == HistPricision::fp32) {
if (!pimpl_fp32) {
SetPimpl(&pimpl_fp32, dmat);
}
CallUpdate(pimpl_single, param, gpair, dmat, out_position, trees);
CallUpdate(pimpl_fp32, param, gpair, dmat, out_position, trees);
} else {
if (!pimpl_double) {
SetPimpl(&pimpl_double, dmat);
if (!pimpl_fp64) {
SetPimpl(&pimpl_fp64, dmat);
}
CallUpdate(pimpl_double, param, gpair, dmat, out_position, trees);
CallUpdate(pimpl_fp64, param, gpair, dmat, out_position, trees);
}

param_.learning_rate = lr;
Expand All @@ -113,16 +119,19 @@ void QuantileHistMaker::Update(xgboost::tree::TrainParam const *param,

bool QuantileHistMaker::UpdatePredictionCache(const DMatrix* data,
linalg::MatrixView<float> out_preds) {
if (param_.subsample < 1.0f) {
return false;
if (param_.subsample < 1.0f) return false;

if (hist_precision_ == HistPricision::fp32) {
if (pimpl_fp32) {
return pimpl_fp32->UpdatePredictionCache(data, out_preds);
} else {
return false;
}
} else {
bool has_double_support = qu_.get_device().has(::sycl::aspect::fp64);
if ((hist_maker_param_.single_precision_histogram || !has_double_support) && pimpl_single) {
return pimpl_single->UpdatePredictionCache(data, out_preds);
} else if (pimpl_double) {
return pimpl_double->UpdatePredictionCache(data, out_preds);
if (pimpl_fp64) {
return pimpl_fp64->UpdatePredictionCache(data, out_preds);
} else {
return false;
return false;
}
}
}
Expand Down
7 changes: 5 additions & 2 deletions plugin/sycl/tree/updater_quantile_hist.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,11 @@ class QuantileHistMaker: public TreeUpdater {
const std::vector<RegTree *> &trees);

protected:
std::unique_ptr<HistUpdater<float>> pimpl_single;
std::unique_ptr<HistUpdater<double>> pimpl_double;
enum class HistPrecision {fp32, fp64};
HistPrecision hist_precision_;

std::unique_ptr<HistUpdater<float>> pimpl_fp32;
std::unique_ptr<HistUpdater<double>> pimpl_fp64;

std::unique_ptr<TreeUpdater> pruner_;
FeatureInteractionConstraintHost int_constraint_;
Expand Down

0 comments on commit d0ec3d7

Please sign in to comment.