This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 493ec7439 [SEDONA-624] Bind references in distance expression to
relations lazily to avoid exception in query plan canonicalization (#1518)
493ec7439 is described below
commit 493ec7439a9a3b0bc3cf12ee93f1f5c8a8547f17
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Wed Jul 10 01:06:56 2024 +0800
[SEDONA-624] Bind references in distance expression to relations lazily to
avoid exception in query plan canonicalization (#1518)
---
.../strategy/join/DistanceJoinExec.scala | 2 +-
.../org/apache/sedona/sql/SpatialJoinSuite.scala | 42 +++++++++++++++++++---
2 files changed, 39 insertions(+), 5 deletions(-)
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
index 4c62d6a82..425eb1b8a 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
@@ -72,7 +72,7 @@ case class DistanceJoinExec(
with TraitJoinQueryExec
with Logging {
- private val boundRadius = if (distanceBoundToLeft) {
+ private lazy val boundRadius = if (distanceBoundToLeft) {
BindReferences.bindReference(distance, left.output)
} else {
BindReferences.bindReference(distance, right.output)
diff --git
a/spark/common/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
b/spark/common/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
index 21740b324..737eb41df 100644
--- a/spark/common/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
+++ b/spark/common/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
@@ -102,13 +102,15 @@ class SpatialJoinSuite extends TestBaseScala with
TableDrivenPropertyChecks {
}
}
- describe("Sedona-SQL Spatial Join Test with SELECT *") {
+ describe("Sedona-SQL Spatial Join Test with SELECT * and SELECT COUNT(*)") {
val joinConditions = Table(
"join condition",
"ST_Contains(df1.geom, df2.geom)",
"ST_Contains(df2.geom, df1.geom)",
"ST_Distance(df1.geom, df2.geom) < 1.0",
- "ST_Distance(df2.geom, df1.geom) < 1.0")
+ "ST_Distance(df2.geom, df1.geom) < 1.0",
+ "ST_Distance(df1.geom, df2.geom) < df1.dist",
+ "ST_Distance(df1.geom, df2.geom) < df2.dist")
forAll(joinConditions) { joinCondition =>
it(s"should SELECT * in join query with $joinCondition produce correct
result") {
@@ -120,6 +122,16 @@ class SpatialJoinSuite extends TestBaseScala with
TableDrivenPropertyChecks {
assert(result === expected)
}
+ it(s"should SELECT COUNT(*) in join query with $joinCondition produce
correct result") {
+ val result = sparkSession
+ .sql(s"SELECT COUNT(*) FROM df1 JOIN df2 ON $joinCondition")
+ .collect()
+ .head
+ .getLong(0)
+ val expected = buildExpectedResult(joinCondition).length
+ assert(result === expected)
+ }
+
it(
s"should SELECT * in join query with $joinCondition produce correct
result, broadcast the left side") {
val resultAll = sparkSession
@@ -131,6 +143,17 @@ class SpatialJoinSuite extends TestBaseScala with
TableDrivenPropertyChecks {
assert(result === expected)
}
+ it(
+ s"should SELECT COUNT(*) in join query with $joinCondition produce
correct result, broadcast the left side") {
+ val result = sparkSession
+ .sql(s"SELECT /*+ BROADCAST(df1) */ COUNT(*) FROM df1 JOIN df2 ON
$joinCondition")
+ .collect()
+ .head
+ .getLong(0)
+ val expected = buildExpectedResult(joinCondition).length
+ assert(result === expected)
+ }
+
it(
s"should SELECT * in join query with $joinCondition produce correct
result, broadcast the right side") {
val resultAll = sparkSession
@@ -141,6 +164,17 @@ class SpatialJoinSuite extends TestBaseScala with
TableDrivenPropertyChecks {
assert(result.nonEmpty)
assert(result === expected)
}
+
+ it(
+ s"should SELECT COUNT(*) in join query with $joinCondition produce
correct result, broadcast the right side") {
+ val result = sparkSession
+ .sql(s"SELECT /*+ BROADCAST(df2) */ COUNT(*) FROM df1 JOIN df2 ON
$joinCondition")
+ .collect()
+ .head
+ .getLong(0)
+ val expected = buildExpectedResult(joinCondition).length
+ assert(result === expected)
+ }
}
}
@@ -192,7 +226,7 @@ class SpatialJoinSuite extends TestBaseScala with
TableDrivenPropertyChecks {
}
}
- describe("Spatial join should work with dataframe containing 0 partitions") {
+ describe("Spatial join should work with dataframe containing various number
of partitions") {
val queries = Table(
"join queries",
"SELECT * FROM df1 JOIN dfEmpty WHERE ST_Intersects(df1.geom,
dfEmpty.geom)",
@@ -203,7 +237,7 @@ class SpatialJoinSuite extends TestBaseScala with
TableDrivenPropertyChecks {
"SELECT /*+ BROADCAST(dfEmpty) */ * FROM dfEmpty JOIN df1 WHERE
ST_Intersects(df1.geom, dfEmpty.geom)")
forAll(queries) { query =>
- it(s"Legacy join: $query") {
+ it(s"empty dataframes: $query") {
withConf(Map(spatialJoinPartitionSideConfKey -> "left")) {
val resultRows = sparkSession.sql(query).collect()
assert(resultRows.isEmpty)