diff --git a/CMakeLists.txt b/CMakeLists.txt index 4f7bc5f2fd6b..6176a247db0f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -86,7 +86,7 @@ option(PLUGIN_DENSE_PARSER "Build dense parser plugin" OFF) option(PLUGIN_RMM "Build with RAPIDS Memory Manager (RMM)" OFF) option(PLUGIN_FEDERATED "Build with Federated Learning" OFF) ## TODO: 1. Add check if DPC++ compiler is used for building -option(PLUGIN_UPDATER_ONEAPI "DPC++ updater" OFF) +option(PLUGIN_SYCL "SYCL plugin" OFF) option(ADD_PKGCONFIG "Add xgboost.pc into system." ON) #-- Checks for building XGBoost @@ -264,14 +264,14 @@ if (PLUGIN_RMM) get_target_property(rmm_link_libs rmm::rmm INTERFACE_LINK_LIBRARIES) endif (PLUGIN_RMM) -if (PLUGIN_UPDATER_ONEAPI) +if (PLUGIN_SYCL) set(CMAKE_CXX_LINK_EXECUTABLE "icpx -qopenmp -o ") set(CMAKE_CXX_CREATE_SHARED_LIBRARY "icpx -qopenmp \ , \ -o ") -endif (PLUGIN_UPDATER_ONEAPI) +endif (PLUGIN_SYCL) #-- library if (BUILD_STATIC_LIB) diff --git a/plugin/CMakeLists.txt b/plugin/CMakeLists.txt index d2df479c80d7..7ad96703fa5a 100644 --- a/plugin/CMakeLists.txt +++ b/plugin/CMakeLists.txt @@ -2,37 +2,37 @@ if (PLUGIN_DENSE_PARSER) target_sources(objxgboost PRIVATE ${xgboost_SOURCE_DIR}/plugin/dense_parser/dense_libsvm.cc) endif (PLUGIN_DENSE_PARSER) -if (PLUGIN_UPDATER_ONEAPI) +if (PLUGIN_SYCL) set(CMAKE_CXX_COMPILER "icpx") - add_library(oneapi_plugin OBJECT - ${xgboost_SOURCE_DIR}/plugin/updater_oneapi/hist_util_oneapi.cc - ${xgboost_SOURCE_DIR}/plugin/updater_oneapi/regression_obj_oneapi.cc - ${xgboost_SOURCE_DIR}/plugin/updater_oneapi/multiclass_obj_oneapi.cc - ${xgboost_SOURCE_DIR}/plugin/updater_oneapi/updater_quantile_hist_oneapi.cc - ${xgboost_SOURCE_DIR}/plugin/updater_oneapi/device_manager_oneapi.cc - ${xgboost_SOURCE_DIR}/plugin/updater_oneapi/predictor_oneapi.cc) - target_include_directories(oneapi_plugin + add_library(plugin_sycl OBJECT + ${xgboost_SOURCE_DIR}/plugin/sycl/common/hist_util.cc + ${xgboost_SOURCE_DIR}/plugin/sycl/objective/regression_obj.cc + ${xgboost_SOURCE_DIR}/plugin/sycl/objective/multiclass_obj.cc + ${xgboost_SOURCE_DIR}/plugin/sycl/tree/updater_quantile_hist.cc + ${xgboost_SOURCE_DIR}/plugin/sycl/device_manager.cc + ${xgboost_SOURCE_DIR}/plugin/sycl/predictor/predictor.cc) + target_include_directories(plugin_sycl PRIVATE ${xgboost_SOURCE_DIR}/include ${xgboost_SOURCE_DIR}/dmlc-core/include ${xgboost_SOURCE_DIR}/rabit/include) - target_compile_definitions(oneapi_plugin PUBLIC -DXGBOOST_USE_ONEAPI=1) - target_link_libraries(oneapi_plugin PUBLIC -fsycl) - set_target_properties(oneapi_plugin PROPERTIES + target_compile_definitions(plugin_sycl PUBLIC -DXGBOOST_USE_SYCL=1) + target_link_libraries(plugin_sycl PUBLIC -fsycl) + set_target_properties(plugin_sycl PROPERTIES COMPILE_FLAGS -fsycl CXX_STANDARD 17 CXX_STANDARD_REQUIRED ON POSITION_INDEPENDENT_CODE ON) if (USE_OPENMP) find_package(OpenMP REQUIRED) - set_target_properties(oneapi_plugin PROPERTIES + set_target_properties(plugin_sycl PROPERTIES COMPILE_FLAGS "-fsycl -qopenmp") endif (USE_OPENMP) - # Get compilation and link flags of oneapi_plugin and propagate to objxgboost - target_link_libraries(objxgboost PUBLIC oneapi_plugin) - # Add all objects of oneapi_plugin to objxgboost - target_sources(objxgboost INTERFACE $) -endif (PLUGIN_UPDATER_ONEAPI) + # Get compilation and link flags of plugin_sycl and propagate to objxgboost + target_link_libraries(objxgboost PUBLIC plugin_sycl) + # Add all objects of plugin_sycl to objxgboost + target_sources(objxgboost INTERFACE $) +endif (PLUGIN_SYCL) # Add the Federate Learning plugin if enabled. if (PLUGIN_FEDERATED) diff --git a/plugin/updater_oneapi/README.md b/plugin/sycl/README.md similarity index 70% rename from plugin/updater_oneapi/README.md rename to plugin/sycl/README.md index ddb05e497925..afccad0e3a00 100755 --- a/plugin/updater_oneapi/README.md +++ b/plugin/sycl/README.md @@ -1,8 +1,8 @@ -# DPC++-based Algorithm for Tree Construction -This plugin adds support of OneAPI programming model for tree construction and prediction algorithms to XGBoost. +# SYCL-based Algorithm for Tree Construction +This plugin adds support of SYCL programming model for tree construction and prediction algorithms to XGBoost. ## Usage -Specify the 'device' parameter as one of the following options to offload model training and inference on OneAPI device. +Specify the 'device' parameter as one of the following options to offload model training and inference on SYCL device. ### Algorithms | device | Description | @@ -27,6 +27,6 @@ From the command line on Linux starting from the xgboost directory: ```bash $ mkdir build $ cd build -$ EXPORT CXX=dpcpp && cmake .. -DPLUGIN_UPDATER_ONEAPI=ON +$ EXPORT CXX=dpcpp && cmake .. -DPLUGIN_SYCL=ON $ make -j ``` diff --git a/plugin/updater_oneapi/hist_util_oneapi.cc b/plugin/sycl/common/hist_util.cc similarity index 57% rename from plugin/updater_oneapi/hist_util_oneapi.cc rename to plugin/sycl/common/hist_util.cc index cb4e37513407..4edc9768b3e1 100644 --- a/plugin/updater_oneapi/hist_util_oneapi.cc +++ b/plugin/sycl/common/hist_util.cc @@ -1,15 +1,16 @@ /*! * Copyright 2017-2023 by Contributors - * \file hist_util_oneapi.cc + * \file hist_util.cc */ #include #include -#include "hist_util_oneapi.h" +#include "hist_util.h" #include "CL/sycl.hpp" namespace xgboost { +namespace sycl { namespace common { uint32_t SearchBin(const bst_float* cut_values, const uint32_t* cut_ptrs, Entry const& e) { @@ -48,26 +49,26 @@ void mergeSort(BinIdxType* begin, BinIdxType* end, BinIdxType* buf) { } template -void GHistIndexMatrixOneAPI::SetIndexData(sycl::queue qu, - common::Span index_data_span, - const DeviceMatrixOneAPI &dmat_device, - size_t nbins, - size_t row_stride, - uint32_t* offsets) { +void GHistIndexMatrix::SetIndexData(::sycl::queue qu, + xgboost::common::Span index_data_span, + const DeviceMatrix &dmat_device, + size_t nbins, + size_t row_stride, + uint32_t* offsets) { const xgboost::Entry *data_ptr = dmat_device.data.DataConst(); const bst_row_t *offset_vec = dmat_device.row_ptr.DataConst(); const size_t num_rows = dmat_device.row_ptr.Size() - 1; BinIdxType* index_data = index_data_span.data(); const bst_float* cut_values = cut_device.Values().DataConst(); const uint32_t* cut_ptrs = cut_device.Ptrs().DataConst(); - sycl::buffer hit_count_buf(hit_count.data(), hit_count.size()); + ::sycl::buffer hit_count_buf(hit_count.data(), hit_count.size()); USMVector sort_buf(qu, num_rows * row_stride); BinIdxType* sort_data = sort_buf.Data(); - qu.submit([&](sycl::handler& cgh) { - auto hit_count_acc = hit_count_buf.template get_access(cgh); - cgh.parallel_for<>(sycl::range<1>(num_rows), [=](sycl::item<1> pid) { + qu.submit([&](::sycl::handler& cgh) { + auto hit_count_acc = hit_count_buf.template get_access<::sycl::access::mode::atomic>(cgh); + cgh.parallel_for<>(::sycl::range<1>(num_rows), [=](::sycl::item<1> pid) { const size_t i = pid.get_id(0); const size_t ibegin = offset_vec[i]; const size_t iend = offset_vec[i + 1]; @@ -76,7 +77,7 @@ void GHistIndexMatrixOneAPI::SetIndexData(sycl::queue qu, for (bst_uint j = 0; j < size; ++j) { uint32_t idx = SearchBin(cut_values, cut_ptrs, data_ptr[ibegin + j]); index_data[start + j] = offsets ? idx - offsets[j] : idx; - sycl::atomic_fetch_add(hit_count_acc[idx], 1); + ::sycl::atomic_fetch_add(hit_count_acc[idx], 1); } if (!offsets) { // Sparse case only @@ -89,29 +90,29 @@ void GHistIndexMatrixOneAPI::SetIndexData(sycl::queue qu, }).wait(); } -void GHistIndexMatrixOneAPI::ResizeIndex(const size_t n_offsets, +void GHistIndexMatrix::ResizeIndex(const size_t n_offsets, const size_t n_index, const bool isDense) { if ((max_num_bins - 1 <= static_cast(std::numeric_limits::max())) && isDense) { - index.SetBinTypeSize(kUint8BinsTypeSize); + index.SetBinTypeSize(BinTypeSize::kUint8BinsTypeSize); index.Resize((sizeof(uint8_t)) * n_index); } else if ((max_num_bins - 1 > static_cast(std::numeric_limits::max()) && max_num_bins - 1 <= static_cast(std::numeric_limits::max())) && isDense) { - index.SetBinTypeSize(kUint16BinsTypeSize); + index.SetBinTypeSize(BinTypeSize::kUint16BinsTypeSize); index.Resize((sizeof(uint16_t)) * n_index); } else { - index.SetBinTypeSize(kUint32BinsTypeSize); + index.SetBinTypeSize(BinTypeSize::kUint32BinsTypeSize); index.Resize((sizeof(uint32_t)) * n_index); } } -void GHistIndexMatrixOneAPI::Init(sycl::queue qu, - Context const * ctx, - const DeviceMatrixOneAPI& p_fmat_device, - int max_bins) { +void GHistIndexMatrix::Init(::sycl::queue qu, + Context const * ctx, + const DeviceMatrix& p_fmat_device, + int max_bins) { nfeatures = p_fmat_device.p_mat->Info().num_col_; - cut = SketchOnDMatrix(ctx, p_fmat_device.p_mat, max_bins); + cut = xgboost::common::SketchOnDMatrix(ctx, p_fmat_device.p_mat, max_bins); cut_device.Init(qu, cut); max_num_bins = max_bins; @@ -155,25 +156,25 @@ void GHistIndexMatrixOneAPI::Init(sycl::queue qu, if (isDense) { BinTypeSize curent_bin_size = index.GetBinTypeSize(); - if (curent_bin_size == kUint8BinsTypeSize) { - common::Span index_data_span = {index.data(), - n_index}; + if (curent_bin_size == BinTypeSize::kUint8BinsTypeSize) { + xgboost::common::Span index_data_span = {index.data(), + n_index}; SetIndexData(qu, index_data_span, p_fmat_device, nbins, row_stride, offsets); - } else if (curent_bin_size == kUint16BinsTypeSize) { - common::Span index_data_span = {index.data(), - n_index}; + } else if (curent_bin_size == BinTypeSize::kUint16BinsTypeSize) { + xgboost::common::Span index_data_span = {index.data(), + n_index}; SetIndexData(qu, index_data_span, p_fmat_device, nbins, row_stride, offsets); } else { - CHECK_EQ(curent_bin_size, kUint32BinsTypeSize); - common::Span index_data_span = {index.data(), - n_index}; + CHECK_EQ(curent_bin_size, BinTypeSize::kUint32BinsTypeSize); + xgboost::common::Span index_data_span = {index.data(), + n_index}; SetIndexData(qu, index_data_span, p_fmat_device, nbins, row_stride, offsets); } /* For sparse DMatrix we have to store index of feature for each bin in index field to chose right offset. So offset is nullptr and index is not reduced */ } else { - common::Span index_data_span = {index.data(), n_index}; + xgboost::common::Span index_data_span = {index.data(), n_index}; SetIndexData(qu, index_data_span, p_fmat_device, nbins, row_stride, offsets); } } @@ -182,81 +183,81 @@ void GHistIndexMatrixOneAPI::Init(sycl::queue qu, * \brief Fill histogram with zeroes */ template -void InitHist(sycl::queue qu, GHistRowOneAPI& hist, size_t size) { +void InitHist(::sycl::queue qu, GHistRow& hist, size_t size) { qu.fill(hist.Begin(), xgboost::detail::GradientPairInternal(), size); } -template void InitHist(sycl::queue qu, GHistRowOneAPI& hist, size_t size); -template void InitHist(sycl::queue qu, GHistRowOneAPI& hist, size_t size); +template void InitHist(::sycl::queue qu, GHistRow& hist, size_t size); +template void InitHist(::sycl::queue qu, GHistRow& hist, size_t size); /*! * \brief Copy histogram from src to dst */ template -void CopyHist(sycl::queue qu, - GHistRowOneAPI& dst, - const GHistRowOneAPI& src, +void CopyHist(::sycl::queue qu, + GHistRow& dst, + const GHistRow& src, size_t size) { GradientSumT* pdst = reinterpret_cast(dst.Data()); const GradientSumT* psrc = reinterpret_cast(src.DataConst()); - qu.submit([&](sycl::handler& cgh) { - cgh.parallel_for<>(sycl::range<1>(2 * size), [=](sycl::item<1> pid) { + qu.submit([&](::sycl::handler& cgh) { + cgh.parallel_for<>(::sycl::range<1>(2 * size), [=](::sycl::item<1> pid) { const size_t i = pid.get_id(0); pdst[i] = psrc[i]; }); }).wait(); } -template void CopyHist(sycl::queue qu, - GHistRowOneAPI& dst, - const GHistRowOneAPI& src, +template void CopyHist(::sycl::queue qu, + GHistRow& dst, + const GHistRow& src, size_t size); -template void CopyHist(sycl::queue qu, - GHistRowOneAPI& dst, - const GHistRowOneAPI& src, +template void CopyHist(::sycl::queue qu, + GHistRow& dst, + const GHistRow& src, size_t size); /*! * \brief Compute Subtraction: dst = src1 - src2 */ template -sycl::event SubtractionHist(sycl::queue qu, - GHistRowOneAPI& dst, - const GHistRowOneAPI& src1, - const GHistRowOneAPI& src2, - size_t size, sycl::event event_priv) { +::sycl::event SubtractionHist(::sycl::queue qu, + GHistRow& dst, + const GHistRow& src1, + const GHistRow& src2, + size_t size, ::sycl::event event_priv) { GradientSumT* pdst = reinterpret_cast(dst.Data()); const GradientSumT* psrc1 = reinterpret_cast(src1.DataConst()); const GradientSumT* psrc2 = reinterpret_cast(src2.DataConst()); - auto event_final = qu.submit([&](sycl::handler& cgh) { + auto event_final = qu.submit([&](::sycl::handler& cgh) { cgh.depends_on(event_priv); - cgh.parallel_for<>(sycl::range<1>(2 * size), [pdst, psrc1, psrc2](sycl::item<1> pid) { + cgh.parallel_for<>(::sycl::range<1>(2 * size), [pdst, psrc1, psrc2](::sycl::item<1> pid) { const size_t i = pid.get_id(0); pdst[i] = psrc1[i] - psrc2[i]; }); }); return event_final; } -template sycl::event SubtractionHist(sycl::queue qu, - GHistRowOneAPI& dst, - const GHistRowOneAPI& src1, - const GHistRowOneAPI& src2, - size_t size, sycl::event event_priv); -template sycl::event SubtractionHist(sycl::queue qu, - GHistRowOneAPI& dst, - const GHistRowOneAPI& src1, - const GHistRowOneAPI& src2, - size_t size, sycl::event event_priv); +template ::sycl::event SubtractionHist(::sycl::queue qu, + GHistRow& dst, + const GHistRow& src1, + const GHistRow& src2, + size_t size, ::sycl::event event_priv); +template ::sycl::event SubtractionHist(::sycl::queue qu, + GHistRow& dst, + const GHistRow& src1, + const GHistRow& src2, + size_t size, ::sycl::event event_priv); // Kernel with buffer using template -sycl::event BuildHistKernel(sycl::queue qu, +::sycl::event BuildHistKernel(::sycl::queue qu, const USMVector& gpair_device, - const RowSetCollectionOneAPI::Elem& row_indices, - const GHistIndexMatrixOneAPI& gmat, - GHistRowOneAPI& hist, - GHistRowOneAPI& hist_buffer, - sycl::event event_priv) { + const RowSetCollection::Elem& row_indices, + const GHistIndexMatrix& gmat, + GHistRow& hist, + GHistRow& hist_buffer, + ::sycl::event event_priv) { const size_t size = row_indices.Size(); const size_t* rid = row_indices.begin; const size_t n_columns = isDense ? gmat.nfeatures : gmat.row_stride; @@ -266,7 +267,7 @@ sycl::event BuildHistKernel(sycl::queue qu, FPType* hist_data = reinterpret_cast(hist.Data()); const size_t nbins = gmat.nbins; - const size_t max_feat_local = qu.get_device().get_info(); + const size_t max_feat_local = qu.get_device().get_info<::sycl::info::device::max_work_group_size>(); const size_t feat_local = n_columns < max_feat_local ? n_columns : max_feat_local; const size_t max_nblocks = hist_buffer.Size() / (nbins * 2); @@ -276,10 +277,10 @@ sycl::event BuildHistKernel(sycl::queue qu, FPType* hist_buffer_data = reinterpret_cast(hist_buffer.Data()); auto event_fill = qu.fill(hist_buffer_data, FPType(0), nblocks * nbins * 2, event_priv); - auto event_main = qu.submit([&](sycl::handler& cgh) { + auto event_main = qu.submit([&](::sycl::handler& cgh) { cgh.depends_on(event_fill); - cgh.parallel_for<>(sycl::nd_range<2>(sycl::range<2>(nblocks, feat_local), - sycl::range<2>(1, feat_local)), [=](sycl::nd_item<2> pid) { + cgh.parallel_for<>(::sycl::nd_range<2>(::sycl::range<2>(nblocks, feat_local), + ::sycl::range<2>(1, feat_local)), [=](::sycl::nd_item<2> pid) { size_t block = pid.get_global_id(0); size_t feat = pid.get_global_id(1); @@ -290,7 +291,7 @@ sycl::event BuildHistKernel(sycl::queue qu, const size_t icol_start = n_columns * rid[i]; const size_t idx_gh = rid[i]; - pid.barrier(sycl::access::fence_space::local_space); + pid.barrier(::sycl::access::fence_space::local_space); const BinIdxType* gr_index_local = gradient_index + icol_start; for (size_t j = feat; j < n_columns; j += feat_local) { @@ -308,9 +309,9 @@ sycl::event BuildHistKernel(sycl::queue qu, }); }); - auto event_save = qu.submit([&](sycl::handler& cgh) { + auto event_save = qu.submit([&](::sycl::handler& cgh) { cgh.depends_on(event_main); - cgh.parallel_for<>(sycl::range<1>(nbins), [=](sycl::item<1> pid) { + cgh.parallel_for<>(::sycl::range<1>(nbins), [=](::sycl::item<1> pid) { size_t idx_bin = pid.get_id(0); FPType gsum = 0.0f; @@ -330,12 +331,12 @@ sycl::event BuildHistKernel(sycl::queue qu, // Kernel with atomic using template -sycl::event BuildHistKernel(sycl::queue qu, +::sycl::event BuildHistKernel(::sycl::queue qu, const USMVector& gpair_device, - const RowSetCollectionOneAPI::Elem& row_indices, - const GHistIndexMatrixOneAPI& gmat, - GHistRowOneAPI& hist, - sycl::event event_priv) { + const RowSetCollection::Elem& row_indices, + const GHistIndexMatrix& gmat, + GHistRow& hist, + ::sycl::event event_priv) { const size_t size = row_indices.Size(); const size_t* rid = row_indices.begin; const size_t n_columns = isDense ? gmat.nfeatures : gmat.row_stride; @@ -345,14 +346,14 @@ sycl::event BuildHistKernel(sycl::queue qu, FPType* hist_data = reinterpret_cast(hist.Data()); const size_t nbins = gmat.nbins; - const size_t max_feat_local = qu.get_device().get_info(); + const size_t max_feat_local = qu.get_device().get_info<::sycl::info::device::max_work_group_size>(); const size_t feat_local = n_columns < max_feat_local ? n_columns : max_feat_local; auto event_fill = qu.fill(hist_data, FPType(0), nbins * 2, event_priv); - auto event_main = qu.submit([&](sycl::handler& cgh) { + auto event_main = qu.submit([&](::sycl::handler& cgh) { cgh.depends_on(event_fill); - cgh.parallel_for<>(sycl::range<2>(size, feat_local), - [=](sycl::item<2> pid) { + cgh.parallel_for<>(::sycl::range<2>(size, feat_local), + [=](::sycl::item<2> pid) { size_t i = pid.get_id(0); size_t feat = pid.get_id(1); @@ -379,19 +380,19 @@ sycl::event BuildHistKernel(sycl::queue qu, } template -sycl::event BuildHistDispatchKernel(sycl::queue qu, +::sycl::event BuildHistDispatchKernel(::sycl::queue qu, const USMVector& gpair_device, - const RowSetCollectionOneAPI::Elem& row_indices, - const GHistIndexMatrixOneAPI& gmat, - GHistRowOneAPI& hist, + const RowSetCollection::Elem& row_indices, + const GHistIndexMatrix& gmat, + GHistRow& hist, bool isDense, - GHistRowOneAPI& hist_buffer, - sycl::event events_priv) { + GHistRow& hist_buffer, + ::sycl::event events_priv) { const size_t size = row_indices.Size(); const size_t n_columns = isDense ? gmat.nfeatures : gmat.row_stride; const size_t nbins = gmat.nbins; - const size_t max_feat_local = qu.get_device().get_info(); + const size_t max_feat_local = qu.get_device().get_info<::sycl::info::device::max_work_group_size>(); const size_t feat_local = n_columns < max_feat_local ? n_columns : max_feat_local; // max cycle size, while atomics are still effective @@ -419,26 +420,26 @@ sycl::event BuildHistDispatchKernel(sycl::queue qu, } template -sycl::event BuildHistKernel(sycl::queue qu, +::sycl::event BuildHistKernel(::sycl::queue qu, const USMVector& gpair_device, - const RowSetCollectionOneAPI::Elem& row_indices, - const GHistIndexMatrixOneAPI& gmat, const bool isDense, - GHistRowOneAPI& hist, - GHistRowOneAPI& hist_buffer, - sycl::event event_priv) { + const RowSetCollection::Elem& row_indices, + const GHistIndexMatrix& gmat, const bool isDense, + GHistRow& hist, + GHistRow& hist_buffer, + ::sycl::event event_priv) { const bool is_dense = isDense; switch (gmat.index.GetBinTypeSize()) { - case kUint8BinsTypeSize: + case BinTypeSize::kUint8BinsTypeSize: return BuildHistDispatchKernel(qu, gpair_device, row_indices, gmat, hist, is_dense, hist_buffer, event_priv); break; - case kUint16BinsTypeSize: + case BinTypeSize::kUint16BinsTypeSize: return BuildHistDispatchKernel(qu, gpair_device, row_indices, gmat, hist, is_dense, hist_buffer, event_priv); break; - case kUint32BinsTypeSize: + case BinTypeSize::kUint32BinsTypeSize: return BuildHistDispatchKernel(qu, gpair_device, row_indices, gmat, hist, is_dense, hist_buffer, event_priv); @@ -449,50 +450,51 @@ sycl::event BuildHistKernel(sycl::queue qu, } template -sycl::event GHistBuilderOneAPI::BuildHist(const USMVector& gpair_device, - const RowSetCollectionOneAPI::Elem& row_indices, - const GHistIndexMatrixOneAPI &gmat, - GHistRowT& hist, - bool isDense, - GHistRowT& hist_buffer, - sycl::event event_priv) { +::sycl::event GHistBuilder::BuildHist(const USMVector& gpair_device, + const RowSetCollection::Elem& row_indices, + const GHistIndexMatrix &gmat, + GHistRowT& hist, + bool isDense, + GHistRowT& hist_buffer, + ::sycl::event event_priv) { return BuildHistKernel(qu_, gpair_device, row_indices, gmat, isDense, hist, hist_buffer, event_priv); } template -sycl::event GHistBuilderOneAPI::BuildHist(const USMVector& gpair_device, - const RowSetCollectionOneAPI::Elem& row_indices, - const GHistIndexMatrixOneAPI& gmat, - GHistRowOneAPI& hist, - bool isDense, - GHistRowOneAPI& hist_buffer, - sycl::event event_priv); +::sycl::event GHistBuilder::BuildHist(const USMVector& gpair_device, + const RowSetCollection::Elem& row_indices, + const GHistIndexMatrix& gmat, + GHistRow& hist, + bool isDense, + GHistRow& hist_buffer, + ::sycl::event event_priv); template -sycl::event GHistBuilderOneAPI::BuildHist(const USMVector& gpair_device, - const RowSetCollectionOneAPI::Elem& row_indices, - const GHistIndexMatrixOneAPI& gmat, - GHistRowOneAPI& hist, - bool isDense, - GHistRowOneAPI& hist_buffer, - sycl::event event_priv); +::sycl::event GHistBuilder::BuildHist(const USMVector& gpair_device, + const RowSetCollection::Elem& row_indices, + const GHistIndexMatrix& gmat, + GHistRow& hist, + bool isDense, + GHistRow& hist_buffer, + ::sycl::event event_priv); template -void GHistBuilderOneAPI::SubtractionTrick(GHistRowT& self, - GHistRowT& sibling, - GHistRowT& parent) { +void GHistBuilder::SubtractionTrick(GHistRowT& self, + GHistRowT& sibling, + GHistRowT& parent) { const size_t size = self.Size(); CHECK_EQ(sibling.Size(), size); CHECK_EQ(parent.Size(), size); - SubtractionHist(qu_, self, parent, sibling, size, sycl::event()); + SubtractionHist(qu_, self, parent, sibling, size, ::sycl::event()); } template -void GHistBuilderOneAPI::SubtractionTrick(GHistRowOneAPI& self, - GHistRowOneAPI& sibling, - GHistRowOneAPI& parent); +void GHistBuilder::SubtractionTrick(GHistRow& self, + GHistRow& sibling, + GHistRow& parent); template -void GHistBuilderOneAPI::SubtractionTrick(GHistRowOneAPI& self, - GHistRowOneAPI& sibling, - GHistRowOneAPI& parent); +void GHistBuilder::SubtractionTrick(GHistRow& self, + GHistRow& sibling, + GHistRow& parent); } // namespace common +} // namespace sycl } // namespace xgboost diff --git a/plugin/updater_oneapi/hist_util_oneapi.h b/plugin/sycl/common/hist_util.h similarity index 66% rename from plugin/updater_oneapi/hist_util_oneapi.h rename to plugin/sycl/common/hist_util.h index 95e0fb6801c9..fe8ab730bd8d 100644 --- a/plugin/updater_oneapi/hist_util_oneapi.h +++ b/plugin/sycl/common/hist_util.h @@ -1,49 +1,50 @@ /*! * Copyright 2017-2023 by Contributors - * \file hist_util_oneapi.h + * \file hist_util.h */ -#ifndef XGBOOST_COMMON_HIST_UTIL_ONEAPI_H_ -#define XGBOOST_COMMON_HIST_UTIL_ONEAPI_H_ +#ifndef XGBOOST_COMMON_HIST_UTIL_SYCL_H_ +#define XGBOOST_COMMON_HIST_UTIL_SYCL_H_ #include -#include "data_oneapi.h" -#include "row_set_oneapi.h" +#include "../data.h" +#include "row_set.h" #include "../../src/common/hist_util.h" #include "CL/sycl.hpp" namespace xgboost { +namespace sycl { namespace common { template -using GHistRowOneAPI = USMVector, memory_type>; +using GHistRow = USMVector, memory_type>; template -using AtomicRef = sycl::atomic_ref; +using AtomicRef = ::sycl::atomic_ref; /*! - * \brief OneAPI implementation of HistogramCuts stored in USM buffers to provide access from device kernels + * \brief SYCL implementation of HistogramCuts stored in USM buffers to provide access from device kernels */ -class HistogramCutsOneAPI { +class HistogramCuts { protected: using BinIdx = uint32_t; public: - HistogramCutsOneAPI() {} + HistogramCuts() {} - HistogramCutsOneAPI(sycl::queue qu) { + HistogramCuts(::sycl::queue qu) { cut_ptrs_.Resize(qu_, 1, 0); } - ~HistogramCutsOneAPI() { + ~HistogramCuts() { } - void Init(sycl::queue qu, HistogramCuts const& cuts) { + void Init(::sycl::queue qu, xgboost::common::HistogramCuts const& cuts) { qu_ = qu; cut_values_.Init(qu_, cuts.cut_values_.HostVector()); cut_ptrs_.Init(qu_, cuts.cut_ptrs_.HostVector()); @@ -59,20 +60,22 @@ class HistogramCutsOneAPI { USMVector cut_values_; USMVector cut_ptrs_; USMVector min_vals_; - sycl::queue qu_; + ::sycl::queue qu_; }; +using BinTypeSize = ::xgboost::common::BinTypeSize; + /*! * \brief Index data and offsets stored in USM buffers to provide access from device kernels */ -struct IndexOneAPI { - IndexOneAPI() { +struct Index { + Index() { SetBinTypeSize(binTypeSize_); } - IndexOneAPI(const IndexOneAPI& i) = delete; - IndexOneAPI& operator=(IndexOneAPI i) = delete; - IndexOneAPI(IndexOneAPI&& i) = delete; - IndexOneAPI& operator=(IndexOneAPI&& i) = delete; + Index(const Index& i) = delete; + Index& operator=(Index i) = delete; + Index(Index&& i) = delete; + Index& operator=(Index&& i) = delete; uint32_t operator[](size_t i) const { if (!offset_.Empty()) { return func_(data_.DataConst(), i) + offset_[i%p_]; @@ -83,19 +86,19 @@ struct IndexOneAPI { void SetBinTypeSize(BinTypeSize binTypeSize) { binTypeSize_ = binTypeSize; switch (binTypeSize) { - case kUint8BinsTypeSize: + case BinTypeSize::kUint8BinsTypeSize: func_ = &GetValueFromUint8; break; - case kUint16BinsTypeSize: + case BinTypeSize::kUint16BinsTypeSize: func_ = &GetValueFromUint16; break; - case kUint32BinsTypeSize: + case BinTypeSize::kUint32BinsTypeSize: func_ = &GetValueFromUint32; break; default: - CHECK(binTypeSize == kUint8BinsTypeSize || - binTypeSize == kUint16BinsTypeSize || - binTypeSize == kUint32BinsTypeSize); + CHECK(binTypeSize == BinTypeSize::kUint8BinsTypeSize || + binTypeSize == BinTypeSize::kUint16BinsTypeSize || + binTypeSize == BinTypeSize::kUint32BinsTypeSize); } } BinTypeSize GetBinTypeSize() const { @@ -141,7 +144,7 @@ struct IndexOneAPI { return data_.End(); } - void setQueue(sycl::queue qu) { + void setQueue(::sycl::queue qu) { qu_ = qu; } @@ -160,11 +163,11 @@ struct IndexOneAPI { USMVector data_; USMVector offset_; // size of this field is equal to number of features - BinTypeSize binTypeSize_ {kUint8BinsTypeSize}; + BinTypeSize binTypeSize_ {BinTypeSize::kUint8BinsTypeSize}; size_t p_ {1}; Func func_; - sycl::queue qu_; + ::sycl::queue qu_; }; @@ -173,17 +176,17 @@ struct IndexOneAPI { * * Transform floating values to integer index in histogram */ -struct GHistIndexMatrixOneAPI { +struct GHistIndexMatrix { /*! \brief row pointer to rows by element position */ std::vector row_ptr; USMVector row_ptr_device; /*! \brief The index data */ - IndexOneAPI index; + Index index; /*! \brief hit count of each index */ std::vector hit_count; /*! \brief The corresponding cuts */ - HistogramCuts cut; - HistogramCutsOneAPI cut_device; + xgboost::common::HistogramCuts cut; + HistogramCuts cut_device; DMatrix* p_fmat; size_t max_num_bins; size_t nbins; @@ -191,11 +194,11 @@ struct GHistIndexMatrixOneAPI { size_t row_stride; // Create a global histogram matrix based on a given DMatrix device wrapper - void Init(sycl::queue qu, Context const * ctx, const DeviceMatrixOneAPI& p_fmat_device, int max_num_bins); + void Init(::sycl::queue qu, Context const * ctx, const sycl::DeviceMatrix& p_fmat_device, int max_num_bins); template - void SetIndexData(sycl::queue qu, common::Span index_data_span, - const DeviceMatrixOneAPI &dmat_device, + void SetIndexData(::sycl::queue qu, xgboost::common::Span index_data_span, + const sycl::DeviceMatrix &dmat_device, size_t nbins, size_t row_stride, uint32_t* offsets); void ResizeIndex(const size_t n_offsets, const size_t n_index, @@ -219,42 +222,42 @@ struct GHistIndexMatrixOneAPI { bool isDense_; }; -class ColumnMatrixOneAPI; +class ColumnMatrix; /*! * \brief Fill histogram with zeroes */ template -void InitHist(sycl::queue qu, - GHistRowOneAPI& hist, +void InitHist(::sycl::queue qu, + GHistRow& hist, size_t size); /*! * \brief Copy histogram from src to dst */ template -void CopyHist(sycl::queue qu, - GHistRowOneAPI& dst, - const GHistRowOneAPI& src, +void CopyHist(::sycl::queue qu, + GHistRow& dst, + const GHistRow& src, size_t size); /*! * \brief Compute subtraction: dst = src1 - src2 */ template -sycl::event SubtractionHist(sycl::queue qu, - GHistRowOneAPI& dst, - const GHistRowOneAPI& src1, - const GHistRowOneAPI& src2, - size_t size, sycl::event event_priv); +::sycl::event SubtractionHist(::sycl::queue qu, + GHistRow& dst, + const GHistRow& src1, + const GHistRow& src2, + size_t size, ::sycl::event event_priv); /*! * \brief Histograms of gradient statistics for multiple nodes */ template -class HistCollectionOneAPI { +class HistCollection { public: - using GHistRowT = GHistRowOneAPI; + using GHistRowT = GHistRow; // Access histogram for i-th node GHistRowT& operator[](bst_uint nid) { @@ -266,7 +269,7 @@ class HistCollectionOneAPI { } // Initialize histogram collection - void Init(sycl::queue qu, uint32_t nbins) { + void Init(::sycl::queue qu, uint32_t nbins) { qu_ = qu; if (nbins_ != nbins) { nbins_ = nbins; @@ -280,7 +283,7 @@ class HistCollectionOneAPI { } // Create an empty histogram for i-th node - sycl::event AddHistRow(bst_uint nid) { + ::sycl::event AddHistRow(bst_uint nid) { if (nid >= data_.size()) { data_.resize(nid + 1); } @@ -297,18 +300,18 @@ class HistCollectionOneAPI { std::vector data_; - sycl::queue qu_; + ::sycl::queue qu_; }; /*! * \brief Stores temporary histograms to compute them in parallel */ template -class ParallelGHistBuilderOneAPI { +class ParallelGHistBuilder { public: - using GHistRowT = GHistRowOneAPI; + using GHistRowT = GHistRow; - void Init(sycl::queue qu, size_t nbins) { + void Init(::sycl::queue qu, size_t nbins) { qu_ = qu; if (nbins != nbins_) { hist_buffer_.Init(qu_, nbins); @@ -328,34 +331,34 @@ class ParallelGHistBuilderOneAPI { /*! \brief Number of bins in each histogram */ size_t nbins_ = 0; /*! \brief Buffers for histograms for all nodes processed */ - HistCollectionOneAPI hist_buffer_; + HistCollection hist_buffer_; /*! \brief Buffer for additional histograms for Parallel processing */ GHistRowT hist_device_buffer_; - sycl::queue qu_; + ::sycl::queue qu_; }; /*! * \brief Builder for histograms of gradient statistics */ template -class GHistBuilderOneAPI { +class GHistBuilder { public: template - using GHistRowT = GHistRowOneAPI; + using GHistRowT = GHistRow; - GHistBuilderOneAPI() = default; - GHistBuilderOneAPI(sycl::queue qu, uint32_t nbins) : qu_{qu}, nbins_{nbins} {} + GHistBuilder() = default; + GHistBuilder(::sycl::queue qu, uint32_t nbins) : qu_{qu}, nbins_{nbins} {} // Construct a histogram via histogram aggregation - sycl::event BuildHist(const USMVector& gpair_device, - const RowSetCollectionOneAPI::Elem& row_indices, - const GHistIndexMatrixOneAPI& gmat, - GHistRowT& HistCollectionOneAPI, + ::sycl::event BuildHist(const USMVector& gpair_device, + const RowSetCollection::Elem& row_indices, + const GHistIndexMatrix& gmat, + GHistRowT& HistCollection, bool isDense, GHistRowT& hist_buffer, - sycl::event evens); + ::sycl::event evens); // Construct a histogram via subtraction trick void SubtractionTrick(GHistRowT& self, @@ -370,8 +373,9 @@ class GHistBuilderOneAPI { /*! \brief Number of all bins over all features */ uint32_t nbins_ { 0 }; - sycl::queue qu_; + ::sycl::queue qu_; }; } // namespace common +} // namespace sycl } // namespace xgboost -#endif // XGBOOST_COMMON_HIST_UTIL_ONEAPI_H_ +#endif // XGBOOST_COMMON_HIST_UTIL_SYCL_H_ diff --git a/plugin/updater_oneapi/row_set_oneapi.h b/plugin/sycl/common/row_set.h similarity index 85% rename from plugin/updater_oneapi/row_set_oneapi.h rename to plugin/sycl/common/row_set.h index 309c85ad432e..4d5bff83119b 100644 --- a/plugin/updater_oneapi/row_set_oneapi.h +++ b/plugin/sycl/common/row_set.h @@ -1,8 +1,8 @@ /*! * Copyright 2017-2023 XGBoost contributors */ -#ifndef XGBOOST_COMMON_ROW_SET_ONEAPI_H_ -#define XGBOOST_COMMON_ROW_SET_ONEAPI_H_ +#ifndef XGBOOST_COMMON_ROW_SET_SYCL_H_ +#define XGBOOST_COMMON_ROW_SET_SYCL_H_ #include @@ -11,18 +11,19 @@ #include -#include "data_oneapi.h" +#include "../data.h" #include "CL/sycl.hpp" namespace xgboost { +namespace sycl { namespace common { /*! \brief Collection of rowsets stored on device in USM memory */ -class RowSetCollectionOneAPI { +class RowSetCollection { public: /*! \brief data structure to store an instance set, a subset of * rows (instances) associated with a particular node in a decision @@ -124,14 +125,14 @@ class RowSetCollectionOneAPI { // The builder is required for samples partition to left and rights children for set of nodes -class PartitionBuilderOneAPI { +class PartitionBuilder { public: static constexpr size_t maxLocalSums = 256; static constexpr size_t subgroupSize = 16; template - void Init(sycl::queue qu, size_t n_nodes, Func funcNTaks) { + void Init(::sycl::queue qu, size_t n_nodes, Func funcNTaks) { qu_ = qu; nodes_offsets_.resize(n_nodes+1); result_rows_.resize(2 * n_nodes); @@ -151,17 +152,17 @@ class PartitionBuilderOneAPI { } - common::Span GetData(int nid) { + xgboost::common::Span GetData(int nid) { return { data_.Data() + nodes_offsets_[nid], nodes_offsets_[nid + 1] - nodes_offsets_[nid] }; } - common::Span GetPrefixSums() { + xgboost::common::Span GetPrefixSums() { return { prefix_sums_.Data(), prefix_sums_.Size() }; } - size_t GetLocalSize(const common::Range1d& range) { + size_t GetLocalSize(const xgboost::common::Range1d& range) { size_t range_size = range.end() - range.begin(); size_t local_subgroups = range_size / (maxLocalSums * subgroupSize) + !!(range_size % (maxLocalSums * subgroupSize)); return subgroupSize * local_subgroups; @@ -183,11 +184,11 @@ class PartitionBuilderOneAPI { // } - // sycl::event SetNLeftRightElems(sycl::queue& qu, const USMVector& parts_size, - // const std::vector& priv_events) { - // auto event = qu.submit([&](sycl::handler& cgh) { + // ::sycl::event SetNLeftRightElems(::sycl::queue& qu, const USMVector& parts_size, + // const std::vector<::sycl::event>& priv_events) { + // auto event = qu.submit([&](::sycl::handler& cgh) { // cgh.depends_on(priv_events); - // cgh.parallel_for<>(sycl::range<1>(n_nodes_), [=](sycl::item<1> nid) { + // cgh.parallel_for<>(::sycl::range<1>(n_nodes_), [=](::sycl::item<1> nid) { // const size_t node_in_set = nid.get_id(0); // result_left_rows_[node_in_set] = parts_size[2 * node_in_set]; // result_right_rows_[node_in_set] = parts_size[2 * node_in_set + 1]; @@ -214,15 +215,15 @@ class PartitionBuilderOneAPI { } - sycl::event MergeToArray(sycl::queue& qu, size_t node_in_set, + ::sycl::event MergeToArray(::sycl::queue& qu, size_t node_in_set, size_t* data_result, - sycl::event priv_event) { + ::sycl::event priv_event) { size_t n_nodes_total = GetNLeftElems(node_in_set) + GetNRightElems(node_in_set); if (n_nodes_total > 0) { const size_t* data = data_.Data() + nodes_offsets_[node_in_set]; return qu.memcpy(data_result, data, sizeof(size_t) * n_nodes_total, priv_event); } else { - return sycl::event(); + return ::sycl::event(); } } @@ -250,12 +251,13 @@ class PartitionBuilderOneAPI { USMVector prefix_sums_; - sycl::queue qu_; + ::sycl::queue qu_; }; } // namespace common +} // namespace sycl } // namespace xgboost -#endif // XGBOOST_COMMON_ROW_SET_ONEAPI_H_ +#endif // XGBOOST_COMMON_ROW_SET_SYCL_H_ diff --git a/plugin/updater_oneapi/data_oneapi.h b/plugin/sycl/data.h similarity index 72% rename from plugin/updater_oneapi/data_oneapi.h rename to plugin/sycl/data.h index a4d6ade88342..5f72cc9eafbb 100644 --- a/plugin/updater_oneapi/data_oneapi.h +++ b/plugin/sycl/data.h @@ -1,15 +1,19 @@ /*! * Copyright by Contributors 2017-2023 */ -#ifndef XGBOOST_COMMON_DATA_ONEAPI_H_ -#define XGBOOST_COMMON_DATA_ONEAPI_H_ +#ifndef XGBOOST_COMMON_DATA_SYCL_H_ +#define XGBOOST_COMMON_DATA_SYCL_H_ #include #include #include #include "xgboost/base.h" +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wtautological-constant-compare" +#pragma GCC diagnostic ignored "-W#pragma-messages" #include "xgboost/data.h" +#pragma GCC diagnostic pop #include "xgboost/logging.h" #include "xgboost/host_device_vector.h" @@ -18,38 +22,36 @@ #include "CL/sycl.hpp" namespace xgboost { - +namespace sycl { enum class MemoryType { shared, on_device}; template class USMDeleter { public: - explicit USMDeleter(sycl::queue qu) : qu_(qu) {} + explicit USMDeleter(::sycl::queue qu) : qu_(qu) {} void operator()(T* data) const { - sycl::free(data, qu_); + ::sycl::free(data, qu_); } private: - sycl::queue qu_; + ::sycl::queue qu_; }; -/* OneAPI implementation of a HostDeviceVector, storing both host and device memory in a single USM buffer. - Synchronization between host and device is managed by the compiler runtime. */ template class USMVector { static_assert(std::is_standard_layout::value, "USMVector admits only POD types"); - std::shared_ptr allocate_memory_(sycl::queue& qu, size_t size) { + std::shared_ptr allocate_memory_(::sycl::queue& qu, size_t size) { if constexpr (memory_type == MemoryType::shared) { - return std::shared_ptr(sycl::malloc_shared(size_, qu), USMDeleter(qu)); + return std::shared_ptr(::sycl::malloc_shared(size_, qu), USMDeleter(qu)); } else { - return std::shared_ptr(sycl::malloc_device(size_, qu), USMDeleter(qu)); + return std::shared_ptr(::sycl::malloc_device(size_, qu), USMDeleter(qu)); } } - void copy_vector_to_memory_(sycl::queue& qu, const std::vector &vec) { + void copy_vector_to_memory_(::sycl::queue& qu, const std::vector &vec) { if constexpr (memory_type == MemoryType::shared) { std::copy(vec.begin (), vec.end (), data_.get()); } else { @@ -61,26 +63,22 @@ class USMVector { public: USMVector() : size_(0), capacity_(0), data_(nullptr) {} - USMVector(sycl::queue& qu, size_t size) : size_(size), capacity_(size) { + USMVector(::sycl::queue& qu, size_t size) : size_(size), capacity_(size) { data_ = allocate_memory_(qu, size_); } - USMVector(sycl::queue& qu, size_t size, T v) : size_(size), capacity_(size) { + USMVector(::sycl::queue& qu, size_t size, T v) : size_(size), capacity_(size) { data_ = allocate_memory_(qu, size_); qu.fill(data_.get(), v, size_).wait(); } - USMVector(sycl::queue& qu, const std::vector &vec) { + USMVector(::sycl::queue& qu, const std::vector &vec) { size_ = vec.size(); capacity_ = size_; data_ = allocate_memory_(qu, size_); copy_vector_to_memory_(qu, vec); } -// Bug. Copy constructor doesn't copy data. -// USMVector(const USMVector& other) : qu_(other.qu_), size_(other.size_), data_(other.data_) { -// } - ~USMVector() { } @@ -112,7 +110,7 @@ class USMVector { capacity_ = 0; } - void Resize(sycl::queue& qu, size_t size_new) { + void Resize(::sycl::queue& qu, size_t size_new) { if (size_new <= capacity_) { size_ = size_new; } else { @@ -127,24 +125,7 @@ class USMVector { } } - // T Get(sycl::queue qu, size_t idx, std::vector* events_ptr) const { - // T val; - // auto event = qu.memcpy(&val, data_.get() + idx, sizeof(T), *events_ptr); - // events_ptr->emplace_back(event); - // return val; - // } - - // T Get(sycl::queue& qu, size_t idx) const { - // T val; - // last_event_ = qu.memcpy(&val, data_.get() + idx, sizeof(T)); - // return val; - // } - - // sycl::event GetLastEvent() const { - // return last_event_; - // } - - void Resize(sycl::queue& qu, size_t size_new, T v) { + void Resize(::sycl::queue& qu, size_t size_new, T v) { if (size_new <= size_) { size_ = size_new; } else if (size_new <= capacity_) { @@ -163,10 +144,10 @@ class USMVector { } } - sycl::event ResizeAsync(sycl::queue& qu, size_t size_new, T v) { + ::sycl::event ResizeAsync(::sycl::queue& qu, size_t size_new, T v) { if (size_new <= size_) { size_ = size_new; - return sycl::event(); + return ::sycl::event(); } else if (size_new <= capacity_) { auto event = qu.fill(data_.get() + size_, v, size_new - size_); size_ = size_new; @@ -177,7 +158,7 @@ class USMVector { size_ = size_new; capacity_ = size_new; data_ = allocate_memory_(qu, size_); - sycl::event event; + ::sycl::event event; if (size_old > 0) { event = qu.memcpy(data_.get(), data_old.get(), sizeof(T) * size_old); } @@ -185,7 +166,7 @@ class USMVector { } } - sycl::event ResizeAndFill(sycl::queue& qu, size_t size_new, int v) { + ::sycl::event ResizeAndFill(::sycl::queue& qu, size_t size_new, int v) { if (size_new <= size_) { size_ = size_new; return qu.memset(data_.get(), v, size_new * sizeof(T)); @@ -202,11 +183,11 @@ class USMVector { } } - sycl::event Fill(sycl::queue& qu, T v) { + ::sycl::event Fill(::sycl::queue& qu, T v) { return qu.fill(data_.get(), v, size_); } - void Init(sycl::queue& qu, const std::vector &vec) { + void Init(::sycl::queue& qu, const std::vector &vec) { size_ = vec.size(); capacity_ = size_; data_ = allocate_memory_(qu, size_); @@ -219,18 +200,17 @@ class USMVector { size_t size_; size_t capacity_; std::shared_ptr data_; - // mutable sycl::event last_event_; }; /* Wrapper for DMatrix which stores all batches in a single USM buffer */ -struct DeviceMatrixOneAPI { +struct DeviceMatrix { DMatrix* p_mat; // Pointer to the original matrix on the host - sycl::queue qu_; + ::sycl::queue qu_; USMVector row_ptr; USMVector data; size_t total_offset; - DeviceMatrixOneAPI(sycl::queue qu, DMatrix* dmat) : p_mat(dmat), qu_(qu) { + DeviceMatrix(::sycl::queue qu, DMatrix* dmat) : p_mat(dmat), qu_(qu) { size_t num_row = 0; size_t num_nonzero = 0; for (auto &batch : dmat->GetBatches()) { @@ -264,10 +244,10 @@ struct DeviceMatrixOneAPI { total_offset = data_offset; } - ~DeviceMatrixOneAPI() { + ~DeviceMatrix() { } }; - +} // namespace sycl } // namespace xgboost -#endif +#endif // XGBOOST_COMMON_DATA_SYCL_H_ diff --git a/plugin/updater_oneapi/device_manager_oneapi.cc b/plugin/sycl/device_manager.cc similarity index 71% rename from plugin/updater_oneapi/device_manager_oneapi.cc rename to plugin/sycl/device_manager.cc index 783b253e480f..2dbe2f96c59e 100644 --- a/plugin/updater_oneapi/device_manager_oneapi.cc +++ b/plugin/sycl/device_manager.cc @@ -1,14 +1,19 @@ /*! * Copyright 2017-2022 by Contributors - * \file device_manager_oneapi.cc + * \file device_manager.cc */ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wtautological-constant-compare" +#pragma GCC diagnostic ignored "-W#pragma-messages" #include +#pragma GCC diagnostic pop -#include "./device_manager_oneapi.h" +#include "../sycl/device_manager.h" namespace xgboost { +namespace sycl { -sycl::device DeviceManagerOneAPI::GetDevice(const DeviceOrd& device_spec) const { +::sycl::device DeviceManager::GetDevice(const DeviceOrd& device_spec) const { if (!device_spec.IsSycl()) { LOG(WARNING) << "Sycl kernel is executed with non-sycl context: " << device_spec.Name() << ". " @@ -35,16 +40,16 @@ sycl::device DeviceManagerOneAPI::GetDevice(const DeviceOrd& device_spec) const } } else { if(device_spec.IsSyclCPU()) { - return sycl::device(sycl::cpu_selector_v); + return ::sycl::device(::sycl::cpu_selector_v); } else if(device_spec.IsSyclGPU()) { - return sycl::device(sycl::gpu_selector_v); + return ::sycl::device(::sycl::gpu_selector_v); } else { - return sycl::device(sycl::default_selector_v); + return ::sycl::device(::sycl::default_selector_v); } } } -sycl::queue DeviceManagerOneAPI::GetQueue(const DeviceOrd& device_spec) const { +::sycl::queue DeviceManager::GetQueue(const DeviceOrd& device_spec) const { if (!device_spec.IsSycl()) { LOG(WARNING) << "Sycl kernel is executed with non-sycl context: " << device_spec.Name() << ". " @@ -65,36 +70,36 @@ sycl::queue DeviceManagerOneAPI::GetQueue(const DeviceOrd& device_spec) const { if (device_spec.IsSyclDefault()) { auto& devices = device_register.devices; CHECK_LT(device_idx, devices.size()); - queue_register[device_spec.Name()] = sycl::queue(devices[device_idx]); + queue_register[device_spec.Name()] = ::sycl::queue(devices[device_idx]); } else if (device_spec.IsSyclCPU()) { auto& cpu_devices = device_register.cpu_devices; CHECK_LT(device_idx, cpu_devices.size()); - queue_register[device_spec.Name()] = sycl::queue(cpu_devices[device_idx]);; + queue_register[device_spec.Name()] = ::sycl::queue(cpu_devices[device_idx]);; } else if (device_spec.IsSyclGPU()) { auto& gpu_devices = device_register.gpu_devices; CHECK_LT(device_idx, gpu_devices.size()); - queue_register[device_spec.Name()] = sycl::queue(gpu_devices[device_idx]); + queue_register[device_spec.Name()] = ::sycl::queue(gpu_devices[device_idx]); } } else { if (device_spec.IsSyclCPU()) { - queue_register[device_spec.Name()] = sycl::queue(sycl::cpu_selector_v); + queue_register[device_spec.Name()] = ::sycl::queue(::sycl::cpu_selector_v); } else if (device_spec.IsSyclGPU()) { - queue_register[device_spec.Name()] = sycl::queue(sycl::gpu_selector_v); + queue_register[device_spec.Name()] = ::sycl::queue(::sycl::gpu_selector_v); } else { - queue_register[device_spec.Name()] = sycl::queue(sycl::default_selector_v); + queue_register[device_spec.Name()] = ::sycl::queue(::sycl::default_selector_v); } } return queue_register.at(device_spec.Name()); } -DeviceManagerOneAPI::DeviceRegister& DeviceManagerOneAPI::GetDevicesRegister() const { +DeviceManager::DeviceRegister& DeviceManager::GetDevicesRegister() const { static DeviceRegister device_register; if (device_register.devices.size() == 0) { std::lock_guard guard(device_registering_mutex); - std::vector devices = sycl::device::get_devices(); + std::vector<::sycl::device> devices = ::sycl::device::get_devices(); for (size_t i = 0; i < devices.size(); i++) { - LOG(INFO) << "device_index = " << i << ", name = " << devices[i].get_info(); + LOG(INFO) << "device_index = " << i << ", name = " << devices[i].get_info<::sycl::info::device::name>(); } for (size_t i = 0; i < devices.size(); i++) { @@ -109,9 +114,10 @@ DeviceManagerOneAPI::DeviceRegister& DeviceManagerOneAPI::GetDevicesRegister() c return device_register; } -DeviceManagerOneAPI::QueueRegister_t& DeviceManagerOneAPI::GetQueueRegister() const { +DeviceManager::QueueRegister_t& DeviceManager::GetQueueRegister() const { static QueueRegister_t queue_register; return queue_register; } +} // namespace sycl } // namespace xgboost \ No newline at end of file diff --git a/plugin/sycl/device_manager.h b/plugin/sycl/device_manager.h new file mode 100644 index 000000000000..7941b9c47536 --- /dev/null +++ b/plugin/sycl/device_manager.h @@ -0,0 +1,44 @@ +/*! + * Copyright 2017-2022 by Contributors + * \file device_manager.h + */ +#ifndef XGBOOST_DEVICE_MANAGER_SYCL_H_ +#define XGBOOST_DEVICE_MANAGER_SYCL_H_ + +#include +#include +#include + +#include "CL/sycl.hpp" +#include "xgboost/context.h" + +namespace xgboost { +namespace sycl { + +class DeviceManager { + public: + ::sycl::queue GetQueue(const DeviceOrd& device_spec) const; + + ::sycl::device GetDevice(const DeviceOrd& device_spec) const; + + private: + using QueueRegister_t = std::unordered_map; + + struct DeviceRegister { + std::vector<::sycl::device> devices; + std::vector<::sycl::device> cpu_devices; + std::vector<::sycl::device> gpu_devices; + }; + + QueueRegister_t& GetQueueRegister() const; + + DeviceRegister& GetDevicesRegister() const; + + mutable std::mutex queue_registering_mutex; + mutable std::mutex device_registering_mutex; +}; + +} // namespace sycl +} // namespace xgboost + +#endif // XGBOOST_DEVICE_MANAGER_SYCL_H_ \ No newline at end of file diff --git a/plugin/updater_oneapi/multiclass_obj_oneapi.cc b/plugin/sycl/objective/multiclass_obj.cc similarity index 50% rename from plugin/updater_oneapi/multiclass_obj_oneapi.cc rename to plugin/sycl/objective/multiclass_obj.cc index 89f2a6b0a2a1..6f86ec6c21bf 100644 --- a/plugin/updater_oneapi/multiclass_obj_oneapi.cc +++ b/plugin/sycl/objective/multiclass_obj.cc @@ -1,30 +1,37 @@ /*! * Copyright 2015-2023 by Contributors - * \file multiclass_obj_oneapi.cc + * \file multiclass_obj.cc * \brief Definition of multi-class classification objectives. */ #include #include #include #include +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wtautological-constant-compare" +#pragma GCC diagnostic ignored "-W#pragma-messages" #include - +#pragma GCC diagnostic pop #include "xgboost/parameter.h" +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wtautological-constant-compare" #include "xgboost/data.h" +#pragma GCC diagnostic pop #include "xgboost/logging.h" #include "xgboost/objective.h" #include "xgboost/json.h" -#include "device_manager_oneapi.h" +#include "../device_manager.h" #include "CL/sycl.hpp" namespace xgboost { +namespace sycl { namespace obj { -DMLC_REGISTRY_FILE_TAG(multiclass_obj_oneapi); +DMLC_REGISTRY_FILE_TAG(multiclass_obj_sycl); /*! @@ -36,14 +43,14 @@ DMLC_REGISTRY_FILE_TAG(multiclass_obj_oneapi); * \param end end iterator of input */ template -inline void SoftmaxOneAPI(Iterator start, Iterator end) { +inline void Softmax(Iterator start, Iterator end) { bst_float wmax = *start; for (Iterator i = start+1; i != end; ++i) { - wmax = sycl::max(*i, wmax); + wmax = ::sycl::max(*i, wmax); } float wsum = 0.0f; for (Iterator i = start; i != end; ++i) { - *i = sycl::exp(*i - wmax); + *i = ::sycl::exp(*i - wmax); wsum += *i; } for (Iterator i = start; i != end; ++i) { @@ -60,7 +67,7 @@ inline void SoftmaxOneAPI(Iterator start, Iterator end) { * \tparam Iterator The type of the iterator. */ template -inline Iterator FindMaxIndexOneAPI(Iterator begin, Iterator end) { +inline Iterator FindMaxIndex(Iterator begin, Iterator end) { Iterator maxit = begin; for (Iterator it = begin; it != end; ++it) { if (*it > *maxit) maxit = it; @@ -69,19 +76,19 @@ inline Iterator FindMaxIndexOneAPI(Iterator begin, Iterator end) { } -struct SoftmaxMultiClassParamOneAPI : public XGBoostParameter { +struct SoftmaxMultiClassParam : public XGBoostParameter { int num_class; // declare parameters - DMLC_DECLARE_PARAMETER(SoftmaxMultiClassParamOneAPI) { + DMLC_DECLARE_PARAMETER(SoftmaxMultiClassParam) { DMLC_DECLARE_FIELD(num_class).set_lower_bound(1) .describe("Number of output class in the multi-class classification."); } }; -class SoftmaxMultiClassObjOneAPI : public ObjFunction { +class SoftmaxMultiClassObj : public ObjFunction { public: - explicit SoftmaxMultiClassObjOneAPI(bool output_prob) + explicit SoftmaxMultiClassObj(bool output_prob) : output_prob_(output_prob) {} @@ -99,7 +106,7 @@ class SoftmaxMultiClassObjOneAPI : public ObjFunction { return; } CHECK(preds.Size() == (static_cast(param_.num_class) * info.labels.Size())) - << "SoftmaxMultiClassObjOneAPI: label size and pred size does not match.\n" + << "SoftmaxMultiClassObj: label size and pred size does not match.\n" << "label.Size() * num_class: " << info.labels.Size() * static_cast(param_.num_class) << "\n" << "num_class: " << param_.num_class << "\n" @@ -119,65 +126,53 @@ class SoftmaxMultiClassObjOneAPI : public ObjFunction { << "Number of weights should be equal to number of data points."; } - - sycl::buffer preds_buf(preds.HostPointer(), preds.Size()); - sycl::buffer labels_buf(info.labels.Data()->HostPointer(), info.labels.Size()); - sycl::buffer out_gpair_buf(out_gpair->HostPointer(), out_gpair->Size()); - sycl::buffer weights_buf(is_null_weight ? NULL : info.weights_.HostPointer(), + ::sycl::buffer preds_buf(preds.HostPointer(), preds.Size()); + ::sycl::buffer labels_buf(info.labels.Data()->HostPointer(), info.labels.Size()); + ::sycl::buffer out_gpair_buf(out_gpair->HostPointer(), out_gpair->Size()); + ::sycl::buffer weights_buf(is_null_weight ? NULL : info.weights_.HostPointer(), is_null_weight ? 1 : info.weights_.Size()); - - sycl::buffer additional_input_buf(1); - { - auto additional_input_acc = additional_input_buf.template get_access(); - additional_input_acc[0] = 1; // Fill the label_correct flag - } - - - qu_.submit([&](sycl::handler& cgh) { - auto preds_acc = preds_buf.template get_access(cgh); - auto labels_acc = labels_buf.template get_access(cgh); - auto weights_acc = weights_buf.template get_access(cgh); - auto out_gpair_acc = out_gpair_buf.template get_access(cgh); - auto additional_input_acc = additional_input_buf.template get_access(cgh); - cgh.parallel_for<>(sycl::range<1>(ndata), [=](sycl::id<1> pid) { - int idx = pid[0]; - - - bst_float const * point = &preds_acc[idx * nclass]; - - - // Part of Softmax function - bst_float wmax = std::numeric_limits::min(); - for (int k = 0; k < nclass; k++) { wmax = sycl::max(point[k], wmax); } - float wsum = 0.0f; - for (int k = 0; k < nclass; k++) { wsum += sycl::exp(point[k] - wmax); } - auto label = labels_acc[idx]; - if (label < 0 || label >= nclass) { - additional_input_acc[0] = 0; - label = 0; - } - bst_float wt = is_null_weight ? 1.0f : weights_acc[idx]; - for (int k = 0; k < nclass; ++k) { - bst_float p = expf(point[k] - wmax) / static_cast(wsum); - const float eps = 1e-16f; - const bst_float h = sycl::max(2.0f * p * (1.0f - p) * wt, eps); - p = label == k ? p - 1.0f : p; - out_gpair_acc[idx * nclass + k] = GradientPair(p * wt, h); - } - }); - }).wait(); - - int flag = 1; - { - auto additional_input_acc = additional_input_buf.template get_access(); - flag = additional_input_acc[0]; - } - + { + ::sycl::buffer additional_input_buf(&flag, 1); + qu_.submit([&](::sycl::handler& cgh) { + auto preds_acc = preds_buf.template get_access<::sycl::access::mode::read>(cgh); + auto labels_acc = labels_buf.template get_access<::sycl::access::mode::read>(cgh); + auto weights_acc = weights_buf.template get_access<::sycl::access::mode::read>(cgh); + auto out_gpair_acc = out_gpair_buf.template get_access<::sycl::access::mode::write>(cgh); + auto additional_input_acc = additional_input_buf.template get_access<::sycl::access::mode::write>(cgh); + cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) { + int idx = pid[0]; + + + bst_float const * point = &preds_acc[idx * nclass]; + + + // Part of Softmax function + bst_float wmax = std::numeric_limits::min(); + for (int k = 0; k < nclass; k++) { wmax = ::sycl::max(point[k], wmax); } + float wsum = 0.0f; + for (int k = 0; k < nclass; k++) { wsum += ::sycl::exp(point[k] - wmax); } + auto label = labels_acc[idx]; + if (label < 0 || label >= nclass) { + additional_input_acc[0] = 0; + label = 0; + } + bst_float wt = is_null_weight ? 1.0f : weights_acc[idx]; + for (int k = 0; k < nclass; ++k) { + bst_float p = expf(point[k] - wmax) / static_cast(wsum); + const float eps = 1e-16f; + const bst_float h = ::sycl::max(2.0f * p * (1.0f - p) * wt, eps); + p = label == k ? p - 1.0f : p; + out_gpair_acc[idx * nclass + k] = GradientPair(p * wt, h); + } + }); + }).wait(); + } + // additional_input_buf is destroyed, content is copyed to the "flag" if (flag == 0) { - LOG(FATAL) << "SoftmaxMultiClassObjOneAPI: label must be in [0, num_class)."; + LOG(FATAL) << "SYCL::SoftmaxMultiClassObj: label must be in [0, num_class)."; } } void PredTransform(HostDeviceVector* io_preds) const override { @@ -198,29 +193,29 @@ class SoftmaxMultiClassObjOneAPI : public ObjFunction { { - sycl::buffer io_preds_buf(io_preds->HostPointer(), io_preds->Size()); + ::sycl::buffer io_preds_buf(io_preds->HostPointer(), io_preds->Size()); if (prob) { - qu_.submit([&](sycl::handler& cgh) { - auto io_preds_acc = io_preds_buf.template get_access(cgh); - cgh.parallel_for<>(sycl::range<1>(ndata), [=](sycl::id<1> pid) { + qu_.submit([&](::sycl::handler& cgh) { + auto io_preds_acc = io_preds_buf.template get_access<::sycl::access::mode::read_write>(cgh); + cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) { int idx = pid[0]; bst_float * point = &io_preds_acc[idx * nclass]; - SoftmaxOneAPI(point, point + nclass); + Softmax(point, point + nclass); }); }).wait(); } else { - sycl::buffer max_preds_buf(max_preds_.HostPointer(), max_preds_.Size()); + ::sycl::buffer max_preds_buf(max_preds_.HostPointer(), max_preds_.Size()); - qu_.submit([&](sycl::handler& cgh) { - auto io_preds_acc = io_preds_buf.template get_access(cgh); - auto max_preds_acc = max_preds_buf.template get_access(cgh); - cgh.parallel_for<>(sycl::range<1>(ndata), [=](sycl::id<1> pid) { + qu_.submit([&](::sycl::handler& cgh) { + auto io_preds_acc = io_preds_buf.template get_access<::sycl::access::mode::read>(cgh); + auto max_preds_acc = max_preds_buf.template get_access<::sycl::access::mode::read_write>(cgh); + cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) { int idx = pid[0]; bst_float const * point = &io_preds_acc[idx * nclass]; - max_preds_acc[idx] = FindMaxIndexOneAPI(point, point + nclass) - point; + max_preds_acc[idx] = FindMaxIndex(point, point + nclass) - point; }); }).wait(); } @@ -240,9 +235,9 @@ class SoftmaxMultiClassObjOneAPI : public ObjFunction { void SaveConfig(Json* p_out) const override { auto& out = *p_out; if (this->output_prob_) { - out["name"] = String("multi:softprob_oneapi"); + out["name"] = String("multi:softprob_sycl"); } else { - out["name"] = String("multi:softmax_oneapi"); + out["name"] = String("multi:softmax_sycl"); } out["softmax_multiclass_param"] = ToJson(param_); } @@ -257,29 +252,30 @@ class SoftmaxMultiClassObjOneAPI : public ObjFunction { // output probability bool output_prob_; // parameter - SoftmaxMultiClassParamOneAPI param_; + SoftmaxMultiClassParam param_; // Cache for max_preds mutable HostDeviceVector max_preds_; - DeviceManagerOneAPI device_manager; + sycl::DeviceManager device_manager; - mutable sycl::queue qu_; + mutable ::sycl::queue qu_; }; // register the objective functions -DMLC_REGISTER_PARAMETER(SoftmaxMultiClassParamOneAPI); +DMLC_REGISTER_PARAMETER(SoftmaxMultiClassParam); -XGBOOST_REGISTER_OBJECTIVE(SoftmaxMultiClassOneAPI, "multi:softmax_oneapi") +XGBOOST_REGISTER_OBJECTIVE(SoftmaxMultiClass, "multi:softmax_sycl") .describe("Softmax for multi-class classification, output class index.") -.set_body([]() { return new SoftmaxMultiClassObjOneAPI(false); }); +.set_body([]() { return new SoftmaxMultiClassObj(false); }); -XGBOOST_REGISTER_OBJECTIVE(SoftprobMultiClassOneAPI, "multi:softprob_oneapi") +XGBOOST_REGISTER_OBJECTIVE(SoftprobMultiClass, "multi:softprob_sycl") .describe("Softmax for multi-class classification, output probability distribution.") -.set_body([]() { return new SoftmaxMultiClassObjOneAPI(true); }); +.set_body([]() { return new SoftmaxMultiClassObj(true); }); } // namespace obj +} // namespace sycl } // namespace xgboost diff --git a/plugin/updater_oneapi/regression_loss_oneapi.h b/plugin/sycl/objective/regression_loss.h similarity index 75% rename from plugin/updater_oneapi/regression_loss_oneapi.h rename to plugin/sycl/objective/regression_loss.h index beab461e2aff..e1ae55602a64 100755 --- a/plugin/updater_oneapi/regression_loss_oneapi.h +++ b/plugin/sycl/objective/regression_loss.h @@ -1,8 +1,8 @@ /*! * Copyright 2017-2023 XGBoost contributors */ -#ifndef XGBOOST_OBJECTIVE_REGRESSION_LOSS_ONEAPI_H_ -#define XGBOOST_OBJECTIVE_REGRESSION_LOSS_ONEAPI_H_ +#ifndef XGBOOST_OBJECTIVE_REGRESSION_LOSS_SYCL_H_ +#define XGBOOST_OBJECTIVE_REGRESSION_LOSS_SYCL_H_ #include #include @@ -11,6 +11,7 @@ #include "CL/sycl.hpp" namespace xgboost { +namespace sycl { namespace obj { /*! @@ -18,13 +19,13 @@ namespace obj { * \param x input parameter * \return the transformed value. */ -inline float SigmoidOneAPI(float x) { - return 1.0f / (1.0f + sycl::exp(-x)); +inline float Sigmoid(float x) { + return 1.0f / (1.0f + ::sycl::exp(-x)); } // common regressions // linear regression -struct LinearSquareLossOneAPI { +struct LinearSquareLoss { static bst_float PredTransform(bst_float x) { return x; } static bool CheckLabel(bst_float x) { return true; } static bst_float FirstOrderGradient(bst_float predt, bst_float label) { @@ -37,25 +38,25 @@ struct LinearSquareLossOneAPI { static const char* LabelErrorMsg() { return ""; } static const char* DefaultEvalMetric() { return "rmse"; } - static const char* Name() { return "reg:squarederror_oneapi"; } + static const char* Name() { return "reg:squarederror_sycl"; } static ObjInfo Info() { return {ObjInfo::kRegression, true, false}; } }; -// TODO: DPC++ does not fully support std math inside offloaded kernels -struct SquaredLogErrorOneAPI { +// TODO: SYCL does not fully support std math inside offloaded kernels +struct SquaredLogError { static bst_float PredTransform(bst_float x) { return x; } static bool CheckLabel(bst_float label) { return label > -1; } static bst_float FirstOrderGradient(bst_float predt, bst_float label) { predt = std::max(predt, (bst_float)(-1 + 1e-6)); // ensure correct value for log1p - return (sycl::log1p(predt) - sycl::log1p(label)) / (predt + 1); + return (::sycl::log1p(predt) - ::sycl::log1p(label)) / (predt + 1); } static bst_float SecondOrderGradient(bst_float predt, bst_float label) { predt = std::max(predt, (bst_float)(-1 + 1e-6)); - float res = (-sycl::log1p(predt) + sycl::log1p(label) + 1) / - sycl::pow(predt + 1, (bst_float)2); + float res = (-::sycl::log1p(predt) + ::sycl::log1p(label) + 1) / + ::sycl::pow(predt + 1, (bst_float)2); res = std::max(res, (bst_float)1e-6f); return res; } @@ -65,16 +66,16 @@ struct SquaredLogErrorOneAPI { } static const char* DefaultEvalMetric() { return "rmsle"; } - static const char* Name() { return "reg:squaredlogerror_oneapi"; } + static const char* Name() { return "reg:squaredlogerror_sycl"; } static ObjInfo Info() { return ObjInfo::kRegression; } }; // logistic loss for probability regression task -struct LogisticRegressionOneAPI { +struct LogisticRegression { // duplication is necessary, as __device__ specifier // cannot be made conditional on template parameter - static bst_float PredTransform(bst_float x) { return SigmoidOneAPI(x); } + static bst_float PredTransform(bst_float x) { return Sigmoid(x); } static bool CheckLabel(bst_float x) { return x >= 0.0f && x <= 1.0f; } static bst_float FirstOrderGradient(bst_float predt, bst_float label) { return predt - label; @@ -84,7 +85,7 @@ struct LogisticRegressionOneAPI { return std::max(predt * (1.0f - predt), eps); } template - static T PredTransform(T x) { return SigmoidOneAPI(x); } + static T PredTransform(T x) { return Sigmoid(x); } template static T FirstOrderGradient(T predt, T label) { return predt - label; } template @@ -102,52 +103,53 @@ struct LogisticRegressionOneAPI { } static const char* DefaultEvalMetric() { return "rmse"; } - static const char* Name() { return "reg:logistic_oneapi"; } + static const char* Name() { return "reg:logistic_sycl"; } static ObjInfo Info() { return ObjInfo::kRegression; } }; // logistic loss for binary classification task -struct LogisticClassificationOneAPI : public LogisticRegressionOneAPI { +struct LogisticClassification : public LogisticRegression { static const char* DefaultEvalMetric() { return "logloss"; } - static const char* Name() { return "binary:logistic_oneapi"; } + static const char* Name() { return "binary:logistic_sycl"; } }; // logistic loss, but predict un-transformed margin -struct LogisticRawOneAPI : public LogisticRegressionOneAPI { +struct LogisticRaw : public LogisticRegression { // duplication is necessary, as __device__ specifier // cannot be made conditional on template parameter static bst_float PredTransform(bst_float x) { return x; } static bst_float FirstOrderGradient(bst_float predt, bst_float label) { - predt = SigmoidOneAPI(predt); + predt = Sigmoid(predt); return predt - label; } static bst_float SecondOrderGradient(bst_float predt, bst_float label) { const bst_float eps = 1e-16f; - predt = SigmoidOneAPI(predt); + predt = Sigmoid(predt); return std::max(predt * (1.0f - predt), eps); } template static T PredTransform(T x) { return x; } template static T FirstOrderGradient(T predt, T label) { - predt = SigmoidOneAPI(predt); + predt = Sigmoid(predt); return predt - label; } template static T SecondOrderGradient(T predt, T label) { const T eps = T(1e-16f); - predt = SigmoidOneAPI(predt); + predt = Sigmoid(predt); return std::max(predt * (T(1.0f) - predt), eps); } static const char* DefaultEvalMetric() { return "logloss"; } - static const char* Name() { return "binary:logitraw_oneapi"; } + static const char* Name() { return "binary:logitraw_sycl"; } static ObjInfo Info() { return ObjInfo::kRegression; } }; } // namespace obj +} // namespace sycl } // namespace xgboost -#endif // XGBOOST_OBJECTIVE_REGRESSION_LOSS_ONEAPI_H_ +#endif // XGBOOST_OBJECTIVE_REGRESSION_LOSS_SYCL_H_ diff --git a/plugin/sycl/objective/regression_obj.cc b/plugin/sycl/objective/regression_obj.cc new file mode 100755 index 000000000000..bd59cd729408 --- /dev/null +++ b/plugin/sycl/objective/regression_obj.cc @@ -0,0 +1,192 @@ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wtautological-constant-compare" +#pragma GCC diagnostic ignored "-W#pragma-messages" +#include +#include +#pragma GCC diagnostic pop +#include +#include +#include +#include + +#include "xgboost/host_device_vector.h" +#include "xgboost/json.h" +#include "xgboost/parameter.h" +#include "xgboost/span.h" + +#include "../../src/common/transform.h" +#include "../../src/common/common.h" +#include "regression_loss.h" +#include "../device_manager.h" + +#include "CL/sycl.hpp" + +namespace xgboost { +namespace sycl { +namespace obj { + +DMLC_REGISTRY_FILE_TAG(regression_obj_sycl); + +struct RegLossParam : public XGBoostParameter { + float scale_pos_weight; + // declare parameters + DMLC_DECLARE_PARAMETER(RegLossParam) { + DMLC_DECLARE_FIELD(scale_pos_weight).set_default(1.0f).set_lower_bound(0.0f) + .describe("Scale the weight of positive examples by this factor"); + } +}; + +template +class RegLossObj : public ObjFunction { + protected: + HostDeviceVector label_correct_; + + public: + RegLossObj() = default; + + void Configure(const std::vector >& args) override { + param_.UpdateAllowUnknown(args); + qu_ = device_manager.GetQueue(ctx_->Device()); + } + + void GetGradient(const HostDeviceVector& preds, + const MetaInfo &info, + int iter, + HostDeviceVector* out_gpair) override { + if (info.labels.Size() == 0U) { + LOG(WARNING) << "Label set is empty."; + } + CHECK_EQ(preds.Size(), info.labels.Size()) + << " " << "labels are not correctly provided" + << "preds.size=" << preds.Size() << ", label.size=" << info.labels.Size() << ", " + << "Loss: " << Loss::Name(); + + size_t const ndata = preds.Size(); + out_gpair->Resize(ndata); + + // TODO: add label_correct check + label_correct_.Resize(1); + label_correct_.Fill(1); + + bool is_null_weight = info.weights_.Size() == 0; + + ::sycl::buffer preds_buf(preds.HostPointer(), preds.Size()); + ::sycl::buffer labels_buf(info.labels.Data()->HostPointer(), info.labels.Size()); + ::sycl::buffer out_gpair_buf(out_gpair->HostPointer(), out_gpair->Size()); + ::sycl::buffer weights_buf(is_null_weight ? NULL : info.weights_.HostPointer(), + is_null_weight ? 1 : info.weights_.Size()); + + const size_t n_targets = std::max(info.labels.Shape(1), static_cast(1)); + + auto scale_pos_weight = param_.scale_pos_weight; + if (!is_null_weight) { + CHECK_EQ(info.weights_.Size(), info.labels.Shape(0)) + << "Number of weights should be equal to number of data points."; + } + + int flag = 1; + { + ::sycl::buffer additional_input_buf(&flag, 1); + qu_.submit([&](::sycl::handler& cgh) { + auto preds_acc = preds_buf.get_access<::sycl::access::mode::read>(cgh); + auto labels_acc = labels_buf.get_access<::sycl::access::mode::read>(cgh); + auto weights_acc = weights_buf.get_access<::sycl::access::mode::read>(cgh); + auto out_gpair_acc = out_gpair_buf.get_access<::sycl::access::mode::write>(cgh); + auto additional_input_acc = additional_input_buf.get_access<::sycl::access::mode::write>(cgh); + cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) { + int idx = pid[0]; + bst_float p = Loss::PredTransform(preds_acc[idx]); + bst_float w = is_null_weight ? 1.0f : weights_acc[idx/n_targets]; + bst_float label = labels_acc[idx]; + if (label == 1.0f) { + w *= scale_pos_weight; + } + if (!Loss::CheckLabel(label)) { + // If there is an incorrect label, the host code will know. + additional_input_acc[0] = 0; + } + out_gpair_acc[idx] = GradientPair(Loss::FirstOrderGradient(p, label) * w, + Loss::SecondOrderGradient(p, label) * w); + }); + }).wait(); + } + // additional_input_buf is destroyed, content is copyed to the "flag" + + if (flag == 0) { + LOG(FATAL) << Loss::LabelErrorMsg(); + } + + } + + public: + const char* DefaultEvalMetric() const override { + return Loss::DefaultEvalMetric(); + } + + void PredTransform(HostDeviceVector *io_preds) const override { + size_t const ndata = io_preds->Size(); + ::sycl::buffer io_preds_buf(io_preds->HostPointer(), io_preds->Size()); + + qu_.submit([&](::sycl::handler& cgh) { + auto io_preds_acc = io_preds_buf.get_access<::sycl::access::mode::read_write>(cgh); + cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) { + int idx = pid[0]; + io_preds_acc[idx] = Loss::PredTransform(io_preds_acc[idx]); + }); + }).wait(); + } + + float ProbToMargin(float base_score) const override { + return Loss::ProbToMargin(base_score); + } + + struct ObjInfo Task() const override { + return Loss::Info(); + }; + + uint32_t Targets(MetaInfo const& info) const override { + // Multi-target regression. + return std::max(static_cast(1), info.labels.Shape(1)); + } + + void SaveConfig(Json* p_out) const override { + auto& out = *p_out; + out["name"] = String(Loss::Name()); + out["reg_loss_param"] = ToJson(param_); + } + + void LoadConfig(Json const& in) override { + FromJson(in["reg_loss_param"], ¶m_); + } + + protected: + RegLossParam param_; + sycl::DeviceManager device_manager; + + mutable ::sycl::queue qu_; +}; + +// register the objective functions +DMLC_REGISTER_PARAMETER(RegLossParam); + +// TODO: Find a better way to dispatch names of SYCL kernels with various template parameters of loss function +XGBOOST_REGISTER_OBJECTIVE(SquaredLossRegression, LinearSquareLoss::Name()) +.describe("Regression with squared error with SYCL backend.") +.set_body([]() { return new RegLossObj(); }); +XGBOOST_REGISTER_OBJECTIVE(SquareLogError, SquaredLogError::Name()) +.describe("Regression with root mean squared logarithmic error with SYCL backend.") +.set_body([]() { return new RegLossObj(); }); +XGBOOST_REGISTER_OBJECTIVE(LogisticRegression, LogisticRegression::Name()) +.describe("Logistic regression for probability regression task with SYCL backend.") +.set_body([]() { return new RegLossObj(); }); +XGBOOST_REGISTER_OBJECTIVE(LogisticClassification, LogisticClassification::Name()) +.describe("Logistic regression for binary classification task with SYCL backend.") +.set_body([]() { return new RegLossObj(); }); +XGBOOST_REGISTER_OBJECTIVE(LogisticRaw, LogisticRaw::Name()) +.describe("Logistic regression for classification, output score " + "before logistic transformation with SYCL backend.") +.set_body([]() { return new RegLossObj(); }); + +} // namespace obj +} // namespace sycl +} // namespace xgboost diff --git a/plugin/updater_oneapi/predictor_oneapi.cc b/plugin/sycl/predictor/predictor.cc similarity index 60% rename from plugin/updater_oneapi/predictor_oneapi.cc rename to plugin/sycl/predictor/predictor.cc index 121391ffa90b..346a14061b6c 100755 --- a/plugin/updater_oneapi/predictor_oneapi.cc +++ b/plugin/sycl/predictor/predictor.cc @@ -4,9 +4,13 @@ #include #include #include +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wtautological-constant-compare" +#pragma GCC diagnostic ignored "-W#pragma-messages" #include +#pragma GCC diagnostic pop -#include "data_oneapi.h" +#include "../data.h" #include "dmlc/registry.h" @@ -14,104 +18,26 @@ #include "xgboost/predictor.h" #include "xgboost/tree_updater.h" +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wtautological-constant-compare" #include "../../src/data/adapter.h" +#pragma GCC diagnostic pop #include "../../src/common/math.h" #include "../../src/gbm/gbtree_model.h" -#include "./device_manager_oneapi.h" +#include "../device_manager.h" #include "CL/sycl.hpp" namespace xgboost { +namespace sycl { namespace predictor { -DMLC_REGISTRY_FILE_TAG(predictor_oneapi); - -class PredictorOneAPI : public Predictor { - void SetupBackend() { - const DeviceOrd device_spec = ctx_->Device(); - - bool is_cpu; - if (device_spec.IsSycl()) { - sycl::device device = device_manager.GetDevice(device_spec); - is_cpu = device.is_cpu(); - } else { - is_cpu = true; - } - LOG(INFO) << "device = " << device_spec.Name() << ", is_cpu = " << int(is_cpu); - if (is_cpu) { - predictor_backend_.reset(Predictor::Create("cpu_predictor", ctx_)); - } else{ - predictor_backend_.reset(Predictor::Create("oneapi_predictor_backend", ctx_)); - } - } - - public: - explicit PredictorOneAPI(Context const* context) : - Predictor::Predictor{context} { - SetupBackend(); - } - - void Configure(const std::vector>& args) override { - SetupBackend(); - predictor_backend_->Configure(args); - } - - void PredictBatch(DMatrix *dmat, PredictionCacheEntry *predts, - const gbm::GBTreeModel &model, uint32_t tree_begin, - uint32_t tree_end = 0) const override { - predictor_backend_->PredictBatch(dmat, predts, model, tree_begin, tree_end); - } - - bool InplacePredict(std::shared_ptr p_m, - const gbm::GBTreeModel &model, float missing, - PredictionCacheEntry *out_preds, uint32_t tree_begin, - unsigned tree_end) const override { - return predictor_backend_->InplacePredict(p_m, model, missing, out_preds, tree_begin, tree_end); - } - - void PredictInstance(const SparsePage::Inst& inst, - std::vector* out_preds, - const gbm::GBTreeModel& model, unsigned ntree_limit, - bool is_column_split) const override { - predictor_backend_->PredictInstance(inst, out_preds, model, ntree_limit, is_column_split); - } - - void PredictLeaf(DMatrix* p_fmat, HostDeviceVector* out_preds, - const gbm::GBTreeModel& model, unsigned ntree_limit) const override { - predictor_backend_->PredictLeaf(p_fmat, out_preds, model, ntree_limit); - } - - void PredictContribution(DMatrix* p_fmat, HostDeviceVector* out_contribs, - const gbm::GBTreeModel& model, uint32_t ntree_limit, - const std::vector* tree_weights, - bool approximate, int condition, - unsigned condition_feature) const override { - predictor_backend_->PredictContribution(p_fmat, out_contribs, model, ntree_limit, tree_weights, approximate, condition, condition_feature); - } - - void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector* out_contribs, - const gbm::GBTreeModel& model, unsigned ntree_limit, - const std::vector* tree_weights, - bool approximate) const override { - predictor_backend_->PredictInteractionContributions(p_fmat, out_contribs, model, ntree_limit, tree_weights, approximate); - } - - protected: - void InitOutPredictions(const MetaInfo& info, - HostDeviceVector* out_preds, - const gbm::GBTreeModel& model) const { - predictor_backend_->InitOutPredictions(info, out_preds, model); - } - - private: - DeviceManagerOneAPI device_manager; - std::unique_ptr predictor_backend_; -}; +DMLC_REGISTRY_FILE_TAG(predictor_sycl); /* Wrapper for descriptor of a tree node */ -struct DeviceNodeOneAPI { - DeviceNodeOneAPI() +struct DeviceNode { + DeviceNode() : fidx(-1), left_child_idx(-1), right_child_idx(-1) {} union NodeValue { @@ -124,7 +50,7 @@ struct DeviceNodeOneAPI { int right_child_idx; NodeValue val; - DeviceNodeOneAPI(const RegTree::Node& n) { + DeviceNode(const RegTree::Node& n) { this->left_child_idx = n.LeftChild(); this->right_child_idx = n.RightChild(); this->fidx = n.SplitIndex(); @@ -158,22 +84,22 @@ struct DeviceNodeOneAPI { float GetWeight() const { return val.leaf_weight; } }; -/* OneAPI implementation of a device model, storing tree structure in USM buffers to provide access from device kernels */ -class DeviceModelOneAPI { +/* SYCL implementation of a device model, storing tree structure in USM buffers to provide access from device kernels */ +class DeviceModel { public: - sycl::queue qu_; - USMVector nodes_; + ::sycl::queue qu_; + USMVector nodes_; USMVector tree_segments_; USMVector tree_group_; size_t tree_beg_; size_t tree_end_; int num_group_; - DeviceModelOneAPI() {} + DeviceModel() {} - ~DeviceModelOneAPI() {} + ~DeviceModel() {} - void Init(sycl::queue qu, const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end) { + void Init(::sycl::queue qu, const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end) { qu_ = qu; tree_segments_.Resize(qu_, (tree_end - tree_begin) + 1); @@ -230,8 +156,8 @@ float GetFvalue(int ridx, int fidx, Entry* data, size_t* row_ptr, bool& is_missi return 0.0; } -float GetLeafWeight(int ridx, const DeviceNodeOneAPI* tree, Entry* data, size_t* row_ptr) { - DeviceNodeOneAPI n = tree[0]; +float GetLeafWeight(int ridx, const DeviceNode* tree, Entry* data, size_t* row_ptr) { + DeviceNode n = tree[0]; int node_id = 0; bool is_missing; while (!n.IsLeaf()) { @@ -252,8 +178,8 @@ float GetLeafWeight(int ridx, const DeviceNodeOneAPI* tree, Entry* data, size_t* return n.GetWeight(); } -void DevicePredictInternal(sycl::queue qu, - DeviceMatrixOneAPI* dmat, +void DevicePredictInternal(::sycl::queue qu, + sycl::DeviceMatrix* dmat, HostDeviceVector* out_preds, const gbm::GBTreeModel& model, size_t tree_begin, @@ -261,13 +187,13 @@ void DevicePredictInternal(sycl::queue qu, if (tree_end - tree_begin == 0) { return; } - DeviceModelOneAPI device_model; + DeviceModel device_model; device_model.Init(qu, model, tree_begin, tree_end); auto& out_preds_vec = out_preds->HostVector(); - DeviceNodeOneAPI* nodes = device_model.nodes_.Data(); - sycl::buffer out_preds_buf(out_preds_vec.data(), out_preds_vec.size()); + DeviceNode* nodes = device_model.nodes_.Data(); + ::sycl::buffer out_preds_buf(out_preds_vec.data(), out_preds_vec.size()); size_t* tree_segments = device_model.tree_segments_.Data(); int* tree_group = device_model.tree_group_.Data(); size_t* row_ptr = dmat->row_ptr.Data(); @@ -276,21 +202,21 @@ void DevicePredictInternal(sycl::queue qu, int num_rows = dmat->row_ptr.Size() - 1; int num_group = model.learner_model_param->num_output_group; - qu.submit([&](sycl::handler& cgh) { - auto out_predictions = out_preds_buf.template get_access(cgh); - cgh.parallel_for<>(sycl::range<1>(num_rows), [=](sycl::id<1> pid) { + qu.submit([&](::sycl::handler& cgh) { + auto out_predictions = out_preds_buf.template get_access<::sycl::access::mode::read_write>(cgh); + cgh.parallel_for<>(::sycl::range<1>(num_rows), [=](::sycl::id<1> pid) { int global_idx = pid[0]; if (global_idx >= num_rows) return; if (num_group == 1) { float sum = 0.0; for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { - const DeviceNodeOneAPI* tree = nodes + tree_segments[tree_idx - tree_begin]; + const DeviceNode* tree = nodes + tree_segments[tree_idx - tree_begin]; sum += GetLeafWeight(global_idx, tree, data, row_ptr); } out_predictions[global_idx] += sum; } else { for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { - const DeviceNodeOneAPI* tree = nodes + tree_segments[tree_idx - tree_begin]; + const DeviceNode* tree = nodes + tree_segments[tree_idx - tree_begin]; int out_prediction_idx = global_idx * num_group + tree_group[tree_idx]; out_predictions[out_prediction_idx] += GetLeafWeight(global_idx, tree, data, row_ptr); } @@ -299,7 +225,7 @@ void DevicePredictInternal(sycl::queue qu, }).wait(); } -class PredictorBackendOneAPI : public Predictor { +class Predictor : public xgboost::Predictor { protected: void InitOutPredictions(const MetaInfo& info, HostDeviceVector* out_preds, @@ -334,10 +260,8 @@ class PredictorBackendOneAPI : public Predictor { } public: - explicit PredictorBackendOneAPI(Context const* context) : - Predictor::Predictor{context}, cpu_predictor(Predictor::Create("cpu_predictor", context)) { - qu_ = device_manager.GetQueue(context->Device()); - } + explicit Predictor(Context const* context) : + xgboost::Predictor::Predictor{context}, cpu_predictor(xgboost::Predictor::Create("cpu_predictor", context)) {} void PredictBatch(DMatrix *dmat, PredictionCacheEntry *predts, const gbm::GBTreeModel &model, uint32_t tree_begin, @@ -347,13 +271,14 @@ class PredictorBackendOneAPI : public Predictor { if (this->device_matrix_cache_.find(dmat) == this->device_matrix_cache_.end()) { this->device_matrix_cache_.emplace( - dmat, std::unique_ptr( - new DeviceMatrixOneAPI(qu_, dmat))); + dmat, std::unique_ptr( + new sycl::DeviceMatrix(qu_, dmat))); } - DeviceMatrixOneAPI* device_matrix = device_matrix_cache_.find(dmat)->second.get(); + sycl::DeviceMatrix* device_matrix = device_matrix_cache_.find(dmat)->second.get(); */ - DeviceMatrixOneAPI device_matrix(qu_, dmat); // TODO: remove temporary workaround after cache fix + ::sycl::queue qu = device_manager.GetQueue(ctx_->Device()); + sycl::DeviceMatrix device_matrix(qu, dmat); // TODO: remove temporary workaround after cache fix auto* out_preds = &predts->predictions; if (tree_end == 0) { @@ -361,7 +286,7 @@ class PredictorBackendOneAPI : public Predictor { } if (tree_begin < tree_end) { - DevicePredictInternal(qu_, &device_matrix, out_preds, model, tree_begin, tree_end); + DevicePredictInternal(qu, &device_matrix, out_preds, model, tree_begin, tree_end); } } @@ -369,6 +294,7 @@ class PredictorBackendOneAPI : public Predictor { const gbm::GBTreeModel &model, float missing, PredictionCacheEntry *out_preds, uint32_t tree_begin, unsigned tree_end) const override { + LOG(WARNING) << "InplacePredict is not yet implemented for SYCL. CPU Predictor is used."; return cpu_predictor->InplacePredict(p_m, model, missing, out_preds, tree_begin, tree_end); } @@ -376,11 +302,13 @@ class PredictorBackendOneAPI : public Predictor { std::vector* out_preds, const gbm::GBTreeModel& model, unsigned ntree_limit, bool is_column_split) const override { + LOG(WARNING) << "PredictInstance is not yet implemented for SYCL. CPU Predictor is used."; cpu_predictor->PredictInstance(inst, out_preds, model, ntree_limit, is_column_split); } void PredictLeaf(DMatrix* p_fmat, HostDeviceVector* out_preds, const gbm::GBTreeModel& model, unsigned ntree_limit) const override { + LOG(WARNING) << "PredictLeaf is not yet implemented for SYCL. CPU Predictor is used."; cpu_predictor->PredictLeaf(p_fmat, out_preds, model, ntree_limit); } @@ -389,6 +317,7 @@ class PredictorBackendOneAPI : public Predictor { const std::vector* tree_weights, bool approximate, int condition, unsigned condition_feature) const override { + LOG(WARNING) << "PredictContribution is not yet implemented for SYCL. CPU Predictor is used."; cpu_predictor->PredictContribution(p_fmat, out_contribs, model, ntree_limit, tree_weights, approximate, condition, condition_feature); } @@ -396,25 +325,22 @@ class PredictorBackendOneAPI : public Predictor { const gbm::GBTreeModel& model, unsigned ntree_limit, const std::vector* tree_weights, bool approximate) const override { + LOG(WARNING) << "PredictInteractionContributions is not yet implemented for SYCL. CPU Predictor is used."; cpu_predictor->PredictInteractionContributions(p_fmat, out_contribs, model, ntree_limit, tree_weights, approximate); } private: - DeviceManagerOneAPI device_manager; - sycl::queue qu_; + DeviceManager device_manager; - std::unique_ptr cpu_predictor; + std::unique_ptr cpu_predictor; - std::unordered_map> device_matrix_cache_; + std::unordered_map> device_matrix_cache_; }; -XGBOOST_REGISTER_PREDICTOR(PredictorOneAPI, "oneapi_predictor") -.describe("Make predictions using DPC++.") -.set_body([](Context const *ctx) { return new PredictorOneAPI(ctx); }); - -XGBOOST_REGISTER_PREDICTOR(PredictorBackendOneAPI, "oneapi_predictor_backend") -.describe("Make predictions using DPC++.") -.set_body([](Context const* ctx) { return new PredictorBackendOneAPI(ctx); }); +XGBOOST_REGISTER_PREDICTOR(Predictor, "sycl_predictor") +.describe("Make predictions using SYCL.") +.set_body([](Context const* ctx) { return new Predictor(ctx); }); } // namespace predictor +} // namespace sycl } // namespace xgboost diff --git a/plugin/updater_oneapi/param_oneapi.h b/plugin/sycl/tree/param.h similarity index 79% rename from plugin/updater_oneapi/param_oneapi.h rename to plugin/sycl/tree/param.h index 27ff1132f0cb..52cc5c5ab438 100644 --- a/plugin/updater_oneapi/param_oneapi.h +++ b/plugin/sycl/tree/param.h @@ -1,8 +1,8 @@ /*! * Copyright 2014-2023 by Contributors */ -#ifndef XGBOOST_TREE_PARAM_ONEAPI_H_ -#define XGBOOST_TREE_PARAM_ONEAPI_H_ +#ifndef XGBOOST_TREE_PARAM_SYCL_H_ +#define XGBOOST_TREE_PARAM_SYCL_H_ #include @@ -14,25 +14,28 @@ #include "xgboost/parameter.h" #include "xgboost/data.h" +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wtautological-constant-compare" #include "../src/tree/param.h" - +#pragma GCC diagnostic pop namespace xgboost { +namespace sycl { namespace tree { /*! \brief Wrapper for necessary training parameters for regression tree to access on device */ -struct TrainParamOneAPI { +struct TrainParam { float min_child_weight; float reg_lambda; float reg_alpha; float max_delta_step; - TrainParamOneAPI() {} + TrainParam() {} - TrainParamOneAPI(const TrainParam& param) { + TrainParam(const xgboost::tree::TrainParam& param) { reg_lambda = param.reg_lambda; reg_alpha = param.reg_alpha; min_child_weight = param.min_child_weight; @@ -43,7 +46,7 @@ struct TrainParamOneAPI { /*! \brief core statistics used for tree construction */ template -struct GradStatsOneAPI { +struct GradStats { /*! \brief sum gradient statistics */ GradType sum_grad { 0 }; /*! \brief sum hessian statistics */ @@ -55,20 +58,20 @@ struct GradStatsOneAPI { GradType GetHess() const { return sum_hess; } - friend std::ostream& operator<<(std::ostream& os, GradStatsOneAPI s) { + friend std::ostream& operator<<(std::ostream& os, GradStats s) { os << s.GetGrad() << "/" << s.GetHess(); return os; } - GradStatsOneAPI() { + GradStats() { } template - explicit GradStatsOneAPI(const GpairT &sum) + explicit GradStats(const GpairT &sum) : sum_grad(sum.GetGrad()), sum_hess(sum.GetHess()) {} - explicit GradStatsOneAPI(const GradType grad, const GradType hess) + explicit GradStats(const GradType grad, const GradType hess) : sum_grad(grad), sum_hess(hess) {} /*! * \brief accumulate statistics @@ -78,16 +81,16 @@ struct GradStatsOneAPI { /*! \brief add statistics to the data */ - inline void Add(const GradStatsOneAPI& b) { + inline void Add(const GradStats& b) { sum_grad += b.sum_grad; sum_hess += b.sum_hess; } /*! \brief same as add, reduce is used in All Reduce */ - inline static void Reduce(GradStatsOneAPI& a, const GradStatsOneAPI& b) { // NOLINT(*) + inline static void Reduce(GradStats& a, const GradStats& b) { // NOLINT(*) a.Add(b); } /*! \brief set current value to a - b */ - inline void SetSubstract(const GradStatsOneAPI& a, const GradStatsOneAPI& b) { + inline void SetSubstract(const GradStats& a, const GradStats& b) { sum_grad = a.sum_grad - b.sum_grad; sum_hess = a.sum_hess - b.sum_hess; } @@ -102,11 +105,11 @@ struct GradStatsOneAPI { /*! - * \brief OneAPI implementation of SplitEntryContainer for device compilation. + * \brief SYCL implementation of SplitEntryContainer for device compilation. * Original structure cannot be used due to std::isinf usage, which is not supported */ template -struct SplitEntryContainerOneAPI { +struct SplitEntryContainer { /*! \brief loss change after split this node */ bst_float loss_chg {0.0f}; /*! \brief split index */ @@ -118,10 +121,10 @@ struct SplitEntryContainerOneAPI { GradientT right_sum; - SplitEntryContainerOneAPI() = default; + SplitEntryContainer() = default; - friend std::ostream& operator<<(std::ostream& os, SplitEntryContainerOneAPI const& s) { + friend std::ostream& operator<<(std::ostream& os, SplitEntryContainer const& s) { os << "loss_chg: " << s.loss_chg << ", " << "split index: " << s.SplitIndex() << ", " << "split value: " << s.split_value << ", " @@ -144,7 +147,7 @@ struct SplitEntryContainerOneAPI { * \param split_index the feature index where the split is on */ inline bool NeedReplace(bst_float new_loss_chg, unsigned split_index) const { - if (sycl::isinf(new_loss_chg)) { // in some cases new_loss_chg can be NaN or Inf, + if (::sycl::isinf(new_loss_chg)) { // in some cases new_loss_chg can be NaN or Inf, // for example when lambda = 0 & min_child_weight = 0 // skip value in this case return false; @@ -159,7 +162,7 @@ struct SplitEntryContainerOneAPI { * \param e candidate split solution * \return whether the proposed split is better and can replace current split */ - inline bool Update(const SplitEntryContainerOneAPI &e) { + inline bool Update(const SplitEntryContainer &e) { if (this->NeedReplace(e.loss_chg, e.SplitIndex())) { this->loss_chg = e.loss_chg; this->sindex = e.sindex; @@ -200,17 +203,18 @@ struct SplitEntryContainerOneAPI { /*! \brief same as update, used by AllReduce*/ - inline static void Reduce(SplitEntryContainerOneAPI &dst, // NOLINT(*) - const SplitEntryContainerOneAPI &src) { // NOLINT(*) + inline static void Reduce(SplitEntryContainer &dst, // NOLINT(*) + const SplitEntryContainer &src) { // NOLINT(*) dst.Update(src); } }; template -using SplitEntryOneAPI = SplitEntryContainerOneAPI>; +using SplitEntry = SplitEntryContainer>; -} -} -#endif // XGBOOST_TREE_PARAM_H_ +} // namespace tree +} // namespace sycl +} // namespace xgboost +#endif // XGBOOST_TREE_PARAM_SYCL_H_ diff --git a/plugin/updater_oneapi/split_evaluator_oneapi.h b/plugin/sycl/tree/split_evaluator.h similarity index 77% rename from plugin/updater_oneapi/split_evaluator_oneapi.h rename to plugin/sycl/tree/split_evaluator.h index 593c7a4910d6..e8809b96686b 100644 --- a/plugin/updater_oneapi/split_evaluator_oneapi.h +++ b/plugin/sycl/tree/split_evaluator.h @@ -2,8 +2,8 @@ * Copyright 2018-2023 by Contributors */ -#ifndef XGBOOST_TREE_SPLIT_EVALUATOR_ONEAPI_H_ -#define XGBOOST_TREE_SPLIT_EVALUATOR_ONEAPI_H_ +#ifndef XGBOOST_TREE_SPLIT_EVALUATOR_SYCL_H_ +#define XGBOOST_TREE_SPLIT_EVALUATOR_SYCL_H_ #include #include @@ -11,7 +11,7 @@ #include #include -#include "param_oneapi.h" +#include "param.h" #include "xgboost/tree_model.h" #include "xgboost/host_device_vector.h" @@ -23,15 +23,16 @@ #include "CL/sycl.hpp" namespace xgboost { +namespace sycl { namespace tree { -/*! \brief OneAPI implementation of TreeEvaluator, with USM memory for temporary buffer to access on device. +/*! \brief SYCL implementation of TreeEvaluator, with USM memory for temporary buffer to access on device. * It also contains own implementation of SplitEvaluator for device compilation, because some of the functions from the original SplitEvaluator are currently not supported */ template -class TreeEvaluatorOneAPI { +class TreeEvaluator { // hist and exact use parent id to calculate constraints. static constexpr bst_node_t kRootParentId = (-1 & static_cast((1U << 31) - 1)); @@ -39,12 +40,12 @@ class TreeEvaluatorOneAPI { USMVector lower_bounds_; USMVector upper_bounds_; USMVector monotone_; - TrainParamOneAPI param_; - sycl::queue qu_; + TrainParam param_; + ::sycl::queue qu_; bool has_constraint_; public: - TreeEvaluatorOneAPI(sycl::queue qu, TrainParam const& p, bst_feature_t n_features) { + TreeEvaluator(::sycl::queue qu, xgboost::tree::TrainParam const& p, bst_feature_t n_features) { qu_ = qu; if (p.monotone_constraints.empty()) { monotone_.Resize(qu_, n_features, 0); @@ -56,7 +57,7 @@ class TreeEvaluatorOneAPI { upper_bounds_.Resize(qu_, p.MaxNodes(), std::numeric_limits::max()); has_constraint_ = true; } - param_ = TrainParamOneAPI(p); + param_ = TrainParam(p); } struct SplitEvaluator { @@ -64,12 +65,12 @@ class TreeEvaluatorOneAPI { GradType* lower; GradType* upper; bool has_constraint; - TrainParamOneAPI param; + TrainParam param; GradType CalcSplitGain(bst_node_t nidx, bst_feature_t fidx, - const GradStatsOneAPI& left, - const GradStatsOneAPI& right) const { + const GradStats& left, + const GradStats& right) const { int constraint = constraints[fidx]; const GradType negative_infinity = -std::numeric_limits::infinity(); GradType wleft = this->CalcWeight(nidx, left); @@ -86,7 +87,7 @@ class TreeEvaluatorOneAPI { } } - inline GradType ThresholdL1OneAPI(GradType w, GradType alpha) const { + inline GradType ThresholdL1(GradType w, GradType alpha) const { if (w > + alpha) { return w - alpha; } @@ -96,19 +97,19 @@ class TreeEvaluatorOneAPI { return 0.0; } - inline GradType CalcWeightOneAPI(GradType sum_grad, GradType sum_hess) const { + inline GradType CalcWeight(GradType sum_grad, GradType sum_hess) const { if (sum_hess < param.min_child_weight || sum_hess <= 0.0) { return 0.0; } - GradType dw = -this->ThresholdL1OneAPI(sum_grad, param.reg_alpha) / (sum_hess + param.reg_lambda); + GradType dw = -this->ThresholdL1(sum_grad, param.reg_alpha) / (sum_hess + param.reg_lambda); if (param.max_delta_step != 0.0f && std::abs(dw) > param.max_delta_step) { - dw = sycl::copysign((GradType)param.max_delta_step, dw); + dw = ::sycl::copysign((GradType)param.max_delta_step, dw); } return dw; } - inline GradType CalcWeight(bst_node_t nodeid, const GradStatsOneAPI& stats) const { - GradType w = this->CalcWeightOneAPI(stats.GetGrad(), stats.GetHess()); + inline GradType CalcWeight(bst_node_t nodeid, const GradStats& stats) const { + GradType w = this->CalcWeight(stats.GetGrad(), stats.GetHess()); if (!has_constraint) { return w; } @@ -130,19 +131,19 @@ class TreeEvaluatorOneAPI { return -(2.0f * sum_grad * w + (sum_hess + param.reg_lambda) * this->Sqr(w)); } - inline GradType CalcGainGivenWeight(bst_node_t nid, const GradStatsOneAPI& stats, GradType w) const { + inline GradType CalcGainGivenWeight(bst_node_t nid, const GradStats& stats, GradType w) const { if (stats.GetHess() <= 0) { return .0f; } // Avoiding tree::CalcGainGivenWeight can significantly reduce avg floating point error. if (param.max_delta_step == 0.0f && has_constraint == false) { - return this->Sqr(this->ThresholdL1OneAPI(stats.sum_grad, param.reg_alpha)) / + return this->Sqr(this->ThresholdL1(stats.sum_grad, param.reg_alpha)) / (stats.sum_hess + param.reg_lambda); } return this->CalcGainGivenWeight(stats.sum_grad, stats.sum_hess, w); } - GradType CalcGain(bst_node_t nid, const GradStatsOneAPI& stats) const { + GradType CalcGain(bst_node_t nid, const GradStats& stats) const { return this->CalcGainGivenWeight(nid, stats, this->CalcWeight(nid, stats)); } }; @@ -165,8 +166,8 @@ class TreeEvaluatorOneAPI { GradType* lower = lower_bounds_.Data(); GradType* upper = upper_bounds_.Data(); int* monotone = monotone_.Data(); - qu_.submit([&](sycl::handler& cgh) { - cgh.parallel_for<>(sycl::range<1>(1), [=](sycl::item<1> pid) { + qu_.submit([&](::sycl::handler& cgh) { + cgh.parallel_for<>(::sycl::range<1>(1), [=](::sycl::item<1> pid) { lower[leftid] = lower[nodeid]; upper[leftid] = upper[nodeid]; @@ -187,6 +188,7 @@ class TreeEvaluatorOneAPI { } }; } // namespace tree +} // namespace sycl } // namespace xgboost -#endif // XGBOOST_TREE_SPLIT_EVALUATOR_ONEAPI_H_ +#endif // XGBOOST_TREE_SPLIT_EVALUATOR_SYCL_H_ diff --git a/plugin/updater_oneapi/updater_quantile_hist_oneapi.cc b/plugin/sycl/tree/updater_quantile_hist.cc similarity index 76% rename from plugin/updater_oneapi/updater_quantile_hist_oneapi.cc rename to plugin/sycl/tree/updater_quantile_hist.cc index 579b903cb54c..a23ae310d321 100644 --- a/plugin/updater_oneapi/updater_quantile_hist_oneapi.cc +++ b/plugin/sycl/tree/updater_quantile_hist.cc @@ -1,61 +1,34 @@ /*! * Copyright 2017-2023 by Contributors - * \file updater_quantile_hist_oneapi.cc + * \file updater_quantile_hist.cc */ #include +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wtautological-constant-compare" +#pragma GCC diagnostic ignored "-W#pragma-messages" #include -#include "xgboost/logging.h" #include "xgboost/tree_updater.h" +#pragma GCC diagnostic pop + +#include "xgboost/logging.h" -#include "updater_quantile_hist_oneapi.h" -#include "data_oneapi.h" +#include "updater_quantile_hist.h" +#include "../data.h" namespace xgboost { +namespace sycl { namespace tree { -using sycl::ext::oneapi::plus; -using sycl::ext::oneapi::minimum; -using sycl::ext::oneapi::maximum; - -DMLC_REGISTRY_FILE_TAG(updater_quantile_hist_oneapi); - -DMLC_REGISTER_PARAMETER(OneAPIHistMakerTrainParam); - -void QuantileHistMakerOneAPI::Configure(const Args& args) { - const DeviceOrd device_spec = ctx_->Device(); - - sycl::device device = device_manager.GetDevice(device_spec); - bool is_cpu = device.is_cpu(); - LOG(INFO) << "device = " << device_spec.Name() << ", is_cpu = " << int(is_cpu); - - if (is_cpu) - { - updater_backend_.reset(TreeUpdater::Create("grow_quantile_histmaker", ctx_, task_)); - updater_backend_->Configure(args); - } - else - { - updater_backend_.reset(TreeUpdater::Create("grow_quantile_histmaker_oneapi_backend", ctx_, task_)); - updater_backend_->Configure(args); - } -} +using ::sycl::ext::oneapi::plus; +using ::sycl::ext::oneapi::minimum; +using ::sycl::ext::oneapi::maximum; -void QuantileHistMakerOneAPI::Update(TrainParam const *param, - HostDeviceVector *gpair, - DMatrix *dmat, - common::Span> out_position, - const std::vector &trees) { - updater_backend_->Update(param, gpair, dmat, out_position, trees); -} +DMLC_REGISTRY_FILE_TAG(updater_quantile_hist_sycl); -bool QuantileHistMakerOneAPI::UpdatePredictionCache( - const DMatrix* data, - linalg::MatrixView out_preds) { - return updater_backend_->UpdatePredictionCache(data, out_preds); -} +DMLC_REGISTER_PARAMETER(HistMakerTrainParam); -void QuantileHistMakerOneAPIBackend::Configure(const Args& args) { +void QuantileHistMaker::Configure(const Args& args) { const DeviceOrd device_spec = ctx_->Device(); qu_ = device_manager.GetQueue(device_spec); @@ -69,7 +42,7 @@ void QuantileHistMakerOneAPIBackend::Configure(const Args& args) { } template -void QuantileHistMakerOneAPIBackend::SetBuilder(std::unique_ptr>* builder, +void QuantileHistMaker::SetBuilder(std::unique_ptr>* builder, DMatrix *dmat) { builder->reset(new Builder( qu_, @@ -77,35 +50,35 @@ void QuantileHistMakerOneAPIBackend::SetBuilder(std::unique_ptrSetHistSynchronizer(new DistributedHistSynchronizerOneAPI()); - (*builder)->SetHistRowsAdder(new DistributedHistRowsAdderOneAPI()); + (*builder)->SetHistSynchronizer(new DistributedHistSynchronizer()); + (*builder)->SetHistRowsAdder(new DistributedHistRowsAdder()); } else { - (*builder)->SetHistSynchronizer(new BatchHistSynchronizerOneAPI()); - (*builder)->SetHistRowsAdder(new BatchHistRowsAdderOneAPI()); + (*builder)->SetHistSynchronizer(new BatchHistSynchronizer()); + (*builder)->SetHistRowsAdder(new BatchHistRowsAdder()); } } template -void QuantileHistMakerOneAPIBackend::CallBuilderUpdate(const std::unique_ptr>& builder, - TrainParam const *param, - HostDeviceVector *gpair, - DMatrix *dmat, - common::Span> out_position, - const std::vector &trees) { +void QuantileHistMaker::CallBuilderUpdate(const std::unique_ptr>& builder, + xgboost::tree::TrainParam const *param, + HostDeviceVector *gpair, + DMatrix *dmat, + xgboost::common::Span> out_position, + const std::vector &trees) { const std::vector& gpair_h = gpair->ConstHostVector(); USMVector gpair_device(qu_, gpair_h); for (auto tree : trees) { builder->Update(ctx_, param, gmat_, gpair, gpair_device, dmat, out_position, tree); } } -void QuantileHistMakerOneAPIBackend::Update(TrainParam const *param, - HostDeviceVector *gpair, - DMatrix *dmat, - common::Span> out_position, - const std::vector &trees) { +void QuantileHistMaker::Update(xgboost::tree::TrainParam const *param, + HostDeviceVector *gpair, + DMatrix *dmat, + xgboost::common::Span> out_position, + const std::vector &trees) { if (dmat != p_last_dmat_ || is_gmat_initialized_ == false) { updater_monitor_.Start("GmatInitialization"); - DeviceMatrixOneAPI dmat_device(qu_, dmat); + sycl::DeviceMatrix dmat_device(qu_, dmat); gmat_.Init(qu_, ctx_, dmat_device, static_cast(param_.max_bin)); updater_monitor_.Stop("GmatInitialization"); is_gmat_initialized_ = true; @@ -115,7 +88,7 @@ void QuantileHistMakerOneAPIBackend::Update(TrainParam const *param, param_.learning_rate = lr / trees.size(); int_constraint_.Configure(param_, dmat->Info().num_col_); // build tree - bool has_double_support = qu_.get_device().has(sycl::aspect::fp64); + bool has_double_support = qu_.get_device().has(::sycl::aspect::fp64); if (hist_maker_param_.single_precision_histogram || !has_double_support) { if (!hist_maker_param_.single_precision_histogram) { LOG(WARNING) << "Target device doesn't support fp64, using single_precision_histogram=True"; @@ -136,12 +109,12 @@ void QuantileHistMakerOneAPIBackend::Update(TrainParam const *param, p_last_dmat_ = dmat; } -bool QuantileHistMakerOneAPIBackend::UpdatePredictionCache(const DMatrix* data, +bool QuantileHistMaker::UpdatePredictionCache(const DMatrix* data, linalg::MatrixView out_preds) { if (param_.subsample < 1.0f) { return false; } else { - bool has_double_support = qu_.get_device().has(sycl::aspect::fp64); + bool has_double_support = qu_.get_device().has(::sycl::aspect::fp64); if ((hist_maker_param_.single_precision_histogram || !has_double_support) && float_builder_) { return float_builder_->UpdatePredictionCache(data, out_preds); } else if (double_builder_) { @@ -153,9 +126,9 @@ bool QuantileHistMakerOneAPIBackend::UpdatePredictionCache(const DMatrix* data, } template -void BatchHistSynchronizerOneAPI::SyncHistograms(BuilderT *builder, - std::vector& sync_ids, - RegTree *p_tree) { +void BatchHistSynchronizer::SyncHistograms(BuilderT *builder, + std::vector& sync_ids, + RegTree *p_tree) { builder->builder_monitor_.Start("SyncHistograms"); const size_t nbins = builder->hist_builder_.GetNumBins(); @@ -168,7 +141,7 @@ void BatchHistSynchronizerOneAPI::SyncHistograms(BuilderT *builder const size_t parent_id = (*p_tree)[entry.nid].Parent(); auto parent_hist = builder->hist_[parent_id]; auto sibling_hist = builder->hist_[entry.sibling_nid]; - hist_sync_events_[i] = common::SubtractionHist(builder->qu_, sibling_hist, parent_hist, this_hist, nbins, sycl::event()); + hist_sync_events_[i] = common::SubtractionHist(builder->qu_, sibling_hist, parent_hist, this_hist, nbins, ::sycl::event()); } } builder->qu_.wait_and_throw(); @@ -177,7 +150,7 @@ void BatchHistSynchronizerOneAPI::SyncHistograms(BuilderT *builder } template -void DistributedHistSynchronizerOneAPI::SyncHistograms(BuilderT* builder, +void DistributedHistSynchronizer::SyncHistograms(BuilderT* builder, std::vector& sync_ids, RegTree *p_tree) { builder->builder_monitor_.Start("SyncHistograms"); @@ -193,7 +166,7 @@ void DistributedHistSynchronizerOneAPI::SyncHistograms(BuilderT* b const size_t parent_id = (*p_tree)[entry.nid].Parent(); auto parent_hist = builder->hist_local_worker_[parent_id]; auto sibling_hist = builder->hist_[entry.sibling_nid]; - common::SubtractionHist(builder->qu_, sibling_hist, parent_hist, this_hist, nbins, sycl::event()); + common::SubtractionHist(builder->qu_, sibling_hist, parent_hist, this_hist, nbins, ::sycl::event()); // Store posible parent node auto sibling_local = builder->hist_local_worker_[entry.sibling_nid]; common::CopyHist(builder->qu_, sibling_local, sibling_hist, nbins); @@ -208,7 +181,7 @@ void DistributedHistSynchronizerOneAPI::SyncHistograms(BuilderT* b } template -void DistributedHistSynchronizerOneAPI::ParallelSubtractionHist( +void DistributedHistSynchronizer::ParallelSubtractionHist( BuilderT* builder, const std::vector& nodes, const RegTree * p_tree) { @@ -221,14 +194,14 @@ void DistributedHistSynchronizerOneAPI::ParallelSubtractionHist( if (!(*p_tree)[entry.nid].IsRoot() && entry.sibling_nid > -1) { auto parent_hist = builder->hist_[(*p_tree)[entry.nid].Parent()]; auto sibling_hist = builder->hist_[entry.sibling_nid]; - common::SubtractionHist(builder->qu_, this_hist, parent_hist, sibling_hist, nbins, sycl::event()); + common::SubtractionHist(builder->qu_, this_hist, parent_hist, sibling_hist, nbins, ::sycl::event()); } } } } template -void QuantileHistMakerOneAPIBackend::Builder::ReduceHists(std::vector& sync_ids, size_t nbins) { +void QuantileHistMaker::Builder::ReduceHists(std::vector& sync_ids, size_t nbins) { std::vector reduce_buffer(sync_ids.size() * nbins); for (size_t i = 0; i < sync_ids.size(); i++) { auto this_hist = hist_[sync_ids[i]]; @@ -247,9 +220,9 @@ void QuantileHistMakerOneAPIBackend::Builder::ReduceHists(std::vec } template -void BatchHistRowsAdderOneAPI::AddHistRows(BuilderT *builder, - std::vector& sync_ids, - RegTree *p_tree) { +void BatchHistRowsAdder::AddHistRows(BuilderT *builder, + std::vector& sync_ids, + RegTree *p_tree) { builder->builder_monitor_.Start("AddHistRows"); int max_nid = 0; @@ -275,9 +248,9 @@ void BatchHistRowsAdderOneAPI::AddHistRows(BuilderT *builder, } template -void DistributedHistRowsAdderOneAPI::AddHistRows(BuilderT *builder, - std::vector& sync_ids, - RegTree *p_tree) { +void DistributedHistRowsAdder::AddHistRows(BuilderT *builder, + std::vector& sync_ids, + RegTree *p_tree) { builder->builder_monitor_.Start("AddHistRows"); const size_t explicit_size = builder->nodes_for_explicit_hist_build_.size(); const size_t subtaction_size = builder->nodes_for_subtraction_trick_.size(); @@ -308,21 +281,21 @@ void DistributedHistRowsAdderOneAPI::AddHistRows(BuilderT *builder } template -void QuantileHistMakerOneAPIBackend::Builder::SetHistSynchronizer( - HistSynchronizerOneAPI *sync) { +void QuantileHistMaker::Builder::SetHistSynchronizer( + HistSynchronizer *sync) { hist_synchronizer_.reset(sync); } template -void QuantileHistMakerOneAPIBackend::Builder::SetHistRowsAdder( - HistRowsAdderOneAPI *adder) { +void QuantileHistMaker::Builder::SetHistRowsAdder( + HistRowsAdder *adder) { hist_rows_adder_.reset(adder); } template -void QuantileHistMakerOneAPIBackend::Builder::BuildHistogramsLossGuide( +void QuantileHistMaker::Builder::BuildHistogramsLossGuide( ExpandEntry entry, - const GHistIndexMatrixOneAPI &gmat, + const GHistIndexMatrix &gmat, RegTree *p_tree, const USMVector &gpair_device) { nodes_for_explicit_hist_build_.clear(); @@ -342,14 +315,14 @@ void QuantileHistMakerOneAPIBackend::Builder::BuildHistogramsLossG } template -void QuantileHistMakerOneAPIBackend::Builder::BuildLocalHistograms( - const GHistIndexMatrixOneAPI &gmat, +void QuantileHistMaker::Builder::BuildLocalHistograms( + const GHistIndexMatrix &gmat, RegTree *p_tree, const USMVector &gpair_device) { - builder_monitor_.Start("BuildLocalHistogramsOneAPI"); + builder_monitor_.Start("BuildLocalHistograms"); const size_t n_nodes = nodes_for_explicit_hist_build_.size(); for (auto& event : hist_build_events_) { - event = sycl::event(); + event = ::sycl::event(); } const size_t event_idx = 0; @@ -367,12 +340,12 @@ void QuantileHistMakerOneAPIBackend::Builder::BuildLocalHistograms } } qu_.wait_and_throw(); - builder_monitor_.Stop("BuildLocalHistogramsOneAPI"); + builder_monitor_.Stop("BuildLocalHistograms"); } template -void QuantileHistMakerOneAPIBackend::Builder::BuildNodeStats( - const GHistIndexMatrixOneAPI &gmat, +void QuantileHistMaker::Builder::BuildNodeStats( + const GHistIndexMatrix &gmat, DMatrix *p_fmat, RegTree *p_tree, const std::vector &gpair) { @@ -397,8 +370,8 @@ void QuantileHistMakerOneAPIBackend::Builder::BuildNodeStats( } template -void QuantileHistMakerOneAPIBackend::Builder::AddSplitsToTree( - const GHistIndexMatrixOneAPI &gmat, +void QuantileHistMaker::Builder::AddSplitsToTree( + const GHistIndexMatrix &gmat, RegTree *p_tree, int *num_leaves, int depth, @@ -418,9 +391,9 @@ void QuantileHistMakerOneAPIBackend::Builder::AddSplitsToTree( NodeEntry& e = snode_[nid]; bst_float left_leaf_weight = - evaluator.CalcWeight(nid, GradStatsOneAPI{e.best.left_sum}) * param_.learning_rate; + evaluator.CalcWeight(nid, GradStats{e.best.left_sum}) * param_.learning_rate; bst_float right_leaf_weight = - evaluator.CalcWeight(nid, GradStatsOneAPI{e.best.right_sum}) * param_.learning_rate; + evaluator.CalcWeight(nid, GradStats{e.best.right_sum}) * param_.learning_rate; p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, e.best.DefaultLeft(), e.weight, left_leaf_weight, right_leaf_weight, e.best.loss_chg, e.stats.GetHess(), @@ -439,8 +412,8 @@ void QuantileHistMakerOneAPIBackend::Builder::AddSplitsToTree( } template -void QuantileHistMakerOneAPIBackend::Builder::EvaluateAndApplySplits( - const GHistIndexMatrixOneAPI &gmat, +void QuantileHistMaker::Builder::EvaluateAndApplySplits( + const GHistIndexMatrix &gmat, RegTree *p_tree, int *num_leaves, int depth, @@ -461,7 +434,7 @@ void QuantileHistMakerOneAPIBackend::Builder::EvaluateAndApplySpli // and use 'Subtraction Trick' to built the histogram for the right child node. // This ensures that the workers operate on the same set of tree nodes. template -void QuantileHistMakerOneAPIBackend::Builder::SplitSiblings( +void QuantileHistMaker::Builder::SplitSiblings( const std::vector &nodes, std::vector *small_siblings, std::vector *big_siblings, @@ -491,8 +464,8 @@ void QuantileHistMakerOneAPIBackend::Builder::SplitSiblings( } template -void QuantileHistMakerOneAPIBackend::Builder::ExpandWithDepthWise( - const GHistIndexMatrixOneAPI &gmat, +void QuantileHistMaker::Builder::ExpandWithDepthWise( + const GHistIndexMatrix &gmat, DMatrix *p_fmat, RegTree *p_tree, const std::vector &gpair, @@ -531,8 +504,8 @@ void QuantileHistMakerOneAPIBackend::Builder::ExpandWithDepthWise( } template -void QuantileHistMakerOneAPIBackend::Builder::ExpandWithLossGuide( - const GHistIndexMatrixOneAPI& gmat, +void QuantileHistMaker::Builder::ExpandWithLossGuide( + const GHistIndexMatrix& gmat, DMatrix* p_fmat, RegTree* p_tree, const std::vector &gpair, @@ -563,9 +536,9 @@ void QuantileHistMakerOneAPIBackend::Builder::ExpandWithLossGuide( auto evaluator = tree_evaluator_.GetEvaluator(); NodeEntry& e = snode_[nid]; bst_float left_leaf_weight = - evaluator.CalcWeight(nid, GradStatsOneAPI{e.best.left_sum}) * param_.learning_rate; + evaluator.CalcWeight(nid, GradStats{e.best.left_sum}) * param_.learning_rate; bst_float right_leaf_weight = - evaluator.CalcWeight(nid, GradStatsOneAPI{e.best.right_sum}) * param_.learning_rate; + evaluator.CalcWeight(nid, GradStats{e.best.right_sum}) * param_.learning_rate; p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, e.best.DefaultLeft(), e.weight, left_leaf_weight, right_leaf_weight, e.best.loss_chg, e.stats.GetHess(), @@ -608,23 +581,23 @@ void QuantileHistMakerOneAPIBackend::Builder::ExpandWithLossGuide( } template -void QuantileHistMakerOneAPIBackend::Builder::Update( +void QuantileHistMaker::Builder::Update( Context const * ctx, - TrainParam const *param, - const GHistIndexMatrixOneAPI &gmat, + xgboost::tree::TrainParam const *param, + const GHistIndexMatrix &gmat, HostDeviceVector *gpair, const USMVector& gpair_device, DMatrix *p_fmat, - common::Span> out_position, + xgboost::common::Span> out_position, RegTree *p_tree) { builder_monitor_.Start("Update"); const std::vector& gpair_h = gpair->ConstHostVector(); - tree_evaluator_ = TreeEvaluatorOneAPI(qu_, param_, p_fmat->Info().num_col_); + tree_evaluator_ = TreeEvaluator(qu_, param_, p_fmat->Info().num_col_); interaction_constraints_.Reset(); this->InitData(ctx, gmat, gpair_h, gpair_device, *p_fmat, *p_tree); - if (param_.grow_policy == TrainParam::kLossGuide) { + if (param_.grow_policy == xgboost::tree::TrainParam::kLossGuide) { ExpandWithLossGuide(gmat, p_fmat, p_tree, gpair_h, gpair_device); } else { ExpandWithDepthWise(gmat, p_fmat, p_tree, gpair_h, gpair_device); @@ -641,7 +614,7 @@ void QuantileHistMakerOneAPIBackend::Builder::Update( } template -bool QuantileHistMakerOneAPIBackend::Builder::UpdatePredictionCache( +bool QuantileHistMaker::Builder::UpdatePredictionCache( const DMatrix* data, linalg::MatrixView out_preds) { // p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in @@ -654,11 +627,11 @@ bool QuantileHistMakerOneAPIBackend::Builder::UpdatePredictionCach const size_t stride = out_preds.Stride(0); const int buffer_size = out_preds.Size()*stride - stride + 1; - sycl::buffer out_preds_buf(&out_preds(0), buffer_size); + ::sycl::buffer out_preds_buf(&out_preds(0), buffer_size); size_t n_nodes = row_set_collection_.Size(); for (size_t node = 0; node < n_nodes; node++) { - const RowSetCollectionOneAPI::Elem& rowset = row_set_collection_[node]; + const RowSetCollection::Elem& rowset = row_set_collection_[node]; if (rowset.begin != nullptr && rowset.end != nullptr && rowset.Size() != 0) { int nid = rowset.node_id; bst_float leaf_value; @@ -675,9 +648,9 @@ bool QuantileHistMakerOneAPIBackend::Builder::UpdatePredictionCach const size_t* rid = rowset.begin; const size_t num_rows = rowset.Size(); - qu_.submit([&](sycl::handler& cgh) { - auto out_predictions = out_preds_buf.template get_access(cgh); - cgh.parallel_for<>(sycl::range<1>(num_rows), [=](sycl::item<1> pid) { + qu_.submit([&](::sycl::handler& cgh) { + auto out_predictions = out_preds_buf.template get_access<::sycl::access::mode::read_write>(cgh); + cgh.parallel_for<>(::sycl::range<1>(num_rows), [=](::sycl::item<1> pid) { out_predictions[rid[pid.get_id(0)]*stride] += leaf_value; }); }).wait(); @@ -688,12 +661,12 @@ bool QuantileHistMakerOneAPIBackend::Builder::UpdatePredictionCach return true; } template -void QuantileHistMakerOneAPIBackend::Builder::InitSampling(const std::vector& gpair, +void QuantileHistMaker::Builder::InitSampling(const std::vector& gpair, const USMVector &gpair_device, const DMatrix& fmat, USMVector& row_indices_device) { const auto& info = fmat.Info(); - auto& rnd = common::GlobalRandom(); + auto& rnd = xgboost::common::GlobalRandom(); #if XGBOOST_CUSTOMIZE_GLOBAL_PRNG std::bernoulli_distribution coin_flip(param_.subsample); size_t j = 0; @@ -775,10 +748,10 @@ void QuantileHistMakerOneAPIBackend::Builder::InitSampling(const s size_t* indices_ptr = row_indices.Data(); const GradientPair* gpair_ptr = gpair_device.DataConst(); const uint8_t* coin_flips_ptr = coin_flips_device.DataConst(); - std::vector events; - events.emplace_back(qu_.submit([&](sycl::handler& cgh) { - cgh.parallel_for<>(sycl::range<1>(sycl::range<1>(nblocks)), - [offsets_ptr, indices_ptr, coin_flips_ptr, block_size, size, gpair_ptr](sycl::item<1> pid) { + std::vector<::sycl::event> events; + events.emplace_back(qu_.submit([&](::sycl::handler& cgh) { + cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(nblocks)), + [offsets_ptr, indices_ptr, coin_flips_ptr, block_size, size, gpair_ptr](::sycl::item<1> pid) { const size_t block = pid.get_id(0); size_t start = block * block_size; @@ -798,10 +771,10 @@ void QuantileHistMakerOneAPIBackend::Builder::InitSampling(const s for (size_t i = 1; i < nblocks; ++i) { const size_t ibegin = i * block_size; const size_t idx = row_indices.Get(qu_, i, &events); - qu_.submit([&](sycl::handler& cgh) { + qu_.submit([&](::sycl::handler& cgh) { cgh.depends_on(events); - cgh.parallel_for<>(sycl::range<1>(sycl::range<1>(idx)), - [indices_ptr, prefix_sum, ibegin](sycl::item<1> pid) { + cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(idx)), + [indices_ptr, prefix_sum, ibegin](::sycl::item<1> pid) { const size_t k = pid.get_id(0); indices_ptr[prefix_sum + k] = indices_ptr[ibegin + k]; }); @@ -815,9 +788,9 @@ void QuantileHistMakerOneAPIBackend::Builder::InitSampling(const s #endif // XGBOOST_CUSTOMIZE_GLOBAL_PRNG } template -void QuantileHistMakerOneAPIBackend::Builder::InitData( +void QuantileHistMaker::Builder::InitData( Context const * ctx, - const GHistIndexMatrixOneAPI& gmat, + const GHistIndexMatrix& gmat, const std::vector& gpair, const USMVector &gpair_device, const DMatrix& fmat, @@ -825,7 +798,7 @@ void QuantileHistMakerOneAPIBackend::Builder::InitData( CHECK((param_.max_depth > 0 || param_.max_leaves > 0)) << "max_depth or max_leaves cannot be both 0 (unlimited); " << "at least one should be a positive quantity."; - if (param_.grow_policy == TrainParam::kDepthWise) { + if (param_.grow_policy == xgboost::tree::TrainParam::kDepthWise) { CHECK(param_.max_depth > 0) << "max_depth cannot be 0 (unlimited) " << "when grow_policy is depthwise."; } @@ -854,7 +827,7 @@ void QuantileHistMakerOneAPIBackend::Builder::InitData( { this->nthread_ = omp_get_num_threads(); } - hist_builder_ = GHistBuilderOneAPI(qu_, nbins); + hist_builder_ = GHistBuilder(qu_, nbins); USMVector& row_indices = row_set_collection_.Data(); row_indices.Resize(qu_, info.num_row_); @@ -862,12 +835,12 @@ void QuantileHistMakerOneAPIBackend::Builder::InitData( // mark subsample and build list of member rows if (param_.subsample < 1.0f) { - CHECK_EQ(param_.sampling_method, TrainParam::kUniform) + CHECK_EQ(param_.sampling_method, xgboost::tree::TrainParam::kUniform) << "Only uniform sampling is supported, " << "gradient-based sampling is only support by GPU Hist."; InitSampling(gpair, gpair_device, fmat, row_indices); } else { - MemStackAllocatorOneAPI buff(this->nthread_); + MemStackAllocator buff(this->nthread_); bool* p_buff = buff.Get(); std::fill(p_buff, p_buff + this->nthread_, false); @@ -906,9 +879,9 @@ void QuantileHistMakerOneAPIBackend::Builder::InitData( qu_.memcpy(p_row_indices, row_indices_buff.data(), j * sizeof(size_t)).wait(); row_indices.Resize(qu_, j); } else { - qu_.submit([&](sycl::handler& cgh) { - cgh.parallel_for<>(sycl::range<1>(sycl::range<1>(info.num_row_)), - [p_row_indices](sycl::item<1> pid) { + qu_.submit([&](::sycl::handler& cgh) { + cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(info.num_row_)), + [p_row_indices](::sycl::item<1> pid) { const size_t idx = pid.get_id(0); p_row_indices[idx] = idx; }); @@ -966,7 +939,7 @@ void QuantileHistMakerOneAPIBackend::Builder::InitData( qu_.wait_and_throw(); } { - if (param_.grow_policy == TrainParam::kLossGuide) { + if (param_.grow_policy == xgboost::tree::TrainParam::kLossGuide) { qexpand_loss_guided_.reset(new ExpandQueue(LossGuide)); } else { qexpand_depth_wise_.clear(); @@ -980,8 +953,8 @@ void QuantileHistMakerOneAPIBackend::Builder::InitData( // then - there are no missing values // else - there are missing values template -bool QuantileHistMakerOneAPIBackend::Builder::SplitContainsMissingValues( - const GradStatsOneAPI& e, const NodeEntry& snode) { +bool QuantileHistMaker::Builder::SplitContainsMissingValues( + const GradStats& e, const NodeEntry& snode) { if (e.GetGrad() == snode.stats.GetGrad() && e.GetHess() == snode.stats.GetHess()) { return false; } else { @@ -991,10 +964,10 @@ bool QuantileHistMakerOneAPIBackend::Builder::SplitContainsMissing // nodes_set - set of nodes to be processed in parallel template -void QuantileHistMakerOneAPIBackend::Builder::EvaluateSplits( +void QuantileHistMaker::Builder::EvaluateSplits( const std::vector& nodes_set, - const GHistIndexMatrixOneAPI& gmat, - const HistCollectionOneAPI& hist, + const GHistIndexMatrix& gmat, + const HistCollection& hist, const RegTree& tree) { builder_monitor_.Start("EvaluateSplits"); @@ -1045,13 +1018,13 @@ void QuantileHistMakerOneAPIBackend::Builder::EvaluateSplits( const bst_float* cut_minval = gmat.cut_device.MinValues().DataConst(); const NodeEntry* snode = snode_.DataConst(); - TrainParamOneAPI param(param_); + TrainParam param(param_); - qu_.submit([&](sycl::handler& cgh) { - cgh.parallel_for<>(sycl::nd_range<2>(sycl::range<2>(total_features, local_size), - sycl::range<2>(1, local_size)), [=](sycl::nd_item<2> pid) [[intel::reqd_sub_group_size(16)]] { - TrainParamOneAPI param_device(param); - typename TreeEvaluatorOneAPI::SplitEvaluator evaluator_device = evaluator; + qu_.submit([&](::sycl::handler& cgh) { + cgh.parallel_for<>(::sycl::nd_range<2>(::sycl::range<2>(total_features, local_size), + ::sycl::range<2>(1, local_size)), [=](::sycl::nd_item<2> pid) [[intel::reqd_sub_group_size(16)]] { + TrainParam param_device(param); + typename TreeEvaluator::SplitEvaluator evaluator_device = evaluator; int i = pid.get_global_id(0); auto sg = pid.get_sub_group(); int nid = split_queries_device[i].nid; @@ -1075,21 +1048,21 @@ void QuantileHistMakerOneAPIBackend::Builder::EvaluateSplits( // for the particular feature fid. template template -GradStatsOneAPI QuantileHistMakerOneAPIBackend::Builder::EnumerateSplit( +GradStats QuantileHistMaker::Builder::EnumerateSplit( const uint32_t* cut_ptr, const bst_float* cut_val, const bst_float* cut_minval, const GradientPairT* hist_data, const NodeEntry& snode, - SplitEntryOneAPI& p_best, + SplitEntry& p_best, bst_uint fid, bst_uint nodeID, - typename TreeEvaluatorOneAPI::SplitEvaluator const &evaluator_device, - const TrainParamOneAPI& param) { - GradStatsOneAPI c; - GradStatsOneAPI e; + typename TreeEvaluator::SplitEvaluator const &evaluator_device, + const TrainParam& param) { + GradStats c; + GradStats e; // best split so far - SplitEntryOneAPI best; + SplitEntry best; // bin boundaries // imin: index (offset) of the minimum value for feature fid @@ -1120,7 +1093,7 @@ GradStatsOneAPI QuantileHistMakerOneAPIBackend::Builder( - evaluator_device.CalcSplitGain(nodeID, fid, GradStatsOneAPI{c}, GradStatsOneAPI{e}) - snode.root_gain); + evaluator_device.CalcSplitGain(nodeID, fid, GradStats{c}, GradStats{e}) - snode.root_gain); if (i == imin) { split_pt = cut_minval[fid]; } else { @@ -1139,18 +1112,18 @@ GradStatsOneAPI QuantileHistMakerOneAPIBackend::Builder -GradStatsOneAPI QuantileHistMakerOneAPIBackend::Builder::EnumerateSplit( - sycl::sub_group& sg, +GradStats QuantileHistMaker::Builder::EnumerateSplit( + ::sycl::sub_group& sg, const uint32_t* cut_ptr, const bst_float* cut_val, const GradientPairT* hist_data, const NodeEntry& snode, - SplitEntryOneAPI& p_best, + SplitEntry& p_best, bst_uint fid, bst_uint nodeID, - typename TreeEvaluatorOneAPI::SplitEvaluator const &evaluator_device, - const TrainParamOneAPI& param) { - SplitEntryOneAPI best; + typename TreeEvaluator::SplitEvaluator const &evaluator_device, + const TrainParam& param) { + SplitEntry best; int32_t ibegin = static_cast(cut_ptr[fid]); int32_t iend = static_cast(cut_ptr[fid + 1]); @@ -1164,14 +1137,14 @@ GradStatsOneAPI QuantileHistMakerOneAPIBackend::Builder()); - GradientSumT e_hess = sum_hess + sycl::inclusive_scan_over_group(sg, hist_data[i].GetHess(), std::plus<>()); + GradientSumT e_grad = sum_grad + ::sycl::inclusive_scan_over_group(sg, hist_data[i].GetGrad(), std::plus<>()); + GradientSumT e_hess = sum_hess + ::sycl::inclusive_scan_over_group(sg, hist_data[i].GetHess(), std::plus<>()); if (e_hess >= param.min_child_weight) { GradientSumT c_grad = tot_grad - e_grad; GradientSumT c_hess = tot_hess - e_hess; if (c_hess >= param.min_child_weight) { - GradStatsOneAPI e(e_grad, e_hess); - GradStatsOneAPI c(c_grad, c_hess); + GradStats e(e_grad, e_hess); + GradStats c(c_grad, c_hess); bst_float loss_chg; bst_float split_pt; loss_chg = static_cast( @@ -1180,28 +1153,28 @@ GradStatsOneAPI QuantileHistMakerOneAPIBackend::Builder()); - sum_hess += sycl::reduce_over_group(sg, hist_data[i].GetHess(), std::plus<>()); + sum_grad += ::sycl::reduce_over_group(sg, hist_data[i].GetGrad(), std::plus<>()); + sum_hess += ::sycl::reduce_over_group(sg, hist_data[i].GetHess(), std::plus<>()); } - bst_float total_loss_chg = sycl::reduce_over_group(sg, best.loss_chg, maximum<>()); - bst_feature_t total_split_index = sycl::reduce_over_group(sg, best.loss_chg == total_loss_chg ? best.SplitIndex() : (1U << 31) - 1U, minimum<>()); + bst_float total_loss_chg = ::sycl::reduce_over_group(sg, best.loss_chg, maximum<>()); + bst_feature_t total_split_index = ::sycl::reduce_over_group(sg, best.loss_chg == total_loss_chg ? best.SplitIndex() : (1U << 31) - 1U, minimum<>()); if (best.loss_chg == total_loss_chg && best.SplitIndex() == total_split_index) p_best.Update(best); - return GradStatsOneAPI(sum_grad, sum_hess); + return GradStats(sum_grad, sum_hess); } // split row indexes (rid_span) to 2 parts (both stored in rid_buf) depending // on comparison of indexes values (idx_span) and split point (split_cond) // Handle dense columns template -inline sycl::event PartitionDenseKernel(sycl::queue& qu, - const GHistIndexMatrixOneAPI& gmat, - const RowSetCollectionOneAPI::Elem& rid_span, +inline ::sycl::event PartitionDenseKernel(::sycl::queue& qu, + const GHistIndexMatrix& gmat, + const RowSetCollection::Elem& rid_span, const size_t fid, const int32_t split_cond, - common::Span& rid_buf, + xgboost::common::Span& rid_buf, size_t* parts_size, - sycl::event priv_event) { + ::sycl::event priv_event) { const size_t row_stride = gmat.row_stride; const BinIdxType* gradient_index = gmat.index.data(); const size_t* rid = rid_span.begin; @@ -1210,9 +1183,9 @@ inline sycl::event PartitionDenseKernel(sycl::queue& qu, size_t* p_rid_buf = rid_buf.data(); - auto event = qu.submit([&](sycl::handler& cgh) { + auto event = qu.submit([&](::sycl::handler& cgh) { cgh.depends_on(priv_event); - cgh.parallel_for<>(sycl::range<1>(range_size), [=](sycl::item<1> nid) { + cgh.parallel_for<>(::sycl::range<1>(range_size), [=](::sycl::item<1> nid) { const size_t id = rid[nid.get_id(0)]; const int32_t value = static_cast(gradient_index[id * row_stride + fid] + offset); const bool is_left = value <= split_cond; @@ -1232,14 +1205,14 @@ inline sycl::event PartitionDenseKernel(sycl::queue& qu, // on comparison of indexes values (idx_span) and split point (split_cond) // Handle dense columns template -inline sycl::event PartitionSparseKernel(sycl::queue& qu, - const GHistIndexMatrixOneAPI& gmat, - const RowSetCollectionOneAPI::Elem& rid_span, +inline ::sycl::event PartitionSparseKernel(::sycl::queue& qu, + const GHistIndexMatrix& gmat, + const RowSetCollection::Elem& rid_span, const size_t fid, const int32_t split_cond, - common::Span& rid_buf, + xgboost::common::Span& rid_buf, size_t* parts_size, - sycl::event priv_event) { + ::sycl::event priv_event) { const size_t row_stride = gmat.row_stride; const BinIdxType* gradient_index = gmat.index.data(); const size_t* rid = rid_span.begin; @@ -1248,9 +1221,9 @@ inline sycl::event PartitionSparseKernel(sycl::queue& qu, const bst_float* cut_vals = gmat.cut_device.Values().DataConst(); size_t* p_rid_buf = rid_buf.data(); - auto event = qu.submit([&](sycl::handler& cgh) { + auto event = qu.submit([&](::sycl::handler& cgh) { cgh.depends_on(priv_event); - cgh.parallel_for<>(sycl::range<1>(range_size), [=](sycl::item<1> nid) { + cgh.parallel_for<>(::sycl::range<1>(range_size), [=](::sycl::item<1> nid) { const size_t id = rid[nid.get_id(0)]; const BinIdxType* gr_index_local = gradient_index + row_stride * id; @@ -1270,14 +1243,14 @@ inline sycl::event PartitionSparseKernel(sycl::queue& qu, template template -sycl::event QuantileHistMakerOneAPIBackend::Builder::PartitionKernel( +::sycl::event QuantileHistMaker::Builder::PartitionKernel( const size_t nid, const int32_t split_cond, - const GHistIndexMatrixOneAPI& gmat, + const GHistIndexMatrix& gmat, const RegTree::Node& node, - common::Span& rid_buf, + xgboost::common::Span& rid_buf, size_t* parts_size, - sycl::event priv_event) { + ::sycl::event priv_event) { const bst_uint fid = node.SplitIndex(); const bool default_left = node.DefaultLeft(); @@ -1297,10 +1270,10 @@ sycl::event QuantileHistMakerOneAPIBackend::Builder::PartitionKern } template -void QuantileHistMakerOneAPIBackend::Builder::FindSplitConditions( +void QuantileHistMaker::Builder::FindSplitConditions( const std::vector& nodes, const RegTree& tree, - const GHistIndexMatrixOneAPI& gmat, + const GHistIndexMatrix& gmat, std::vector* split_conditions) { const size_t n_nodes = nodes.size(); split_conditions->resize(n_nodes); @@ -1325,7 +1298,7 @@ void QuantileHistMakerOneAPIBackend::Builder::FindSplitConditions( } } template -void QuantileHistMakerOneAPIBackend::Builder::AddSplitsToRowSet(const std::vector& nodes, +void QuantileHistMaker::Builder::AddSplitsToRowSet(const std::vector& nodes, RegTree* p_tree) { const size_t n_nodes = nodes.size(); for (size_t i = 0; i < n_nodes; ++i) { @@ -1339,9 +1312,9 @@ void QuantileHistMakerOneAPIBackend::Builder::AddSplitsToRowSet(co } template -void QuantileHistMakerOneAPIBackend::Builder::ApplySplit(const std::vector nodes, - const GHistIndexMatrixOneAPI& gmat, - const HistCollectionOneAPI& hist, +void QuantileHistMaker::Builder::ApplySplit(const std::vector nodes, + const GHistIndexMatrix& gmat, + const HistCollection& hist, RegTree* p_tree) { builder_monitor_.Start("ApplySplit"); @@ -1360,35 +1333,35 @@ void QuantileHistMakerOneAPIBackend::Builder::ApplySplit(const std for (size_t node_in_set = 0; node_in_set < n_nodes; node_in_set++) { const int32_t nid = nodes[node_in_set].nid; - sycl::event& cur_event = apply_split_events_[node_in_set]; + ::sycl::event& cur_event = apply_split_events_[node_in_set]; if (row_set_collection_[nid].Size() > 0) { const RegTree::Node& node = (*p_tree)[nid]; - common::Span rid_buf = partition_builder_.GetData(node_in_set); + xgboost::common::Span rid_buf = partition_builder_.GetData(node_in_set); size_t* part_size = parts_size_.Data() + 2 * node_in_set; int32_t split_condition = split_conditions[node_in_set]; switch (gmat.index.GetBinTypeSize()) { - case common::kUint8BinsTypeSize: + case common::BinTypeSize::kUint8BinsTypeSize: cur_event = PartitionKernel(nid, split_condition, gmat, node, rid_buf, part_size, event); break; - case common::kUint16BinsTypeSize: + case common::BinTypeSize::kUint16BinsTypeSize: cur_event = PartitionKernel(nid, split_condition, gmat, node, rid_buf, part_size, event); break; - case common::kUint32BinsTypeSize: + case common::BinTypeSize::kUint32BinsTypeSize: cur_event = PartitionKernel(nid, split_condition, gmat, node, rid_buf, part_size, event); break; default: CHECK(false); // no default behavior } } else { - cur_event = sycl::event(); + cur_event = ::sycl::event(); } } - sycl::event event_cpy = qu_.memcpy(partition_builder_.GetResultRowsPtr(), parts_size_.DataConst(), sizeof(size_t) * 2 * n_nodes, apply_split_events_); + ::sycl::event event_cpy = qu_.memcpy(partition_builder_.GetResultRowsPtr(), parts_size_.DataConst(), sizeof(size_t) * 2 * n_nodes, apply_split_events_); qu_.wait_and_throw(); merge_to_array_events_.resize(n_nodes); for (size_t node_in_set = 0; node_in_set < n_nodes; node_in_set++) { - sycl::event& cur_event = merge_to_array_events_[node_in_set]; + ::sycl::event& cur_event = merge_to_array_events_[node_in_set]; const int32_t nid = nodes[node_in_set].nid; size_t* data_result = const_cast(row_set_collection_[nid].begin); cur_event = partition_builder_.MergeToArray(qu_, node_in_set, data_result, event_cpy); @@ -1401,8 +1374,8 @@ void QuantileHistMakerOneAPIBackend::Builder::ApplySplit(const std } template -void QuantileHistMakerOneAPIBackend::Builder::InitNewNode(int nid, - const GHistIndexMatrixOneAPI& gmat, +void QuantileHistMaker::Builder::InitNewNode(int nid, + const GHistIndexMatrix& gmat, const std::vector& gpair, const DMatrix& fmat, const RegTree& tree) { @@ -1428,7 +1401,7 @@ void QuantileHistMakerOneAPIBackend::Builder::InitNewNode(int nid, grad_stat.Add(et.GetGrad(), et.GetHess()); } } else { - const RowSetCollectionOneAPI::Elem e = row_set_collection_[nid]; + const RowSetCollection::Elem e = row_set_collection_[nid]; // for (const size_t* it = e.begin; it < e.end; ++it) { // grad_stat.Add(gpair[*it].GetGrad(), gpair[*it].GetHess()); // } @@ -1440,7 +1413,7 @@ void QuantileHistMakerOneAPIBackend::Builder::InitNewNode(int nid, } collective::Allreduce(reinterpret_cast(&grad_stat), 2); // histred_.Allreduce(&grad_stat, 1); - snode_[nid].stats = GradStatsOneAPI(grad_stat.GetGrad(), grad_stat.GetHess()); + snode_[nid].stats = GradStats(grad_stat.GetGrad(), grad_stat.GetHess()); } else { int parent_id = tree[nid].Parent(); if (tree[nid].IsLeftChild()) { @@ -1456,46 +1429,40 @@ void QuantileHistMakerOneAPIBackend::Builder::InitNewNode(int nid, auto evaluator = tree_evaluator_.GetEvaluator(); bst_uint parentid = tree[nid].Parent(); snode_[nid].weight = static_cast( - evaluator.CalcWeight(parentid, GradStatsOneAPI{snode_[nid].stats})); + evaluator.CalcWeight(parentid, GradStats{snode_[nid].stats})); snode_[nid].root_gain = static_cast( - evaluator.CalcGain(parentid, GradStatsOneAPI{snode_[nid].stats})); + evaluator.CalcGain(parentid, GradStats{snode_[nid].stats})); } builder_monitor_.Stop("InitNewNode"); } -template struct QuantileHistMakerOneAPIBackend::Builder; -template struct QuantileHistMakerOneAPIBackend::Builder; -template sycl::event QuantileHistMakerOneAPIBackend::Builder::PartitionKernel( - const size_t nid, const int32_t split_cond, const GHistIndexMatrixOneAPI &gmat, - const RegTree::Node& node, common::Span& rid_buf, size_t* parts_size, sycl::event priv_event); -template sycl::event QuantileHistMakerOneAPIBackend::Builder::PartitionKernel( - const size_t nid, const int32_t split_cond, const GHistIndexMatrixOneAPI &gmat, - const RegTree::Node& node, common::Span& rid_buf, size_t* parts_size, sycl::event priv_event); -template sycl::event QuantileHistMakerOneAPIBackend::Builder::PartitionKernel( - const size_t nid, const int32_t split_cond, const GHistIndexMatrixOneAPI &gmat, - const RegTree::Node& node, common::Span& rid_buf, size_t* parts_size, sycl::event priv_event); -template sycl::event QuantileHistMakerOneAPIBackend::Builder::PartitionKernel( - const size_t nid, const int32_t split_cond, const GHistIndexMatrixOneAPI &gmat, - const RegTree::Node& node, common::Span& rid_buf, size_t* parts_size, sycl::event priv_event); -template sycl::event QuantileHistMakerOneAPIBackend::Builder::PartitionKernel( - const size_t nid, const int32_t split_cond, const GHistIndexMatrixOneAPI &gmat, - const RegTree::Node& node, common::Span& rid_buf, size_t* parts_size, sycl::event priv_event); -template sycl::event QuantileHistMakerOneAPIBackend::Builder::PartitionKernel( - const size_t nid, const int32_t split_cond, const GHistIndexMatrixOneAPI &gmat, - const RegTree::Node& node, common::Span& rid_buf, size_t* parts_size, sycl::event priv_event); - -XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMakerOneAPI, "grow_quantile_histmaker_oneapi") -.describe("Grow tree using quantized histogram with dpc++.") -.set_body( - [](Context const* ctx, ObjInfo const * task) { - return new QuantileHistMakerOneAPI(ctx, task); - }); - -XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMakerOneAPIBackend, "grow_quantile_histmaker_oneapi_backend") -.describe("Grow tree using quantized histogram with dpc++ on GPU.") +template struct QuantileHistMaker::Builder; +template struct QuantileHistMaker::Builder; +template ::sycl::event QuantileHistMaker::Builder::PartitionKernel( + const size_t nid, const int32_t split_cond, const GHistIndexMatrix &gmat, + const RegTree::Node& node, xgboost::common::Span& rid_buf, size_t* parts_size, ::sycl::event priv_event); +template ::sycl::event QuantileHistMaker::Builder::PartitionKernel( + const size_t nid, const int32_t split_cond, const GHistIndexMatrix &gmat, + const RegTree::Node& node, xgboost::common::Span& rid_buf, size_t* parts_size, ::sycl::event priv_event); +template ::sycl::event QuantileHistMaker::Builder::PartitionKernel( + const size_t nid, const int32_t split_cond, const GHistIndexMatrix &gmat, + const RegTree::Node& node, xgboost::common::Span& rid_buf, size_t* parts_size, ::sycl::event priv_event); +template ::sycl::event QuantileHistMaker::Builder::PartitionKernel( + const size_t nid, const int32_t split_cond, const GHistIndexMatrix &gmat, + const RegTree::Node& node, xgboost::common::Span& rid_buf, size_t* parts_size, ::sycl::event priv_event); +template ::sycl::event QuantileHistMaker::Builder::PartitionKernel( + const size_t nid, const int32_t split_cond, const GHistIndexMatrix &gmat, + const RegTree::Node& node, xgboost::common::Span& rid_buf, size_t* parts_size, ::sycl::event priv_event); +template ::sycl::event QuantileHistMaker::Builder::PartitionKernel( + const size_t nid, const int32_t split_cond, const GHistIndexMatrix &gmat, + const RegTree::Node& node, xgboost::common::Span& rid_buf, size_t* parts_size, ::sycl::event priv_event); + +XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker_sycl") +.describe("Grow tree using quantized histogram with SYCL.") .set_body( [](Context const* ctx, ObjInfo const * task) { - return new QuantileHistMakerOneAPIBackend(ctx, task); + return new QuantileHistMaker(ctx, task); }); } // namespace tree +} // namespace sycl } // namespace xgboost diff --git a/plugin/updater_oneapi/updater_quantile_hist_oneapi.h b/plugin/sycl/tree/updater_quantile_hist.h similarity index 62% rename from plugin/updater_oneapi/updater_quantile_hist_oneapi.h rename to plugin/sycl/tree/updater_quantile_hist.h index 70e13c8c9e98..434f6bbdbedd 100644 --- a/plugin/updater_oneapi/updater_quantile_hist_oneapi.h +++ b/plugin/sycl/tree/updater_quantile_hist.h @@ -1,9 +1,9 @@ /*! * Copyright 2017-2021 by Contributors - * \file updater_quantile_hist_oneapi.h + * \file updater_quantile_hist.h */ -#ifndef XGBOOST_TREE_UPDATER_QUANTILE_HIST_ONEAPI_H_ -#define XGBOOST_TREE_UPDATER_QUANTILE_HIST_ONEAPI_H_ +#ifndef XGBOOST_TREE_UPDATER_QUANTILE_HIST_SYCL_H_ +#define XGBOOST_TREE_UPDATER_QUANTILE_HIST_SYCL_H_ #include #include @@ -11,10 +11,10 @@ #include -#include "hist_util_oneapi.h" -#include "row_set_oneapi.h" -#include "split_evaluator_oneapi.h" -#include "device_manager_oneapi.h" +#include "../common/hist_util.h" +#include "../common/row_set.h" +#include "split_evaluator.h" +#include "../device_manager.h" #include "xgboost/data.h" #include "xgboost/json.h" @@ -22,16 +22,16 @@ #include "../../src/common/random.h" namespace xgboost { - +namespace sycl { /*! * \brief A C-style array with in-stack allocation. As long as the array is smaller than MaxStackSize, it will be allocated inside the stack. Otherwise, it will be heap-allocated. Temporary copy of implementation to remove dependency on updater_quantile_hist.h */ template -class MemStackAllocatorOneAPI { +class MemStackAllocator { public: - explicit MemStackAllocatorOneAPI(size_t required_size): required_size_(required_size) { + explicit MemStackAllocator(size_t required_size): required_size_(required_size) { } T* Get() { @@ -47,7 +47,7 @@ class MemStackAllocatorOneAPI { return ptr_; } - ~MemStackAllocatorOneAPI() { + ~MemStackAllocator() { if (do_free_) free(ptr_); } @@ -61,122 +61,71 @@ class MemStackAllocatorOneAPI { namespace tree { -using xgboost::common::HistCollectionOneAPI; -using xgboost::common::GHistBuilderOneAPI; -using xgboost::common::GHistIndexMatrixOneAPI; -using xgboost::common::GHistRowOneAPI; -using xgboost::common::RowSetCollectionOneAPI; +using xgboost::sycl::common::HistCollection; +using xgboost::sycl::common::GHistBuilder; +using xgboost::sycl::common::GHistIndexMatrix; +using xgboost::sycl::common::GHistRow; +using xgboost::sycl::common::RowSetCollection; +using xgboost::sycl::common::PartitionBuilder; template -class HistSynchronizerOneAPI; +class HistSynchronizer; template -class BatchHistSynchronizerOneAPI; +class BatchHistSynchronizer; template -class DistributedHistSynchronizerOneAPI; +class DistributedHistSynchronizer; template -class HistRowsAdderOneAPI; +class HistRowsAdder; template -class BatchHistRowsAdderOneAPI; +class BatchHistRowsAdder; template -class DistributedHistRowsAdderOneAPI; +class DistributedHistRowsAdder; // training parameters specific to this algorithm -struct OneAPIHistMakerTrainParam - : public XGBoostParameter { +struct HistMakerTrainParam + : public XGBoostParameter { bool single_precision_histogram = false; // declare parameters - DMLC_DECLARE_PARAMETER(OneAPIHistMakerTrainParam) { + DMLC_DECLARE_PARAMETER(HistMakerTrainParam) { DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe( "Use single precision to build histograms."); } }; -/*! \brief construct a tree using quantized feature values with DPC++ interface */ -class QuantileHistMakerOneAPI: public TreeUpdater { - public: - explicit QuantileHistMakerOneAPI(Context const* ctx, ObjInfo const * task) : TreeUpdater(ctx), ctx_(ctx), task_{task} {} - void Configure(const Args& args) override; - - void Update(TrainParam const *param, - HostDeviceVector* gpair, - DMatrix* dmat, - common::Span> out_position, - const std::vector& trees) override; - - bool UpdatePredictionCache(const DMatrix* data, - linalg::MatrixView out_preds) override; - - void LoadConfig(Json const& in) override { - if (updater_backend_) { - updater_backend_->LoadConfig(in); - } else { - auto const& config = get(in); - FromJson(config.at("train_param"), &this->param_); - } - } - - void SaveConfig(Json* p_out) const override { - if (updater_backend_) { - updater_backend_->SaveConfig(p_out); - } else { - auto& out = *p_out; - out["train_param"] = ToJson(param_); - } - } - - char const* Name() const override { - if (updater_backend_) { - return updater_backend_->Name(); - } else { - return "grow_quantile_histmaker_oneapi"; - } - } - - protected: - // training parameter - TrainParam param_; - - DeviceManagerOneAPI device_manager; - - ObjInfo const *task_{nullptr}; - Context const* ctx_; - std::unique_ptr updater_backend_; -}; - // data structure template struct NodeEntry { /*! \brief statics for node entry */ - GradStatsOneAPI stats; + GradStats stats; /*! \brief loss of this node, without split */ GradType root_gain; /*! \brief weight calculated related to current data */ GradType weight; /*! \brief current best solution */ - SplitEntryOneAPI best; + SplitEntry best; // constructor - explicit NodeEntry(const TrainParam& param) + explicit NodeEntry(const xgboost::tree::TrainParam& param) : root_gain(0.0f), weight(0.0f) {} }; // actual builder that runs the algorithm -/*! \brief construct a tree using quantized feature values with DPC++ backend on GPU*/ -class QuantileHistMakerOneAPIBackend: public TreeUpdater { +/*! \brief construct a tree using quantized feature values with SYCL backend*/ +class QuantileHistMaker: public TreeUpdater { public: - explicit QuantileHistMakerOneAPIBackend(Context const* ctx, ObjInfo const * task) : TreeUpdater(ctx), ctx_(ctx), task_{task} { - updater_monitor_.Init("QuantileHistMakerOneAPIBackend"); + explicit QuantileHistMaker(Context const* ctx, ObjInfo const * task) : TreeUpdater(ctx), ctx_(ctx), task_{task} { + updater_monitor_.Init("SYCLQuantileHistMaker"); } void Configure(const Args& args) override; - void Update(TrainParam const *param, + void Update(xgboost::tree::TrainParam const *param, HostDeviceVector* gpair, DMatrix* dmat, - common::Span> out_position, + xgboost::common::Span> out_position, const std::vector& trees) override; bool UpdatePredictionCache(const DMatrix* data, @@ -186,7 +135,7 @@ class QuantileHistMakerOneAPIBackend: public TreeUpdater { auto const& config = get(in); FromJson(config.at("train_param"), &this->param_); try { - FromJson(config.at("oneapi_hist_train_param"), &this->hist_maker_param_); + FromJson(config.at("sycl_hist_train_param"), &this->hist_maker_param_); } catch (std::out_of_range& e) { // XGBoost model is from 1.1.x, so 'cpu_hist_train_param' is missing. // We add this compatibility check because it's just recently that we (developers) began @@ -205,33 +154,33 @@ class QuantileHistMakerOneAPIBackend: public TreeUpdater { void SaveConfig(Json* p_out) const override { auto& out = *p_out; out["train_param"] = ToJson(param_); - out["oneapi_hist_train_param"] = ToJson(hist_maker_param_); + out["sycl_hist_train_param"] = ToJson(hist_maker_param_); } char const* Name() const override { - return "grow_quantile_histmaker_oneapi_backend"; + return "grow_quantile_histmaker_sycl"; } protected: template - friend class HistSynchronizerOneAPI; + friend class HistSynchronizer; template - friend class BatchHistSynchronizerOneAPI; + friend class BatchHistSynchronizer; template - friend class DistributedHistSynchronizerOneAPI; + friend class DistributedHistSynchronizer; template - friend class HistRowsAdderOneAPI; + friend class HistRowsAdder; template - friend class BatchHistRowsAdderOneAPI; + friend class BatchHistRowsAdder; template - friend class DistributedHistRowsAdderOneAPI; + friend class DistributedHistRowsAdder; - OneAPIHistMakerTrainParam hist_maker_param_; + HistMakerTrainParam hist_maker_param_; // training parameter - TrainParam param_; + xgboost::tree::TrainParam param_; // quantized data matrix - GHistIndexMatrixOneAPI gmat_; + GHistIndexMatrix gmat_; // (optional) data matrix with feature grouping // column accessor DMatrix const* p_last_dmat_ {nullptr}; @@ -241,11 +190,11 @@ class QuantileHistMakerOneAPIBackend: public TreeUpdater { struct Builder { public: template - using GHistRowT = GHistRowOneAPI; + using GHistRowT = GHistRow; using GradientPairT = xgboost::detail::GradientPairInternal; // constructor - explicit Builder(sycl::queue qu, - const TrainParam& param, + explicit Builder(::sycl::queue qu, + const xgboost::tree::TrainParam& param, std::unique_ptr pruner, FeatureInteractionConstraintHost int_constraints_, DMatrix const* fmat) @@ -255,25 +204,25 @@ class QuantileHistMakerOneAPIBackend: public TreeUpdater { interaction_constraints_{std::move(int_constraints_)}, p_last_tree_(nullptr), p_last_fmat_(fmat), snode_(qu, 1u << (param.max_depth + 1), NodeEntry(param)) { - builder_monitor_.Init("QuantileOneAPI::Builder"); - kernel_monitor_.Init("QuantileOneAPI::Kernels"); + builder_monitor_.Init("SYCL::Quantile::Builder"); + kernel_monitor_.Init("SYCL::Quantile::Kernels"); } // update one tree, growing void Update(Context const * ctx, - TrainParam const *param, - const GHistIndexMatrixOneAPI &gmat, + xgboost::tree::TrainParam const *param, + const GHistIndexMatrix &gmat, HostDeviceVector *gpair, const USMVector& gpair_device, DMatrix *p_fmat, - common::Span> out_position, + xgboost::common::Span> out_position, RegTree *p_tree); - inline sycl::event BuildHist(const USMVector& gpair_device, - const RowSetCollectionOneAPI::Elem row_indices, - const GHistIndexMatrixOneAPI& gmat, + inline ::sycl::event BuildHist(const USMVector& gpair_device, + const RowSetCollection::Elem row_indices, + const GHistIndexMatrix& gmat, GHistRowT& hist, GHistRowT& hist_buffer, - sycl::event event_priv) { + ::sycl::event event_priv) { return hist_builder_.BuildHist(gpair_device, row_indices, gmat, hist, data_layout_ != kSparseData, hist_buffer, event_priv); } @@ -287,24 +236,24 @@ class QuantileHistMakerOneAPIBackend: public TreeUpdater { bool UpdatePredictionCache(const DMatrix* data, linalg::MatrixView p_out_preds); - void SetHistSynchronizer(HistSynchronizerOneAPI* sync); - void SetHistRowsAdder(HistRowsAdderOneAPI* adder); + void SetHistSynchronizer(HistSynchronizer* sync); + void SetHistRowsAdder(HistRowsAdder* adder); // initialize temp data structure void InitData(Context const * ctx, - const GHistIndexMatrixOneAPI& gmat, + const GHistIndexMatrix& gmat, const std::vector& gpair, const USMVector &gpair_device, const DMatrix& fmat, const RegTree& tree); protected: - friend class HistSynchronizerOneAPI; - friend class BatchHistSynchronizerOneAPI; - friend class DistributedHistSynchronizerOneAPI; - friend class HistRowsAdderOneAPI; - friend class BatchHistRowsAdderOneAPI; - friend class DistributedHistRowsAdderOneAPI; + friend class HistSynchronizer; + friend class BatchHistSynchronizer; + friend class DistributedHistSynchronizer; + friend class HistRowsAdder; + friend class BatchHistRowsAdder; + friend class DistributedHistRowsAdder; /* tree growing policies */ struct ExpandEntry { @@ -320,7 +269,7 @@ class QuantileHistMakerOneAPIBackend: public TreeUpdater { : nid(nid), sibling_nid(sibling_nid), depth(depth), loss_chg(loss_chg), timestamp(tstmp) {} - bool IsValid(TrainParam const ¶m, int32_t num_leaves) const { + bool IsValid(xgboost::tree::TrainParam const ¶m, int32_t num_leaves) const { bool ret = loss_chg <= kRtEps || (param.max_depth > 0 && this->depth == param.max_depth) || (param.max_leaves > 0 && num_leaves == param.max_leaves); @@ -331,7 +280,7 @@ class QuantileHistMakerOneAPIBackend: public TreeUpdater { struct SplitQuery { int nid; int fid; - SplitEntryOneAPI best; + SplitEntry best; const GradientPairT* hist; }; @@ -340,48 +289,48 @@ class QuantileHistMakerOneAPIBackend: public TreeUpdater { const DMatrix& fmat, USMVector& row_indices); void EvaluateSplits(const std::vector& nodes_set, - const GHistIndexMatrixOneAPI& gmat, - const HistCollectionOneAPI& hist, + const GHistIndexMatrix& gmat, + const HistCollection& hist, const RegTree& tree); // Enumerate the split values of specific feature // Returns the sum of gradients corresponding to the data points that contains a non-missing // value for the particular feature fid. template - static GradStatsOneAPI EnumerateSplit( + static GradStats EnumerateSplit( const uint32_t* cut_ptr,const bst_float* cut_val, const bst_float* cut_minval, const GradientPairT* hist_data, - const NodeEntry &snode, SplitEntryOneAPI& p_best, bst_uint fid, + const NodeEntry &snode, SplitEntry& p_best, bst_uint fid, bst_uint nodeID, - typename TreeEvaluatorOneAPI::SplitEvaluator const &evaluator, const TrainParamOneAPI& param); + typename TreeEvaluator::SplitEvaluator const &evaluator, const TrainParam& param); - static GradStatsOneAPI EnumerateSplit(sycl::sub_group& sg, + static GradStats EnumerateSplit(::sycl::sub_group& sg, const uint32_t* cut_ptr, const bst_float* cut_val, const GradientPairT* hist_data, - const NodeEntry &snode, SplitEntryOneAPI& p_best, bst_uint fid, + const NodeEntry &snode, SplitEntry& p_best, bst_uint fid, bst_uint nodeID, - typename TreeEvaluatorOneAPI::SplitEvaluator const &evaluator, const TrainParamOneAPI& param); + typename TreeEvaluator::SplitEvaluator const &evaluator, const TrainParam& param); void ApplySplit(std::vector nodes, - const GHistIndexMatrixOneAPI& gmat, - const HistCollectionOneAPI& hist, + const GHistIndexMatrix& gmat, + const HistCollection& hist, RegTree* p_tree); template - sycl::event PartitionKernel(const size_t nid, + ::sycl::event PartitionKernel(const size_t nid, const int32_t split_cond, - const GHistIndexMatrixOneAPI &gmat, + const GHistIndexMatrix &gmat, const RegTree::Node& node, - common::Span& rid_buf, + xgboost::common::Span& rid_buf, size_t* parts_size, - sycl::event priv_event); + ::sycl::event priv_event); void AddSplitsToRowSet(const std::vector& nodes, RegTree* p_tree); void FindSplitConditions(const std::vector& nodes, const RegTree& tree, - const GHistIndexMatrixOneAPI& gmat, std::vector* split_conditions); + const GHistIndexMatrix& gmat, std::vector* split_conditions); void InitNewNode(int nid, - const GHistIndexMatrixOneAPI& gmat, + const GHistIndexMatrix& gmat, const std::vector& gpair, const DMatrix& fmat, const RegTree& tree); @@ -390,21 +339,21 @@ class QuantileHistMakerOneAPIBackend: public TreeUpdater { // is equal to sum of statistics for all values: // then - there are no missing values // else - there are missing values - static bool SplitContainsMissingValues(const GradStatsOneAPI& e, const NodeEntry& snode); + static bool SplitContainsMissingValues(const GradStats& e, const NodeEntry& snode); - void ExpandWithDepthWise(const GHistIndexMatrixOneAPI &gmat, + void ExpandWithDepthWise(const GHistIndexMatrix &gmat, DMatrix *p_fmat, RegTree *p_tree, const std::vector &gpair, const USMVector &gpair_device); - void BuildLocalHistograms(const GHistIndexMatrixOneAPI &gmat, + void BuildLocalHistograms(const GHistIndexMatrix &gmat, RegTree *p_tree, const USMVector &gpair_device); void BuildHistogramsLossGuide( ExpandEntry entry, - const GHistIndexMatrixOneAPI &gmat, + const GHistIndexMatrix &gmat, RegTree *p_tree, const USMVector &gpair_device); @@ -416,16 +365,16 @@ class QuantileHistMakerOneAPIBackend: public TreeUpdater { std::vector* big_siblings, RegTree *p_tree); - void ParallelSubtractionHist(const common::BlockedSpace2d& space, + void ParallelSubtractionHist(const xgboost::common::BlockedSpace2d& space, const std::vector& nodes, const RegTree * p_tree); - void BuildNodeStats(const GHistIndexMatrixOneAPI &gmat, + void BuildNodeStats(const GHistIndexMatrix &gmat, DMatrix *p_fmat, RegTree *p_tree, const std::vector &gpair); - void EvaluateAndApplySplits(const GHistIndexMatrixOneAPI &gmat, + void EvaluateAndApplySplits(const GHistIndexMatrix &gmat, RegTree *p_tree, int *num_leaves, int depth, @@ -433,7 +382,7 @@ class QuantileHistMakerOneAPIBackend: public TreeUpdater { std::vector *temp_qexpand_depth); void AddSplitsToTree( - const GHistIndexMatrixOneAPI &gmat, + const GHistIndexMatrix &gmat, RegTree *p_tree, int *num_leaves, int depth, @@ -441,7 +390,7 @@ class QuantileHistMakerOneAPIBackend: public TreeUpdater { std::vector* nodes_for_apply_split, std::vector* temp_qexpand_depth); - void ExpandWithLossGuide(const GHistIndexMatrixOneAPI& gmat, + void ExpandWithLossGuide(const GHistIndexMatrix& gmat, DMatrix* p_fmat, RegTree* p_tree, const std::vector &gpair, @@ -457,29 +406,29 @@ class QuantileHistMakerOneAPIBackend: public TreeUpdater { } } // --data fields-- - const TrainParam& param_; + const xgboost::tree::TrainParam& param_; // number of omp thread used during training int nthread_; - common::ColumnSampler column_sampler_; + xgboost::common::ColumnSampler column_sampler_; // the internal row sets - RowSetCollectionOneAPI row_set_collection_; + RowSetCollection row_set_collection_; USMVector split_queries_device_; /*! \brief TreeNode Data: statistics for each constructed node */ USMVector> snode_; /*! \brief culmulative histogram of gradients. */ - HistCollectionOneAPI hist_; + HistCollection hist_; /*! \brief culmulative local parent histogram of gradients. */ - HistCollectionOneAPI hist_local_worker_; - TreeEvaluatorOneAPI tree_evaluator_; + HistCollection hist_local_worker_; + TreeEvaluator tree_evaluator_; /*! \brief feature with least # of bins. to be used for dense specialization of InitNewNode() */ uint32_t fid_least_bins_; - GHistBuilderOneAPI hist_builder_; + GHistBuilder hist_builder_; std::unique_ptr pruner_; FeatureInteractionConstraintHost interaction_constraints_; - common::PartitionBuilderOneAPI partition_builder_; + PartitionBuilder partition_builder_; // back pointers to tree and data matrix const RegTree* p_last_tree_; @@ -500,32 +449,32 @@ class QuantileHistMakerOneAPIBackend: public TreeUpdater { enum DataLayout { kDenseDataZeroBased, kDenseDataOneBased, kSparseData }; DataLayout data_layout_; - common::Monitor builder_monitor_; - common::Monitor kernel_monitor_; + xgboost::common::Monitor builder_monitor_; + xgboost::common::Monitor kernel_monitor_; constexpr static size_t kNumParallelBuffers = 1; - std::array, kNumParallelBuffers> hist_buffers_; - std::array hist_build_events_; + std::array, kNumParallelBuffers> hist_buffers_; + std::array<::sycl::event, kNumParallelBuffers> hist_build_events_; USMVector parts_size_; std::vector parts_size_cpu_; - std::vector apply_split_events_; - std::vector merge_to_array_events_; + std::vector<::sycl::event> apply_split_events_; + std::vector<::sycl::event> merge_to_array_events_; // rabit::op::Reducer histred_; - std::unique_ptr> hist_synchronizer_; - std::unique_ptr> hist_rows_adder_; + std::unique_ptr> hist_synchronizer_; + std::unique_ptr> hist_rows_adder_; - sycl::queue qu_; + ::sycl::queue qu_; }; - common::Monitor updater_monitor_; + xgboost::common::Monitor updater_monitor_; template void SetBuilder(std::unique_ptr>*, DMatrix *dmat); template void CallBuilderUpdate(const std::unique_ptr>& builder, - TrainParam const *param, + xgboost::tree::TrainParam const *param, HostDeviceVector *gpair, DMatrix *dmat, - common::Span> out_position, + xgboost::common::Span> out_position, const std::vector &trees); protected: @@ -535,43 +484,43 @@ class QuantileHistMakerOneAPIBackend: public TreeUpdater { std::unique_ptr pruner_; FeatureInteractionConstraintHost int_constraint_; - sycl::queue qu_; - DeviceManagerOneAPI device_manager; + ::sycl::queue qu_; + DeviceManager device_manager; Context const* ctx_; ObjInfo const *task_{nullptr}; }; template -class HistSynchronizerOneAPI { +class HistSynchronizer { public: - using BuilderT = QuantileHistMakerOneAPIBackend::Builder; + using BuilderT = QuantileHistMaker::Builder; virtual void SyncHistograms(BuilderT* builder, std::vector& sync_ids, RegTree *p_tree) = 0; - virtual ~HistSynchronizerOneAPI() = default; + virtual ~HistSynchronizer() = default; }; template -class BatchHistSynchronizerOneAPI: public HistSynchronizerOneAPI { +class BatchHistSynchronizer: public HistSynchronizer { public: - using BuilderT = QuantileHistMakerOneAPIBackend::Builder; + using BuilderT = QuantileHistMaker::Builder; void SyncHistograms(BuilderT* builder, std::vector& sync_ids, RegTree *p_tree) override; - std::vector GetEvents() const { + std::vector<::sycl::event> GetEvents() const { return hist_sync_events_; } private: - std::vector hist_sync_events_; + std::vector<::sycl::event> hist_sync_events_; }; template -class DistributedHistSynchronizerOneAPI: public HistSynchronizerOneAPI { +class DistributedHistSynchronizer: public HistSynchronizer { public: - using BuilderT = QuantileHistMakerOneAPIBackend::Builder; + using BuilderT = QuantileHistMaker::Builder; using ExpandEntryT = typename BuilderT::ExpandEntry; void SyncHistograms(BuilderT* builder, std::vector& sync_ids, RegTree *p_tree) override; @@ -582,30 +531,31 @@ class DistributedHistSynchronizerOneAPI: public HistSynchronizerOneAPI -class HistRowsAdderOneAPI { +class HistRowsAdder { public: - using BuilderT = QuantileHistMakerOneAPIBackend::Builder; + using BuilderT = QuantileHistMaker::Builder; virtual void AddHistRows(BuilderT* builder, std::vector& sync_ids, RegTree *p_tree) = 0; - virtual ~HistRowsAdderOneAPI() = default; + virtual ~HistRowsAdder() = default; }; template -class BatchHistRowsAdderOneAPI: public HistRowsAdderOneAPI { +class BatchHistRowsAdder: public HistRowsAdder { public: - using BuilderT = QuantileHistMakerOneAPIBackend::Builder; + using BuilderT = QuantileHistMaker::Builder; void AddHistRows(BuilderT*, std::vector& sync_ids, RegTree *p_tree) override; }; template -class DistributedHistRowsAdderOneAPI: public HistRowsAdderOneAPI { +class DistributedHistRowsAdder: public HistRowsAdder { public: - using BuilderT = QuantileHistMakerOneAPIBackend::Builder; + using BuilderT = QuantileHistMaker::Builder; void AddHistRows(BuilderT*, std::vector& sync_ids, RegTree *p_tree) override; }; } // namespace tree +} // namespace sycl } // namespace xgboost -#endif // XGBOOST_TREE_UPDATER_QUANTILE_HIST_ONEAPI_H_ +#endif // XGBOOST_TREE_UPDATER_QUANTILE_HIST_SYCL_H_ diff --git a/plugin/updater_oneapi/device_manager_oneapi.h b/plugin/updater_oneapi/device_manager_oneapi.h deleted file mode 100644 index 92f02939ca1d..000000000000 --- a/plugin/updater_oneapi/device_manager_oneapi.h +++ /dev/null @@ -1,44 +0,0 @@ -/*! - * Copyright 2017-2022 by Contributors - * \file device_manager_oneapi.h - */ -#ifndef XGBOOST_DEVICE_MANAGER_ONEAPI_H_ -#define XGBOOST_DEVICE_MANAGER_ONEAPI_H_ - -#include -#include -#include - -#include "CL/sycl.hpp" -#include "xgboost/context.h" - -namespace xgboost { - -class DeviceManagerOneAPI { - public: - // DeviceManagerOneAPI(); - - sycl::queue GetQueue(const DeviceOrd& device_spec) const; - - sycl::device GetDevice(const DeviceOrd& device_spec) const; - - private: - using QueueRegister_t = std::unordered_map; - - struct DeviceRegister { - std::vector devices; - std::vector cpu_devices; - std::vector gpu_devices; - }; - - QueueRegister_t& GetQueueRegister() const; - - DeviceRegister& GetDevicesRegister() const; - - mutable std::mutex queue_registering_mutex; - mutable std::mutex device_registering_mutex; -}; - -} // namespace xgboost - -#endif // XGBOOST_DEVICE_MANAGER_ONEAPI_H_ \ No newline at end of file diff --git a/plugin/updater_oneapi/regression_obj_oneapi.cc b/plugin/updater_oneapi/regression_obj_oneapi.cc deleted file mode 100755 index 3c157a80e797..000000000000 --- a/plugin/updater_oneapi/regression_obj_oneapi.cc +++ /dev/null @@ -1,193 +0,0 @@ -#include -#include -#include -#include -#include -#include - -#include "xgboost/host_device_vector.h" -#include "xgboost/json.h" -#include "xgboost/parameter.h" -#include "xgboost/span.h" - -#include "../../src/common/transform.h" -#include "../../src/common/common.h" -#include "regression_loss_oneapi.h" -#include "device_manager_oneapi.h" - -#include "CL/sycl.hpp" - -namespace xgboost { -namespace obj { - -DMLC_REGISTRY_FILE_TAG(regression_obj_oneapi); - -struct RegLossParamOneAPI : public XGBoostParameter { - float scale_pos_weight; - // declare parameters - DMLC_DECLARE_PARAMETER(RegLossParamOneAPI) { - DMLC_DECLARE_FIELD(scale_pos_weight).set_default(1.0f).set_lower_bound(0.0f) - .describe("Scale the weight of positive examples by this factor"); - } -}; - -template -class RegLossObjOneAPI : public ObjFunction { - protected: - HostDeviceVector label_correct_; - - public: - RegLossObjOneAPI() = default; - - void Configure(const std::vector >& args) override { - param_.UpdateAllowUnknown(args); - qu_ = device_manager.GetQueue(ctx_->Device()); - } - - void GetGradient(const HostDeviceVector& preds, - const MetaInfo &info, - int iter, - HostDeviceVector* out_gpair) override { - if (info.labels.Size() == 0U) { - LOG(WARNING) << "Label set is empty."; - } - CHECK_EQ(preds.Size(), info.labels.Size()) - << " " << "labels are not correctly provided" - << "preds.size=" << preds.Size() << ", label.size=" << info.labels.Size() << ", " - << "Loss: " << Loss::Name(); - - size_t const ndata = preds.Size(); - out_gpair->Resize(ndata); - - // TODO: add label_correct check - label_correct_.Resize(1); - label_correct_.Fill(1); - - bool is_null_weight = info.weights_.Size() == 0; - - sycl::buffer preds_buf(preds.HostPointer(), preds.Size()); - sycl::buffer labels_buf(info.labels.Data()->HostPointer(), info.labels.Size()); - sycl::buffer out_gpair_buf(out_gpair->HostPointer(), out_gpair->Size()); - sycl::buffer weights_buf(is_null_weight ? NULL : info.weights_.HostPointer(), - is_null_weight ? 1 : info.weights_.Size()); - - const size_t n_targets = std::max(info.labels.Shape(1), static_cast(1)); - - sycl::buffer additional_input_buf(1); - { - auto additional_input_acc = additional_input_buf.get_access(); - additional_input_acc[0] = 1; // Fill the label_correct flag - } - - auto scale_pos_weight = param_.scale_pos_weight; - if (!is_null_weight) { - CHECK_EQ(info.weights_.Size(), info.labels.Shape(0)) - << "Number of weights should be equal to number of data points."; - } - - qu_.submit([&](sycl::handler& cgh) { - auto preds_acc = preds_buf.get_access(cgh); - auto labels_acc = labels_buf.get_access(cgh); - auto weights_acc = weights_buf.get_access(cgh); - auto out_gpair_acc = out_gpair_buf.get_access(cgh); - auto additional_input_acc = additional_input_buf.get_access(cgh); - cgh.parallel_for<>(sycl::range<1>(ndata), [=](sycl::id<1> pid) { - int idx = pid[0]; - bst_float p = Loss::PredTransform(preds_acc[idx]); - bst_float w = is_null_weight ? 1.0f : weights_acc[idx/n_targets]; - bst_float label = labels_acc[idx]; - if (label == 1.0f) { - w *= scale_pos_weight; - } - if (!Loss::CheckLabel(label)) { - // If there is an incorrect label, the host code will know. - additional_input_acc[0] = 0; - } - out_gpair_acc[idx] = GradientPair(Loss::FirstOrderGradient(p, label) * w, - Loss::SecondOrderGradient(p, label) * w); - }); - }).wait(); - - int flag = 1; - { - auto additional_input_acc = additional_input_buf.get_access(); - flag = additional_input_acc[0]; - } - - if (flag == 0) { - LOG(FATAL) << Loss::LabelErrorMsg(); - } - - } - - public: - const char* DefaultEvalMetric() const override { - return Loss::DefaultEvalMetric(); - } - - void PredTransform(HostDeviceVector *io_preds) const override { - size_t const ndata = io_preds->Size(); - sycl::buffer io_preds_buf(io_preds->HostPointer(), io_preds->Size()); - - qu_.submit([&](sycl::handler& cgh) { - auto io_preds_acc = io_preds_buf.get_access(cgh); - cgh.parallel_for<>(sycl::range<1>(ndata), [=](sycl::id<1> pid) { - int idx = pid[0]; - io_preds_acc[idx] = Loss::PredTransform(io_preds_acc[idx]); - }); - }).wait(); - } - - float ProbToMargin(float base_score) const override { - return Loss::ProbToMargin(base_score); - } - - struct ObjInfo Task() const override { - return Loss::Info(); - }; - - uint32_t Targets(MetaInfo const& info) const override { - // Multi-target regression. - return std::max(static_cast(1), info.labels.Shape(1)); - } - - void SaveConfig(Json* p_out) const override { - auto& out = *p_out; - out["name"] = String(Loss::Name()); - out["reg_loss_param"] = ToJson(param_); - } - - void LoadConfig(Json const& in) override { - FromJson(in["reg_loss_param"], ¶m_); - } - - protected: - RegLossParamOneAPI param_; - DeviceManagerOneAPI device_manager; - - mutable sycl::queue qu_; -}; - -// register the objective functions -DMLC_REGISTER_PARAMETER(RegLossParamOneAPI); - -// TODO: Find a better way to dispatch names of DPC++ kernels with various template parameters of loss function -XGBOOST_REGISTER_OBJECTIVE(SquaredLossRegressionOneAPI, LinearSquareLossOneAPI::Name()) -.describe("Regression with squared error with DPC++ backend.") -.set_body([]() { return new RegLossObjOneAPI(); }); -XGBOOST_REGISTER_OBJECTIVE(SquareLogErrorOneAPI, SquaredLogErrorOneAPI::Name()) -.describe("Regression with root mean squared logarithmic error with DPC++ backend.") -.set_body([]() { return new RegLossObjOneAPI(); }); -XGBOOST_REGISTER_OBJECTIVE(LogisticRegressionOneAPI, LogisticRegressionOneAPI::Name()) -.describe("Logistic regression for probability regression task with DPC++ backend.") -.set_body([]() { return new RegLossObjOneAPI(); }); -XGBOOST_REGISTER_OBJECTIVE(LogisticClassificationOneAPI, LogisticClassificationOneAPI::Name()) -.describe("Logistic regression for binary classification task with DPC++ backend.") -.set_body([]() { return new RegLossObjOneAPI(); }); -XGBOOST_REGISTER_OBJECTIVE(LogisticRawOneAPI, LogisticRawOneAPI::Name()) -.describe("Logistic regression for classification, output score " - "before logistic transformation with DPC++ backend.") -.set_body([]() { return new RegLossObjOneAPI(); }); - -} // namespace obj -} // namespace xgboost diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f6ffc795f2ca..ce62c6fda840 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -16,9 +16,9 @@ if (USE_CUDA) target_sources(objxgboost PRIVATE ${CUDA_SOURCES}) endif (USE_CUDA) -if (PLUGIN_UPDATER_ONEAPI) - target_compile_definitions(objxgboost PRIVATE -DXGBOOST_USE_ONEAPI=1) -endif (PLUGIN_UPDATER_ONEAPI) +if (PLUGIN_SYCL) + target_compile_definitions(objxgboost PRIVATE -DXGBOOST_USE_SYCL=1) +endif (PLUGIN_SYCL) target_include_directories(objxgboost PRIVATE diff --git a/src/common/common.h b/src/common/common.h index bedff80b33d5..a92f6acf52c9 100644 --- a/src/common/common.h +++ b/src/common/common.h @@ -164,10 +164,10 @@ inline void AssertGPUSupport() { #endif // XGBOOST_USE_CUDA } -inline void AssertOneAPISupport() { -#ifndef XGBOOST_USE_ONEAPI - LOG(FATAL) << "XGBoost version not compiled with OneAPI support."; -#endif // XGBOOST_USE_ONEAPI +inline void AssertSYCLSupport() { +#ifndef XGBOOST_USE_SYCL + LOG(FATAL) << "XGBoost version not compiled with SYCL support."; +#endif // XGBOOST_USE_SYCL } void SetDevice(std::int32_t device); diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index cd1567dc0d7f..2df7540719ec 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -53,7 +53,7 @@ std::string MapTreeMethodToUpdaters(Context const* ctx, TreeMethod tree_method) case TreeMethod::kHist: { return ctx->DispatchDevice([] { return "grow_quantile_histmaker"; }, [] { return "grow_gpu_hist"; }, - [] { return "grow_quantile_histmaker_oneapi"; }); + [] { return "grow_quantile_histmaker_sycl"; }); } case TreeMethod::kApprox: { return ctx->DispatchDevice([] { return "grow_histmaker"; }, [] { return "grow_gpu_approx"; }); @@ -114,13 +114,13 @@ void GBTree::Configure(Args const& cfg) { } #endif // defined(XGBOOST_USE_CUDA) -#if defined(XGBOOST_USE_ONEAPI) - if (!oneapi_predictor_) { - oneapi_predictor_ = - std::unique_ptr(Predictor::Create("oneapi_predictor", this->ctx_)); +#if defined(XGBOOST_USE_SYCL) + if (!sycl_predictor_) { + sycl_predictor_ = + std::unique_ptr(Predictor::Create("sycl_predictor", this->ctx_)); } - oneapi_predictor_->Configure(cfg); -#endif // defined(XGBOOST_USE_ONEAPI) + sycl_predictor_->Configure(cfg); +#endif // defined(XGBOOST_USE_SYCL) // `updater` parameter was manually specified specified_updater_ = @@ -571,11 +571,11 @@ void GBTree::InplacePredict(std::shared_ptr p_m, float missing, CHECK(gpu_predictor_); return gpu_predictor_; } else { -#if defined(XGBOOST_USE_ONEAPI) - common::AssertOneAPISupport(); - CHECK(oneapi_predictor_); - return oneapi_predictor_; -#endif // defined(XGBOOST_USE_ONEAPI) +#if defined(XGBOOST_USE_SYCL) + common::AssertSYCLSupport(); + CHECK(sycl_predictor_); + return sycl_predictor_; +#endif // defined(XGBOOST_USE_SYCL) } } @@ -616,11 +616,11 @@ void GBTree::InplacePredict(std::shared_ptr p_m, float missing, CHECK(gpu_predictor_); return gpu_predictor_; } else { -#if defined(XGBOOST_USE_ONEAPI) - common::AssertOneAPISupport(); - CHECK(oneapi_predictor_); - return oneapi_predictor_; -#endif // defined(XGBOOST_USE_ONEAPI) +#if defined(XGBOOST_USE_SYCL) + common::AssertSYCLSupport(); + CHECK(sycl_predictor_); + return sycl_predictor_; +#endif // defined(XGBOOST_USE_SYCL) } return cpu_predictor_; diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index 81e568368024..1ed40951038f 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -349,9 +349,9 @@ class GBTree : public GradientBooster { // Predictors std::unique_ptr cpu_predictor_; std::unique_ptr gpu_predictor_{nullptr}; -#if defined(XGBOOST_USE_ONEAPI) - std::unique_ptr oneapi_predictor_; -#endif // defined(XGBOOST_USE_ONEAPI) +#if defined(XGBOOST_USE_SYCL) + std::unique_ptr sycl_predictor_; +#endif // defined(XGBOOST_USE_SYCL) common::Monitor monitor_; }; diff --git a/src/learner.cc b/src/learner.cc index 4ba31544ba1e..b03d13f4ee5f 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -780,7 +780,7 @@ class LearnerConfiguration : public Learner { // Once binary IO is gone, NONE of these config is useful. if (cfg_.find("num_class") != cfg_.cend() && cfg_.at("num_class") != "0" && (tparam_.objective != "multi:softprob") && - (tparam_.objective != "multi:softprob_oneapi")) { + (tparam_.objective != "multi:softprob_sycl")) { cfg_["num_output_group"] = cfg_["num_class"]; if (atoi(cfg_["num_class"].c_str()) > 1 && cfg_.count("objective") == 0) { tparam_.objective = "multi:softmax"; diff --git a/src/objective/objective.cc b/src/objective/objective.cc index 4dee8abe410e..e526e67b2048 100644 --- a/src/objective/objective.cc +++ b/src/objective/objective.cc @@ -34,7 +34,7 @@ ObjFunction* ObjFunction::Create(const std::string& name, Context const* ctx) { // Return sycl specific implementation name if possible. std::string ObjFunction::GetSyclImplementationName(const std::string& name) { - const std::string sycl_postfix = "_oneapi"; + const std::string sycl_postfix = "_sycl"; auto *e = ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->Find(name + sycl_postfix); if (e != nullptr) { // Function has specific sycl implementation diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index f268d219b77b..2294ab75d4f3 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -13,10 +13,10 @@ if (USE_CUDA) list(APPEND TEST_SOURCES ${CUDA_TEST_SOURCES}) endif (USE_CUDA) -file(GLOB_RECURSE ONEAPI_TEST_SOURCES "plugin/test_sycl_*.cc") -if (NOT PLUGIN_UPDATER_ONEAPI) - list(REMOVE_ITEM TEST_SOURCES ${ONEAPI_TEST_SOURCES}) -endif (NOT PLUGIN_UPDATER_ONEAPI) +file(GLOB_RECURSE SYCL_TEST_SOURCES "plugin/test_sycl_*.cc") +if (NOT PLUGIN_SYCL) + list(REMOVE_ITEM TEST_SOURCES ${SYCL_TEST_SOURCES}) +endif (NOT PLUGIN_SYCL) if (PLUGIN_FEDERATED) target_include_directories(testxgboost PRIVATE ${xgboost_SOURCE_DIR}/plugin/federated) diff --git a/tests/cpp/objective/test_multiclass_obj.cc b/tests/cpp/objective/test_multiclass_obj.cc index d028ef9cfa18..7ed50ed883d1 100644 --- a/tests/cpp/objective/test_multiclass_obj.cc +++ b/tests/cpp/objective/test_multiclass_obj.cc @@ -8,15 +8,19 @@ namespace xgboost { -TEST(Objective, DeclareUnifiedTest(SoftmaxMultiClassObjGPair)) { - Context ctx = MakeCUDACtx(GPUIDX); +void TestSoftmaxMultiClassObjGPair(const Context* ctx) { + std::string obj_name = "multi:softmax"; + if (ctx->IsSycl()) { + obj_name += "_sycl"; + } + std::vector> args {{"num_class", "3"}}; std::unique_ptr obj { - ObjFunction::Create("multi:softmax", &ctx) + ObjFunction::Create(obj_name, ctx) }; obj->Configure(args); - CheckConfigReload(obj, "multi:softmax"); + CheckConfigReload(obj, obj_name); CheckObjFunction(obj, {1.0f, 0.0f, 2.0f, 2.0f, 0.0f, 1.0f}, // preds @@ -35,14 +39,18 @@ TEST(Objective, DeclareUnifiedTest(SoftmaxMultiClassObjGPair)) { ASSERT_NO_THROW(obj->DefaultEvalMetric()); } -TEST(Objective, DeclareUnifiedTest(SoftmaxMultiClassBasic)) { - auto ctx = MakeCUDACtx(GPUIDX); +void TestSoftmaxMultiClassBasic(const Context* ctx) { + std::string obj_name = "multi:softmax"; + if (ctx->IsSycl()) { + obj_name += "_sycl"; + } + std::vector> args{ std::pair("num_class", "3")}; - std::unique_ptr obj{ObjFunction::Create("multi:softmax", &ctx)}; + std::unique_ptr obj{ObjFunction::Create(obj_name, ctx)}; obj->Configure(args); - CheckConfigReload(obj, "multi:softmax"); + CheckConfigReload(obj, obj_name); HostDeviceVector io_preds = {2.0f, 0.0f, 1.0f, 1.0f, 0.0f, 2.0f}; @@ -56,16 +64,20 @@ TEST(Objective, DeclareUnifiedTest(SoftmaxMultiClassBasic)) { } } -TEST(Objective, DeclareUnifiedTest(SoftprobMultiClassBasic)) { - Context ctx = MakeCUDACtx(GPUIDX); +void TestSoftprobMultiClassBasic(const Context* ctx) { + std::string obj_name = "multi:softprob"; + if (ctx->IsSycl()) { + obj_name += "_sycl"; + } + std::vector> args { std::pair("num_class", "3")}; std::unique_ptr obj { - ObjFunction::Create("multi:softprob", &ctx) + ObjFunction::Create(obj_name, ctx) }; obj->Configure(args); - CheckConfigReload(obj, "multi:softprob"); + CheckConfigReload(obj, obj_name); HostDeviceVector io_preds = {2.0f, 0.0f, 1.0f}; std::vector out_preds = {0.66524096f, 0.09003057f, 0.24472847f}; @@ -77,4 +89,20 @@ TEST(Objective, DeclareUnifiedTest(SoftprobMultiClassBasic)) { EXPECT_NEAR(preds[i], out_preds[i], 0.01f); } } + +TEST(Objective, DeclareUnifiedTest(SoftmaxMultiClassObjGPair)) { + Context ctx = MakeCUDACtx(GPUIDX); + TestSoftmaxMultiClassObjGPair(&ctx); +} + + +TEST(Objective, DeclareUnifiedTest(SoftmaxMultiClassBasic)) { + auto ctx = MakeCUDACtx(GPUIDX); + TestSoftmaxMultiClassBasic(&ctx); +} + +TEST(Objective, DeclareUnifiedTest(SoftprobMultiClassBasic)) { + Context ctx = MakeCUDACtx(GPUIDX); + TestSoftprobMultiClassBasic(&ctx); +} } // namespace xgboost diff --git a/tests/cpp/objective/test_multiclass_obj.h b/tests/cpp/objective/test_multiclass_obj.h new file mode 100644 index 000000000000..cf34f718ebeb --- /dev/null +++ b/tests/cpp/objective/test_multiclass_obj.h @@ -0,0 +1,19 @@ +/** + * Copyright 2020-2023 by XGBoost Contributors + */ +#ifndef XGBOOST_TEST_MULTICLASS_OBJ_H_ +#define XGBOOST_TEST_MULTICLASS_OBJ_H_ + +#include // for Context + +namespace xgboost { + +void TestSoftmaxMultiClassObjGPair(const Context* ctx); + +void TestSoftmaxMultiClassBasic(const Context* ctx); + +void TestSoftprobMultiClassBasic(const Context* ctx); + +} // namespace xgboost + +#endif // XGBOOST_TEST_MULTICLASS_OBJ_H_ \ No newline at end of file diff --git a/tests/cpp/objective/test_regression_obj.cc b/tests/cpp/objective/test_regression_obj.cc index b8a40603b348..ceda723d0907 100644 --- a/tests/cpp/objective/test_regression_obj.cc +++ b/tests/cpp/objective/test_regression_obj.cc @@ -16,11 +16,14 @@ namespace xgboost { -TEST(Objective, DeclareUnifiedTest(LinearRegressionGPair)) { - Context ctx = MakeCUDACtx(GPUIDX); - std::vector> args; +void TestLinearRegressionGPair(const Context* ctx) { + std::string obj_name = "reg:squarederror"; + if (ctx->IsSycl()) { + obj_name += "_sycl"; + } - std::unique_ptr obj{ObjFunction::Create("reg:squarederror", &ctx)}; + std::vector> args; + std::unique_ptr obj{ObjFunction::Create(obj_name, ctx)}; obj->Configure(args); CheckObjFunction(obj, @@ -35,16 +38,19 @@ TEST(Objective, DeclareUnifiedTest(LinearRegressionGPair)) { {}, // empty weight {0, 0.1f, 0.9f, 1.0f, -1.0f, -0.9f, -0.1f, 0}, {1, 1, 1, 1, 1, 1, 1, 1}); - ASSERT_NO_THROW(obj->DefaultEvalMetric()); + ASSERT_NO_THROW(obj->DefaultEvalMetric()); } -TEST(Objective, DeclareUnifiedTest(SquaredLog)) { - Context ctx = MakeCUDACtx(GPUIDX); +void TestSquaredLog(const Context* ctx) { + std::string obj_name = "reg:squaredlogerror"; + if (ctx->IsSycl()) { + obj_name += "_sycl"; + } std::vector> args; - std::unique_ptr obj{ObjFunction::Create("reg:squaredlogerror", &ctx)}; + std::unique_ptr obj{ObjFunction::Create(obj_name, ctx)}; obj->Configure(args); - CheckConfigReload(obj, "reg:squaredlogerror"); + CheckConfigReload(obj, obj_name); CheckObjFunction(obj, {0.1f, 0.2f, 0.4f, 0.8f, 1.6f}, // pred @@ -61,42 +67,16 @@ TEST(Objective, DeclareUnifiedTest(SquaredLog)) { ASSERT_EQ(obj->DefaultEvalMetric(), std::string{"rmsle"}); } -TEST(Objective, DeclareUnifiedTest(PseudoHuber)) { - Context ctx = MakeCUDACtx(GPUIDX); - Args args; - - std::unique_ptr obj{ObjFunction::Create("reg:pseudohubererror", &ctx)}; - obj->Configure(args); - CheckConfigReload(obj, "reg:pseudohubererror"); - - CheckObjFunction(obj, {0.1f, 0.2f, 0.4f, 0.8f, 1.6f}, // pred - {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // labels - {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // weights - {-0.668965f, -0.624695f, -0.514496f, -0.196116f, 0.514496f}, // out_grad - {0.410660f, 0.476140f, 0.630510f, 0.9428660f, 0.630510f}); // out_hess - CheckObjFunction(obj, {0.1f, 0.2f, 0.4f, 0.8f, 1.6f}, // pred - {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // labels - {}, // empty weights - {-0.668965f, -0.624695f, -0.514496f, -0.196116f, 0.514496f}, // out_grad - {0.410660f, 0.476140f, 0.630510f, 0.9428660f, 0.630510f}); // out_hess - ASSERT_EQ(obj->DefaultEvalMetric(), std::string{"mphe"}); - - obj->Configure({{"huber_slope", "0.1"}}); - CheckConfigReload(obj, "reg:pseudohubererror"); - CheckObjFunction(obj, {0.1f, 0.2f, 0.4f, 0.8f, 1.6f}, // pred - {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // labels - {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // weights - {-0.099388f, -0.099228f, -0.098639f, -0.089443f, 0.098639f}, // out_grad - {0.0013467f, 0.001908f, 0.004443f, 0.089443f, 0.004443f}); // out_hess -} - -TEST(Objective, DeclareUnifiedTest(LogisticRegressionGPair)) { - Context ctx = MakeCUDACtx(GPUIDX); +void TestLogisticRegressionGPair(const Context* ctx) { + std::string obj_name = "reg:logistic"; + if (ctx->IsSycl()) { + obj_name += "_sycl"; + } std::vector> args; - std::unique_ptr obj{ObjFunction::Create("reg:logistic", &ctx)}; + std::unique_ptr obj{ObjFunction::Create(obj_name, ctx)}; obj->Configure(args); - CheckConfigReload(obj, "reg:logistic"); + CheckConfigReload(obj, obj_name); CheckObjFunction(obj, { 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, // preds @@ -106,13 +86,17 @@ TEST(Objective, DeclareUnifiedTest(LogisticRegressionGPair)) { {0.25f, 0.24f, 0.20f, 0.19f, 0.25f, 0.24f, 0.20f, 0.19f}); // out_hess } -TEST(Objective, DeclareUnifiedTest(LogisticRegressionBasic)) { - Context ctx = MakeCUDACtx(GPUIDX); +void TestLogisticRegressionBasic(const Context* ctx) { + std::string obj_name = "reg:logistic"; + if (ctx->IsSycl()) { + obj_name += "_sycl"; + } + std::vector> args; - std::unique_ptr obj{ObjFunction::Create("reg:logistic", &ctx)}; + std::unique_ptr obj{ObjFunction::Create(obj_name, ctx)}; obj->Configure(args); - CheckConfigReload(obj, "reg:logistic"); + CheckConfigReload(obj, obj_name); // test label validation EXPECT_ANY_THROW(CheckObjFunction(obj, {0}, {10}, {1}, {0}, {0})) @@ -135,11 +119,15 @@ TEST(Objective, DeclareUnifiedTest(LogisticRegressionBasic)) { } } -TEST(Objective, DeclareUnifiedTest(LogisticRawGPair)) { - Context ctx = MakeCUDACtx(GPUIDX); +void TestsLogisticRawGPair(const Context* ctx) { + std::string obj_name = "reg:logistic"; + if (ctx->IsSycl()) { + obj_name += "_sycl"; + } + std::vector> args; - std::unique_ptr obj { - ObjFunction::Create("binary:logitraw", &ctx) + std::unique_ptr obj { + ObjFunction::Create(obj_name, ctx) }; obj->Configure(args); @@ -151,6 +139,60 @@ TEST(Objective, DeclareUnifiedTest(LogisticRawGPair)) { {0.25f, 0.24f, 0.20f, 0.19f, 0.25f, 0.24f, 0.20f, 0.19f}); } +TEST(Objective, DeclareUnifiedTest(LinearRegressionGPair)) { + Context ctx = MakeCUDACtx(GPUIDX); + TestLinearRegressionGPair(&ctx); +} + +TEST(Objective, DeclareUnifiedTest(SquaredLog)) { + Context ctx = MakeCUDACtx(GPUIDX); + TestSquaredLog(&ctx); +} + +TEST(Objective, DeclareUnifiedTest(PseudoHuber)) { + Context ctx = MakeCUDACtx(GPUIDX); + Args args; + + std::unique_ptr obj{ObjFunction::Create("reg:pseudohubererror", &ctx)}; + obj->Configure(args); + CheckConfigReload(obj, "reg:pseudohubererror"); + + CheckObjFunction(obj, {0.1f, 0.2f, 0.4f, 0.8f, 1.6f}, // pred + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // labels + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // weights + {-0.668965f, -0.624695f, -0.514496f, -0.196116f, 0.514496f}, // out_grad + {0.410660f, 0.476140f, 0.630510f, 0.9428660f, 0.630510f}); // out_hess + CheckObjFunction(obj, {0.1f, 0.2f, 0.4f, 0.8f, 1.6f}, // pred + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // labels + {}, // empty weights + {-0.668965f, -0.624695f, -0.514496f, -0.196116f, 0.514496f}, // out_grad + {0.410660f, 0.476140f, 0.630510f, 0.9428660f, 0.630510f}); // out_hess + ASSERT_EQ(obj->DefaultEvalMetric(), std::string{"mphe"}); + + obj->Configure({{"huber_slope", "0.1"}}); + CheckConfigReload(obj, "reg:pseudohubererror"); + CheckObjFunction(obj, {0.1f, 0.2f, 0.4f, 0.8f, 1.6f}, // pred + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // labels + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // weights + {-0.099388f, -0.099228f, -0.098639f, -0.089443f, 0.098639f}, // out_grad + {0.0013467f, 0.001908f, 0.004443f, 0.089443f, 0.004443f}); // out_hess +} + +TEST(Objective, DeclareUnifiedTest(LogisticRegressionGPair)) { + Context ctx = MakeCUDACtx(GPUIDX); + TestLogisticRegressionGPair(&ctx); +} + +TEST(Objective, DeclareUnifiedTest(LogisticRegressionBasic)) { + Context ctx = MakeCUDACtx(GPUIDX); + TestLogisticRegressionBasic(&ctx); +} + +TEST(Objective, DeclareUnifiedTest(LogisticRawGPair)) { + Context ctx = MakeCUDACtx(GPUIDX); + TestsLogisticRawGPair(&ctx); +} + TEST(Objective, DeclareUnifiedTest(PoissonRegressionGPair)) { Context ctx = MakeCUDACtx(GPUIDX); std::vector> args; diff --git a/tests/cpp/objective/test_regression_obj.h b/tests/cpp/objective/test_regression_obj.h new file mode 100644 index 000000000000..13056c4bbfc1 --- /dev/null +++ b/tests/cpp/objective/test_regression_obj.h @@ -0,0 +1,23 @@ +/** + * Copyright 2020-2023 by XGBoost Contributors + */ +#ifndef XGBOOST_TEST_REGRESSION_OBJ_H_ +#define XGBOOST_TEST_REGRESSION_OBJ_H_ + +#include // for Context + +namespace xgboost { + +void TestLinearRegressionGPair(const Context* ctx); + +void TestSquaredLog(const Context* ctx); + +void TestLogisticRegressionGPair(const Context* ctx); + +void TestLogisticRegressionBasic(const Context* ctx); + +void TestsLogisticRawGPair(const Context* ctx); + +} // namespace xgboost + +#endif // XGBOOST_TEST_REGRESSION_OBJ_H_ \ No newline at end of file diff --git a/tests/cpp/plugin/test_sycl_multiclass_obj.cc b/tests/cpp/plugin/test_sycl_multiclass_obj.cc new file mode 100644 index 000000000000..fadfc6d41c96 --- /dev/null +++ b/tests/cpp/plugin/test_sycl_multiclass_obj.cc @@ -0,0 +1,28 @@ +/*! + * Copyright 2018-2019 XGBoost contributors + */ +#include +#include + +#include "../objective/test_multiclass_obj.h" + +namespace xgboost { + +TEST(SyclObjective, SoftmaxMultiClassObjGPair) { + Context ctx; + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + TestSoftmaxMultiClassObjGPair(&ctx); +} + +TEST(SyclObjective, SoftmaxMultiClassBasic) { + Context ctx; + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + TestSoftmaxMultiClassObjGPair(&ctx); +} + +TEST(SyclObjective, SoftprobMultiClassBasic) { + Context ctx; + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + TestSoftprobMultiClassBasic(&ctx); +} +} // namespace xgboost diff --git a/tests/cpp/plugin/test_sycl_regression_obj.cc b/tests/cpp/plugin/test_sycl_regression_obj.cc index 0b5b6bf20776..00041395f46f 100755 --- a/tests/cpp/plugin/test_sycl_regression_obj.cc +++ b/tests/cpp/plugin/test_sycl_regression_obj.cc @@ -4,131 +4,48 @@ #include #include #include -#include + #include "../helpers.h" +#include "../objective/test_regression_obj.h" + namespace xgboost { TEST(SyclObjective, LinearRegressionGPair) { Context ctx; ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); - std::vector> args; - - std::unique_ptr obj { - ObjFunction::Create("reg:squarederror_oneapi", &ctx) - }; - - obj->Configure(args); - CheckObjFunction(obj, - {0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, - {0, 0, 0, 0, 1, 1, 1, 1}, - {1, 1, 1, 1, 1, 1, 1, 1}, - {0, 0.1f, 0.9f, 1.0f, -1.0f, -0.9f, -0.1f, 0}, - {1, 1, 1, 1, 1, 1, 1, 1}); - CheckObjFunction(obj, - {0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, - {0, 0, 0, 0, 1, 1, 1, 1}, - {}, // empty weight - {0, 0.1f, 0.9f, 1.0f, -1.0f, -0.9f, -0.1f, 0}, - {1, 1, 1, 1, 1, 1, 1, 1}); - ASSERT_NO_THROW(obj->DefaultEvalMetric()); + TestLinearRegressionGPair(&ctx); } TEST(SyclObjective, SquaredLog) { Context ctx; ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); - std::vector> args; - - std::unique_ptr obj { ObjFunction::Create("reg:squaredlogerror_oneapi", &ctx) }; - obj->Configure(args); - CheckConfigReload(obj, "reg:squaredlogerror_oneapi"); - - CheckObjFunction(obj, - {0.1f, 0.2f, 0.4f, 0.8f, 1.6f}, // pred - {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // labels - {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // weights - {-0.5435f, -0.4257f, -0.25475f, -0.05855f, 0.1009f}, - { 1.3205f, 1.0492f, 0.69215f, 0.34115f, 0.1091f}); - CheckObjFunction(obj, - {0.1f, 0.2f, 0.4f, 0.8f, 1.6f}, // pred - {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // labels - {}, // empty weights - {-0.5435f, -0.4257f, -0.25475f, -0.05855f, 0.1009f}, - { 1.3205f, 1.0492f, 0.69215f, 0.34115f, 0.1091f}); - ASSERT_EQ(obj->DefaultEvalMetric(), std::string{"rmsle"}); + TestSquaredLog(&ctx); } TEST(SyclObjective, LogisticRegressionGPair) { Context ctx; ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); - std::vector> args; - std::unique_ptr obj { ObjFunction::Create("reg:logistic_oneapi", &ctx) }; - - obj->Configure(args); - CheckConfigReload(obj, "reg:logistic_oneapi"); - - CheckObjFunction(obj, - { 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, // preds - { 0, 0 , 0, 0, 1, 1, 1, 1}, // labels - { 1, 1, 1, 1, 1, 1, 1, 1}, // weights - { 0.5f, 0.52f, 0.71f, 0.73f, -0.5f, -0.47f, -0.28f, -0.26f}, // out_grad - {0.25f, 0.24f, 0.20f, 0.19f, 0.25f, 0.24f, 0.20f, 0.19f}); // out_hess + TestLogisticRegressionGPair(&ctx); } TEST(SyclObjective, LogisticRegressionBasic) { Context ctx; ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); - std::vector> args; - std::unique_ptr obj { - ObjFunction::Create("reg:logistic_oneapi", &ctx) - }; - - obj->Configure(args); - CheckConfigReload(obj, "reg:logistic_oneapi"); - - // test label validation - EXPECT_ANY_THROW(CheckObjFunction(obj, {0}, {10}, {1}, {0}, {0})) - << "Expected error when label not in range [0,1f] for LogisticRegression"; - - // test ProbToMargin - EXPECT_NEAR(obj->ProbToMargin(0.1f), -2.197f, 0.01f); - EXPECT_NEAR(obj->ProbToMargin(0.5f), 0, 0.01f); - EXPECT_NEAR(obj->ProbToMargin(0.9f), 2.197f, 0.01f); - EXPECT_ANY_THROW(obj->ProbToMargin(10)) - << "Expected error when base_score not in range [0,1f] for LogisticRegression"; - - // test PredTransform - HostDeviceVector io_preds = {0, 0.1f, 0.5f, 0.9f, 1}; - std::vector out_preds = {0.5f, 0.524f, 0.622f, 0.710f, 0.731f}; - obj->PredTransform(&io_preds); - auto& preds = io_preds.HostVector(); - for (int i = 0; i < static_cast(io_preds.Size()); ++i) { - EXPECT_NEAR(preds[i], out_preds[i], 0.01f); - } + + TestLogisticRegressionBasic(&ctx); } TEST(SyclObjective, LogisticRawGPair) { Context ctx; ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); - std::vector> args; - std::unique_ptr obj { - ObjFunction::Create("binary:logitraw_oneapi", &ctx) - }; - - obj->Configure(args); - - CheckObjFunction(obj, - { 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, - { 0, 0, 0, 0, 1, 1, 1, 1}, - { 1, 1, 1, 1, 1, 1, 1, 1}, - { 0.5f, 0.52f, 0.71f, 0.73f, -0.5f, -0.47f, -0.28f, -0.26f}, - {0.25f, 0.24f, 0.20f, 0.19f, 0.25f, 0.24f, 0.20f, 0.19f}); + TestsLogisticRawGPair(&ctx); } TEST(SyclObjective, CPUvsSycl) { Context ctx; ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); ObjFunction * obj_sycl = - ObjFunction::Create("reg:squarederror_oneapi", &ctx); + ObjFunction::Create("reg:squarederror_sycl", &ctx); ctx = ctx.MakeCPU(); ObjFunction * obj_cpu = diff --git a/tests/cpp/predictor/test_predictor.h b/tests/cpp/predictor/test_predictor.h index d7cddf50c1c1..1634789be581 100644 --- a/tests/cpp/predictor/test_predictor.h +++ b/tests/cpp/predictor/test_predictor.h @@ -35,7 +35,7 @@ inline auto CreatePredictorForTest(Context const* ctx) { if (ctx->IsCPU()) { return Predictor::Create("cpu_predictor", ctx); } else if (ctx->IsSycl()) { - return Predictor::Create("oneapi_predictor", ctx); + return Predictor::Create("sycl_predictor", ctx); } else { return Predictor::Create("gpu_predictor", ctx); } diff --git a/tests/python-oneapi/test_oneapi_prediction.py b/tests/python-sycl/test_sycl_prediction.py similarity index 88% rename from tests/python-oneapi/test_oneapi_prediction.py rename to tests/python-sycl/test_sycl_prediction.py index 6eee485b618b..bfa3f18cc697 100644 --- a/tests/python-oneapi/test_oneapi_prediction.py +++ b/tests/python-sycl/test_sycl_prediction.py @@ -17,7 +17,7 @@ }).filter(lambda x: x['max_depth'] > 0 or x['max_leaves'] > 0) -class TestOneAPIPredict(unittest.TestCase): +class TestSYCLPredict(unittest.TestCase): def test_predict(self): iterations = 10 np.random.seed(1) @@ -49,15 +49,15 @@ def test_predict(self): cpu_pred_val = bst.predict(dval, output_margin=True) bst.set_param({"device": "sycl:gpu"}) - oneapi_pred_train = bst.predict(dtrain, output_margin=True) - oneapi_pred_test = bst.predict(dtest, output_margin=True) - oneapi_pred_val = bst.predict(dval, output_margin=True) + sycl_pred_train = bst.predict(dtrain, output_margin=True) + sycl_pred_test = bst.predict(dtest, output_margin=True) + sycl_pred_val = bst.predict(dval, output_margin=True) - np.testing.assert_allclose(cpu_pred_train, oneapi_pred_train, + np.testing.assert_allclose(cpu_pred_train, sycl_pred_train, rtol=1e-6) - np.testing.assert_allclose(cpu_pred_val, oneapi_pred_val, + np.testing.assert_allclose(cpu_pred_val, sycl_pred_val, rtol=1e-6) - np.testing.assert_allclose(cpu_pred_test, oneapi_pred_test, + np.testing.assert_allclose(cpu_pred_test, sycl_pred_test, rtol=1e-6) def non_increasing(self, L): @@ -109,17 +109,17 @@ def test_sklearn(self): cpu_train_score = m.score(X_train, y_train) cpu_test_score = m.score(X_test, y_test) - # Now with oneapi_predictor + # Now with sycl_predictor params['device'] = 'sycl:gpu' m.set_params(**params) # m = xgb.XGBRegressor(**params).fit(X_train, y_train) - oneapi_train_score = m.score(X_train, y_train) + sycl_train_score = m.score(X_train, y_train) # m = xgb.XGBRegressor(**params).fit(X_train, y_train) - oneapi_test_score = m.score(X_test, y_test) + sycl_test_score = m.score(X_test, y_test) - assert np.allclose(cpu_train_score, oneapi_train_score) - assert np.allclose(cpu_test_score, oneapi_test_score) + assert np.allclose(cpu_train_score, sycl_train_score) + assert np.allclose(cpu_test_score, sycl_test_score) @given(strategies.integers(1, 10), tm.make_dataset_strategy(), shap_parameter_strategy) diff --git a/tests/python-oneapi/test_oneapi_training_continuation.py b/tests/python-sycl/test_sycl_training_continuation.py similarity index 93% rename from tests/python-oneapi/test_oneapi_training_continuation.py rename to tests/python-sycl/test_sycl_training_continuation.py index 2ce809b076ce..5ce042e1bf15 100644 --- a/tests/python-oneapi/test_oneapi_training_continuation.py +++ b/tests/python-sycl/test_sycl_training_continuation.py @@ -5,7 +5,7 @@ rng = np.random.RandomState(1994) -class TestOneAPITrainingContinuation: +class TestSYCLTrainingContinuation: def run_training_continuation(self, use_json): kRows = 64 kCols = 32 @@ -49,8 +49,8 @@ def recursive_compare(obj_0, obj_1): obj_1 = json.loads(dump_1[i]) recursive_compare(obj_0, obj_1) - def test_oneapi_training_continuation_binary(self): + def test_sycl_training_continuation_binary(self): self.run_training_continuation(False) - def test_oneapi_training_continuation_json(self): + def test_sycl_training_continuation_json(self): self.run_training_continuation(True) diff --git a/tests/python-oneapi/test_oneapi_updaters.py b/tests/python-sycl/test_sycl_updaters.py similarity index 94% rename from tests/python-oneapi/test_oneapi_updaters.py rename to tests/python-sycl/test_sycl_updaters.py index 282f7cfb2264..3f9ae59cc5dd 100644 --- a/tests/python-oneapi/test_oneapi_updaters.py +++ b/tests/python-sycl/test_sycl_updaters.py @@ -33,11 +33,11 @@ def train_result(param, dmat, num_rounds): return result -class TestOneAPIUpdaters: +class TestSYCLUpdaters: @given(parameter_strategy, strategies.integers(1, 5), tm.make_dataset_strategy()) @settings(deadline=None) - def test_oneapi_hist(self, param, num_rounds, dataset): + def test_sycl_hist(self, param, num_rounds, dataset): param['tree_method'] = 'hist' param['device'] = 'sycl:gpu' param['verbosity'] = 0 @@ -48,7 +48,7 @@ def test_oneapi_hist(self, param, num_rounds, dataset): @given(tm.make_dataset_strategy(), strategies.integers(0, 1)) @settings(deadline=None) - def test_specified_device_id_oneapi_update(self, dataset, device_id): + def test_specified_device_id_sycl_update(self, dataset, device_id): # Read the list of sycl-devicese sycl_ls = os.popen('sycl-ls').read() devices = sycl_ls.split('\n') diff --git a/tests/python-oneapi/test_oneapi_with_sklearn.py b/tests/python-sycl/test_sycl_with_sklearn.py similarity index 96% rename from tests/python-oneapi/test_oneapi_with_sklearn.py rename to tests/python-sycl/test_sycl_with_sklearn.py index 0f71efee72de..bc34ed46a5b7 100644 --- a/tests/python-oneapi/test_oneapi_with_sklearn.py +++ b/tests/python-sycl/test_sycl_with_sklearn.py @@ -12,7 +12,7 @@ rng = np.random.RandomState(1994) -def test_oneapi_binary_classification(): +def test_sycl_binary_classification(): from sklearn.datasets import load_digits from sklearn.model_selection import KFold