Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21044#discussion_r181288725
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala ---
    @@ -195,15 +206,32 @@ final class OneVsRestModel private[ml] (
           newDataset.unpersist()
         }
     
    -    // output the index of the classifier with highest confidence as 
prediction
    -    val labelUDF = udf { (predictions: Map[Int, Double]) =>
    -      predictions.maxBy(_._2)._1.toDouble
    -    }
    +    // output the RawPrediction as vector
    +    if (getRawPredictionCol != "") {
    +      val rawPredictionUDF = udf { (predictions: Map[Int, Double]) =>
    +        val predArray = Array.fill[Double](numClasses)(0.0)
    +        predictions.foreach { case (idx, value) => predArray(idx) = value }
    +        Vectors.dense(predArray)
    +      }
    +
    +      // output the index of the classifier with highest confidence as 
prediction
    +      val labelUDF = udf { (predictions: Vector) => 
predictions.argmax.toDouble }
     
    -    // output label and label metadata as prediction
    -    aggregatedDataset
    -      .withColumn($(predictionCol), labelUDF(col(accColName)), 
labelMetadata)
    -      .drop(accColName)
    +      aggregatedDataset
    +        .withColumn(getRawPredictionCol, rawPredictionUDF(col(accColName)))
    +        .withColumn(getPredictionCol, labelUDF(col(getRawPredictionCol)), 
labelMetadata)
    +        .drop(accColName)
    +    }
    +    else {
    --- End diff --
    
    Scala style: This should go on the previous line: ```} else {```


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to