Github user mengxr commented on a diff in the pull request:
https://github.com/apache/spark/pull/3118#discussion_r20140244
--- Diff:
mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
---
@@ -24,39 +24,99 @@ import org.apache.spark.mllib.util.TestingUtils._
class BinaryClassificationMetricsSuite extends FunSuite with
LocalSparkContext {
- def cond1(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5
+ private def areWithinEpsilon(x: (Double, Double)): Boolean = x._1 ~=
(x._2) absTol 1E-5
- def cond2(x: ((Double, Double), (Double, Double))): Boolean =
+ private def pairsWithinEpsilon(x: ((Double, Double), (Double, Double))):
Boolean =
(x._1._1 ~= x._2._1 absTol 1E-5) && (x._1._2 ~= x._2._2 absTol 1E-5)
+ private def assertSequencesMatch(left: Seq[Double], right: Seq[Double]):
Unit = {
+ assert(left.zip(right).forall(areWithinEpsilon))
+ }
+
+ private def assertTupleSequencesMatch(left: Seq[(Double, Double)],
right: Seq[(Double, Double)]): Unit = {
+ assert(left.zip(right).forall(pairsWithinEpsilon))
+ }
+
+ private def validateMetrics(metrics: BinaryClassificationMetrics,
+ expectedThresholds: Seq[Double],
+ expectedROCCurve: Seq[(Double, Double)],
+ expectedPRCurve: Seq[(Double, Double)],
+ expectedFMeasures1: Seq[Double],
+ expectedFmeasures2: Seq[Double],
+ expectedPrecisions: Seq[Double],
+ expectedRecalls: Seq[Double]) = {
+
+ assertSequencesMatch(metrics.thresholds().collect(),
expectedThresholds)
+ assertTupleSequencesMatch(metrics.roc().collect(), expectedROCCurve)
+ assert(metrics.areaUnderROC() ~== AreaUnderCurve.of(expectedROCCurve)
absTol 1E-5)
+ assertTupleSequencesMatch(metrics.pr().collect(), expectedPRCurve)
+ assert(metrics.areaUnderPR() ~== AreaUnderCurve.of(expectedPRCurve)
absTol 1E-5)
+ assertTupleSequencesMatch(metrics.fMeasureByThreshold().collect(),
expectedThresholds.zip(expectedFMeasures1))
+ assertTupleSequencesMatch(metrics.fMeasureByThreshold(2.0).collect(),
expectedThresholds.zip(expectedFmeasures2))
+ assertTupleSequencesMatch(metrics.precisionByThreshold().collect(),
expectedThresholds.zip(expectedPrecisions))
+ assertTupleSequencesMatch(metrics.recallByThreshold().collect(),
expectedThresholds.zip(expectedRecalls))
+ }
+
test("binary evaluation metrics") {
val scoreAndLabels = sc.parallelize(
Seq((0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0),
(0.6, 1.0), (0.8, 1.0)), 2)
val metrics = new BinaryClassificationMetrics(scoreAndLabels)
- val threshold = Seq(0.8, 0.6, 0.4, 0.1)
+ val thresholds = Seq(0.8, 0.6, 0.4, 0.1)
val numTruePositives = Seq(1, 3, 3, 4)
val numFalsePositives = Seq(0, 1, 2, 3)
val numPositives = 4
val numNegatives = 3
- val precision = numTruePositives.zip(numFalsePositives).map { case (t,
f) =>
+ val precisions = numTruePositives.zip(numFalsePositives).map { case
(t, f) =>
t.toDouble / (t + f)
}
- val recall = numTruePositives.map(t => t.toDouble / numPositives)
+ val recalls = numTruePositives.map(t => t.toDouble / numPositives)
val fpr = numFalsePositives.map(f => f.toDouble / numNegatives)
- val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recall) ++ Seq((1.0, 1.0))
- val pr = recall.zip(precision)
+ val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recalls) ++ Seq((1.0, 1.0))
+ val pr = recalls.zip(precisions)
val prCurve = Seq((0.0, 1.0)) ++ pr
val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)}
val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}
- assert(metrics.thresholds().collect().zip(threshold).forall(cond1))
- assert(metrics.roc().collect().zip(rocCurve).forall(cond2))
- assert(metrics.areaUnderROC() ~== AreaUnderCurve.of(rocCurve) absTol
1E-5)
- assert(metrics.pr().collect().zip(prCurve).forall(cond2))
- assert(metrics.areaUnderPR() ~== AreaUnderCurve.of(prCurve) absTol
1E-5)
-
assert(metrics.fMeasureByThreshold().collect().zip(threshold.zip(f1)).forall(cond2))
-
assert(metrics.fMeasureByThreshold(2.0).collect().zip(threshold.zip(f2)).forall(cond2))
-
assert(metrics.precisionByThreshold().collect().zip(threshold.zip(precision)).forall(cond2))
-
assert(metrics.recallByThreshold().collect().zip(threshold.zip(recall)).forall(cond2))
+ validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2,
precisions, recalls)
+ }
+
+ test("binary evaluation metrics for All Positive RDD") {
+ val scoreAndLabels = sc.parallelize(Seq((0.5, 1.0), (0.5, 1.0)), 2)
+ val metrics: BinaryClassificationMetrics = new
BinaryClassificationMetrics(scoreAndLabels)
--- End diff --
The type info is not necessary.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]