Github user WeichenXu123 commented on a diff in the pull request: https://github.com/apache/spark/pull/21044#discussion_r181286908 --- Diff: mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala --- @@ -195,15 +206,32 @@ final class OneVsRestModel private[ml] ( newDataset.unpersist() } - // output the index of the classifier with highest confidence as prediction - val labelUDF = udf { (predictions: Map[Int, Double]) => - predictions.maxBy(_._2)._1.toDouble - } + // output the RawPrediction as vector + if (getRawPredictionCol != "") { + val rawPredictionUDF = udf { (predictions: Map[Int, Double]) => + val predArray = Array.fill[Double](numClasses)(0.0) + predictions.foreach { case (idx, value) => predArray(idx) = value } + Vectors.dense(predArray) + } + + // output the index of the classifier with highest confidence as prediction + val labelUDF = udf { (predictions: Vector) => predictions.argmax.toDouble } - // output label and label metadata as prediction - aggregatedDataset - .withColumn($(predictionCol), labelUDF(col(accColName)), labelMetadata) - .drop(accColName) + aggregatedDataset + .withColumn(getRawPredictionCol, rawPredictionUDF(col(accColName))) + .withColumn(getPredictionCol, labelUDF(col(getRawPredictionCol)), labelMetadata) + .drop(accColName) + } + else { + // output the index of the classifier with highest confidence as prediction + val labelUDF = udf { (predictions: Map[Int, Double]) => + predictions.maxBy(_._2)._1.toDouble + } + // output confidence as rwa prediction, label and label metadata as prediction --- End diff -- rwa -> raw
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org