Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Jun 8, 2024
1 parent 06f5204 commit 5f3de01
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
package ml.dmlc.xgboost4j.scala.spark

import ml.dmlc.xgboost4j.scala.Booster
import ml.dmlc.xgboost4j.scala.spark.params.ClassifierParams
import ml.dmlc.xgboost4j.scala.spark.params.ClassificationParams
import org.apache.spark.ml.util.{DefaultParamsWritable, Identifiable}

class NewXGBoostClassifier(override val uid: String)
extends XGBoostEstimator[NewXGBoostClassifier, NewXGBoostClassificationModel]
with ClassifierParams with DefaultParamsWritable {
with ClassificationParams[NewXGBoostClassifier] with DefaultParamsWritable {

def this() = this(Identifiable.randomUID("xgbc"))

Expand All @@ -39,7 +39,7 @@ class NewXGBoostClassificationModel(
trainingSummary: XGBoostTrainingSummary
)
extends XGBoostModel[NewXGBoostClassificationModel](uid, booster, trainingSummary)
with ClassifierParams {
with ClassificationParams[NewXGBoostClassificationModel] {


}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ private case class ColumnIndexes(label: String, features: String,
private[spark] abstract class XGBoostEstimator[
Learner <: XGBoostEstimator[Learner, M],
M <: XGBoostModel[M]
] extends Estimator[M] with XGBoostParams with SparkParams {
] extends Estimator[M] with XGBoostParams[Learner] with SparkParams[Learner] {

/**
* Pre-convert input double data to floats to align with XGBoost's internal float-based
Expand Down Expand Up @@ -163,16 +163,15 @@ private[spark] abstract class XGBoostEstimator[
val rdd = toRdd(input, columnIndexes)

val paramMap = Map(
"num_rounds" -> 10,
"num_workers" -> 1,
"num_round" -> 1
"num_round" -> 100
)

val (booster, metrics) = NewXGBoost.train(
dataset.sparkSession.sparkContext, rdd, paramMap)

val summary = XGBoostTrainingSummary(metrics)
createModel(booster, summary)
copyValues(createModel(booster, summary))
}

override def copy(extra: ParamMap): Learner = defaultCopy(extra)
Expand All @@ -196,10 +195,10 @@ private[spark] abstract class XGBoostEstimator[
private[spark] abstract class XGBoostModel[M <: XGBoostModel[M]](
override val uid: String,
protected val booster: Booster,
protected val trainingSummary: XGBoostTrainingSummary) extends Model[M]
with XGBoostParams with SparkParams {
protected val trainingSummary: XGBoostTrainingSummary)
extends Model[M] with XGBoostParams[M] with SparkParams[M] {

override def copy(extra: ParamMap): M = defaultCopy(extra)
override def copy(extra: ParamMap): M = defaultCopy(extra).asInstanceOf[M]

def nativeBooster: Booster = booster

Expand Down Expand Up @@ -229,7 +228,6 @@ private[spark] abstract class XGBoostModel[M <: XGBoostModel[M]](
var schema = dataset.schema

var hasRawPredictionCol = false

this match {
case p: HasRawPredictionCol =>
if (isDefined(p.rawPredictionCol) && p.getRawPredictionCol.nonEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ trait HasBaseMarginCol extends Params {

/** @group getParam */
final def getBaseMarginCol: String = $(baseMarginCol)

def setBaseMarginCol(value: String): this.type = set(baseMarginCol, value)
}

trait HasGroupCol extends Params {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.types.StructType

private[scala] sealed trait XGBoostEstimatorCommon extends GeneralParams with LearningTaskParams
with BoosterParams with RabitParams with ParamMapFuncs with NonParamVariables with HasWeightCol
with NewHasBaseMarginCol with HasLeafPredictionCol with HasContribPredictionCol
with HasBaseMarginCol with HasLeafPredictionCol with HasContribPredictionCol
with HasFeaturesCol
with HasLabelCol with HasFeaturesCols with HasHandleInvalid {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,54 +19,47 @@ package ml.dmlc.xgboost4j.scala.spark.params
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.param.{IntParam, Param, ParamValidators, Params}


trait NewHasBaseMarginCol extends Params {
final val baseMarginCol: Param[String] = new Param[String](this, "baseMarginCol",
"Initial prediction (aka base margin) column name.")

/** @group getParam */
final def getBaseMarginCol: String = $(baseMarginCol)

def setBaseMarginCol(value: String): this.type = set(baseMarginCol, value)
}

trait NewHasGroupCol extends Params {

final val groupCol: Param[String] = new Param[String](this, "groupCol",
"group column name for ranker.")

final def getGroupCol: String = $(groupCol)

def setGroupCol(value: String): this.type = set(groupCol, value)

}


private[spark] trait SparkParams extends Params
with HasFeaturesCol with HasLabelCol with NewHasBaseMarginCol
private[spark] trait SparkParams[T <: Params] extends Params
with HasFeaturesCol with HasLabelCol with HasBaseMarginCol
with HasWeightCol with HasPredictionCol with HasLeafPredictionCol with HasContribPredictionCol {

final val numWorkers = new IntParam(this, "numWorkers", "number of workers used to run xgboost",
final val numWorkers = new IntParam(this, "numWorkers", "Number of workers used to run xgboost",
ParamValidators.gtEq(1))
setDefault(numWorkers, 1)

final def getNumWorkers: Int = $(numWorkers)

def setNumWorkers(value: Int): this.type = set(numWorkers, value)
def setNumWorkers(value: Int): T = set(numWorkers, value).asInstanceOf[T]

def setFeaturesCol(value: String): T = set(featuresCol, value).asInstanceOf[T]

def setLabelCol(value: String): this.type = set(labelCol, value)
def setLabelCol(value: String): T = set(labelCol, value).asInstanceOf[T]

def setLeafPredictionCol(value: String): this.type = set(leafPredictionCol, value)
def setLeafPredictionCol(value: String): T = set(leafPredictionCol, value).asInstanceOf[T]

def setContribPredictionCol(value: String): this.type = set(contribPredictionCol, value)
def setContribPredictionCol(value: String): T = set(contribPredictionCol, value).asInstanceOf[T]
}

private[spark] trait ClassifierParams extends HasRawPredictionCol with HasProbabilityCol {
def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value)
private[spark] trait ClassificationParams[T <: Params] extends HasRawPredictionCol
with HasProbabilityCol {

def setRawPredictionCol(value: String): T = set(rawPredictionCol, value).asInstanceOf[T]

def setProbabilityCol(value: String): T = set(probabilityCol, value).asInstanceOf[T]
}

def setProbabilityCol(value: String): this.type = set(probabilityCol, value)
private[spark] trait RankerParams[T <: Params] extends HasGroupCol {
def setGroupCol(value: String): T = set(groupCol, value).asInstanceOf[T]
}

private[spark] trait XGBoostParams extends Params {
private[spark] trait XGBoostParams[T <: Params] extends Params {

}
Original file line number Diff line number Diff line change
@@ -1,50 +1,54 @@
package ml.dmlc.xgboost4j.scala.spark

import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.functions.{array, col}
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
import org.apache.spark.sql.functions.{array, col, lit, rand}
import org.scalatest.funsuite.AnyFunSuite

class NewXGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {

test("test NewXGBoostClassifierSuite") {
// Define the schema for the fake data
val features = Array("feature1", "feature2", "feature3", "feature4")

val spark = ss
import spark.implicits._
val df = Seq(
(1.0, 0.0, 0.0, 0.0, 0.0, 30),
(2.0, 3.0, 4.0, 4.0, 0.0, 31),
(3.0, 4.0, 5.0, 5.0, 1.0, 32),
(4.0, 5.0, 6.0, 6.0, 1.0, 33),
).toDF("feature1", "feature2", "feature3", "feature4", "label", "base_margin")
// val features = Array("feature1", "feature2", "feature3", "feature4")

// val df = Seq(
// (1.0, 0.0, 0.0, 0.0, 0.0, 30),
// (2.0, 3.0, 4.0, 4.0, 0.0, 31),
// (3.0, 4.0, 5.0, 5.0, 1.0, 32),
// (4.0, 5.0, 6.0, 6.0, 1.0, 33),
// ).toDF("feature1", "feature2", "feature3", "feature4", "label", "base_margin")

var df = spark.read.parquet("/home/bobwang/data/iris/parquet")

// Select the features and label columns
val labelCol = "label"
val labelCol = "class"

val features = df.schema.names.filter(_ != labelCol)

df = df.withColumn("base_margin", lit(20))
.withColumn("weight", rand(1))


// Assemble the feature columns into a single vector column
val assembler = new VectorAssembler()
.setInputCols(features)
.setOutputCol("features")
val dataset = assembler.transform(df)

val arrayInput = df.select(array(features.map(col(_)): _*).as("features"),
col("label"), col("base_margin"))
// val arrayInput = df.select(array(features.map(col(_)): _*).as("features"),
// col("label"), col("base_margin"))

val est = new NewXGBoostClassifier()
.setNumWorkers(1)
.setLabelCol(labelCol)
.setBaseMarginCol("base_margin")
.setRawPredictionCol("raw")
.setProbabilityCol("")
.setProbabilityCol("prob")
.setContribPredictionCol("contrb")
.setLeafPredictionCol("leaf")
// val est = new XGBoostClassifier().setLabelCol(labelCol)

// est.fit(arrayInput)
// est.fit(arrayInput)
val model = est.fit(dataset)
// model.setProbabilityCol("")
model.setLeafPredictionCol("leaf")
model.setContribPredictionCol("conb")
model.transform(dataset).show()
}

Expand Down

0 comments on commit 5f3de01

Please sign in to comment.