diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala index 548f8b5adda4..215b7408249d 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala @@ -144,7 +144,7 @@ object PreXGBoost extends PreXGBoostProvider { }) (PackedParams(col(est.getLabelCol), col(featuresName), weight, baseMargin, group, - est.getNumWorkers, est.needDeterministicRepartitioning), evalSets, xgbInput) + est.getNumWorkers), evalSets, xgbInput) case _ => throw new RuntimeException("Unsupporting " + estimator) } @@ -379,7 +379,7 @@ object PreXGBoost extends PreXGBoostProvider { xgbExecutionParam.allowNonZeroForMissing), getCacheDirName(xgbExecutionParam.useExternalMemory)) Iterator.single(buildWatches) - }) + }).cache() } else { coPartitionGroupSets(trainingData, evalSetsMap, xgbExecutionParam.numWorkers).mapPartitions( labeledPointGroupSets => { @@ -390,7 +390,7 @@ object PreXGBoost extends PreXGBoostProvider { }, getCacheDirName(xgbExecutionParam.useExternalMemory)) Iterator.single(buildWatches) - }) + }).cache() } } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostEstimatorCommon.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostEstimatorCommon.scala index 9581ea0f2c59..ba559499df39 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostEstimatorCommon.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostEstimatorCommon.scala @@ -28,11 +28,6 @@ private[scala] sealed trait XGBoostEstimatorCommon extends GeneralParams with Le with HasBaseMarginCol with HasLeafPredictionCol with HasContribPredictionCol with HasFeaturesCol with HasLabelCol with HasFeaturesCols with HasHandleInvalid { - def needDeterministicRepartitioning: Boolean = { - isDefined(checkpointPath) && getCheckpointPath != null && getCheckpointPath.nonEmpty && - isDefined(checkpointInterval) && getCheckpointInterval > 0 - } - /** * Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with * invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/util/DataUtils.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/util/DataUtils.scala index acc605b1f0a5..98217e4afcbd 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/util/DataUtils.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/util/DataUtils.scala @@ -72,33 +72,25 @@ object DataUtils extends Serializable { private def attachPartitionKey( row: Row, - deterministicPartition: Boolean, numWorkers: Int, - xgbLp: XGBLabeledPoint): (Int, XGBLabeledPoint) = { - if (deterministicPartition) { - (math.abs(row.hashCode() % numWorkers), xgbLp) + xgbLp: XGBLabeledPoint, + group: Option[Int]): (Int, XGBLabeledPoint) = { + // If group exists, we must use group as key to make sure instances for a group are + // the same partition. + if (group.isDefined){ + (group.get % numWorkers, xgbLp) + // If no group exists, we can use row hash as key for the repartition } else { - (1, xgbLp) + (math.abs(row.hashCode() % numWorkers), xgbLp) } } private def repartitionRDDs( - deterministicPartition: Boolean, numWorkers: Int, arrayOfRDDs: Array[RDD[(Int, XGBLabeledPoint)]]): Array[RDD[XGBLabeledPoint]] = { - if (deterministicPartition) { arrayOfRDDs.map {rdd => rdd.partitionBy(new HashPartitioner(numWorkers))}.map { rdd => rdd.map(_._2) } - } else { - arrayOfRDDs.map(rdd => { - if (rdd.getNumPartitions != numWorkers) { - rdd.map(_._2).repartition(numWorkers) - } else { - rdd.map(_._2) - } - }) - } } /** Packed parameters used by [[convertDataFrameToXGBLabeledPointRDDs]] */ @@ -107,8 +99,7 @@ object DataUtils extends Serializable { weight: Column, baseMargin: Column, group: Option[Column], - numWorkers: Int, - deterministicPartition: Boolean) + numWorkers: Int) /** * convertDataFrameToXGBLabeledPointRDDs converts DataFrames to an array of RDD[XGBLabeledPoint] @@ -122,8 +113,7 @@ object DataUtils extends Serializable { dataFrames: DataFrame*): Array[RDD[XGBLabeledPoint]] = { packedParams match { - case j @ PackedParams(labelCol, featuresCol, weight, baseMargin, group, numWorkers, - deterministicPartition) => + case j @ PackedParams(labelCol, featuresCol, weight, baseMargin, group, numWorkers) => val selectedColumns = group.map(groupCol => Seq(labelCol.cast(FloatType), featuresCol, weight.cast(FloatType), @@ -141,7 +131,7 @@ object DataUtils extends Serializable { case v: DenseVector => (v.size, null, v.values.map(_.toFloat)) } val xgbLp = XGBLabeledPoint(label, size, indices, values, weight, group, baseMargin) - attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp) + attachPartitionKey(row, numWorkers, xgbLp, Some(group)) case row @ Row(label: Float, features: Vector, weight: Float, baseMargin: Float) => val (size, indices, values) = features match { case v: SparseVector => (v.size, v.indices, v.values.map(_.toFloat)) @@ -149,10 +139,10 @@ object DataUtils extends Serializable { } val xgbLp = XGBLabeledPoint(label, size, indices, values, weight, baseMargin = baseMargin) - attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp) + attachPartitionKey(row, numWorkers, xgbLp, None) } } - repartitionRDDs(deterministicPartition, numWorkers, arrayOfRDDs) + repartitionRDDs(numWorkers, arrayOfRDDs) case _ => throw new IllegalArgumentException("Wrong PackedParams") // never reach here } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/DeterministicPartitioningSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/DeterministicPartitioningSuite.scala index 8d9723bb62ef..88dd252a0e72 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/DeterministicPartitioningSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/DeterministicPartitioningSuite.scala @@ -25,26 +25,6 @@ import org.apache.spark.sql.functions._ class DeterministicPartitioningSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest { - test("perform deterministic partitioning when checkpointInternal and" + - " checkpointPath is set (Classifier)") { - val tmpPath = createTmpFolder("model1").toAbsolutePath.toString - val paramMap = Map("eta" -> "1", "max_depth" -> 2, - "objective" -> "binary:logistic", "checkpoint_path" -> tmpPath, - "checkpoint_interval" -> 2, "num_workers" -> numWorkers) - val xgbClassifier = new XGBoostClassifier(paramMap) - assert(xgbClassifier.needDeterministicRepartitioning) - } - - test("perform deterministic partitioning when checkpointInternal and" + - " checkpointPath is set (Regressor)") { - val tmpPath = createTmpFolder("model1").toAbsolutePath.toString - val paramMap = Map("eta" -> "1", "max_depth" -> 2, - "objective" -> "binary:logistic", "checkpoint_path" -> tmpPath, - "checkpoint_interval" -> 2, "num_workers" -> numWorkers) - val xgbRegressor = new XGBoostRegressor(paramMap) - assert(xgbRegressor.needDeterministicRepartitioning) - } - test("deterministic partitioning takes effect with various parts of data") { val trainingDF = buildDataFrame(Classification.train) // the test idea is that, we apply a chain of repartitions over trainingDFs but they @@ -62,8 +42,7 @@ class DeterministicPartitioningSuite extends AnyFunSuite with TmpFolderPerSuite lit(1.0), lit(Float.NaN), None, - numWorkers, - deterministicPartition = true), + numWorkers), df ).head) val resultsMaps = transformedRDDs.map(rdd => rdd.mapPartitionsWithIndex { @@ -97,8 +76,7 @@ class DeterministicPartitioningSuite extends AnyFunSuite with TmpFolderPerSuite lit(1.0), lit(Float.NaN), None, - 10, - deterministicPartition = true), df + 10), df ).head val partitionsSizes = dfRepartitioned