diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala index 98b70a63c4f6..7d623b180f57 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala @@ -136,6 +136,7 @@ private[spark] trait XGBoostEstimator[ /** * Sort partition for Ranker issue. + * * @param dataset * @return */ @@ -320,6 +321,7 @@ private[spark] trait XGBoostEstimator[ trainRDD.zipPartitions(evalRDD) { (left, right) => new Iterator[Watches] { override def hasNext: Boolean = left.hasNext + override def next(): Watches = { val trainDMatrix = buildDMatrix(left) val evalDMatrix = buildDMatrix(right) @@ -332,6 +334,7 @@ private[spark] trait XGBoostEstimator[ trainRDD.mapPartitions { iter => new Iterator[Watches] { override def hasNext: Boolean = iter.hasNext + override def next(): Watches = { val dm = buildDMatrix(iter) new Watches(Array(dm), Array(Utils.TRAIN_NAME), None) @@ -527,24 +530,27 @@ private[spark] trait XGBoostModel[M <: XGBoostModel[M]] extends Model[M] with ML /** Predict */ private[spark] def predictInternal(booster: Booster, dm: DMatrix, pred: PredictedColumns, - batchRow: Iterator[Row]): Seq[Row] = { - var tmpOut = batchRow.toSeq.map(_.toSeq) - val zip = (left: Seq[Seq[_]], right: Array[Array[Float]]) => left.zip(right).map { - case (a, b) => a ++ Seq(b) - } + originalRowIter: Iterator[Row]): Iterator[Row] = { + val tmpIters: ArrayBuffer[Iterator[Row]] = ArrayBuffer.empty if (pred.predLeaf) { - tmpOut = zip(tmpOut, booster.predictLeaf(dm)) + tmpIters += booster.predictLeaf(dm).map(Row(_)).iterator } if (pred.predContrib) { - tmpOut = zip(tmpOut, booster.predictContrib(dm)) + tmpIters += booster.predictContrib(dm).map(Row(_)).iterator } if (pred.predRaw) { - tmpOut = zip(tmpOut, booster.predict(dm, outPutMargin = true)) + tmpIters += booster.predict(dm, outPutMargin = true).map(Row(_)).iterator } if (pred.predTmp) { - tmpOut = zip(tmpOut, booster.predict(dm, outPutMargin = false)) + tmpIters += booster.predict(dm, outPutMargin = false).map(Row(_)).iterator + } + + tmpIters.foldLeft(originalRowIter) { case (accIter, nextIter) => + // Zip the accumulated iterator with the next iterator + accIter.zip(nextIter).map { case (a: Row, b: Row) => + Row.fromSeq(a.toSeq ++ b.toSeq) + } } - tmpOut.map(Row.fromSeq) } override def transform(dataset: Dataset[_]): DataFrame = {