Github user jkbradley commented on a diff in the pull request:
https://github.com/apache/spark/pull/21044#discussion_r181288721
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala ---
@@ -195,15 +206,32 @@ final class OneVsRestModel private[ml] (
newDataset.unpersist()
}
- // output the index of the classifier with highest confidence as
prediction
- val labelUDF = udf { (predictions: Map[Int, Double]) =>
- predictions.maxBy(_._2)._1.toDouble
- }
+ // output the RawPrediction as vector
+ if (getRawPredictionCol != "") {
+ val rawPredictionUDF = udf { (predictions: Map[Int, Double]) =>
+ val predArray = Array.fill[Double](numClasses)(0.0)
--- End diff --
This causes a subtle ContextCleaner bug: `numClasses` refers to a field of
the class OneVsRestModel, so when Spark's closure capture serializes this UDF
to send to executors, it will end up sending the entire OneVsRestModel object,
rather than just the value for numClasses. Make a local copy of the value
numClasses within the transform() method to avoid this issue.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]