This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new 8489b2d6 [SEDONA-191] Support join types other than inner in 
BroadcastIndexJoinExec (#711)
8489b2d6 is described below

commit 8489b2d604f8b5039707d805386292e0a3b9c245
Author: Tanel Kiis <[email protected]>
AuthorDate: Sat Nov 26 09:43:29 2022 +0200

    [SEDONA-191] Support join types other than inner in BroadcastIndexJoinExec 
(#711)
---
 docs/api/sql/Optimizer.md                          |   12 +-
 .../strategy/join/BroadcastIndexJoinExec.scala     |  222 +++-
 .../strategy/join/JoinQueryDetector.scala          |  102 +-
 .../sedona/sql/BroadcastIndexJoinSuite.scala       | 1223 +++++++++++++++++++-
 4 files changed, 1436 insertions(+), 123 deletions(-)

diff --git a/docs/api/sql/Optimizer.md b/docs/api/sql/Optimizer.md
index eab7dbff..35f05878 100644
--- a/docs/api/sql/Optimizer.md
+++ b/docs/api/sql/Optimizer.md
@@ -73,7 +73,15 @@ DistanceJoin pointshape1#12: geometry, pointshape2#33: 
geometry, 2.0, true
        Sedona doesn't control the distance's unit (degree or meter). It is 
same with the geometry. To change the geometry's unit, please transform the 
coordinate reference system. See [ST_Transform](Function.md#st_transform).
 
 ## Broadcast join
-Introduction: Perform a range join or distance join but broadcast one of the 
sides of the join. This maintains the partitioning of the non-broadcast side 
and doesn't require a shuffle.
+Introduction: Perform a range join or distance join but broadcast one of the 
sides of the join.
+This maintains the partitioning of the non-broadcast side and doesn't require 
a shuffle.
+Sedona uses broadcast join only if the correct side has a broadcast hint.
+The supported join type - broadcast side combinations are
+* Inner - either side, preferring to broadcast left if both sides have the hint
+* Left semi - broadcast right
+* Left anti - broadcast right
+* Left outer - broadcast right
+* Right outer - broadcast left
 
 ```Scala
 pointDf.alias("pointDf").join(broadcast(polygonDf).alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
@@ -107,7 +115,7 @@ BroadcastIndexJoin pointshape#52: geometry, BuildRight, 
BuildLeft, true, 2.0 ST_
       +- FileScan csv
 ```
 
-Note: Ff the distance is an expression, it is only evaluated on the first 
argument to ST_Distance (`pointDf1` above).
+Note: If the distance is an expression, it is only evaluated on the first 
argument to ST_Distance (`pointDf1` above).
 
 ## Predicate pushdown
 
diff --git 
a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala
 
b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala
index 0d089873..34ec41f1 100644
--- 
a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala
+++ 
b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala
@@ -18,19 +18,18 @@
  */
 package org.apache.spark.sql.sedona_sql.strategy.join
 
-import org.apache.sedona.core.spatialOperator.SpatialPredicate
-import org.apache.sedona.core.spatialOperator.SpatialPredicateEvaluators
+import org.apache.sedona.core.spatialOperator.{SpatialPredicate, 
SpatialPredicateEvaluators}
+import 
org.apache.sedona.core.spatialOperator.SpatialPredicateEvaluators.SpatialPredicateEvaluator
 
-import collection.JavaConverters._
-import org.apache.sedona.core.spatialRDD.SpatialRDD
+import scala.collection.JavaConverters._
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, 
Expression, Predicate, UnsafeRow}
-import 
org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
+import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, 
Expression, GenericInternalRow, JoinedRow, Predicate, UnsafeProjection, 
UnsafeRow}
 import org.apache.spark.sql.catalyst.plans.physical.Partitioning
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.execution.{RowIterator, SparkPlan}
 import org.apache.spark.sql.sedona_sql.execution.SedonaBinaryExecNode
 import org.locationtech.jts.geom.Geometry
 import org.locationtech.jts.geom.prep.{PreparedGeometry, 
PreparedGeometryFactory}
@@ -38,32 +37,59 @@ import org.locationtech.jts.index.SpatialIndex
 
 import scala.collection.mutable
 
-case class BroadcastIndexJoinExec(left: SparkPlan,
-                                  right: SparkPlan,
-                                  streamShape: Expression,
-                                  indexBuildSide: JoinSide,
-                                  windowJoinSide: JoinSide,
-                                  spatialPredicate: SpatialPredicate,
-                                  extraCondition: Option[Expression] = None,
-                                  distance: Option[Expression] = None)
+case class BroadcastIndexJoinExec(
+  left: SparkPlan,
+  right: SparkPlan,
+  streamShape: Expression,
+  indexBuildSide: JoinSide,
+  windowJoinSide: JoinSide,
+  joinType: JoinType,
+  spatialPredicate: SpatialPredicate,
+  extraCondition: Option[Expression] = None,
+  distance: Option[Expression] = None)
   extends SedonaBinaryExecNode
     with TraitJoinQueryBase
     with Logging {
 
-  override def output: Seq[Attribute] = left.output ++ right.output
+  override def output: Seq[Attribute] = {
+    joinType match {
+      case _: InnerLike =>
+        left.output ++ right.output
+      case LeftOuter =>
+        left.output ++ right.output.map(_.withNullability(true))
+      case RightOuter =>
+        left.output.map(_.withNullability(true)) ++ right.output
+      case j: ExistenceJoin =>
+        left.output :+ j.exists
+      case LeftExistence(_) =>
+        left.output
+      case x =>
+        throw new IllegalArgumentException(s"BroadcastIndexJoinExec should not 
take $x as the JoinType")
+    }
+  }
+
+  private val (streamed, broadcast) = indexBuildSide match {
+    case LeftSide => (right, left.asInstanceOf[SpatialIndexExec])
+    case RightSide => (left, right.asInstanceOf[SpatialIndexExec])
+  }
 
   // Using lazy val to avoid serialization
   @transient private lazy val boundCondition: (InternalRow => Boolean) = 
extraCondition match {
     case Some(condition) =>
-      Predicate.create(condition, output).eval _ // SPARK3 anchor
-//      newPredicate(condition, output).eval _ // SPARK2 anchor
+      Predicate.create(condition, streamed.output ++ broadcast.output).eval _ 
// SPARK3 anchor
+    //      newPredicate(condition, broadcast.output ++ streamed.output).eval 
_ // SPARK2 anchor
     case None =>
       (r: InternalRow) => true
   }
 
-  private val (streamed, broadcast) = indexBuildSide match {
-    case LeftSide => (right, left.asInstanceOf[SpatialIndexExec])
-    case RightSide => (left, right.asInstanceOf[SpatialIndexExec])
+  protected def createResultProjection(): InternalRow => InternalRow = 
joinType match {
+    case LeftExistence(_) =>
+      UnsafeProjection.create(output, output)
+    case _ =>
+      // Always put the stream side on left to simplify implementation
+      // both of left and right side could be null
+      UnsafeProjection.create(
+        output, (streamed.output ++ 
broadcast.output).map(_.withNullability(true)))
   }
 
   override def outputPartitioning: Partitioning = streamed.outputPartitioning
@@ -83,38 +109,120 @@ case class BroadcastIndexJoinExec(left: SparkPlan,
   override def simpleString(maxFields: Int): String = 
super.simpleString(maxFields) + s" $spatialExpression" // SPARK3 anchor
 //  override def simpleString: String = super.simpleString + s" 
$spatialExpression" // SPARK2 anchor
 
-  private def windowBroadcastJoin(index: Broadcast[SpatialIndex], spatialRdd: 
SpatialRDD[Geometry]): RDD[(Geometry, Geometry)] = {
-    spatialRdd.getRawSpatialRDD.rdd.mapPartitions { rows =>
-      val factory = new PreparedGeometryFactory()
-      val preparedGeometries = new mutable.HashMap[Geometry, PreparedGeometry]
-      val evaluator = SpatialPredicateEvaluators.create(spatialPredicate)
-      rows.flatMap { row =>
-        val candidates = 
index.value.query(row.getEnvelopeInternal).iterator.asScala.asInstanceOf[Iterator[Geometry]]
-        candidates
-          .filter(candidate => 
evaluator.eval(preparedGeometries.getOrElseUpdate(candidate, { 
factory.create(candidate) }), row))
-          .map(candidate => (candidate, row))
+  private lazy val evaluator: SpatialPredicateEvaluator = if (indexBuildSide 
== windowJoinSide) {
+    SpatialPredicateEvaluators.create(spatialPredicate)
+  } else {
+    
SpatialPredicateEvaluators.create(SpatialPredicate.inverse(spatialPredicate))
+  }
+
+  private def innerJoin(streamIter: Iterator[Geometry], index: 
Broadcast[SpatialIndex]): Iterator[InternalRow] = {
+    val factory = new PreparedGeometryFactory()
+    val preparedGeometries = new mutable.HashMap[Geometry, PreparedGeometry]
+    val joinedRow = new JoinedRow
+    streamIter.flatMap { srow =>
+      joinedRow.withLeft(srow.getUserData.asInstanceOf[UnsafeRow])
+      index.value.query(srow.getEnvelopeInternal)
+        .iterator.asScala.asInstanceOf[Iterator[Geometry]]
+        .filter(candidate => 
evaluator.eval(preparedGeometries.getOrElseUpdate(candidate, { 
factory.create(candidate) }), srow))
+        .map(candidate => 
joinedRow.withRight(candidate.getUserData.asInstanceOf[UnsafeRow]))
+        .filter(boundCondition)
+    }
+  }
+
+  private def semiJoin(
+    streamIter: Iterator[Geometry], index: Broadcast[SpatialIndex]
+  ): Iterator[InternalRow] = {
+    val factory = new PreparedGeometryFactory()
+    val preparedGeometries = new mutable.HashMap[Geometry, PreparedGeometry]
+    val joinedRow = new JoinedRow
+    streamIter.flatMap { srow =>
+      val left = srow.getUserData.asInstanceOf[UnsafeRow]
+      joinedRow.withLeft(left)
+      val anyMatches = index.value.query(srow.getEnvelopeInternal)
+        .iterator.asScala.asInstanceOf[Iterator[Geometry]]
+        .filter(candidate => 
evaluator.eval(preparedGeometries.getOrElseUpdate(candidate, {
+          factory.create(candidate)
+        }), srow))
+        .map(candidate => 
joinedRow.withRight(candidate.getUserData.asInstanceOf[UnsafeRow]))
+        .exists(boundCondition)
+
+      if (anyMatches) {
+        Iterator.single(left)
+      } else {
+        Iterator.empty
       }
     }
   }
 
-  private def objectBroadcastJoin(index: Broadcast[SpatialIndex], spatialRdd: 
SpatialRDD[Geometry]): RDD[(Geometry, Geometry)] = {
-    spatialRdd.getRawSpatialRDD.rdd.mapPartitions { rows =>
-      val factory = new PreparedGeometryFactory()
-      val preparedGeometries = new mutable.HashMap[Geometry, PreparedGeometry]
-      val evaluator = 
SpatialPredicateEvaluators.create(SpatialPredicate.inverse(spatialPredicate))
-      rows.flatMap { row =>
-        val candidates = 
index.value.query(row.getEnvelopeInternal).iterator.asScala.asInstanceOf[Iterator[Geometry]]
-        candidates
-          .filter(candidate => 
evaluator.eval(preparedGeometries.getOrElseUpdate(candidate, { 
factory.create(candidate) }), row))
-          .map(candidate => (row, candidate))
+  private def antiJoin(
+    streamIter: Iterator[Geometry], index: Broadcast[SpatialIndex]
+  ): Iterator[InternalRow] = {
+    val factory = new PreparedGeometryFactory()
+    val preparedGeometries = new mutable.HashMap[Geometry, PreparedGeometry]
+    val joinedRow = new JoinedRow
+    streamIter.flatMap { srow =>
+      val left = srow.getUserData.asInstanceOf[UnsafeRow]
+      joinedRow.withLeft(left)
+      val anyMatches = index.value.query(srow.getEnvelopeInternal)
+        .iterator.asScala.asInstanceOf[Iterator[Geometry]]
+        .filter(candidate => 
evaluator.eval(preparedGeometries.getOrElseUpdate(candidate, {
+          factory.create(candidate)
+        }), srow))
+        .map(candidate => 
joinedRow.withRight(candidate.getUserData.asInstanceOf[UnsafeRow]))
+        .exists(boundCondition)
+
+      if (anyMatches) {
+        Iterator.empty
+      } else {
+        Iterator.single(left)
       }
     }
   }
 
+  private def outerJoin(
+    streamIter: Iterator[Geometry], index: Broadcast[SpatialIndex]
+  ): Iterator[InternalRow] = {
+    val factory = new PreparedGeometryFactory()
+    val preparedGeometries = new mutable.HashMap[Geometry, PreparedGeometry]
+    val joinedRow = new JoinedRow
+    val nullRow = new GenericInternalRow(broadcast.output.length)
+
+    streamIter.flatMap { srow =>
+      joinedRow.withLeft(srow.getUserData.asInstanceOf[UnsafeRow])
+      val candidates = index.value.query(srow.getEnvelopeInternal)
+        .iterator.asScala.asInstanceOf[Iterator[Geometry]]
+        .filter(candidate => 
evaluator.eval(preparedGeometries.getOrElseUpdate(candidate, {
+          factory.create(candidate)
+        }), srow))
+
+      new RowIterator {
+        private var found = false
+        override def advanceNext(): Boolean = {
+          while (candidates.hasNext) {
+            val candidateRow = 
candidates.next().getUserData.asInstanceOf[UnsafeRow]
+            if (boundCondition(joinedRow.withRight(candidateRow))) {
+              found = true
+              return true
+            }
+          }
+          if (!found) {
+            joinedRow.withRight(nullRow)
+            found = true
+            return true
+          }
+          false
+        }
+        override def getRow: InternalRow = joinedRow
+      }.toScala
+    }
+  }
+
   override protected def doExecute(): RDD[InternalRow] = {
     val boundStreamShape = BindReferences.bindReference(streamShape, 
streamed.output)
     val streamResultsRaw = streamed.execute().asInstanceOf[RDD[UnsafeRow]]
 
+    val broadcastIndex = broadcast.executeBroadcast[SpatialIndex]()
+
     // If there's a distance and the objects are being broadcast, we need to 
build the expanded envelope on the window stream side
     val streamShapes = distance match {
       case Some(distanceExpression) if indexBuildSide != windowJoinSide =>
@@ -123,23 +231,25 @@ case class BroadcastIndexJoinExec(left: SparkPlan,
         toSpatialRDD(streamResultsRaw, boundStreamShape)
     }
 
-    val broadcastIndex = broadcast.executeBroadcast[SpatialIndex]()
+    streamShapes.getRawSpatialRDD.rdd.mapPartitions { streamedIter =>
+      val joinedIter = joinType match {
+        case _: InnerLike =>
+          innerJoin(streamedIter, broadcastIndex)
+        case LeftSemi =>
+          semiJoin(streamedIter, broadcastIndex)
+        case LeftAnti =>
+          antiJoin(streamedIter, broadcastIndex)
+        case LeftOuter | RightOuter =>
+          outerJoin(streamedIter, broadcastIndex)
+        case x =>
+          throw new IllegalArgumentException(s"BroadcastIndexJoinExec should 
not take $x as the JoinType")
 
-    val pairs = (indexBuildSide, windowJoinSide) match {
-      case (LeftSide, LeftSide) => windowBroadcastJoin(broadcastIndex, 
streamShapes)
-      case (LeftSide, RightSide) => objectBroadcastJoin(broadcastIndex, 
streamShapes).map { case (left, right) => (right, left) }
-      case (RightSide, LeftSide) => objectBroadcastJoin(broadcastIndex, 
streamShapes)
-      case (RightSide, RightSide) => windowBroadcastJoin(broadcastIndex, 
streamShapes).map { case (left, right) => (right, left) }
-    }
+      }
 
-    pairs.mapPartitions { iter =>
-      val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
-      iter.map {
-        case (l, r) =>
-          val leftRow = l.getUserData.asInstanceOf[UnsafeRow]
-          val rightRow = r.getUserData.asInstanceOf[UnsafeRow]
-          joiner.join(leftRow, rightRow)
-      }.filter(boundCondition(_))
+      val resultProj = createResultProjection()
+      joinedIter.map { r =>
+        resultProj(r)
+      }
     }
   }
 
diff --git 
a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
 
b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
index f5e086c0..fbd1f958 100644
--- 
a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
+++ 
b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
@@ -23,20 +23,20 @@ import 
org.apache.sedona.core.spatialOperator.SpatialPredicate
 import org.apache.sedona.core.utils.SedonaConf
 import org.apache.spark.sql.{SparkSession, Strategy}
 import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan, 
LessThanOrEqual}
-import org.apache.spark.sql.catalyst.plans.Inner
+import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, Inner, 
InnerLike, JoinType, LeftAnti, LeftOuter, LeftSemi, NaturalJoin, RightOuter, 
UsingJoin}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.sedona_sql.expressions._
 
 
 case class JoinQueryDetection(
-                               left: LogicalPlan,
-                               right: LogicalPlan,
-                               leftShape: Expression,
-                               rightShape: Expression,
-                               spatialPredicate: SpatialPredicate,
-                               extraCondition: Option[Expression] = None,
-                               distance: Option[Expression] = None
+  left: LogicalPlan,
+  right: LogicalPlan,
+  leftShape: Expression,
+  rightShape: Expression,
+  spatialPredicate: SpatialPredicate,
+  extraCondition: Option[Expression] = None,
+  distance: Option[Expression] = None
 )
 
 /**
@@ -78,7 +78,7 @@ class JoinQueryDetector(sparkSession: SparkSession) extends 
Strategy {
     }
 
   def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
-    case Join(left, right, Inner, condition, JoinHint(leftHint, rightHint)) => 
{ // SPARK3 anchor
+    case Join(left, right, joinType, condition, JoinHint(leftHint, rightHint)) 
=> { // SPARK3 anchor
 //    case Join(left, right, Inner, condition) => { // SPARK2 anchor
       val broadcastLeft = leftHint.exists(_.strategy.contains(BROADCAST)) // 
SPARK3 anchor
       val broadcastRight = rightHint.exists(_.strategy.contains(BROADCAST)) // 
SPARK3 anchor
@@ -115,16 +115,19 @@ class JoinQueryDetector(sparkSession: SparkSession) 
extends Strategy {
       if ((broadcastLeft || broadcastRight) && sedonaConf.getUseIndex) {
         queryDetection match {
           case Some(JoinQueryDetection(left, right, leftShape, rightShape, 
spatialPredicate, extraCondition, distance)) =>
-            planBroadcastJoin(left, right, Seq(leftShape, rightShape), 
spatialPredicate, sedonaConf.getIndexType, broadcastLeft, extraCondition, 
distance)
+            planBroadcastJoin(
+              left, right, Seq(leftShape, rightShape), joinType,
+              spatialPredicate, sedonaConf.getIndexType,
+              broadcastLeft, broadcastRight, extraCondition, distance)
           case _ =>
             Nil
         }
       } else {
         queryDetection match {
           case Some(JoinQueryDetection(left, right, leftShape, rightShape, 
spatialPredicate, extraCondition, None)) =>
-            planSpatialJoin(left, right, Seq(leftShape, rightShape), 
spatialPredicate, extraCondition)
+            planSpatialJoin(left, right, Seq(leftShape, rightShape), joinType, 
spatialPredicate, extraCondition)
           case Some(JoinQueryDetection(left, right, leftShape, rightShape, 
spatialPredicate, extraCondition, Some(distance))) =>
-            planDistanceJoin(left, right, Seq(leftShape, rightShape), 
distance, spatialPredicate, extraCondition)
+            planDistanceJoin(left, right, Seq(leftShape, rightShape), 
joinType, distance, spatialPredicate, extraCondition)
           case None => 
             Nil
         }
@@ -153,11 +156,18 @@ class JoinQueryDetector(sparkSession: SparkSession) 
extends Strategy {
       None
     }
 
-  private def planSpatialJoin(left: LogicalPlan,
-                              right: LogicalPlan,
-                              children: Seq[Expression],
-                              spatialPredicate: SpatialPredicate,
-                              extraCondition: Option[Expression] = None): 
Seq[SparkPlan] = {
+  private def planSpatialJoin(
+    left: LogicalPlan,
+    right: LogicalPlan,
+    children: Seq[Expression],
+    joinType: JoinType,
+    spatialPredicate: SpatialPredicate,
+    extraCondition: Option[Expression] = None): Seq[SparkPlan] = {
+
+    if (joinType != Inner) {
+      return Nil
+    }
+
     val a = children.head
     val b = children.tail.head
 
@@ -175,12 +185,19 @@ class JoinQueryDetector(sparkSession: SparkSession) 
extends Strategy {
     }
   }
 
-  private def planDistanceJoin(left: LogicalPlan,
-                               right: LogicalPlan,
-                               children: Seq[Expression],
-                               distance: Expression,
-                               spatialPredicate: SpatialPredicate,
-                               extraCondition: Option[Expression] = None): 
Seq[SparkPlan] = {
+  private def planDistanceJoin(
+    left: LogicalPlan,
+    right: LogicalPlan,
+    children: Seq[Expression],
+    joinType: JoinType,
+    distance: Expression,
+    spatialPredicate: SpatialPredicate,
+    extraCondition: Option[Expression] = None): Seq[SparkPlan] = {
+
+    if (joinType != Inner) {
+      return Nil
+    }
+
     val a = children.head
     val b = children.tail.head
 
@@ -206,14 +223,32 @@ class JoinQueryDetector(sparkSession: SparkSession) 
extends Strategy {
     }
   }
 
-  private def planBroadcastJoin(left: LogicalPlan,
-                                right: LogicalPlan,
-                                children: Seq[Expression],
-                                spatialPredicate: SpatialPredicate,
-                                indexType: IndexType,
-                                broadcastLeft: Boolean,
-                                extraCondition: Option[Expression],
-                                distance: Option[Expression]): Seq[SparkPlan] 
= {
+  private def planBroadcastJoin(
+    left: LogicalPlan,
+    right: LogicalPlan,
+    children: Seq[Expression],
+    joinType: JoinType,
+    spatialPredicate: SpatialPredicate,
+    indexType: IndexType,
+    broadcastLeft: Boolean,
+    broadcastRight: Boolean,
+    extraCondition: Option[Expression],
+    distance: Option[Expression]): Seq[SparkPlan] = {
+
+    val broadcastSide = joinType match {
+      case Inner if broadcastLeft => Some(LeftSide)
+      case Inner if broadcastRight => Some(RightSide)
+      case LeftSemi if broadcastRight => Some(RightSide)
+      case LeftAnti if broadcastRight => Some(RightSide)
+      case LeftOuter if broadcastRight => Some(RightSide)
+      case RightOuter if broadcastLeft => Some(LeftSide)
+      case _ => None
+    }
+
+    if (broadcastSide.isEmpty) {
+      return Nil
+    }
+
     val a = children.head
     val b = children.tail.head
 
@@ -226,8 +261,7 @@ class JoinQueryDetector(sparkSession: SparkSession) extends 
Strategy {
     matchExpressionsToPlans(a, b, left, right) match {
       case Some((_, _, swapped)) =>
         logInfo(s"Planning spatial join for $relationship relationship")
-        val broadcastSide = if (broadcastLeft) LeftSide else RightSide
-        val (leftPlan, rightPlan, streamShape, windowSide) = (broadcastSide, 
swapped) match {
+        val (leftPlan, rightPlan, streamShape, windowSide) = 
(broadcastSide.get, swapped) match {
           case (LeftSide, false) => // Broadcast the left side, windows on the 
left
             (SpatialIndexExec(planLater(left), a, indexType, distance), 
planLater(right), b, LeftSide)
           case (LeftSide, true) => // Broadcast the left side, objects on the 
left
@@ -237,7 +271,7 @@ class JoinQueryDetector(sparkSession: SparkSession) extends 
Strategy {
           case (RightSide, true) => // Broadcast the right side, objects on 
the left
             (planLater(left), SpatialIndexExec(planLater(right), a, indexType, 
distance), b, RightSide)
         }
-        BroadcastIndexJoinExec(leftPlan, rightPlan, streamShape, 
broadcastSide, windowSide, spatialPredicate, extraCondition, distance) :: Nil
+        BroadcastIndexJoinExec(leftPlan, rightPlan, streamShape, 
broadcastSide.get, windowSide, joinType, spatialPredicate, extraCondition, 
distance) :: Nil
       case None =>
         logInfo(
           s"Spatial join for $relationship with arguments not aligned " +
diff --git 
a/sql/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala 
b/sql/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala
index 445af754..30a22f08 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala
@@ -19,37 +19,43 @@
 
 package org.apache.sedona.sql
 
+import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoinExec
 import org.apache.spark.sql.sedona_sql.strategy.join.BroadcastIndexJoinExec
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.Row
 
 class BroadcastIndexJoinSuite extends TestBaseScala {
 
-  describe("Sedona-SQL Broadcast Index Join Test") {
+  describe("Sedona-SQL Broadcast Index Join Test for inner joins") {
 
     // Using UDFs rather than lit prevents optimizations that would circumvent 
the checks we want to test
-    val one = udf(() => 1)
-    val two = udf(() => 2)
+    val one = udf(() => 1).asNondeterministic()
+    val two = udf(() => 2).asNondeterministic()
 
     it("Passed Correct partitioning for broadcast join for ST_Polygon and 
ST_Point") {
       val polygonDf = buildPolygonDf.repartition(3)
       val pointDf = buildPointDf.repartition(5)
 
-      var broadcastJoinDf = 
pointDf.alias("pointDf").join(broadcast(polygonDf).alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+      var broadcastJoinDf = pointDf.alias("pointDf").join(
+        broadcast(polygonDf).alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
       assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p: 
BroadcastIndexJoinExec => p }.size === 1)
       assert(broadcastJoinDf.rdd.getNumPartitions == 
pointDf.rdd.getNumPartitions)
       assert(broadcastJoinDf.count() == 1000)
 
-      broadcastJoinDf = 
broadcast(polygonDf).alias("polygonDf").join(pointDf.alias("pointDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+      broadcastJoinDf = broadcast(polygonDf).alias("polygonDf").join(
+        pointDf.alias("pointDf"), expr("ST_Contains(polygonDf.polygonshape, 
pointDf.pointshape)"))
       assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p: 
BroadcastIndexJoinExec => p }.size === 1)
       assert(broadcastJoinDf.rdd.getNumPartitions == 
pointDf.rdd.getNumPartitions)
       assert(broadcastJoinDf.count() == 1000)
 
-      broadcastJoinDf = 
broadcast(pointDf).alias("pointDf").join(polygonDf.alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+      broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+        polygonDf.alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
       assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p: 
BroadcastIndexJoinExec => p }.size === 1)
       assert(broadcastJoinDf.rdd.getNumPartitions == 
polygonDf.rdd.getNumPartitions)
       assert(broadcastJoinDf.count() == 1000)
 
-      broadcastJoinDf = 
polygonDf.alias("polygonDf").join(broadcast(pointDf).alias("pointDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+      broadcastJoinDf = polygonDf.alias("polygonDf").join(
+        broadcast(pointDf).alias("pointDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
       assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p: 
BroadcastIndexJoinExec => p }.size === 1)
       assert(broadcastJoinDf.rdd.getNumPartitions == 
polygonDf.rdd.getNumPartitions)
       assert(broadcastJoinDf.count() == 1000)
@@ -59,7 +65,8 @@ class BroadcastIndexJoinSuite extends TestBaseScala {
       val polygonDf = buildPolygonDf.repartition(3)
       val pointDf = buildPointDf.repartition(5)
 
-      var broadcastJoinDf = 
broadcast(pointDf).alias("pointDf").join(broadcast(polygonDf).alias("polygonDf"),
 expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+      val broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+        broadcast(polygonDf).alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
       assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p: 
BroadcastIndexJoinExec => p }.size === 1)
       assert(broadcastJoinDf.rdd.getNumPartitions == 
polygonDf.rdd.getNumPartitions)
       assert(broadcastJoinDf.count() == 1000)
@@ -68,20 +75,28 @@ class BroadcastIndexJoinSuite extends TestBaseScala {
     it("Passed Can access attributes of both sides of broadcast join") {
       val polygonDf = buildPolygonDf.withColumn("window_extra", one())
       val pointDf = buildPointDf.withColumn("object_extra", one())
-      
-      var broadcastJoinDf = 
polygonDf.alias("polygonDf").join(broadcast(pointDf).alias("pointDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+
+      var broadcastJoinDf = polygonDf.alias("polygonDf").join(
+        broadcast(pointDf).alias("pointDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p: 
BroadcastIndexJoinExec => p }.size === 1)
       assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0) == 
1000)
       assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0) == 
1000)
 
-      broadcastJoinDf = 
broadcast(polygonDf).alias("polygonDf").join(pointDf.alias("pointDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+      broadcastJoinDf = broadcast(polygonDf).alias("polygonDf").join(
+        pointDf.alias("pointDf"), expr("ST_Contains(polygonDf.polygonshape, 
pointDf.pointshape)"))
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p: 
BroadcastIndexJoinExec => p }.size === 1)
       assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0) == 
1000)
       assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0) == 
1000)
 
-      broadcastJoinDf = 
broadcast(pointDf).alias("pointDf").join(polygonDf.alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+      broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+        polygonDf.alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p: 
BroadcastIndexJoinExec => p }.size === 1)
       assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0) == 
1000)
       assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0) == 
1000)
 
-      broadcastJoinDf = 
pointDf.alias("pointDf").join(broadcast(polygonDf).alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+      broadcastJoinDf = pointDf.alias("pointDf").join(
+        broadcast(polygonDf).alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p: 
BroadcastIndexJoinExec => p }.size === 1)
       assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0) == 
1000)
       assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0) == 
1000)
     }
@@ -157,48 +172,56 @@ class BroadcastIndexJoinSuite extends TestBaseScala {
     }
 
     it("Passed ST_Distance <= distance in a broadcast join") {
-      var pointDf1 = buildPointDf
-      var pointDf2 = buildPointDf
+      val pointDf1 = buildPointDf
+      val pointDf2 = buildPointDf
 
-      var distanceJoinDf = 
pointDf1.alias("pointDf1").join(broadcast(pointDf2).alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"))
+      var distanceJoinDf = pointDf1.alias("pointDf1").join(
+        broadcast(pointDf2).alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"))
       assert(distanceJoinDf.queryExecution.sparkPlan.collect{ case p: 
BroadcastIndexJoinExec => p }.size === 1)
       assert(distanceJoinDf.count() == 2998)
 
-      distanceJoinDf = 
broadcast(pointDf1).alias("pointDf1").join(pointDf2.alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"))
+      distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+        pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape, 
pointDf2.pointshape) <= 2"))
       assert(distanceJoinDf.queryExecution.sparkPlan.collect{ case p: 
BroadcastIndexJoinExec => p }.size === 1)
       assert(distanceJoinDf.count() == 2998)
     }
 
     it("Passed ST_Distance < distance in a broadcast join") {
-      var pointDf1 = buildPointDf
-      var pointDf2 = buildPointDf
+      val pointDf1 = buildPointDf
+      val pointDf2 = buildPointDf
 
-      var distanceJoinDf = 
pointDf1.alias("pointDf1").join(broadcast(pointDf2).alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < 2"))
+      var distanceJoinDf = pointDf1.alias("pointDf1").join(
+        broadcast(pointDf2).alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < 2"))
       assert(distanceJoinDf.queryExecution.sparkPlan.collect{ case p: 
BroadcastIndexJoinExec => p }.size === 1)
       assert(distanceJoinDf.count() == 2998)
 
-      distanceJoinDf = 
broadcast(pointDf1).alias("pointDf1").join(pointDf2.alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < 2"))
+      distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+        pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape, 
pointDf2.pointshape) < 2"))
       assert(distanceJoinDf.queryExecution.sparkPlan.collect{ case p: 
BroadcastIndexJoinExec => p }.size === 1)
       assert(distanceJoinDf.count() == 2998)
     }
 
     it("Passed ST_Distance distance is bound to first expression") {
-      var pointDf1 = buildPointDf.withColumn("radius", two())
-      var pointDf2 = buildPointDf
+      val pointDf1 = buildPointDf.withColumn("radius", two())
+      val pointDf2 = buildPointDf
 
-      var distanceJoinDf = 
pointDf1.alias("pointDf1").join(broadcast(pointDf2).alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"))
+      var distanceJoinDf = pointDf1.alias("pointDf1").join(
+        broadcast(pointDf2).alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"))
       assert(distanceJoinDf.queryExecution.sparkPlan.collect{ case p: 
BroadcastIndexJoinExec => p }.size === 1)
       assert(distanceJoinDf.count() == 2998)
 
-      distanceJoinDf = 
broadcast(pointDf1).alias("pointDf1").join(pointDf2.alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"))
+      distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+        pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape, 
pointDf2.pointshape) < radius"))
       assert(distanceJoinDf.queryExecution.sparkPlan.collect{ case p: 
BroadcastIndexJoinExec => p }.size === 1)
       assert(distanceJoinDf.count() == 2998)
 
-      distanceJoinDf = 
pointDf2.alias("pointDf2").join(broadcast(pointDf1).alias("pointDf1"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"))
+      distanceJoinDf = pointDf2.alias("pointDf2").join(
+        broadcast(pointDf1).alias("pointDf1"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"))
       assert(distanceJoinDf.queryExecution.sparkPlan.collect{ case p: 
BroadcastIndexJoinExec => p }.size === 1)
       assert(distanceJoinDf.count() == 2998)
 
-      distanceJoinDf = 
broadcast(pointDf2).alias("pointDf2").join(pointDf1.alias("pointDf1"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"))
+      distanceJoinDf = broadcast(pointDf2).alias("pointDf2").join(
+        pointDf1.alias("pointDf1"), expr("ST_Distance(pointDf1.pointshape, 
pointDf2.pointshape) < radius"))
       assert(distanceJoinDf.queryExecution.sparkPlan.collect{ case p: 
BroadcastIndexJoinExec => p }.size === 1)
       assert(distanceJoinDf.count() == 2998)
     }
@@ -208,22 +231,26 @@ class BroadcastIndexJoinSuite extends TestBaseScala {
       val polygonDf = buildPolygonDf.repartition(3)
       val pointDf = buildPointDf.repartition(5)
 
-      var broadcastJoinDf = 
pointDf.alias("pointDf").join(broadcast(polygonDf).alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+      var broadcastJoinDf = pointDf.alias("pointDf").join(
+        broadcast(polygonDf).alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
       assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p: 
BroadcastIndexJoinExec => p }.size === 1)
       assert(broadcastJoinDf.rdd.getNumPartitions == 
pointDf.rdd.getNumPartitions)
       assert(broadcastJoinDf.count() == 1000)
 
-      broadcastJoinDf = 
broadcast(polygonDf).alias("polygonDf").join(pointDf.alias("pointDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+      broadcastJoinDf = broadcast(polygonDf).alias("polygonDf").join(
+        pointDf.alias("pointDf"), expr("ST_Contains(polygonDf.polygonshape, 
pointDf.pointshape)"))
       assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p: 
BroadcastIndexJoinExec => p }.size === 1)
       assert(broadcastJoinDf.rdd.getNumPartitions == 
pointDf.rdd.getNumPartitions)
       assert(broadcastJoinDf.count() == 1000)
 
-      broadcastJoinDf = 
broadcast(pointDf).alias("pointDf").join(polygonDf.alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+      broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+        polygonDf.alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
       assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p: 
BroadcastIndexJoinExec => p }.size === 1)
       assert(broadcastJoinDf.rdd.getNumPartitions == 
polygonDf.rdd.getNumPartitions)
       assert(broadcastJoinDf.count() == 1000)
 
-      broadcastJoinDf = 
polygonDf.alias("polygonDf").join(broadcast(pointDf).alias("pointDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+      broadcastJoinDf = polygonDf.alias("polygonDf").join(
+        broadcast(pointDf).alias("pointDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
       assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p: 
BroadcastIndexJoinExec => p }.size === 1)
       assert(broadcastJoinDf.rdd.getNumPartitions == 
polygonDf.rdd.getNumPartitions)
       assert(broadcastJoinDf.count() == 1000)
@@ -246,5 +273,1139 @@ class BroadcastIndexJoinSuite extends TestBaseScala {
           |on ST_Distance(a.geom, b.geom) < 1.5
           |""".stripMargin).count() == 1)
     }
+
+    it("Passed validate output rows") {
+      val left = sparkSession.createDataFrame(Seq(
+        (1.0, 1.0, "left_1"),
+        (2.0, 2.0, "left_2")
+      )).toDF("l_x", "l_y", "l_data")
+
+      val right = sparkSession.createDataFrame(Seq(
+        (2.0, 2.0, "right_2"),
+        (3.0, 3.0, "right_3")
+      )).toDF("r_x", "r_y", "r_data")
+
+      val joined = left.join(broadcast(right),
+        expr("ST_Intersects(ST_Point(l_x, l_y), ST_Point(r_x, r_y))"), "inner")
+      assert(joined.queryExecution.sparkPlan.collect{ case p: 
BroadcastIndexJoinExec => p }.size === 1)
+
+      val rows = joined.collect()
+      assert(rows.length == 1)
+      assert(rows(0) == Row(2.0, 2.0, "left_2", 2.0, 2.0, "right_2"))
+
+      val joined2 = broadcast(left).join(right,
+        expr("ST_Intersects(ST_Point(l_x, l_y), ST_Point(r_x, r_y))"), "inner")
+      assert(joined.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+
+      val rows2 = joined2.collect()
+      assert(rows2.length == 1)
+      assert(rows2(0) == Row(2.0, 2.0, "left_2", 2.0, 2.0, "right_2"))
+
+    }
+  }
+
+  describe("Sedona-SQL Broadcast Index Join Test for left semi joins") {
+
+    // Using UDFs rather than lit prevents optimizations that would circumvent 
the checks we want to test
+    val one = udf(() => 1).asNondeterministic()
+    val two = udf(() => 2).asNondeterministic()
+
+    it("Passed Correct partitioning for broadcast join for ST_Polygon and 
ST_Point") {
+      val polygonDf = buildPolygonDf.repartition(3)
+      val pointDf = buildPointDf.repartition(5)
+
+      Seq(500, 900, 1000).foreach { limit =>
+        var broadcastJoinDf = pointDf.alias("pointDf").join(
+          broadcast(polygonDf.limit(limit)).alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_semi")
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.rdd.getNumPartitions == 
pointDf.rdd.getNumPartitions)
+        assert(broadcastJoinDf.count() == limit)
+
+        broadcastJoinDf = broadcast(polygonDf).alias("polygonDf").join(
+          pointDf.limit(limit).alias("pointDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_semi")
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.count() == limit)
+
+        broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+          polygonDf.limit(limit).alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_semi")
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.count() == limit)
+
+        broadcastJoinDf = polygonDf.alias("polygonDf").join(
+          broadcast(pointDf.limit(limit)).alias("pointDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_semi")
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.rdd.getNumPartitions == 
polygonDf.rdd.getNumPartitions)
+        assert(broadcastJoinDf.count() == limit)
+      }
+    }
+
+    it("Passed Broadcasts the right side if both sides have a broadcast hint") 
{
+      val polygonDf = buildPolygonDf.repartition(3)
+      val pointDf = buildPointDf.repartition(5)
+
+      val broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+        broadcast(polygonDf).alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_semi")
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.rdd.getNumPartitions == 
pointDf.rdd.getNumPartitions)
+      assert(broadcastJoinDf.count() == 1000)
+    }
+
+    it("Passed Can access attributes of left side of broadcast join") {
+      val polygonDf = buildPolygonDf.withColumn("window_extra", one())
+      val pointDf = buildPointDf.withColumn("object_extra", one())
+
+      var broadcastJoinDf = polygonDf.alias("polygonDf").join(
+        broadcast(pointDf).alias("pointDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_semi")
+      assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0) == 
1000)
+
+      broadcastJoinDf = broadcast(polygonDf).alias("polygonDf").join(
+        pointDf.alias("pointDf"), expr("ST_Contains(polygonDf.polygonshape, 
pointDf.pointshape)"), "left_semi")
+      assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0) == 
1000)
+
+      broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+        polygonDf.alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_semi")
+      assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0) == 
1000)
+
+      broadcastJoinDf = pointDf.alias("pointDf").join(
+        broadcast(polygonDf).alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_semi")
+      assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0) == 
1000)
+    }
+
+    it("Passed Handles extra conditions on a broadcast join") {
+      val polygonDf = buildPolygonDf.withColumn("window_extra", one())
+      val pointDf = buildPointDf.withColumn("object_extra", two())
+
+      Seq(500, 900, 1000).foreach { limit =>
+        var broadcastJoinDf = pointDf
+          .alias("pointDf")
+          .join(
+            broadcast(polygonDf.limit(limit).alias("polygonDf")),
+            expr("ST_Contains(polygonshape, pointshape) AND window_extra <= 
object_extra"),
+            "left_semi"
+          )
+
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.count() == limit)
+
+        broadcastJoinDf = pointDf
+          .alias("pointDf")
+          .join(
+            broadcast(polygonDf.limit(limit).alias("polygonDf")),
+            expr("ST_Contains(polygonshape, pointshape) AND window_extra > 
object_extra"),
+            "left_semi"
+          )
+
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.count() == 0)
+
+        broadcastJoinDf = pointDf
+          .alias("pointDf")
+          .join(
+            broadcast(polygonDf.limit(limit).alias("polygonDf")),
+            expr("window_extra <= object_extra AND ST_Contains(polygonshape, 
pointshape)"),
+            "left_semi"
+          )
+
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.count() == limit)
+
+        broadcastJoinDf = pointDf
+          .alias("pointDf")
+          .join(
+            broadcast(polygonDf.limit(limit).alias("polygonDf")),
+            expr("window_extra > object_extra AND ST_Contains(polygonshape, 
pointshape)"),
+            "left_semi"
+          )
+
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.count() == 0)
+      }
+    }
+
+    it("Passed Handles multiple extra conditions on a broadcast join with the 
ST predicate last") {
+      val polygonDf = buildPolygonDf.withColumn("window_extra", 
one()).withColumn("window_extra2", one())
+      val pointDf = buildPointDf.withColumn("object_extra", 
two()).withColumn("object_extra2", two())
+
+      Seq(500, 900, 1000).foreach { limit =>
+        var broadcastJoinDf = pointDf
+          .alias("pointDf")
+          .join(
+            broadcast(polygonDf.limit(limit).alias("polygonDf")),
+            expr("window_extra <= object_extra AND window_extra2 <= 
object_extra2 AND ST_Contains(polygonshape, pointshape)"),
+            "left_semi"
+          )
+
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.count() == limit)
+
+        broadcastJoinDf = pointDf
+          .alias("pointDf")
+          .join(
+            broadcast(polygonDf.limit(limit).alias("polygonDf")),
+            expr("window_extra > object_extra AND window_extra2 > 
object_extra2 AND ST_Contains(polygonshape, pointshape)"),
+            "left_semi"
+          )
+
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.count() == 0)
+      }
+    }
+
+    it("Passed ST_Distance <= distance in a broadcast join") {
+      val pointDf1 = buildPointDf
+      val pointDf2 = buildPointDf
+
+      var distanceJoinDf = pointDf1.alias("pointDf1").join(
+        broadcast(pointDf2).alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"), "left_semi")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 1000)
+
+      distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+        pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape, 
pointDf2.pointshape) <= 2"), "left_semi")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 1000)
+
+      distanceJoinDf = pointDf1.alias("pointDf1").join(
+        broadcast(pointDf2.limit(500)).alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"), "left_semi")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 501)
+
+      distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+        pointDf2.limit(500).alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"), "left_semi")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 501)
+    }
+
+    it("Passed ST_Distance < distance in a broadcast join") {
+      val pointDf1 = buildPointDf
+      val pointDf2 = buildPointDf
+
+      var distanceJoinDf = pointDf1.alias("pointDf1").join(
+        broadcast(pointDf2).alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < 2"), "left_semi")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 1000)
+
+      distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+        pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape, 
pointDf2.pointshape) < 2"), "left_semi")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 1000)
+
+      distanceJoinDf = pointDf1.alias("pointDf1").join(
+        broadcast(pointDf2.limit(500)).alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < 2"), "left_semi")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 501)
+
+      distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+        pointDf2.limit(500).alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < 2"), "left_semi")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 501)
+    }
+
+    it("Passed ST_Distance distance is bound to first expression") {
+      val pointDf1 = buildPointDf.withColumn("radius", two())
+      val pointDf2 = buildPointDf
+
+      var distanceJoinDf = pointDf1.alias("pointDf1").join(
+        broadcast(pointDf2).alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"), 
"left_semi")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 1000)
+
+      distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+        pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape, 
pointDf2.pointshape) < radius"), "left_semi")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 1000)
+
+      distanceJoinDf = pointDf2.alias("pointDf2").join(
+        broadcast(pointDf1).alias("pointDf1"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"), 
"left_semi")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 1000)
+
+      distanceJoinDf = broadcast(pointDf2).alias("pointDf2").join(
+        pointDf1.alias("pointDf1"), expr("ST_Distance(pointDf1.pointshape, 
pointDf2.pointshape) < radius"), "left_semi")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 1000)
+    }
+
+    it("Passed Correct partitioning for broadcast join for ST_Polygon and 
ST_Point with AQE enabled") {
+      sparkSession.conf.set("spark.sql.adaptive.enabled", true)
+      val polygonDf = buildPolygonDf.repartition(3)
+      val pointDf = buildPointDf.repartition(5)
+
+      var broadcastJoinDf = pointDf.alias("pointDf").join(
+        broadcast(polygonDf).alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_semi")
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.rdd.getNumPartitions == 
pointDf.rdd.getNumPartitions)
+      assert(broadcastJoinDf.count() == 1000)
+
+      broadcastJoinDf = broadcast(polygonDf).alias("polygonDf").join(
+        pointDf.alias("pointDf"), expr("ST_Contains(polygonDf.polygonshape, 
pointDf.pointshape)"), "left_semi")
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.count() == 1000)
+
+      broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+        polygonDf.alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_semi")
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.count() == 1000)
+
+      broadcastJoinDf = polygonDf.alias("polygonDf").join(
+        broadcast(pointDf).alias("pointDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_semi")
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.rdd.getNumPartitions == 
polygonDf.rdd.getNumPartitions)
+      assert(broadcastJoinDf.count() == 1000)
+      sparkSession.conf.set("spark.sql.adaptive.enabled", false)
+    }
+
+    it("Passed validate output rows") {
+      val left = sparkSession.createDataFrame(Seq(
+        (1.0, 1.0, "left_1"),
+        (2.0, 2.0, "left_2")
+      )).toDF("l_x", "l_y", "l_data")
+
+      val right = sparkSession.createDataFrame(Seq(
+        (2.0, 2.0, "right_2"),
+        (3.0, 3.0, "right_3")
+      )).toDF("r_x", "r_y", "r_data")
+
+      val joined = left.join(broadcast(right),
+        expr("ST_Intersects(ST_Point(l_x, l_y), ST_Point(r_x, r_y))"), 
"left_semi")
+      assert(joined.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+
+      val rows = joined.collect()
+      assert(rows.length == 1)
+      assert(rows(0) == Row(2.0, 2.0, "left_2"))
+    }
+  }
+
+  describe("Sedona-SQL Broadcast Index Join Test for left anti joins") {
+
+    // Using UDFs rather than lit prevents optimizations that would circumvent 
the checks we want to test
+    val one = udf(() => 1).asNondeterministic()
+    val two = udf(() => 2).asNondeterministic()
+
+    it("Passed Correct partitioning for broadcast join for ST_Polygon and 
ST_Point") {
+      val polygonDf = buildPolygonDf.repartition(3)
+      val pointDf = buildPointDf.repartition(5)
+
+      Seq(500, 900, 1000).foreach { limit =>
+        var broadcastJoinDf = pointDf.alias("pointDf").join(
+          broadcast(polygonDf.limit(limit)).alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_anti")
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.count() == 1000 - limit)
+
+        broadcastJoinDf = broadcast(polygonDf).alias("polygonDf").join(
+          pointDf.limit(limit).alias("pointDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_anti")
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.count() == 1000 - limit)
+
+        broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+          polygonDf.limit(limit).alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_anti")
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.count() == 1000 - limit)
+
+        broadcastJoinDf = polygonDf.alias("polygonDf").join(
+          broadcast(pointDf.limit(limit)).alias("pointDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_anti")
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.count() == 1000 - limit)
+      }
+    }
+
+    it("Passed Broadcasts the right side if both sides have a broadcast hint") 
{
+      val polygonDf = buildPolygonDf.repartition(3)
+      val pointDf = buildPointDf.repartition(5)
+
+      val broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+        broadcast(polygonDf).alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_anti")
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.rdd.getNumPartitions == 
pointDf.rdd.getNumPartitions)
+      assert(broadcastJoinDf.count() == 0)
+    }
+
+    it("Passed Can access attributes of left side of broadcast join") {
+      val polygonDf = buildPolygonDf.withColumn("window_extra", one())
+      val pointDf = buildPointDf.withColumn("object_extra", one())
+
+      Seq(500, 900).foreach { limit =>
+        var broadcastJoinDf = polygonDf.alias("polygonDf").join(
+          broadcast(pointDf.limit(limit)).alias("pointDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_anti")
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0) 
== 1000 - limit)
+
+        broadcastJoinDf = broadcast(polygonDf).alias("polygonDf").join(
+          pointDf.limit(limit).alias("pointDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_anti")
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0) 
== 1000 - limit)
+
+        broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+          polygonDf.limit(limit).alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_anti")
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0) 
== 1000 - limit)
+
+        broadcastJoinDf = pointDf.alias("pointDf").join(
+          broadcast(polygonDf.limit(limit)).alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_anti")
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0) 
== 1000 - limit)
+      }
+    }
+
+    it("Passed Handles extra conditions on a broadcast join") {
+      val polygonDf = buildPolygonDf.withColumn("window_extra", one())
+      val pointDf = buildPointDf.withColumn("object_extra", two())
+
+      Seq(500, 900).foreach { limit =>
+        var broadcastJoinDf = pointDf
+          .alias("pointDf")
+          .join(
+            broadcast(polygonDf.limit(limit).alias("polygonDf")),
+            expr("ST_Contains(polygonshape, pointshape) AND window_extra <= 
object_extra"),
+            "left_anti"
+          )
+
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.count() == 1000 - limit)
+
+        broadcastJoinDf = pointDf
+          .alias("pointDf")
+          .join(
+            broadcast(polygonDf.limit(limit).alias("polygonDf")),
+            expr("ST_Contains(polygonshape, pointshape) AND window_extra > 
object_extra"),
+            "left_anti"
+          )
+
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.count() == 1000)
+
+        broadcastJoinDf = pointDf
+          .alias("pointDf")
+          .join(
+            broadcast(polygonDf.limit(limit).alias("polygonDf")),
+            expr("window_extra <= object_extra AND ST_Contains(polygonshape, 
pointshape)"),
+            "left_anti"
+          )
+
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.count() == 1000 - limit)
+
+        broadcastJoinDf = pointDf
+          .alias("pointDf")
+          .join(
+            broadcast(polygonDf.limit(limit).alias("polygonDf")),
+            expr("window_extra > object_extra AND ST_Contains(polygonshape, 
pointshape)"),
+            "left_anti"
+          )
+
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.count() == 1000)
+      }
+    }
+
+    it("Passed Handles multiple extra conditions on a broadcast join with the 
ST predicate last") {
+      val polygonDf = buildPolygonDf.withColumn("window_extra", 
one()).withColumn("window_extra2", one())
+      val pointDf = buildPointDf.withColumn("object_extra", 
two()).withColumn("object_extra2", two())
+
+      var broadcastJoinDf = pointDf
+        .alias("pointDf")
+        .join(
+          broadcast(polygonDf.alias("polygonDf")),
+          expr("window_extra <= object_extra AND window_extra2 <= 
object_extra2 AND ST_Contains(polygonshape, pointshape)"),
+          "left_anti"
+        )
+
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.count() == 0)
+
+      broadcastJoinDf = pointDf
+        .alias("pointDf")
+        .join(
+          broadcast(polygonDf.alias("polygonDf")),
+          expr("window_extra > object_extra AND window_extra2 > object_extra2 
AND ST_Contains(polygonshape, pointshape)"),
+          "left_anti"
+        )
+
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.count() == 1000)
+    }
+
+    it("Passed ST_Distance <= distance in a broadcast join") {
+      val pointDf1 = buildPointDf
+      val pointDf2 = buildPointDf
+
+      var distanceJoinDf = pointDf1.alias("pointDf1").join(
+        broadcast(pointDf2).alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"), "left_anti")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 0)
+
+      distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+        pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape, 
pointDf2.pointshape) <= 2"), "left_anti")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 0)
+
+      distanceJoinDf = pointDf1.alias("pointDf1").join(
+        broadcast(pointDf2.limit(500)).alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"), "left_anti")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 499)
+
+      distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+        pointDf2.limit(500).alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"), "left_anti")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 499)
+    }
+
+    it("Passed ST_Distance < distance in a broadcast join") {
+      val pointDf1 = buildPointDf
+      val pointDf2 = buildPointDf
+
+      var distanceJoinDf = pointDf1.alias("pointDf1").join(
+        broadcast(pointDf2).alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < 2"), "left_anti")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 0)
+
+      distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+        pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape, 
pointDf2.pointshape) < 2"), "left_anti")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 0)
+
+      distanceJoinDf = pointDf1.alias("pointDf1").join(
+        broadcast(pointDf2.limit(500)).alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < 2"), "left_anti")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 499)
+
+      distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+        pointDf2.limit(500).alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < 2"), "left_anti")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 499)
+    }
+
+    it("Passed ST_Distance distance is bound to first expression") {
+      val pointDf1 = buildPointDf.withColumn("radius", two())
+      val pointDf2 = buildPointDf
+
+      var distanceJoinDf = pointDf1.alias("pointDf1").join(
+        broadcast(pointDf2).alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"), 
"left_anti")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 0)
+
+      distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+        pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape, 
pointDf2.pointshape) < radius"), "left_anti")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 0)
+
+      distanceJoinDf = pointDf2.alias("pointDf2").join(
+        broadcast(pointDf1).alias("pointDf1"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"), 
"left_anti")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 0)
+
+      distanceJoinDf = broadcast(pointDf2).alias("pointDf2").join(
+        pointDf1.alias("pointDf1"), expr("ST_Distance(pointDf1.pointshape, 
pointDf2.pointshape) < radius"), "left_anti")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 0)
+    }
+
+    it("Passed Correct partitioning for broadcast join for ST_Polygon and 
ST_Point with AQE enabled") {
+      sparkSession.conf.set("spark.sql.adaptive.enabled", true)
+      val polygonDf = buildPolygonDf.repartition(3)
+      val pointDf = buildPointDf.repartition(5)
+
+      var broadcastJoinDf = pointDf.alias("pointDf").join(
+        broadcast(polygonDf).alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_anti")
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.count() == 0)
+
+      broadcastJoinDf = broadcast(polygonDf).alias("polygonDf").join(
+        pointDf.alias("pointDf"), expr("ST_Contains(polygonDf.polygonshape, 
pointDf.pointshape)"), "left_anti")
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.count() == 0)
+
+      broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+        polygonDf.alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_anti")
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.count() == 0)
+
+      broadcastJoinDf = polygonDf.alias("polygonDf").join(
+        broadcast(pointDf).alias("pointDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_anti")
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.count() == 0)
+      sparkSession.conf.set("spark.sql.adaptive.enabled", false)
+    }
+
+    it("Passed validate output rows") {
+      val left = sparkSession.createDataFrame(Seq(
+        (1.0, 1.0, "left_1"),
+        (2.0, 2.0, "left_2")
+      )).toDF("l_x", "l_y", "l_data")
+
+      val right = sparkSession.createDataFrame(Seq(
+        (2.0, 2.0, "right_2"),
+        (3.0, 3.0, "right_3")
+      )).toDF("r_x", "r_y", "r_data")
+
+      val joined = left.join(broadcast(right),
+        expr("ST_Intersects(ST_Point(l_x, l_y), ST_Point(r_x, r_y))"), 
"left_anti")
+      assert(joined.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+
+      val rows = joined.collect()
+      assert(rows.length == 1)
+      assert(rows(0) == Row(1.0, 1.0, "left_1"))
+    }
+  }
+
+  describe("Sedona-SQL Broadcast Index Join Test for left outer joins") {
+
+    // Using UDFs rather than lit prevents optimizations that would circumvent 
the checks we want to test
+    val one = udf(() => 1).asNondeterministic()
+    val two = udf(() => 2).asNondeterministic()
+
+    it("Passed Correct partitioning for broadcast join for ST_Polygon and 
ST_Point") {
+      val polygonDf = buildPolygonDf.repartition(3)
+      val pointDf = buildPointDf.repartition(5)
+
+      Seq(500, 900, 1000).foreach { limit =>
+        var broadcastJoinDf = pointDf.alias("pointDf").join(
+          broadcast(polygonDf.limit(limit)).alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"),
+          "left_outer")
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.rdd.getNumPartitions == 
pointDf.rdd.getNumPartitions)
+        assert(broadcastJoinDf.count() == 1000)
+
+        broadcastJoinDf = broadcast(polygonDf).alias("polygonDf").join(
+          pointDf.limit(limit).alias("pointDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"),
+          "left_outer")
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.count() == 1000)
+
+        broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+          polygonDf.limit(limit).alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"),
+          "left_outer")
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.count() == 1000)
+
+        broadcastJoinDf = polygonDf.alias("polygonDf").join(
+          broadcast(pointDf.limit(limit)).alias("pointDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"),
+          "left_outer")
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.rdd.getNumPartitions == 
polygonDf.rdd.getNumPartitions)
+        assert(broadcastJoinDf.count() == 1000)
+      }
+    }
+
+    it("Passed Broadcasts the left side if both sides have a broadcast hint") {
+      val polygonDf = buildPolygonDf.repartition(3)
+      val pointDf = buildPointDf.repartition(5)
+
+      val broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+        broadcast(polygonDf).alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_outer")
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.rdd.getNumPartitions == 
pointDf.rdd.getNumPartitions)
+      assert(broadcastJoinDf.count() == 1000)
+    }
+
+    it("Passed Can access attributes of both sides of broadcast join") {
+      val polygonDf = buildPolygonDf.withColumn("window_extra", one())
+      val pointDf = buildPointDf.withColumn("object_extra", one())
+
+      Seq(500, 900, 1000).foreach { limit =>
+        var broadcastJoinDf = polygonDf.alias("polygonDf").join(
+          broadcast(pointDf.limit(limit)).alias("pointDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_outer")
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0) 
== limit)
+        assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0) 
== 1000)
+
+        broadcastJoinDf = broadcast(polygonDf).alias("polygonDf").join(
+          pointDf.limit(limit).alias("pointDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_outer")
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0) 
== limit)
+        assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0) 
== 1000)
+
+        broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+          polygonDf.limit(limit).alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_outer")
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0) 
== 1000)
+        assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0) 
== limit)
+
+        broadcastJoinDf = pointDf.alias("pointDf").join(
+          broadcast(polygonDf.limit(limit)).alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_outer")
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0) 
== 1000)
+        assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0) 
== limit)
+      }
+    }
+
+    it("Passed Handles extra conditions on a broadcast join") {
+      val polygonDf = buildPolygonDf.withColumn("window_extra", one())
+      val pointDf = buildPointDf.withColumn("object_extra", two())
+
+      Seq(500, 900, 1000).foreach { limit =>
+        var broadcastJoinDf = pointDf
+          .alias("pointDf")
+          .join(
+            broadcast(polygonDf.limit(limit).alias("polygonDf")),
+            expr("ST_Contains(polygonshape, pointshape) AND window_extra <= 
object_extra"),
+            "left_outer"
+          )
+
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.count() == 1000)
+
+        broadcastJoinDf = pointDf
+          .alias("pointDf")
+          .join(
+            broadcast(polygonDf.limit(limit).alias("polygonDf")),
+            expr("ST_Contains(polygonshape, pointshape) AND window_extra > 
object_extra"),
+            "left_outer"
+          )
+
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.count() == 1000)
+
+        broadcastJoinDf = pointDf
+          .alias("pointDf")
+          .join(
+            broadcast(polygonDf.limit(limit).alias("polygonDf")),
+            expr("window_extra <= object_extra AND ST_Contains(polygonshape, 
pointshape)"),
+            "left_outer"
+          )
+
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.count() == 1000)
+
+        broadcastJoinDf = pointDf
+          .alias("pointDf")
+          .join(
+            broadcast(polygonDf.limit(limit).alias("polygonDf")),
+            expr("window_extra > object_extra AND ST_Contains(polygonshape, 
pointshape)"),
+            "left_outer"
+          )
+
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.count() == 1000)
+      }
+    }
+
+    it("Passed Handles multiple extra conditions on a broadcast join with the 
ST predicate last") {
+      val polygonDf = buildPolygonDf.withColumn("window_extra", 
one()).withColumn("window_extra2", one())
+      val pointDf = buildPointDf.withColumn("object_extra", 
two()).withColumn("object_extra2", two())
+
+      var broadcastJoinDf = pointDf
+        .alias("pointDf")
+        .join(
+          broadcast(polygonDf.alias("polygonDf")),
+          expr("window_extra <= object_extra AND window_extra2 <= 
object_extra2 AND ST_Contains(polygonshape, pointshape)"),
+          "left_outer"
+        )
+
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.count() == 1000)
+
+      broadcastJoinDf = pointDf
+        .alias("pointDf")
+        .join(
+          broadcast(polygonDf.alias("polygonDf")),
+          expr("window_extra > object_extra AND window_extra2 > object_extra2 
AND ST_Contains(polygonshape, pointshape)"),
+          "left_outer"
+        )
+
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.count() == 1000)
+    }
+
+    it("Passed ST_Distance <= distance in a broadcast join") {
+      val pointDf1 = buildPointDf
+      val pointDf2 = buildPointDf
+
+      var distanceJoinDf = pointDf1.alias("pointDf1").join(
+        broadcast(pointDf2).alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"), 
"left_outer")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 2998)
+
+      distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+        pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape, 
pointDf2.pointshape) <= 2"), "left_outer")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 2998)
+
+      distanceJoinDf = pointDf1.alias("pointDf1").join(
+        broadcast(pointDf2.limit(500)).alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"), 
"left_outer")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 1998)
+
+      distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+        pointDf2.limit(500).alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"), 
"left_outer")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 1998)
+    }
+
+    it("Passed ST_Distance < distance in a broadcast join") {
+      val pointDf1 = buildPointDf
+      val pointDf2 = buildPointDf
+
+      var distanceJoinDf = pointDf1.alias("pointDf1").join(
+        broadcast(pointDf2).alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < 2"), "left_outer")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 2998)
+
+      distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+        pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape, 
pointDf2.pointshape) < 2"), "left_outer")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 2998)
+
+      distanceJoinDf = pointDf1.alias("pointDf1").join(
+        broadcast(pointDf2.limit(500)).alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < 2"), "left_outer")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 1998)
+
+      distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+        pointDf2.limit(500).alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < 2"), "left_outer")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 1998)
+    }
+
+    it("Passed ST_Distance distance is bound to first expression") {
+      val pointDf1 = buildPointDf.withColumn("radius", two())
+      val pointDf2 = buildPointDf
+
+      var distanceJoinDf = pointDf1.alias("pointDf1").join(
+        broadcast(pointDf2).alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"), 
"left_outer")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 2998)
+
+      distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+        pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape, 
pointDf2.pointshape) < radius"), "left_outer")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 2998)
+
+      distanceJoinDf = pointDf2.alias("pointDf2").join(
+        broadcast(pointDf1).alias("pointDf1"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"), 
"left_outer")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 2998)
+
+      distanceJoinDf = broadcast(pointDf2).alias("pointDf2").join(
+        pointDf1.alias("pointDf1"), expr("ST_Distance(pointDf1.pointshape, 
pointDf2.pointshape) < radius"), "left_outer")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 2998)
+    }
+
+    it("Passed Correct partitioning for broadcast join for ST_Polygon and 
ST_Point with AQE enabled") {
+      sparkSession.conf.set("spark.sql.adaptive.enabled", true)
+      val polygonDf = buildPolygonDf.repartition(3)
+      val pointDf = buildPointDf.repartition(5)
+
+      var broadcastJoinDf = pointDf.alias("pointDf").join(
+        broadcast(polygonDf).alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_outer")
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.rdd.getNumPartitions == 
pointDf.rdd.getNumPartitions)
+      assert(broadcastJoinDf.count() == 1000)
+
+      broadcastJoinDf = broadcast(polygonDf).alias("polygonDf").join(
+        pointDf.alias("pointDf"), expr("ST_Contains(polygonDf.polygonshape, 
pointDf.pointshape)"), "left_outer")
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.count() == 1000)
+
+      broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+        polygonDf.alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_outer")
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.count() == 1000)
+
+      broadcastJoinDf = polygonDf.alias("polygonDf").join(
+        broadcast(pointDf).alias("pointDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "left_outer")
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.rdd.getNumPartitions == 
polygonDf.rdd.getNumPartitions)
+      assert(broadcastJoinDf.count() == 1000)
+      sparkSession.conf.set("spark.sql.adaptive.enabled", false)
+    }
+
+    it("Passed validate output rows") {
+      val left = sparkSession.createDataFrame(Seq(
+        (1.0, 1.0, "left_1"),
+        (2.0, 2.0, "left_2")
+      )).toDF("l_x", "l_y", "l_data")
+
+      val right = sparkSession.createDataFrame(Seq(
+        (2.0, 2.0, "right_2"),
+        (3.0, 3.0, "right_3")
+      )).toDF("r_x", "r_y", "r_data")
+
+      val joined = left.join(broadcast(right),
+        expr("ST_Intersects(ST_Point(l_x, l_y), ST_Point(r_x, r_y))"), 
"left_outer")
+      assert(joined.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+
+      val rows = joined.collect()
+      assert(rows.length == 2)
+      assert(rows(0) == Row(1.0, 1.0, "left_1", null, null, null))
+      assert(rows(1) == Row(2.0, 2.0, "left_2", 2.0, 2.0, "right_2"))
+    }
+  }
+
+  describe("Sedona-SQL Broadcast Index Join Test for right outer joins") {
+
+    // Using UDFs rather than lit prevents optimizations that would circumvent 
the checks we want to test
+    val one = udf(() => 1).asNondeterministic()
+    val two = udf(() => 2).asNondeterministic()
+
+    it("Passed Correct partitioning for broadcast join for ST_Polygon and 
ST_Point") {
+      val polygonDf = buildPolygonDf.repartition(3)
+      val pointDf = buildPointDf.repartition(5)
+
+      Seq(500, 900, 1000).foreach { limit =>
+        var broadcastJoinDf = pointDf.alias("pointDf").join(
+          broadcast(polygonDf.limit(limit)).alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "right_outer")
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.count() == limit)
+
+        broadcastJoinDf = broadcast(polygonDf).alias("polygonDf").join(
+          pointDf.limit(limit).alias("pointDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "right_outer")
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.count() == limit)
+
+        broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+          polygonDf.limit(limit).alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "right_outer")
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.count() == limit)
+
+        broadcastJoinDf = polygonDf.alias("polygonDf").join(
+          broadcast(pointDf.limit(limit)).alias("pointDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "right_outer")
+        assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+        assert(broadcastJoinDf.count() == limit)
+      }
+    }
+
+    it("Passed Broadcasts the left side if both sides have a broadcast hint") {
+      val polygonDf = buildPolygonDf.repartition(3)
+      val pointDf = buildPointDf.repartition(5)
+
+      val broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+        broadcast(polygonDf).alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "right_outer")
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.rdd.getNumPartitions == 
polygonDf.rdd.getNumPartitions)
+      assert(broadcastJoinDf.count() == 1000)
+    }
+
+    it("Passed Can access attributes of both sides of broadcast join") {
+      val polygonDf = buildPolygonDf.withColumn("window_extra", one())
+      val pointDf = buildPointDf.withColumn("object_extra", one())
+
+      var broadcastJoinDf = polygonDf.alias("polygonDf").join(
+        broadcast(pointDf).alias("pointDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "right_outer")
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0) == 
1000)
+      assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0) == 
1000)
+
+      broadcastJoinDf = broadcast(polygonDf).alias("polygonDf").join(
+        pointDf.alias("pointDf"), expr("ST_Contains(polygonDf.polygonshape, 
pointDf.pointshape)"), "right_outer")
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0) == 
1000)
+      assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0) == 
1000)
+
+      broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+        polygonDf.alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "right_outer")
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0) == 
1000)
+      assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0) == 
1000)
+
+      broadcastJoinDf = pointDf.alias("pointDf").join(
+        broadcast(polygonDf).alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "right_outer")
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0) == 
1000)
+      assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0) == 
1000)
+    }
+
+    it("Passed Handles extra conditions on a broadcast join") {
+      val polygonDf = buildPolygonDf.withColumn("window_extra", one())
+      val pointDf = buildPointDf.withColumn("object_extra", two())
+
+      var broadcastJoinDf = pointDf
+        .alias("pointDf")
+        .join(
+          broadcast(polygonDf.alias("polygonDf")),
+          expr("ST_Contains(polygonshape, pointshape) AND window_extra <= 
object_extra"),
+          "right_outer"
+        )
+
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.count() == 1000)
+
+      broadcastJoinDf = broadcast(pointDf)
+        .alias("pointDf")
+        .join(
+          polygonDf.alias("polygonDf"),
+          expr("ST_Contains(polygonshape, pointshape) AND window_extra > 
object_extra"),
+          "right_outer"
+        )
+
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.count() == 1000)
+
+      broadcastJoinDf = broadcast(pointDf)
+        .alias("pointDf")
+        .join(
+          polygonDf.alias("polygonDf"),
+          expr("window_extra <= object_extra AND ST_Contains(polygonshape, 
pointshape)"),
+          "right_outer"
+        )
+
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.count() == 1000)
+
+      broadcastJoinDf = pointDf
+        .alias("pointDf")
+        .join(
+          broadcast(polygonDf.alias("polygonDf")),
+          expr("window_extra > object_extra AND ST_Contains(polygonshape, 
pointshape)"),
+          "right_outer"
+        )
+
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.count() == 1000)
+    }
+
+    it("Passed Handles multiple extra conditions on a broadcast join with the 
ST predicate last") {
+      val polygonDf = buildPolygonDf.withColumn("window_extra", 
one()).withColumn("window_extra2", one())
+      val pointDf = buildPointDf.withColumn("object_extra", 
two()).withColumn("object_extra2", two())
+
+      var broadcastJoinDf = broadcast(pointDf)
+        .alias("pointDf")
+        .join(
+          polygonDf.alias("polygonDf"),
+          expr("window_extra <= object_extra AND window_extra2 <= 
object_extra2 AND ST_Contains(polygonshape, pointshape)"),
+          "right_outer"
+        )
+
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.count() == 1000)
+
+      broadcastJoinDf = broadcast(pointDf)
+        .alias("pointDf")
+        .join(
+          polygonDf.alias("polygonDf"),
+          expr("window_extra > object_extra AND window_extra2 > object_extra2 
AND ST_Contains(polygonshape, pointshape)"),
+          "right_outer"
+        )
+
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.count() == 1000)
+    }
+
+    it("Passed ST_Distance <= distance in a broadcast join") {
+      val pointDf1 = buildPointDf
+      val pointDf2 = buildPointDf
+
+      var distanceJoinDf = pointDf1.alias("pointDf1").join(
+        broadcast(pointDf2).alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"), 
"right_outer")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 2998)
+
+      distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+        pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape, 
pointDf2.pointshape) <= 2"), "right_outer")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 2998)
+
+      distanceJoinDf = pointDf1.alias("pointDf1").join(
+        broadcast(pointDf2.limit(500)).alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"), 
"right_outer")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 1499)
+
+      distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+        pointDf2.limit(500).alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"), 
"right_outer")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 1499)
+    }
+
+    it("Passed ST_Distance < distance in a broadcast join") {
+      val pointDf1 = buildPointDf
+      val pointDf2 = buildPointDf
+
+      var distanceJoinDf = pointDf1.alias("pointDf1").join(
+        broadcast(pointDf2).alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < 2"), 
"right_outer")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 2998)
+
+      distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+        pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape, 
pointDf2.pointshape) < 2"), "right_outer")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 2998)
+
+      distanceJoinDf = pointDf1.alias("pointDf1").join(
+        broadcast(pointDf2.limit(500)).alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < 2"), 
"right_outer")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 1499)
+
+      distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+        pointDf2.limit(500).alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < 2"), 
"right_outer")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 1499)
+    }
+
+    it("Passed ST_Distance distance is bound to first expression") {
+      val pointDf1 = buildPointDf.withColumn("radius", two())
+      val pointDf2 = buildPointDf
+
+      var distanceJoinDf = pointDf1.alias("pointDf1").join(
+        broadcast(pointDf2).alias("pointDf2"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"), 
"right_outer")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 2998)
+
+      distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+        pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape, 
pointDf2.pointshape) < radius"), "right_outer")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 2998)
+
+      distanceJoinDf = pointDf2.alias("pointDf2").join(
+        broadcast(pointDf1).alias("pointDf1"), 
expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"), 
"right_outer")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 2998)
+
+      distanceJoinDf = broadcast(pointDf2).alias("pointDf2").join(
+        pointDf1.alias("pointDf1"), expr("ST_Distance(pointDf1.pointshape, 
pointDf2.pointshape) < radius"), "right_outer")
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 2998)
+    }
+
+    it("Passed Correct partitioning for broadcast join for ST_Polygon and 
ST_Point with AQE enabled") {
+      sparkSession.conf.set("spark.sql.adaptive.enabled", true)
+      val polygonDf = buildPolygonDf.repartition(3)
+      val pointDf = buildPointDf.repartition(5)
+
+      var broadcastJoinDf = pointDf.alias("pointDf").join(
+        broadcast(polygonDf).alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "right_outer")
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.count() == 1000)
+
+      broadcastJoinDf = broadcast(polygonDf).alias("polygonDf").join(
+        pointDf.alias("pointDf"), expr("ST_Contains(polygonDf.polygonshape, 
pointDf.pointshape)"), "right_outer")
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.rdd.getNumPartitions == 
pointDf.rdd.getNumPartitions)
+      assert(broadcastJoinDf.count() == 1000)
+
+      broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(
+        polygonDf.alias("polygonDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "right_outer")
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.rdd.getNumPartitions == 
polygonDf.rdd.getNumPartitions)
+      assert(broadcastJoinDf.count() == 1000)
+
+      broadcastJoinDf = polygonDf.alias("polygonDf").join(
+        broadcast(pointDf).alias("pointDf"), 
expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"), "right_outer")
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect { case p: 
BroadcastNestedLoopJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.count() == 1000)
+      sparkSession.conf.set("spark.sql.adaptive.enabled", false)
+    }
+
+    it("Passed validate output rows") {
+      val left = sparkSession.createDataFrame(Seq(
+        (1.0, 1.0, "left_1"),
+        (2.0, 2.0, "left_2")
+      )).toDF("l_x", "l_y", "l_data")
+
+      val right = sparkSession.createDataFrame(Seq(
+        (2.0, 2.0, "right_2"),
+        (3.0, 3.0, "right_3")
+      )).toDF("r_x", "r_y", "r_data")
+
+      val joined = broadcast(left).join(right,
+        expr("ST_Intersects(ST_Point(l_x, l_y), ST_Point(r_x, r_y))"), 
"right_outer")
+      assert(joined.queryExecution.sparkPlan.collect { case p: 
BroadcastIndexJoinExec => p }.size === 1)
+
+      val rows = joined.collect()
+      assert(rows.length == 2)
+      assert(rows(0) == Row(2.0, 2.0, "left_2", 2.0, 2.0, "right_2"))
+      assert(rows(1) == Row(null, null, null, 3.0, 3.0, "right_3"))
+    }
   }
+
 }

Reply via email to