Github user imatiach-msft commented on a diff in the pull request: https://github.com/apache/spark/pull/17086#discussion_r185163674 --- Diff: mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala --- @@ -55,44 +60,128 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1) val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2) - assert(metrics.confusionMatrix.toArray.sameElements(confusionMatrix.toArray)) - assert(math.abs(metrics.truePositiveRate(0.0) - tpRate0) < delta) - assert(math.abs(metrics.truePositiveRate(1.0) - tpRate1) < delta) - assert(math.abs(metrics.truePositiveRate(2.0) - tpRate2) < delta) - assert(math.abs(metrics.falsePositiveRate(0.0) - fpRate0) < delta) - assert(math.abs(metrics.falsePositiveRate(1.0) - fpRate1) < delta) - assert(math.abs(metrics.falsePositiveRate(2.0) - fpRate2) < delta) - assert(math.abs(metrics.precision(0.0) - precision0) < delta) - assert(math.abs(metrics.precision(1.0) - precision1) < delta) - assert(math.abs(metrics.precision(2.0) - precision2) < delta) - assert(math.abs(metrics.recall(0.0) - recall0) < delta) - assert(math.abs(metrics.recall(1.0) - recall1) < delta) - assert(math.abs(metrics.recall(2.0) - recall2) < delta) - assert(math.abs(metrics.fMeasure(0.0) - f1measure0) < delta) - assert(math.abs(metrics.fMeasure(1.0) - f1measure1) < delta) - assert(math.abs(metrics.fMeasure(2.0) - f1measure2) < delta) - assert(math.abs(metrics.fMeasure(0.0, 2.0) - f2measure0) < delta) - assert(math.abs(metrics.fMeasure(1.0, 2.0) - f2measure1) < delta) - assert(math.abs(metrics.fMeasure(2.0, 2.0) - f2measure2) < delta) + assert(metrics.confusionMatrix.asML ~== confusionMatrix.asML relTol delta) + assert(metrics.truePositiveRate(0.0) ~== tpRate0 absTol delta) + assert(metrics.truePositiveRate(1.0) ~== tpRate1 absTol delta) + assert(metrics.truePositiveRate(2.0) ~== tpRate2 absTol delta) + assert(metrics.falsePositiveRate(0.0) ~== fpRate0 absTol delta) + assert(metrics.falsePositiveRate(1.0) ~== fpRate1 absTol delta) + assert(metrics.falsePositiveRate(2.0) ~== fpRate2 absTol delta) + assert(metrics.precision(0.0) ~== precision0 absTol delta) + assert(metrics.precision(1.0) ~== precision1 absTol delta) + assert(metrics.precision(2.0) ~== precision2 absTol delta) + assert(metrics.recall(0.0) ~== recall0 absTol delta) + assert(metrics.recall(1.0) ~== recall1 absTol delta) + assert(metrics.recall(2.0) ~== recall2 absTol delta) + assert(metrics.fMeasure(0.0) ~== f1measure0 absTol delta) + assert(metrics.fMeasure(1.0) ~== f1measure1 absTol delta) + assert(metrics.fMeasure(2.0) ~== f1measure2 absTol delta) + assert(metrics.fMeasure(0.0, 2.0) ~== f2measure0 absTol delta) + assert(metrics.fMeasure(1.0, 2.0) ~== f2measure1 absTol delta) + assert(metrics.fMeasure(2.0, 2.0) ~== f2measure2 absTol delta) + + assert(metrics.accuracy ~== + (2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1)) absTol delta) + assert(metrics.accuracy ~== metrics.precision absTol delta) + assert(metrics.accuracy ~== metrics.recall absTol delta) + assert(metrics.accuracy ~== metrics.fMeasure absTol delta) + assert(metrics.accuracy ~== metrics.weightedRecall absTol delta) + val weight0 = 4.0 / 9 + val weight1 = 4.0 / 9 + val weight2 = 1.0 / 9 + assert(metrics.weightedTruePositiveRate ~== + (weight0 * tpRate0 + weight1 * tpRate1 + weight2 * tpRate2) absTol delta) + assert(metrics.weightedFalsePositiveRate ~== + (weight0 * fpRate0 + weight1 * fpRate1 + weight2 * fpRate2) absTol delta) + assert(metrics.weightedPrecision ~== + (weight0 * precision0 + weight1 * precision1 + weight2 * precision2) absTol delta) + assert(metrics.weightedRecall ~== + (weight0 * recall0 + weight1 * recall1 + weight2 * recall2) absTol delta) + assert(metrics.weightedFMeasure ~== + (weight0 * f1measure0 + weight1 * f1measure1 + weight2 * f1measure2) absTol delta) + assert(metrics.weightedFMeasure(2.0) ~== + (weight0 * f2measure0 + weight1 * f2measure1 + weight2 * f2measure2) absTol delta) + assert(metrics.labels === labels) + } + + test("Multiclass evaluation metrics with weights") { + /* + * Confusion matrix for 3-class classification with total 9 instances with 2 weights: + * |2 * w1|1 * w2 |1 * w1| true class0 (4 instances) + * |1 * w2|2 * w1 + 1 * w2|0 | true class1 (4 instances) + * |0 |0 |1 * w2| true class2 (1 instance) + */ + val w1 = 2.2 + val w2 = 1.5 + val tw = 2.0 * w1 + 1.0 * w2 + 1.0 * w1 + 1.0 * w2 + 2.0 * w1 + 1.0 * w2 + 1.0 * w2 + val confusionMatrix = Matrices.dense(3, 3, + Array(2 * w1, 1 * w2, 0, 1 * w2, 2 * w1 + 1 * w2, 0, 1 * w1, 0, 1 * w2)) + val labels = Array(0.0, 1.0, 2.0) + val predictionAndLabelsWithWeights = sc.parallelize( + Seq((0.0, 0.0, w1), (0.0, 1.0, w2), (0.0, 0.0, w1), (1.0, 0.0, w2), + (1.0, 1.0, w1), (1.0, 1.0, w2), (1.0, 1.0, w1), (2.0, 2.0, w2), + (2.0, 0.0, w1)), 2) + val metrics = new MulticlassMetrics(predictionAndLabelsWithWeights) + val tpRate0 = (2.0 * w1) / (2.0 * w1 + 1.0 * w2 + 1.0 * w1) + val tpRate1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * w2) + val tpRate2 = (1.0 * w2) / (1.0 * w2 + 0) + val fpRate0 = (1.0 * w2) / (tw - (2.0 * w1 + 1.0 * w2 + 1.0 * w1)) + val fpRate1 = (1.0 * w2) / (tw - (1.0 * w2 + 2.0 * w1 + 1.0 * w2)) + val fpRate2 = (1.0 * w1) / (tw - (1.0 * w2)) + val precision0 = (2.0 * w1) / (2 * w1 + 1 * w2) + val precision1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * w2) + val precision2 = (1.0 * w2) / (1 * w1 + 1 * w2) + val recall0 = (2.0 * w1) / (2.0 * w1 + 1.0 * w2 + 1.0 * w1) + val recall1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * w2) + val recall2 = (1.0 * w2) / (1.0 * w2 + 0) + val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0) + val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1) + val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2) + val f2measure0 = (1 + 2 * 2) * precision0 * recall0 / (2 * 2 * precision0 + recall0) + val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1) + val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2) + + assert(metrics.confusionMatrix.asML ~== confusionMatrix.asML relTol delta) + assert(metrics.truePositiveRate(0.0) ~== tpRate0 absTol delta) + assert(metrics.truePositiveRate(1.0) ~== tpRate1 absTol delta) + assert(metrics.truePositiveRate(2.0) ~== tpRate2 absTol delta) + assert(metrics.falsePositiveRate(0.0) ~== fpRate0 absTol delta) + assert(metrics.falsePositiveRate(1.0) ~== fpRate1 absTol delta) + assert(metrics.falsePositiveRate(2.0) ~== fpRate2 absTol delta) + assert(metrics.precision(0.0) ~== precision0 absTol delta) + assert(metrics.precision(1.0) ~== precision1 absTol delta) + assert(metrics.precision(2.0) ~== precision2 absTol delta) + assert(metrics.recall(0.0) ~== recall0 absTol delta) + assert(metrics.recall(1.0) ~== recall1 absTol delta) + assert(metrics.recall(2.0) ~== recall2 absTol delta) + assert(metrics.fMeasure(0.0) ~== f1measure0 absTol delta) + assert(metrics.fMeasure(1.0) ~== f1measure1 absTol delta) + assert(metrics.fMeasure(2.0) ~== f1measure2 absTol delta) + assert(metrics.fMeasure(0.0, 2.0) ~== f2measure0 absTol delta) + assert(metrics.fMeasure(1.0, 2.0) ~== f2measure1 absTol delta) + assert(metrics.fMeasure(2.0, 2.0) ~== f2measure2 absTol delta) - assert(math.abs(metrics.accuracy - - (2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1))) < delta) - assert(math.abs(metrics.accuracy - metrics.precision) < delta) - assert(math.abs(metrics.accuracy - metrics.recall) < delta) - assert(math.abs(metrics.accuracy - metrics.fMeasure) < delta) - assert(math.abs(metrics.accuracy - metrics.weightedRecall) < delta) - assert(math.abs(metrics.weightedTruePositiveRate - - ((4.0 / 9) * tpRate0 + (4.0 / 9) * tpRate1 + (1.0 / 9) * tpRate2)) < delta) - assert(math.abs(metrics.weightedFalsePositiveRate - - ((4.0 / 9) * fpRate0 + (4.0 / 9) * fpRate1 + (1.0 / 9) * fpRate2)) < delta) - assert(math.abs(metrics.weightedPrecision - - ((4.0 / 9) * precision0 + (4.0 / 9) * precision1 + (1.0 / 9) * precision2)) < delta) - assert(math.abs(metrics.weightedRecall - - ((4.0 / 9) * recall0 + (4.0 / 9) * recall1 + (1.0 / 9) * recall2)) < delta) - assert(math.abs(metrics.weightedFMeasure - - ((4.0 / 9) * f1measure0 + (4.0 / 9) * f1measure1 + (1.0 / 9) * f1measure2)) < delta) - assert(math.abs(metrics.weightedFMeasure(2.0) - - ((4.0 / 9) * f2measure0 + (4.0 / 9) * f2measure1 + (1.0 / 9) * f2measure2)) < delta) - assert(metrics.labels.sameElements(labels)) + assert(metrics.accuracy ~== + (2.0 * w1 + 2.0 * w1 + 1.0 * w2 + 1.0 * w2) / tw absTol delta) + assert(metrics.accuracy ~== metrics.precision absTol delta) + assert(metrics.accuracy ~== metrics.recall absTol delta) + assert(metrics.accuracy ~== metrics.fMeasure absTol delta) + assert(metrics.accuracy ~== metrics.weightedRecall absTol delta) + val weight0 = (2 * w1 + 1 * w2 + 1 * w1) / tw + val weight1 = (1 * w2 + 2 * w1 + 1 * w2) / tw + val weight2 = 1 * w2 / tw + assert(metrics.weightedTruePositiveRate ~== + (weight0 * tpRate0 + weight1 * tpRate1 + weight2 * tpRate2) absTol delta) + assert(metrics.weightedFalsePositiveRate ~== + (weight0 * fpRate0 + weight1 * fpRate1 + weight2 * fpRate2) absTol delta) + assert(metrics.weightedPrecision ~== + (weight0 * precision0 + weight1 * precision1 + weight2 * precision2) absTol delta) + assert(metrics.weightedRecall ~== + (weight0 * recall0 + weight1 * recall1 + weight2 * recall2) absTol delta) + assert(metrics.weightedFMeasure ~== + (weight0 * f1measure0 + weight1 * f1measure1 + weight2 * f1measure2) absTol delta) + assert(metrics.weightedFMeasure(2.0) ~== + (weight0 * f2measure0 + weight1 * f2measure1 + weight2 * f2measure2) absTol delta) --- End diff -- good idea, I've replaced absTol with relTol, done
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org