Skip to content

Commit

Permalink
Support vector transform and more tests. (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 authored Jun 24, 2024
1 parent c034a9a commit 6ec4606
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 151 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark

import scala.collection.mutable

import org.apache.spark.ml.functions.array_to_vector
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable, MLReadable, MLReader}
Expand Down Expand Up @@ -150,16 +151,24 @@ class XGBoostClassificationModel(
override def postTransform(dataset: Dataset[_]): Dataset[_] = {
var output = dataset
// Always use probability col to get the prediction
if (isDefined(predictionCol) && getPredictionCol.nonEmpty) {
if (isDefinedNonEmpty(predictionCol)) {
val predCol = udf { probability: mutable.WrappedArray[Float] =>
probability2prediction(Vectors.dense(probability.map(_.toDouble).toArray))
}
output = output.withColumn(getPredictionCol, predCol(col(TMP_TRANSFORMED_COL)))
}

if (isDefined(probabilityCol) && getProbabilityCol.nonEmpty) {
output = output.withColumnRenamed(TMP_TRANSFORMED_COL, getProbabilityCol)
if (isDefinedNonEmpty(probabilityCol)) {
output = output.withColumn(TMP_TRANSFORMED_COL,
array_to_vector(output.col(TMP_TRANSFORMED_COL)))
.withColumnRenamed(TMP_TRANSFORMED_COL, getProbabilityCol)
}

if (isDefinedNonEmpty(rawPredictionCol)) {
output = output.withColumn(getRawPredictionCol,
array_to_vector(output.col(getRawPredictionCol)))
}

output.drop(TMP_TRANSFORMED_COL)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.jdk.CollectionConverters.iterableAsScalaIterableConverter
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.fs.Path
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.functions.array_to_vector
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.util.{DefaultParamsWritable, MLReader, MLWritable, MLWriter}
Expand All @@ -34,10 +35,11 @@ import org.apache.spark.sql._
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType}

import ml.dmlc.xgboost4j.java.{Booster => JBooster}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
import ml.dmlc.xgboost4j.scala.spark.Utils.MLVectorToXGBLabeledPoint
import ml.dmlc.xgboost4j.scala.spark.params._
import ml.dmlc.xgboost4j.scala.spark.params.{ParamUtils, _}


/**
Expand Down Expand Up @@ -67,7 +69,7 @@ private[spark] trait NonParamVariables[T <: XGBoostEstimator[T, M], M <: XGBoost

private[spark] abstract class XGBoostEstimator[
Learner <: XGBoostEstimator[Learner, M], M <: XGBoostModel[M]] extends Estimator[M]
with XGBoostParams[Learner] with SparkParams[Learner]
with XGBoostParams[Learner] with SparkParams[Learner] with ParamUtils[Learner]
with NonParamVariables[Learner, M] with ParamMapConversion with DefaultParamsWritable {

protected val logger = LogFactory.getLog("XGBoostSpark")
Expand Down Expand Up @@ -142,7 +144,7 @@ private[spark] abstract class XGBoostEstimator[

// function to get the column id according to the parameter
def columnId(param: Param[String]): Option[Int] = {
if (isDefined(param) && $(param).nonEmpty) {
if (isDefinedNonEmpty(param)) {
Some(schema.fieldIndex($(param)))
} else {
None
Expand All @@ -164,10 +166,6 @@ private[spark] abstract class XGBoostEstimator[
groupId)
}

private[spark] def isDefinedNonEmpty(param: Param[String]): Boolean = {
if (isDefined(param) && $(param).nonEmpty) true else false
}

/**
* Preprocess the dataset to meet the xgboost input requirement
*
Expand Down Expand Up @@ -346,11 +344,11 @@ private[spark] abstract class XGBoostEstimator[
validateSparkSslConf(dataset.sparkSession)
val schema = dataset.schema
SparkUtils.checkNumericType(schema, $(labelCol))
if (isDefined(weightCol) && $(weightCol).nonEmpty) {
if (isDefinedNonEmpty(weightCol)) {
SparkUtils.checkNumericType(schema, $(weightCol))
}

if (isDefined(baseMarginCol) && $(baseMarginCol).nonEmpty) {
if (isDefinedNonEmpty(baseMarginCol)) {
SparkUtils.checkNumericType(schema, $(baseMarginCol))
}

Expand Down Expand Up @@ -413,7 +411,7 @@ private[spark] abstract class XGBoostModel[M <: XGBoostModel[M]](
override val uid: String,
private val model: Booster,
private val trainingSummary: Option[XGBoostTrainingSummary]) extends Model[M] with MLWritable
with XGBoostParams[M] with SparkParams[M] {
with XGBoostParams[M] with SparkParams[M] with ParamUtils[M] {

protected val TMP_TRANSFORMED_COL = "_tmp_xgb_transformed_col"

Expand Down Expand Up @@ -446,7 +444,7 @@ private[spark] abstract class XGBoostModel[M <: XGBoostModel[M]](

/** If the parameter is defined, add it to schema and turn true */
def addToSchema(param: Param[String], colName: Option[String] = None): Boolean = {
if (isDefined(param) && $(param).nonEmpty) {
if (isDefinedNonEmpty(param)) {
val name = colName.getOrElse($(param))
schema = schema.add(StructField(name, ArrayType(FloatType)))
true
Expand All @@ -467,7 +465,7 @@ private[spark] abstract class XGBoostModel[M <: XGBoostModel[M]](
hasRawPredictionCol = addToSchema(p.rawPredictionCol)
hasTransformedCol = addToSchema(p.probabilityCol, Some(TMP_TRANSFORMED_COL))

if (isDefined(predictionCol) && getPredictionCol.nonEmpty) {
if (isDefinedNonEmpty(predictionCol)) {
// Let's use transformed col to calculate the prediction
if (!hasTransformedCol) {
// Add the transformed col for predition
Expand All @@ -488,7 +486,7 @@ private[spark] abstract class XGBoostModel[M <: XGBoostModel[M]](
val bBooster = spark.sparkContext.broadcast(nativeBooster)
val featureName = getFeaturesCol

val outputData = dataset.toDF().mapPartitions { rowIter =>
var output = dataset.toDF().mapPartitions { rowIter =>

rowIter.grouped(inferBatchSize).flatMap { batchRow =>
val features = batchRow.iterator.map(row => row.getAs[Vector](
Expand Down Expand Up @@ -524,7 +522,19 @@ private[spark] abstract class XGBoostModel[M <: XGBoostModel[M]](

}(Encoders.row(schema))
bBooster.unpersist(blocking = false)
postTransform(outputData).toDF()

// Convert leaf/contrib to the vector from array
if (hasLeafPredictionCol) {
output = output.withColumn(getLeafPredictionCol,
array_to_vector(output.col(getLeafPredictionCol)))
}

if (hasContribPredictionCol) {
output = output.withColumn(getContribPredictionCol,
array_to_vector(output.col(getContribPredictionCol)))
}

postTransform(output).toDF()
}

override def write: MLWriter = new XGBoostModelWriter[XGBoostModel[_]](this)
Expand All @@ -536,15 +546,17 @@ private[spark] abstract class XGBoostModel[M <: XGBoostModel[M]](
* @param instance model to be written
*/
private[spark] class XGBoostModelWriter[M <: XGBoostModel[M]](instance: M) extends MLWriter {

override protected def saveImpl(path: String): Unit = {
SparkUtils.saveMetadata(instance, path, sc)

// Save model data
val dataPath = new Path(path, "data").toString
val internalPath = new Path(dataPath, "model")
val outputStream = internalPath.getFileSystem(sc.hadoopConfiguration).create(internalPath)
val format = optionMap.getOrElse("format", JBooster.DEFAULT_FORMAT)
try {
instance.nativeBooster.saveModel(outputStream)
instance.nativeBooster.saveModel(outputStream, format)
} finally {
outputStream.close()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,10 @@ private[spark] trait XGBoostParams[T <: Params] extends TreeBoosterParams

def setNthread(value: Int): T = set(nthread, value).asInstanceOf[T]
}

private[spark] trait ParamUtils[T <: Params] extends Params {

def isDefinedNonEmpty(param: Param[String]): Boolean = {
isDefined(param) && $(param).nonEmpty
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ trait PerTest extends BeforeAndAfterEach {
import Utils.XGBLabeledPointFeatures
val it = labeledPoints.iterator.zipWithIndex
.map { case (labeledPoint: XGBLabeledPoint, id: Int) =>
(id, labeledPoint.label, labeledPoint.features)
(id, labeledPoint.label, labeledPoint.features, labeledPoint.weight)
}

ss.createDataFrame(sc.parallelize(it.toList, numPartitions))
.toDF("id", "label", "features")
.toDF("id", "label", "features", "weight")
}

protected def compareTwoFiles(lhs: String, rhs: String): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package ml.dmlc.xgboost4j.scala.spark

import scala.io.Source
import scala.util.Random

import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}

Expand All @@ -31,8 +32,8 @@ trait TrainTestData {
Source.fromInputStream(is).getLines()
}

protected def getLabeledPoints(resource: String, featureSize: Int, zeroBased: Boolean):
Seq[XGBLabeledPoint] = {
protected def getLabeledPoints(resource: String, featureSize: Int,
zeroBased: Boolean): Seq[XGBLabeledPoint] = {
getResourceLines(resource).map { line =>
val labelAndFeatures = line.split(" ")
val label = labelAndFeatures.head.toFloat
Expand Down Expand Up @@ -65,10 +66,32 @@ trait TrainTestData {
object Classification extends TrainTestData {
val train: Seq[XGBLabeledPoint] = getLabeledPoints("/agaricus.txt.train", 126, zeroBased = false)
val test: Seq[XGBLabeledPoint] = getLabeledPoints("/agaricus.txt.test", 126, zeroBased = false)

Random.setSeed(10)
val randomWeights = Array.fill(train.length)(Random.nextFloat())
val trainWithWeight = train.zipWithIndex.map { case (v, index) =>
XGBLabeledPoint(v.label, v.size, v.indices, v.values,
randomWeights(index), v.group, v.baseMargin)
}
}

object MultiClassification extends TrainTestData {
val train: Seq[XGBLabeledPoint] = getLabeledPoints("/dermatology.data")

private def split(): (Seq[XGBLabeledPoint], Seq[XGBLabeledPoint]) = {
val tmp: Seq[XGBLabeledPoint] = getLabeledPoints("/dermatology.data")
Random.setSeed(100)
val randomizedTmp = Random.shuffle(tmp)
val splitIndex = (randomizedTmp.length * 0.8).toInt
(randomizedTmp.take(splitIndex), randomizedTmp.drop(splitIndex))
}

val (train, test) = split()
Random.setSeed(10)
val randomWeights = Array.fill(train.length)(Random.nextFloat())
val trainWithWeight = train.zipWithIndex.map { case (v, index) =>
XGBLabeledPoint(v.label, v.size, v.indices, v.values,
randomWeights(index), v.group, v.baseMargin)
}

private def getLabeledPoints(resource: String): Seq[XGBLabeledPoint] = {
getResourceLines(resource).map { line =>
Expand Down
Loading

0 comments on commit 6ec4606

Please sign in to comment.