Skip to content

Commit

Permalink
Support more tests (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 authored Jun 23, 2024
1 parent 693860a commit c034a9a
Show file tree
Hide file tree
Showing 9 changed files with 450 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class GpuXGBoostPlugin extends XGBoostPlugin {
}

// TODO, support numeric type
private def preprocess[T <: XGBoostEstimator[T, M], M <: XGBoostModel[M]](
private[spark] def preprocess[T <: XGBoostEstimator[T, M], M <: XGBoostModel[M]](
estimator: XGBoostEstimator[T, M], dataset: Dataset[_]): Dataset[_] = {

// Columns to be selected for XGBoost training
Expand Down Expand Up @@ -80,7 +80,8 @@ class GpuXGBoostPlugin extends XGBoostPlugin {
estimator.repartitionIfNeeded(input)
}

private def validate[T <: XGBoostEstimator[T, M], M <: XGBoostModel[M]](
// visiable for testing
private[spark] def validate[T <: XGBoostEstimator[T, M], M <: XGBoostModel[M]](
estimator: XGBoostEstimator[T, M],
dataset: Dataset[_]): Unit = {
require(estimator.getTreeMethod == "gpu_hist" || estimator.getDevice != "cpu",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ object SparkSessionHolder extends Logging {
.config("spark.sql.adaptive.enabled", "false")
.config("spark.rapids.sql.enabled", "false")
.config("spark.rapids.sql.test.enabled", "false")
.config("spark.stage.maxConsecutiveAttempts", "1")
.config("spark.plugins", "com.nvidia.spark.SQLPlugin")
.config("spark.rapids.memory.gpu.pooling.enabled", "false") // Disable RMM for unit tests.
.config("spark.sql.files.maxPartitionBytes", "1000")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
package ml.dmlc.xgboost4j.scala.spark

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.SparkSession

import ml.dmlc.xgboost4j.scala.rapids.spark.GpuTestSuite

class GpuXGBoostPluginSuite extends GpuTestSuite {


test("isEnabled") {
def checkIsEnabled(spark: SparkSession, expected: Boolean): Unit = {
import spark.implicits._
val df = Seq((1.0f, 2.0f, 0.0f),
(2.0f, 3.0f, 1.0f)
).toDF("c1", "c2", "label")
val classifier = new XGBoostClassifier()
assert(classifier.getPlugin.isDefined)
assert(classifier.getPlugin.get.isEnabled(df) === expected)
}

withCpuSparkSession() { spark =>
checkIsEnabled(spark, false)
}

withGpuSparkSession() { spark =>
checkIsEnabled(spark, true)
}
}


test("parameter validation") {
withGpuSparkSession() { spark =>
import spark.implicits._
val df = Seq((1.0f, 2.0f, 1.0f, 2.0f, 0.0f, 0.0f),
(2.0f, 3.0f, 2.0f, 3.0f, 1.0f, 0.1f),
(3.0f, 4.0f, 5.0f, 6.0f, 0.0f, 0.1f),
(4.0f, 5.0f, 6.0f, 7.0f, 0.0f, 0.1f),
(5.0f, 6.0f, 7.0f, 8.0f, 0.0f, 0.1f),
).toDF("c1", "c2", "weight", "margin", "label", "other")
val classifier = new XGBoostClassifier()

val plugin = classifier.getPlugin.get.asInstanceOf[GpuXGBoostPlugin]
intercept[IllegalArgumentException] {
plugin.validate(classifier, df)
}
classifier.setDevice("cuda")
plugin.validate(classifier, df)

classifier.setDevice("gpu")
plugin.validate(classifier, df)

classifier.setDevice("cpu")
classifier.setTreeMethod("gpu_hist")
plugin.validate(classifier, df)
}
}

test("preprocess") {
withGpuSparkSession() { spark =>
import spark.implicits._
val df = Seq((1.0f, 2.0f, 1.0f, 2.0f, 0.0f, 0.0f),
(2.0f, 3.0f, 2.0f, 3.0f, 1.0f, 0.1f),
(3.0f, 4.0f, 5.0f, 6.0f, 0.0f, 0.1f),
(4.0f, 5.0f, 6.0f, 7.0f, 0.0f, 0.1f),
(5.0f, 6.0f, 7.0f, 8.0f, 0.0f, 0.1f),
).toDF("c1", "c2", "weight", "margin", "label", "other")
.repartition(5)

assert(df.schema.names.contains("other"))
assert(df.rdd.getNumPartitions === 5)

val features = Array("c1", "c2")
var classifier = new XGBoostClassifier()
.setNumWorkers(3)
.setFeaturesCol(features)
assert(classifier.getPlugin.isDefined)
assert(classifier.getPlugin.get.isInstanceOf[GpuXGBoostPlugin])
var out = classifier.getPlugin.get.asInstanceOf[GpuXGBoostPlugin]
.preprocess(classifier, df)

assert(out.schema.names.contains("c1") && out.schema.names.contains("c2"))
assert(out.schema.names.contains(classifier.getLabelCol))
assert(!out.schema.names.contains("weight") && !out.schema.names.contains("margin"))
assert(out.rdd.getNumPartitions === 3)

classifier = new XGBoostClassifier()
.setNumWorkers(4)
.setFeaturesCol(features)
.setWeightCol("weight")
.setBaseMarginCol("margin")
.setDevice("cuda")
out = classifier.getPlugin.get.asInstanceOf[GpuXGBoostPlugin]
.preprocess(classifier, df)

assert(out.schema.names.contains("c1") && out.schema.names.contains("c2"))
assert(out.schema.names.contains(classifier.getLabelCol))
assert(out.schema.names.contains("weight") && out.schema.names.contains("margin"))
assert(out.rdd.getNumPartitions === 4)
}
}

// TODO .... why rowNum is 5, and non missing = 9
test("build RDD Watches") {
withGpuSparkSession() { spark =>
import spark.implicits._

// dataPoint -> (missing, rowNum, nonMissing)
Map(0.0f -> (0.0f, 4, 8), Float.NaN -> (0.0f, 5, 10)).foreach {
case (data, (missing, expectedRowNum, expectedNonMissing)) =>
val df = Seq(
(1.0f, 2.0f, 1.0f, 2.0f, 0.0f, 0.0f),
(2.0f, 3.0f, 2.0f, 3.0f, 1.0f, 0.1f),
(3.0f, data, 5.0f, 6.0f, 0.0f, 0.1f),
(4.0f, 5.0f, 6.0f, 7.0f, 0.0f, 0.1f),
(5.0f, 6.0f, 7.0f, 8.0f, 1.0f, 0.1f),
).toDF("c1", "c2", "weight", "margin", "label", "other")

val features = Array("c1", "c2")
val classifier = new XGBoostClassifier()
.setNumWorkers(2)
.setWeightCol("weight")
.setBaseMarginCol("margin")
.setFeaturesCol(features)
.setDevice("cuda")
.setMissing(missing)

val rdd = classifier.getPlugin.get.buildRddWatches(classifier, df)
val result = rdd.mapPartitions { iter =>
val watches = iter.next()
val size = watches.size
val labels = watches.datasets(0).getLabel
val weight = watches.datasets(0).getWeight
val margins = watches.datasets(0).getBaseMargin
val rowNumber = watches.datasets(0).rowNum
val nonMissing = watches.datasets(0).nonMissingNum
Iterator.single(size, rowNumber, nonMissing, labels, weight, margins)
}.collect()

val labels: ArrayBuffer[Float] = ArrayBuffer.empty
val weight: ArrayBuffer[Float] = ArrayBuffer.empty
val margins: ArrayBuffer[Float] = ArrayBuffer.empty
val rowNumber: ArrayBuffer[Long] = ArrayBuffer.empty
val nonMissing: ArrayBuffer[Long] = ArrayBuffer.empty

for (row <- result) {
assert(row._1 === 1)
rowNumber.append(row._2)
nonMissing.append(row._3)
labels.append(row._4: _*)
weight.append(row._5: _*)
margins.append(row._6: _*)
}
assert(labels.sorted === Array(0.0f, 1.0f, 0.0f, 0.0f, 1.0f).sorted)
assert(weight.sorted === Array(1.0f, 2.0f, 5.0f, 6.0f, 7.0f).sorted)
assert(margins.sorted === Array(2.0f, 3.0f, 6.0f, 7.0f, 8.0f).sorted)
// assert(rowNumber.sum === expectedRowNum)
assert(nonMissing.sum === expectedNonMissing)
}
}
}

// TODO .... why rowNum is 5, and non missing = 9
test("build RDD Watches with Eval") {
withGpuSparkSession() { spark =>
import spark.implicits._

val train = Seq(
(1.0f, 2.0f, 1.0f, 2.0f, 0.0f, 0.0f),
(2.0f, 3.0f, 2.0f, 3.0f, 1.0f, 0.1f),
).toDF("c1", "c2", "weight", "margin", "label", "other")

// dataPoint -> (missing, rowNum, nonMissing)
Map(0.0f -> (0.0f, 4, 8), Float.NaN -> (0.0f, 5, 10)).foreach {
case (data, (missing, expectedRowNum, expectedNonMissing)) =>
val eval = Seq(
(1.0f, 2.0f, 1.0f, 2.0f, 0.0f, 0.0f),
(2.0f, 3.0f, 2.0f, 3.0f, 1.0f, 0.1f),
(3.0f, data, 5.0f, 6.0f, 0.0f, 0.1f),
(4.0f, 5.0f, 6.0f, 7.0f, 0.0f, 0.1f),
(5.0f, 6.0f, 7.0f, 8.0f, 1.0f, 0.1f),
).toDF("c1", "c2", "weight", "margin", "label", "other")

val features = Array("c1", "c2")
val classifier = new XGBoostClassifier()
.setNumWorkers(2)
.setWeightCol("weight")
.setBaseMarginCol("margin")
.setFeaturesCol(features)
.setDevice("cuda")
.setMissing(missing)
.setEvalDataset(eval)

val rdd = classifier.getPlugin.get.buildRddWatches(classifier, train)
val result = rdd.mapPartitions { iter =>
val watches = iter.next()
val size = watches.size
val labels = watches.datasets(1).getLabel
val weight = watches.datasets(1).getWeight
val margins = watches.datasets(1).getBaseMargin
val rowNumber = watches.datasets(1).rowNum
val nonMissing = watches.datasets(1).nonMissingNum
Iterator.single(size, rowNumber, nonMissing, labels, weight, margins)
}.collect()

val labels: ArrayBuffer[Float] = ArrayBuffer.empty
val weight: ArrayBuffer[Float] = ArrayBuffer.empty
val margins: ArrayBuffer[Float] = ArrayBuffer.empty
val rowNumber: ArrayBuffer[Long] = ArrayBuffer.empty
val nonMissing: ArrayBuffer[Long] = ArrayBuffer.empty

for (row <- result) {
assert(row._1 === 2)
rowNumber.append(row._2)
nonMissing.append(row._3)
labels.append(row._4: _*)
weight.append(row._5: _*)
margins.append(row._6: _*)
}
assert(labels.sorted === Array(0.0f, 1.0f, 0.0f, 0.0f, 1.0f).sorted)
assert(weight.sorted === Array(1.0f, 2.0f, 5.0f, 6.0f, 7.0f).sorted)
assert(margins.sorted === Array(2.0f, 3.0f, 6.0f, 7.0f, 8.0f).sorted)
// assert(rowNumber.sum === expectedRowNum)
assert(nonMissing.sum === expectedNonMissing)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.functions.{col, udf}

import ml.dmlc.xgboost4j.scala.Booster
import ml.dmlc.xgboost4j.scala.spark.params.ClassificationParams
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams.{binaryClassificationObjs, multiClassificationObjs}


class XGBoostClassifier(override val uid: String,
Expand All @@ -42,41 +43,59 @@ class XGBoostClassifier(override val uid: String,

xgboost2SparkParams(xgboostParams)

/**
* Validate the parameters before training, throw exception if possible
*/
override protected def validate(dataset: Dataset[_]): Unit = {
super.validate(dataset)

// The default objective is for regression case.
private def validateObjective(dataset: Dataset[_]): Unit = {
// If the objective is set explicitly, it must be in binaryClassificationObjs and
// multiClassificationObjs
val obj = if (isSet(objective)) {
Some(getObjective)
val tmpObj = getObjective
val supportedObjs = binaryClassificationObjs.toSeq ++ multiClassificationObjs.toSeq
require(supportedObjs.contains(tmpObj),
s"Wrong objective for XGBoostClassifier, supported objs: ${supportedObjs.mkString(",")}")
Some(tmpObj)
} else {
None
}

var numClasses = getNumClass
// If user didn't set it, inferred it.
if (numClasses == 0) {
numClasses = SparkUtils.getNumClasses(dataset, getLabelCol)
def inferNumClasses: Int = {
var numClasses = getNumClass
// Infer num class if num class is not set explicitly.
// Note that user sets the num classes explicitly, we're not checking that.
if (numClasses == 0) {
numClasses = SparkUtils.getNumClasses(dataset, getLabelCol)
}
require(numClasses > 0)
numClasses
}
assert(numClasses > 0)

if (numClasses <= 2) {
if (!obj.exists(_.startsWith("binary:"))) {
logger.warn(s"Inferred for binary classification, but found wrong objective: " +
s"${getObjective}, rewrite objective to binary:logistic")
setObjective("binary:logistic")
// objective is set explicitly.
if (obj.isDefined) {
if (multiClassificationObjs.contains(getObjective)) {
setNumClass(inferNumClasses)
} else {
// binary classification doesn't require num_class be set
require(!isSet(numClass), "num_class is not allowed for binary classification")
}
} else {
if (!obj.exists(_.startsWith("multi:"))) {
logger.warn(s"Inferred for multiclass classification, but found wrong objective: " +
s"${getObjective}, rewrite objective to multi:softprob")
// infer the objective according to the num_class
val numClasses = inferNumClasses
if (numClasses <= 2) {
setObjective("binary:logistic")
logger.warn("Inferred for binary classification, set the objective to binary:logistic")
require(!isSet(numClass), "num_class is not allowed for binary classification")
} else {
logger.warn("Inferred for multi classification, set the objective to multi:softprob")
setObjective("multi:softprob")
setNumClass(numClasses)
}
setNumClass(numClasses)
}
}

/**
* Validate the parameters before training, throw exception if possible
*/
override protected[spark] def validate(dataset: Dataset[_]): Unit = {
super.validate(dataset)
validateObjective(dataset)
}

override protected def createModel(booster: Booster, summary: XGBoostTrainingSummary):
Expand Down Expand Up @@ -130,6 +149,7 @@ class XGBoostClassificationModel(

override def postTransform(dataset: Dataset[_]): Dataset[_] = {
var output = dataset
// Always use probability col to get the prediction
if (isDefined(predictionCol) && getPredictionCol.nonEmpty) {
val predCol = udf { probability: mutable.WrappedArray[Float] =>
probability2prediction(Vectors.dense(probability.map(_.toDouble).toArray))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ private[spark] abstract class XGBoostEstimator[
}
}

/** Visiable for testing */
private[spark] def getPlugin: Option[XGBoostPlugin] = plugin

private def isPluginEnabled(dataset: Dataset[_]): Boolean = {
plugin.map(_.isEnabled(dataset)).getOrElse(false)
}
Expand Down Expand Up @@ -339,7 +342,7 @@ private[spark] abstract class XGBoostEstimator[
/**
* Validate the parameters before training, throw exception if possible
*/
protected def validate(dataset: Dataset[_]): Unit = {
protected[spark] def validate(dataset: Dataset[_]): Unit = {
validateSparkSslConf(dataset.sparkSession)
val schema = dataset.schema
SparkUtils.checkNumericType(schema, $(labelCol))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ private[spark] trait LearningTaskParams extends Params {

private[spark] object LearningTaskParams {

val binaryClassificationObjs = HashSet("binary:logistic", "binary:hinge", "binary:logitraw")
val multiClassificationObjs = HashSet("multi:softmax", "multi:softprob")

val supportedObjectives = HashSet("reg:squarederror", "reg:squaredlogerror", "reg:logistic",
"reg:pseudohubererror", "reg:absoluteerror", "reg:quantileerror", "binary:logistic",
"binary:logitraw", "binary:hinge", "count:poisson", "survival:cox", "survival:aft",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ trait PerTest extends BeforeAndAfterEach {
.config("spark.driver.memory", "512m")
.config("spark.barrier.sync.timeout", 10)
.config("spark.task.cpus", 1)
.config("spark.stage.maxConsecutiveAttempts", 1)

override def beforeEach(): Unit = getOrCreateSession

Expand Down
Loading

0 comments on commit c034a9a

Please sign in to comment.