Skip to content

Commit

Permalink
supporting array for cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Oct 24, 2024
1 parent e1e7d8d commit 564a239
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.ml.linalg.{SparseVector, Vector}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.util.{DefaultParamsWritable, MLReader, MLWritable, MLWriter}
import org.apache.spark.ml.xgboost.{SparkUtils, XGBProbabilisticClassifierParams}
import org.apache.spark.mllib.linalg.VectorUDT
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.functions.{col, udf}
Expand Down Expand Up @@ -187,16 +188,20 @@ private[spark] trait XGBoostEstimator[
* @return
*/
private[spark] def preprocess(dataset: Dataset[_]): (Dataset[_], ColumnIndices) = {
val featureIsArray: Boolean = featureTypeChecking(dataset.schema)

// Columns to be selected for XGBoost training
val selectedCols: ArrayBuffer[Column] = ArrayBuffer.empty
val schema = dataset.schema

def selectCol(c: Param[String], targetType: DataType) = {
if (isDefinedNonEmpty(c)) {
// Validation col should be a boolean column.
if (c == featuresCol) {
selectedCols.append(col($(c)))
// If feature is array type, we force to cast it to array of float
val featureCol = if (featureIsArray) {
col($(featuresCol)).cast(ArrayType(FloatType))
} else col($(featuresCol))
selectedCols.append(featureCol)
} else {
selectedCols.append(castIfNeeded(schema, $(c), targetType))
}
Expand All @@ -219,22 +224,30 @@ private[spark] trait XGBoostEstimator[
columnIndexes: ColumnIndices): RDD[XGBLabeledPoint] = {
val isSetMissing = isSet(missing)
dataset.toDF().rdd.map { row =>
val features = row.getAs[Vector](columnIndexes.featureId.get)
val label = row.getFloat(columnIndexes.labelId)
val weight = columnIndexes.weightId.map(row.getFloat).getOrElse(1.0f)
val baseMargin = columnIndexes.marginId.map(row.getFloat).getOrElse(Float.NaN)
val group = columnIndexes.groupId.map(row.getInt).getOrElse(-1)
// To make "0" meaningful, we convert sparse vector if possible to dense to create DMatrix.
features match {
case _: SparseVector => if (!isSetMissing) {
throw new IllegalArgumentException("We've detected sparse vectors in the dataset that " +
"need conversion to dense format. However, we can't assume 0 for missing values as " +
"it may be meaningful. Please specify the missing value explicitly to ensure " +
"accurate data representation for analysis.")
}
case _ =>

val values = row.schema(columnIndexes.featureId.get).dataType match {
case _: ArrayType =>
// The driver has casted the array(_) to array(float), so it's safe to
// specify it as WrappedArray[Float]
row.getAs[mutable.WrappedArray[Float]](columnIndexes.featureId.get).toArray
case _: VectorUDT =>
val features = row.getAs[Vector](columnIndexes.featureId.get)
features match {
case _: SparseVector => if (!isSetMissing) {
throw new IllegalArgumentException("We've detected sparse vectors in the dataset " +
"that need conversion to dense format. However, we can't assume 0 for missing " +
"values as it may be meaningful. Please specify the missing value explicitly to" +
"ensure accurate data representation for analysis.")
}
case _ => // DenseVector
}
// To make "0" meaningful, we convert sparse vector if possible to dense.
features.toArray.map(_.toFloat)
}
val values = features.toArray.map(_.toFloat)
XGBLabeledPoint(label, values.length, null, values, weight, group, baseMargin)
}
}
Expand Down Expand Up @@ -581,7 +594,6 @@ private[spark] trait XGBoostModel[M <: XGBoostModel[M]] extends Model[M] with ML
}

override def transform(dataset: Dataset[_]): DataFrame = {

if (getPlugin.isDefined) {
return getPlugin.get.transform(this, dataset)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ import scala.collection.mutable.ArrayBuffer

import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.sql.types.StructType
import org.apache.spark.mllib.linalg.VectorUDT
import org.apache.spark.sql.types.{ArrayType, StructType}

import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}

Expand Down Expand Up @@ -222,6 +223,16 @@ private[spark] trait SparkParams[T <: Params] extends HasFeaturesCols with HasFe
def setFeatureNames(value: Array[String]): T = set(featureNames, value).asInstanceOf[T]

def setFeatureTypes(value: Array[String]): T = set(featureTypes, value).asInstanceOf[T]

protected def featureTypeChecking(schema: StructType): Boolean = {
// Features cols must be Vector or Array.
schema(getFeaturesCol).dataType match {
case _: ArrayType => true
case _: VectorUDT => false
case _ => throw new IllegalArgumentException("Feature type must be Array type " +
"or Vector type")
}
}
}

private[spark] trait SchemaValidationTrait {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,15 @@ trait PerTest extends BeforeAndAfterEach {
(1.0, -0.4, -2.1, Vectors.dense(0.5, 2.2, 1.7))
))).toDF("label", "margin", "weight", "features")

def smallBinaryClassificationArray: DataFrame = ss.createDataFrame(sc.parallelize(Seq(
(1.0, 0.5, 1.0, Seq(1.0, 2.0, 3.0)),
(0.0, 0.4, -3.0, Seq(0.0, 0.0, 0.0)),
(0.0, 0.3, 1.0, Seq(0.0, 3.0, 0.0)),
(1.0, 1.2, 0.2, Seq(2.0, 0.0, 4.0)),
(0.0, -0.5, 0.0, Seq(0.2, 1.2, 2.0)),
(1.0, -0.4, -2.1, Seq(0.5, 2.2, 1.7))
))).toDF("label", "margin", "weight", "features")

def smallMultiClassificationVector: DataFrame = ss.createDataFrame(sc.parallelize(Seq(
(1.0, 0.5, 1.0, Vectors.dense(1.0, 2.0, 3.0)),
(0.0, 0.4, -3.0, Vectors.dense(0.0, 0.0, 0.0)),
Expand All @@ -121,7 +130,6 @@ trait PerTest extends BeforeAndAfterEach {
(2.0, -0.4, -2.1, Vectors.dense(0.5, 2.2, 1.7))
))).toDF("label", "margin", "weight", "features")


def smallGroupVector: DataFrame = ss.createDataFrame(sc.parallelize(Seq(
(1.0, 0, 0.5, 2.0, Vectors.dense(1.0, 2.0, 3.0)),
(0.0, 1, 0.4, 1.0, Vectors.dense(0.0, 0.0, 0.0)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import java.io.File
import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{ArrayType, FloatType}
import org.scalatest.funsuite.AnyFunSuite

import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
Expand All @@ -38,6 +40,15 @@ class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerS
assert(classifier.getNumWorkers === classifier.getNumWorkers)
}

test("xxxxxx") {
val df = smallBinaryClassificationArray
// df.select(col("features").cast(ArrayType(FloatType))).printSchema()
// df.printSchema()
val classifier = new XGBoostClassifier()
val model = classifier.fit(df)
model.transform(df).show()
}

test("XGBoostClassification copy") {
val model = new XGBoostClassificationModel("hello").setNthread(2).setNumWorkers(10)
val modelCopied = model.copy(ParamMap.empty)
Expand Down

0 comments on commit 564a239

Please sign in to comment.