Github user srowen commented on a diff in the pull request:
https://github.com/apache/spark/pull/17086#discussion_r231724984
--- Diff:
mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
---
@@ -27,10 +27,17 @@ import org.apache.spark.sql.DataFrame
/**
* Evaluator for multiclass classification.
*
- * @param predictionAndLabels an RDD of (prediction, label) pairs.
+ * @param predAndLabelsWithOptWeight an RDD of (prediction, label, weight)
or
+ * (prediction, label) pairs.
*/
@Since("1.1.0")
-class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double,
Double)]) {
+class MulticlassMetrics @Since("3.0.0") (predAndLabelsWithOptWeight:
RDD[_]) {
+ val predLabelsWeight: RDD[(Double, Double, Double)] =
predAndLabelsWithOptWeight.map {
+ case (prediction: Double, label: Double, weight: Double) =>
+ (prediction, label, weight)
+ case (prediction: Double, label: Double) =>
+ (prediction, label, 1.0)
--- End diff --
If you're making one more change, might make an explicit check here on the
type of the RDD, now that it can be anything at compile time. Like `case other
=> throw new IllegalArgumentException(s"Expected tuples, got $other")`
Actually, in order to tighten this back down a little, I wonder if the
method argument can be `RDD[_ <: Product]` ? That includes Tuple2 and Tuple3,
and a lot of other things, but is more specific than 'anything'.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]