Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Jun 7, 2024
1 parent 5ccaafe commit 3343a26
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@
package ml.dmlc.xgboost4j.scala.spark

import ml.dmlc.xgboost4j.scala.spark.params.{NewHasGroupCol, SparkParams, XGBoostParams}
import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.MLVectorToXGBLabeledPoint
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.XGBoostSchemaUtils
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{FloatType, StructType}
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}

import scala.collection.mutable.ArrayBuffer

Expand Down Expand Up @@ -89,7 +90,6 @@ private[spark] abstract class XGBoostEstimator[
None
}


var groupName: Option[String] = None
this match {
case p: NewHasGroupCol =>
Expand Down Expand Up @@ -203,10 +203,47 @@ private[spark] abstract class XGBoostModel[M <: XGBoostModel[M]](

def summary: XGBoostTrainingSummary = trainingSummary

/**
* Predict label for the given features.
* This method is used to implement `transform()` and output [[predictionCol]].
*/
// def predict(features: Vector): Double

// def predictRaw(features: Vector): Vector

override def transformSchema(schema: StructType): StructType = schema

override def transform(dataset: Dataset[_]): DataFrame = {
dataset.asInstanceOf[DataFrame]

val spark = dataset.sparkSession
val outputSchema = transformSchema(dataset.schema, logging = true)

// Broadcast the booster to each executor.
val bBooster = spark.sparkContext.broadcast(booster)

val featureIndex = dataset.schema.fieldIndex(getFeaturesCol)

// TODO configurable
val inferBatchSize = 32 << 10
var schema = StructType(dataset.schema.fields ++
Seq(StructField(name = XGBoostClassificationModel._rawPredictionCol, dataType =
ArrayType(FloatType, containsNull = false), nullable = false)))

val outputData = dataset.toDF().mapPartitions { rowIter =>

rowIter.grouped(inferBatchSize).flatMap { batchRow =>
val features = batchRow.iterator.map(row => row.getAs[Vector](featureIndex))
val dm = new DMatrix(features.map(_.asXGB))
val rawIter = bBooster.value.predict(dm).map(Row(_))
batchRow.zip(rawIter).map { case (original, raw) =>
Row.fromSeq(original.toSeq ++ raw.toSeq)
}
}

}(Encoders.row(schema))

outputData.toDF()
}

override def transformSchema(schema: StructType): StructType = schema

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package ml.dmlc.xgboost4j.scala.spark.params

import org.apache.spark.ml.param.{IntParam, Param, ParamValidators, Params}
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasWeightCol}
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasPredictionCol, HasRawPredictionCol, HasWeightCol}


trait NewHasBaseMarginCol extends Params {
Expand All @@ -44,7 +44,7 @@ trait NewHasGroupCol extends Params {

private[spark] trait SparkParams extends Params
with HasFeaturesCol with HasLabelCol with NewHasBaseMarginCol
with HasWeightCol {
with HasWeightCol with HasPredictionCol with HasRawPredictionCol {

final val numWorkers = new IntParam(this, "numWorkers", "number of workers used to run xgboost",
ParamValidators.gtEq(1))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class NewXGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderP

// est.fit(arrayInput)
val model = est.fit(dataset)

model.transform(dataset).show()
}

}

0 comments on commit 3343a26

Please sign in to comment.