huaxingao commented on a change in pull request #26415: [SPARK-18409][ML] LSH
approxNearestNeighbors should use approxQuantile instead of sort
URL: https://github.com/apache/spark/pull/26415#discussion_r346986229
##########
File path: mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
##########
@@ -137,14 +139,23 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
val hashDistUDF = udf((x: Seq[Vector]) => hashDistance(x, keyHash),
DataTypes.DoubleType)
val hashDistCol = hashDistUDF(col($(outputCol)))
- // Compute threshold to get exact k elements.
- // TODO: SPARK-18409: Use approxQuantile to get the threshold
- val modelDatasetSortedByHash =
modelDataset.sort(hashDistCol).limit(numNearestNeighbors)
- val thresholdDataset = modelDatasetSortedByHash.select(max(hashDistCol))
- val hashThreshold = thresholdDataset.take(1).head.getDouble(0)
-
- // Filter the dataset where the hash value is less than the threshold.
- modelDataset.filter(hashDistCol <= hashThreshold)
+ val modelDatasetWithDist = modelDataset.withColumn(distCol, hashDistCol)
+ var filtered: DataFrame = null
+ var requestedNum = numNearestNeighbors
+ do {
+ requestedNum *= 2
+ if (requestedNum > modelDataset.count()) {
+ requestedNum = modelDataset.count().toInt
+ }
+ var quantile = requestedNum.toDouble / modelDataset.count()
+ var hashThreshold = modelDatasetWithDist.stat
+ .approxQuantile(distCol, Array(quantile), 0.001)
+
+ // Filter the dataset where the hash value is less than the threshold.
+ filtered = modelDatasetWithDist.filter(hashDistCol <= hashThreshold(0))
Review comment:
Seems to me that I have to filter out to find out if I can get enough number
of the nearest neighbors. If not, I go back to the loop to double the quantile.
I am debating if I should continue this PR. The purpose of this PR is to
improve performance. If the first round of the loop doesn't get enough number
of the nearest neighbors and we have to go into the loop multiple times, the
performance could be worse than the original code.
In the doc of ```approxNearestNeighbors```, it says ```Given a large dataset
and an item, approximately find at most k items which have the closest distance
to the item. ``` If this is true, then I guess we can just use a quantile that
should yield 2x the number of results. If we get less than k elements, that's
OK. However, the original implementation returns exact k elements. I am not
sure if we can change the original behavior.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]