Repository: spark Updated Branches: refs/heads/master 24d07e45d -> d1cf32010
[SPARK-14886][MLLIB] RankingMetrics.ndcgAt throw java.lang.ArrayIndexOutOfBoundsException ## What changes were proposed in this pull request? Handle case where number of predictions is less than label set, k in nDCG computation ## How was this patch tested? New unit test; existing tests Author: Sean Owen <[email protected]> Closes #12756 from srowen/SPARK-14886. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d1cf3201 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d1cf3201 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d1cf3201 Branch: refs/heads/master Commit: d1cf320105504f908ee01f33044d0a6b29c3c03f Parents: 24d07e4 Author: Sean Owen <[email protected]> Authored: Fri Apr 29 09:21:27 2016 +0200 Committer: Nick Pentreath <[email protected]> Committed: Fri Apr 29 09:21:27 2016 +0200 ---------------------------------------------------------------------- .../spark/mllib/evaluation/RankingMetrics.scala | 2 +- .../mllib/evaluation/RankingMetricsSuite.scala | 26 ++++++++++++++++---- 2 files changed, 22 insertions(+), 6 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/d1cf3201/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala index c45742c..4ed4a05 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala @@ -140,7 +140,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])] var i = 0 while (i < n) { val gain = 1.0 / math.log(i + 2) - if (labSet.contains(pred(i))) { + if (i < pred.length && labSet.contains(pred(i))) { dcg += gain } if (i < labSetSize) { http://git-wip-us.apache.org/repos/asf/spark/blob/d1cf3201/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala index 77ec49d..8e9d910 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala @@ -22,14 +22,15 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { - test("Ranking metrics: map, ndcg") { + + test("Ranking metrics: MAP, NDCG") { val predictionAndLabels = sc.parallelize( Seq( - (Array[Int](1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array[Int](1, 2, 3, 4, 5)), - (Array[Int](4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Array[Int](1, 2, 3)), - (Array[Int](1, 2, 3, 4, 5), Array[Int]()) + (Array(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array(1, 2, 3, 4, 5)), + (Array(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Array(1, 2, 3)), + (Array(1, 2, 3, 4, 5), Array[Int]()) ), 2) - val eps: Double = 1E-5 + val eps = 1.0E-5 val metrics = new RankingMetrics(predictionAndLabels) val map = metrics.meanAveragePrecision @@ -48,6 +49,21 @@ class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { assert(metrics.ndcgAt(5) ~== 0.328788 absTol eps) assert(metrics.ndcgAt(10) ~== 0.487913 absTol eps) assert(metrics.ndcgAt(15) ~== metrics.ndcgAt(10) absTol eps) + } + + test("MAP, NDCG with few predictions (SPARK-14886)") { + val predictionAndLabels = sc.parallelize( + Seq( + (Array(1, 6, 2), Array(1, 2, 3, 4, 5)), + (Array[Int](), Array(1, 2, 3)) + ), 2) + val eps = 1.0E-5 + val metrics = new RankingMetrics(predictionAndLabels) + assert(metrics.precisionAt(1) ~== 0.5 absTol eps) + assert(metrics.precisionAt(2) ~== 0.25 absTol eps) + assert(metrics.ndcgAt(1) ~== 0.5 absTol eps) + assert(metrics.ndcgAt(2) ~== 0.30657 absTol eps) } + } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
