This is an automated email from the ASF dual-hosted git repository. jiayu pushed a commit to branch geography-join in repository https://gitbox.apache.org/repos/asf/sedona.git
commit dadc81861cb0a912aacf69881efbe3d579fd3673 Author: Jia Yu <[email protected]> AuthorDate: Mon May 29 01:23:36 2023 -0700 Add distanceJoin support for ST_DistanceSpheroid --- docs/api/sql/Function.md | 6 +- docs/api/sql/Optimizer.md | 37 ++++++++-- .../strategy/join/DistanceJoinExec.scala | 5 +- .../strategy/join/JoinQueryDetector.scala | 84 +++++++++++++--------- .../strategy/join/SpatialIndexExec.scala | 3 +- .../strategy/join/TraitJoinQueryBase.scala | 21 +++++- .../sedona/sql/BroadcastIndexJoinSuite.scala | 19 +++++ .../org/apache/sedona/sql/TestBaseScala.scala | 21 ++++++ .../apache/sedona/sql/predicateJoinTestScala.scala | 18 ++++- 9 files changed, 163 insertions(+), 51 deletions(-) diff --git a/docs/api/sql/Function.md b/docs/api/sql/Function.md index b8643268..e92ce4f3 100644 --- a/docs/api/sql/Function.md +++ b/docs/api/sql/Function.md @@ -51,7 +51,7 @@ FROM polygondf ## ST_AreaSpheroid -Introduction: Return the geodesic area of A using WGS84 spheroid. Unit is meter. Works better for large geometries (country level) compared to `ST_Area` + `ST_Transform`. It is equivalent to PostGIS `ST_Area(geography, use_spheroid=true)` function and produces nearly identical results. +Introduction: Return the geodesic area of A using WGS84 spheroid. Unit is square meter. Works better for large geometries (country level) compared to `ST_Area` + `ST_Transform`. It is equivalent to PostGIS `ST_Area(geography, use_spheroid=true)` function and produces nearly identical results. Geometry must be in EPSG:4326 (WGS84) projection and must be in ==lat/lon== order. You can use ==ST_FlipCoordinates== to swap lat and lon. @@ -416,7 +416,7 @@ FROM polygondf ## ST_DistanceSphere -Introduction: Return the haversine / great-circle distance of A using a given earth radius (default radius: 6378137.0). Unit is meter. Works better for large geometries (country level) compared to `ST_Distance` + `ST_Transform`. It is equivalent to PostGIS `ST_Distance(geography, use_spheroid=false)` and `ST_DistanceSphere` function and produces nearly identical results. It provides faster but less accurate result compared to `ST_DistanceSpheroid`. +Introduction: Return the haversine / great-circle distance of A using a given earth radius (default radius: 6378137.0). Unit is meter. Compared to `ST_Distance` + `ST_Transform`, it works better for datasets that cover large regions such as continents or the entire planet. It is equivalent to PostGIS `ST_Distance(geography, use_spheroid=false)` and `ST_DistanceSphere` function and produces nearly identical results. It provides faster but less accurate result compared to `ST_DistanceSpheroid`. Geometry must be in EPSG:4326 (WGS84) projection and must be in ==lat/lon== order. You can use ==ST_FlipCoordinates== to swap lat and lon. For non-point data, we first take the centroids of both geometries and then compute the distance. @@ -441,7 +441,7 @@ Output: `544405.4459192449` ## ST_DistanceSpheroid -Introduction: Return the geodesic distance of A using WGS84 spheroid. Unit is meter. Works better for large geometries (country level) compared to `ST_Distance` + `ST_Transform`. It is equivalent to PostGIS `ST_Distance(geography, use_spheroid=true)` and `ST_DistanceSpheroid` function and produces nearly identical results. It provides slower but more accurate result compared to `ST_DistanceSphere`. +Introduction: Return the geodesic distance of A using WGS84 spheroid. Unit is meter. Compared to `ST_Distance` + `ST_Transform`, it works better for datasets that cover large regions such as continents or the entire planet. It is equivalent to PostGIS `ST_Distance(geography, use_spheroid=true)` and `ST_DistanceSpheroid` function and produces nearly identical results. It provides slower but more accurate result compared to `ST_DistanceSphere`. Geometry must be in EPSG:4326 (WGS84) projection and must be in ==lat/lon== order. You can use ==ST_FlipCoordinates== to swap lat and lon. For non-point data, we first take the centroids of both geometries and then compute the distance. diff --git a/docs/api/sql/Optimizer.md b/docs/api/sql/Optimizer.md index 025b964b..ca6cc09d 100644 --- a/docs/api/sql/Optimizer.md +++ b/docs/api/sql/Optimizer.md @@ -28,7 +28,9 @@ SELECT * FROM pointdf, polygondf WHERE ST_Within(pointdf.pointshape, polygondf.polygonshape) ``` + Spark SQL Physical plan: + ``` == Physical Plan == RangeJoin polygonshape#20: geometry, pointshape#43: geometry, false @@ -44,9 +46,9 @@ RangeJoin polygonshape#20: geometry, pointshape#43: geometry, false ## Distance join -Introduction: Find geometries from A and geometries from B such that the internal Euclidean distance of each geometry pair is less or equal than a certain distance +Introduction: Find geometries from A and geometries from B such that the distance of each geometry pair is less or equal than a certain distance. It supports the planar Euclidean distance calculator `ST_Distance` and the meter-based geodesic distance calculator `ST_DistanceSpheroid`. -Spark SQL Example: +Spark SQL Example for planar Euclidean distance: *Only consider ==fully within a certain distance==* ```sql @@ -73,7 +75,26 @@ DistanceJoin pointshape1#12: geometry, pointshape2#33: geometry, 2.0, true ``` !!!warning - Sedona doesn't control the distance's unit (degree or meter). It is same with the geometry. If your coordinates are in the longitude and latitude system, the unit of `distance` should be degree instead of meter or mile. To change the geometry's unit, please either transform the coordinate reference system to a meter-based system. See [ST_Transform](Function.md#st_transform). If you don't want to transform your data and are ok with sacrificing the query accuracy, you can use an approxima [...] + If you use `ST_Distance` as the predicate, Sedona doesn't control the distance's unit (degree or meter). It is same with the geometry. If your coordinates are in the longitude and latitude system, the unit of `distance` should be degree instead of meter or mile. To change the geometry's unit, please either transform the coordinate reference system to a meter-based system. See [ST_Transform](Function.md#st_transform). If you don't want to transform your data, please consider using `ST_Di [...] + +Spark SQL Example for meter-based geodesic distance: + +*Less than a certain distance==* +```sql +SELECT * +FROM pointdf1, pointdf2 +WHERE ST_DistanceSpheroid(pointdf1.pointshape1,pointdf2.pointshape2) < 2 +``` + +*Less than or equal to a certain distance==* +```sql +SELECT * +FROM pointdf1, pointdf2 +WHERE ST_DistanceSpheroid(pointdf1.pointshape1,pointdf2.pointshape2) <= 2 +``` + +!!!warning + If you use `ST_DistanceSpheroid ` as the predicate, the unit of the distance is meter. Currently, distance join with geodesic distance calculators work best for point data. For non-point data, it only considers their centroids. ## Broadcast index join @@ -105,7 +126,7 @@ BroadcastIndexJoin pointshape#52: geometry, BuildRight, BuildRight, false ST_Con +- FileScan csv ``` -This also works for distance joins: +This also works for distance joins with `ST_Distance` or `ST_DistanceSpheroid`: ```scala pointDf1.alias("pointDf1").join(broadcast(pointDf2).alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2")) @@ -202,14 +223,16 @@ GROUP BY (lcs_geom, rcs_geom) This also works for distance join. You first need to use `ST_Buffer(geometry, distance)` to wrap one of your original geometry column. If your original geometry column contains points, this `ST_Buffer` will make them become circles with a radius of `distance`. -For example. run this query first on the left table before Step 1. +Since the coordinates are in the longitude and latitude system, so the unit of `distance` should be degree instead of meter or mile. You can get an approximation by performing `METER_DISTANCE/111000.0`, then filter out false-positives. + +In a nutshell, run this query first on the left table before Step 1. Please replace `METER_DISTANCE` with a meter distance. In Step 1, generate S2 IDs based on the `buffered_geom` column. Then run Step 2, 3, 4 on the original `geom` column. ```sql -SELECT id, ST_Buffer(geom, DISTANCE), name +SELECT id, geom , ST_Buffer(geom, METER_DISTANCE/111000.0) as buffered_geom, name FROM lefts ``` -Since the coordinates are in the longitude and latitude system, so the unit of `distance` should be degree instead of meter or mile. You will have to estimate the corresponding degrees based on your meter values. Please use [this calculator](https://lucidar.me/en/online-unit-converter-length-to-angle/convert-degrees-to-meters/#online-converter). + ## Regular spatial predicate pushdown Introduction: Given a join query and a predicate in the same WHERE clause, first executes the Predicate as a filter, then executes the join query. diff --git a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala index b393f2d0..615f88a2 100644 --- a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala +++ b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala @@ -54,6 +54,7 @@ case class DistanceJoinExec(left: SparkPlan, distance: Expression, distanceBoundToLeft: Boolean, spatialPredicate: SpatialPredicate, + isGeography: Boolean, extraCondition: Option[Expression] = None) extends SedonaBinaryExecNode with TraitJoinQueryExec @@ -70,9 +71,9 @@ case class DistanceJoinExec(left: SparkPlan, rightRdd: RDD[UnsafeRow], rightShapeExpr: Expression): (SpatialRDD[Geometry], SpatialRDD[Geometry]) = { if (distanceBoundToLeft) { - (toExpandedEnvelopeRDD(leftRdd, leftShapeExpr, boundRadius), toSpatialRDD(rightRdd, rightShapeExpr)) + (toExpandedEnvelopeRDD(leftRdd, leftShapeExpr, boundRadius, isGeography), toSpatialRDD(rightRdd, rightShapeExpr)) } else { - (toSpatialRDD(leftRdd, leftShapeExpr), toExpandedEnvelopeRDD(rightRdd, rightShapeExpr, boundRadius)) + (toSpatialRDD(leftRdd, leftShapeExpr), toExpandedEnvelopeRDD(rightRdd, rightShapeExpr, boundRadius, isGeography)) } } diff --git a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala index 464db3e7..4536ec50 100644 --- a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala +++ b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala @@ -21,13 +21,13 @@ package org.apache.spark.sql.sedona_sql.strategy.join import org.apache.sedona.core.enums.{IndexType, SpatialJoinOptimizationMode} import org.apache.sedona.core.spatialOperator.SpatialPredicate import org.apache.sedona.core.utils.SedonaConf -import org.apache.spark.sql.{SparkSession, Strategy} import org.apache.spark.sql.catalyst.expressions.{And, EqualNullSafe, EqualTo, Expression, LessThan, LessThanOrEqual} -import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, Inner, InnerLike, JoinType, LeftAnti, LeftOuter, LeftSemi, NaturalJoin, RightOuter, UsingJoin} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.sedona_sql.expressions._ import org.apache.spark.sql.sedona_sql.optimization.ExpressionUtils.splitConjunctivePredicates +import org.apache.spark.sql.{SparkSession, Strategy} case class JoinQueryDetection( @@ -36,6 +36,7 @@ case class JoinQueryDetection( leftShape: Expression, rightShape: Expression, spatialPredicate: SpatialPredicate, + isGeography: Boolean, extraCondition: Option[Expression] = None, distance: Option[Expression] = None ) @@ -57,23 +58,23 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy { extraCondition: Option[Expression] = None): Option[JoinQueryDetection] = { predicate match { case ST_Contains(Seq(leftShape, rightShape)) => - Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.CONTAINS, extraCondition)) + Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.CONTAINS, false, extraCondition)) case ST_Intersects(Seq(leftShape, rightShape)) => - Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, extraCondition)) + Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, false, extraCondition)) case ST_Within(Seq(leftShape, rightShape)) => - Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.WITHIN, extraCondition)) + Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.WITHIN, false, extraCondition)) case ST_Covers(Seq(leftShape, rightShape)) => - Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.COVERS, extraCondition)) + Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.COVERS, false, extraCondition)) case ST_CoveredBy(Seq(leftShape, rightShape)) => - Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.COVERED_BY, extraCondition)) + Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.COVERED_BY, false, extraCondition)) case ST_Overlaps(Seq(leftShape, rightShape)) => - Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.OVERLAPS, extraCondition)) + Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.OVERLAPS, false, extraCondition)) case ST_Touches(Seq(leftShape, rightShape)) => - Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.TOUCHES, extraCondition)) + Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.TOUCHES, false, extraCondition)) case ST_Equals(Seq(leftShape, rightShape)) => - Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.EQUALS, extraCondition)) + Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.EQUALS, false, extraCondition)) case ST_Crosses(Seq(leftShape, rightShape)) => - Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.CROSSES, extraCondition)) + Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.CROSSES, false, extraCondition)) case _ => None } } @@ -109,20 +110,32 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy { getJoinDetection(left, right, predicate, Some(extraCondition)) case Some(And(extraCondition, predicate: ST_Predicate)) => getJoinDetection(left, right, predicate, Some(extraCondition)) - // For distance joins we execute the actual predicate (condition) and not only extraConditions. case Some(LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)), distance)) => - Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, condition, Some(distance))) + Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, false, condition, Some(distance))) case Some(And(LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)), distance), _)) => - Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, condition, Some(distance))) + Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, false, condition, Some(distance))) case Some(And(_, LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)), distance))) => - Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, condition, Some(distance))) + Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, false, condition, Some(distance))) case Some(LessThan(ST_Distance(Seq(leftShape, rightShape)), distance)) => - Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, condition, Some(distance))) + Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, false, condition, Some(distance))) case Some(And(LessThan(ST_Distance(Seq(leftShape, rightShape)), distance), _)) => - Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, condition, Some(distance))) + Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, false, condition, Some(distance))) case Some(And(_, LessThan(ST_Distance(Seq(leftShape, rightShape)), distance))) => - Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, condition, Some(distance))) + Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, false, condition, Some(distance))) + // ST_DistanceSpheroid + case Some(LessThanOrEqual(ST_DistanceSpheroid(Seq(leftShape, rightShape)), distance)) => + Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, true, condition, Some(distance))) + case Some(And(LessThanOrEqual(ST_DistanceSpheroid(Seq(leftShape, rightShape)), distance), _)) => + Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, true, condition, Some(distance))) + case Some(And(_, LessThanOrEqual(ST_DistanceSpheroid(Seq(leftShape, rightShape)), distance))) => + Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, true, condition, Some(distance))) + case Some(LessThan(ST_DistanceSpheroid(Seq(leftShape, rightShape)), distance)) => + Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, true, condition, Some(distance))) + case Some(And(LessThan(ST_DistanceSpheroid(Seq(leftShape, rightShape)), distance), _)) => + Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, true, condition, Some(distance))) + case Some(And(_, LessThan(ST_DistanceSpheroid(Seq(leftShape, rightShape)), distance))) => + Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, true, condition, Some(distance))) case _ => None } @@ -131,20 +144,20 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy { if ((broadcastLeft || broadcastRight) && sedonaConf.getUseIndex) { queryDetection match { - case Some(JoinQueryDetection(left, right, leftShape, rightShape, spatialPredicate, extraCondition, distance)) => + case Some(JoinQueryDetection(left, right, leftShape, rightShape, spatialPredicate, isGeography, extraCondition, distance)) => planBroadcastJoin( left, right, Seq(leftShape, rightShape), joinType, spatialPredicate, sedonaConf.getIndexType, - broadcastLeft, broadcastRight, extraCondition, distance) + broadcastLeft, broadcastRight, isGeography, extraCondition, distance) case _ => Nil } } else { queryDetection match { - case Some(JoinQueryDetection(left, right, leftShape, rightShape, spatialPredicate, extraCondition, None)) => + case Some(JoinQueryDetection(left, right, leftShape, rightShape, spatialPredicate, isGeography, extraCondition, None)) => planSpatialJoin(left, right, Seq(leftShape, rightShape), joinType, spatialPredicate, extraCondition) - case Some(JoinQueryDetection(left, right, leftShape, rightShape, spatialPredicate, extraCondition, Some(distance))) => - planDistanceJoin(left, right, Seq(leftShape, rightShape), joinType, distance, spatialPredicate, extraCondition) + case Some(JoinQueryDetection(left, right, leftShape, rightShape, spatialPredicate, isGeography, extraCondition, Some(distance))) => + planDistanceJoin(left, right, Seq(leftShape, rightShape), joinType, distance, spatialPredicate, isGeography, extraCondition) case None => Nil } @@ -236,6 +249,7 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy { joinType: JoinType, distance: Expression, spatialPredicate: SpatialPredicate, + isGeography: Boolean, extraCondition: Option[Expression] = None): Seq[SparkPlan] = { if (joinType != Inner) { @@ -252,11 +266,11 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy { case Some(LeftSide) => logInfo("Planning spatial distance join, distance bound to left relation") DistanceJoinExec(planLater(left), planLater(right), leftShape, rightShape, distance, distanceBoundToLeft = true, - spatialPredicate, extraCondition) :: Nil + spatialPredicate, isGeography, extraCondition) :: Nil case Some(RightSide) => logInfo("Planning spatial distance join, distance bound to right relation") DistanceJoinExec(planLater(left), planLater(right), leftShape, rightShape, distance, distanceBoundToLeft = false, - spatialPredicate, extraCondition) :: Nil + spatialPredicate, isGeography, extraCondition) :: Nil case _ => logInfo( "Spatial distance join for ST_Distance with non-scalar distance " + @@ -280,6 +294,7 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy { indexType: IndexType, broadcastLeft: Boolean, broadcastRight: Boolean, + isGeography: Boolean, extraCondition: Option[Expression], distance: Option[Expression]): Seq[SparkPlan] = { @@ -300,12 +315,13 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy { val a = children.head val b = children.tail.head - val relationship = (distance, spatialPredicate) match { - case (Some(_), SpatialPredicate.INTERSECTS) => "ST_Distance <=" - case (Some(_), _) => "ST_Distance <" - case (None, _) => s"ST_$spatialPredicate" + val relationship = (distance, spatialPredicate, isGeography) match { + case (Some(_), SpatialPredicate.INTERSECTS, false) => "ST_Distance <=" + case (Some(_), _, false) => "ST_Distance <" + case (Some(_), SpatialPredicate.INTERSECTS, true) => "ST_Distance (Geography) <=" + case (Some(_), _, true) => "ST_Distance (Geography) <" + case (None, _, false) => s"ST_$spatialPredicate" } - val (distanceOnIndexSide, distanceOnStreamSide) = distance.map { distanceExpr => matchDistanceExpressionToJoinSide(distanceExpr, left, right) match { case Some(side) => @@ -321,13 +337,13 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy { logInfo(s"Planning spatial join for $relationship relationship") val (leftPlan, rightPlan, streamShape, windowSide) = (broadcastSide.get, swapped) match { case (LeftSide, false) => // Broadcast the left side, windows on the left - (SpatialIndexExec(planLater(left), a, indexType, distanceOnIndexSide), planLater(right), b, LeftSide) + (SpatialIndexExec(planLater(left), a, indexType, isGeography, distanceOnIndexSide), planLater(right), b, LeftSide) case (LeftSide, true) => // Broadcast the left side, objects on the left - (SpatialIndexExec(planLater(left), b, indexType, distanceOnIndexSide), planLater(right), a, RightSide) + (SpatialIndexExec(planLater(left), b, indexType, isGeography, distanceOnIndexSide), planLater(right), a, RightSide) case (RightSide, false) => // Broadcast the right side, windows on the left - (planLater(left), SpatialIndexExec(planLater(right), b, indexType, distanceOnIndexSide), a, LeftSide) + (planLater(left), SpatialIndexExec(planLater(right), b, indexType, isGeography, distanceOnIndexSide), a, LeftSide) case (RightSide, true) => // Broadcast the right side, objects on the left - (planLater(left), SpatialIndexExec(planLater(right), a, indexType, distanceOnIndexSide), b, RightSide) + (planLater(left), SpatialIndexExec(planLater(right), a, indexType, isGeography, distanceOnIndexSide), b, RightSide) } BroadcastIndexJoinExec(leftPlan, rightPlan, streamShape, broadcastSide.get, windowSide, joinType, spatialPredicate, extraCondition, distanceOnStreamSide) :: Nil diff --git a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala index f9f9a4ed..2c24a34e 100644 --- a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala +++ b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.sedona_sql.execution.SedonaUnaryExecNode case class SpatialIndexExec(child: SparkPlan, shape: Expression, indexType: IndexType, + isGeography: Boolean, distance: Option[Expression] = None) extends SedonaUnaryExecNode with TraitJoinQueryBase @@ -51,7 +52,7 @@ case class SpatialIndexExec(child: SparkPlan, val resultRaw = child.execute().asInstanceOf[RDD[UnsafeRow]].coalesce(1) val spatialRDD = distance match { - case Some(distanceExpression) => toExpandedEnvelopeRDD(resultRaw, boundShape, BindReferences.bindReference(distanceExpression, child.output)) + case Some(distanceExpression) => toExpandedEnvelopeRDD(resultRaw, boundShape, BindReferences.bindReference(distanceExpression, child.output), isGeography) case None => toSpatialRDD(resultRaw, boundShape) } diff --git a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala index df8f4cd3..0f636ab3 100644 --- a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala +++ b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala @@ -48,14 +48,14 @@ trait TraitJoinQueryBase { spatialRdd } - def toExpandedEnvelopeRDD(rdd: RDD[UnsafeRow], shapeExpression: Expression, boundRadius: Expression): SpatialRDD[Geometry] = { + def toExpandedEnvelopeRDD(rdd: RDD[UnsafeRow], shapeExpression: Expression, boundRadius: Expression, isGeography: Boolean): SpatialRDD[Geometry] = { val spatialRdd = new SpatialRDD[Geometry] spatialRdd.setRawSpatialRDD( rdd .map { x => val shape = GeometrySerializer.deserialize(shapeExpression.eval(x).asInstanceOf[Array[Byte]]) val envelope = shape.getEnvelopeInternal.copy() - envelope.expandBy(boundRadius.eval(x).asInstanceOf[Double]) + envelope.expandBy(distanceToDegree(boundRadius.eval(x).asInstanceOf[Double], isGeography)) val expandedEnvelope = shape.getFactory.toGeometry(envelope) expandedEnvelope.setUserData(x.copy) @@ -72,4 +72,21 @@ trait TraitJoinQueryBase { followerShapes.spatialPartitioning(dominantShapes.getPartitioner) } } + + /** + * Convert distance to degree based on the given isGeography flag. + * Note that this is an approximation since the degree of longitude is not constant. + * We assume that the degree of longitude is 111000 meters without considering the latitude. + * For latitude, the degree is always 111000 meters. + * @param distance + * @param isGeography + * @return + */ + private def distanceToDegree(distance: Double, isGeography: Boolean): Double = { + if (isGeography) { + distance / 111000.0 + } else { + distance + } + } } diff --git a/sql/common/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala b/sql/common/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala index 78f3aaf5..0dd17104 100644 --- a/sql/common/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala +++ b/sql/common/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala @@ -302,6 +302,25 @@ class BroadcastIndexJoinSuite extends TestBaseScala { assert(rows2(0) == Row(2.0, 2.0, "left_2", 2.0, 2.0, "right_2")) } + + it("Passed ST_DistanceSpheroid in a broadcast join") { + val pointDf1 = buildPointDf + val pointDf2 = buildPointDf + var distanceJoinDf = pointDf1.alias("pointDf1").join( + broadcast(pointDf2).alias("pointDf2"), expr("ST_DistanceSpheroid(pointDf1.pointshape, pointDf2.pointshape) <= 2.0")) + assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: BroadcastIndexJoinExec => p }.size === 1) + assert(distanceJoinDf.count() == 89) + + distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join( + pointDf2.alias("pointDf2"), expr("ST_DistanceSpheroid(pointDf1.pointshape, pointDf2.pointshape) <= 2.0")) + assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: BroadcastIndexJoinExec => p }.size === 1) + assert(distanceJoinDf.count() == 89) + + distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join( + pointDf2.alias("pointDf2"), expr("ST_DistanceSpheroid(pointDf1.pointshape, pointDf2.pointshape) < 2.0")) + assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: BroadcastIndexJoinExec => p }.size === 1) + assert(distanceJoinDf.count() == 89) + } } describe("Sedona-SQL Broadcast Index Join Test for left semi joins") { diff --git a/sql/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala b/sql/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala index f12b5874..c9116533 100644 --- a/sql/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala +++ b/sql/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala @@ -20,6 +20,7 @@ package org.apache.sedona.sql import com.google.common.math.DoubleMath import org.apache.log4j.{Level, Logger} +import org.apache.sedona.common.sphere.{Haversine, Spheroid} import org.apache.sedona.core.serde.SedonaKryoRegistrator import org.apache.sedona.sql.utils.SedonaSQLRegistrator import org.apache.spark.serializer.KryoSerializer @@ -97,4 +98,24 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll { } } + protected def bruteForceDistanceJoinCountSpheroid(distance: Double): Int = { + buildPointDf.collect().map(row => { + val point1 = row.getAs[org.locationtech.jts.geom.Point](0) + buildPointDf.collect().map(row => { + val point2 = row.getAs[org.locationtech.jts.geom.Point](0) + if (Spheroid.distance(point1, point2) <= distance) 1 else 0 + }).sum + }).sum + } + + protected def bruteForceDistanceJoinCountSphere(distance: Double): Int = { + buildPointDf.collect().map(row => { + val point1 = row.getAs[org.locationtech.jts.geom.Point](0) + buildPointDf.collect().map(row => { + val point2 = row.getAs[org.locationtech.jts.geom.Point](0) + if (Haversine.distance(point1, point2) <= distance) 1 else 0 + }).sum + }).sum + } + } diff --git a/sql/common/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala b/sql/common/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala index 596e561b..037b75c5 100644 --- a/sql/common/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala +++ b/sql/common/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala @@ -21,10 +21,10 @@ package org.apache.sedona.sql import org.apache.sedona.core.utils.SedonaConf import org.apache.spark.sql.Row -import org.apache.spark.sql.execution.ExplainMode +import org.apache.spark.sql.functions.expr +import org.apache.spark.sql.sedona_sql.strategy.join.DistanceJoinExec import org.apache.spark.sql.types._ import org.locationtech.jts.geom.Geometry -import org.locationtech.jts.io.WKTWriter class predicateJoinTestScala extends TestBaseScala { @@ -361,5 +361,19 @@ class predicateJoinTestScala extends TestBaseScala { assert(equalJoinDf.count() == 0, s"Expected 0 but got ${equalJoinDf.count()}") } + + it("Passed ST_DistanceSpheroid in a spatial join") { + val pointDf1 = buildPointDf + val pointDf2 = buildPointDf + var distanceJoinDf = pointDf1.alias("pointDf1").join( + pointDf2.alias("pointDf2"), expr("ST_DistanceSpheroid(pointDf1.pointshape, pointDf2.pointshape) <= 2.0")) + assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: DistanceJoinExec => p }.size === 1) + assert(distanceJoinDf.count() == 89) + + distanceJoinDf = pointDf1.alias("pointDf1").join( + pointDf2.alias("pointDf2"), expr("ST_DistanceSpheroid(pointDf1.pointshape, pointDf2.pointshape) < 2.0")) + assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: DistanceJoinExec => p }.size === 1) + assert(distanceJoinDf.count() == 89) + } } }
