srowen commented on a change in pull request #26415: [SPARK-18409][ML] LSH 
approxNearestNeighbors should use approxQuantile instead of sort
URL: https://github.com/apache/spark/pull/26415#discussion_r347015773
 
 

 ##########
 File path: mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
 ##########
 @@ -137,14 +139,23 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
       val hashDistUDF = udf((x: Seq[Vector]) => hashDistance(x, keyHash), 
DataTypes.DoubleType)
       val hashDistCol = hashDistUDF(col($(outputCol)))
 
-      // Compute threshold to get exact k elements.
-      // TODO: SPARK-18409: Use approxQuantile to get the threshold
-      val modelDatasetSortedByHash = 
modelDataset.sort(hashDistCol).limit(numNearestNeighbors)
-      val thresholdDataset = modelDatasetSortedByHash.select(max(hashDistCol))
-      val hashThreshold = thresholdDataset.take(1).head.getDouble(0)
-
-      // Filter the dataset where the hash value is less than the threshold.
-      modelDataset.filter(hashDistCol <= hashThreshold)
+      val modelDatasetWithDist = modelDataset.withColumn(distCol, hashDistCol)
+      var filtered: DataFrame = null
+      var requestedNum = numNearestNeighbors
+      do {
+        requestedNum *= 2
+        if (requestedNum > modelDataset.count()) {
+          requestedNum = modelDataset.count().toInt
+        }
+        var quantile = requestedNum.toDouble / modelDataset.count()
+        var hashThreshold = modelDatasetWithDist.stat
+          .approxQuantile(distCol, Array(quantile), 0.001)
+
+        // Filter the dataset where the hash value is less than the threshold.
+        filtered = modelDatasetWithDist.filter(hashDistCol <= hashThreshold(0))
 
 Review comment:
   Nevermind my comment; you're not reusing that filtered subset in the next 
loop, oops. This is correct.
   Yes you have to count the filtered result each time.
   
   Certainly the hope is that the first pass will yield enough neighbors, and 
the loop is just there as a fallback. I think it can be a win if the relative 
error tolerance is loose enough that 1 or even 2 approx quantile checks is 
faster than a full sort, but, I don't know how it plays out at scale.
   
   You raise a good point; the docs below do say 'at most k elements', although 
the current implementation will return exactly k (assuming there are at least k 
points in the input). It repeats that twice. Hm. I'd also support just making 
one pass and picking a larger multiple of the request number of neighbors. But 
I don't mind the approach you have here (modulo a few optimizations above).
   
   We could also optimize the `requestedNum > modelDataset.count()` case by 
just returning the whole input in that case rather than continuing with another 
pass.
   
   I wonder if it's reasonable to construct a simple synthetic large input in a 
test, and test out whether it seems to be faster at some scale, and how you 
have to set the relative error to get that speedup, and how likely it is that 
it needs even a second loop.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to