Skip to content

Commit

Permalink
reuse some code from main xgboost
Browse files Browse the repository at this point in the history
  • Loading branch information
Dmitry Razdoburdin committed Mar 14, 2024
1 parent 04a7205 commit 88c3929
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 66 deletions.
59 changes: 2 additions & 57 deletions plugin/sycl/tree/param.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,62 +45,8 @@ struct TrainParam {
}
};


/*! \brief core statistics used for tree construction */
template<typename GradType>
struct GradStats {
/*! \brief sum gradient statistics */
GradType sum_grad { 0 };
/*! \brief sum hessian statistics */
GradType sum_hess { 0 };


public:
GradType GetGrad() const { return sum_grad; }
GradType GetHess() const { return sum_hess; }

GradStats<GradType>& operator+= (const GradStats<GradType>& rhs) {
sum_grad += rhs.sum_grad;
sum_hess += rhs.sum_hess;

return *this;
}

GradStats<GradType>& operator-= (const GradStats<GradType>& rhs) {
sum_grad -= rhs.sum_grad;
sum_hess -= rhs.sum_hess;

return *this;
}

friend GradStats<GradType> operator+ (GradStats<GradType> lhs,
const GradStats<GradType> rhs) {
lhs += rhs;
return lhs;
}

friend GradStats<GradType> operator- (GradStats<GradType> lhs,
const GradStats<GradType> rhs) {
lhs -= rhs;
return lhs;
}


friend std::ostream& operator<<(std::ostream& os, GradStats s) {
os << s.GetGrad() << "/" << s.GetHess();
return os;
}

GradStats() {
}

template <typename GpairT>
explicit GradStats(const GpairT &sum)
: sum_grad(sum.GetGrad()), sum_hess(sum.GetHess()) {}
explicit GradStats(const GradType grad, const GradType hess)
: sum_grad(grad), sum_hess(hess) {}
};

template <typename GradType>
using GradStats = xgboost::detail::GradientPairInternal<GradType>;

/*!
* \brief SYCL implementation of SplitEntryContainer for device compilation.
Expand Down Expand Up @@ -208,7 +154,6 @@ struct SplitEntryContainer {
}
};


template<typename GradType>
using SplitEntry = SplitEntryContainer<GradStats<GradType>>;

Expand Down
10 changes: 5 additions & 5 deletions plugin/sycl/tree/split_evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,10 @@ class TreeEvaluator {
}
}

inline GradType Sqr(GradType a) const { return a * a; }
// inline GradType Sqr(GradType a) const { return a * a; }

inline GradType CalcGainGivenWeight(GradType sum_grad, GradType sum_hess, GradType w) const {
return -(2.0f * sum_grad * w + (sum_hess + param.reg_lambda) * this->Sqr(w));
return -(2.0f * sum_grad * w + (sum_hess + param.reg_lambda) * xgboost::common::Sqr(w));
}

inline GradType CalcGainGivenWeight(bst_node_t nid, const GradStats<GradType>& stats,
Expand All @@ -159,10 +159,10 @@ class TreeEvaluator {
}
// Avoiding tree::CalcGainGivenWeight can significantly reduce avg floating point error.
if (param.max_delta_step == 0.0f && has_constraint == false) {
return this->Sqr(this->ThresholdL1(stats.sum_grad, param.reg_alpha)) /
(stats.sum_hess + param.reg_lambda);
return xgboost::common::Sqr(this->ThresholdL1(stats.GetGrad(), param.reg_alpha)) /
(stats.GetHess() + param.reg_lambda);
}
return this->CalcGainGivenWeight(stats.sum_grad, stats.sum_hess, w);
return this->CalcGainGivenWeight(stats.GetGrad(), stats.GetHess(), w);
}

GradType CalcGain(bst_node_t nid, const GradStats<GradType>& stats) const {
Expand Down
7 changes: 3 additions & 4 deletions plugin/sycl/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1092,9 +1092,8 @@ void QuantileHistMaker::Builder<GradientSumT>::EnumerateSplit(
* Maybe calculating of reduce overgroup in seprate kernel and reusing it here can be faster
*/
for (int32_t i = ibegin + local_id; i < iend; i += sub_group_size) {
sum += GradStats<GradientSumT>(
::sycl::inclusive_scan_over_group(sg, hist_data[i].GetGrad(), std::plus<>()),
::sycl::inclusive_scan_over_group(sg, hist_data[i].GetHess(), std::plus<>()));
sum.Add(::sycl::inclusive_scan_over_group(sg, hist_data[i].GetGrad(), std::plus<>()),
::sycl::inclusive_scan_over_group(sg, hist_data[i].GetHess(), std::plus<>()));

if (sum.GetHess() >= min_child_weight) {
GradStats<GradientSumT> c = snode.stats - sum;
Expand All @@ -1110,7 +1109,7 @@ void QuantileHistMaker::Builder<GradientSumT>::EnumerateSplit(
size_t end = i - local_id + sub_group_size;
if (end > iend) end = iend;
for (size_t j = i + 1; j < end; ++j) {
sum += GradStats<GradientSumT>(hist_data[j].GetGrad(), hist_data[j].GetHess());
sum.Add(hist_data[j].GetGrad(), hist_data[j].GetHess());
}
}
}
Expand Down

0 comments on commit 88c3929

Please sign in to comment.