Github user zhengruifeng commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19927#discussion_r156314727
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala ---
    @@ -156,54 +153,22 @@ final class OneVsRestModel private[ml] (
         // Check schema
         transformSchema(dataset.schema, logging = true)
     
    -    // determine the input columns: these need to be passed through
    -    val origCols = dataset.schema.map(f => col(f.name))
    -
    -    // add an accumulator column to store predictions of all the models
    -    val accColName = "mbc$acc" + UUID.randomUUID().toString
    -    val initUDF = udf { () => Map[Int, Double]() }
    -    val newDataset = dataset.withColumn(accColName, initUDF())
    -
    -    // persist if underlying dataset is not persistent.
    -    val handlePersistence = dataset.storageLevel == StorageLevel.NONE
    -    if (handlePersistence) {
    -      newDataset.persist(StorageLevel.MEMORY_AND_DISK)
    -    }
    -
    -    // update the accumulator column with the result of prediction of 
models
    -    val aggregatedDataset = 
models.zipWithIndex.foldLeft[DataFrame](newDataset) {
    -      case (df, (model, index)) =>
    -        val rawPredictionCol = model.getRawPredictionCol
    -        val columns = origCols ++ List(col(rawPredictionCol), 
col(accColName))
    -
    -        // add temporary column to store intermediate scores and update
    -        val tmpColName = "mbc$tmp" + UUID.randomUUID().toString
    -        val updateUDF = udf { (predictions: Map[Int, Double], prediction: 
Vector) =>
    -          predictions + ((index, prediction(1)))
    +    val predictUDF = udf { (features: Any) =>
    +      var i = 0
    +      var maxIndex = Double.NaN
    +      var maxPred = Double.MinValue
    +      while (i < models.length) {
    +        val pred = models(i).predictRawAsFeaturesType(features)(1)
    --- End diff --
    
    @WeichenXu123  Agree that we should take some distributed classification 
model like `DistributedLDAModel` in the future into account, but for now all 
classification models are not distributed.
    What about let models tell us whether it is distributed (like adding a 
method named `isDistributed`), and determine the way of prediction in OVR?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to