This is an automated email from the ASF dual-hosted git repository.

imbruced pushed a commit to branch fix-performance-issue-with-weighting
in repository https://gitbox.apache.org/repos/asf/sedona.git

commit f012d4200e9706e0a64e0cf1a35ad07c4a6cc5e9
Author: pawelkocinski <[email protected]>
AuthorDate: Tue Nov 11 14:22:19 2025 +0100

    SEDONA-748 Fix issue with no optimization for weighting function.
---
 .../scala/org/apache/sedona/stats/Weighting.scala  | 50 ++++++++++++++--------
 .../org/apache/sedona/stats/WeightingTest.scala    | 42 ++++++++++++++++--
 2 files changed, 70 insertions(+), 22 deletions(-)

diff --git 
a/spark/common/src/main/scala/org/apache/sedona/stats/Weighting.scala 
b/spark/common/src/main/scala/org/apache/sedona/stats/Weighting.scala
index d404f2c2db..aaad1b007b 100644
--- a/spark/common/src/main/scala/org/apache/sedona/stats/Weighting.scala
+++ b/spark/common/src/main/scala/org/apache/sedona/stats/Weighting.scala
@@ -109,30 +109,42 @@ object Weighting {
 
     val formattedDataFrame = dataframe.withColumn(ID_COLUMN, 
sha2(to_json(struct("*")), 256))
 
-    formattedDataFrame
+    val spatiallyJoined = formattedDataFrame
       .alias("l")
       .join(
         formattedDataFrame.alias("r"),
-        joinCondition && col(s"l.$ID_COLUMN") =!= col(
-          s"r.$ID_COLUMN"
-        ), // we will add self back later if self.includeSelf
+        joinCondition && col(s"l.$ID_COLUMN") =!= col(s"r.$ID_COLUMN"),
+        "inner")
+      .select(struct("l.*").alias("left"), struct("r.*").alias("right"))
+
+    val mapped = formattedDataFrame
+      .alias("f")
+      .join(
+        spatiallyJoined.alias("s"),
+        col(s"s.left.$ID_COLUMN") === col(s"f.$ID_COLUMN"),
         "left")
       .select(
-        col(s"l.$ID_COLUMN"),
-        struct("l.*").alias("left_contents"),
-        struct(
-          (
-            savedAttributesWithGeom match {
-              case null => struct(col("r.*")).dropFields(ID_COLUMN)
-              case _ =>
-                struct(savedAttributesWithGeom.map(c => col(s"r.$c")): _*)
-            }
-          ).alias("neighbor"),
-          if (!binary)
-            pow(distanceFunction(col(s"l.$geometryColumn"), 
col(s"r.$geometryColumn")), alpha)
-              .alias("value")
-          else lit(1.0).alias("value")).alias("weight"))
-      .groupBy(s"l.$ID_COLUMN")
+        col(ID_COLUMN),
+        struct("f.*").alias("left_contents"),
+        when(col(ID_COLUMN).isNull, lit(null))
+          .otherwise(struct(
+            (
+              savedAttributesWithGeom match {
+                case null => struct(col("s.right.*")).dropFields(ID_COLUMN)
+                case _ =>
+                  struct(savedAttributesWithGeom.map(c => col(s"s.right.$c")): 
_*)
+              }
+            ).alias("neighbor"),
+            if (!binary)
+              pow(
+                distanceFunction(col(s"s.left.$geometryColumn"), 
col(s"s.right.$geometryColumn")),
+                alpha)
+                .alias("value")
+            else lit(1.0).alias("value")))
+          .alias("weight"))
+
+    mapped
+      .groupBy(ID_COLUMN)
       .agg(
         first("left_contents").alias("left_contents"),
         concat(
diff --git 
a/spark/common/src/test/scala/org/apache/sedona/stats/WeightingTest.scala 
b/spark/common/src/test/scala/org/apache/sedona/stats/WeightingTest.scala
index a7a8865dda..4fcdaa654d 100644
--- a/spark/common/src/test/scala/org/apache/sedona/stats/WeightingTest.scala
+++ b/spark/common/src/test/scala/org/apache/sedona/stats/WeightingTest.scala
@@ -22,6 +22,8 @@ import org.apache.sedona.sql.TestBaseScala
 import org.apache.spark.sql.sedona_sql.expressions.st_constructors.ST_MakePoint
 import org.apache.spark.sql.{DataFrame, Row, functions => f}
 
+import java.io.{ByteArrayOutputStream, PrintStream}
+
 class WeightingTest extends TestBaseScala {
 
   case class Neighbors(id: Int, neighbor: Seq[Int])
@@ -78,6 +80,9 @@ class WeightingTest extends TestBaseScala {
           f.col("id"),
           f.array_sort(
             f.transform(f.col("weights"), w => 
w("neighbor")("id")).as("neighbor_ids")))
+
+      hasOptimizationTurnedOn(actualDf)
+
       val expectedDf = sparkSession.createDataFrame(
         Seq(
           Neighbors(0, Seq(1, 3, 5, 7)),
@@ -97,6 +102,7 @@ class WeightingTest extends TestBaseScala {
 
     it("return empty weights array when no neighbors") {
       val actualDf = Weighting.addDistanceBandColumn(getData(), .9)
+      hasOptimizationTurnedOn(actualDf)
 
       assert(actualDf.count() == 11)
       assert(actualDf.filter(f.size(f.col("weights")) > 0).count() == 0)
@@ -113,10 +119,12 @@ class WeightingTest extends TestBaseScala {
           f.col("id"),
           f.transform(f.col("weights"), w => 
w("neighbor")("id")).as("neighbor_ids"))
 
+      hasOptimizationTurnedOn(actualDfWithZeroDistanceNeighbors)
+
       assertDataFramesEqual(
         actualDfWithZeroDistanceNeighbors,
         sparkSession.createDataFrame(
-          Seq(Neighbors(0, Seq(1, 2)), Neighbors(1, Seq(0, 2)), Neighbors(2, 
Seq(0, 1)))))
+          Seq(Neighbors(0, Seq(2, 1)), Neighbors(1, Seq(2, 0)), Neighbors(2, 
Seq(1, 0)))))
 
       val actualDfWithoutZeroDistanceNeighbors = Weighting
         .addDistanceBandColumn(getDupedData(), 1.1)
@@ -127,15 +135,18 @@ class WeightingTest extends TestBaseScala {
       assertDataFramesEqual(
         actualDfWithoutZeroDistanceNeighbors,
         sparkSession.createDataFrame(
-          Seq(Neighbors(0, Seq(2)), Neighbors(1, Seq(2)), Neighbors(2, Seq(0, 
1)))))
+          Seq(Neighbors(0, Seq(2)), Neighbors(1, Seq(2)), Neighbors(2, Seq(1, 
0)))))
 
     }
 
     it("adds binary weights") {
-
       val result = Weighting.addDistanceBandColumn(getData(), 2.0, geometry = 
"geometry")
       val weights = result.select("weights").collect().map(_.getSeq[Row](0))
+      hasOptimizationTurnedOn(result)
+
       assert(weights.forall(_.forall(_.getAs[Double]("value") == 1.0)))
+
+      hasOptimizationTurnedOn(result)
     }
 
     it("adds non-binary weights when binary is false") {
@@ -148,6 +159,8 @@ class WeightingTest extends TestBaseScala {
         geometry = "geometry")
       val weights = result.select("weights").collect().map(_.getSeq[Row](0))
       assert(weights.exists(_.exists(_.getAs[Double]("value") != 1.0)))
+
+      hasOptimizationTurnedOn(result)
     }
 
     it("throws IllegalArgumentException when threshold is negative") {
@@ -175,4 +188,27 @@ class WeightingTest extends TestBaseScala {
       }
     }
   }
+
+  private def hasOptimizationTurnedOn(result: DataFrame) = {
+    val sparkPlan = captureStdOut(result.explain())
+
+    val distanceJoinOptimization = "DistanceJoin"
+
+    val occurrences =
+      sparkPlan.sliding(distanceJoinOptimization.length).count(_ == 
distanceJoinOptimization)
+
+    assert(occurrences == 1)
+  }
+
+  def captureStdOut(block: => Unit): String = {
+    val stream = new ByteArrayOutputStream()
+    val ps = new PrintStream(stream)
+
+    Console.withOut(ps) {
+      block
+    }
+
+    ps.flush()
+    stream.toString("UTF-8")
+  }
 }

Reply via email to