Skip to content

Commit

Permalink
minor updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Dmitry Razdoburdin committed Apr 15, 2024
1 parent 1792bba commit 4a9de50
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 21 deletions.
10 changes: 5 additions & 5 deletions plugin/sycl/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@

#include <utility>

#include "xgboost/logging.h"

#include "updater_quantile_hist.h"
#include "../data.h"

namespace xgboost {
namespace sycl {
Expand Down Expand Up @@ -64,9 +61,12 @@ void QuantileHistMaker::CallUpdate(
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree *> &trees) {
const std::vector<GradientPair>& gpair_h = gpair->ConstHostVector();
USMVector<GradientPair, MemoryType::on_device> gpair_device(&qu_, gpair_h);
gpair_device_.Resize(&qu_, gpair_h.size());
qu_.memcpy(gpair_device_.Data(), gpair_h.data(), gpair_h.size() * sizeof(GradientPair));
qu_.wait();

for (auto tree : trees) {
pimpl->Update(ctx_, param, gmat_, gpair, gpair_device, dmat, out_position, tree);
pimpl->Update(ctx_, param, gmat_, gpair, gpair_device_, dmat, out_position, tree);
}
}

Expand Down
20 changes: 4 additions & 16 deletions plugin/sycl/tree/updater_quantile_hist.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,23 +55,9 @@ class QuantileHistMaker: public TreeUpdater {
void LoadConfig(Json const& in) override {
auto const& config = get<Object const>(in);
FromJson(config.at("train_param"), &this->param_);
try {
FromJson(config.at("sycl_hist_train_param"), &this->hist_maker_param_);
} catch (std::out_of_range& e) {
// XGBoost model is from 1.1.x, so 'cpu_hist_train_param' is missing.
// We add this compatibility check because it's just recently that we (developers) began
// persuade R users away from using saveRDS() for model serialization. Hopefully, one day,
// everyone will be using xgb.save().
LOG(WARNING) << "Attempted to load interal configuration for a model file that was generated "
<< "by a previous version of XGBoost. A likely cause for this warning is that the model "
<< "was saved with saveRDS() in R or pickle.dump() in Python. We strongly ADVISE AGAINST "
<< "using saveRDS() or pickle.dump() so that the model remains accessible in current and "
<< "upcoming XGBoost releases. Please use xgb.save() instead to preserve models for the "
<< "long term. For more details and explanation, see "
<< "https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html";
this->hist_maker_param_.UpdateAllowUnknown(Args{});
}
FromJson(config.at("sycl_hist_train_param"), &this->hist_maker_param_);
}

void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["train_param"] = ToJson(param_);
Expand Down Expand Up @@ -116,6 +102,8 @@ class QuantileHistMaker: public TreeUpdater {
::sycl::queue qu_;
DeviceManager device_manager;
ObjInfo const *task_{nullptr};

USMVector<GradientPair, MemoryType::on_device> gpair_device_;
};

} // namespace tree
Expand Down

0 comments on commit 4a9de50

Please sign in to comment.