This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new af74a17cf1 [SEDONA-690] Optimize query side broadcast knn join (#1741)
af74a17cf1 is described below

commit af74a17cf1a274431542fd2ce74b86fb5cb2de52
Author: Feng Zhang <[email protected]>
AuthorDate: Sun Jan 5 06:49:01 2025 -0800

    [SEDONA-690] Optimize query side broadcast knn join (#1741)
    
    * [SEDONA-688] Verify KNN parameter K must be equal or larger than 1
    
    * [SEDONA-690] Optimize query side broadcast knn join
    
    * fix isGeography parameter
---
 .../core/joinJudgement/KnnJoinIndexJudgement.java  | 189 +++++++++++++++-----
 .../core/knnJudgement/EuclideanItemDistance.java   |   8 +
 .../core/knnJudgement/HaversineItemDistance.java   |   8 +
 .../sedona/core/knnJudgement/SpheroidDistance.java |   8 +
 .../sedona/core/spatialOperator/JoinQuery.java     | 190 ++++++++++++++++++---
 .../apache/sedona/core/wrapper/UniqueGeometry.java | 168 ++++++++++++++++++
 .../join/BroadcastQuerySideKNNJoinExec.scala       |  17 +-
 .../strategy/join/JoinQueryDetector.scala          |  43 ++---
 8 files changed, 535 insertions(+), 96 deletions(-)

diff --git 
a/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/KnnJoinIndexJudgement.java
 
b/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/KnnJoinIndexJudgement.java
index 1c7fe7a0ae..f5375009ed 100644
--- 
a/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/KnnJoinIndexJudgement.java
+++ 
b/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/KnnJoinIndexJudgement.java
@@ -25,6 +25,7 @@ import org.apache.sedona.core.enums.DistanceMetric;
 import org.apache.sedona.core.knnJudgement.EuclideanItemDistance;
 import org.apache.sedona.core.knnJudgement.HaversineItemDistance;
 import org.apache.sedona.core.knnJudgement.SpheroidDistance;
+import org.apache.sedona.core.wrapper.UniqueGeometry;
 import org.apache.spark.api.java.function.FlatMapFunction2;
 import org.apache.spark.broadcast.Broadcast;
 import org.apache.spark.util.LongAccumulator;
@@ -46,35 +47,43 @@ public class KnnJoinIndexJudgement<T extends Geometry, U 
extends Geometry>
     extends JudgementBase<T, U>
     implements FlatMapFunction2<Iterator<T>, Iterator<SpatialIndex>, Pair<U, 
T>>, Serializable {
   private final int k;
+  private final Double searchRadius;
   private final DistanceMetric distanceMetric;
   private final boolean includeTies;
-  private final Broadcast<STRtree> broadcastedTreeIndex;
+  private final Broadcast<List> broadcastQueryObjects;
+  private final Broadcast<STRtree> broadcastObjectsTreeIndex;
 
   /**
    * Constructor for the KnnJoinIndexJudgement class.
    *
    * @param k the number of nearest neighbors to find
+   * @param searchRadius
    * @param distanceMetric the distance metric to use
+   * @param broadcastQueryObjects the broadcast geometries on queries
+   * @param broadcastObjectsTreeIndex the broadcast spatial index on objects
    * @param buildCount accumulator for the number of geometries processed from 
the build side
    * @param streamCount accumulator for the number of geometries processed 
from the stream side
    * @param resultCount accumulator for the number of join results
    * @param candidateCount accumulator for the number of candidate matches
-   * @param broadcastedTreeIndex the broadcasted spatial index
    */
   public KnnJoinIndexJudgement(
       int k,
+      Double searchRadius,
       DistanceMetric distanceMetric,
       boolean includeTies,
-      Broadcast<STRtree> broadcastedTreeIndex,
+      Broadcast<List> broadcastQueryObjects,
+      Broadcast<STRtree> broadcastObjectsTreeIndex,
       LongAccumulator buildCount,
       LongAccumulator streamCount,
       LongAccumulator resultCount,
       LongAccumulator candidateCount) {
     super(null, buildCount, streamCount, resultCount, candidateCount);
     this.k = k;
+    this.searchRadius = searchRadius;
     this.distanceMetric = distanceMetric;
     this.includeTies = includeTies;
-    this.broadcastedTreeIndex = broadcastedTreeIndex;
+    this.broadcastQueryObjects = broadcastQueryObjects;
+    this.broadcastObjectsTreeIndex = broadcastObjectsTreeIndex;
   }
 
   /**
@@ -90,7 +99,7 @@ public class KnnJoinIndexJudgement<T extends Geometry, U 
extends Geometry>
   @Override
   public Iterator<Pair<U, T>> call(Iterator<T> streamShapes, 
Iterator<SpatialIndex> treeIndexes)
       throws Exception {
-    if (!treeIndexes.hasNext() || !streamShapes.hasNext()) {
+    if (!treeIndexes.hasNext() || (streamShapes != null && 
!streamShapes.hasNext())) {
       buildCount.add(0);
       streamCount.add(0);
       resultCount.add(0);
@@ -99,10 +108,9 @@ public class KnnJoinIndexJudgement<T extends Geometry, U 
extends Geometry>
     }
 
     STRtree strTree;
-    if (broadcastedTreeIndex != null) {
-      // get the broadcasted spatial index if available
-      // this is to support the broadcast join
-      strTree = broadcastedTreeIndex.getValue();
+    if (broadcastObjectsTreeIndex != null) {
+      // get the broadcast spatial index on objects side if available
+      strTree = broadcastObjectsTreeIndex.getValue();
     } else {
       // get the spatial index from the iterator
       SpatialIndex treeIndex = treeIndexes.next();
@@ -113,44 +121,133 @@ public class KnnJoinIndexJudgement<T extends Geometry, U 
extends Geometry>
       strTree = (STRtree) treeIndex;
     }
 
+    // TODO: For future improvement, instead of using a list to store the 
results,
+    // we can use lazy evaluation to avoid storing all the results in memory.
     List<Pair<U, T>> result = new ArrayList<>();
-    ItemDistance itemDistance;
 
-    while (streamShapes.hasNext()) {
-      T streamShape = streamShapes.next();
-      streamCount.add(1);
-
-      Object[] localK;
-      switch (distanceMetric) {
-        case EUCLIDEAN:
-          itemDistance = new EuclideanItemDistance();
-          break;
-        case HAVERSINE:
-          itemDistance = new HaversineItemDistance();
-          break;
-        case SPHEROID:
-          itemDistance = new SpheroidDistance();
-          break;
-        default:
-          itemDistance = new GeometryItemDistance();
-          break;
-      }
+    List queryItems;
+    if (broadcastQueryObjects != null) {
+      // get the broadcast spatial index on queries side if available
+      queryItems = broadcastQueryObjects.getValue();
+      for (Object item : queryItems) {
+        T queryGeom;
+        if (item instanceof UniqueGeometry) {
+          queryGeom = (T) ((UniqueGeometry) item).getOriginalGeometry();
+        } else {
+          queryGeom = (T) item;
+        }
+        streamCount.add(1);
 
-      localK =
-          strTree.nearestNeighbour(streamShape.getEnvelopeInternal(), 
streamShape, itemDistance, k);
-      if (includeTies) {
-        localK = getUpdatedLocalKWithTies(streamShape, localK, strTree);
+        Object[] localK =
+            strTree.nearestNeighbour(
+                queryGeom.getEnvelopeInternal(), queryGeom, getItemDistance(), 
k);
+        if (includeTies) {
+          localK = getUpdatedLocalKWithTies(queryGeom, localK, strTree);
+        }
+        if (searchRadius != null) {
+          localK = getInSearchRadius(localK, queryGeom);
+        }
+
+        for (Object obj : localK) {
+          T candidate = (T) obj;
+          Pair<U, T> pair = Pair.of((U) item, candidate);
+          result.add(pair);
+          resultCount.add(1);
+        }
       }
+      return result.iterator();
+    } else {
+      while (streamShapes.hasNext()) {
+        T streamShape = streamShapes.next();
+        streamCount.add(1);
+
+        Object[] localK =
+            strTree.nearestNeighbour(
+                streamShape.getEnvelopeInternal(), streamShape, 
getItemDistance(), k);
+        if (includeTies) {
+          localK = getUpdatedLocalKWithTies(streamShape, localK, strTree);
+        }
+        if (searchRadius != null) {
+          localK = getInSearchRadius(localK, streamShape);
+        }
 
-      for (Object obj : localK) {
-        T candidate = (T) obj;
-        Pair<U, T> pair = Pair.of((U) streamShape, candidate);
-        result.add(pair);
-        resultCount.add(1);
+        for (Object obj : localK) {
+          T candidate = (T) obj;
+          Pair<U, T> pair = Pair.of((U) streamShape, candidate);
+          result.add(pair);
+          resultCount.add(1);
+        }
       }
+      return result.iterator();
     }
+  }
 
-    return result.iterator();
+  private Object[] getInSearchRadius(Object[] localK, T queryGeom) {
+    localK =
+        Arrays.stream(localK)
+            .filter(
+                candidate -> {
+                  Geometry candidateGeom = (Geometry) candidate;
+                  return distanceByMetric(queryGeom, candidateGeom, 
distanceMetric) <= searchRadius;
+                })
+            .toArray();
+    return localK;
+  }
+
+  /**
+   * This method calculates the distance between two geometries using the 
specified distance metric.
+   *
+   * @param queryGeom the query geometry
+   * @param candidateGeom the candidate geometry
+   * @param distanceMetric the distance metric to use
+   * @return the distance between the two geometries
+   */
+  public static double distanceByMetric(
+      Geometry queryGeom, Geometry candidateGeom, DistanceMetric 
distanceMetric) {
+    switch (distanceMetric) {
+      case EUCLIDEAN:
+        EuclideanItemDistance euclideanItemDistance = new 
EuclideanItemDistance();
+        return euclideanItemDistance.distance(queryGeom, candidateGeom);
+      case HAVERSINE:
+        HaversineItemDistance haversineItemDistance = new 
HaversineItemDistance();
+        return haversineItemDistance.distance(queryGeom, candidateGeom);
+      case SPHEROID:
+        SpheroidDistance spheroidDistance = new SpheroidDistance();
+        return spheroidDistance.distance(queryGeom, candidateGeom);
+      default:
+        return queryGeom.distance(candidateGeom);
+    }
+  }
+
+  private ItemDistance getItemDistance() {
+    ItemDistance itemDistance;
+    itemDistance = getItemDistanceByMetric(distanceMetric);
+    return itemDistance;
+  }
+
+  /**
+   * This method returns the ItemDistance object based on the specified 
distance metric.
+   *
+   * @param distanceMetric the distance metric to use
+   * @return the ItemDistance object
+   */
+  public static ItemDistance getItemDistanceByMetric(DistanceMetric 
distanceMetric) {
+    ItemDistance itemDistance;
+    switch (distanceMetric) {
+      case EUCLIDEAN:
+        itemDistance = new EuclideanItemDistance();
+        break;
+      case HAVERSINE:
+        itemDistance = new HaversineItemDistance();
+        break;
+      case SPHEROID:
+        itemDistance = new SpheroidDistance();
+        break;
+      default:
+        itemDistance = new GeometryItemDistance();
+        break;
+    }
+    return itemDistance;
   }
 
   private Object[] getUpdatedLocalKWithTies(T streamShape, Object[] localK, 
STRtree strTree) {
@@ -184,4 +281,18 @@ public class KnnJoinIndexJudgement<T extends Geometry, U 
extends Geometry>
     }
     return localK;
   }
+
+  public static <U extends Geometry, T extends Geometry> double distance(
+      U key, T value, DistanceMetric distanceMetric) {
+    switch (distanceMetric) {
+      case EUCLIDEAN:
+        return new EuclideanItemDistance().distance(key, value);
+      case HAVERSINE:
+        return new HaversineItemDistance().distance(key, value);
+      case SPHEROID:
+        return new SpheroidDistance().distance(key, value);
+      default:
+        return new EuclideanItemDistance().distance(key, value);
+    }
+  }
 }
diff --git 
a/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/EuclideanItemDistance.java
 
b/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/EuclideanItemDistance.java
index a27bf543b1..1aba8f87f7 100644
--- 
a/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/EuclideanItemDistance.java
+++ 
b/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/EuclideanItemDistance.java
@@ -36,4 +36,12 @@ public class EuclideanItemDistance implements ItemDistance {
       return g1.distance(g2);
     }
   }
+
+  public double distance(Geometry geometry1, Geometry geometry2) {
+    if (geometry1 == geometry2) {
+      return Double.MAX_VALUE;
+    } else {
+      return geometry1.distance(geometry2);
+    }
+  }
 }
diff --git 
a/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/HaversineItemDistance.java
 
b/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/HaversineItemDistance.java
index 9ad1bfbee4..b04627074e 100644
--- 
a/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/HaversineItemDistance.java
+++ 
b/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/HaversineItemDistance.java
@@ -37,4 +37,12 @@ public class HaversineItemDistance implements ItemDistance {
       return Haversine.distance(g1, g2);
     }
   }
+
+  public double distance(Geometry geometry1, Geometry geometry2) {
+    if (geometry1 == geometry2) {
+      return Double.MAX_VALUE;
+    } else {
+      return Haversine.distance(geometry1, geometry2);
+    }
+  }
 }
diff --git 
a/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/SpheroidDistance.java
 
b/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/SpheroidDistance.java
index df22d3565e..4ecdbf84c6 100644
--- 
a/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/SpheroidDistance.java
+++ 
b/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/SpheroidDistance.java
@@ -37,4 +37,12 @@ public class SpheroidDistance implements ItemDistance {
       return Spheroid.distance(g1, g2);
     }
   }
+
+  public double distance(Geometry geometry1, Geometry geometry2) {
+    if (geometry1 == geometry2) {
+      return Double.MAX_VALUE;
+    } else {
+      return Spheroid.distance(geometry1, geometry2);
+    }
+  }
 }
diff --git 
a/spark/common/src/main/java/org/apache/sedona/core/spatialOperator/JoinQuery.java
 
b/spark/common/src/main/java/org/apache/sedona/core/spatialOperator/JoinQuery.java
index d20563d279..a5665726e0 100644
--- 
a/spark/common/src/main/java/org/apache/sedona/core/spatialOperator/JoinQuery.java
+++ 
b/spark/common/src/main/java/org/apache/sedona/core/spatialOperator/JoinQuery.java
@@ -18,10 +18,7 @@
  */
 package org.apache.sedona.core.spatialOperator;
 
-import java.util.ArrayList;
-import java.util.Iterator;
-import java.util.List;
-import java.util.Objects;
+import java.util.*;
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.log4j.LogManager;
 import org.apache.log4j.Logger;
@@ -35,15 +32,18 @@ import org.apache.sedona.core.monitoring.Metrics;
 import org.apache.sedona.core.spatialPartitioning.SpatialPartitioner;
 import org.apache.sedona.core.spatialRDD.CircleRDD;
 import org.apache.sedona.core.spatialRDD.SpatialRDD;
+import org.apache.sedona.core.wrapper.UniqueGeometry;
 import org.apache.spark.SparkContext;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.FlatMapFunction;
 import org.apache.spark.api.java.function.Function2;
 import org.apache.spark.api.java.function.PairFunction;
 import org.apache.spark.broadcast.Broadcast;
 import org.apache.spark.util.LongAccumulator;
 import org.locationtech.jts.geom.Geometry;
+import org.locationtech.jts.index.SpatialIndex;
 import org.locationtech.jts.index.strtree.STRtree;
 import scala.Tuple2;
 
@@ -784,47 +784,82 @@ public class JoinQuery {
     LongAccumulator resultCount = Metrics.createMetric(sparkContext, 
"resultCount");
     LongAccumulator candidateCount = Metrics.createMetric(sparkContext, 
"candidateCount");
 
-    final Broadcast<STRtree> broadcastedTreeIndex;
-    if (broadcastJoin) {
-      // adjust auto broadcast threshold to avoid building index on large RDDs
+    final Broadcast<STRtree> broadcastObjectsTreeIndex;
+    final Broadcast<List> broadcastQueryObjects;
+    if (broadcastJoin && objectRDD.indexedRawRDD != null && 
objectRDD.indexedRDD == null) {
+      // If broadcastJoin is true and rawIndex is created on object side
+      // we will broadcast queryRDD to objectRDD
+      List<UniqueGeometry<U>> uniqueQueryObjects = new ArrayList<>();
+      for (U queryObject : queryRDD.rawSpatialRDD.collect()) {
+        // Wrap the query objects in a UniqueGeometry object to count for 
duplicate queries in the
+        // join
+        uniqueQueryObjects.add(new UniqueGeometry<>(queryObject));
+      }
+      broadcastQueryObjects =
+          
JavaSparkContext.fromSparkContext(sparkContext).broadcast(uniqueQueryObjects);
+      broadcastObjectsTreeIndex = null;
+    } else if (broadcastJoin && objectRDD.indexedRawRDD == null && 
objectRDD.indexedRDD == null) {
+      // If broadcastJoin is true and index and rawIndex are NOT created on 
object side
+      // we will broadcast objectRDD to queryRDD
       STRtree strTree = objectRDD.coalesceAndBuildRawIndex(IndexType.RTREE);
-      broadcastedTreeIndex = 
JavaSparkContext.fromSparkContext(sparkContext).broadcast(strTree);
+      broadcastObjectsTreeIndex =
+          JavaSparkContext.fromSparkContext(sparkContext).broadcast(strTree);
+      broadcastQueryObjects = null;
     } else {
-      broadcastedTreeIndex = null;
+      // Regular join does not need to set broadcast inderx
+      broadcastQueryObjects = null;
+      broadcastObjectsTreeIndex = null;
     }
 
     // The reason for using objectRDD as the right side is that the partitions 
are built on the
     // right side.
     final JavaRDD<Pair<U, T>> joinResult;
-    if (objectRDD.indexedRDD != null) {
+    if (broadcastObjectsTreeIndex == null && broadcastQueryObjects == null) {
+      // no broadcast join
       final KnnJoinIndexJudgement judgement =
           new KnnJoinIndexJudgement(
               joinParams.k,
+              joinParams.searchRadius,
               joinParams.distanceMetric,
               includeTies,
-              broadcastedTreeIndex,
+              null,
+              null,
               buildCount,
               streamCount,
               resultCount,
               candidateCount);
       joinResult = 
queryRDD.spatialPartitionedRDD.zipPartitions(objectRDD.indexedRDD, judgement);
-    } else if (broadcastedTreeIndex != null) {
+    } else if (broadcastObjectsTreeIndex != null) {
+      // broadcast join with objectRDD as broadcast side
       final KnnJoinIndexJudgement judgement =
           new KnnJoinIndexJudgement(
               joinParams.k,
+              joinParams.searchRadius,
               joinParams.distanceMetric,
               includeTies,
-              broadcastedTreeIndex,
+              null,
+              broadcastObjectsTreeIndex,
               buildCount,
               streamCount,
               resultCount,
               candidateCount);
-      int numPartitionsObjects = objectRDD.rawSpatialRDD.getNumPartitions();
-      joinResult =
-          queryRDD
-              .rawSpatialRDD
-              .repartition(numPartitionsObjects)
-              .zipPartitions(objectRDD.rawSpatialRDD, judgement);
+      // won't need inputs from the shapes in the objectRDD
+      joinResult = 
queryRDD.rawSpatialRDD.zipPartitions(queryRDD.rawSpatialRDD, judgement);
+    } else if (broadcastQueryObjects != null) {
+      // broadcast join with queryRDD as broadcast side
+      final KnnJoinIndexJudgement judgement =
+          new KnnJoinIndexJudgement(
+              joinParams.k,
+              joinParams.searchRadius,
+              joinParams.distanceMetric,
+              includeTies,
+              broadcastQueryObjects,
+              null,
+              buildCount,
+              streamCount,
+              resultCount,
+              candidateCount);
+      joinResult = querySideBroadcastKNNJoin(objectRDD, joinParams, judgement, 
includeTies);
     } else {
       throw new IllegalArgumentException("No index found on the input RDDs.");
     }
@@ -833,6 +868,123 @@ public class JoinQuery {
         (PairFunction<Pair<U, T>, U, T>) pair -> new Tuple2<>(pair.getKey(), 
pair.getValue()));
   }
 
+  /**
+   * Performs a KNN join where the query side is broadcasted.
+   *
+   * <p>This function performs a K-Nearest Neighbors (KNN) join operation 
where the query geometries
+   * are broadcasted to all partitions of the object geometries.
+   *
+   * <p>The function first maps partitions of the indexed raw RDD to perform 
the KNN join, then
+   * groups the results by the query geometry and keeps the top K pair for 
each query geometry based
+   * on the distance.
+   *
+   * @param objectRDD The set of geometries (neighbors) to be queried.
+   * @param joinParams The parameters for the join, including index type, 
number of neighbors (k),
+   *     and distance metric.
+   * @param judgement The judgement function used to perform the KNN join.
+   * @param <U> The type of the geometries in the queryRDD set.
+   * @param <T> The type of the geometries in the objectRDD set.
+   * @return A JavaRDD of pairs where each pair contains a geometry from the 
queryRDD and a matching
+   *     geometry from the objectRDD.
+   */
+  private static <U extends Geometry, T extends Geometry>
+      JavaRDD<Pair<U, T>> querySideBroadcastKNNJoin(
+          SpatialRDD<T> objectRDD,
+          JoinParams joinParams,
+          KnnJoinIndexJudgement judgement,
+          boolean includeTies) {
+    final JavaRDD<Pair<U, T>> joinResult;
+    JavaRDD<Pair<U, T>> joinResultMapped =
+        objectRDD.indexedRawRDD.mapPartitions(
+            iterator -> {
+              List<Pair<U, T>> results = new ArrayList<>();
+              if (iterator.hasNext()) {
+                SpatialIndex spatialIndex = iterator.next();
+                // the broadcast join won't need inputs from the query's shape 
stream
+                Iterator<Pair<U, T>> callResult =
+                    judgement.call(null, 
Collections.singletonList(spatialIndex).iterator());
+                callResult.forEachRemaining(results::add);
+              }
+              return results.iterator();
+            });
+    // this is to avoid serializable issues with the broadcast variable
+    int k = joinParams.k;
+    DistanceMetric distanceMetric = joinParams.distanceMetric;
+
+    // Transform joinResultMapped to keep the top k pairs for each geometry
+    // (based on a grouping key and distance)
+    joinResult =
+        joinResultMapped
+            .groupBy(pair -> pair.getKey()) // Group by the first geometry
+            .flatMap(
+                (FlatMapFunction<Tuple2<U, Iterable<Pair<U, T>>>, Pair<U, T>>)
+                    pair -> {
+                      Iterable<Pair<U, T>> values = pair._2;
+
+                      // Extract and sort values by distance
+                      List<Pair<U, T>> sortedPairs = new ArrayList<>();
+                      for (Pair<U, T> p : values) {
+                        Pair<U, T> newPair =
+                            Pair.of(
+                                (U) ((UniqueGeometry<?>) 
p.getKey()).getOriginalGeometry(),
+                                p.getValue());
+                        sortedPairs.add(newPair);
+                      }
+
+                      // Sort pairs based on the distance function between the 
two geometries
+                      sortedPairs.sort(
+                          (p1, p2) -> {
+                            double distance1 =
+                                KnnJoinIndexJudgement.distance(
+                                    p1.getKey(), p1.getValue(), 
distanceMetric);
+                            double distance2 =
+                                KnnJoinIndexJudgement.distance(
+                                    p2.getKey(), p2.getValue(), 
distanceMetric);
+                            return Double.compare(
+                                distance1, distance2); // Sort ascending by 
distance
+                          });
+
+                      if (includeTies) {
+                        // Keep the top k pairs, including ties
+                        List<Pair<U, T>> topPairs = new ArrayList<>();
+                        double kthDistance = -1;
+                        for (int i = 0; i < sortedPairs.size(); i++) {
+                          if (i < k) {
+                            topPairs.add(sortedPairs.get(i));
+                            if (i == k - 1) {
+                              kthDistance =
+                                  KnnJoinIndexJudgement.distance(
+                                      sortedPairs.get(i).getKey(),
+                                      sortedPairs.get(i).getValue(),
+                                      distanceMetric);
+                            }
+                          } else {
+                            double currentDistance =
+                                KnnJoinIndexJudgement.distance(
+                                    sortedPairs.get(i).getKey(),
+                                    sortedPairs.get(i).getValue(),
+                                    distanceMetric);
+                            if (currentDistance == kthDistance) {
+                              topPairs.add(sortedPairs.get(i));
+                            } else {
+                              break;
+                            }
+                          }
+                        }
+                        return topPairs.iterator();
+                      } else {
+                        // Keep the top k pairs without ties
+                        List<Pair<U, T>> topPairs = new ArrayList<>();
+                        for (int i = 0; i < Math.min(k, sortedPairs.size()); 
i++) {
+                          topPairs.add(sortedPairs.get(i));
+                        }
+                        return topPairs.iterator();
+                      }
+                    });
+
+    return joinResult;
+  }
+
   public static final class JoinParams {
     public final boolean useIndex;
     public final SpatialPredicate spatialPredicate;
diff --git 
a/spark/common/src/main/java/org/apache/sedona/core/wrapper/UniqueGeometry.java 
b/spark/common/src/main/java/org/apache/sedona/core/wrapper/UniqueGeometry.java
new file mode 100644
index 0000000000..01f20f2fa6
--- /dev/null
+++ 
b/spark/common/src/main/java/org/apache/sedona/core/wrapper/UniqueGeometry.java
@@ -0,0 +1,168 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.core.wrapper;
+
+import java.util.UUID;
+import org.apache.commons.lang3.NotImplementedException;
+import org.locationtech.jts.geom.*;
+
+public class UniqueGeometry<T> extends Geometry {
+  private final T originalGeometry;
+  private final String uniqueId;
+
+  public UniqueGeometry(T originalGeometry) {
+    super(new GeometryFactory());
+    this.originalGeometry = originalGeometry;
+    this.uniqueId = UUID.randomUUID().toString();
+  }
+
+  public T getOriginalGeometry() {
+    return originalGeometry;
+  }
+
+  public String getUniqueId() {
+    return uniqueId;
+  }
+
+  @Override
+  public int hashCode() {
+    return uniqueId.hashCode(); // Uniqueness ensured by uniqueId
+  }
+
+  @Override
+  public String getGeometryType() {
+    throw new NotImplementedException("getGeometryType is not implemented.");
+  }
+
+  @Override
+  public Coordinate getCoordinate() {
+    throw new NotImplementedException("getCoordinate is not implemented.");
+  }
+
+  @Override
+  public Coordinate[] getCoordinates() {
+    throw new NotImplementedException("getCoordinates is not implemented.");
+  }
+
+  @Override
+  public int getNumPoints() {
+    throw new NotImplementedException("getNumPoints is not implemented.");
+  }
+
+  @Override
+  public boolean isEmpty() {
+    throw new NotImplementedException("isEmpty is not implemented.");
+  }
+
+  @Override
+  public int getDimension() {
+    throw new NotImplementedException("getDimension is not implemented.");
+  }
+
+  @Override
+  public Geometry getBoundary() {
+    throw new NotImplementedException("getBoundary is not implemented.");
+  }
+
+  @Override
+  public int getBoundaryDimension() {
+    throw new NotImplementedException("getBoundaryDimension is not 
implemented.");
+  }
+
+  @Override
+  public boolean equals(Object obj) {
+    if (this == obj) return true;
+    if (obj == null || getClass() != obj.getClass()) return false;
+    UniqueGeometry<?> that = (UniqueGeometry<?>) obj;
+    return uniqueId.equals(that.uniqueId);
+  }
+
+  @Override
+  public String toString() {
+    return "UniqueGeometry{"
+        + "originalGeometry="
+        + originalGeometry
+        + ", uniqueId='"
+        + uniqueId
+        + '\''
+        + '}';
+  }
+
+  @Override
+  protected Geometry reverseInternal() {
+    throw new NotImplementedException("reverseInternal is not implemented.");
+  }
+
+  @Override
+  public boolean equalsExact(Geometry geometry, double v) {
+    throw new NotImplementedException("equalsExact is not implemented.");
+  }
+
+  @Override
+  public void apply(CoordinateFilter coordinateFilter) {
+    throw new NotImplementedException("apply(CoordinateFilter) is not 
implemented.");
+  }
+
+  @Override
+  public void apply(CoordinateSequenceFilter coordinateSequenceFilter) {
+    throw new NotImplementedException("apply(CoordinateSequenceFilter) is not 
implemented.");
+  }
+
+  @Override
+  public void apply(GeometryFilter geometryFilter) {
+    throw new NotImplementedException("apply(GeometryFilter) is not 
implemented.");
+  }
+
+  @Override
+  public void apply(GeometryComponentFilter geometryComponentFilter) {
+    throw new NotImplementedException("apply(GeometryComponentFilter) is not 
implemented.");
+  }
+
+  @Override
+  protected Geometry copyInternal() {
+    throw new NotImplementedException("copyInternal is not implemented.");
+  }
+
+  @Override
+  public void normalize() {
+    throw new NotImplementedException("normalize is not implemented.");
+  }
+
+  @Override
+  protected Envelope computeEnvelopeInternal() {
+    throw new NotImplementedException("computeEnvelopeInternal is not 
implemented.");
+  }
+
+  @Override
+  protected int compareToSameClass(Object o) {
+    throw new NotImplementedException("compareToSameClass(Object) is not 
implemented.");
+  }
+
+  @Override
+  protected int compareToSameClass(
+      Object o, CoordinateSequenceComparator coordinateSequenceComparator) {
+    throw new NotImplementedException(
+        "compareToSameClass(Object, CoordinateSequenceComparator) is not 
implemented.");
+  }
+
+  @Override
+  protected int getTypeCode() {
+    throw new NotImplementedException("getTypeCode is not implemented.");
+  }
+}
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala
index 001c0a1ca3..9ce40c6d42 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala
@@ -130,19 +130,10 @@ case class BroadcastQuerySideKNNJoinExec(
     require(kValue >= 1, "The number of neighbors (k) must be equal or greater 
than 1.")
     objectsShapes.setNeighborSampleNumber(kValue)
 
-    val joinPartitions: Integer = numPartitions
-    broadcastJoin = false
-
-    // expand the boundary for partition to include both RDDs
-    objectsShapes.analyze()
-    queryShapes.analyze()
-    
objectsShapes.boundaryEnvelope.expandToInclude(queryShapes.boundaryEnvelope)
-
-    objectsShapes.spatialPartitioning(GridType.QUADTREE_RTREE, joinPartitions)
-    queryShapes.spatialPartitioning(
-      
objectsShapes.getPartitioner.asInstanceOf[QuadTreeRTPartitioner].nonOverlappedPartitioner())
-
-    objectsShapes.buildIndex(IndexType.RTREE, true)
+    // index the objects on regular partitions (not spatial partitions)
+    // this avoids the cost of spatial partitioning
+    objectsShapes.buildIndex(IndexType.RTREE, false)
+    broadcastJoin = true
   }
 
   /**
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
index da9bd5359b..b89b1adeda 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
@@ -589,7 +589,14 @@ class JoinQueryDetector(sparkSession: SparkSession) 
extends Strategy {
     val leftShape = children.head
     val rightShape = children.tail.head
 
-    val querySide = getKNNQuerySide(left, leftShape)
+    val querySide = matchExpressionsToPlans(leftShape, rightShape, left, 
right) match {
+      case Some((_, _, false)) =>
+        LeftSide
+      case Some((_, _, true)) =>
+        RightSide
+      case None =>
+        Nil
+    }
     val objectSidePlan = if (querySide == LeftSide) right else left
 
     checkObjectPlanFilterPushdown(objectSidePlan)
@@ -722,7 +729,14 @@ class JoinQueryDetector(sparkSession: SparkSession) 
extends Strategy {
         val leftShape = children.head
         val rightShape = children.tail.head
 
-        val querySide = getKNNQuerySide(left, leftShape)
+        val querySide = matchExpressionsToPlans(leftShape, rightShape, left, 
right) match {
+          case Some((_, _, false)) =>
+            LeftSide
+          case Some((_, _, true)) =>
+            RightSide
+          case None =>
+            Nil
+        }
         val objectSidePlan = if (querySide == LeftSide) right else left
 
         checkObjectPlanFilterPushdown(objectSidePlan)
@@ -739,7 +753,7 @@ class JoinQueryDetector(sparkSession: SparkSession) extends 
Strategy {
             k = distance.get,
             useApproximate = false,
             spatialPredicate,
-            isGeography = false,
+            isGeography,
             condition = null,
             extraCondition = None) :: Nil
         } else {
@@ -754,7 +768,7 @@ class JoinQueryDetector(sparkSession: SparkSession) extends 
Strategy {
             k = distance.get,
             useApproximate = false,
             spatialPredicate,
-            isGeography = false,
+            isGeography,
             condition = null,
             extraCondition = None) :: Nil
         }
@@ -865,27 +879,6 @@ class JoinQueryDetector(sparkSession: SparkSession) 
extends Strategy {
     }
   }
 
-  /**
-   * Gets the query and object plans based on the left shape.
-   *
-   * This method checks if the left shape is part of the left or right plan 
and returns the query
-   * and object plans accordingly.
-   *
-   * @param leftShape
-   *   The left shape expression.
-   * @return
-   *   The join side where the left shape is located.
-   */
-  private def getKNNQuerySide(left: LogicalPlan, leftShape: Expression) = {
-    val isLeftQuerySide =
-      
left.toString().toLowerCase().contains(leftShape.toString().toLowerCase())
-    if (isLeftQuerySide) {
-      LeftSide
-    } else {
-      RightSide
-    }
-  }
-
   /**
    * Check if the given condition is an equi-join between the given plans. 
This method basically
    * replicates the logic of


Reply via email to