Skip to content

Commit

Permalink
[jvm-packages] fix transform performance issue
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Oct 23, 2024
1 parent d9684ea commit 8071c7b
Showing 1 changed file with 16 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ private[spark] trait XGBoostEstimator[

/**
* Sort partition for Ranker issue.
*
* @param dataset
* @return
*/
Expand Down Expand Up @@ -320,6 +321,7 @@ private[spark] trait XGBoostEstimator[
trainRDD.zipPartitions(evalRDD) { (left, right) =>
new Iterator[Watches] {
override def hasNext: Boolean = left.hasNext

override def next(): Watches = {
val trainDMatrix = buildDMatrix(left)
val evalDMatrix = buildDMatrix(right)
Expand All @@ -332,6 +334,7 @@ private[spark] trait XGBoostEstimator[
trainRDD.mapPartitions { iter =>
new Iterator[Watches] {
override def hasNext: Boolean = iter.hasNext

override def next(): Watches = {
val dm = buildDMatrix(iter)
new Watches(Array(dm), Array(Utils.TRAIN_NAME), None)
Expand Down Expand Up @@ -527,24 +530,27 @@ private[spark] trait XGBoostModel[M <: XGBoostModel[M]] extends Model[M] with ML

/** Predict */
private[spark] def predictInternal(booster: Booster, dm: DMatrix, pred: PredictedColumns,
batchRow: Iterator[Row]): Seq[Row] = {
var tmpOut = batchRow.toSeq.map(_.toSeq)
val zip = (left: Seq[Seq[_]], right: Array[Array[Float]]) => left.zip(right).map {
case (a, b) => a ++ Seq(b)
}
originalRowIter: Iterator[Row]): Iterator[Row] = {
val tmpIters: ArrayBuffer[Iterator[Row]] = ArrayBuffer.empty
if (pred.predLeaf) {
tmpOut = zip(tmpOut, booster.predictLeaf(dm))
tmpIters += booster.predictLeaf(dm).map(Row(_)).iterator
}
if (pred.predContrib) {
tmpOut = zip(tmpOut, booster.predictContrib(dm))
tmpIters += booster.predictContrib(dm).map(Row(_)).iterator
}
if (pred.predRaw) {
tmpOut = zip(tmpOut, booster.predict(dm, outPutMargin = true))
tmpIters += booster.predict(dm, outPutMargin = true).map(Row(_)).iterator
}
if (pred.predTmp) {
tmpOut = zip(tmpOut, booster.predict(dm, outPutMargin = false))
tmpIters += booster.predict(dm, outPutMargin = false).map(Row(_)).iterator
}

tmpIters.foldLeft(originalRowIter) { case (accIter, nextIter) =>
// Zip the accumulated iterator with the next iterator
accIter.zip(nextIter).map { case (a: Row, b: Row) =>
Row.fromSeq(a.toSeq ++ b.toSeq)
}
}
tmpOut.map(Row.fromSeq)
}

override def transform(dataset: Dataset[_]): DataFrame = {
Expand Down

0 comments on commit 8071c7b

Please sign in to comment.