This is an automated email from the ASF dual-hosted git repository. jiayu pushed a commit to branch prepare-1.7.2 in repository https://gitbox.apache.org/repos/asf/sedona.git
commit d930769adfaa69ec49f59deb92719325b67f6478 Author: Feng Zhang <[email protected]> AuthorDate: Thu Apr 10 09:52:58 2025 -0700 [SEDONA-690] Set default metric to use Haversine for KNN join and code refactoring (#1909) * [SEDONA-690] Set default metric to use Haversine for KNN join and some code refactor * fix unit tests * clean up join params --- .../joinJudgement/InMemoryKNNJoinIterator.java | 155 ++++++++++++++++ .../core/joinJudgement/KnnJoinIndexJudgement.java | 200 +++++++-------------- .../sedona/core/spatialOperator/JoinQuery.java | 169 ++++++++--------- .../join/BroadcastObjectSideKNNJoinExec.scala | 4 +- .../join/BroadcastQuerySideKNNJoinExec.scala | 4 +- .../sql/sedona_sql/strategy/join/KNNJoinExec.scala | 4 +- .../scala/org/apache/sedona/sql/KnnJoinSuite.scala | 28 +++ 7 files changed, 329 insertions(+), 235 deletions(-) diff --git a/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/InMemoryKNNJoinIterator.java b/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/InMemoryKNNJoinIterator.java new file mode 100644 index 0000000000..54ba42485e --- /dev/null +++ b/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/InMemoryKNNJoinIterator.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.core.joinJudgement; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.NoSuchElementException; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.sedona.core.enums.DistanceMetric; +import org.apache.sedona.core.wrapper.UniqueGeometry; +import org.apache.spark.util.LongAccumulator; +import org.locationtech.jts.geom.Envelope; +import org.locationtech.jts.geom.Geometry; +import org.locationtech.jts.index.strtree.ItemDistance; +import org.locationtech.jts.index.strtree.STRtree; + +public class InMemoryKNNJoinIterator<T extends Geometry, U extends Geometry> + implements Iterator<Pair<T, U>> { + private final Iterator<T> querySideIterator; + private final STRtree strTree; + + private final int k; + private final DistanceMetric distanceMetric; + private final boolean includeTies; + private final ItemDistance itemDistance; + + private final LongAccumulator streamCount; + private final LongAccumulator resultCount; + + private final List<Pair<T, U>> currentResults = new ArrayList<>(); + private int currentResultIndex = 0; + + public InMemoryKNNJoinIterator( + Iterator<T> querySideIterator, + STRtree strTree, + int k, + DistanceMetric distanceMetric, + boolean includeTies, + LongAccumulator streamCount, + LongAccumulator resultCount) { + this.querySideIterator = querySideIterator; + this.strTree = strTree; + + this.k = k; + this.distanceMetric = distanceMetric; + this.includeTies = includeTies; + this.itemDistance = KnnJoinIndexJudgement.getItemDistance(distanceMetric); + + this.streamCount = streamCount; + this.resultCount = resultCount; + } + + @Override + public boolean hasNext() { + if (currentResultIndex < currentResults.size()) { + return true; + } + + currentResultIndex = 0; + currentResults.clear(); + while (querySideIterator.hasNext()) { + populateNextBatch(); + if (!currentResults.isEmpty()) { + return true; + } + } + + return false; + } + + @Override + public Pair<T, U> next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + + return currentResults.get(currentResultIndex++); + } + + private void populateNextBatch() { + T queryItem = querySideIterator.next(); + Geometry queryGeom; + if (queryItem instanceof UniqueGeometry) { + queryGeom = (Geometry) ((UniqueGeometry<?>) queryItem).getOriginalGeometry(); + } else { + queryGeom = queryItem; + } + streamCount.add(1); + + Object[] localK = + strTree.nearestNeighbour(queryGeom.getEnvelopeInternal(), queryGeom, itemDistance, k); + if (includeTies) { + localK = getUpdatedLocalKWithTies(queryGeom, localK, strTree); + } + + for (Object obj : localK) { + U candidate = (U) obj; + Pair<T, U> pair = Pair.of(queryItem, candidate); + currentResults.add(pair); + resultCount.add(1); + } + } + + private Object[] getUpdatedLocalKWithTies( + Geometry streamShape, Object[] localK, STRtree strTree) { + Envelope searchEnvelope = streamShape.getEnvelopeInternal(); + // get the maximum distance from the k nearest neighbors + double maxDistance = 0.0; + LinkedHashSet<U> uniqueCandidates = new LinkedHashSet<>(); + for (Object obj : localK) { + U candidate = (U) obj; + uniqueCandidates.add(candidate); + double distance = streamShape.distance(candidate); + if (distance > maxDistance) { + maxDistance = distance; + } + } + searchEnvelope.expandBy(maxDistance); + List<U> candidates = strTree.query(searchEnvelope); + if (!candidates.isEmpty()) { + // update localK with all candidates that are within the maxDistance + List<Object> tiedResults = new ArrayList<>(); + // add all localK + Collections.addAll(tiedResults, localK); + + for (U candidate : candidates) { + double distance = streamShape.distance(candidate); + if (distance == maxDistance && !uniqueCandidates.contains(candidate)) { + tiedResults.add(candidate); + } + } + localK = tiedResults.toArray(); + } + return localK; + } +} diff --git a/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/KnnJoinIndexJudgement.java b/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/KnnJoinIndexJudgement.java index f5375009ed..0dda586986 100644 --- a/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/KnnJoinIndexJudgement.java +++ b/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/KnnJoinIndexJudgement.java @@ -19,19 +19,18 @@ package org.apache.sedona.core.joinJudgement; import java.io.Serializable; -import java.util.*; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; import org.apache.commons.lang3.tuple.Pair; import org.apache.sedona.core.enums.DistanceMetric; import org.apache.sedona.core.knnJudgement.EuclideanItemDistance; import org.apache.sedona.core.knnJudgement.HaversineItemDistance; import org.apache.sedona.core.knnJudgement.SpheroidDistance; -import org.apache.sedona.core.wrapper.UniqueGeometry; import org.apache.spark.api.java.function.FlatMapFunction2; import org.apache.spark.broadcast.Broadcast; import org.apache.spark.util.LongAccumulator; -import org.locationtech.jts.geom.Envelope; import org.locationtech.jts.geom.Geometry; -import org.locationtech.jts.index.SpatialIndex; import org.locationtech.jts.index.strtree.GeometryItemDistance; import org.locationtech.jts.index.strtree.ItemDistance; import org.locationtech.jts.index.strtree.STRtree; @@ -45,19 +44,17 @@ import org.locationtech.jts.index.strtree.STRtree; */ public class KnnJoinIndexJudgement<T extends Geometry, U extends Geometry> extends JudgementBase<T, U> - implements FlatMapFunction2<Iterator<T>, Iterator<SpatialIndex>, Pair<U, T>>, Serializable { + implements FlatMapFunction2<Iterator<T>, Iterator<U>, Pair<T, U>>, Serializable { private final int k; - private final Double searchRadius; private final DistanceMetric distanceMetric; private final boolean includeTies; - private final Broadcast<List> broadcastQueryObjects; + private final Broadcast<List<T>> broadcastQueryObjects; private final Broadcast<STRtree> broadcastObjectsTreeIndex; /** * Constructor for the KnnJoinIndexJudgement class. * * @param k the number of nearest neighbors to find - * @param searchRadius * @param distanceMetric the distance metric to use * @param broadcastQueryObjects the broadcast geometries on queries * @param broadcastObjectsTreeIndex the broadcast spatial index on objects @@ -68,10 +65,9 @@ public class KnnJoinIndexJudgement<T extends Geometry, U extends Geometry> */ public KnnJoinIndexJudgement( int k, - Double searchRadius, DistanceMetric distanceMetric, boolean includeTies, - Broadcast<List> broadcastQueryObjects, + Broadcast<List<T>> broadcastQueryObjects, Broadcast<STRtree> broadcastObjectsTreeIndex, LongAccumulator buildCount, LongAccumulator streamCount, @@ -79,7 +75,6 @@ public class KnnJoinIndexJudgement<T extends Geometry, U extends Geometry> LongAccumulator candidateCount) { super(null, buildCount, streamCount, resultCount, candidateCount); this.k = k; - this.searchRadius = searchRadius; this.distanceMetric = distanceMetric; this.includeTies = includeTies; this.broadcastQueryObjects = broadcastQueryObjects; @@ -91,15 +86,15 @@ public class KnnJoinIndexJudgement<T extends Geometry, U extends Geometry> * and uses the spatial index to find the k nearest neighbors for each geometry. The method * returns an iterator over the join results. * - * @param streamShapes iterator over the geometries in the stream side - * @param treeIndexes iterator over the spatial indexes + * @param queryShapes iterator over the geometries in the query side + * @param objectShapes iterator over the geometries in the object side * @return an iterator over the join results * @throws Exception if the spatial index is not of type STRtree */ @Override - public Iterator<Pair<U, T>> call(Iterator<T> streamShapes, Iterator<SpatialIndex> treeIndexes) + public Iterator<Pair<T, U>> call(Iterator<T> queryShapes, Iterator<U> objectShapes) throws Exception { - if (!treeIndexes.hasNext() || (streamShapes != null && !streamShapes.hasNext())) { + if (!objectShapes.hasNext() || (queryShapes != null && !queryShapes.hasNext())) { buildCount.add(0); streamCount.add(0); resultCount.add(0); @@ -107,91 +102,64 @@ public class KnnJoinIndexJudgement<T extends Geometry, U extends Geometry> return Collections.emptyIterator(); } - STRtree strTree; - if (broadcastObjectsTreeIndex != null) { - // get the broadcast spatial index on objects side if available - strTree = broadcastObjectsTreeIndex.getValue(); - } else { - // get the spatial index from the iterator - SpatialIndex treeIndex = treeIndexes.next(); - if (!(treeIndex instanceof STRtree)) { - throw new Exception( - "[KnnJoinIndexJudgement][Call] Only STRtree index supports KNN search."); - } - strTree = (STRtree) treeIndex; - } - - // TODO: For future improvement, instead of using a list to store the results, - // we can use lazy evaluation to avoid storing all the results in memory. - List<Pair<U, T>> result = new ArrayList<>(); - - List queryItems; - if (broadcastQueryObjects != null) { - // get the broadcast spatial index on queries side if available - queryItems = broadcastQueryObjects.getValue(); - for (Object item : queryItems) { - T queryGeom; - if (item instanceof UniqueGeometry) { - queryGeom = (T) ((UniqueGeometry) item).getOriginalGeometry(); - } else { - queryGeom = (T) item; - } - streamCount.add(1); - - Object[] localK = - strTree.nearestNeighbour( - queryGeom.getEnvelopeInternal(), queryGeom, getItemDistance(), k); - if (includeTies) { - localK = getUpdatedLocalKWithTies(queryGeom, localK, strTree); - } - if (searchRadius != null) { - localK = getInSearchRadius(localK, queryGeom); - } + STRtree strTree = buildSTRtree(objectShapes); + return new InMemoryKNNJoinIterator<>( + queryShapes, strTree, k, distanceMetric, includeTies, streamCount, resultCount); + } - for (Object obj : localK) { - T candidate = (T) obj; - Pair<U, T> pair = Pair.of((U) item, candidate); - result.add(pair); - resultCount.add(1); - } - } - return result.iterator(); - } else { - while (streamShapes.hasNext()) { - T streamShape = streamShapes.next(); - streamCount.add(1); + /** + * This method performs the KNN join operation using the broadcast spatial index built using all + * geometries in the object side. + * + * @param queryShapes iterator over the geometries in the query side + * @return an iterator over the join results + */ + public Iterator<Pair<T, U>> callUsingBroadcastObjectIndex(Iterator<T> queryShapes) { + if (!queryShapes.hasNext()) { + buildCount.add(0); + streamCount.add(0); + resultCount.add(0); + candidateCount.add(0); + return Collections.emptyIterator(); + } - Object[] localK = - strTree.nearestNeighbour( - streamShape.getEnvelopeInternal(), streamShape, getItemDistance(), k); - if (includeTies) { - localK = getUpdatedLocalKWithTies(streamShape, localK, strTree); - } - if (searchRadius != null) { - localK = getInSearchRadius(localK, streamShape); - } + // There's no need to use external spatial index, since the object side is small enough to be + // broadcasted, the STRtree built from the broadcasted object should be able to fit into memory. + STRtree strTree = broadcastObjectsTreeIndex.getValue(); + return new InMemoryKNNJoinIterator<>( + queryShapes, strTree, k, distanceMetric, includeTies, streamCount, resultCount); + } - for (Object obj : localK) { - T candidate = (T) obj; - Pair<U, T> pair = Pair.of((U) streamShape, candidate); - result.add(pair); - resultCount.add(1); - } - } - return result.iterator(); + /** + * This method performs the KNN join operation using the broadcast query geometries. + * + * @param objectShapes iterator over the geometries in the object side + * @return an iterator over the join results + */ + public Iterator<Pair<T, U>> callUsingBroadcastQueryList(Iterator<U> objectShapes) { + if (!objectShapes.hasNext()) { + buildCount.add(0); + streamCount.add(0); + resultCount.add(0); + candidateCount.add(0); + return Collections.emptyIterator(); } + + List<T> queryItems = broadcastQueryObjects.getValue(); + STRtree strTree = buildSTRtree(objectShapes); + return new InMemoryKNNJoinIterator<>( + queryItems.iterator(), strTree, k, distanceMetric, includeTies, streamCount, resultCount); } - private Object[] getInSearchRadius(Object[] localK, T queryGeom) { - localK = - Arrays.stream(localK) - .filter( - candidate -> { - Geometry candidateGeom = (Geometry) candidate; - return distanceByMetric(queryGeom, candidateGeom, distanceMetric) <= searchRadius; - }) - .toArray(); - return localK; + private STRtree buildSTRtree(Iterator<U> objectShapes) { + STRtree strTree = new STRtree(); + while (objectShapes.hasNext()) { + U spatialObject = objectShapes.next(); + strTree.insert(spatialObject.getEnvelopeInternal(), spatialObject); + buildCount.add(1); + } + strTree.build(); + return strTree; } /** @@ -219,12 +187,6 @@ public class KnnJoinIndexJudgement<T extends Geometry, U extends Geometry> } } - private ItemDistance getItemDistance() { - ItemDistance itemDistance; - itemDistance = getItemDistanceByMetric(distanceMetric); - return itemDistance; - } - /** * This method returns the ItemDistance object based on the specified distance metric. * @@ -250,38 +212,6 @@ public class KnnJoinIndexJudgement<T extends Geometry, U extends Geometry> return itemDistance; } - private Object[] getUpdatedLocalKWithTies(T streamShape, Object[] localK, STRtree strTree) { - Envelope searchEnvelope = streamShape.getEnvelopeInternal(); - // get the maximum distance from the k nearest neighbors - double maxDistance = 0.0; - LinkedHashSet<T> uniqueCandidates = new LinkedHashSet<>(); - for (Object obj : localK) { - T candidate = (T) obj; - uniqueCandidates.add(candidate); - double distance = streamShape.distance(candidate); - if (distance > maxDistance) { - maxDistance = distance; - } - } - searchEnvelope.expandBy(maxDistance); - List<T> candidates = strTree.query(searchEnvelope); - if (!candidates.isEmpty()) { - // update localK with all candidates that are within the maxDistance - List<Object> tiedResults = new ArrayList<>(); - // add all localK - Collections.addAll(tiedResults, localK); - - for (T candidate : candidates) { - double distance = streamShape.distance(candidate); - if (distance == maxDistance && !uniqueCandidates.contains(candidate)) { - tiedResults.add(candidate); - } - } - localK = tiedResults.toArray(); - } - return localK; - } - public static <U extends Geometry, T extends Geometry> double distance( U key, T value, DistanceMetric distanceMetric) { switch (distanceMetric) { @@ -295,4 +225,10 @@ public class KnnJoinIndexJudgement<T extends Geometry, U extends Geometry> return new EuclideanItemDistance().distance(key, value); } } + + public static ItemDistance getItemDistance(DistanceMetric distanceMetric) { + ItemDistance itemDistance; + itemDistance = getItemDistanceByMetric(distanceMetric); + return itemDistance; + } } diff --git a/spark/common/src/main/java/org/apache/sedona/core/spatialOperator/JoinQuery.java b/spark/common/src/main/java/org/apache/sedona/core/spatialOperator/JoinQuery.java index a5665726e0..7b55dd0763 100644 --- a/spark/common/src/main/java/org/apache/sedona/core/spatialOperator/JoinQuery.java +++ b/spark/common/src/main/java/org/apache/sedona/core/spatialOperator/JoinQuery.java @@ -37,13 +37,11 @@ import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.broadcast.Broadcast; import org.apache.spark.util.LongAccumulator; import org.locationtech.jts.geom.Geometry; -import org.locationtech.jts.index.SpatialIndex; import org.locationtech.jts.index.strtree.STRtree; import scala.Tuple2; @@ -401,7 +399,7 @@ public class JoinQuery { DistanceMetric distanceMetric) throws Exception { final JoinParams joinParams = - new JoinParams(true, null, IndexType.RTREE, null, k, distanceMetric, null); + new JoinParams(true, null, IndexType.RTREE, null, k, distanceMetric); final JavaPairRDD<U, T> joinResults = knnJoin(queryRDD, objectRDD, joinParams, false, false); return collectGeometriesByKey(joinResults); @@ -785,7 +783,7 @@ public class JoinQuery { LongAccumulator candidateCount = Metrics.createMetric(sparkContext, "candidateCount"); final Broadcast<STRtree> broadcastObjectsTreeIndex; - final Broadcast<List> broadcastQueryObjects; + final Broadcast<List<UniqueGeometry<U>>> broadcastQueryObjects; if (broadcastJoin && objectRDD.indexedRawRDD != null && objectRDD.indexedRDD == null) { // If broadcastJoin is true and rawIndex is created on object side // we will broadcast queryRDD to objectRDD @@ -816,10 +814,9 @@ public class JoinQuery { final JavaRDD<Pair<U, T>> joinResult; if (broadcastObjectsTreeIndex == null && broadcastQueryObjects == null) { // no broadcast join - final KnnJoinIndexJudgement judgement = - new KnnJoinIndexJudgement( + final KnnJoinIndexJudgement<U, T> judgement = + new KnnJoinIndexJudgement<>( joinParams.k, - joinParams.searchRadius, joinParams.distanceMetric, includeTies, null, @@ -828,13 +825,13 @@ public class JoinQuery { streamCount, resultCount, candidateCount); - joinResult = queryRDD.spatialPartitionedRDD.zipPartitions(objectRDD.indexedRDD, judgement); + joinResult = + queryRDD.spatialPartitionedRDD.zipPartitions(objectRDD.spatialPartitionedRDD, judgement); } else if (broadcastObjectsTreeIndex != null) { // broadcast join with objectRDD as broadcast side - final KnnJoinIndexJudgement judgement = - new KnnJoinIndexJudgement( + final KnnJoinIndexJudgement<U, T> judgement = + new KnnJoinIndexJudgement<>( joinParams.k, - joinParams.searchRadius, joinParams.distanceMetric, includeTies, null, @@ -844,13 +841,12 @@ public class JoinQuery { resultCount, candidateCount); // won't need inputs from the shapes in the objectRDD - joinResult = queryRDD.rawSpatialRDD.zipPartitions(queryRDD.rawSpatialRDD, judgement); - } else if (broadcastQueryObjects != null) { + joinResult = queryRDD.rawSpatialRDD.mapPartitions(judgement::callUsingBroadcastObjectIndex); + } else { // broadcast join with queryRDD as broadcast side - final KnnJoinIndexJudgement judgement = - new KnnJoinIndexJudgement( + final KnnJoinIndexJudgement<UniqueGeometry<U>, T> judgement = + new KnnJoinIndexJudgement<>( joinParams.k, - joinParams.searchRadius, joinParams.distanceMetric, includeTies, broadcastQueryObjects, @@ -860,8 +856,6 @@ public class JoinQuery { resultCount, candidateCount); joinResult = querySideBroadcastKNNJoin(objectRDD, joinParams, judgement, includeTies); - } else { - throw new IllegalArgumentException("No index found on the input RDDs."); } return joinResult.mapToPair( @@ -891,22 +885,11 @@ public class JoinQuery { JavaRDD<Pair<U, T>> querySideBroadcastKNNJoin( SpatialRDD<T> objectRDD, JoinParams joinParams, - KnnJoinIndexJudgement judgement, + KnnJoinIndexJudgement<UniqueGeometry<U>, T> judgement, boolean includeTies) { final JavaRDD<Pair<U, T>> joinResult; - JavaRDD<Pair<U, T>> joinResultMapped = - objectRDD.indexedRawRDD.mapPartitions( - iterator -> { - List<Pair<U, T>> results = new ArrayList<>(); - if (iterator.hasNext()) { - SpatialIndex spatialIndex = iterator.next(); - // the broadcast join won't need inputs from the query's shape stream - Iterator<Pair<U, T>> callResult = - judgement.call(null, Collections.singletonList(spatialIndex).iterator()); - callResult.forEachRemaining(results::add); - } - return results.iterator(); - }); + JavaRDD<Pair<UniqueGeometry<U>, T>> joinResultMapped = + objectRDD.rawSpatialRDD.mapPartitions(judgement::callUsingBroadcastQueryList); // this is to avoid serializable issues with the broadcast variable int k = joinParams.k; DistanceMetric distanceMetric = joinParams.distanceMetric; @@ -915,72 +898,67 @@ public class JoinQuery { // (based on a grouping key and distance) joinResult = joinResultMapped - .groupBy(pair -> pair.getKey()) // Group by the first geometry + .groupBy(pair -> pair.getKey().getUniqueId()) .flatMap( - (FlatMapFunction<Tuple2<U, Iterable<Pair<U, T>>>, Pair<U, T>>) - pair -> { - Iterable<Pair<U, T>> values = pair._2; - - // Extract and sort values by distance - List<Pair<U, T>> sortedPairs = new ArrayList<>(); - for (Pair<U, T> p : values) { - Pair<U, T> newPair = - Pair.of( - (U) ((UniqueGeometry<?>) p.getKey()).getOriginalGeometry(), - p.getValue()); - sortedPairs.add(newPair); - } - - // Sort pairs based on the distance function between the two geometries - sortedPairs.sort( - (p1, p2) -> { - double distance1 = - KnnJoinIndexJudgement.distance( - p1.getKey(), p1.getValue(), distanceMetric); - double distance2 = - KnnJoinIndexJudgement.distance( - p2.getKey(), p2.getValue(), distanceMetric); - return Double.compare( - distance1, distance2); // Sort ascending by distance - }); - - if (includeTies) { - // Keep the top k pairs, including ties - List<Pair<U, T>> topPairs = new ArrayList<>(); - double kthDistance = -1; - for (int i = 0; i < sortedPairs.size(); i++) { - if (i < k) { - topPairs.add(sortedPairs.get(i)); - if (i == k - 1) { - kthDistance = - KnnJoinIndexJudgement.distance( - sortedPairs.get(i).getKey(), - sortedPairs.get(i).getValue(), - distanceMetric); - } - } else { - double currentDistance = - KnnJoinIndexJudgement.distance( - sortedPairs.get(i).getKey(), - sortedPairs.get(i).getValue(), - distanceMetric); - if (currentDistance == kthDistance) { - topPairs.add(sortedPairs.get(i)); - } else { - break; - } - } + pair -> { + Iterable<Pair<UniqueGeometry<U>, T>> values = pair._2; + + // Extract and sort values by distance + List<Pair<U, T>> sortedPairs = new ArrayList<>(); + for (Pair<UniqueGeometry<U>, T> p : values) { + Pair<U, T> newPair = Pair.of(p.getKey().getOriginalGeometry(), p.getValue()); + sortedPairs.add(newPair); + } + + // Sort pairs based on the distance function between the two geometries + sortedPairs.sort( + (p1, p2) -> { + double distance1 = + KnnJoinIndexJudgement.distance( + p1.getKey(), p1.getValue(), distanceMetric); + double distance2 = + KnnJoinIndexJudgement.distance( + p2.getKey(), p2.getValue(), distanceMetric); + return Double.compare(distance1, distance2); // Sort ascending by distance + }); + + if (includeTies) { + // Keep the top k pairs, including ties + List<Pair<U, T>> topPairs = new ArrayList<>(); + double kthDistance = -1; + for (int i = 0; i < sortedPairs.size(); i++) { + if (i < k) { + topPairs.add(sortedPairs.get(i)); + if (i == k - 1) { + kthDistance = + KnnJoinIndexJudgement.distance( + sortedPairs.get(i).getKey(), + sortedPairs.get(i).getValue(), + distanceMetric); } - return topPairs.iterator(); } else { - // Keep the top k pairs without ties - List<Pair<U, T>> topPairs = new ArrayList<>(); - for (int i = 0; i < Math.min(k, sortedPairs.size()); i++) { + double currentDistance = + KnnJoinIndexJudgement.distance( + sortedPairs.get(i).getKey(), + sortedPairs.get(i).getValue(), + distanceMetric); + if (currentDistance == kthDistance) { topPairs.add(sortedPairs.get(i)); + } else { + break; } - return topPairs.iterator(); } - }); + } + return topPairs.iterator(); + } else { + // Keep the top k pairs without ties + List<Pair<U, T>> topPairs = new ArrayList<>(); + for (int i = 0; i < Math.min(k, sortedPairs.size()); i++) { + topPairs.add(sortedPairs.get(i)); + } + return topPairs.iterator(); + } + }); return joinResult; } @@ -994,14 +972,13 @@ public class JoinQuery { // KNN specific parameters public final int k; public final DistanceMetric distanceMetric; - public final Double searchRadius; public JoinParams( boolean useIndex, SpatialPredicate spatialPredicate, IndexType polygonIndexType, JoinBuildSide joinBuildSide) { - this(useIndex, spatialPredicate, polygonIndexType, joinBuildSide, -1, null, null); + this(useIndex, spatialPredicate, polygonIndexType, joinBuildSide, -1, null); } public JoinParams( @@ -1010,15 +987,13 @@ public class JoinQuery { IndexType polygonIndexType, JoinBuildSide joinBuildSide, int k, - DistanceMetric distanceMetric, - Double searchRadius) { + DistanceMetric distanceMetric) { this.useIndex = useIndex; this.spatialPredicate = spatialPredicate; this.indexType = polygonIndexType; this.joinBuildSide = joinBuildSide; this.k = k; this.distanceMetric = distanceMetric; - this.searchRadius = searchRadius; } public JoinParams(boolean useIndex, SpatialPredicate spatialPredicate) { diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastObjectSideKNNJoinExec.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastObjectSideKNNJoinExec.scala index c5777be3c1..f4bdae40d5 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastObjectSideKNNJoinExec.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastObjectSideKNNJoinExec.scala @@ -138,9 +138,9 @@ case class BroadcastObjectSideKNNJoinExec( // Number of neighbors to find val kValue: Int = this.k.eval().asInstanceOf[Int] // Metric to use in the join to calculate the distance, only Euclidean and Spheroid are supported - val distanceMetric = if (isGeography) DistanceMetric.SPHEROID else DistanceMetric.EUCLIDEAN + val distanceMetric = if (isGeography) DistanceMetric.HAVERSINE else DistanceMetric.EUCLIDEAN val joinParams = - new JoinParams(true, null, IndexType.RTREE, null, kValue, distanceMetric, null) + new JoinParams(true, null, IndexType.RTREE, null, kValue, distanceMetric) joinParams } diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala index 9ce40c6d42..575ded9125 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala @@ -149,9 +149,9 @@ case class BroadcastQuerySideKNNJoinExec( // Number of neighbors to find val kValue: Int = this.k.eval().asInstanceOf[Int] // Metric to use in the join to calculate the distance, only Euclidean and Spheroid are supported - val distanceMetric = if (isGeography) DistanceMetric.SPHEROID else DistanceMetric.EUCLIDEAN + val distanceMetric = if (isGeography) DistanceMetric.HAVERSINE else DistanceMetric.EUCLIDEAN val joinParams = - new JoinParams(true, null, IndexType.RTREE, null, kValue, distanceMetric, null) + new JoinParams(true, null, IndexType.RTREE, null, kValue, distanceMetric) joinParams } diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/KNNJoinExec.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/KNNJoinExec.scala index fdc53d13ce..a879447d40 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/KNNJoinExec.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/KNNJoinExec.scala @@ -209,9 +209,9 @@ case class KNNJoinExec( // Number of neighbors to find val kValue: Int = this.k.eval().asInstanceOf[Int] // Metric to use in the join to calculate the distance, only Euclidean and Spheroid are supported - val distanceMetric = if (isGeography) DistanceMetric.SPHEROID else DistanceMetric.EUCLIDEAN + val distanceMetric = if (isGeography) DistanceMetric.HAVERSINE else DistanceMetric.EUCLIDEAN val joinParams = - new JoinParams(true, null, IndexType.RTREE, null, kValue, distanceMetric, null) + new JoinParams(true, null, IndexType.RTREE, null, kValue, distanceMetric) joinParams } } diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala b/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala index ab2c64898a..50696337f7 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala @@ -458,6 +458,34 @@ class KnnJoinSuite extends TestBaseScala with TableDrivenPropertyChecks { df2.cache() df1.join(df2, expr("ST_KNN(geom1, geom2, 1)")).count() should be(0) } + + it("KNN Join using spider data source") { + val dfRandomSquares = sparkSession.read + .format("spider") + .option("n", "10000") + .option("distribution", "parcel") + .option("dither", "0.5") + .option("splitRange", "0.5") + .load() + + dfRandomSquares.createOrReplaceTempView("df_random_squares") + + val dfRandomPoints = sparkSession.read + .format("spider") + .option("n", "1000") + .option("distribution", "uniform") + .load() + + dfRandomPoints.createOrReplaceTempView("df_random_points") + + // Execute a KNN join query: attribute points to the nearest square + val knnJoined = sparkSession.sql("""SELECT sq.id, pt.id + |FROM df_random_squares sq + |JOIN df_random_points pt + |ON ST_KNN(sq.geometry, pt.geometry, 1, TRUE)""".stripMargin) + + assert(knnJoined.count() > 0) + } } private def withOptimizationMode(mode: String)(body: => Unit): Unit = {
