diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkMLlibPipeline.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkMLlibPipeline.scala index 1b46d2f050bb..26a68f085fbb 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkMLlibPipeline.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkMLlibPipeline.scala @@ -88,11 +88,9 @@ object SparkMLlibPipeline { "max_depth" -> 2, "objective" -> "multi:softprob", "num_class" -> 3, - "num_round" -> 100, - "num_workers" -> numWorkers, "device" -> device ) - ) + ).setNumRound(10).setNumWorkers(numWorkers) booster.setFeaturesCol("features") booster.setLabelCol("classIndex") val labelConverter = new IndexToString() diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala index 8653e9d9d47f..408a10011f92 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala @@ -108,6 +108,7 @@ class XGBoostClassifier(override val uid: String, XGBoostClassificationModel = { new XGBoostClassificationModel(uid, numberClasses, booster, Some(summary)) } + } object XGBoostClassifier extends DefaultParamsReadable[XGBoostClassifier] { diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala index a17856d41246..7f0d26370c86 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala @@ -377,10 +377,8 @@ private[spark] trait XGBoostEstimator[ } else { setNthread(taskCpus) } - } - def train(dataset: Dataset[_]): M = { validate(dataset) @@ -402,11 +400,6 @@ private[spark] trait XGBoostEstimator[ } override def copy(extra: ParamMap): Learner = defaultCopy(extra).asInstanceOf[Learner] - - // Not used in XGBoost - override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema, true) - } } /** Indicate what to be predicted */ @@ -440,11 +433,6 @@ private[spark] trait XGBoostModel[M <: XGBoostModel[M]] extends Model[M] with ML def summary: Option[XGBoostTrainingSummary] - // Not used in XGBoost - override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema, false) - } - protected[spark] def postTransform(dataset: Dataset[_], pred: PredictedColumns): Dataset[_] = { var output = dataset // Convert leaf/contrib to the vector from array diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostParams.scala index 7ef57e79e1ce..8345cab35149 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostParams.scala @@ -20,8 +20,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.xgboost.SparkUtils -import org.apache.spark.sql.types.{DoubleType, StructType} +import org.apache.spark.sql.types.StructType import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait} @@ -231,8 +230,6 @@ private[spark] trait SchemaValidationTrait { fitting: Boolean): StructType = schema } - - /** * XGBoost ranking spark-specific parameters *