Skip to content

Commit

Permalink
[EM] CPU implementation for external memory QDM. (dmlc#10682)
Browse files Browse the repository at this point in the history
- A new DMatrix type.
- Extract common code into a new QDM base class.

Not yet working:
- Not exposed to the interface yet, will wait for the GPU implementation.
- ~No meta info yet, still working on the source.~
- Exporting data to CSR is not supported yet.
  • Loading branch information
trivialfis authored Aug 9, 2024
1 parent ac83666 commit 7bccc1e
Show file tree
Hide file tree
Showing 33 changed files with 1,199 additions and 498 deletions.
3 changes: 3 additions & 0 deletions R-package/src/Makevars.in
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ OBJECTS= \
$(PKGROOT)/src/data/gradient_index_format.o \
$(PKGROOT)/src/data/sparse_page_dmatrix.o \
$(PKGROOT)/src/data/sparse_page_source.o \
$(PKGROOT)/src/data/extmem_quantile_dmatrix.o \
$(PKGROOT)/src/data/quantile_dmatrix.o \
$(PKGROOT)/src/data/batch_utils.o \
$(PKGROOT)/src/data/proxy_dmatrix.o \
$(PKGROOT)/src/data/iterative_dmatrix.o \
$(PKGROOT)/src/predictor/predictor.o \
Expand Down
3 changes: 3 additions & 0 deletions R-package/src/Makevars.win
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ OBJECTS= \
$(PKGROOT)/src/data/gradient_index_format.o \
$(PKGROOT)/src/data/sparse_page_dmatrix.o \
$(PKGROOT)/src/data/sparse_page_source.o \
$(PKGROOT)/src/data/extmem_quantile_dmatrix.o \
$(PKGROOT)/src/data/quantile_dmatrix.o \
$(PKGROOT)/src/data/batch_utils.o \
$(PKGROOT)/src/data/proxy_dmatrix.o \
$(PKGROOT)/src/data/iterative_dmatrix.o \
$(PKGROOT)/src/predictor/predictor.o \
Expand Down
84 changes: 51 additions & 33 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <xgboost/string_view.h>

#include <algorithm>
#include <cstdint> // for int32_t, uint8_t
#include <limits>
#include <memory>
#include <string>
Expand Down Expand Up @@ -499,8 +500,12 @@ class BatchSet {

struct XGBAPIThreadLocalEntry;

/*!
* \brief Internal data structured used by XGBoost during training.
/**
* @brief Internal data structured used by XGBoost to hold all external data.
*
* There are multiple variants of the DMatrix class and can be accessed through the
* @ref Create() methods. The DMatrix itself holds the predictor `X`, and other data
* including labels and sample weights are stored in the @ref MetaInfo class.
*/
class DMatrix {
public:
Expand All @@ -518,13 +523,13 @@ class DMatrix {
/*! \brief Get thread local memory for returning data from DMatrix. */
[[nodiscard]] XGBAPIThreadLocalEntry& GetThreadLocal() const;
/**
* \brief Get the context object of this DMatrix. The context is created during construction of
* @brief Get the context object of this DMatrix. The context is created during construction of
* DMatrix with user specified `nthread` parameter.
*/
[[nodiscard]] virtual Context const* Ctx() const = 0;

/**
* \brief Gets batches. Use range based for loop over BatchSet to access individual batches.
* @brief Gets batches. Use range based for loop over BatchSet to access individual batches.
*/
template <typename T>
BatchSet<T> GetBatches();
Expand All @@ -548,57 +553,57 @@ class DMatrix {
[[nodiscard]] bool IsDense() const { return this->Info().IsDense(); }

/**
* \brief Load DMatrix from URI.
* @brief Load DMatrix from URI.
*
* \param uri The URI of input.
* \param silent Whether print information during loading.
* \param data_split_mode Indicate how the data was split beforehand.
* \return The created DMatrix.
* @param uri The URI of input.
* @param silent Whether print information during loading.
* @param data_split_mode Indicate how the data was split beforehand.
* @return The created DMatrix.
*/
static DMatrix* Load(const std::string& uri, bool silent = true,
DataSplitMode data_split_mode = DataSplitMode::kRow);

/**
* \brief Creates a new DMatrix from an external data adapter.
* @brief Creates a new DMatrix from an external data adapter.
*
* \tparam AdapterT Type of the adapter.
* \param [in,out] adapter View onto an external data.
* \param missing Values to count as missing.
* \param nthread Number of threads for construction.
* \param cache_prefix (Optional) The cache prefix for external memory.
* \param data_split_mode (Optional) Data split mode.
* @tparam AdapterT Type of the adapter.
* @param [in,out] adapter View onto an external data.
* @param missing Values to count as missing.
* @param nthread Number of threads for construction.
* @param cache_prefix (Optional) The cache prefix for external memory.
* @param data_split_mode (Optional) Data split mode.
*
* \return a Created DMatrix.
* @return a Created DMatrix.
*/
template <typename AdapterT>
static DMatrix* Create(AdapterT* adapter, float missing, int nthread,
const std::string& cache_prefix = "",
DataSplitMode data_split_mode = DataSplitMode::kRow);

/**
* \brief Create a new Quantile based DMatrix used for histogram based algorithm.
* @brief Create a new Quantile based DMatrix used for histogram based algorithm.
*
* \tparam DataIterHandle External iterator type, defined in C API.
* \tparam DMatrixHandle DMatrix handle, defined in C API.
* \tparam DataIterResetCallback Callback for reset, prototype defined in C API.
* \tparam XGDMatrixCallbackNext Callback for next, prototype defined in C API.
* @tparam DataIterHandle External iterator type, defined in C API.
* @tparam DMatrixHandle DMatrix handle, defined in C API.
* @tparam DataIterResetCallback Callback for reset, prototype defined in C API.
* @tparam XGDMatrixCallbackNext Callback for next, prototype defined in C API.
*
* \param iter External data iterator
* \param proxy A hanlde to ProxyDMatrix
* \param ref Reference Quantile DMatrix.
* \param reset Callback for reset
* \param next Callback for next
* \param missing Value that should be treated as missing.
* \param nthread number of threads used for initialization.
* \param max_bin Maximum number of bins.
* @param iter External data iterator
* @param proxy A hanlde to ProxyDMatrix
* @param ref Reference Quantile DMatrix.
* @param reset Callback for reset
* @param next Callback for next
* @param missing Value that should be treated as missing.
* @param nthread number of threads used for initialization.
* @param max_bin Maximum number of bins.
*
* \return A created quantile based DMatrix.
* @return A created quantile based DMatrix.
*/
template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,
typename XGDMatrixCallbackNext>
static DMatrix* Create(DataIterHandle iter, DMatrixHandle proxy, std::shared_ptr<DMatrix> ref,
DataIterResetCallback* reset, XGDMatrixCallbackNext* next, float missing,
int nthread, bst_bin_t max_bin);
std::int32_t nthread, bst_bin_t max_bin);

/**
* @brief Create an external memory DMatrix with callbacks.
Expand All @@ -622,9 +627,22 @@ class DMatrix {
template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,
typename XGDMatrixCallbackNext>
static DMatrix* Create(DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback* reset,
XGDMatrixCallbackNext* next, float missing, int32_t nthread,
XGDMatrixCallbackNext* next, float missing, std::int32_t nthread,
std::string cache, bool on_host);

/**
* @brief Create an external memory quantile DMatrix with callbacks.
*
* Parameters are a combination of the external memory DMatrix and the quantile DMatrix.
*
* @return A created external memory quantile DMatrix.
*/
template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,
typename XGDMatrixCallbackNext>
static DMatrix* Create(DataIterHandle iter, DMatrixHandle proxy, std::shared_ptr<DMatrix> ref,
DataIterResetCallback* reset, XGDMatrixCallbackNext* next, float missing,
std::int32_t nthread, bst_bin_t max_bin, std::string cache);

virtual DMatrix *Slice(common::Span<int32_t const> ridxs) = 0;

/**
Expand Down
13 changes: 13 additions & 0 deletions src/data/batch_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
/**
* Copyright 2023-2024, XGBoost Contributors
*/
#include "batch_utils.h"

#include "../common/error_msg.h" // for InconsistentMaxBin

namespace xgboost::data::detail {
void CheckParam(BatchParam const& init, BatchParam const& param) {
CHECK_EQ(param.max_bin, init.max_bin) << error::InconsistentMaxBin();
CHECK(!param.regen && param.hess.empty()) << "Only `hist` tree method can use `QuantileDMatrix`.";
}
} // namespace xgboost::data::detail
5 changes: 5 additions & 0 deletions src/data/batch_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,10 @@ inline bool RegenGHist(BatchParam old, BatchParam p) {
}
return p.regen || old.ParamNotEqual(p);
}

/**
* @brief Validate the batch parameter from the caller
*/
void CheckParam(BatchParam const& init, BatchParam const& param);
} // namespace xgboost::data::detail
#endif // XGBOOST_DATA_BATCH_UTILS_H_
89 changes: 52 additions & 37 deletions src/data/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,43 +17,44 @@
#include <tuple> // for get, apply
#include <type_traits> // for remove_pointer_t, remove_reference

#include "../collective/communicator-inl.h" // for GetRank, GetWorldSize, Allreduce, IsFederated
#include "../collective/allgather.h"
#include "../collective/allreduce.h"
#include "../common/algorithm.h" // for StableSort
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
#include "../common/error_msg.h" // for GroupSize, GroupWeight, InfInData
#include "../common/group_data.h" // for ParallelGroupBuilder
#include "../common/io.h" // for PeekableInStream
#include "../common/linalg_op.h" // for ElementWiseTransformHost
#include "../common/math.h" // for CheckNAN
#include "../common/numeric.h" // for Iota, RunLengthEncode
#include "../common/threading_utils.h" // for ParallelFor
#include "../common/version.h" // for Version
#include "../data/adapter.h" // for COOTuple, FileAdapter, IsValidFunctor
#include "../data/iterative_dmatrix.h" // for IterativeDMatrix
#include "./sparse_page_dmatrix.h" // for SparsePageDMatrix
#include "array_interface.h" // for ArrayInterfaceHandler, ArrayInterface, Dispa...
#include "dmlc/base.h" // for BeginPtr
#include "dmlc/common.h" // for OMPException
#include "dmlc/data.h" // for Parser
#include "dmlc/endian.h" // for ByteSwap, DMLC_IO_NO_ENDIAN_SWAP
#include "dmlc/io.h" // for Stream
#include "dmlc/thread_local.h" // for ThreadLocalStore
#include "ellpack_page.h" // for EllpackPage
#include "file_iterator.h" // for ValidateFileFormat, FileIterator, Next, Reset
#include "gradient_index.h" // for GHistIndexMatrix
#include "simple_dmatrix.h" // for SimpleDMatrix
#include "sparse_page_writer.h" // for SparsePageFormatReg
#include "validation.h" // for LabelsCheck, WeightsCheck, ValidateQueryGroup
#include "xgboost/base.h" // for bst_group_t, bst_idx_t, bst_float, bst_ulong
#include "xgboost/context.h" // for Context
#include "xgboost/host_device_vector.h" // for HostDeviceVector
#include "xgboost/learner.h" // for HostDeviceVector
#include "xgboost/linalg.h" // for Tensor, Stack, TensorView, Vector, ArrayInte...
#include "xgboost/logging.h" // for Error, LogCheck_EQ, CHECK, CHECK_EQ, LOG
#include "xgboost/span.h" // for Span, operator!=, SpanIterator
#include "xgboost/string_view.h" // for operator==, operator<<, StringView
#include "../collective/allgather.h" // for AllgatherStrings
#include "../collective/allreduce.h" // for Allreduce
#include "../collective/communicator-inl.h" // for GetRank, IsFederated
#include "../common/algorithm.h" // for StableSort
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
#include "../common/error_msg.h" // for GroupSize, GroupWeight, InfInData
#include "../common/group_data.h" // for ParallelGroupBuilder
#include "../common/io.h" // for PeekableInStream
#include "../common/linalg_op.h" // for ElementWiseTransformHost
#include "../common/math.h" // for CheckNAN
#include "../common/numeric.h" // for Iota, RunLengthEncode
#include "../common/threading_utils.h" // for ParallelFor
#include "../common/version.h" // for Version
#include "../data/adapter.h" // for COOTuple, FileAdapter, IsValidFunctor
#include "../data/extmem_quantile_dmatrix.h" // for ExtMemQuantileDMatrix
#include "../data/iterative_dmatrix.h" // for IterativeDMatrix
#include "./sparse_page_dmatrix.h" // for SparsePageDMatrix
#include "array_interface.h" // for ArrayInterfaceHandler, ArrayInterface, Dispa...
#include "dmlc/base.h" // for BeginPtr
#include "dmlc/common.h" // for OMPException
#include "dmlc/data.h" // for Parser
#include "dmlc/endian.h" // for ByteSwap, DMLC_IO_NO_ENDIAN_SWAP
#include "dmlc/io.h" // for Stream
#include "dmlc/thread_local.h" // for ThreadLocalStore
#include "ellpack_page.h" // for EllpackPage
#include "file_iterator.h" // for ValidateFileFormat, FileIterator, Next, Reset
#include "gradient_index.h" // for GHistIndexMatrix
#include "simple_dmatrix.h" // for SimpleDMatrix
#include "sparse_page_writer.h" // for SparsePageFormatReg
#include "validation.h" // for LabelsCheck, WeightsCheck, ValidateQueryGroup
#include "xgboost/base.h" // for bst_group_t, bst_idx_t, bst_float, bst_ulong
#include "xgboost/context.h" // for Context
#include "xgboost/host_device_vector.h" // for HostDeviceVector
#include "xgboost/learner.h" // for HostDeviceVector
#include "xgboost/linalg.h" // for Tensor, Stack, TensorView, Vector, ArrayInte...
#include "xgboost/logging.h" // for Error, LogCheck_EQ, CHECK, CHECK_EQ, LOG
#include "xgboost/span.h" // for Span, operator!=, SpanIterator
#include "xgboost/string_view.h" // for operator==, operator<<, StringView

namespace dmlc {
DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg<::xgboost::SparsePage>);
Expand Down Expand Up @@ -909,6 +910,15 @@ DMatrix* DMatrix::Create(DataIterHandle iter, DMatrixHandle proxy, DataIterReset
return new data::SparsePageDMatrix{iter, proxy, reset, next, missing, n_threads, cache, on_host};
}

template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,
typename XGDMatrixCallbackNext>
DMatrix* DMatrix::Create(DataIterHandle iter, DMatrixHandle proxy, std::shared_ptr<DMatrix> ref,
DataIterResetCallback* reset, XGDMatrixCallbackNext* next, float missing,
std::int32_t nthread, bst_bin_t max_bin, std::string cache) {
return new data::ExtMemQuantileDMatrix{
iter, proxy, ref, reset, next, missing, nthread, std::move(cache), max_bin};
}

template DMatrix* DMatrix::Create<DataIterHandle, DMatrixHandle, DataIterResetCallback,
XGDMatrixCallbackNext>(DataIterHandle iter, DMatrixHandle proxy,
std::shared_ptr<DMatrix> ref,
Expand All @@ -922,6 +932,11 @@ template DMatrix* DMatrix::Create<DataIterHandle, DMatrixHandle, DataIterResetCa
XGDMatrixCallbackNext* next, float missing,
int32_t n_threads, std::string, bool);

template DMatrix*
DMatrix::Create<DataIterHandle, DMatrixHandle, DataIterResetCallback, XGDMatrixCallbackNext>(
DataIterHandle, DMatrixHandle, std::shared_ptr<DMatrix>, DataIterResetCallback*,
XGDMatrixCallbackNext*, float, std::int32_t, bst_bin_t, std::string);

template <typename AdapterT>
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread, const std::string&,
DataSplitMode data_split_mode) {
Expand Down
2 changes: 1 addition & 1 deletion src/data/ellpack_page.cu
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ EllpackDeviceAccessor EllpackPageImpl::GetHostAccessor(
CHECK_EQ(h_gidx_buffer->size(), gidx_buffer.size());
CHECK_NE(gidx_buffer.size(), 0);
dh::safe_cuda(cudaMemcpyAsync(h_gidx_buffer->data(), gidx_buffer.data(), gidx_buffer.size_bytes(),
cudaMemcpyDefault, dh::DefaultStream()));
cudaMemcpyDefault, ctx->CUDACtx()->Stream()));
return {DeviceOrd::CPU(),
cuts_,
is_dense,
Expand Down
Loading

0 comments on commit 7bccc1e

Please sign in to comment.