Github user WeichenXu123 commented on a diff in the pull request:
https://github.com/apache/spark/pull/21044#discussion_r181286908
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala ---
@@ -195,15 +206,32 @@ final class OneVsRestModel private[ml] (
newDataset.unpersist()
}
- // output the index of the classifier with highest confidence as
prediction
- val labelUDF = udf { (predictions: Map[Int, Double]) =>
- predictions.maxBy(_._2)._1.toDouble
- }
+ // output the RawPrediction as vector
+ if (getRawPredictionCol != "") {
+ val rawPredictionUDF = udf { (predictions: Map[Int, Double]) =>
+ val predArray = Array.fill[Double](numClasses)(0.0)
+ predictions.foreach { case (idx, value) => predArray(idx) = value }
+ Vectors.dense(predArray)
+ }
+
+ // output the index of the classifier with highest confidence as
prediction
+ val labelUDF = udf { (predictions: Vector) =>
predictions.argmax.toDouble }
- // output label and label metadata as prediction
- aggregatedDataset
- .withColumn($(predictionCol), labelUDF(col(accColName)),
labelMetadata)
- .drop(accColName)
+ aggregatedDataset
+ .withColumn(getRawPredictionCol, rawPredictionUDF(col(accColName)))
+ .withColumn(getPredictionCol, labelUDF(col(getRawPredictionCol)),
labelMetadata)
+ .drop(accColName)
+ }
+ else {
+ // output the index of the classifier with highest confidence as
prediction
+ val labelUDF = udf { (predictions: Map[Int, Double]) =>
+ predictions.maxBy(_._2)._1.toDouble
+ }
+ // output confidence as rwa prediction, label and label metadata as
prediction
--- End diff --
rwa -> raw
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]