Github user imatiach-msft commented on a diff in the pull request:
https://github.com/apache/spark/pull/17086#discussion_r184420711
--- Diff:
mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
---
@@ -39,21 +46,28 @@ class MulticlassMetrics @Since("1.1.0")
(predictionAndLabels: RDD[(Double, Doubl
private[mllib] def this(predictionAndLabels: DataFrame) =
this(predictionAndLabels.rdd.map(r => (r.getDouble(0),
r.getDouble(1))))
- private lazy val labelCountByClass: Map[Double, Long] =
predictionAndLabels.values.countByValue()
- private lazy val labelCount: Long = labelCountByClass.values.sum
- private lazy val tpByClass: Map[Double, Int] = predictionAndLabels
- .map { case (prediction, label) =>
- (label, if (label == prediction) 1 else 0)
+ private lazy val labelCountByClass: Map[Double, Double] =
+ predLabelsWeight.map {
+ case (prediction: Double, label: Double, weight: Double) =>
+ (label, weight)
+ }.mapValues(weight => weight).reduceByKey(_ + _).collect().toMap
--- End diff --
good catch! removed
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]