Skip to content

Commit

Permalink
udpate
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Jun 22, 2024
1 parent bce89be commit 6a428f2
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.functions.{col, udf}

import ml.dmlc.xgboost4j.scala.Booster
import ml.dmlc.xgboost4j.scala.spark.params.ClassificationParams
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams.{binaryClassificationObjs, multiClassificationObjs}


class XGBoostClassifier(override val uid: String,
Expand All @@ -45,37 +46,47 @@ class XGBoostClassifier(override val uid: String,
/**
* Validate the parameters before training, throw exception if possible
*/
override protected def validate(dataset: Dataset[_]): Unit = {
override protected[spark] def validate(dataset: Dataset[_]): Unit = {
super.validate(dataset)

// The default objective is for regression case.
val obj = if (isSet(objective)) {
Some(getObjective)
val tmpObj = getObjective
val supportedObjs = (binaryClassificationObjs.toSeq ++ multiClassificationObjs.toSeq)
require(supportedObjs.contains(tmpObj),
s"Wrong objective for XGBoostClassifier, supported objs: ${supportedObjs.mkString(",")}")
Some(tmpObj)
} else {
None
}

var numClasses = getNumClass
// If user didn't set it, inferred it.
// For binary classification
if (obj.isDefined && obj.exists(binaryClassificationObjs.contains)) {
numClasses = 2
}

// Infer num class if possible,
// Note that user sets the num classes explicitly, we're not checking that.
if (numClasses == 0) {
numClasses = SparkUtils.getNumClasses(dataset, getLabelCol)
}
assert(numClasses > 0)

if (numClasses <= 2) {
if (!obj.exists(_.startsWith("binary:"))) {
if (!obj.exists(binaryClassificationObjs.contains)) {
logger.warn(s"Inferred for binary classification, but found wrong objective: " +
s"${getObjective}, rewrite objective to binary:logistic")
setObjective("binary:logistic")
}
} else {
if (!obj.exists(_.startsWith("multi:"))) {
if (!obj.exists(multiClassificationObjs.contains)) {
logger.warn(s"Inferred for multiclass classification, but found wrong objective: " +
s"${getObjective}, rewrite objective to multi:softprob")
setObjective("multi:softprob")
}
setNumClass(numClasses)
}
setNumClass(numClasses)

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ private[spark] abstract class XGBoostEstimator[
/**
* Validate the parameters before training, throw exception if possible
*/
protected def validate(dataset: Dataset[_]): Unit = {
protected[spark] def validate(dataset: Dataset[_]): Unit = {
validateSparkSslConf(dataset.sparkSession)
val schema = dataset.schema
SparkUtils.checkNumericType(schema, $(labelCol))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ private[spark] trait LearningTaskParams extends Params {

private[spark] object LearningTaskParams {

val binaryClassificationObjs = HashSet("binary:logistic", "binary:hinge")
val multiClassificationObjs = HashSet("multi:softmax", "multi:softprob")

val supportedObjectives = HashSet("reg:squarederror", "reg:squaredlogerror", "reg:logistic",
"reg:pseudohubererror", "reg:absoluteerror", "reg:quantileerror", "binary:logistic",
"binary:logitraw", "binary:hinge", "count:poisson", "survival:cox", "survival:aft",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.param.ParamMap
import org.scalatest.funsuite.AnyFunSuite

import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams.{binaryClassificationObjs, multiClassificationObjs}
import ml.dmlc.xgboost4j.scala.spark.params.XGBoostParams

class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {
Expand Down Expand Up @@ -97,6 +98,86 @@ class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerS
check(modelLoaded)
}

test("XGBoostClassificationModel transformed schema") {
val trainDf = smallBinaryClassificationVector
val classifier = new XGBoostClassifier().setRabitTrackerHostIp("127.0.0.1")
val model = classifier.fit(trainDf)
val out = model.transform(trainDf)

out.printSchema()
// assert(out.schema.names.contains(""))
}

test("Supported objectives") {
val classifier = new XGBoostClassifier()
val df = smallMultiClassificationVector
(binaryClassificationObjs.toSeq ++ multiClassificationObjs.toSeq).foreach { obj =>
classifier.setObjective(obj)
classifier.validate(df)
}

classifier.setObjective("reg:squaredlogerror")
intercept[IllegalArgumentException](
classifier.validate(df)
)
}

test("Binaryclassification infer objective and num_class") {
val trainDf = smallBinaryClassificationVector
var classifier = new XGBoostClassifier()
assert(classifier.getObjective === "reg:squarederror")
assert(classifier.getNumClass === 0)
classifier.validate(trainDf)
assert(classifier.getObjective === "binary:logistic")
assert(classifier.getNumClass === 2)

// Infer objective according num class
classifier = new XGBoostClassifier()
classifier.setNumClass(2)
classifier.validate(trainDf)
assert(classifier.getObjective === "binary:logistic")
assert(classifier.getNumClass === 2)

// Infer to num class according to num class
classifier = new XGBoostClassifier()
classifier.setObjective("binary:logistic")
classifier.validate(trainDf)
assert(classifier.getObjective === "binary:logistic")
assert(classifier.getNumClass === 2)
}

test("Classification infer objective and num_class") {
val trainDf = smallMultiClassificationVector
var classifier = new XGBoostClassifier()
assert(classifier.getObjective === "reg:squarederror")
assert(classifier.getNumClass === 0)
classifier.validate(trainDf)
assert(classifier.getObjective === "multi:softprob")
assert(classifier.getNumClass === 3)

// Infer to objective according to num class
classifier = new XGBoostClassifier()
classifier.setNumClass(3)
classifier.validate(trainDf)
assert(classifier.getObjective === "multi:softprob")
assert(classifier.getNumClass === 3)

// Infer to num class according to objective
classifier = new XGBoostClassifier()
classifier.setObjective("multi:softmax")
classifier.validate(trainDf)
assert(classifier.getObjective === "multi:softmax")
assert(classifier.getNumClass === 3)
}

test("Binary classification") {

}

test("Multiclass classification") {

}


test("pipeline") {
val spark = ss
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ import org.scalatest.funsuite.AnyFunSuite
import ml.dmlc.xgboost4j.scala.DMatrix

class XGBoostEstimatorSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {
test("nthread") {
val classifier = new XGBoostClassifier().setNthread(100)

intercept[IllegalArgumentException](
classifier.validate(smallBinaryClassificationVector)
)
}

test("RuntimeParameter") {
var runtimeParams = new XGBoostClassifier(
Expand Down

0 comments on commit 6a428f2

Please sign in to comment.