This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new dc3c5634 [SEDONA-286] Support optimized distance join on 
ST_DistanceSphere and ST_DistanceSpheroid (#845)
dc3c5634 is described below

commit dc3c563492da6dccf116a98b3b381f178daf3c8c
Author: Jia Yu <[email protected]>
AuthorDate: Wed May 31 01:01:37 2023 -0700

    [SEDONA-286] Support optimized distance join on ST_DistanceSphere and 
ST_DistanceSpheroid (#845)
---
 docs/api/sql/Optimizer.md                          | 37 +++++++--
 .../strategy/join/DistanceJoinExec.scala           |  5 +-
 .../strategy/join/JoinQueryDetector.scala          | 97 ++++++++++++++--------
 .../strategy/join/SpatialIndexExec.scala           |  3 +-
 .../strategy/join/TraitJoinQueryBase.scala         | 31 ++++++-
 .../sedona/sql/BroadcastIndexJoinSuite.scala       | 48 +++++++++++
 .../org/apache/sedona/sql/TestBaseScala.scala      | 23 +++++
 .../apache/sedona/sql/predicateJoinTestScala.scala | 48 ++++++++++-
 8 files changed, 242 insertions(+), 50 deletions(-)

diff --git a/docs/api/sql/Optimizer.md b/docs/api/sql/Optimizer.md
index 025b964b..3fa0242b 100644
--- a/docs/api/sql/Optimizer.md
+++ b/docs/api/sql/Optimizer.md
@@ -28,7 +28,9 @@ SELECT *
 FROM pointdf, polygondf
 WHERE ST_Within(pointdf.pointshape, polygondf.polygonshape)
 ```
+
 Spark SQL Physical plan:
+
 ```
 == Physical Plan ==
 RangeJoin polygonshape#20: geometry, pointshape#43: geometry, false
@@ -44,9 +46,9 @@ RangeJoin polygonshape#20: geometry, pointshape#43: geometry, 
false
 
 ## Distance join
 
-Introduction: Find geometries from A and geometries from B such that the 
internal Euclidean distance of each geometry pair is less or equal than a 
certain distance
+Introduction: Find geometries from A and geometries from B such that the 
distance of each geometry pair is less or equal than a certain distance. It 
supports the planar Euclidean distance calculator `ST_Distance` and the 
meter-based geodesic distance calculators `ST_DistanceSpheroid` and 
`ST_DistanceSphere`.
 
-Spark SQL Example:
+Spark SQL Example for planar Euclidean distance:
 
 *Only consider ==fully within a certain distance==*
 ```sql
@@ -73,7 +75,26 @@ DistanceJoin pointshape1#12: geometry, pointshape2#33: 
geometry, 2.0, true
 ```
 
 !!!warning
-       Sedona doesn't control the distance's unit (degree or meter). It is 
same with the geometry. If your coordinates are in the longitude and latitude 
system, the unit of `distance` should be degree instead of meter or mile. To 
change the geometry's unit, please either transform the coordinate reference 
system to a meter-based system. See [ST_Transform](Function.md#st_transform). 
If you don't want to transform your data and are ok with sacrificing the query 
accuracy, you can use an approxima [...]
+       If you use `ST_Distance` as the predicate, Sedona doesn't control the 
distance's unit (degree or meter). It is same with the geometry. If your 
coordinates are in the longitude and latitude system, the unit of `distance` 
should be degree instead of meter or mile. To change the geometry's unit, 
please either transform the coordinate reference system to a meter-based 
system. See [ST_Transform](Function.md#st_transform). If you don't want to 
transform your data, please consider using `ST_Di [...]
+
+Spark SQL Example for meter-based geodesic distance `ST_DistanceSpheroid` 
(works for `ST_DistanceSphere` too):
+
+*Less than a certain distance==*
+```sql
+SELECT *
+FROM pointdf1, pointdf2
+WHERE ST_DistanceSpheroid(pointdf1.pointshape1,pointdf2.pointshape2) < 2
+```
+
+*Less than or equal to a certain distance==*
+```sql
+SELECT *
+FROM pointdf1, pointdf2
+WHERE ST_DistanceSpheroid(pointdf1.pointshape1,pointdf2.pointshape2) <= 2
+```
+
+!!!warning
+       If you use `ST_DistanceSpheroid ` or `ST_DistanceSphere` as the 
predicate, the unit of the distance is meter. Currently, distance join with 
geodesic distance calculators work best for point data. For non-point data, it 
only considers their centroids. The distance join algorithm internally uses an 
approximate distance buffer which might lead to inaccurate results if your data 
is close to the poles or antimeridian.
 
 ## Broadcast index join
 
@@ -105,7 +126,7 @@ BroadcastIndexJoin pointshape#52: geometry, BuildRight, 
BuildRight, false ST_Con
       +- FileScan csv
 ```
 
-This also works for distance joins:
+This also works for distance joins with `ST_Distance`, `ST_DistanceSpheroid` 
or `ST_DistanceSphere`:
 
 ```scala
 pointDf1.alias("pointDf1").join(broadcast(pointDf2).alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"))
@@ -202,14 +223,16 @@ GROUP BY (lcs_geom, rcs_geom)
 
 This also works for distance join. You first need to use `ST_Buffer(geometry, 
distance)` to wrap one of your original geometry column. If your original 
geometry column contains points, this `ST_Buffer` will make them become circles 
with a radius of `distance`.
 
-For example. run this query first on the left table before Step 1.
+Since the coordinates are in the longitude and latitude system, so the unit of 
`distance` should be degree instead of meter or mile. You can get an 
approximation by performing `METER_DISTANCE/111000.0`, then filter out 
false-positives.  Note that this might lead to inaccurate results if your data 
is close to the poles or antimeridian.
+
+In a nutshell, run this query first on the left table before Step 1. Please 
replace `METER_DISTANCE` with a meter distance. In Step 1, generate S2 IDs 
based on the `buffered_geom` column. Then run Step 2, 3, 4 on the original 
`geom` column.
 
 ```sql
-SELECT id, ST_Buffer(geom, DISTANCE), name
+SELECT id, geom , ST_Buffer(geom, METER_DISTANCE/111000.0) as buffered_geom, 
name
 FROM lefts
 ```
 
-Since the coordinates are in the longitude and latitude system, so the unit of 
`distance` should be degree instead of meter or mile. You will have to estimate 
the corresponding degrees based on your meter values. Please use [this 
calculator](https://lucidar.me/en/online-unit-converter-length-to-angle/convert-degrees-to-meters/#online-converter).
+
 
 ## Regular spatial predicate pushdown
 Introduction: Given a join query and a predicate in the same WHERE clause, 
first executes the Predicate as a filter, then executes the join query.
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
index b393f2d0..615f88a2 100644
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
@@ -54,6 +54,7 @@ case class DistanceJoinExec(left: SparkPlan,
                             distance: Expression,
                             distanceBoundToLeft: Boolean,
                             spatialPredicate: SpatialPredicate,
+                            isGeography: Boolean,
                             extraCondition: Option[Expression] = None)
   extends SedonaBinaryExecNode
     with TraitJoinQueryExec
@@ -70,9 +71,9 @@ case class DistanceJoinExec(left: SparkPlan,
                                 rightRdd: RDD[UnsafeRow],
                                 rightShapeExpr: Expression): 
(SpatialRDD[Geometry], SpatialRDD[Geometry]) = {
     if (distanceBoundToLeft) {
-      (toExpandedEnvelopeRDD(leftRdd, leftShapeExpr, boundRadius), 
toSpatialRDD(rightRdd, rightShapeExpr))
+      (toExpandedEnvelopeRDD(leftRdd, leftShapeExpr, boundRadius, 
isGeography), toSpatialRDD(rightRdd, rightShapeExpr))
     } else {
-      (toSpatialRDD(leftRdd, leftShapeExpr), toExpandedEnvelopeRDD(rightRdd, 
rightShapeExpr, boundRadius))
+      (toSpatialRDD(leftRdd, leftShapeExpr), toExpandedEnvelopeRDD(rightRdd, 
rightShapeExpr, boundRadius, isGeography))
     }
   }
 
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
index 464db3e7..8a8f411b 100644
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
@@ -21,13 +21,13 @@ package org.apache.spark.sql.sedona_sql.strategy.join
 import org.apache.sedona.core.enums.{IndexType, SpatialJoinOptimizationMode}
 import org.apache.sedona.core.spatialOperator.SpatialPredicate
 import org.apache.sedona.core.utils.SedonaConf
-import org.apache.spark.sql.{SparkSession, Strategy}
 import org.apache.spark.sql.catalyst.expressions.{And, EqualNullSafe, EqualTo, 
Expression, LessThan, LessThanOrEqual}
-import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, Inner, 
InnerLike, JoinType, LeftAnti, LeftOuter, LeftSemi, NaturalJoin, RightOuter, 
UsingJoin}
+import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.sedona_sql.expressions._
 import 
org.apache.spark.sql.sedona_sql.optimization.ExpressionUtils.splitConjunctivePredicates
+import org.apache.spark.sql.{SparkSession, Strategy}
 
 
 case class JoinQueryDetection(
@@ -36,6 +36,7 @@ case class JoinQueryDetection(
   leftShape: Expression,
   rightShape: Expression,
   spatialPredicate: SpatialPredicate,
+  isGeography: Boolean,
   extraCondition: Option[Expression] = None,
   distance: Option[Expression] = None
 )
@@ -57,23 +58,23 @@ class JoinQueryDetector(sparkSession: SparkSession) extends 
Strategy {
     extraCondition: Option[Expression] = None): Option[JoinQueryDetection] = {
       predicate match {
         case ST_Contains(Seq(leftShape, rightShape)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.CONTAINS, extraCondition))
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.CONTAINS, false, extraCondition))
         case ST_Intersects(Seq(leftShape, rightShape)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, extraCondition))
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, extraCondition))
         case ST_Within(Seq(leftShape, rightShape)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.WITHIN, extraCondition))
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.WITHIN, false, extraCondition))
         case ST_Covers(Seq(leftShape, rightShape)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.COVERS, extraCondition))
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.COVERS, false, extraCondition))
         case ST_CoveredBy(Seq(leftShape, rightShape)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.COVERED_BY, extraCondition))
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.COVERED_BY, false, extraCondition))
         case ST_Overlaps(Seq(leftShape, rightShape)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.OVERLAPS, extraCondition))
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.OVERLAPS, false, extraCondition))
         case ST_Touches(Seq(leftShape, rightShape)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.TOUCHES, extraCondition))
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.TOUCHES, false, extraCondition))
         case ST_Equals(Seq(leftShape, rightShape)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.EQUALS, extraCondition))
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.EQUALS, false, extraCondition))
         case ST_Crosses(Seq(leftShape, rightShape)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.CROSSES, extraCondition))
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.CROSSES, false, extraCondition))
         case _ => None
       }
     }
@@ -109,20 +110,45 @@ class JoinQueryDetector(sparkSession: SparkSession) 
extends Strategy {
           getJoinDetection(left, right, predicate, Some(extraCondition))
         case Some(And(extraCondition, predicate: ST_Predicate)) =>
           getJoinDetection(left, right, predicate, Some(extraCondition))
-
         // For distance joins we execute the actual predicate (condition) and 
not only extraConditions.
         case Some(LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)), 
distance)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, condition, Some(distance)))
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
         case Some(And(LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)), 
distance), _)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, condition, Some(distance)))
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
         case Some(And(_, LessThanOrEqual(ST_Distance(Seq(leftShape, 
rightShape)), distance))) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, condition, Some(distance)))
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
         case Some(LessThan(ST_Distance(Seq(leftShape, rightShape)), distance)) 
=>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, condition, Some(distance)))
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
         case Some(And(LessThan(ST_Distance(Seq(leftShape, rightShape)), 
distance), _)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, condition, Some(distance)))
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
         case Some(And(_, LessThan(ST_Distance(Seq(leftShape, rightShape)), 
distance))) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, condition, Some(distance)))
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
+        // ST_DistanceSphere
+        case Some(LessThanOrEqual(ST_DistanceSphere(Seq(leftShape, rightShape, 
radius)), distance)) =>
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
+        case Some(And(LessThanOrEqual(ST_DistanceSphere(Seq(leftShape, 
rightShape, radius)), distance), _)) =>
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
+        case Some(And(_, LessThanOrEqual(ST_DistanceSphere(Seq(leftShape, 
rightShape, radius)), distance))) =>
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
+        case Some(LessThan(ST_DistanceSphere(Seq(leftShape, rightShape, 
radius)), distance)) =>
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
+        case Some(And(LessThan(ST_DistanceSphere(Seq(leftShape, rightShape, 
radius)), distance), _)) =>
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
+        case Some(And(_, LessThan(ST_DistanceSphere(Seq(leftShape, rightShape, 
radius)), distance))) =>
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
+        // ST_DistanceSpheroid
+        case Some(LessThanOrEqual(ST_DistanceSpheroid(Seq(leftShape, 
rightShape)), distance)) =>
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
+        case Some(And(LessThanOrEqual(ST_DistanceSpheroid(Seq(leftShape, 
rightShape)), distance), _)) =>
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
+        case Some(And(_, LessThanOrEqual(ST_DistanceSpheroid(Seq(leftShape, 
rightShape)), distance))) =>
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
+        case Some(LessThan(ST_DistanceSpheroid(Seq(leftShape, rightShape)), 
distance)) =>
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
+        case Some(And(LessThan(ST_DistanceSpheroid(Seq(leftShape, 
rightShape)), distance), _)) =>
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
+        case Some(And(_, LessThan(ST_DistanceSpheroid(Seq(leftShape, 
rightShape)), distance))) =>
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
         case _ =>
           None
       }
@@ -131,20 +157,20 @@ class JoinQueryDetector(sparkSession: SparkSession) 
extends Strategy {
 
       if ((broadcastLeft || broadcastRight) && sedonaConf.getUseIndex) {
         queryDetection match {
-          case Some(JoinQueryDetection(left, right, leftShape, rightShape, 
spatialPredicate, extraCondition, distance)) =>
+          case Some(JoinQueryDetection(left, right, leftShape, rightShape, 
spatialPredicate, isGeography, extraCondition, distance)) =>
             planBroadcastJoin(
               left, right, Seq(leftShape, rightShape), joinType,
               spatialPredicate, sedonaConf.getIndexType,
-              broadcastLeft, broadcastRight, extraCondition, distance)
+              broadcastLeft, broadcastRight, isGeography, extraCondition, 
distance)
           case _ =>
             Nil
         }
       } else {
         queryDetection match {
-          case Some(JoinQueryDetection(left, right, leftShape, rightShape, 
spatialPredicate, extraCondition, None)) =>
+          case Some(JoinQueryDetection(left, right, leftShape, rightShape, 
spatialPredicate, isGeography, extraCondition, None)) =>
             planSpatialJoin(left, right, Seq(leftShape, rightShape), joinType, 
spatialPredicate, extraCondition)
-          case Some(JoinQueryDetection(left, right, leftShape, rightShape, 
spatialPredicate, extraCondition, Some(distance))) =>
-            planDistanceJoin(left, right, Seq(leftShape, rightShape), 
joinType, distance, spatialPredicate, extraCondition)
+          case Some(JoinQueryDetection(left, right, leftShape, rightShape, 
spatialPredicate, isGeography, extraCondition, Some(distance))) =>
+            planDistanceJoin(left, right, Seq(leftShape, rightShape), 
joinType, distance, spatialPredicate, isGeography, extraCondition)
           case None =>
             Nil
         }
@@ -236,6 +262,7 @@ class JoinQueryDetector(sparkSession: SparkSession) extends 
Strategy {
     joinType: JoinType,
     distance: Expression,
     spatialPredicate: SpatialPredicate,
+    isGeography: Boolean,
     extraCondition: Option[Expression] = None): Seq[SparkPlan] = {
 
     if (joinType != Inner) {
@@ -252,11 +279,11 @@ class JoinQueryDetector(sparkSession: SparkSession) 
extends Strategy {
           case Some(LeftSide) =>
             logInfo("Planning spatial distance join, distance bound to left 
relation")
             DistanceJoinExec(planLater(left), planLater(right), leftShape, 
rightShape, distance, distanceBoundToLeft = true,
-              spatialPredicate, extraCondition) :: Nil
+              spatialPredicate, isGeography, extraCondition) :: Nil
           case Some(RightSide) =>
             logInfo("Planning spatial distance join, distance bound to right 
relation")
             DistanceJoinExec(planLater(left), planLater(right), leftShape, 
rightShape, distance, distanceBoundToLeft = false,
-              spatialPredicate, extraCondition) :: Nil
+              spatialPredicate, isGeography, extraCondition) :: Nil
           case _ =>
             logInfo(
               "Spatial distance join for ST_Distance with non-scalar distance 
" +
@@ -280,6 +307,7 @@ class JoinQueryDetector(sparkSession: SparkSession) extends 
Strategy {
     indexType: IndexType,
     broadcastLeft: Boolean,
     broadcastRight: Boolean,
+    isGeography: Boolean,
     extraCondition: Option[Expression],
     distance: Option[Expression]): Seq[SparkPlan] = {
 
@@ -300,12 +328,13 @@ class JoinQueryDetector(sparkSession: SparkSession) 
extends Strategy {
     val a = children.head
     val b = children.tail.head
 
-    val relationship = (distance, spatialPredicate) match {
-      case (Some(_), SpatialPredicate.INTERSECTS) => "ST_Distance <="
-      case (Some(_), _) => "ST_Distance <"
-      case (None, _) => s"ST_$spatialPredicate"
+    val relationship = (distance, spatialPredicate, isGeography) match {
+      case (Some(_), SpatialPredicate.INTERSECTS, false) => "ST_Distance <="
+      case (Some(_), _, false) => "ST_Distance <"
+      case (Some(_), SpatialPredicate.INTERSECTS, true) => "ST_Distance 
(Geography) <="
+      case (Some(_), _, true) => "ST_Distance (Geography) <"
+      case (None, _, false) => s"ST_$spatialPredicate"
     }
-
     val (distanceOnIndexSide, distanceOnStreamSide) = distance.map { 
distanceExpr =>
       matchDistanceExpressionToJoinSide(distanceExpr, left, right) match {
         case Some(side) =>
@@ -321,13 +350,13 @@ class JoinQueryDetector(sparkSession: SparkSession) 
extends Strategy {
         logInfo(s"Planning spatial join for $relationship relationship")
         val (leftPlan, rightPlan, streamShape, windowSide) = 
(broadcastSide.get, swapped) match {
           case (LeftSide, false) => // Broadcast the left side, windows on the 
left
-            (SpatialIndexExec(planLater(left), a, indexType, 
distanceOnIndexSide), planLater(right), b, LeftSide)
+            (SpatialIndexExec(planLater(left), a, indexType, isGeography, 
distanceOnIndexSide), planLater(right), b, LeftSide)
           case (LeftSide, true) => // Broadcast the left side, objects on the 
left
-            (SpatialIndexExec(planLater(left), b, indexType, 
distanceOnIndexSide), planLater(right), a, RightSide)
+            (SpatialIndexExec(planLater(left), b, indexType, isGeography, 
distanceOnIndexSide), planLater(right), a, RightSide)
           case (RightSide, false) => // Broadcast the right side, windows on 
the left
-            (planLater(left), SpatialIndexExec(planLater(right), b, indexType, 
distanceOnIndexSide), a, LeftSide)
+            (planLater(left), SpatialIndexExec(planLater(right), b, indexType, 
isGeography, distanceOnIndexSide), a, LeftSide)
           case (RightSide, true) => // Broadcast the right side, objects on 
the left
-            (planLater(left), SpatialIndexExec(planLater(right), a, indexType, 
distanceOnIndexSide), b, RightSide)
+            (planLater(left), SpatialIndexExec(planLater(right), a, indexType, 
isGeography, distanceOnIndexSide), b, RightSide)
         }
         BroadcastIndexJoinExec(leftPlan, rightPlan, streamShape, 
broadcastSide.get, windowSide, joinType,
           spatialPredicate, extraCondition, distanceOnStreamSide) :: Nil
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
index f9f9a4ed..2c24a34e 100644
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
@@ -33,6 +33,7 @@ import 
org.apache.spark.sql.sedona_sql.execution.SedonaUnaryExecNode
 case class SpatialIndexExec(child: SparkPlan,
                             shape: Expression,
                             indexType: IndexType,
+                            isGeography: Boolean,
                             distance: Option[Expression] = None)
   extends SedonaUnaryExecNode
     with TraitJoinQueryBase
@@ -51,7 +52,7 @@ case class SpatialIndexExec(child: SparkPlan,
     val resultRaw = child.execute().asInstanceOf[RDD[UnsafeRow]].coalesce(1)
 
     val spatialRDD = distance match {
-      case Some(distanceExpression) => toExpandedEnvelopeRDD(resultRaw, 
boundShape, BindReferences.bindReference(distanceExpression, child.output))
+      case Some(distanceExpression) => toExpandedEnvelopeRDD(resultRaw, 
boundShape, BindReferences.bindReference(distanceExpression, child.output), 
isGeography)
       case None => toSpatialRDD(resultRaw, boundShape)
     }
 
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
index df8f4cd3..0bdeeb83 100644
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
@@ -24,7 +24,7 @@ import org.apache.sedona.sql.utils.GeometrySerializer
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeRow}
 import org.apache.spark.sql.execution.SparkPlan
-import org.locationtech.jts.geom.Geometry
+import org.locationtech.jts.geom.{Envelope, Geometry}
 
 trait TraitJoinQueryBase {
   self: SparkPlan =>
@@ -48,14 +48,14 @@ trait TraitJoinQueryBase {
     spatialRdd
   }
 
-  def toExpandedEnvelopeRDD(rdd: RDD[UnsafeRow], shapeExpression: Expression, 
boundRadius: Expression): SpatialRDD[Geometry] = {
+  def toExpandedEnvelopeRDD(rdd: RDD[UnsafeRow], shapeExpression: Expression, 
boundRadius: Expression, isGeography: Boolean): SpatialRDD[Geometry] = {
     val spatialRdd = new SpatialRDD[Geometry]
     spatialRdd.setRawSpatialRDD(
       rdd
         .map { x =>
           val shape = 
GeometrySerializer.deserialize(shapeExpression.eval(x).asInstanceOf[Array[Byte]])
           val envelope = shape.getEnvelopeInternal.copy()
-          envelope.expandBy(boundRadius.eval(x).asInstanceOf[Double])
+          expandEnvelope(envelope, boundRadius.eval(x).asInstanceOf[Double], 
6357000.0, isGeography)
 
           val expandedEnvelope = shape.getFactory.toGeometry(envelope)
           expandedEnvelope.setUserData(x.copy)
@@ -72,4 +72,27 @@ trait TraitJoinQueryBase {
       followerShapes.spatialPartitioning(dominantShapes.getPartitioner)
     }
   }
-}
+
+  /**
+    * Expand the given envelope by the given distance in meter.
+    * For geography, we expand the envelope by the given distance in both 
longitude and latitude.
+    * @param envelope
+    * @param distance in meter
+    * @param radius in meter
+    * @param isGeography
+    */
+  private def expandEnvelope(envelope:Envelope, distance:Double, 
radius:Double, isGeography:Boolean):Unit = {
+    if (isGeography) {
+      val scaleFactor = 1.1 // 10% buffer to get rid of false negatives
+      val latRadian = Math.toRadians((envelope.getMinX + envelope.getMaxX) / 
2.0)
+      val latDeltaRadian = distance / radius;
+      val latDeltaDegree = Math.toDegrees(latDeltaRadian)
+      val lonDeltaRadian = Math.max(Math.abs(distance / (radius * 
Math.cos(latRadian + latDeltaRadian))),
+        Math.abs(distance / (radius * Math.cos(latRadian - latDeltaRadian))))
+      val lonDeltaDegree = Math.toDegrees(lonDeltaRadian)
+      envelope.expandBy(latDeltaDegree * scaleFactor, lonDeltaDegree * 
scaleFactor)
+    } else {
+      envelope.expandBy(distance)
+    }
+  }
+}
\ No newline at end of file
diff --git 
a/sql/common/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala 
b/sql/common/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala
index 78f3aaf5..6f569128 100644
--- 
a/sql/common/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala
+++ 
b/sql/common/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala
@@ -302,6 +302,54 @@ class BroadcastIndexJoinSuite extends TestBaseScala {
       assert(rows2(0) == Row(2.0, 2.0, "left_2", 2.0, 2.0, "right_2"))
 
     }
+
+    it("Passed ST_DistanceSpheroid in a broadcast join") {
+      val distanceCandidates = Seq(130000, 160000, 500000)
+      val sampleCount = 50
+      distanceCandidates.foreach(distance => {
+        val expected = bruteForceDistanceJoinCountSpheroid(sampleCount, 
distance)
+        val pointDf1 = buildPointDf.limit(sampleCount).repartition(4)
+        val pointDf2 = pointDf1
+        var distanceJoinDf = pointDf1.alias("pointDf1").join(
+          broadcast(pointDf2).alias("pointDf2"), 
expr(s"ST_DistanceSpheroid(pointDf1.pointshape, pointDf2.pointshape) <= 
$distance"))
+        assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+        assert(distanceJoinDf.count() == expected)
+
+        distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+          pointDf2.alias("pointDf2"), 
expr(s"ST_DistanceSpheroid(pointDf1.pointshape, pointDf2.pointshape) <= 
$distance"))
+        assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+        assert(distanceJoinDf.count() == expected)
+
+        distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+          pointDf2.alias("pointDf2"), 
expr(s"ST_DistanceSpheroid(pointDf1.pointshape, pointDf2.pointshape) < 
$distance"))
+        assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+        assert(distanceJoinDf.count() == expected)
+      })
+    }
+
+    it("Passed ST_DistanceSphere in a broadcast join") {
+      val distanceCandidates = Seq(130000, 160000, 500000)
+      val sampleCount = 50
+      distanceCandidates.foreach(distance => {
+        val expected = bruteForceDistanceJoinCountSphere(sampleCount, distance)
+        val pointDf1 = buildPointDf.limit(sampleCount).repartition(4)
+        val pointDf2 = pointDf1
+        var distanceJoinDf = pointDf1.alias("pointDf1").join(
+          broadcast(pointDf2).alias("pointDf2"), 
expr(s"ST_DistanceSphere(pointDf1.pointshape, pointDf2.pointshape) <= 
$distance"))
+        assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+        assert(distanceJoinDf.count() == expected)
+
+        distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+          pointDf2.alias("pointDf2"), 
expr(s"ST_DistanceSphere(pointDf1.pointshape, pointDf2.pointshape) <= 
$distance"))
+        assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+        assert(distanceJoinDf.count() == expected)
+
+        distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+          pointDf2.alias("pointDf2"), 
expr(s"ST_DistanceSphere(pointDf1.pointshape, pointDf2.pointshape, 6371008.0) < 
$distance"))
+        assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+        assert(distanceJoinDf.count() == expected)
+      })
+    }
   }
 
   describe("Sedona-SQL Broadcast Index Join Test for left semi joins") {
diff --git 
a/sql/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala 
b/sql/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
index f12b5874..16d97888 100644
--- a/sql/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
+++ b/sql/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
@@ -20,6 +20,7 @@ package org.apache.sedona.sql
 
 import com.google.common.math.DoubleMath
 import org.apache.log4j.{Level, Logger}
+import org.apache.sedona.common.sphere.{Haversine, Spheroid}
 import org.apache.sedona.core.serde.SedonaKryoRegistrator
 import org.apache.sedona.sql.utils.SedonaSQLRegistrator
 import org.apache.spark.serializer.KryoSerializer
@@ -97,4 +98,26 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll {
     }
   }
 
+  protected def bruteForceDistanceJoinCountSpheroid(sampleCount:Int, distance: 
Double): Int = {
+    val input = buildPointDf.limit(sampleCount).collect()
+    input.map(row => {
+      val point1 = row.getAs[org.locationtech.jts.geom.Point](0)
+      input.map(row => {
+        val point2 = row.getAs[org.locationtech.jts.geom.Point](0)
+        if (Spheroid.distance(point1, point2) <= distance) 1 else 0
+      }).sum
+    }).sum
+  }
+
+  protected def bruteForceDistanceJoinCountSphere(sampleCount: Int, distance: 
Double): Int = {
+    val input = buildPointDf.limit(sampleCount).collect()
+    input.map(row => {
+      val point1 = row.getAs[org.locationtech.jts.geom.Point](0)
+      input.map(row => {
+        val point2 = row.getAs[org.locationtech.jts.geom.Point](0)
+        if (Haversine.distance(point1, point2) <= distance) 1 else 0
+      }).sum
+    }).sum
+  }
+
 }
diff --git 
a/sql/common/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala 
b/sql/common/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala
index 596e561b..66ee94ca 100644
--- 
a/sql/common/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala
+++ 
b/sql/common/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala
@@ -21,10 +21,10 @@ package org.apache.sedona.sql
 
 import org.apache.sedona.core.utils.SedonaConf
 import org.apache.spark.sql.Row
-import org.apache.spark.sql.execution.ExplainMode
+import org.apache.spark.sql.functions.expr
+import org.apache.spark.sql.sedona_sql.strategy.join.DistanceJoinExec
 import org.apache.spark.sql.types._
 import org.locationtech.jts.geom.Geometry
-import org.locationtech.jts.io.WKTWriter
 
 class predicateJoinTestScala extends TestBaseScala {
 
@@ -361,5 +361,49 @@ class predicateJoinTestScala extends TestBaseScala {
 
       assert(equalJoinDf.count() == 0, s"Expected 0 but got 
${equalJoinDf.count()}")
     }
+
+    it("Passed ST_DistanceSpheroid in a spatial join") {
+      val distanceCandidates = Seq(130000, 160000, 500000)
+      val sampleCount = 50
+      distanceCandidates.foreach(distance => {
+        val expected = bruteForceDistanceJoinCountSpheroid(sampleCount, 
distance)
+        val pointDf1 = buildPointDf.limit(sampleCount).repartition(4)
+        val pointDf2 = pointDf1
+        var distanceJoinDf = pointDf1.alias("pointDf1").join(
+          pointDf2.alias("pointDf2"), 
expr(s"ST_DistanceSpheroid(pointDf1.pointshape, pointDf2.pointshape) <= 
$distance"))
+        assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
DistanceJoinExec => p }.size === 1)
+        assert(distanceJoinDf.count() == expected)
+
+        distanceJoinDf = pointDf1.alias("pointDf1").join(
+          pointDf2.alias("pointDf2"), 
expr(s"ST_DistanceSpheroid(pointDf1.pointshape, pointDf2.pointshape) < 
$distance"))
+        assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
DistanceJoinExec => p }.size === 1)
+        assert(distanceJoinDf.count() == expected)
+      })
+    }
+
+    it("Passed ST_DistanceSphere in a spatial join") {
+      val distanceCandidates = Seq(130000, 160000, 500000)
+      val sampleCount = 50
+      distanceCandidates.foreach(distance => {
+        val expected = bruteForceDistanceJoinCountSphere(sampleCount, distance)
+        val pointDf1 = buildPointDf.limit(sampleCount).repartition(4)
+        val pointDf2 = pointDf1
+        var distanceJoinDf = pointDf1.alias("pointDf1").join(
+          pointDf2.alias("pointDf2"), 
expr(s"ST_DistanceSphere(pointDf1.pointshape, pointDf2.pointshape) <= 
$distance"))
+        assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
DistanceJoinExec => p }.size === 1)
+        assert(distanceJoinDf.count() == expected)
+
+        distanceJoinDf = pointDf1.alias("pointDf1").join(
+          pointDf2.alias("pointDf2"), 
expr(s"ST_DistanceSphere(pointDf1.pointshape, pointDf2.pointshape) < 
$distance"))
+        assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
DistanceJoinExec => p }.size === 1)
+        assert(distanceJoinDf.count() == expected)
+
+        distanceJoinDf = pointDf1.alias("pointDf1").join(
+          pointDf2.alias("pointDf2"), 
expr(s"ST_DistanceSphere(pointDf1.pointshape, pointDf2.pointshape, 6371008.0) < 
$distance"))
+        assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
DistanceJoinExec => p }.size === 1)
+        assert(distanceJoinDf.count() == expected)
+      })
+
+    }
   }
 }


Reply via email to