Github user WeichenXu123 commented on a diff in the pull request:
https://github.com/apache/spark/pull/21044#discussion_r181287383
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala ---
@@ -195,15 +206,32 @@ final class OneVsRestModel private[ml] (
newDataset.unpersist()
}
- // output the index of the classifier with highest confidence as
prediction
- val labelUDF = udf { (predictions: Map[Int, Double]) =>
- predictions.maxBy(_._2)._1.toDouble
- }
+ // output the RawPrediction as vector
+ if (getRawPredictionCol != "") {
+ val rawPredictionUDF = udf { (predictions: Map[Int, Double]) =>
+ val predArray = Array.fill[Double](numClasses)(0.0)
+ predictions.foreach { case (idx, value) => predArray(idx) = value }
+ Vectors.dense(predArray)
+ }
+
+ // output the index of the classifier with highest confidence as
prediction
+ val labelUDF = udf { (predictions: Vector) =>
predictions.argmax.toDouble }
--- End diff --
==> `udf { (rawPredictions: Vector) => ... }`
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]