diff --git a/plugin/sycl/tree/updater_quantile_hist.cc b/plugin/sycl/tree/updater_quantile_hist.cc index e49d5a49f1d1..36709ef111c0 100644 --- a/plugin/sycl/tree/updater_quantile_hist.cc +++ b/plugin/sycl/tree/updater_quantile_hist.cc @@ -33,6 +33,16 @@ void QuantileHistMaker::Configure(const Args& args) { pruner_->Configure(args); param_.UpdateAllowUnknown(args); hist_maker_param_.UpdateAllowUnknown(args); + + bool has_fp64_support = qu_.get_device().has(::sycl::aspect::fp64); + if (hist_maker_param_.single_precision_histogram || !has_fp64_support) { + if (hist_maker_param_.single_precision_histogram) { + LOG(WARNING) << "Target device doesn't support fp64, using single_precision_histogram=True"; + } + hist_precision_ = HistPrecision::fp32; + } else { + hist_precision_ = HistPrecision::fp64; + } } template @@ -90,20 +100,16 @@ void QuantileHistMaker::Update(xgboost::tree::TrainParam const *param, param_.learning_rate = lr / trees.size(); int_constraint_.Configure(param_, dmat->Info().num_col_); // build tree - bool has_double_support = qu_.get_device().has(::sycl::aspect::fp64); - if (hist_maker_param_.single_precision_histogram || !has_double_support) { - if (!hist_maker_param_.single_precision_histogram) { - LOG(WARNING) << "Target device doesn't support fp64, using single_precision_histogram=True"; - } - if (!pimpl_single) { - SetPimpl(&pimpl_single, dmat); + if (hist_precision_ == HistPricision::fp32) { + if (!pimpl_fp32) { + SetPimpl(&pimpl_fp32, dmat); } - CallUpdate(pimpl_single, param, gpair, dmat, out_position, trees); + CallUpdate(pimpl_fp32, param, gpair, dmat, out_position, trees); } else { - if (!pimpl_double) { - SetPimpl(&pimpl_double, dmat); + if (!pimpl_fp64) { + SetPimpl(&pimpl_fp64, dmat); } - CallUpdate(pimpl_double, param, gpair, dmat, out_position, trees); + CallUpdate(pimpl_fp64, param, gpair, dmat, out_position, trees); } param_.learning_rate = lr; @@ -113,16 +119,19 @@ void QuantileHistMaker::Update(xgboost::tree::TrainParam const *param, bool QuantileHistMaker::UpdatePredictionCache(const DMatrix* data, linalg::MatrixView out_preds) { - if (param_.subsample < 1.0f) { - return false; + if (param_.subsample < 1.0f) return false; + + if (hist_precision_ == HistPricision::fp32) { + if (pimpl_fp32) { + return pimpl_fp32->UpdatePredictionCache(data, out_preds); + } else { + return false; + } } else { - bool has_double_support = qu_.get_device().has(::sycl::aspect::fp64); - if ((hist_maker_param_.single_precision_histogram || !has_double_support) && pimpl_single) { - return pimpl_single->UpdatePredictionCache(data, out_preds); - } else if (pimpl_double) { - return pimpl_double->UpdatePredictionCache(data, out_preds); + if (pimpl_fp64) { + return pimpl_fp64->UpdatePredictionCache(data, out_preds); } else { - return false; + return false; } } } diff --git a/plugin/sycl/tree/updater_quantile_hist.h b/plugin/sycl/tree/updater_quantile_hist.h index e26751617830..8890970208d0 100644 --- a/plugin/sycl/tree/updater_quantile_hist.h +++ b/plugin/sycl/tree/updater_quantile_hist.h @@ -93,8 +93,11 @@ class QuantileHistMaker: public TreeUpdater { const std::vector &trees); protected: - std::unique_ptr> pimpl_single; - std::unique_ptr> pimpl_double; + enum class HistPrecision {fp32, fp64}; + HistPrecision hist_precision_; + + std::unique_ptr> pimpl_fp32; + std::unique_ptr> pimpl_fp64; std::unique_ptr pruner_; FeatureInteractionConstraintHost int_constraint_;