Skip to content

Commit

Permalink
Improve objective dispatching; improve code sharing
Browse files Browse the repository at this point in the history
  • Loading branch information
Dmitry Razdoburdin committed Dec 7, 2023
1 parent e66ae85 commit 6fa0f7f
Show file tree
Hide file tree
Showing 11 changed files with 99 additions and 298 deletions.
70 changes: 11 additions & 59 deletions plugin/sycl/objective/multiclass_obj.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
#include "xgboost/data.h"
#include "../../src/common/math.h"
#pragma GCC diagnostic pop
#include "xgboost/logging.h"
#include "xgboost/objective.h"
#include "xgboost/json.h"
#include "xgboost/span.h"

#include "../../../src/objective/multiclass_param.h"

#include "../device_manager.h"
#include <CL/sycl.hpp>
Expand All @@ -32,55 +36,6 @@ namespace obj {

DMLC_REGISTRY_FILE_TAG(multiclass_obj_sycl);

/*!
* \brief Do inplace softmax transformaton on start to end
*
* \tparam Iterator Input iterator type
*
* \param start Start iterator of input
* \param end end iterator of input
*/
template <typename Iterator>
inline void Softmax(Iterator start, Iterator end) {
bst_float wmax = *start;
for (Iterator i = start+1; i != end; ++i) {
wmax = ::sycl::max(*i, wmax);
}
float wsum = 0.0f;
for (Iterator i = start; i != end; ++i) {
*i = ::sycl::exp(*i - wmax);
wsum += *i;
}
for (Iterator i = start; i != end; ++i) {
*i /= static_cast<float>(wsum);
}
}

/*!
* \brief Find the maximum iterator within the iterators
* \param begin The begining iterator.
* \param end The end iterator.
* \return the iterator point to the maximum value.
* \tparam Iterator The type of the iterator.
*/
template<typename Iterator>
inline Iterator FindMaxIndex(Iterator begin, Iterator end) {
Iterator maxit = begin;
for (Iterator it = begin; it != end; ++it) {
if (*it > *maxit) maxit = it;
}
return maxit;
}

struct SoftmaxMultiClassParam : public XGBoostParameter<SoftmaxMultiClassParam> {
int num_class;
// declare parameters
DMLC_DECLARE_PARAMETER(SoftmaxMultiClassParam) {
DMLC_DECLARE_FIELD(num_class).set_lower_bound(1)
.describe("Number of output class in the multi-class classification.");
}
};

class SoftmaxMultiClassObj : public ObjFunction {
public:
explicit SoftmaxMultiClassObj(bool output_prob)
Expand Down Expand Up @@ -188,8 +143,8 @@ class SoftmaxMultiClassObj : public ObjFunction {
auto io_preds_acc = io_preds_buf.get_access<::sycl::access::mode::read_write>(cgh);
cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) {
int idx = pid[0];
bst_float * point = &io_preds_acc[idx * nclass];
Softmax(point, point + nclass);
auto it = io_preds_acc.begin() + idx * nclass;
common::Softmax(it, it + nclass);
});
}).wait();
} else {
Expand All @@ -200,8 +155,8 @@ class SoftmaxMultiClassObj : public ObjFunction {
auto max_preds_acc = max_preds_buf.get_access<::sycl::access::mode::read_write>(cgh);
cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) {
int idx = pid[0];
bst_float const * point = &io_preds_acc[idx * nclass];
max_preds_acc[idx] = FindMaxIndex(point, point + nclass) - point;
auto it = io_preds_acc.begin() + idx * nclass;
max_preds_acc[idx] = common::FindMaxIndex(it, it + nclass) - it;
});
}).wait();
}
Expand All @@ -218,9 +173,9 @@ class SoftmaxMultiClassObj : public ObjFunction {
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
if (this->output_prob_) {
out["name"] = String("multi:softprob_sycl");
out["name"] = String("multi:softprob");
} else {
out["name"] = String("multi:softmax_sycl");
out["name"] = String("multi:softmax");
}
out["softmax_multiclass_param"] = ToJson(param_);
}
Expand All @@ -233,7 +188,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
// output probability
bool output_prob_;
// parameter
SoftmaxMultiClassParam param_;
xgboost::obj::SoftmaxMultiClassParam param_;
// Cache for max_preds
mutable HostDeviceVector<bst_float> max_preds_;

Expand All @@ -242,9 +197,6 @@ class SoftmaxMultiClassObj : public ObjFunction {
mutable ::sycl::queue qu_;
};

// register the objective functions
DMLC_REGISTER_PARAMETER(SoftmaxMultiClassParam);

XGBOOST_REGISTER_OBJECTIVE(SoftmaxMultiClass, "multi:softmax_sycl")
.describe("Softmax for multi-class classification, output class index.")
.set_body([]() { return new SoftmaxMultiClassObj(false); });
Expand Down
155 changes: 0 additions & 155 deletions plugin/sycl/objective/regression_loss.h

This file was deleted.

54 changes: 26 additions & 28 deletions plugin/sycl/objective/regression_obj.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@

#include "../../src/common/transform.h"
#include "../../src/common/common.h"
#include "regression_loss.h"
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
#include "../../../src/objective/regression_loss.h"
#pragma GCC diagnostic pop
#include "../../../src/objective/regression_param.h"

#include "../device_manager.h"

#include <CL/sycl.hpp>
Expand All @@ -33,15 +38,6 @@ namespace obj {

DMLC_REGISTRY_FILE_TAG(regression_obj_sycl);

struct RegLossParam : public XGBoostParameter<RegLossParam> {
float scale_pos_weight;
// declare parameters
DMLC_DECLARE_PARAMETER(RegLossParam) {
DMLC_DECLARE_FIELD(scale_pos_weight).set_default(1.0f).set_lower_bound(0.0f)
.describe("Scale the weight of positive examples by this factor");
}
};

template<typename Loss>
class RegLossObj : public ObjFunction {
protected:
Expand Down Expand Up @@ -164,35 +160,37 @@ class RegLossObj : public ObjFunction {
}

protected:
RegLossParam param_;
xgboost::obj::RegLossParam param_;
sycl::DeviceManager device_manager;

mutable ::sycl::queue qu_;
};

// register the objective functions
DMLC_REGISTER_PARAMETER(RegLossParam);

/* TODO(razdoburdin):
* Find a better way to dispatch names of SYCL kernels with various
* template parameters of loss function
*/
XGBOOST_REGISTER_OBJECTIVE(SquaredLossRegression, LinearSquareLoss::Name())
XGBOOST_REGISTER_OBJECTIVE(SquaredLossRegression,
std::string(xgboost::obj::LinearSquareLoss::Name()) + "_sycl")
.describe("Regression with squared error with SYCL backend.")
.set_body([]() { return new RegLossObj<LinearSquareLoss>(); });
XGBOOST_REGISTER_OBJECTIVE(SquareLogError, SquaredLogError::Name())
.set_body([]() { return new RegLossObj<xgboost::obj::LinearSquareLoss>(); });

XGBOOST_REGISTER_OBJECTIVE(SquareLogError,
std::string(xgboost::obj::SquaredLogError::Name()) + "_sycl")
.describe("Regression with root mean squared logarithmic error with SYCL backend.")
.set_body([]() { return new RegLossObj<SquaredLogError>(); });
XGBOOST_REGISTER_OBJECTIVE(LogisticRegression, LogisticRegression::Name())
.set_body([]() { return new RegLossObj<xgboost::obj::SquaredLogError>(); });

XGBOOST_REGISTER_OBJECTIVE(LogisticRegression,
std::string(xgboost::obj::LogisticRegression::Name()) + "_sycl")
.describe("Logistic regression for probability regression task with SYCL backend.")
.set_body([]() { return new RegLossObj<LogisticRegression>(); });
XGBOOST_REGISTER_OBJECTIVE(LogisticClassification, LogisticClassification::Name())
.set_body([]() { return new RegLossObj<xgboost::obj::LogisticRegression>(); });

XGBOOST_REGISTER_OBJECTIVE(LogisticClassification,
std::string(xgboost::obj::LogisticClassification::Name()) + "_sycl")
.describe("Logistic regression for binary classification task with SYCL backend.")
.set_body([]() { return new RegLossObj<LogisticClassification>(); });
XGBOOST_REGISTER_OBJECTIVE(LogisticRaw, LogisticRaw::Name())
.set_body([]() { return new RegLossObj<xgboost::obj::LogisticClassification>(); });

XGBOOST_REGISTER_OBJECTIVE(LogisticRaw,
std::string(xgboost::obj::LogisticRaw::Name()) + "_sycl")
.describe("Logistic regression for classification, output score "
"before logistic transformation with SYCL backend.")
.set_body([]() { return new RegLossObj<LogisticRaw>(); });
.set_body([]() { return new RegLossObj<xgboost::obj::LogisticRaw>(); });

} // namespace obj
} // namespace sycl
Expand Down
Loading

0 comments on commit 6fa0f7f

Please sign in to comment.