Skip to content

Commit

Permalink
udpate
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Jun 26, 2024
1 parent 499ca78 commit 468b812
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,9 @@ object SparkMLlibPipeline {
"max_depth" -> 2,
"objective" -> "multi:softprob",
"num_class" -> 3,
"num_round" -> 100,
"num_workers" -> numWorkers,
"device" -> device
)
)
).setNumRound(10).setNumWorkers(numWorkers)
booster.setFeaturesCol("features")
booster.setLabelCol("classIndex")
val labelConverter = new IndexToString()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class XGBoostClassifier(override val uid: String,
XGBoostClassificationModel = {
new XGBoostClassificationModel(uid, numberClasses, booster, Some(summary))
}

}

object XGBoostClassifier extends DefaultParamsReadable[XGBoostClassifier] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -377,10 +377,8 @@ private[spark] trait XGBoostEstimator[
} else {
setNthread(taskCpus)
}

}


def train(dataset: Dataset[_]): M = {
validate(dataset)

Expand All @@ -402,11 +400,6 @@ private[spark] trait XGBoostEstimator[
}

override def copy(extra: ParamMap): Learner = defaultCopy(extra).asInstanceOf[Learner]

// Not used in XGBoost
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, true)
}
}

/** Indicate what to be predicted */
Expand Down Expand Up @@ -440,11 +433,6 @@ private[spark] trait XGBoostModel[M <: XGBoostModel[M]] extends Model[M] with ML

def summary: Option[XGBoostTrainingSummary]

// Not used in XGBoost
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, false)
}

protected[spark] def postTransform(dataset: Dataset[_], pred: PredictedColumns): Dataset[_] = {
var output = dataset
// Convert leaf/contrib to the vector from array
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ import scala.collection.mutable.ArrayBuffer

import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.xgboost.SparkUtils
import org.apache.spark.sql.types.{DoubleType, StructType}
import org.apache.spark.sql.types.StructType

import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}

Expand Down Expand Up @@ -231,8 +230,6 @@ private[spark] trait SchemaValidationTrait {
fitting: Boolean): StructType = schema
}



/**
* XGBoost ranking spark-specific parameters
*
Expand Down

0 comments on commit 468b812

Please sign in to comment.