This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new 493ec7439 [SEDONA-624] Bind references in distance expression to 
relations lazily to avoid exception in query plan canonicalization (#1518)
493ec7439 is described below

commit 493ec7439a9a3b0bc3cf12ee93f1f5c8a8547f17
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Wed Jul 10 01:06:56 2024 +0800

    [SEDONA-624] Bind references in distance expression to relations lazily to 
avoid exception in query plan canonicalization (#1518)
---
 .../strategy/join/DistanceJoinExec.scala           |  2 +-
 .../org/apache/sedona/sql/SpatialJoinSuite.scala   | 42 +++++++++++++++++++---
 2 files changed, 39 insertions(+), 5 deletions(-)

diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
index 4c62d6a82..425eb1b8a 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
@@ -72,7 +72,7 @@ case class DistanceJoinExec(
     with TraitJoinQueryExec
     with Logging {
 
-  private val boundRadius = if (distanceBoundToLeft) {
+  private lazy val boundRadius = if (distanceBoundToLeft) {
     BindReferences.bindReference(distance, left.output)
   } else {
     BindReferences.bindReference(distance, right.output)
diff --git 
a/spark/common/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala 
b/spark/common/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
index 21740b324..737eb41df 100644
--- a/spark/common/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
+++ b/spark/common/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
@@ -102,13 +102,15 @@ class SpatialJoinSuite extends TestBaseScala with 
TableDrivenPropertyChecks {
     }
   }
 
-  describe("Sedona-SQL Spatial Join Test with SELECT *") {
+  describe("Sedona-SQL Spatial Join Test with SELECT * and SELECT COUNT(*)") {
     val joinConditions = Table(
       "join condition",
       "ST_Contains(df1.geom, df2.geom)",
       "ST_Contains(df2.geom, df1.geom)",
       "ST_Distance(df1.geom, df2.geom) < 1.0",
-      "ST_Distance(df2.geom, df1.geom) < 1.0")
+      "ST_Distance(df2.geom, df1.geom) < 1.0",
+      "ST_Distance(df1.geom, df2.geom) < df1.dist",
+      "ST_Distance(df1.geom, df2.geom) < df2.dist")
 
     forAll(joinConditions) { joinCondition =>
       it(s"should SELECT * in join query with $joinCondition produce correct 
result") {
@@ -120,6 +122,16 @@ class SpatialJoinSuite extends TestBaseScala with 
TableDrivenPropertyChecks {
         assert(result === expected)
       }
 
+      it(s"should SELECT COUNT(*) in join query with $joinCondition produce 
correct result") {
+        val result = sparkSession
+          .sql(s"SELECT COUNT(*) FROM df1 JOIN df2 ON $joinCondition")
+          .collect()
+          .head
+          .getLong(0)
+        val expected = buildExpectedResult(joinCondition).length
+        assert(result === expected)
+      }
+
       it(
         s"should SELECT * in join query with $joinCondition produce correct 
result, broadcast the left side") {
         val resultAll = sparkSession
@@ -131,6 +143,17 @@ class SpatialJoinSuite extends TestBaseScala with 
TableDrivenPropertyChecks {
         assert(result === expected)
       }
 
+      it(
+        s"should SELECT COUNT(*) in join query with $joinCondition produce 
correct result, broadcast the left side") {
+        val result = sparkSession
+          .sql(s"SELECT /*+ BROADCAST(df1) */ COUNT(*) FROM df1 JOIN df2 ON 
$joinCondition")
+          .collect()
+          .head
+          .getLong(0)
+        val expected = buildExpectedResult(joinCondition).length
+        assert(result === expected)
+      }
+
       it(
         s"should SELECT * in join query with $joinCondition produce correct 
result, broadcast the right side") {
         val resultAll = sparkSession
@@ -141,6 +164,17 @@ class SpatialJoinSuite extends TestBaseScala with 
TableDrivenPropertyChecks {
         assert(result.nonEmpty)
         assert(result === expected)
       }
+
+      it(
+        s"should SELECT COUNT(*) in join query with $joinCondition produce 
correct result, broadcast the right side") {
+        val result = sparkSession
+          .sql(s"SELECT /*+ BROADCAST(df2) */ COUNT(*) FROM df1 JOIN df2 ON 
$joinCondition")
+          .collect()
+          .head
+          .getLong(0)
+        val expected = buildExpectedResult(joinCondition).length
+        assert(result === expected)
+      }
     }
   }
 
@@ -192,7 +226,7 @@ class SpatialJoinSuite extends TestBaseScala with 
TableDrivenPropertyChecks {
     }
   }
 
-  describe("Spatial join should work with dataframe containing 0 partitions") {
+  describe("Spatial join should work with dataframe containing various number 
of partitions") {
     val queries = Table(
       "join queries",
       "SELECT * FROM df1 JOIN dfEmpty WHERE ST_Intersects(df1.geom, 
dfEmpty.geom)",
@@ -203,7 +237,7 @@ class SpatialJoinSuite extends TestBaseScala with 
TableDrivenPropertyChecks {
       "SELECT /*+ BROADCAST(dfEmpty) */ * FROM dfEmpty JOIN df1 WHERE 
ST_Intersects(df1.geom, dfEmpty.geom)")
 
     forAll(queries) { query =>
-      it(s"Legacy join: $query") {
+      it(s"empty dataframes: $query") {
         withConf(Map(spatialJoinPartitionSideConfKey -> "left")) {
           val resultRows = sparkSession.sql(query).collect()
           assert(resultRows.isEmpty)

Reply via email to