Github user imatiach-msft commented on a diff in the pull request:
https://github.com/apache/spark/pull/17086#discussion_r184430052
--- Diff:
mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala
---
@@ -95,4 +95,95 @@ class MulticlassMetricsSuite extends SparkFunSuite with
MLlibTestSparkContext {
((4.0 / 9) * f2measure0 + (4.0 / 9) * f2measure1 + (1.0 / 9) *
f2measure2)) < delta)
assert(metrics.labels.sameElements(labels))
}
+
+ test("Multiclass evaluation metrics with weights") {
+ /*
+ * Confusion matrix for 3-class classification with total 9 instances
with 2 weights:
+ * |2 * w1|1 * w2 |1 * w1| true class0 (4 instances)
+ * |1 * w2|2 * w1 + 1 * w2|0 | true class1 (4 instances)
+ * |0 |0 |1 * w2| true class2 (1 instance)
+ */
+ val w1 = 2.2
+ val w2 = 1.5
+ val tw = 2.0 * w1 + 1.0 * w2 + 1.0 * w1 + 1.0 * w2 + 2.0 * w1 + 1.0 *
w2 + 1.0 * w2
+ val confusionMatrix = Matrices.dense(3, 3,
+ Array(2 * w1, 1 * w2, 0, 1 * w2, 2 * w1 + 1 * w2, 0, 1 * w1, 0, 1 *
w2))
+ val labels = Array(0.0, 1.0, 2.0)
+ val predictionAndLabelsWithWeights = sc.parallelize(
+ Seq((0.0, 0.0, w1), (0.0, 1.0, w2), (0.0, 0.0, w1), (1.0, 0.0, w2),
+ (1.0, 1.0, w1), (1.0, 1.0, w2), (1.0, 1.0, w1), (2.0, 2.0, w2),
+ (2.0, 0.0, w1)), 2)
+ val metrics = new MulticlassMetrics(predictionAndLabelsWithWeights)
+ val delta = 0.0000001
+ val tpRate0 = (2.0 * w1) / (2.0 * w1 + 1.0 * w2 + 1.0 * w1)
+ val tpRate1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * w2)
+ val tpRate2 = (1.0 * w2) / (1.0 * w2 + 0)
+ val fpRate0 = (1.0 * w2) / (tw - (2.0 * w1 + 1.0 * w2 + 1.0 * w1))
+ val fpRate1 = (1.0 * w2) / (tw - (1.0 * w2 + 2.0 * w1 + 1.0 * w2))
+ val fpRate2 = (1.0 * w1) / (tw - (1.0 * w2))
+ val precision0 = (2.0 * w1) / (2 * w1 + 1 * w2)
+ val precision1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 *
w2)
+ val precision2 = (1.0 * w2) / (1 * w1 + 1 * w2)
+ val recall0 = (2.0 * w1) / (2.0 * w1 + 1.0 * w2 + 1.0 * w1)
+ val recall1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * w2)
+ val recall2 = (1.0 * w2) / (1.0 * w2 + 0)
+ val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0)
+ val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1)
+ val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2)
+ val f2measure0 = (1 + 2 * 2) * precision0 * recall0 / (2 * 2 *
precision0 + recall0)
+ val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 *
precision1 + recall1)
+ val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 *
precision2 + recall2)
+
+
assert(metrics.confusionMatrix.toArray.sameElements(confusionMatrix.toArray))
+ assert(math.abs(metrics.truePositiveRate(0.0) - tpRate0) < delta)
+ assert(math.abs(metrics.truePositiveRate(1.0) - tpRate1) < delta)
+ assert(math.abs(metrics.truePositiveRate(2.0) - tpRate2) < delta)
+ assert(math.abs(metrics.falsePositiveRate(0.0) - fpRate0) < delta)
+ assert(math.abs(metrics.falsePositiveRate(1.0) - fpRate1) < delta)
+ assert(math.abs(metrics.falsePositiveRate(2.0) - fpRate2) < delta)
+ assert(math.abs(metrics.precision(0.0) - precision0) < delta)
+ assert(math.abs(metrics.precision(1.0) - precision1) < delta)
+ assert(math.abs(metrics.precision(2.0) - precision2) < delta)
+ assert(math.abs(metrics.recall(0.0) - recall0) < delta)
+ assert(math.abs(metrics.recall(1.0) - recall1) < delta)
+ assert(math.abs(metrics.recall(2.0) - recall2) < delta)
+ assert(math.abs(metrics.fMeasure(0.0) - f1measure0) < delta)
+ assert(math.abs(metrics.fMeasure(1.0) - f1measure1) < delta)
+ assert(math.abs(metrics.fMeasure(2.0) - f1measure2) < delta)
+ assert(math.abs(metrics.fMeasure(0.0, 2.0) - f2measure0) < delta)
+ assert(math.abs(metrics.fMeasure(1.0, 2.0) - f2measure1) < delta)
+ assert(math.abs(metrics.fMeasure(2.0, 2.0) - f2measure2) < delta)
+
+ assert(math.abs(metrics.accuracy -
+ (2.0 * w1 + 2.0 * w1 + 1.0 * w2 + 1.0 * w2) / tw) < delta)
+ assert(math.abs(metrics.accuracy - metrics.precision) < delta)
+ assert(math.abs(metrics.accuracy - metrics.recall) < delta)
+ assert(math.abs(metrics.accuracy - metrics.fMeasure) < delta)
+ assert(math.abs(metrics.accuracy - metrics.weightedRecall) < delta)
+ assert(math.abs(metrics.weightedTruePositiveRate -
+ (((2 * w1 + 1 * w2 + 1 * w1) / tw) * tpRate0 +
+ ((1 * w2 + 2 * w1 + 1 * w2) / tw) * tpRate1 +
+ (1 * w2 / tw) * tpRate2)) < delta)
+ assert(math.abs(metrics.weightedFalsePositiveRate -
+ (((2 * w1 + 1 * w2 + 1 * w1) / tw) * fpRate0 +
+ ((1 * w2 + 2 * w1 + 1 * w2) / tw) * fpRate1 +
+ (1 * w2 / tw) * fpRate2)) < delta)
+ assert(math.abs(metrics.weightedPrecision -
+ (((2 * w1 + 1 * w2 + 1 * w1) / tw) * precision0 +
+ ((1 * w2 + 2 * w1 + 1 * w2) / tw) * precision1 +
+ (1 * w2 / tw) * precision2)) < delta)
+ assert(math.abs(metrics.weightedRecall -
+ (((2 * w1 + 1 * w2 + 1 * w1) / tw) * recall0 +
+ ((1 * w2 + 2 * w1 + 1 * w2) / tw) * recall1 +
+ (1 * w2 / tw) * recall2)) < delta)
+ assert(math.abs(metrics.weightedFMeasure -
+ (((2 * w1 + 1 * w2 + 1 * w1) / tw) * f1measure0 +
+ ((1 * w2 + 2 * w1 + 1 * w2) / tw) * f1measure1 +
+ (1 * w2 / tw) * f1measure2)) < delta)
+ assert(math.abs(metrics.weightedFMeasure(2.0) -
+ (((2 * w1 + 1 * w2 + 1 * w1) / tw) * f2measure0 +
+ ((1 * w2 + 2 * w1 + 1 * w2) / tw) * f2measure1 +
+ (1 * w2 / tw) * f2measure2)) < delta)
--- End diff --
sure, I was trying to follow the format of the other existing test, made
the change in both test cases
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]