This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 9497a075c9 [SEDONA-688] Verify KNN parameter K must be equal or larger
than 1 (#1739)
9497a075c9 is described below
commit 9497a075c9c8517ddc2223c0deb92c43c5bfbdde
Author: Feng Zhang <[email protected]>
AuthorDate: Fri Jan 3 18:39:21 2025 -0800
[SEDONA-688] Verify KNN parameter K must be equal or larger than 1 (#1739)
---
.../strategy/join/BroadcastObjectSideKNNJoinExec.scala | 2 +-
.../strategy/join/BroadcastQuerySideKNNJoinExec.scala | 2 +-
.../sql/sedona_sql/strategy/join/JoinQueryDetector.scala | 8 ++++++++
.../spark/sql/sedona_sql/strategy/join/KNNJoinExec.scala | 2 +-
.../test/scala/org/apache/sedona/sql/KnnJoinSuite.scala | 16 ++++++++++++++++
5 files changed, 27 insertions(+), 3 deletions(-)
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastObjectSideKNNJoinExec.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastObjectSideKNNJoinExec.scala
index 1b21c79e7c..c5777be3c1 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastObjectSideKNNJoinExec.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastObjectSideKNNJoinExec.scala
@@ -120,7 +120,7 @@ case class BroadcastObjectSideKNNJoinExec(
sedonaConf: SedonaConf): Unit = {
require(numPartitions > 0, "The number of partitions must be greater than
0.")
val kValue: Int = this.k.eval().asInstanceOf[Int]
- require(kValue > 0, "The number of neighbors must be greater than 0.")
+ require(kValue >= 1, "The number of neighbors (k) must be equal or greater
than 1.")
objectsShapes.setNeighborSampleNumber(kValue)
broadcastJoin = true
}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala
index 812bc6e6d6..001c0a1ca3 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala
@@ -127,7 +127,7 @@ case class BroadcastQuerySideKNNJoinExec(
sedonaConf: SedonaConf): Unit = {
require(numPartitions > 0, "The number of partitions must be greater than
0.")
val kValue: Int = this.k.eval().asInstanceOf[Int]
- require(kValue > 0, "The number of neighbors must be greater than 0.")
+ require(kValue >= 1, "The number of neighbors (k) must be equal or greater
than 1.")
objectsShapes.setNeighborSampleNumber(kValue)
val joinPartitions: Integer = numPartitions
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
index 825855b88c..da9bd5359b 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
@@ -582,6 +582,10 @@ class JoinQueryDetector(sparkSession: SparkSession)
extends Strategy {
return Nil
}
+ // validate the k value
+ val kValue: Int = distance.eval().asInstanceOf[Int]
+ require(kValue >= 1, "The number of neighbors (k) must be equal or greater
than 1.")
+
val leftShape = children.head
val rightShape = children.tail.head
@@ -711,6 +715,10 @@ class JoinQueryDetector(sparkSession: SparkSession)
extends Strategy {
if (spatialPredicate == SpatialPredicate.KNN) {
{
+ // validate the k value for KNN join
+ val kValue: Int = distance.get.eval().asInstanceOf[Int]
+ require(kValue >= 1, "The number of neighbors (k) must be equal or
greater than 1.")
+
val leftShape = children.head
val rightShape = children.tail.head
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/KNNJoinExec.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/KNNJoinExec.scala
index 2b9bbfb50b..fdc53d13ce 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/KNNJoinExec.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/KNNJoinExec.scala
@@ -162,7 +162,7 @@ case class KNNJoinExec(
sedonaConf: SedonaConf): Unit = {
require(numPartitions > 0, "The number of partitions must be greater than
0.")
val kValue: Int = this.k.eval().asInstanceOf[Int]
- require(kValue > 0, "The number of neighbors must be greater than 0.")
+ require(kValue >= 1, "The number of neighbors (k) must be equal or greater
than 1.")
objectsShapes.setNeighborSampleNumber(kValue)
exactSpatialPartitioning(objectsShapes, queryShapes, numPartitions)
diff --git
a/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala
b/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala
index 1d6119d02d..f3b07c2501 100644
--- a/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala
+++ b/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala
@@ -209,6 +209,22 @@ class KnnJoinSuite extends TestBaseScala with
TableDrivenPropertyChecks {
"[1,3][1,6][1,13][1,16][2,1][2,5][2,11][2,15][3,3][3,9][3,13][3,19]")
}
+ it("KNN Join should verify the correct parameter k is passed to the join
function") {
+ val df = sparkSession
+ .range(0, 1)
+ .toDF("id")
+ .withColumn("geom", expr("ST_Point(id, id)"))
+ .repartition(1)
+ df.createOrReplaceTempView("df1")
+ val exception = intercept[IllegalArgumentException] {
+ sparkSession
+ .sql(s"SELECT A.ID, B.ID FROM df1 A JOIN df1 B ON ST_KNN(A.GEOM,
B.GEOM, 0, false)")
+ .collect()
+ }
+ exception.getMessage should include(
+ "The number of neighbors (k) must be equal or greater than 1.")
+ }
+
it("KNN Join with exact algorithms with additional join conditions on id")
{
val df = sparkSession.sql(
s"SELECT QUERIES.ID, OBJECTS.ID FROM QUERIES JOIN OBJECTS ON
ST_KNN(QUERIES.GEOM, OBJECTS.GEOM, 4, false) AND QUERIES.ID > 1")