Github user imatiach-msft commented on a diff in the pull request:

    https://github.com/apache/spark/pull/17086#discussion_r231294624
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala 
---
    @@ -39,21 +46,28 @@ class MulticlassMetrics @Since("1.1.0") 
(predictionAndLabels: RDD[(Double, Doubl
       private[mllib] def this(predictionAndLabels: DataFrame) =
         this(predictionAndLabels.rdd.map(r => (r.getDouble(0), 
r.getDouble(1))))
     
    -  private lazy val labelCountByClass: Map[Double, Long] = 
predictionAndLabels.values.countByValue()
    -  private lazy val labelCount: Long = labelCountByClass.values.sum
    -  private lazy val tpByClass: Map[Double, Int] = predictionAndLabels
    -    .map { case (prediction, label) =>
    -      (label, if (label == prediction) 1 else 0)
    +  private lazy val labelCountByClass: Map[Double, Double] =
    +    predLabelsWeight.map {
    +      case (prediction: Double, label: Double, weight: Double) =>
    +        (label, weight)
    +    }.reduceByKey(_ + _).collect().toMap
    +  private lazy val labelCount: Double = labelCountByClass.values.sum
    +  private lazy val tpByClass: Map[Double, Double] = predLabelsWeight
    +    .map {
    --- End diff --
    
    it looks like this will actually cause tests to fail, because the key may 
become missing if we filter everything out first, whereas we would want it to 
be present otherwise but have a 0 value


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to