Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Jun 25, 2024
1 parent 66a97c4 commit 1fd3d48
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ class GpuXGBoostPlugin extends XGBoostPlugin {
// UnsafeProjection is not serializable so do it on the executor side
val toUnsafe = UnsafeProjection.create(originalSchema)


synchronized {
val device = booster.getAttr("device")
if (device != null && device.trim.isEmpty) {
Expand Down Expand Up @@ -237,8 +236,7 @@ class GpuXGBoostPlugin extends XGBoostPlugin {
table.getRowCount().toInt)
val rowIterator = currentBatch.rowIterator().asScala.map(toUnsafe)
.map(converter(_))


model.predictInternal(booster, dm, pred, rowIterator).toIterator
} finally {
dm.delete()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -490,51 +490,47 @@ private[spark] trait XGBoostModel[M <: XGBoostModel[M]] extends Model[M] with ML
(schema, PredictedColumns(predLeaf, predContrib, predRaw, predTmp))
}

private[spark] def predictInternal(booster: Booster, dm: DMatrix) = {

/** Predict */
private[spark] def predictInternal(booster: Booster, dm: DMatrix, pred: PredictedColumns,
batchRow: Iterator[Row]): Seq[Row] = {
var tmpOut = batchRow.toSeq.map(_.toSeq)
val zip = (left: Seq[Seq[_]], right: Array[Array[Float]]) => left.zip(right).map {
case (a, b) => a ++ Seq(b)
}
if (pred.predLeaf) {
tmpOut = zip(tmpOut, booster.predictLeaf(dm))
}
if (pred.predContrib) {
tmpOut = zip(tmpOut, booster.predictContrib(dm))
}
if (pred.predRaw) {
tmpOut = zip(tmpOut, booster.predict(dm, outPutMargin = true))
}
if (pred.predTmp) {
tmpOut = zip(tmpOut, booster.predict(dm, outPutMargin = false))
}
tmpOut.map(Row.fromSeq)
}

override def transform(dataset: Dataset[_]): DataFrame = {

val spark = dataset.sparkSession
val bBooster = dataset.sparkSession.sparkContext.broadcast(nativeBooster)

val (schema, pred) = preprocess(dataset)

// TODO configurable
val inferBatchSize = 32 << 10
// Broadcast the booster to each executor.
val bBooster = spark.sparkContext.broadcast(nativeBooster)
val featureName = getFeaturesCol

var output = dataset.toDF().mapPartitions { rowIter =>

rowIter.grouped(inferBatchSize).flatMap { batchRow =>
val features = batchRow.iterator.map(row => row.getAs[Vector](
row.fieldIndex(featureName)))

// DMatrix used to prediction
val dm = new DMatrix(features.map(_.asXGB))

try {
var tmpOut = batchRow.map(_.toSeq)

val zip = (left: Seq[Seq[_]], right: Array[Array[Float]]) => left.zip(right).map {
case (a, b) => a ++ Seq(b)
}

if (pred.predLeaf) {
tmpOut = zip(tmpOut, bBooster.value.predictLeaf(dm))
}
if (pred.predContrib) {
tmpOut = zip(tmpOut, bBooster.value.predictContrib(dm))
}
if (pred.predRaw) {
tmpOut = zip(tmpOut, bBooster.value.predict(dm, outPutMargin = true))
}
if (pred.predTmp) {
tmpOut = zip(tmpOut, bBooster.value.predict(dm, outPutMargin = false))
}
tmpOut.map(Row.fromSeq)
predictInternal(bBooster.value, dm, pred, batchRow.toIterator)
} finally {
dm.delete()
}
Expand Down

0 comments on commit 1fd3d48

Please sign in to comment.