This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 8489b2d6 [SEDONA-191] Support join types other than inner in
BroadcastIndexJoinExec (#711)
8489b2d6 is described below
commit 8489b2d604f8b5039707d805386292e0a3b9c245
Author: Tanel Kiis <[email protected]>
AuthorDate: Sat Nov 26 09:43:29 2022 +0200
[SEDONA-191] Support join types other than inner in BroadcastIndexJoinExec
(#711)
---
docs/api/sql/Optimizer.md | 12 +-
.../strategy/join/BroadcastIndexJoinExec.scala | 222 +++-
.../strategy/join/JoinQueryDetector.scala | 102 +-
.../sedona/sql/BroadcastIndexJoinSuite.scala | 1223 +++++++++++++++++++-
4 files changed, 1436 insertions(+), 123 deletions(-)
diff --git a/docs/api/sql/Optimizer.md b/docs/api/sql/Optimizer.md
index eab7dbff..35f05878 100644
--- a/docs/api/sql/Optimizer.md
+++ b/docs/api/sql/Optimizer.md
@@ -73,7 +73,15 @@ DistanceJoin pointshape1#12: geometry, pointshape2#33:
geometry, 2.0, true
Sedona doesn't control the distance's unit (degree or meter). It is
same with the geometry. To change the geometry's unit, please transform the
coordinate reference system. See [ST_Transform](Function.md#st_transform).
## Broadcast join
-Introduction: Perform a range join or distance join but broadcast one of the
sides of the join. This maintains the partitioning of the non-broadcast side
and doesn't require a shuffle.
+Introduction: Perform a range join or distance join but broadcast one of the
sides of the join.
+This maintains the partitioning of the non-broadcast side and doesn't require
a shuffle.
+Sedona uses broadcast join only if the correct side has a broadcast hint.
+The supported join type - broadcast side combinations are
+* Inner - either side, preferring to broadcast left if both sides have the hint
+* Left semi - broadcast right
+* Left anti - broadcast right
+* Left outer - broadcast right
+* Right outer - broadcast left
```Scala
pointDf.alias("pointDf").join(broadcast(polygonDf).alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
@@ -107,7 +115,7 @@ BroadcastIndexJoin pointshape#52: geometry, BuildRight,
BuildLeft, true, 2.0 ST_
+- FileScan csv
```
-Note: Ff the distance is an expression, it is only evaluated on the first
argument to ST_Distance (`pointDf1` above).
+Note: If the distance is an expression, it is only evaluated on the first
argument to ST_Distance (`pointDf1` above).
## Predicate pushdown
diff --git
a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala
b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala
index 0d089873..34ec41f1 100644
---
a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala
+++
b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala
@@ -18,19 +18,18 @@
*/
package org.apache.spark.sql.sedona_sql.strategy.join
-import org.apache.sedona.core.spatialOperator.SpatialPredicate
-import org.apache.sedona.core.spatialOperator.SpatialPredicateEvaluators
+import org.apache.sedona.core.spatialOperator.{SpatialPredicate,
SpatialPredicateEvaluators}
+import
org.apache.sedona.core.spatialOperator.SpatialPredicateEvaluators.SpatialPredicateEvaluator
-import collection.JavaConverters._
-import org.apache.sedona.core.spatialRDD.SpatialRDD
+import scala.collection.JavaConverters._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences,
Expression, Predicate, UnsafeRow}
-import
org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
+import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences,
Expression, GenericInternalRow, JoinedRow, Predicate, UnsafeProjection,
UnsafeRow}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.execution.{RowIterator, SparkPlan}
import org.apache.spark.sql.sedona_sql.execution.SedonaBinaryExecNode
import org.locationtech.jts.geom.Geometry
import org.locationtech.jts.geom.prep.{PreparedGeometry,
PreparedGeometryFactory}
@@ -38,32 +37,59 @@ import org.locationtech.jts.index.SpatialIndex
import scala.collection.mutable
-case class BroadcastIndexJoinExec(left: SparkPlan,
- right: SparkPlan,
- streamShape: Expression,
- indexBuildSide: JoinSide,
- windowJoinSide: JoinSide,
- spatialPredicate: SpatialPredicate,
- extraCondition: Option[Expression] = None,
- distance: Option[Expression] = None)
+case class BroadcastIndexJoinExec(
+ left: SparkPlan,
+ right: SparkPlan,
+ streamShape: Expression,
+ indexBuildSide: JoinSide,
+ windowJoinSide: JoinSide,
+ joinType: JoinType,
+ spatialPredicate: SpatialPredicate,
+ extraCondition: Option[Expression] = None,
+ distance: Option[Expression] = None)
extends SedonaBinaryExecNode
with TraitJoinQueryBase
with Logging {
- override def output: Seq[Attribute] = left.output ++ right.output
+ override def output: Seq[Attribute] = {
+ joinType match {
+ case _: InnerLike =>
+ left.output ++ right.output
+ case LeftOuter =>
+ left.output ++ right.output.map(_.withNullability(true))
+ case RightOuter =>
+ left.output.map(_.withNullability(true)) ++ right.output
+ case j: ExistenceJoin =>
+ left.output :+ j.exists
+ case LeftExistence(_) =>
+ left.output
+ case x =>
+ throw new IllegalArgumentException(s"BroadcastIndexJoinExec should not
take $x as the JoinType")
+ }
+ }
+
+ private val (streamed, broadcast) = indexBuildSide match {
+ case LeftSide => (right, left.asInstanceOf[SpatialIndexExec])
+ case RightSide => (left, right.asInstanceOf[SpatialIndexExec])
+ }
// Using lazy val to avoid serialization
@transient private lazy val boundCondition: (InternalRow => Boolean) =
extraCondition match {
case Some(condition) =>
- Predicate.create(condition, output).eval _ // SPARK3 anchor
-// newPredicate(condition, output).eval _ // SPARK2 anchor
+ Predicate.create(condition, streamed.output ++ broadcast.output).eval _
// SPARK3 anchor
+ // newPredicate(condition, broadcast.output ++ streamed.output).eval
_ // SPARK2 anchor
case None =>
(r: InternalRow) => true
}
- private val (streamed, broadcast) = indexBuildSide match {
- case LeftSide => (right, left.asInstanceOf[SpatialIndexExec])
- case RightSide => (left, right.asInstanceOf[SpatialIndexExec])
+ protected def createResultProjection(): InternalRow => InternalRow =
joinType match {
+ case LeftExistence(_) =>
+ UnsafeProjection.create(output, output)
+ case _ =>
+ // Always put the stream side on left to simplify implementation
+ // both of left and right side could be null
+ UnsafeProjection.create(
+ output, (streamed.output ++
broadcast.output).map(_.withNullability(true)))
}
override def outputPartitioning: Partitioning = streamed.outputPartitioning
@@ -83,38 +109,120 @@ case class BroadcastIndexJoinExec(left: SparkPlan,
override def simpleString(maxFields: Int): String =
super.simpleString(maxFields) + s" $spatialExpression" // SPARK3 anchor
// override def simpleString: String = super.simpleString + s"
$spatialExpression" // SPARK2 anchor
- private def windowBroadcastJoin(index: Broadcast[SpatialIndex], spatialRdd:
SpatialRDD[Geometry]): RDD[(Geometry, Geometry)] = {
- spatialRdd.getRawSpatialRDD.rdd.mapPartitions { rows =>
- val factory = new PreparedGeometryFactory()
- val preparedGeometries = new mutable.HashMap[Geometry, PreparedGeometry]
- val evaluator = SpatialPredicateEvaluators.create(spatialPredicate)
- rows.flatMap { row =>
- val candidates =
index.value.query(row.getEnvelopeInternal).iterator.asScala.asInstanceOf[Iterator[Geometry]]
- candidates
- .filter(candidate =>
evaluator.eval(preparedGeometries.getOrElseUpdate(candidate, {
factory.create(candidate) }), row))
- .map(candidate => (candidate, row))
+ private lazy val evaluator: SpatialPredicateEvaluator = if (indexBuildSide
== windowJoinSide) {
+ SpatialPredicateEvaluators.create(spatialPredicate)
+ } else {
+
SpatialPredicateEvaluators.create(SpatialPredicate.inverse(spatialPredicate))
+ }
+
+ private def innerJoin(streamIter: Iterator[Geometry], index:
Broadcast[SpatialIndex]): Iterator[InternalRow] = {
+ val factory = new PreparedGeometryFactory()
+ val preparedGeometries = new mutable.HashMap[Geometry, PreparedGeometry]
+ val joinedRow = new JoinedRow
+ streamIter.flatMap { srow =>
+ joinedRow.withLeft(srow.getUserData.asInstanceOf[UnsafeRow])
+ index.value.query(srow.getEnvelopeInternal)
+ .iterator.asScala.asInstanceOf[Iterator[Geometry]]
+ .filter(candidate =>
evaluator.eval(preparedGeometries.getOrElseUpdate(candidate, {
factory.create(candidate) }), srow))
+ .map(candidate =>
joinedRow.withRight(candidate.getUserData.asInstanceOf[UnsafeRow]))
+ .filter(boundCondition)
+ }
+ }
+
+ private def semiJoin(
+ streamIter: Iterator[Geometry], index: Broadcast[SpatialIndex]
+ ): Iterator[InternalRow] = {
+ val factory = new PreparedGeometryFactory()
+ val preparedGeometries = new mutable.HashMap[Geometry, PreparedGeometry]
+ val joinedRow = new JoinedRow
+ streamIter.flatMap { srow =>
+ val left = srow.getUserData.asInstanceOf[UnsafeRow]
+ joinedRow.withLeft(left)
+ val anyMatches = index.value.query(srow.getEnvelopeInternal)
+ .iterator.asScala.asInstanceOf[Iterator[Geometry]]
+ .filter(candidate =>
evaluator.eval(preparedGeometries.getOrElseUpdate(candidate, {
+ factory.create(candidate)
+ }), srow))
+ .map(candidate =>
joinedRow.withRight(candidate.getUserData.asInstanceOf[UnsafeRow]))
+ .exists(boundCondition)
+
+ if (anyMatches) {
+ Iterator.single(left)
+ } else {
+ Iterator.empty
}
}
}
- private def objectBroadcastJoin(index: Broadcast[SpatialIndex], spatialRdd:
SpatialRDD[Geometry]): RDD[(Geometry, Geometry)] = {
- spatialRdd.getRawSpatialRDD.rdd.mapPartitions { rows =>
- val factory = new PreparedGeometryFactory()
- val preparedGeometries = new mutable.HashMap[Geometry, PreparedGeometry]
- val evaluator =
SpatialPredicateEvaluators.create(SpatialPredicate.inverse(spatialPredicate))
- rows.flatMap { row =>
- val candidates =
index.value.query(row.getEnvelopeInternal).iterator.asScala.asInstanceOf[Iterator[Geometry]]
- candidates
- .filter(candidate =>
evaluator.eval(preparedGeometries.getOrElseUpdate(candidate, {
factory.create(candidate) }), row))
- .map(candidate => (row, candidate))
+ private def antiJoin(
+ streamIter: Iterator[Geometry], index: Broadcast[SpatialIndex]
+ ): Iterator[InternalRow] = {
+ val factory = new PreparedGeometryFactory()
+ val preparedGeometries = new mutable.HashMap[Geometry, PreparedGeometry]
+ val joinedRow = new JoinedRow
+ streamIter.flatMap { srow =>
+ val left = srow.getUserData.asInstanceOf[UnsafeRow]
+ joinedRow.withLeft(left)
+ val anyMatches = index.value.query(srow.getEnvelopeInternal)
+ .iterator.asScala.asInstanceOf[Iterator[Geometry]]
+ .filter(candidate =>
evaluator.eval(preparedGeometries.getOrElseUpdate(candidate, {
+ factory.create(candidate)
+ }), srow))
+ .map(candidate =>
joinedRow.withRight(candidate.getUserData.asInstanceOf[UnsafeRow]))
+ .exists(boundCondition)
+
+ if (anyMatches) {
+ Iterator.empty
+ } else {
+ Iterator.single(left)
}
}
}
+ private def outerJoin(
+ streamIter: Iterator[Geometry], index: Broadcast[SpatialIndex]
+ ): Iterator[InternalRow] = {
+ val factory = new PreparedGeometryFactory()
+ val preparedGeometries = new mutable.HashMap[Geometry, PreparedGeometry]
+ val joinedRow = new JoinedRow
+ val nullRow = new GenericInternalRow(broadcast.output.length)
+
+ streamIter.flatMap { srow =>
+ joinedRow.withLeft(srow.getUserData.asInstanceOf[UnsafeRow])
+ val candidates = index.value.query(srow.getEnvelopeInternal)
+ .iterator.asScala.asInstanceOf[Iterator[Geometry]]
+ .filter(candidate =>
evaluator.eval(preparedGeometries.getOrElseUpdate(candidate, {
+ factory.create(candidate)
+ }), srow))
+
+ new RowIterator {
+ private var found = false
+ override def advanceNext(): Boolean = {
+ while (candidates.hasNext) {
+ val candidateRow =
candidates.next().getUserData.asInstanceOf[UnsafeRow]
+ if (boundCondition(joinedRow.withRight(candidateRow))) {
+ found = true
+ return true
+ }
+ }
+ if (!found) {
+ joinedRow.withRight(nullRow)
+ found = true
+ return true
+ }
+ false
+ }
+ override def getRow: InternalRow = joinedRow
+ }.toScala
+ }
+ }
+
override protected def doExecute(): RDD[InternalRow] = {
val boundStreamShape = BindReferences.bindReference(streamShape,
streamed.output)
val streamResultsRaw = streamed.execute().asInstanceOf[RDD[UnsafeRow]]
+ val broadcastIndex = broadcast.executeBroadcast[SpatialIndex]()
+
// If there's a distance and the objects are being broadcast, we need to
build the expanded envelope on the window stream side
val streamShapes = distance match {
case Some(distanceExpression) if indexBuildSide != windowJoinSide =>
@@ -123,23 +231,25 @@ case class BroadcastIndexJoinExec(left: SparkPlan,
toSpatialRDD(streamResultsRaw, boundStreamShape)
}
- val broadcastIndex = broadcast.executeBroadcast[SpatialIndex]()
+ streamShapes.getRawSpatialRDD.rdd.mapPartitions { streamedIter =>
+ val joinedIter = joinType match {
+ case _: InnerLike =>
+ innerJoin(streamedIter, broadcastIndex)
+ case LeftSemi =>
+ semiJoin(streamedIter, broadcastIndex)
+ case LeftAnti =>
+ antiJoin(streamedIter, broadcastIndex)
+ case LeftOuter | RightOuter =>
+ outerJoin(streamedIter, broadcastIndex)
+ case x =>
+ throw new IllegalArgumentException(s"BroadcastIndexJoinExec should
not take $x as the JoinType")
- val pairs = (indexBuildSide, windowJoinSide) match {
- case (LeftSide, LeftSide) => windowBroadcastJoin(broadcastIndex,
streamShapes)
- case (LeftSide, RightSide) => objectBroadcastJoin(broadcastIndex,
streamShapes).map { case (left, right) => (right, left) }
- case (RightSide, LeftSide) => objectBroadcastJoin(broadcastIndex,
streamShapes)
- case (RightSide, RightSide) => windowBroadcastJoin(broadcastIndex,
streamShapes).map { case (left, right) => (right, left) }
- }
+ }
- pairs.mapPartitions { iter =>
- val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
- iter.map {
- case (l, r) =>
- val leftRow = l.getUserData.asInstanceOf[UnsafeRow]
- val rightRow = r.getUserData.asInstanceOf[UnsafeRow]
- joiner.join(leftRow, rightRow)
- }.filter(boundCondition(_))
+ val resultProj = createResultProjection()
+ joinedIter.map { r =>
+ resultProj(r)
+ }
}
}
diff --git
a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
index f5e086c0..fbd1f958 100644
---
a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
+++
b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
@@ -23,20 +23,20 @@ import
org.apache.sedona.core.spatialOperator.SpatialPredicate
import org.apache.sedona.core.utils.SedonaConf
import org.apache.spark.sql.{SparkSession, Strategy}
import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan,
LessThanOrEqual}
-import org.apache.spark.sql.catalyst.plans.Inner
+import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, Inner,
InnerLike, JoinType, LeftAnti, LeftOuter, LeftSemi, NaturalJoin, RightOuter,
UsingJoin}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.sedona_sql.expressions._
case class JoinQueryDetection(
- left: LogicalPlan,
- right: LogicalPlan,
- leftShape: Expression,
- rightShape: Expression,
- spatialPredicate: SpatialPredicate,
- extraCondition: Option[Expression] = None,
- distance: Option[Expression] = None
+ left: LogicalPlan,
+ right: LogicalPlan,
+ leftShape: Expression,
+ rightShape: Expression,
+ spatialPredicate: SpatialPredicate,
+ extraCondition: Option[Expression] = None,
+ distance: Option[Expression] = None
)
/**
@@ -78,7 +78,7 @@ class JoinQueryDetector(sparkSession: SparkSession) extends
Strategy {
}
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case Join(left, right, Inner, condition, JoinHint(leftHint, rightHint)) =>
{ // SPARK3 anchor
+ case Join(left, right, joinType, condition, JoinHint(leftHint, rightHint))
=> { // SPARK3 anchor
// case Join(left, right, Inner, condition) => { // SPARK2 anchor
val broadcastLeft = leftHint.exists(_.strategy.contains(BROADCAST)) //
SPARK3 anchor
val broadcastRight = rightHint.exists(_.strategy.contains(BROADCAST)) //
SPARK3 anchor
@@ -115,16 +115,19 @@ class JoinQueryDetector(sparkSession: SparkSession)
extends Strategy {
if ((broadcastLeft || broadcastRight) && sedonaConf.getUseIndex) {
queryDetection match {
case Some(JoinQueryDetection(left, right, leftShape, rightShape,
spatialPredicate, extraCondition, distance)) =>
- planBroadcastJoin(left, right, Seq(leftShape, rightShape),
spatialPredicate, sedonaConf.getIndexType, broadcastLeft, extraCondition,
distance)
+ planBroadcastJoin(
+ left, right, Seq(leftShape, rightShape), joinType,
+ spatialPredicate, sedonaConf.getIndexType,
+ broadcastLeft, broadcastRight, extraCondition, distance)
case _ =>
Nil
}
} else {
queryDetection match {
case Some(JoinQueryDetection(left, right, leftShape, rightShape,
spatialPredicate, extraCondition, None)) =>
- planSpatialJoin(left, right, Seq(leftShape, rightShape),
spatialPredicate, extraCondition)
+ planSpatialJoin(left, right, Seq(leftShape, rightShape), joinType,
spatialPredicate, extraCondition)
case Some(JoinQueryDetection(left, right, leftShape, rightShape,
spatialPredicate, extraCondition, Some(distance))) =>
- planDistanceJoin(left, right, Seq(leftShape, rightShape),
distance, spatialPredicate, extraCondition)
+ planDistanceJoin(left, right, Seq(leftShape, rightShape),
joinType, distance, spatialPredicate, extraCondition)
case None =>
Nil
}
@@ -153,11 +156,18 @@ class JoinQueryDetector(sparkSession: SparkSession)
extends Strategy {
None
}
- private def planSpatialJoin(left: LogicalPlan,
- right: LogicalPlan,
- children: Seq[Expression],
- spatialPredicate: SpatialPredicate,
- extraCondition: Option[Expression] = None):
Seq[SparkPlan] = {
+ private def planSpatialJoin(
+ left: LogicalPlan,
+ right: LogicalPlan,
+ children: Seq[Expression],
+ joinType: JoinType,
+ spatialPredicate: SpatialPredicate,
+ extraCondition: Option[Expression] = None): Seq[SparkPlan] = {
+
+ if (joinType != Inner) {
+ return Nil
+ }
+
val a = children.head
val b = children.tail.head
@@ -175,12 +185,19 @@ class JoinQueryDetector(sparkSession: SparkSession)
extends Strategy {
}
}
- private def planDistanceJoin(left: LogicalPlan,
- right: LogicalPlan,
- children: Seq[Expression],
- distance: Expression,
- spatialPredicate: SpatialPredicate,
- extraCondition: Option[Expression] = None):
Seq[SparkPlan] = {
+ private def planDistanceJoin(
+ left: LogicalPlan,
+ right: LogicalPlan,
+ children: Seq[Expression],
+ joinType: JoinType,
+ distance: Expression,
+ spatialPredicate: SpatialPredicate,
+ extraCondition: Option[Expression] = None): Seq[SparkPlan] = {
+
+ if (joinType != Inner) {
+ return Nil
+ }
+
val a = children.head
val b = children.tail.head
@@ -206,14 +223,32 @@ class JoinQueryDetector(sparkSession: SparkSession)
extends Strategy {
}
}
- private def planBroadcastJoin(left: LogicalPlan,
- right: LogicalPlan,
- children: Seq[Expression],
- spatialPredicate: SpatialPredicate,
- indexType: IndexType,
- broadcastLeft: Boolean,
- extraCondition: Option[Expression],
- distance: Option[Expression]): Seq[SparkPlan]
= {
+ private def planBroadcastJoin(
+ left: LogicalPlan,
+ right: LogicalPlan,
+ children: Seq[Expression],
+ joinType: JoinType,
+ spatialPredicate: SpatialPredicate,
+ indexType: IndexType,
+ broadcastLeft: Boolean,
+ broadcastRight: Boolean,
+ extraCondition: Option[Expression],
+ distance: Option[Expression]): Seq[SparkPlan] = {
+
+ val broadcastSide = joinType match {
+ case Inner if broadcastLeft => Some(LeftSide)
+ case Inner if broadcastRight => Some(RightSide)
+ case LeftSemi if broadcastRight => Some(RightSide)
+ case LeftAnti if broadcastRight => Some(RightSide)
+ case LeftOuter if broadcastRight => Some(RightSide)
+ case RightOuter if broadcastLeft => Some(LeftSide)
+ case _ => None
+ }
+
+ if (broadcastSide.isEmpty) {
+ return Nil
+ }
+
val a = children.head
val b = children.tail.head
@@ -226,8 +261,7 @@ class JoinQueryDetector(sparkSession: SparkSession) extends
Strategy {
matchExpressionsToPlans(a, b, left, right) match {
case Some((_, _, swapped)) =>
logInfo(s"Planning spatial join for $relationship relationship")
- val broadcastSide = if (broadcastLeft) LeftSide else RightSide
- val (leftPlan, rightPlan, streamShape, windowSide) = (broadcastSide,
swapped) match {
+ val (leftPlan, rightPlan, streamShape, windowSide) =
(broadcastSide.get, swapped) match {
case (LeftSide, false) => // Broadcast the left side, windows on the
left
(SpatialIndexExec(planLater(left), a, indexType, distance),
planLater(right), b, LeftSide)
case (LeftSide, true) => // Broadcast the left side, objects on the
left
@@ -237,7 +271,7 @@ class JoinQueryDetector(sparkSession: SparkSession) extends
Strategy {
case (RightSide, true) => // Broadcast the right side, objects on
the left
(planLater(left), SpatialIndexExec(planLater(right), a, indexType,
distance), b, RightSide)
}
- BroadcastIndexJoinExec(leftPlan, rightPlan, streamShape,
broadcastSide, windowSide, spatialPredicate, extraCondition, distance) :: Nil
+ BroadcastIndexJoinExec(leftPlan, rightPlan, streamShape,
broadcastSide.get, windowSide, joinType, spatialPredicate, extraCondition,
distance) :: Nil
case None =>
logInfo(
s"Spatial join for $relationship with arguments not aligned " +
diff --git
a/sql/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala
b/sql/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala
index 445af754..30a22f08 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala
@@ -19,37 +19,43 @@
package org.apache.sedona.sql
+import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoinExec
import org.apache.spark.sql.sedona_sql.strategy.join.BroadcastIndexJoinExec
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.Row
class BroadcastIndexJoinSuite extends TestBaseScala {
- describe("Sedona-SQL Broadcast Index Join Test") {
+ describe("Sedona-SQL Broadcast Index Join Test for inner joins") {
// Using UDFs rather than lit prevents optimizations that would circumvent
the checks we want to test
- val one = udf(() => 1)
- val two = udf(() => 2)
+ val one = udf(() => 1).asNondeterministic()
+ val two = udf(() => 2).asNondeterministic()
it("Passed Correct partitioning for broadcast join for ST_Polygon and
ST_Point") {
val polygonDf = buildPolygonDf.repartition(3)
val pointDf = buildPointDf.repartition(5)
- var broadcastJoinDf =
pointDf.alias("pointDf").join(broadcast(polygonDf).alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+ var broadcastJoinDf = pointDf.alias("pointDf").join(
+ broadcast(polygonDf).alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p:
BroadcastIndexJoinExec => p }.size === 1)
assert(broadcastJoinDf.rdd.getNumPartitions ==
pointDf.rdd.getNumPartitions)
assert(broadcastJoinDf.count() == 1000)
- broadcastJoinDf =
broadcast(polygonDf).alias("polygonDf").join(pointDf.alias("pointDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+ broadcastJoinDf = broadcast(polygonDf).alias("polygonDf").join(
+ pointDf.alias("pointDf"), expr("ST_Contains(polygonDf.polygonshape,
pointDf.pointshape)"))
assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p:
BroadcastIndexJoinExec => p }.size === 1)
assert(broadcastJoinDf.rdd.getNumPartitions ==
pointDf.rdd.getNumPartitions)
assert(broadcastJoinDf.count() == 1000)
- broadcastJoinDf =
broadcast(pointDf).alias("pointDf").join(polygonDf.alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+ broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+ polygonDf.alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p:
BroadcastIndexJoinExec => p }.size === 1)
assert(broadcastJoinDf.rdd.getNumPartitions ==
polygonDf.rdd.getNumPartitions)
assert(broadcastJoinDf.count() == 1000)
- broadcastJoinDf =
polygonDf.alias("polygonDf").join(broadcast(pointDf).alias("pointDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+ broadcastJoinDf = polygonDf.alias("polygonDf").join(
+ broadcast(pointDf).alias("pointDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p:
BroadcastIndexJoinExec => p }.size === 1)
assert(broadcastJoinDf.rdd.getNumPartitions ==
polygonDf.rdd.getNumPartitions)
assert(broadcastJoinDf.count() == 1000)
@@ -59,7 +65,8 @@ class BroadcastIndexJoinSuite extends TestBaseScala {
val polygonDf = buildPolygonDf.repartition(3)
val pointDf = buildPointDf.repartition(5)
- var broadcastJoinDf =
broadcast(pointDf).alias("pointDf").join(broadcast(polygonDf).alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+ val broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+ broadcast(polygonDf).alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p:
BroadcastIndexJoinExec => p }.size === 1)
assert(broadcastJoinDf.rdd.getNumPartitions ==
polygonDf.rdd.getNumPartitions)
assert(broadcastJoinDf.count() == 1000)
@@ -68,20 +75,28 @@ class BroadcastIndexJoinSuite extends TestBaseScala {
it("Passed Can access attributes of both sides of broadcast join") {
val polygonDf = buildPolygonDf.withColumn("window_extra", one())
val pointDf = buildPointDf.withColumn("object_extra", one())
-
- var broadcastJoinDf =
polygonDf.alias("polygonDf").join(broadcast(pointDf).alias("pointDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+
+ var broadcastJoinDf = polygonDf.alias("polygonDf").join(
+ broadcast(pointDf).alias("pointDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p:
BroadcastIndexJoinExec => p }.size === 1)
assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0) ==
1000)
assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0) ==
1000)
- broadcastJoinDf =
broadcast(polygonDf).alias("polygonDf").join(pointDf.alias("pointDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+ broadcastJoinDf = broadcast(polygonDf).alias("polygonDf").join(
+ pointDf.alias("pointDf"), expr("ST_Contains(polygonDf.polygonshape,
pointDf.pointshape)"))
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p:
BroadcastIndexJoinExec => p }.size === 1)
assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0) ==
1000)
assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0) ==
1000)
- broadcastJoinDf =
broadcast(pointDf).alias("pointDf").join(polygonDf.alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+ broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+ polygonDf.alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p:
BroadcastIndexJoinExec => p }.size === 1)
assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0) ==
1000)
assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0) ==
1000)
- broadcastJoinDf =
pointDf.alias("pointDf").join(broadcast(polygonDf).alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+ broadcastJoinDf = pointDf.alias("pointDf").join(
+ broadcast(polygonDf).alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p:
BroadcastIndexJoinExec => p }.size === 1)
assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0) ==
1000)
assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0) ==
1000)
}
@@ -157,48 +172,56 @@ class BroadcastIndexJoinSuite extends TestBaseScala {
}
it("Passed ST_Distance <= distance in a broadcast join") {
- var pointDf1 = buildPointDf
- var pointDf2 = buildPointDf
+ val pointDf1 = buildPointDf
+ val pointDf2 = buildPointDf
- var distanceJoinDf =
pointDf1.alias("pointDf1").join(broadcast(pointDf2).alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"))
+ var distanceJoinDf = pointDf1.alias("pointDf1").join(
+ broadcast(pointDf2).alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"))
assert(distanceJoinDf.queryExecution.sparkPlan.collect{ case p:
BroadcastIndexJoinExec => p }.size === 1)
assert(distanceJoinDf.count() == 2998)
- distanceJoinDf =
broadcast(pointDf1).alias("pointDf1").join(pointDf2.alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"))
+ distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+ pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape,
pointDf2.pointshape) <= 2"))
assert(distanceJoinDf.queryExecution.sparkPlan.collect{ case p:
BroadcastIndexJoinExec => p }.size === 1)
assert(distanceJoinDf.count() == 2998)
}
it("Passed ST_Distance < distance in a broadcast join") {
- var pointDf1 = buildPointDf
- var pointDf2 = buildPointDf
+ val pointDf1 = buildPointDf
+ val pointDf2 = buildPointDf
- var distanceJoinDf =
pointDf1.alias("pointDf1").join(broadcast(pointDf2).alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < 2"))
+ var distanceJoinDf = pointDf1.alias("pointDf1").join(
+ broadcast(pointDf2).alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < 2"))
assert(distanceJoinDf.queryExecution.sparkPlan.collect{ case p:
BroadcastIndexJoinExec => p }.size === 1)
assert(distanceJoinDf.count() == 2998)
- distanceJoinDf =
broadcast(pointDf1).alias("pointDf1").join(pointDf2.alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < 2"))
+ distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+ pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape,
pointDf2.pointshape) < 2"))
assert(distanceJoinDf.queryExecution.sparkPlan.collect{ case p:
BroadcastIndexJoinExec => p }.size === 1)
assert(distanceJoinDf.count() == 2998)
}
it("Passed ST_Distance distance is bound to first expression") {
- var pointDf1 = buildPointDf.withColumn("radius", two())
- var pointDf2 = buildPointDf
+ val pointDf1 = buildPointDf.withColumn("radius", two())
+ val pointDf2 = buildPointDf
- var distanceJoinDf =
pointDf1.alias("pointDf1").join(broadcast(pointDf2).alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"))
+ var distanceJoinDf = pointDf1.alias("pointDf1").join(
+ broadcast(pointDf2).alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"))
assert(distanceJoinDf.queryExecution.sparkPlan.collect{ case p:
BroadcastIndexJoinExec => p }.size === 1)
assert(distanceJoinDf.count() == 2998)
- distanceJoinDf =
broadcast(pointDf1).alias("pointDf1").join(pointDf2.alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"))
+ distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+ pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape,
pointDf2.pointshape) < radius"))
assert(distanceJoinDf.queryExecution.sparkPlan.collect{ case p:
BroadcastIndexJoinExec => p }.size === 1)
assert(distanceJoinDf.count() == 2998)
- distanceJoinDf =
pointDf2.alias("pointDf2").join(broadcast(pointDf1).alias("pointDf1"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"))
+ distanceJoinDf = pointDf2.alias("pointDf2").join(
+ broadcast(pointDf1).alias("pointDf1"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"))
assert(distanceJoinDf.queryExecution.sparkPlan.collect{ case p:
BroadcastIndexJoinExec => p }.size === 1)
assert(distanceJoinDf.count() == 2998)
- distanceJoinDf =
broadcast(pointDf2).alias("pointDf2").join(pointDf1.alias("pointDf1"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"))
+ distanceJoinDf = broadcast(pointDf2).alias("pointDf2").join(
+ pointDf1.alias("pointDf1"), expr("ST_Distance(pointDf1.pointshape,
pointDf2.pointshape) < radius"))
assert(distanceJoinDf.queryExecution.sparkPlan.collect{ case p:
BroadcastIndexJoinExec => p }.size === 1)
assert(distanceJoinDf.count() == 2998)
}
@@ -208,22 +231,26 @@ class BroadcastIndexJoinSuite extends TestBaseScala {
val polygonDf = buildPolygonDf.repartition(3)
val pointDf = buildPointDf.repartition(5)
- var broadcastJoinDf =
pointDf.alias("pointDf").join(broadcast(polygonDf).alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+ var broadcastJoinDf = pointDf.alias("pointDf").join(
+ broadcast(polygonDf).alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p:
BroadcastIndexJoinExec => p }.size === 1)
assert(broadcastJoinDf.rdd.getNumPartitions ==
pointDf.rdd.getNumPartitions)
assert(broadcastJoinDf.count() == 1000)
- broadcastJoinDf =
broadcast(polygonDf).alias("polygonDf").join(pointDf.alias("pointDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+ broadcastJoinDf = broadcast(polygonDf).alias("polygonDf").join(
+ pointDf.alias("pointDf"), expr("ST_Contains(polygonDf.polygonshape,
pointDf.pointshape)"))
assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p:
BroadcastIndexJoinExec => p }.size === 1)
assert(broadcastJoinDf.rdd.getNumPartitions ==
pointDf.rdd.getNumPartitions)
assert(broadcastJoinDf.count() == 1000)
- broadcastJoinDf =
broadcast(pointDf).alias("pointDf").join(polygonDf.alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+ broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+ polygonDf.alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p:
BroadcastIndexJoinExec => p }.size === 1)
assert(broadcastJoinDf.rdd.getNumPartitions ==
polygonDf.rdd.getNumPartitions)
assert(broadcastJoinDf.count() == 1000)
- broadcastJoinDf =
polygonDf.alias("polygonDf").join(broadcast(pointDf).alias("pointDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+ broadcastJoinDf = polygonDf.alias("polygonDf").join(
+ broadcast(pointDf).alias("pointDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p:
BroadcastIndexJoinExec => p }.size === 1)
assert(broadcastJoinDf.rdd.getNumPartitions ==
polygonDf.rdd.getNumPartitions)
assert(broadcastJoinDf.count() == 1000)
@@ -246,5 +273,1139 @@ class BroadcastIndexJoinSuite extends TestBaseScala {
|on ST_Distance(a.geom, b.geom) < 1.5
|""".stripMargin).count() == 1)
}
+
+ it("Passed validate output rows") {
+ val left = sparkSession.createDataFrame(Seq(
+ (1.0, 1.0, "left_1"),
+ (2.0, 2.0, "left_2")
+ )).toDF("l_x", "l_y", "l_data")
+
+ val right = sparkSession.createDataFrame(Seq(
+ (2.0, 2.0, "right_2"),
+ (3.0, 3.0, "right_3")
+ )).toDF("r_x", "r_y", "r_data")
+
+ val joined = left.join(broadcast(right),
+ expr("ST_Intersects(ST_Point(l_x, l_y), ST_Point(r_x, r_y))"), "inner")
+ assert(joined.queryExecution.sparkPlan.collect{ case p:
BroadcastIndexJoinExec => p }.size === 1)
+
+ val rows = joined.collect()
+ assert(rows.length == 1)
+ assert(rows(0) == Row(2.0, 2.0, "left_2", 2.0, 2.0, "right_2"))
+
+ val joined2 = broadcast(left).join(right,
+ expr("ST_Intersects(ST_Point(l_x, l_y), ST_Point(r_x, r_y))"), "inner")
+ assert(joined.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+
+ val rows2 = joined2.collect()
+ assert(rows2.length == 1)
+ assert(rows2(0) == Row(2.0, 2.0, "left_2", 2.0, 2.0, "right_2"))
+
+ }
+ }
+
+ describe("Sedona-SQL Broadcast Index Join Test for left semi joins") {
+
+ // Using UDFs rather than lit prevents optimizations that would circumvent
the checks we want to test
+ val one = udf(() => 1).asNondeterministic()
+ val two = udf(() => 2).asNondeterministic()
+
+ it("Passed Correct partitioning for broadcast join for ST_Polygon and
ST_Point") {
+ val polygonDf = buildPolygonDf.repartition(3)
+ val pointDf = buildPointDf.repartition(5)
+
+ Seq(500, 900, 1000).foreach { limit =>
+ var broadcastJoinDf = pointDf.alias("pointDf").join(
+ broadcast(polygonDf.limit(limit)).alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_semi")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.rdd.getNumPartitions ==
pointDf.rdd.getNumPartitions)
+ assert(broadcastJoinDf.count() == limit)
+
+ broadcastJoinDf = broadcast(polygonDf).alias("polygonDf").join(
+ pointDf.limit(limit).alias("pointDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_semi")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == limit)
+
+ broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+ polygonDf.limit(limit).alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_semi")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == limit)
+
+ broadcastJoinDf = polygonDf.alias("polygonDf").join(
+ broadcast(pointDf.limit(limit)).alias("pointDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_semi")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.rdd.getNumPartitions ==
polygonDf.rdd.getNumPartitions)
+ assert(broadcastJoinDf.count() == limit)
+ }
+ }
+
+ it("Passed Broadcasts the right side if both sides have a broadcast hint")
{
+ val polygonDf = buildPolygonDf.repartition(3)
+ val pointDf = buildPointDf.repartition(5)
+
+ val broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+ broadcast(polygonDf).alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_semi")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.rdd.getNumPartitions ==
pointDf.rdd.getNumPartitions)
+ assert(broadcastJoinDf.count() == 1000)
+ }
+
+ it("Passed Can access attributes of left side of broadcast join") {
+ val polygonDf = buildPolygonDf.withColumn("window_extra", one())
+ val pointDf = buildPointDf.withColumn("object_extra", one())
+
+ var broadcastJoinDf = polygonDf.alias("polygonDf").join(
+ broadcast(pointDf).alias("pointDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_semi")
+ assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0) ==
1000)
+
+ broadcastJoinDf = broadcast(polygonDf).alias("polygonDf").join(
+ pointDf.alias("pointDf"), expr("ST_Contains(polygonDf.polygonshape,
pointDf.pointshape)"), "left_semi")
+ assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0) ==
1000)
+
+ broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+ polygonDf.alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_semi")
+ assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0) ==
1000)
+
+ broadcastJoinDf = pointDf.alias("pointDf").join(
+ broadcast(polygonDf).alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_semi")
+ assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0) ==
1000)
+ }
+
+ it("Passed Handles extra conditions on a broadcast join") {
+ val polygonDf = buildPolygonDf.withColumn("window_extra", one())
+ val pointDf = buildPointDf.withColumn("object_extra", two())
+
+ Seq(500, 900, 1000).foreach { limit =>
+ var broadcastJoinDf = pointDf
+ .alias("pointDf")
+ .join(
+ broadcast(polygonDf.limit(limit).alias("polygonDf")),
+ expr("ST_Contains(polygonshape, pointshape) AND window_extra <=
object_extra"),
+ "left_semi"
+ )
+
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == limit)
+
+ broadcastJoinDf = pointDf
+ .alias("pointDf")
+ .join(
+ broadcast(polygonDf.limit(limit).alias("polygonDf")),
+ expr("ST_Contains(polygonshape, pointshape) AND window_extra >
object_extra"),
+ "left_semi"
+ )
+
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 0)
+
+ broadcastJoinDf = pointDf
+ .alias("pointDf")
+ .join(
+ broadcast(polygonDf.limit(limit).alias("polygonDf")),
+ expr("window_extra <= object_extra AND ST_Contains(polygonshape,
pointshape)"),
+ "left_semi"
+ )
+
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == limit)
+
+ broadcastJoinDf = pointDf
+ .alias("pointDf")
+ .join(
+ broadcast(polygonDf.limit(limit).alias("polygonDf")),
+ expr("window_extra > object_extra AND ST_Contains(polygonshape,
pointshape)"),
+ "left_semi"
+ )
+
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 0)
+ }
+ }
+
+ it("Passed Handles multiple extra conditions on a broadcast join with the
ST predicate last") {
+ val polygonDf = buildPolygonDf.withColumn("window_extra",
one()).withColumn("window_extra2", one())
+ val pointDf = buildPointDf.withColumn("object_extra",
two()).withColumn("object_extra2", two())
+
+ Seq(500, 900, 1000).foreach { limit =>
+ var broadcastJoinDf = pointDf
+ .alias("pointDf")
+ .join(
+ broadcast(polygonDf.limit(limit).alias("polygonDf")),
+ expr("window_extra <= object_extra AND window_extra2 <=
object_extra2 AND ST_Contains(polygonshape, pointshape)"),
+ "left_semi"
+ )
+
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == limit)
+
+ broadcastJoinDf = pointDf
+ .alias("pointDf")
+ .join(
+ broadcast(polygonDf.limit(limit).alias("polygonDf")),
+ expr("window_extra > object_extra AND window_extra2 >
object_extra2 AND ST_Contains(polygonshape, pointshape)"),
+ "left_semi"
+ )
+
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 0)
+ }
+ }
+
+ it("Passed ST_Distance <= distance in a broadcast join") {
+ val pointDf1 = buildPointDf
+ val pointDf2 = buildPointDf
+
+ var distanceJoinDf = pointDf1.alias("pointDf1").join(
+ broadcast(pointDf2).alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"), "left_semi")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 1000)
+
+ distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+ pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape,
pointDf2.pointshape) <= 2"), "left_semi")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 1000)
+
+ distanceJoinDf = pointDf1.alias("pointDf1").join(
+ broadcast(pointDf2.limit(500)).alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"), "left_semi")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 501)
+
+ distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+ pointDf2.limit(500).alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"), "left_semi")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 501)
+ }
+
+ it("Passed ST_Distance < distance in a broadcast join") {
+ val pointDf1 = buildPointDf
+ val pointDf2 = buildPointDf
+
+ var distanceJoinDf = pointDf1.alias("pointDf1").join(
+ broadcast(pointDf2).alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < 2"), "left_semi")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 1000)
+
+ distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+ pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape,
pointDf2.pointshape) < 2"), "left_semi")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 1000)
+
+ distanceJoinDf = pointDf1.alias("pointDf1").join(
+ broadcast(pointDf2.limit(500)).alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < 2"), "left_semi")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 501)
+
+ distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+ pointDf2.limit(500).alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < 2"), "left_semi")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 501)
+ }
+
+ it("Passed ST_Distance distance is bound to first expression") {
+ val pointDf1 = buildPointDf.withColumn("radius", two())
+ val pointDf2 = buildPointDf
+
+ var distanceJoinDf = pointDf1.alias("pointDf1").join(
+ broadcast(pointDf2).alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"),
"left_semi")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 1000)
+
+ distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+ pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape,
pointDf2.pointshape) < radius"), "left_semi")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 1000)
+
+ distanceJoinDf = pointDf2.alias("pointDf2").join(
+ broadcast(pointDf1).alias("pointDf1"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"),
"left_semi")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 1000)
+
+ distanceJoinDf = broadcast(pointDf2).alias("pointDf2").join(
+ pointDf1.alias("pointDf1"), expr("ST_Distance(pointDf1.pointshape,
pointDf2.pointshape) < radius"), "left_semi")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 1000)
+ }
+
+ it("Passed Correct partitioning for broadcast join for ST_Polygon and
ST_Point with AQE enabled") {
+ sparkSession.conf.set("spark.sql.adaptive.enabled", true)
+ val polygonDf = buildPolygonDf.repartition(3)
+ val pointDf = buildPointDf.repartition(5)
+
+ var broadcastJoinDf = pointDf.alias("pointDf").join(
+ broadcast(polygonDf).alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_semi")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.rdd.getNumPartitions ==
pointDf.rdd.getNumPartitions)
+ assert(broadcastJoinDf.count() == 1000)
+
+ broadcastJoinDf = broadcast(polygonDf).alias("polygonDf").join(
+ pointDf.alias("pointDf"), expr("ST_Contains(polygonDf.polygonshape,
pointDf.pointshape)"), "left_semi")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 1000)
+
+ broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+ polygonDf.alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_semi")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 1000)
+
+ broadcastJoinDf = polygonDf.alias("polygonDf").join(
+ broadcast(pointDf).alias("pointDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_semi")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.rdd.getNumPartitions ==
polygonDf.rdd.getNumPartitions)
+ assert(broadcastJoinDf.count() == 1000)
+ sparkSession.conf.set("spark.sql.adaptive.enabled", false)
+ }
+
+ it("Passed validate output rows") {
+ val left = sparkSession.createDataFrame(Seq(
+ (1.0, 1.0, "left_1"),
+ (2.0, 2.0, "left_2")
+ )).toDF("l_x", "l_y", "l_data")
+
+ val right = sparkSession.createDataFrame(Seq(
+ (2.0, 2.0, "right_2"),
+ (3.0, 3.0, "right_3")
+ )).toDF("r_x", "r_y", "r_data")
+
+ val joined = left.join(broadcast(right),
+ expr("ST_Intersects(ST_Point(l_x, l_y), ST_Point(r_x, r_y))"),
"left_semi")
+ assert(joined.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+
+ val rows = joined.collect()
+ assert(rows.length == 1)
+ assert(rows(0) == Row(2.0, 2.0, "left_2"))
+ }
+ }
+
+ describe("Sedona-SQL Broadcast Index Join Test for left anti joins") {
+
+ // Using UDFs rather than lit prevents optimizations that would circumvent
the checks we want to test
+ val one = udf(() => 1).asNondeterministic()
+ val two = udf(() => 2).asNondeterministic()
+
+ it("Passed Correct partitioning for broadcast join for ST_Polygon and
ST_Point") {
+ val polygonDf = buildPolygonDf.repartition(3)
+ val pointDf = buildPointDf.repartition(5)
+
+ Seq(500, 900, 1000).foreach { limit =>
+ var broadcastJoinDf = pointDf.alias("pointDf").join(
+ broadcast(polygonDf.limit(limit)).alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_anti")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 1000 - limit)
+
+ broadcastJoinDf = broadcast(polygonDf).alias("polygonDf").join(
+ pointDf.limit(limit).alias("pointDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_anti")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 1000 - limit)
+
+ broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+ polygonDf.limit(limit).alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_anti")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 1000 - limit)
+
+ broadcastJoinDf = polygonDf.alias("polygonDf").join(
+ broadcast(pointDf.limit(limit)).alias("pointDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_anti")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 1000 - limit)
+ }
+ }
+
+ it("Passed Broadcasts the right side if both sides have a broadcast hint")
{
+ val polygonDf = buildPolygonDf.repartition(3)
+ val pointDf = buildPointDf.repartition(5)
+
+ val broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+ broadcast(polygonDf).alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_anti")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.rdd.getNumPartitions ==
pointDf.rdd.getNumPartitions)
+ assert(broadcastJoinDf.count() == 0)
+ }
+
+ it("Passed Can access attributes of left side of broadcast join") {
+ val polygonDf = buildPolygonDf.withColumn("window_extra", one())
+ val pointDf = buildPointDf.withColumn("object_extra", one())
+
+ Seq(500, 900).foreach { limit =>
+ var broadcastJoinDf = polygonDf.alias("polygonDf").join(
+ broadcast(pointDf.limit(limit)).alias("pointDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_anti")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0)
== 1000 - limit)
+
+ broadcastJoinDf = broadcast(polygonDf).alias("polygonDf").join(
+ pointDf.limit(limit).alias("pointDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_anti")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0)
== 1000 - limit)
+
+ broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+ polygonDf.limit(limit).alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_anti")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0)
== 1000 - limit)
+
+ broadcastJoinDf = pointDf.alias("pointDf").join(
+ broadcast(polygonDf.limit(limit)).alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_anti")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0)
== 1000 - limit)
+ }
+ }
+
+ it("Passed Handles extra conditions on a broadcast join") {
+ val polygonDf = buildPolygonDf.withColumn("window_extra", one())
+ val pointDf = buildPointDf.withColumn("object_extra", two())
+
+ Seq(500, 900).foreach { limit =>
+ var broadcastJoinDf = pointDf
+ .alias("pointDf")
+ .join(
+ broadcast(polygonDf.limit(limit).alias("polygonDf")),
+ expr("ST_Contains(polygonshape, pointshape) AND window_extra <=
object_extra"),
+ "left_anti"
+ )
+
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 1000 - limit)
+
+ broadcastJoinDf = pointDf
+ .alias("pointDf")
+ .join(
+ broadcast(polygonDf.limit(limit).alias("polygonDf")),
+ expr("ST_Contains(polygonshape, pointshape) AND window_extra >
object_extra"),
+ "left_anti"
+ )
+
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 1000)
+
+ broadcastJoinDf = pointDf
+ .alias("pointDf")
+ .join(
+ broadcast(polygonDf.limit(limit).alias("polygonDf")),
+ expr("window_extra <= object_extra AND ST_Contains(polygonshape,
pointshape)"),
+ "left_anti"
+ )
+
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 1000 - limit)
+
+ broadcastJoinDf = pointDf
+ .alias("pointDf")
+ .join(
+ broadcast(polygonDf.limit(limit).alias("polygonDf")),
+ expr("window_extra > object_extra AND ST_Contains(polygonshape,
pointshape)"),
+ "left_anti"
+ )
+
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 1000)
+ }
+ }
+
+ it("Passed Handles multiple extra conditions on a broadcast join with the
ST predicate last") {
+ val polygonDf = buildPolygonDf.withColumn("window_extra",
one()).withColumn("window_extra2", one())
+ val pointDf = buildPointDf.withColumn("object_extra",
two()).withColumn("object_extra2", two())
+
+ var broadcastJoinDf = pointDf
+ .alias("pointDf")
+ .join(
+ broadcast(polygonDf.alias("polygonDf")),
+ expr("window_extra <= object_extra AND window_extra2 <=
object_extra2 AND ST_Contains(polygonshape, pointshape)"),
+ "left_anti"
+ )
+
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 0)
+
+ broadcastJoinDf = pointDf
+ .alias("pointDf")
+ .join(
+ broadcast(polygonDf.alias("polygonDf")),
+ expr("window_extra > object_extra AND window_extra2 > object_extra2
AND ST_Contains(polygonshape, pointshape)"),
+ "left_anti"
+ )
+
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 1000)
+ }
+
+ it("Passed ST_Distance <= distance in a broadcast join") {
+ val pointDf1 = buildPointDf
+ val pointDf2 = buildPointDf
+
+ var distanceJoinDf = pointDf1.alias("pointDf1").join(
+ broadcast(pointDf2).alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"), "left_anti")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 0)
+
+ distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+ pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape,
pointDf2.pointshape) <= 2"), "left_anti")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 0)
+
+ distanceJoinDf = pointDf1.alias("pointDf1").join(
+ broadcast(pointDf2.limit(500)).alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"), "left_anti")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 499)
+
+ distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+ pointDf2.limit(500).alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"), "left_anti")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 499)
+ }
+
+ it("Passed ST_Distance < distance in a broadcast join") {
+ val pointDf1 = buildPointDf
+ val pointDf2 = buildPointDf
+
+ var distanceJoinDf = pointDf1.alias("pointDf1").join(
+ broadcast(pointDf2).alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < 2"), "left_anti")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 0)
+
+ distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+ pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape,
pointDf2.pointshape) < 2"), "left_anti")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 0)
+
+ distanceJoinDf = pointDf1.alias("pointDf1").join(
+ broadcast(pointDf2.limit(500)).alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < 2"), "left_anti")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 499)
+
+ distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+ pointDf2.limit(500).alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < 2"), "left_anti")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 499)
+ }
+
+ it("Passed ST_Distance distance is bound to first expression") {
+ val pointDf1 = buildPointDf.withColumn("radius", two())
+ val pointDf2 = buildPointDf
+
+ var distanceJoinDf = pointDf1.alias("pointDf1").join(
+ broadcast(pointDf2).alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"),
"left_anti")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 0)
+
+ distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+ pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape,
pointDf2.pointshape) < radius"), "left_anti")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 0)
+
+ distanceJoinDf = pointDf2.alias("pointDf2").join(
+ broadcast(pointDf1).alias("pointDf1"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"),
"left_anti")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 0)
+
+ distanceJoinDf = broadcast(pointDf2).alias("pointDf2").join(
+ pointDf1.alias("pointDf1"), expr("ST_Distance(pointDf1.pointshape,
pointDf2.pointshape) < radius"), "left_anti")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 0)
+ }
+
+ it("Passed Correct partitioning for broadcast join for ST_Polygon and
ST_Point with AQE enabled") {
+ sparkSession.conf.set("spark.sql.adaptive.enabled", true)
+ val polygonDf = buildPolygonDf.repartition(3)
+ val pointDf = buildPointDf.repartition(5)
+
+ var broadcastJoinDf = pointDf.alias("pointDf").join(
+ broadcast(polygonDf).alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_anti")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 0)
+
+ broadcastJoinDf = broadcast(polygonDf).alias("polygonDf").join(
+ pointDf.alias("pointDf"), expr("ST_Contains(polygonDf.polygonshape,
pointDf.pointshape)"), "left_anti")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 0)
+
+ broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+ polygonDf.alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_anti")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 0)
+
+ broadcastJoinDf = polygonDf.alias("polygonDf").join(
+ broadcast(pointDf).alias("pointDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_anti")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 0)
+ sparkSession.conf.set("spark.sql.adaptive.enabled", false)
+ }
+
+ it("Passed validate output rows") {
+ val left = sparkSession.createDataFrame(Seq(
+ (1.0, 1.0, "left_1"),
+ (2.0, 2.0, "left_2")
+ )).toDF("l_x", "l_y", "l_data")
+
+ val right = sparkSession.createDataFrame(Seq(
+ (2.0, 2.0, "right_2"),
+ (3.0, 3.0, "right_3")
+ )).toDF("r_x", "r_y", "r_data")
+
+ val joined = left.join(broadcast(right),
+ expr("ST_Intersects(ST_Point(l_x, l_y), ST_Point(r_x, r_y))"),
"left_anti")
+ assert(joined.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+
+ val rows = joined.collect()
+ assert(rows.length == 1)
+ assert(rows(0) == Row(1.0, 1.0, "left_1"))
+ }
+ }
+
+ describe("Sedona-SQL Broadcast Index Join Test for left outer joins") {
+
+ // Using UDFs rather than lit prevents optimizations that would circumvent
the checks we want to test
+ val one = udf(() => 1).asNondeterministic()
+ val two = udf(() => 2).asNondeterministic()
+
+ it("Passed Correct partitioning for broadcast join for ST_Polygon and
ST_Point") {
+ val polygonDf = buildPolygonDf.repartition(3)
+ val pointDf = buildPointDf.repartition(5)
+
+ Seq(500, 900, 1000).foreach { limit =>
+ var broadcastJoinDf = pointDf.alias("pointDf").join(
+ broadcast(polygonDf.limit(limit)).alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"),
+ "left_outer")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.rdd.getNumPartitions ==
pointDf.rdd.getNumPartitions)
+ assert(broadcastJoinDf.count() == 1000)
+
+ broadcastJoinDf = broadcast(polygonDf).alias("polygonDf").join(
+ pointDf.limit(limit).alias("pointDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"),
+ "left_outer")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 1000)
+
+ broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+ polygonDf.limit(limit).alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"),
+ "left_outer")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 1000)
+
+ broadcastJoinDf = polygonDf.alias("polygonDf").join(
+ broadcast(pointDf.limit(limit)).alias("pointDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"),
+ "left_outer")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.rdd.getNumPartitions ==
polygonDf.rdd.getNumPartitions)
+ assert(broadcastJoinDf.count() == 1000)
+ }
+ }
+
+ it("Passed Broadcasts the left side if both sides have a broadcast hint") {
+ val polygonDf = buildPolygonDf.repartition(3)
+ val pointDf = buildPointDf.repartition(5)
+
+ val broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+ broadcast(polygonDf).alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_outer")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.rdd.getNumPartitions ==
pointDf.rdd.getNumPartitions)
+ assert(broadcastJoinDf.count() == 1000)
+ }
+
+ it("Passed Can access attributes of both sides of broadcast join") {
+ val polygonDf = buildPolygonDf.withColumn("window_extra", one())
+ val pointDf = buildPointDf.withColumn("object_extra", one())
+
+ Seq(500, 900, 1000).foreach { limit =>
+ var broadcastJoinDf = polygonDf.alias("polygonDf").join(
+ broadcast(pointDf.limit(limit)).alias("pointDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_outer")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0)
== limit)
+ assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0)
== 1000)
+
+ broadcastJoinDf = broadcast(polygonDf).alias("polygonDf").join(
+ pointDf.limit(limit).alias("pointDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_outer")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0)
== limit)
+ assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0)
== 1000)
+
+ broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+ polygonDf.limit(limit).alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_outer")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0)
== 1000)
+ assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0)
== limit)
+
+ broadcastJoinDf = pointDf.alias("pointDf").join(
+ broadcast(polygonDf.limit(limit)).alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_outer")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0)
== 1000)
+ assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0)
== limit)
+ }
+ }
+
+ it("Passed Handles extra conditions on a broadcast join") {
+ val polygonDf = buildPolygonDf.withColumn("window_extra", one())
+ val pointDf = buildPointDf.withColumn("object_extra", two())
+
+ Seq(500, 900, 1000).foreach { limit =>
+ var broadcastJoinDf = pointDf
+ .alias("pointDf")
+ .join(
+ broadcast(polygonDf.limit(limit).alias("polygonDf")),
+ expr("ST_Contains(polygonshape, pointshape) AND window_extra <=
object_extra"),
+ "left_outer"
+ )
+
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 1000)
+
+ broadcastJoinDf = pointDf
+ .alias("pointDf")
+ .join(
+ broadcast(polygonDf.limit(limit).alias("polygonDf")),
+ expr("ST_Contains(polygonshape, pointshape) AND window_extra >
object_extra"),
+ "left_outer"
+ )
+
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 1000)
+
+ broadcastJoinDf = pointDf
+ .alias("pointDf")
+ .join(
+ broadcast(polygonDf.limit(limit).alias("polygonDf")),
+ expr("window_extra <= object_extra AND ST_Contains(polygonshape,
pointshape)"),
+ "left_outer"
+ )
+
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 1000)
+
+ broadcastJoinDf = pointDf
+ .alias("pointDf")
+ .join(
+ broadcast(polygonDf.limit(limit).alias("polygonDf")),
+ expr("window_extra > object_extra AND ST_Contains(polygonshape,
pointshape)"),
+ "left_outer"
+ )
+
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 1000)
+ }
+ }
+
+ it("Passed Handles multiple extra conditions on a broadcast join with the
ST predicate last") {
+ val polygonDf = buildPolygonDf.withColumn("window_extra",
one()).withColumn("window_extra2", one())
+ val pointDf = buildPointDf.withColumn("object_extra",
two()).withColumn("object_extra2", two())
+
+ var broadcastJoinDf = pointDf
+ .alias("pointDf")
+ .join(
+ broadcast(polygonDf.alias("polygonDf")),
+ expr("window_extra <= object_extra AND window_extra2 <=
object_extra2 AND ST_Contains(polygonshape, pointshape)"),
+ "left_outer"
+ )
+
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 1000)
+
+ broadcastJoinDf = pointDf
+ .alias("pointDf")
+ .join(
+ broadcast(polygonDf.alias("polygonDf")),
+ expr("window_extra > object_extra AND window_extra2 > object_extra2
AND ST_Contains(polygonshape, pointshape)"),
+ "left_outer"
+ )
+
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 1000)
+ }
+
+ it("Passed ST_Distance <= distance in a broadcast join") {
+ val pointDf1 = buildPointDf
+ val pointDf2 = buildPointDf
+
+ var distanceJoinDf = pointDf1.alias("pointDf1").join(
+ broadcast(pointDf2).alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"),
"left_outer")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 2998)
+
+ distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+ pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape,
pointDf2.pointshape) <= 2"), "left_outer")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 2998)
+
+ distanceJoinDf = pointDf1.alias("pointDf1").join(
+ broadcast(pointDf2.limit(500)).alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"),
"left_outer")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 1998)
+
+ distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+ pointDf2.limit(500).alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"),
"left_outer")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 1998)
+ }
+
+ it("Passed ST_Distance < distance in a broadcast join") {
+ val pointDf1 = buildPointDf
+ val pointDf2 = buildPointDf
+
+ var distanceJoinDf = pointDf1.alias("pointDf1").join(
+ broadcast(pointDf2).alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < 2"), "left_outer")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 2998)
+
+ distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+ pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape,
pointDf2.pointshape) < 2"), "left_outer")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 2998)
+
+ distanceJoinDf = pointDf1.alias("pointDf1").join(
+ broadcast(pointDf2.limit(500)).alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < 2"), "left_outer")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 1998)
+
+ distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+ pointDf2.limit(500).alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < 2"), "left_outer")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 1998)
+ }
+
+ it("Passed ST_Distance distance is bound to first expression") {
+ val pointDf1 = buildPointDf.withColumn("radius", two())
+ val pointDf2 = buildPointDf
+
+ var distanceJoinDf = pointDf1.alias("pointDf1").join(
+ broadcast(pointDf2).alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"),
"left_outer")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 2998)
+
+ distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+ pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape,
pointDf2.pointshape) < radius"), "left_outer")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 2998)
+
+ distanceJoinDf = pointDf2.alias("pointDf2").join(
+ broadcast(pointDf1).alias("pointDf1"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"),
"left_outer")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 2998)
+
+ distanceJoinDf = broadcast(pointDf2).alias("pointDf2").join(
+ pointDf1.alias("pointDf1"), expr("ST_Distance(pointDf1.pointshape,
pointDf2.pointshape) < radius"), "left_outer")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 2998)
+ }
+
+ it("Passed Correct partitioning for broadcast join for ST_Polygon and
ST_Point with AQE enabled") {
+ sparkSession.conf.set("spark.sql.adaptive.enabled", true)
+ val polygonDf = buildPolygonDf.repartition(3)
+ val pointDf = buildPointDf.repartition(5)
+
+ var broadcastJoinDf = pointDf.alias("pointDf").join(
+ broadcast(polygonDf).alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_outer")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.rdd.getNumPartitions ==
pointDf.rdd.getNumPartitions)
+ assert(broadcastJoinDf.count() == 1000)
+
+ broadcastJoinDf = broadcast(polygonDf).alias("polygonDf").join(
+ pointDf.alias("pointDf"), expr("ST_Contains(polygonDf.polygonshape,
pointDf.pointshape)"), "left_outer")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 1000)
+
+ broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+ polygonDf.alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_outer")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 1000)
+
+ broadcastJoinDf = polygonDf.alias("polygonDf").join(
+ broadcast(pointDf).alias("pointDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_outer")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.rdd.getNumPartitions ==
polygonDf.rdd.getNumPartitions)
+ assert(broadcastJoinDf.count() == 1000)
+ sparkSession.conf.set("spark.sql.adaptive.enabled", false)
+ }
+
+ it("Passed validate output rows") {
+ val left = sparkSession.createDataFrame(Seq(
+ (1.0, 1.0, "left_1"),
+ (2.0, 2.0, "left_2")
+ )).toDF("l_x", "l_y", "l_data")
+
+ val right = sparkSession.createDataFrame(Seq(
+ (2.0, 2.0, "right_2"),
+ (3.0, 3.0, "right_3")
+ )).toDF("r_x", "r_y", "r_data")
+
+ val joined = left.join(broadcast(right),
+ expr("ST_Intersects(ST_Point(l_x, l_y), ST_Point(r_x, r_y))"),
"left_outer")
+ assert(joined.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+
+ val rows = joined.collect()
+ assert(rows.length == 2)
+ assert(rows(0) == Row(1.0, 1.0, "left_1", null, null, null))
+ assert(rows(1) == Row(2.0, 2.0, "left_2", 2.0, 2.0, "right_2"))
+ }
+ }
+
+ describe("Sedona-SQL Broadcast Index Join Test for right outer joins") {
+
+ // Using UDFs rather than lit prevents optimizations that would circumvent
the checks we want to test
+ val one = udf(() => 1).asNondeterministic()
+ val two = udf(() => 2).asNondeterministic()
+
+ it("Passed Correct partitioning for broadcast join for ST_Polygon and
ST_Point") {
+ val polygonDf = buildPolygonDf.repartition(3)
+ val pointDf = buildPointDf.repartition(5)
+
+ Seq(500, 900, 1000).foreach { limit =>
+ var broadcastJoinDf = pointDf.alias("pointDf").join(
+ broadcast(polygonDf.limit(limit)).alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "right_outer")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == limit)
+
+ broadcastJoinDf = broadcast(polygonDf).alias("polygonDf").join(
+ pointDf.limit(limit).alias("pointDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "right_outer")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == limit)
+
+ broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+ polygonDf.limit(limit).alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "right_outer")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == limit)
+
+ broadcastJoinDf = polygonDf.alias("polygonDf").join(
+ broadcast(pointDf.limit(limit)).alias("pointDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "right_outer")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == limit)
+ }
+ }
+
+ it("Passed Broadcasts the left side if both sides have a broadcast hint") {
+ val polygonDf = buildPolygonDf.repartition(3)
+ val pointDf = buildPointDf.repartition(5)
+
+ val broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+ broadcast(polygonDf).alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "right_outer")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.rdd.getNumPartitions ==
polygonDf.rdd.getNumPartitions)
+ assert(broadcastJoinDf.count() == 1000)
+ }
+
+ it("Passed Can access attributes of both sides of broadcast join") {
+ val polygonDf = buildPolygonDf.withColumn("window_extra", one())
+ val pointDf = buildPointDf.withColumn("object_extra", one())
+
+ var broadcastJoinDf = polygonDf.alias("polygonDf").join(
+ broadcast(pointDf).alias("pointDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "right_outer")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0) ==
1000)
+ assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0) ==
1000)
+
+ broadcastJoinDf = broadcast(polygonDf).alias("polygonDf").join(
+ pointDf.alias("pointDf"), expr("ST_Contains(polygonDf.polygonshape,
pointDf.pointshape)"), "right_outer")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0) ==
1000)
+ assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0) ==
1000)
+
+ broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+ polygonDf.alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "right_outer")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0) ==
1000)
+ assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0) ==
1000)
+
+ broadcastJoinDf = pointDf.alias("pointDf").join(
+ broadcast(polygonDf).alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "right_outer")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0) ==
1000)
+ assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0) ==
1000)
+ }
+
+ it("Passed Handles extra conditions on a broadcast join") {
+ val polygonDf = buildPolygonDf.withColumn("window_extra", one())
+ val pointDf = buildPointDf.withColumn("object_extra", two())
+
+ var broadcastJoinDf = pointDf
+ .alias("pointDf")
+ .join(
+ broadcast(polygonDf.alias("polygonDf")),
+ expr("ST_Contains(polygonshape, pointshape) AND window_extra <=
object_extra"),
+ "right_outer"
+ )
+
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 1000)
+
+ broadcastJoinDf = broadcast(pointDf)
+ .alias("pointDf")
+ .join(
+ polygonDf.alias("polygonDf"),
+ expr("ST_Contains(polygonshape, pointshape) AND window_extra >
object_extra"),
+ "right_outer"
+ )
+
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 1000)
+
+ broadcastJoinDf = broadcast(pointDf)
+ .alias("pointDf")
+ .join(
+ polygonDf.alias("polygonDf"),
+ expr("window_extra <= object_extra AND ST_Contains(polygonshape,
pointshape)"),
+ "right_outer"
+ )
+
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 1000)
+
+ broadcastJoinDf = pointDf
+ .alias("pointDf")
+ .join(
+ broadcast(polygonDf.alias("polygonDf")),
+ expr("window_extra > object_extra AND ST_Contains(polygonshape,
pointshape)"),
+ "right_outer"
+ )
+
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 1000)
+ }
+
+ it("Passed Handles multiple extra conditions on a broadcast join with the
ST predicate last") {
+ val polygonDf = buildPolygonDf.withColumn("window_extra",
one()).withColumn("window_extra2", one())
+ val pointDf = buildPointDf.withColumn("object_extra",
two()).withColumn("object_extra2", two())
+
+ var broadcastJoinDf = broadcast(pointDf)
+ .alias("pointDf")
+ .join(
+ polygonDf.alias("polygonDf"),
+ expr("window_extra <= object_extra AND window_extra2 <=
object_extra2 AND ST_Contains(polygonshape, pointshape)"),
+ "right_outer"
+ )
+
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 1000)
+
+ broadcastJoinDf = broadcast(pointDf)
+ .alias("pointDf")
+ .join(
+ polygonDf.alias("polygonDf"),
+ expr("window_extra > object_extra AND window_extra2 > object_extra2
AND ST_Contains(polygonshape, pointshape)"),
+ "right_outer"
+ )
+
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 1000)
+ }
+
+ it("Passed ST_Distance <= distance in a broadcast join") {
+ val pointDf1 = buildPointDf
+ val pointDf2 = buildPointDf
+
+ var distanceJoinDf = pointDf1.alias("pointDf1").join(
+ broadcast(pointDf2).alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"),
"right_outer")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 2998)
+
+ distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+ pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape,
pointDf2.pointshape) <= 2"), "right_outer")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 2998)
+
+ distanceJoinDf = pointDf1.alias("pointDf1").join(
+ broadcast(pointDf2.limit(500)).alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"),
"right_outer")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 1499)
+
+ distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+ pointDf2.limit(500).alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"),
"right_outer")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 1499)
+ }
+
+ it("Passed ST_Distance < distance in a broadcast join") {
+ val pointDf1 = buildPointDf
+ val pointDf2 = buildPointDf
+
+ var distanceJoinDf = pointDf1.alias("pointDf1").join(
+ broadcast(pointDf2).alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < 2"),
"right_outer")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 2998)
+
+ distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+ pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape,
pointDf2.pointshape) < 2"), "right_outer")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 2998)
+
+ distanceJoinDf = pointDf1.alias("pointDf1").join(
+ broadcast(pointDf2.limit(500)).alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < 2"),
"right_outer")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 1499)
+
+ distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+ pointDf2.limit(500).alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < 2"),
"right_outer")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 1499)
+ }
+
+ it("Passed ST_Distance distance is bound to first expression") {
+ val pointDf1 = buildPointDf.withColumn("radius", two())
+ val pointDf2 = buildPointDf
+
+ var distanceJoinDf = pointDf1.alias("pointDf1").join(
+ broadcast(pointDf2).alias("pointDf2"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"),
"right_outer")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 2998)
+
+ distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+ pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape,
pointDf2.pointshape) < radius"), "right_outer")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 2998)
+
+ distanceJoinDf = pointDf2.alias("pointDf2").join(
+ broadcast(pointDf1).alias("pointDf1"),
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"),
"right_outer")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 2998)
+
+ distanceJoinDf = broadcast(pointDf2).alias("pointDf2").join(
+ pointDf1.alias("pointDf1"), expr("ST_Distance(pointDf1.pointshape,
pointDf2.pointshape) < radius"), "right_outer")
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 2998)
+ }
+
+ it("Passed Correct partitioning for broadcast join for ST_Polygon and
ST_Point with AQE enabled") {
+ sparkSession.conf.set("spark.sql.adaptive.enabled", true)
+ val polygonDf = buildPolygonDf.repartition(3)
+ val pointDf = buildPointDf.repartition(5)
+
+ var broadcastJoinDf = pointDf.alias("pointDf").join(
+ broadcast(polygonDf).alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "right_outer")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 1000)
+
+ broadcastJoinDf = broadcast(polygonDf).alias("polygonDf").join(
+ pointDf.alias("pointDf"), expr("ST_Contains(polygonDf.polygonshape,
pointDf.pointshape)"), "right_outer")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.rdd.getNumPartitions ==
pointDf.rdd.getNumPartitions)
+ assert(broadcastJoinDf.count() == 1000)
+
+ broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+ polygonDf.alias("polygonDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "right_outer")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.rdd.getNumPartitions ==
polygonDf.rdd.getNumPartitions)
+ assert(broadcastJoinDf.count() == 1000)
+
+ broadcastJoinDf = polygonDf.alias("polygonDf").join(
+ broadcast(pointDf).alias("pointDf"),
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "right_outer")
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p:
BroadcastNestedLoopJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 1000)
+ sparkSession.conf.set("spark.sql.adaptive.enabled", false)
+ }
+
+ it("Passed validate output rows") {
+ val left = sparkSession.createDataFrame(Seq(
+ (1.0, 1.0, "left_1"),
+ (2.0, 2.0, "left_2")
+ )).toDF("l_x", "l_y", "l_data")
+
+ val right = sparkSession.createDataFrame(Seq(
+ (2.0, 2.0, "right_2"),
+ (3.0, 3.0, "right_3")
+ )).toDF("r_x", "r_y", "r_data")
+
+ val joined = broadcast(left).join(right,
+ expr("ST_Intersects(ST_Point(l_x, l_y), ST_Point(r_x, r_y))"),
"right_outer")
+ assert(joined.queryExecution.sparkPlan.collect { case p:
BroadcastIndexJoinExec => p }.size === 1)
+
+ val rows = joined.collect()
+ assert(rows.length == 2)
+ assert(rows(0) == Row(2.0, 2.0, "left_2", 2.0, 2.0, "right_2"))
+ assert(rows(1) == Row(null, null, null, 3.0, 3.0, "right_3"))
+ }
}
+
}