qb-tarushg commented on a change in pull request #24458: [SPARK-27540][MLlib]
Add 'meanAveragePrecision_at_k' metric to RankingMetrics
URL: https://github.com/apache/spark/pull/24458#discussion_r278807435
##########
File path:
mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
##########
@@ -92,6 +92,56 @@ class RankingMetrics[T: ClassTag](predictionAndLabels:
RDD[(Array[T], Array[T])]
}.mean()
}
+ /**
+ * Returns the mean average precision (MAP) at ranking position k of all the
queries.
+ * If a query has an empty ground truth set, the average precision will be
zero and a log
+ * warning is generated.
+ * @param k the position to compute the truncated precision, must be positive
+ * @return the mean average precision at first k ranking positions
+ */
+ @Since("3.0.0")
+ def meanAveragePrecisionAt(k: Int): Double = {
+ require(k > 0, "ranking position k should be positive")
+ predictionAndLabels
+ .map {
+ case (pred, lab) =>
+ averagePrecisionAt(pred, lab, k)
+ }
+ .mean()
+ }
+
+ /**
+ * Computes the average precision at first k ranking positions of all the
queries.
+ * If a query has an empty ground truth set, the value will be zero and a log
+ * warning is generated.
+ *
+ * @param pred predicted ranking
+ * @param lab ground truth
+ * @param k use the top k predicted ranking, must be positive
+ * @return average precision at first k ranking positions
+ */
+ private def averagePrecisionAt(pred: Array[T], lab: Array[T], k: Int):
Double = {
+ val labSet = lab.toSet
+
+ if (labSet.nonEmpty) {
+ var i = 0
+ var cnt = 0
+ var precSum = 0.0
+ val n = math.min(k, pred.length)
+ while (i < n) {
+ if (labSet.contains(pred(i))) {
+ cnt += 1
+ precSum += cnt.toDouble / (i + 1)
+ }
+ i += 1
+ }
+ precSum / math.min(labSet.size, k)
Review comment:
This is true when we are calculating `AP` but for `AP@k` I think we have to
take `math.min(labSet.size, k)`. Cross checked with IR metrics literature and
also found some reference related to AP@k in tensorflow
https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/ops/metrics_impl.py#L2932
https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/ops/metrics_impl.py#L3052
Don't know whether my understanding is right? Please check once and let me
know I will fix the above code.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]