diff --git a/plugin/sycl/tree/param.h b/plugin/sycl/tree/param.h index 3f68f28a6bb2..9e4d52a14287 100644 --- a/plugin/sycl/tree/param.h +++ b/plugin/sycl/tree/param.h @@ -45,62 +45,8 @@ struct TrainParam { } }; - -/*! \brief core statistics used for tree construction */ -template -struct GradStats { - /*! \brief sum gradient statistics */ - GradType sum_grad { 0 }; - /*! \brief sum hessian statistics */ - GradType sum_hess { 0 }; - - - public: - GradType GetGrad() const { return sum_grad; } - GradType GetHess() const { return sum_hess; } - - GradStats& operator+= (const GradStats& rhs) { - sum_grad += rhs.sum_grad; - sum_hess += rhs.sum_hess; - - return *this; - } - - GradStats& operator-= (const GradStats& rhs) { - sum_grad -= rhs.sum_grad; - sum_hess -= rhs.sum_hess; - - return *this; - } - - friend GradStats operator+ (GradStats lhs, - const GradStats rhs) { - lhs += rhs; - return lhs; - } - - friend GradStats operator- (GradStats lhs, - const GradStats rhs) { - lhs -= rhs; - return lhs; - } - - - friend std::ostream& operator<<(std::ostream& os, GradStats s) { - os << s.GetGrad() << "/" << s.GetHess(); - return os; - } - - GradStats() { - } - - template - explicit GradStats(const GpairT &sum) - : sum_grad(sum.GetGrad()), sum_hess(sum.GetHess()) {} - explicit GradStats(const GradType grad, const GradType hess) - : sum_grad(grad), sum_hess(hess) {} -}; - +template +using GradStats = xgboost::detail::GradientPairInternal; /*! * \brief SYCL implementation of SplitEntryContainer for device compilation. @@ -208,7 +154,6 @@ struct SplitEntryContainer { } }; - template using SplitEntry = SplitEntryContainer>; diff --git a/plugin/sycl/tree/split_evaluator.h b/plugin/sycl/tree/split_evaluator.h index e9eb1c2da321..d9618bf5162c 100644 --- a/plugin/sycl/tree/split_evaluator.h +++ b/plugin/sycl/tree/split_evaluator.h @@ -146,10 +146,10 @@ class TreeEvaluator { } } - inline GradType Sqr(GradType a) const { return a * a; } + // inline GradType Sqr(GradType a) const { return a * a; } inline GradType CalcGainGivenWeight(GradType sum_grad, GradType sum_hess, GradType w) const { - return -(2.0f * sum_grad * w + (sum_hess + param.reg_lambda) * this->Sqr(w)); + return -(2.0f * sum_grad * w + (sum_hess + param.reg_lambda) * xgboost::common::Sqr(w)); } inline GradType CalcGainGivenWeight(bst_node_t nid, const GradStats& stats, @@ -159,10 +159,10 @@ class TreeEvaluator { } // Avoiding tree::CalcGainGivenWeight can significantly reduce avg floating point error. if (param.max_delta_step == 0.0f && has_constraint == false) { - return this->Sqr(this->ThresholdL1(stats.sum_grad, param.reg_alpha)) / - (stats.sum_hess + param.reg_lambda); + return xgboost::common::Sqr(this->ThresholdL1(stats.GetGrad(), param.reg_alpha)) / + (stats.GetHess() + param.reg_lambda); } - return this->CalcGainGivenWeight(stats.sum_grad, stats.sum_hess, w); + return this->CalcGainGivenWeight(stats.GetGrad(), stats.GetHess(), w); } GradType CalcGain(bst_node_t nid, const GradStats& stats) const { diff --git a/plugin/sycl/tree/updater_quantile_hist.cc b/plugin/sycl/tree/updater_quantile_hist.cc index 16db0d835183..384cb1b8c0ab 100644 --- a/plugin/sycl/tree/updater_quantile_hist.cc +++ b/plugin/sycl/tree/updater_quantile_hist.cc @@ -1092,9 +1092,8 @@ void QuantileHistMaker::Builder::EnumerateSplit( * Maybe calculating of reduce overgroup in seprate kernel and reusing it here can be faster */ for (int32_t i = ibegin + local_id; i < iend; i += sub_group_size) { - sum += GradStats( - ::sycl::inclusive_scan_over_group(sg, hist_data[i].GetGrad(), std::plus<>()), - ::sycl::inclusive_scan_over_group(sg, hist_data[i].GetHess(), std::plus<>())); + sum.Add(::sycl::inclusive_scan_over_group(sg, hist_data[i].GetGrad(), std::plus<>()), + ::sycl::inclusive_scan_over_group(sg, hist_data[i].GetHess(), std::plus<>())); if (sum.GetHess() >= min_child_weight) { GradStats c = snode.stats - sum; @@ -1110,7 +1109,7 @@ void QuantileHistMaker::Builder::EnumerateSplit( size_t end = i - local_id + sub_group_size; if (end > iend) end = iend; for (size_t j = i + 1; j < end; ++j) { - sum += GradStats(hist_data[j].GetGrad(), hist_data[j].GetHess()); + sum.Add(hist_data[j].GetGrad(), hist_data[j].GetHess()); } } }