This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new dc3c5634 [SEDONA-286] Support optimized distance join on
ST_DistanceSphere and ST_DistanceSpheroid (#845)
dc3c5634 is described below
commit dc3c563492da6dccf116a98b3b381f178daf3c8c
Author: Jia Yu <[email protected]>
AuthorDate: Wed May 31 01:01:37 2023 -0700
[SEDONA-286] Support optimized distance join on ST_DistanceSphere and
ST_DistanceSpheroid (#845)
---
docs/api/sql/Optimizer.md | 37 +++++++--
.../strategy/join/DistanceJoinExec.scala | 5 +-
.../strategy/join/JoinQueryDetector.scala | 97 ++++++++++++++--------
.../strategy/join/SpatialIndexExec.scala | 3 +-
.../strategy/join/TraitJoinQueryBase.scala | 31 ++++++-
.../sedona/sql/BroadcastIndexJoinSuite.scala | 48 +++++++++++
.../org/apache/sedona/sql/TestBaseScala.scala | 23 +++++
.../apache/sedona/sql/predicateJoinTestScala.scala | 48 ++++++++++-
8 files changed, 242 insertions(+), 50 deletions(-)
diff --git a/docs/api/sql/Optimizer.md b/docs/api/sql/Optimizer.md
index 025b964b..3fa0242b 100644
--- a/docs/api/sql/Optimizer.md
+++ b/docs/api/sql/Optimizer.md
@@ -28,7 +28,9 @@ SELECT *
FROM pointdf, polygondf
WHERE ST_Within(pointdf.pointshape, polygondf.polygonshape)
```
+
Spark SQL Physical plan:
+
```
== Physical Plan ==
RangeJoin polygonshape#20: geometry, pointshape#43: geometry, false
@@ -44,9 +46,9 @@ RangeJoin polygonshape#20: geometry, pointshape#43: geometry,
false
## Distance join
-Introduction: Find geometries from A and geometries from B such that the
internal Euclidean distance of each geometry pair is less or equal than a
certain distance
+Introduction: Find geometries from A and geometries from B such that the
distance of each geometry pair is less or equal than a certain distance. It
supports the planar Euclidean distance calculator `ST_Distance` and the
meter-based geodesic distance calculators `ST_DistanceSpheroid` and
`ST_DistanceSphere`.
-Spark SQL Example:
+Spark SQL Example for planar Euclidean distance:
*Only consider ==fully within a certain distance==*
```sql
@@ -73,7 +75,26 @@ DistanceJoin pointshape1#12: geometry, pointshape2#33:
geometry, 2.0, true
```
!!!warning
- Sedona doesn't control the distance's unit (degree or meter). It is
same with the geometry. If your coordinates are in the longitude and latitude
system, the unit of `distance` should be degree instead of meter or mile. To
change the geometry's unit, please either transform the coordinate reference
system to a meter-based system. See [ST_Transform](Function.md#st_transform).
If you don't want to transform your data and are ok with sacrificing the query
accuracy, you can use an approxima [...]
+ If you use `ST_Distance` as the predicate, Sedona doesn't control the
distance's unit (degree or meter). It is same with the geometry. If your
coordinates are in the longitude and latitude system, the unit of `distance`
should be degree instead of meter or mile. To change the geometry's unit,
please either transform the coordinate reference system to a meter-based
system. See [ST_Transform](Function.md#st_transform). If you don't want to
transform your data, please consider using `ST_Di [...]
+
+Spark SQL Example for meter-based geodesic distance `ST_DistanceSpheroid`
(works for `ST_DistanceSphere` too):
+
+*Less than a certain distance==*
+```sql
+SELECT *
+FROM pointdf1, pointdf2
+WHERE ST_DistanceSpheroid(pointdf1.pointshape1,pointdf2.pointshape2) < 2
+```
+
+*Less than or equal to a certain distance==*
+```sql
+SELECT *
+FROM pointdf1, pointdf2
+WHERE ST_DistanceSpheroid(pointdf1.pointshape1,pointdf2.pointshape2) <= 2
+```
+
+!!!warning
+ If you use `ST_DistanceSpheroid ` or `ST_DistanceSphere` as the
predicate, the unit of the distance is meter. Currently, distance join with
geodesic distance calculators work best for point data. For non-point data, it
only considers their centroids. The distance join algorithm internally uses an
approximate distance buffer which might lead to inaccurate results if your data
is close to the poles or antimeridian.
## Broadcast index join
@@ -105,7 +126,7 @@ BroadcastIndexJoin pointshape#52: geometry, BuildRight,
BuildRight, false ST_Con
+- FileScan csv
```
-This also works for distance joins:
+This also works for distance joins with `ST_Distance`, `ST_DistanceSpheroid`
or `ST_DistanceSphere`:
```scala
pointDf1.alias("pointDf1").join(broadcast(pointDf2).alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"))
@@ -202,14 +223,16 @@ GROUP BY (lcs_geom, rcs_geom)
This also works for distance join. You first need to use `ST_Buffer(geometry,
distance)` to wrap one of your original geometry column. If your original
geometry column contains points, this `ST_Buffer` will make them become circles
with a radius of `distance`.
-For example. run this query first on the left table before Step 1.
+Since the coordinates are in the longitude and latitude system, so the unit of
`distance` should be degree instead of meter or mile. You can get an
approximation by performing `METER_DISTANCE/111000.0`, then filter out
false-positives. Note that this might lead to inaccurate results if your data
is close to the poles or antimeridian.
+
+In a nutshell, run this query first on the left table before Step 1. Please
replace `METER_DISTANCE` with a meter distance. In Step 1, generate S2 IDs
based on the `buffered_geom` column. Then run Step 2, 3, 4 on the original
`geom` column.
```sql
-SELECT id, ST_Buffer(geom, DISTANCE), name
+SELECT id, geom , ST_Buffer(geom, METER_DISTANCE/111000.0) as buffered_geom,
name
FROM lefts
```
-Since the coordinates are in the longitude and latitude system, so the unit of
`distance` should be degree instead of meter or mile. You will have to estimate
the corresponding degrees based on your meter values. Please use [this
calculator](https://lucidar.me/en/online-unit-converter-length-to-angle/convert-degrees-to-meters/#online-converter).
+
## Regular spatial predicate pushdown
Introduction: Given a join query and a predicate in the same WHERE clause,
first executes the Predicate as a filter, then executes the join query.
diff --git
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
index b393f2d0..615f88a2 100644
---
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
+++
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
@@ -54,6 +54,7 @@ case class DistanceJoinExec(left: SparkPlan,
distance: Expression,
distanceBoundToLeft: Boolean,
spatialPredicate: SpatialPredicate,
+ isGeography: Boolean,
extraCondition: Option[Expression] = None)
extends SedonaBinaryExecNode
with TraitJoinQueryExec
@@ -70,9 +71,9 @@ case class DistanceJoinExec(left: SparkPlan,
rightRdd: RDD[UnsafeRow],
rightShapeExpr: Expression):
(SpatialRDD[Geometry], SpatialRDD[Geometry]) = {
if (distanceBoundToLeft) {
- (toExpandedEnvelopeRDD(leftRdd, leftShapeExpr, boundRadius),
toSpatialRDD(rightRdd, rightShapeExpr))
+ (toExpandedEnvelopeRDD(leftRdd, leftShapeExpr, boundRadius,
isGeography), toSpatialRDD(rightRdd, rightShapeExpr))
} else {
- (toSpatialRDD(leftRdd, leftShapeExpr), toExpandedEnvelopeRDD(rightRdd,
rightShapeExpr, boundRadius))
+ (toSpatialRDD(leftRdd, leftShapeExpr), toExpandedEnvelopeRDD(rightRdd,
rightShapeExpr, boundRadius, isGeography))
}
}
diff --git
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
index 464db3e7..8a8f411b 100644
---
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
+++
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
@@ -21,13 +21,13 @@ package org.apache.spark.sql.sedona_sql.strategy.join
import org.apache.sedona.core.enums.{IndexType, SpatialJoinOptimizationMode}
import org.apache.sedona.core.spatialOperator.SpatialPredicate
import org.apache.sedona.core.utils.SedonaConf
-import org.apache.spark.sql.{SparkSession, Strategy}
import org.apache.spark.sql.catalyst.expressions.{And, EqualNullSafe, EqualTo,
Expression, LessThan, LessThanOrEqual}
-import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, Inner,
InnerLike, JoinType, LeftAnti, LeftOuter, LeftSemi, NaturalJoin, RightOuter,
UsingJoin}
+import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.sedona_sql.expressions._
import
org.apache.spark.sql.sedona_sql.optimization.ExpressionUtils.splitConjunctivePredicates
+import org.apache.spark.sql.{SparkSession, Strategy}
case class JoinQueryDetection(
@@ -36,6 +36,7 @@ case class JoinQueryDetection(
leftShape: Expression,
rightShape: Expression,
spatialPredicate: SpatialPredicate,
+ isGeography: Boolean,
extraCondition: Option[Expression] = None,
distance: Option[Expression] = None
)
@@ -57,23 +58,23 @@ class JoinQueryDetector(sparkSession: SparkSession) extends
Strategy {
extraCondition: Option[Expression] = None): Option[JoinQueryDetection] = {
predicate match {
case ST_Contains(Seq(leftShape, rightShape)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.CONTAINS, extraCondition))
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.CONTAINS, false, extraCondition))
case ST_Intersects(Seq(leftShape, rightShape)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, extraCondition))
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, extraCondition))
case ST_Within(Seq(leftShape, rightShape)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.WITHIN, extraCondition))
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.WITHIN, false, extraCondition))
case ST_Covers(Seq(leftShape, rightShape)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.COVERS, extraCondition))
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.COVERS, false, extraCondition))
case ST_CoveredBy(Seq(leftShape, rightShape)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.COVERED_BY, extraCondition))
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.COVERED_BY, false, extraCondition))
case ST_Overlaps(Seq(leftShape, rightShape)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.OVERLAPS, extraCondition))
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.OVERLAPS, false, extraCondition))
case ST_Touches(Seq(leftShape, rightShape)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.TOUCHES, extraCondition))
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.TOUCHES, false, extraCondition))
case ST_Equals(Seq(leftShape, rightShape)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.EQUALS, extraCondition))
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.EQUALS, false, extraCondition))
case ST_Crosses(Seq(leftShape, rightShape)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.CROSSES, extraCondition))
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.CROSSES, false, extraCondition))
case _ => None
}
}
@@ -109,20 +110,45 @@ class JoinQueryDetector(sparkSession: SparkSession)
extends Strategy {
getJoinDetection(left, right, predicate, Some(extraCondition))
case Some(And(extraCondition, predicate: ST_Predicate)) =>
getJoinDetection(left, right, predicate, Some(extraCondition))
-
// For distance joins we execute the actual predicate (condition) and
not only extraConditions.
case Some(LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)),
distance)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, condition, Some(distance)))
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
case Some(And(LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)),
distance), _)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, condition, Some(distance)))
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
case Some(And(_, LessThanOrEqual(ST_Distance(Seq(leftShape,
rightShape)), distance))) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, condition, Some(distance)))
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
case Some(LessThan(ST_Distance(Seq(leftShape, rightShape)), distance))
=>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, condition, Some(distance)))
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
case Some(And(LessThan(ST_Distance(Seq(leftShape, rightShape)),
distance), _)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, condition, Some(distance)))
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
case Some(And(_, LessThan(ST_Distance(Seq(leftShape, rightShape)),
distance))) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, condition, Some(distance)))
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
+ // ST_DistanceSphere
+ case Some(LessThanOrEqual(ST_DistanceSphere(Seq(leftShape, rightShape,
radius)), distance)) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
+ case Some(And(LessThanOrEqual(ST_DistanceSphere(Seq(leftShape,
rightShape, radius)), distance), _)) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
+ case Some(And(_, LessThanOrEqual(ST_DistanceSphere(Seq(leftShape,
rightShape, radius)), distance))) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
+ case Some(LessThan(ST_DistanceSphere(Seq(leftShape, rightShape,
radius)), distance)) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
+ case Some(And(LessThan(ST_DistanceSphere(Seq(leftShape, rightShape,
radius)), distance), _)) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
+ case Some(And(_, LessThan(ST_DistanceSphere(Seq(leftShape, rightShape,
radius)), distance))) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
+ // ST_DistanceSpheroid
+ case Some(LessThanOrEqual(ST_DistanceSpheroid(Seq(leftShape,
rightShape)), distance)) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
+ case Some(And(LessThanOrEqual(ST_DistanceSpheroid(Seq(leftShape,
rightShape)), distance), _)) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
+ case Some(And(_, LessThanOrEqual(ST_DistanceSpheroid(Seq(leftShape,
rightShape)), distance))) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
+ case Some(LessThan(ST_DistanceSpheroid(Seq(leftShape, rightShape)),
distance)) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
+ case Some(And(LessThan(ST_DistanceSpheroid(Seq(leftShape,
rightShape)), distance), _)) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
+ case Some(And(_, LessThan(ST_DistanceSpheroid(Seq(leftShape,
rightShape)), distance))) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
case _ =>
None
}
@@ -131,20 +157,20 @@ class JoinQueryDetector(sparkSession: SparkSession)
extends Strategy {
if ((broadcastLeft || broadcastRight) && sedonaConf.getUseIndex) {
queryDetection match {
- case Some(JoinQueryDetection(left, right, leftShape, rightShape,
spatialPredicate, extraCondition, distance)) =>
+ case Some(JoinQueryDetection(left, right, leftShape, rightShape,
spatialPredicate, isGeography, extraCondition, distance)) =>
planBroadcastJoin(
left, right, Seq(leftShape, rightShape), joinType,
spatialPredicate, sedonaConf.getIndexType,
- broadcastLeft, broadcastRight, extraCondition, distance)
+ broadcastLeft, broadcastRight, isGeography, extraCondition,
distance)
case _ =>
Nil
}
} else {
queryDetection match {
- case Some(JoinQueryDetection(left, right, leftShape, rightShape,
spatialPredicate, extraCondition, None)) =>
+ case Some(JoinQueryDetection(left, right, leftShape, rightShape,
spatialPredicate, isGeography, extraCondition, None)) =>
planSpatialJoin(left, right, Seq(leftShape, rightShape), joinType,
spatialPredicate, extraCondition)
- case Some(JoinQueryDetection(left, right, leftShape, rightShape,
spatialPredicate, extraCondition, Some(distance))) =>
- planDistanceJoin(left, right, Seq(leftShape, rightShape),
joinType, distance, spatialPredicate, extraCondition)
+ case Some(JoinQueryDetection(left, right, leftShape, rightShape,
spatialPredicate, isGeography, extraCondition, Some(distance))) =>
+ planDistanceJoin(left, right, Seq(leftShape, rightShape),
joinType, distance, spatialPredicate, isGeography, extraCondition)
case None =>
Nil
}
@@ -236,6 +262,7 @@ class JoinQueryDetector(sparkSession: SparkSession) extends
Strategy {
joinType: JoinType,
distance: Expression,
spatialPredicate: SpatialPredicate,
+ isGeography: Boolean,
extraCondition: Option[Expression] = None): Seq[SparkPlan] = {
if (joinType != Inner) {
@@ -252,11 +279,11 @@ class JoinQueryDetector(sparkSession: SparkSession)
extends Strategy {
case Some(LeftSide) =>
logInfo("Planning spatial distance join, distance bound to left
relation")
DistanceJoinExec(planLater(left), planLater(right), leftShape,
rightShape, distance, distanceBoundToLeft = true,
- spatialPredicate, extraCondition) :: Nil
+ spatialPredicate, isGeography, extraCondition) :: Nil
case Some(RightSide) =>
logInfo("Planning spatial distance join, distance bound to right
relation")
DistanceJoinExec(planLater(left), planLater(right), leftShape,
rightShape, distance, distanceBoundToLeft = false,
- spatialPredicate, extraCondition) :: Nil
+ spatialPredicate, isGeography, extraCondition) :: Nil
case _ =>
logInfo(
"Spatial distance join for ST_Distance with non-scalar distance
" +
@@ -280,6 +307,7 @@ class JoinQueryDetector(sparkSession: SparkSession) extends
Strategy {
indexType: IndexType,
broadcastLeft: Boolean,
broadcastRight: Boolean,
+ isGeography: Boolean,
extraCondition: Option[Expression],
distance: Option[Expression]): Seq[SparkPlan] = {
@@ -300,12 +328,13 @@ class JoinQueryDetector(sparkSession: SparkSession)
extends Strategy {
val a = children.head
val b = children.tail.head
- val relationship = (distance, spatialPredicate) match {
- case (Some(_), SpatialPredicate.INTERSECTS) => "ST_Distance <="
- case (Some(_), _) => "ST_Distance <"
- case (None, _) => s"ST_$spatialPredicate"
+ val relationship = (distance, spatialPredicate, isGeography) match {
+ case (Some(_), SpatialPredicate.INTERSECTS, false) => "ST_Distance <="
+ case (Some(_), _, false) => "ST_Distance <"
+ case (Some(_), SpatialPredicate.INTERSECTS, true) => "ST_Distance
(Geography) <="
+ case (Some(_), _, true) => "ST_Distance (Geography) <"
+ case (None, _, false) => s"ST_$spatialPredicate"
}
-
val (distanceOnIndexSide, distanceOnStreamSide) = distance.map {
distanceExpr =>
matchDistanceExpressionToJoinSide(distanceExpr, left, right) match {
case Some(side) =>
@@ -321,13 +350,13 @@ class JoinQueryDetector(sparkSession: SparkSession)
extends Strategy {
logInfo(s"Planning spatial join for $relationship relationship")
val (leftPlan, rightPlan, streamShape, windowSide) =
(broadcastSide.get, swapped) match {
case (LeftSide, false) => // Broadcast the left side, windows on the
left
- (SpatialIndexExec(planLater(left), a, indexType,
distanceOnIndexSide), planLater(right), b, LeftSide)
+ (SpatialIndexExec(planLater(left), a, indexType, isGeography,
distanceOnIndexSide), planLater(right), b, LeftSide)
case (LeftSide, true) => // Broadcast the left side, objects on the
left
- (SpatialIndexExec(planLater(left), b, indexType,
distanceOnIndexSide), planLater(right), a, RightSide)
+ (SpatialIndexExec(planLater(left), b, indexType, isGeography,
distanceOnIndexSide), planLater(right), a, RightSide)
case (RightSide, false) => // Broadcast the right side, windows on
the left
- (planLater(left), SpatialIndexExec(planLater(right), b, indexType,
distanceOnIndexSide), a, LeftSide)
+ (planLater(left), SpatialIndexExec(planLater(right), b, indexType,
isGeography, distanceOnIndexSide), a, LeftSide)
case (RightSide, true) => // Broadcast the right side, objects on
the left
- (planLater(left), SpatialIndexExec(planLater(right), a, indexType,
distanceOnIndexSide), b, RightSide)
+ (planLater(left), SpatialIndexExec(planLater(right), a, indexType,
isGeography, distanceOnIndexSide), b, RightSide)
}
BroadcastIndexJoinExec(leftPlan, rightPlan, streamShape,
broadcastSide.get, windowSide, joinType,
spatialPredicate, extraCondition, distanceOnStreamSide) :: Nil
diff --git
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
index f9f9a4ed..2c24a34e 100644
---
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
+++
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
@@ -33,6 +33,7 @@ import
org.apache.spark.sql.sedona_sql.execution.SedonaUnaryExecNode
case class SpatialIndexExec(child: SparkPlan,
shape: Expression,
indexType: IndexType,
+ isGeography: Boolean,
distance: Option[Expression] = None)
extends SedonaUnaryExecNode
with TraitJoinQueryBase
@@ -51,7 +52,7 @@ case class SpatialIndexExec(child: SparkPlan,
val resultRaw = child.execute().asInstanceOf[RDD[UnsafeRow]].coalesce(1)
val spatialRDD = distance match {
- case Some(distanceExpression) => toExpandedEnvelopeRDD(resultRaw,
boundShape, BindReferences.bindReference(distanceExpression, child.output))
+ case Some(distanceExpression) => toExpandedEnvelopeRDD(resultRaw,
boundShape, BindReferences.bindReference(distanceExpression, child.output),
isGeography)
case None => toSpatialRDD(resultRaw, boundShape)
}
diff --git
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
index df8f4cd3..0bdeeb83 100644
---
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
+++
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
@@ -24,7 +24,7 @@ import org.apache.sedona.sql.utils.GeometrySerializer
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeRow}
import org.apache.spark.sql.execution.SparkPlan
-import org.locationtech.jts.geom.Geometry
+import org.locationtech.jts.geom.{Envelope, Geometry}
trait TraitJoinQueryBase {
self: SparkPlan =>
@@ -48,14 +48,14 @@ trait TraitJoinQueryBase {
spatialRdd
}
- def toExpandedEnvelopeRDD(rdd: RDD[UnsafeRow], shapeExpression: Expression,
boundRadius: Expression): SpatialRDD[Geometry] = {
+ def toExpandedEnvelopeRDD(rdd: RDD[UnsafeRow], shapeExpression: Expression,
boundRadius: Expression, isGeography: Boolean): SpatialRDD[Geometry] = {
val spatialRdd = new SpatialRDD[Geometry]
spatialRdd.setRawSpatialRDD(
rdd
.map { x =>
val shape =
GeometrySerializer.deserialize(shapeExpression.eval(x).asInstanceOf[Array[Byte]])
val envelope = shape.getEnvelopeInternal.copy()
- envelope.expandBy(boundRadius.eval(x).asInstanceOf[Double])
+ expandEnvelope(envelope, boundRadius.eval(x).asInstanceOf[Double],
6357000.0, isGeography)
val expandedEnvelope = shape.getFactory.toGeometry(envelope)
expandedEnvelope.setUserData(x.copy)
@@ -72,4 +72,27 @@ trait TraitJoinQueryBase {
followerShapes.spatialPartitioning(dominantShapes.getPartitioner)
}
}
-}
+
+ /**
+ * Expand the given envelope by the given distance in meter.
+ * For geography, we expand the envelope by the given distance in both
longitude and latitude.
+ * @param envelope
+ * @param distance in meter
+ * @param radius in meter
+ * @param isGeography
+ */
+ private def expandEnvelope(envelope:Envelope, distance:Double,
radius:Double, isGeography:Boolean):Unit = {
+ if (isGeography) {
+ val scaleFactor = 1.1 // 10% buffer to get rid of false negatives
+ val latRadian = Math.toRadians((envelope.getMinX + envelope.getMaxX) /
2.0)
+ val latDeltaRadian = distance / radius;
+ val latDeltaDegree = Math.toDegrees(latDeltaRadian)
+ val lonDeltaRadian = Math.max(Math.abs(distance / (radius *
Math.cos(latRadian + latDeltaRadian))),
+ Math.abs(distance / (radius * Math.cos(latRadian - latDeltaRadian))))
+ val lonDeltaDegree = Math.toDegrees(lonDeltaRadian)
+ envelope.expandBy(latDeltaDegree * scaleFactor, lonDeltaDegree *
scaleFactor)
+ } else {
+ envelope.expandBy(distance)
+ }
+ }
+}
\ No newline at end of file
diff --git
a/sql/common/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala
b/sql/common/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala
index 78f3aaf5..6f569128 100644
---
a/sql/common/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala
+++
b/sql/common/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala
@@ -302,6 +302,54 @@ class BroadcastIndexJoinSuite extends TestBaseScala {
assert(rows2(0) == Row(2.0, 2.0, "left_2", 2.0, 2.0, "right_2"))
}
+
+ it("Passed ST_DistanceSpheroid in a broadcast join") {
+ val distanceCandidates = Seq(130000, 160000, 500000)
+ val sampleCount = 50
+ distanceCandidates.foreach(distance => {
+ val expected = bruteForceDistanceJoinCountSpheroid(sampleCount,
distance)
+ val pointDf1 = buildPointDf.limit(sampleCount).repartition(4)
+ val pointDf2 = pointDf1
+ var distanceJoinDf = pointDf1.alias("pointDf1").join(
+ broadcast(pointDf2).alias("pointDf2"),
expr(s"ST_DistanceSpheroid(pointDf1.pointshape, pointDf2.pointshape) <=
$distance"))
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == expected)
+
+ distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+ pointDf2.alias("pointDf2"),
expr(s"ST_DistanceSpheroid(pointDf1.pointshape, pointDf2.pointshape) <=
$distance"))
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == expected)
+
+ distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+ pointDf2.alias("pointDf2"),
expr(s"ST_DistanceSpheroid(pointDf1.pointshape, pointDf2.pointshape) <
$distance"))
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == expected)
+ })
+ }
+
+ it("Passed ST_DistanceSphere in a broadcast join") {
+ val distanceCandidates = Seq(130000, 160000, 500000)
+ val sampleCount = 50
+ distanceCandidates.foreach(distance => {
+ val expected = bruteForceDistanceJoinCountSphere(sampleCount, distance)
+ val pointDf1 = buildPointDf.limit(sampleCount).repartition(4)
+ val pointDf2 = pointDf1
+ var distanceJoinDf = pointDf1.alias("pointDf1").join(
+ broadcast(pointDf2).alias("pointDf2"),
expr(s"ST_DistanceSphere(pointDf1.pointshape, pointDf2.pointshape) <=
$distance"))
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == expected)
+
+ distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+ pointDf2.alias("pointDf2"),
expr(s"ST_DistanceSphere(pointDf1.pointshape, pointDf2.pointshape) <=
$distance"))
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == expected)
+
+ distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+ pointDf2.alias("pointDf2"),
expr(s"ST_DistanceSphere(pointDf1.pointshape, pointDf2.pointshape, 6371008.0) <
$distance"))
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == expected)
+ })
+ }
}
describe("Sedona-SQL Broadcast Index Join Test for left semi joins") {
diff --git
a/sql/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
b/sql/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
index f12b5874..16d97888 100644
--- a/sql/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
+++ b/sql/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
@@ -20,6 +20,7 @@ package org.apache.sedona.sql
import com.google.common.math.DoubleMath
import org.apache.log4j.{Level, Logger}
+import org.apache.sedona.common.sphere.{Haversine, Spheroid}
import org.apache.sedona.core.serde.SedonaKryoRegistrator
import org.apache.sedona.sql.utils.SedonaSQLRegistrator
import org.apache.spark.serializer.KryoSerializer
@@ -97,4 +98,26 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll {
}
}
+ protected def bruteForceDistanceJoinCountSpheroid(sampleCount:Int, distance:
Double): Int = {
+ val input = buildPointDf.limit(sampleCount).collect()
+ input.map(row => {
+ val point1 = row.getAs[org.locationtech.jts.geom.Point](0)
+ input.map(row => {
+ val point2 = row.getAs[org.locationtech.jts.geom.Point](0)
+ if (Spheroid.distance(point1, point2) <= distance) 1 else 0
+ }).sum
+ }).sum
+ }
+
+ protected def bruteForceDistanceJoinCountSphere(sampleCount: Int, distance:
Double): Int = {
+ val input = buildPointDf.limit(sampleCount).collect()
+ input.map(row => {
+ val point1 = row.getAs[org.locationtech.jts.geom.Point](0)
+ input.map(row => {
+ val point2 = row.getAs[org.locationtech.jts.geom.Point](0)
+ if (Haversine.distance(point1, point2) <= distance) 1 else 0
+ }).sum
+ }).sum
+ }
+
}
diff --git
a/sql/common/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala
b/sql/common/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala
index 596e561b..66ee94ca 100644
---
a/sql/common/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala
+++
b/sql/common/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala
@@ -21,10 +21,10 @@ package org.apache.sedona.sql
import org.apache.sedona.core.utils.SedonaConf
import org.apache.spark.sql.Row
-import org.apache.spark.sql.execution.ExplainMode
+import org.apache.spark.sql.functions.expr
+import org.apache.spark.sql.sedona_sql.strategy.join.DistanceJoinExec
import org.apache.spark.sql.types._
import org.locationtech.jts.geom.Geometry
-import org.locationtech.jts.io.WKTWriter
class predicateJoinTestScala extends TestBaseScala {
@@ -361,5 +361,49 @@ class predicateJoinTestScala extends TestBaseScala {
assert(equalJoinDf.count() == 0, s"Expected 0 but got
${equalJoinDf.count()}")
}
+
+ it("Passed ST_DistanceSpheroid in a spatial join") {
+ val distanceCandidates = Seq(130000, 160000, 500000)
+ val sampleCount = 50
+ distanceCandidates.foreach(distance => {
+ val expected = bruteForceDistanceJoinCountSpheroid(sampleCount,
distance)
+ val pointDf1 = buildPointDf.limit(sampleCount).repartition(4)
+ val pointDf2 = pointDf1
+ var distanceJoinDf = pointDf1.alias("pointDf1").join(
+ pointDf2.alias("pointDf2"),
expr(s"ST_DistanceSpheroid(pointDf1.pointshape, pointDf2.pointshape) <=
$distance"))
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
DistanceJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == expected)
+
+ distanceJoinDf = pointDf1.alias("pointDf1").join(
+ pointDf2.alias("pointDf2"),
expr(s"ST_DistanceSpheroid(pointDf1.pointshape, pointDf2.pointshape) <
$distance"))
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
DistanceJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == expected)
+ })
+ }
+
+ it("Passed ST_DistanceSphere in a spatial join") {
+ val distanceCandidates = Seq(130000, 160000, 500000)
+ val sampleCount = 50
+ distanceCandidates.foreach(distance => {
+ val expected = bruteForceDistanceJoinCountSphere(sampleCount, distance)
+ val pointDf1 = buildPointDf.limit(sampleCount).repartition(4)
+ val pointDf2 = pointDf1
+ var distanceJoinDf = pointDf1.alias("pointDf1").join(
+ pointDf2.alias("pointDf2"),
expr(s"ST_DistanceSphere(pointDf1.pointshape, pointDf2.pointshape) <=
$distance"))
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
DistanceJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == expected)
+
+ distanceJoinDf = pointDf1.alias("pointDf1").join(
+ pointDf2.alias("pointDf2"),
expr(s"ST_DistanceSphere(pointDf1.pointshape, pointDf2.pointshape) <
$distance"))
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
DistanceJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == expected)
+
+ distanceJoinDf = pointDf1.alias("pointDf1").join(
+ pointDf2.alias("pointDf2"),
expr(s"ST_DistanceSphere(pointDf1.pointshape, pointDf2.pointshape, 6371008.0) <
$distance"))
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
DistanceJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == expected)
+ })
+
+ }
}
}