Github user imatiach-msft commented on a diff in the pull request:

    https://github.com/apache/spark/pull/17086#discussion_r231296074
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala 
---
    @@ -39,21 +46,28 @@ class MulticlassMetrics @Since("1.1.0") 
(predictionAndLabels: RDD[(Double, Doubl
       private[mllib] def this(predictionAndLabels: DataFrame) =
         this(predictionAndLabels.rdd.map(r => (r.getDouble(0), 
r.getDouble(1))))
     
    -  private lazy val labelCountByClass: Map[Double, Long] = 
predictionAndLabels.values.countByValue()
    -  private lazy val labelCount: Long = labelCountByClass.values.sum
    -  private lazy val tpByClass: Map[Double, Int] = predictionAndLabels
    -    .map { case (prediction, label) =>
    -      (label, if (label == prediction) 1 else 0)
    +  private lazy val labelCountByClass: Map[Double, Double] =
    +    predLabelsWeight.map {
    +      case (prediction: Double, label: Double, weight: Double) =>
    +        (label, weight)
    +    }.reduceByKey(_ + _).collect().toMap
    +  private lazy val labelCount: Double = labelCountByClass.values.sum
    +  private lazy val tpByClass: Map[Double, Double] = predLabelsWeight
    +    .map {
    --- End diff --
    
    see the test failure here for reference: 
https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/98529/testReport/


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to