Github user WeichenXu123 commented on a diff in the pull request:
https://github.com/apache/spark/pull/19927#discussion_r156249173
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala ---
@@ -156,54 +153,22 @@ final class OneVsRestModel private[ml] (
// Check schema
transformSchema(dataset.schema, logging = true)
- // determine the input columns: these need to be passed through
- val origCols = dataset.schema.map(f => col(f.name))
-
- // add an accumulator column to store predictions of all the models
- val accColName = "mbc$acc" + UUID.randomUUID().toString
- val initUDF = udf { () => Map[Int, Double]() }
- val newDataset = dataset.withColumn(accColName, initUDF())
-
- // persist if underlying dataset is not persistent.
- val handlePersistence = dataset.storageLevel == StorageLevel.NONE
- if (handlePersistence) {
- newDataset.persist(StorageLevel.MEMORY_AND_DISK)
- }
-
- // update the accumulator column with the result of prediction of
models
- val aggregatedDataset =
models.zipWithIndex.foldLeft[DataFrame](newDataset) {
- case (df, (model, index)) =>
- val rawPredictionCol = model.getRawPredictionCol
- val columns = origCols ++ List(col(rawPredictionCol),
col(accColName))
-
- // add temporary column to store intermediate scores and update
- val tmpColName = "mbc$tmp" + UUID.randomUUID().toString
- val updateUDF = udf { (predictions: Map[Int, Double], prediction:
Vector) =>
- predictions + ((index, prediction(1)))
+ val predictUDF = udf { (features: Any) =>
+ var i = 0
+ var maxIndex = Double.NaN
+ var maxPred = Double.MinValue
+ while (i < models.length) {
+ val pred = models(i).predictRawAsFeaturesType(features)(1)
--- End diff --
I thought on this, when a model is a distributed model, e.g, this model
including a `Dataframe` or `RDD` data, then it cannot be broadcasted. In
`OneVsRest`, it should support any kind of classifier, so we cannot suppose
that the corresponding model of specified classifier can be broadcast.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]