Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Jun 25, 2024
1 parent 1fd3d48 commit 2513297
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.TaskContext
import org.apache.spark.ml.functions.array_to_vector

import ml.dmlc.xgboost4j.java.{CudfColumnBatch, GpuColumnBatch}
import ml.dmlc.xgboost4j.scala.{DMatrix, QuantileDMatrix}
Expand Down Expand Up @@ -273,7 +274,19 @@ class GpuXGBoostPlugin extends XGBoostPlugin {
}
bBooster.unpersist(false)
bOriginalSchema.unpersist(false)
var output = dataset.sparkSession.createDataFrame(rdd, transformedSchema)

dataset.toDF()
// Convert leaf/contrib to the vector from array
if (pred.predLeaf) {
output = output.withColumn(model.getLeafPredictionCol,
array_to_vector(output.col(model.getLeafPredictionCol)))
}

if (pred.predContrib) {
output = output.withColumn(model.getContribPredictionCol,
array_to_vector(output.col(model.getContribPredictionCol)))
}

model.postTransform(output).toDF()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,8 @@ private[spark] trait NonParamVariables[T <: XGBoostEstimator[T, M], M <: XGBoost
}
}

private[spark] trait XGBoostEstimator[
Learner <: XGBoostEstimator[Learner, M], M <: XGBoostModel[M]] extends Estimator[M]
with XGBoostParams[Learner] with SparkParams[Learner] with ParamUtils[Learner]
with NonParamVariables[Learner, M] with ParamMapConversion with DefaultParamsWritable {

protected val logger = LogFactory.getLog("XGBoostSpark")

// Find the XGBoostPlugin by ServiceLoader
private[spark] trait PluginMixin {
// Find the XGBoostPlugin by ServiceLoader
private val plugin: Option[XGBoostPlugin] = {
val classLoader = Option(Thread.currentThread().getContextClassLoader)
.getOrElse(getClass.getClassLoader)
Expand All @@ -92,11 +86,20 @@ private[spark] trait XGBoostEstimator[
}

/** Visiable for testing */
private[spark] def getPlugin: Option[XGBoostPlugin] = plugin
protected[spark] def getPlugin: Option[XGBoostPlugin] = plugin

private def isPluginEnabled(dataset: Dataset[_]): Boolean = {
protected def isPluginEnabled(dataset: Dataset[_]): Boolean = {
plugin.map(_.isEnabled(dataset)).getOrElse(false)
}
}

private[spark] trait XGBoostEstimator[
Learner <: XGBoostEstimator[Learner, M], M <: XGBoostModel[M]] extends Estimator[M]
with XGBoostParams[Learner] with SparkParams[Learner] with ParamUtils[Learner]
with NonParamVariables[Learner, M] with ParamMapConversion with DefaultParamsWritable
with PluginMixin {

protected val logger = LogFactory.getLog("XGBoostSpark")

/**
* Pre-convert input double data to floats to align with XGBoost's internal float-based
Expand Down Expand Up @@ -383,7 +386,7 @@ private[spark] trait XGBoostEstimator[
validate(dataset)

val rdd = if (isPluginEnabled(dataset)) {
plugin.get.buildRddWatches(this, dataset)
getPlugin.foreach(_.buildRddWatches(this, dataset))
} else {
val (input, columnIndexes) = preprocess(dataset)
toRdd(input, columnIndexes)
Expand Down Expand Up @@ -420,7 +423,7 @@ private[spark] case class PredictedColumns(predLeaf: Boolean, predContrib: Boole
* @tparam the exact model which must extend from XGBoostModel
*/
private[spark] trait XGBoostModel[M <: XGBoostModel[M]] extends Model[M] with MLWritable
with XGBoostParams[M] with SparkParams[M] with ParamUtils[M] {
with XGBoostParams[M] with SparkParams[M] with ParamUtils[M] with PluginMixin {

protected val TMP_TRANSFORMED_COL = "_tmp_xgb_transformed_col"

Expand All @@ -440,7 +443,7 @@ private[spark] trait XGBoostModel[M <: XGBoostModel[M]] extends Model[M] with ML
validateAndTransformSchema(schema, false)
}

protected def postTransform(dataset: Dataset[_]): Dataset[_] = dataset
protected[spark] def postTransform(dataset: Dataset[_]): Dataset[_] = dataset

/**
* Preprocess the schema before transforming.
Expand Down Expand Up @@ -514,10 +517,12 @@ private[spark] trait XGBoostModel[M <: XGBoostModel[M]] extends Model[M] with ML

override def transform(dataset: Dataset[_]): DataFrame = {

val bBooster = dataset.sparkSession.sparkContext.broadcast(nativeBooster)
if (getPlugin.isDefined) {
return getPlugin.get.transform(this, dataset)
}

val bBooster = dataset.sparkSession.sparkContext.broadcast(nativeBooster)
val (schema, pred) = preprocess(dataset)

// TODO configurable
val inferBatchSize = 32 << 10
// Broadcast the booster to each executor.
Expand Down

0 comments on commit 2513297

Please sign in to comment.