huaxingao commented on a change in pull request #26415: [SPARK-18409][ML] LSH
approxNearestNeighbors should use approxQuantile instead of sort
URL: https://github.com/apache/spark/pull/26415#discussion_r347619272
##########
File path: mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
##########
@@ -137,14 +139,23 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
val hashDistUDF = udf((x: Seq[Vector]) => hashDistance(x, keyHash),
DataTypes.DoubleType)
val hashDistCol = hashDistUDF(col($(outputCol)))
- // Compute threshold to get exact k elements.
- // TODO: SPARK-18409: Use approxQuantile to get the threshold
- val modelDatasetSortedByHash =
modelDataset.sort(hashDistCol).limit(numNearestNeighbors)
- val thresholdDataset = modelDatasetSortedByHash.select(max(hashDistCol))
- val hashThreshold = thresholdDataset.take(1).head.getDouble(0)
-
- // Filter the dataset where the hash value is less than the threshold.
- modelDataset.filter(hashDistCol <= hashThreshold)
+ val modelDatasetWithDist = modelDataset.withColumn(distCol, hashDistCol)
+ var filtered: DataFrame = null
+ var requestedNum = numNearestNeighbors
+ do {
+ requestedNum *= 2
+ if (requestedNum > modelDataset.count()) {
+ requestedNum = modelDataset.count().toInt
+ }
+ var quantile = requestedNum.toDouble / modelDataset.count()
+ var hashThreshold = modelDatasetWithDist.stat
+ .approxQuantile(distCol, Array(quantile), 0.001)
+
+ // Filter the dataset where the hash value is less than the threshold.
+ filtered = modelDatasetWithDist.filter(hashDistCol <= hashThreshold(0))
Review comment:
Sounds good.
I did a small test to try to find a good value of relative error to use. I
used existing ```BucketedRandomProjectionLSHSuite``` but made the dataset
bigger: one with incremental values and one with random values.
```
val data1 = {
for (i <- -200 until 200; j <- -200 until 200) yield
Vectors.dense(i*10, j*10)
}
dataset1 = spark.createDataFrame(data1.map(Tuple1.apply)).toDF("keys")
```
```
val data2 = {
for (i <- -200 until 200; j <- -200 until 200) yield
Vectors.dense(Random.nextInt, Random.nextInt)
}
dataset2 = spark.createDataFrame(data2.map(Tuple1.apply)).toDF("keys")
```
So the dataset count is 160000 and I asked for 40000 nearest neighbors. I
tested relative error 0.001, 0.005, 0.01, 0.05, 0.1 and 0.2 but didn't see
much change in performance. Not sure if it is because the dataset is too
small.
**Dataset1 with incremental values**
| relative error | 1st run | 2nd run | 3rd run | 4th run | 5th run | average
|
| ----------- | ------- | -------- | -------- | -------- | -------|
-------|
| 0.001 | 7.61s | 6.56s | 6.39s | 7.32s | 7.49s | 6.998s |
| 0.005 | 6.39s | 6.44s | 6.62s | 6.39s | 7.54s | 6.67s |
| 0.01 | 6.56s | 6.38s | 7.34s | 6.58s | 6.68s | 6.71s |
| 0.05 | 6.51s | 6.24s | 7.24s | 6.34s | 6.54s | 6.57s |
| 0.1 | 6.28s | 6.20s | 6.34s | 6.68s | 7.07s | 6.51s |
| 0.2 | 6.39s | 6.21s | 6.25s | 6.22s | 6.30s | 6.27s |
**Dataset2 with random values**
| relative error | 1st run | 2nd run | 3rd run | 4th run | 5th run | average
|
| ----------- | ------- | -------- | -------- | -------- | -------|
-------|
| 0.001 | 7.66s | 6.77s | 6.75s | 7.78s | 6.64s | 7.11s |
| 0.005 | 6.57s | 6.61s | 6.75s | 7.42s | 6.60s | 6.79s |
| 0.01 | 6.68s | 7.44s | 6.25s | 6.69s | 7.48s | 6.91s |
| 0.05 | 6.59s | 6.54s | 6.75s | 6.62s | 6.63s | 6.62s |
| 0.1 | 7.73s | 6.58s | 6.61s | 6.68s | 6.55s | 6.83s |
| 0.2 | 6.61s | 6.62s | 6.54s | 6.51s | 6.59s | 6.57s |
Seems to me that it may not be good to have a fixed value for relative
error. For example, 0.05 might be a good relative error for the case of getting
40000 nearest neighbors from 160000 data, but it's too big for the case of
getting 400 nearest neighbors from 160000 data. I guess I will pick
err = 0.2 M / N. Since p = err + M / N, we actually have p = 1.2 M/N. Hope
this makes sense.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]