This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new d1aed8531a [SEDONA-748] Fix issue with no optimization for weighting
function (#2490)
d1aed8531a is described below
commit d1aed8531a12724705fd62de4f967c54dd4375ba
Author: Paweł Tokaj <[email protected]>
AuthorDate: Thu Nov 13 01:22:11 2025 +0100
[SEDONA-748] Fix issue with no optimization for weighting function (#2490)
---
.../scala/org/apache/sedona/stats/Weighting.scala | 30 ++++++++++------
.../org/apache/sedona/stats/WeightingTest.scala | 42 ++++++++++++++++++++--
2 files changed, 59 insertions(+), 13 deletions(-)
diff --git
a/spark/common/src/main/scala/org/apache/sedona/stats/Weighting.scala
b/spark/common/src/main/scala/org/apache/sedona/stats/Weighting.scala
index d404f2c2db..bd2ac8ed4a 100644
--- a/spark/common/src/main/scala/org/apache/sedona/stats/Weighting.scala
+++ b/spark/common/src/main/scala/org/apache/sedona/stats/Weighting.scala
@@ -109,30 +109,40 @@ object Weighting {
val formattedDataFrame = dataframe.withColumn(ID_COLUMN,
sha2(to_json(struct("*")), 256))
- formattedDataFrame
+ val spatiallyJoined = formattedDataFrame
.alias("l")
.join(
formattedDataFrame.alias("r"),
- joinCondition && col(s"l.$ID_COLUMN") =!= col(
- s"r.$ID_COLUMN"
- ), // we will add self back later if self.includeSelf
+ joinCondition && col(s"l.$ID_COLUMN") =!= col(s"r.$ID_COLUMN"),
+ "inner")
+ .select(struct("l.*").alias("left"), struct("r.*").alias("right"))
+
+ val mapped = formattedDataFrame
+ .alias("f")
+ .join(
+ spatiallyJoined.alias("s"),
+ col(s"s.left.$ID_COLUMN") === col(s"f.$ID_COLUMN"),
"left")
.select(
- col(s"l.$ID_COLUMN"),
- struct("l.*").alias("left_contents"),
+ col(ID_COLUMN),
+ struct("f.*").alias("left_contents"),
struct(
(
savedAttributesWithGeom match {
- case null => struct(col("r.*")).dropFields(ID_COLUMN)
+ case null => struct(col("s.right.*")).dropFields(ID_COLUMN)
case _ =>
- struct(savedAttributesWithGeom.map(c => col(s"r.$c")): _*)
+ struct(savedAttributesWithGeom.map(c => col(s"s.right.$c")):
_*)
}
).alias("neighbor"),
if (!binary)
- pow(distanceFunction(col(s"l.$geometryColumn"),
col(s"r.$geometryColumn")), alpha)
+ pow(
+ distanceFunction(col(s"s.left.$geometryColumn"),
col(s"s.right.$geometryColumn")),
+ alpha)
.alias("value")
else lit(1.0).alias("value")).alias("weight"))
- .groupBy(s"l.$ID_COLUMN")
+
+ mapped
+ .groupBy(ID_COLUMN)
.agg(
first("left_contents").alias("left_contents"),
concat(
diff --git
a/spark/common/src/test/scala/org/apache/sedona/stats/WeightingTest.scala
b/spark/common/src/test/scala/org/apache/sedona/stats/WeightingTest.scala
index a7a8865dda..4fcdaa654d 100644
--- a/spark/common/src/test/scala/org/apache/sedona/stats/WeightingTest.scala
+++ b/spark/common/src/test/scala/org/apache/sedona/stats/WeightingTest.scala
@@ -22,6 +22,8 @@ import org.apache.sedona.sql.TestBaseScala
import org.apache.spark.sql.sedona_sql.expressions.st_constructors.ST_MakePoint
import org.apache.spark.sql.{DataFrame, Row, functions => f}
+import java.io.{ByteArrayOutputStream, PrintStream}
+
class WeightingTest extends TestBaseScala {
case class Neighbors(id: Int, neighbor: Seq[Int])
@@ -78,6 +80,9 @@ class WeightingTest extends TestBaseScala {
f.col("id"),
f.array_sort(
f.transform(f.col("weights"), w =>
w("neighbor")("id")).as("neighbor_ids")))
+
+ hasOptimizationTurnedOn(actualDf)
+
val expectedDf = sparkSession.createDataFrame(
Seq(
Neighbors(0, Seq(1, 3, 5, 7)),
@@ -97,6 +102,7 @@ class WeightingTest extends TestBaseScala {
it("return empty weights array when no neighbors") {
val actualDf = Weighting.addDistanceBandColumn(getData(), .9)
+ hasOptimizationTurnedOn(actualDf)
assert(actualDf.count() == 11)
assert(actualDf.filter(f.size(f.col("weights")) > 0).count() == 0)
@@ -113,10 +119,12 @@ class WeightingTest extends TestBaseScala {
f.col("id"),
f.transform(f.col("weights"), w =>
w("neighbor")("id")).as("neighbor_ids"))
+ hasOptimizationTurnedOn(actualDfWithZeroDistanceNeighbors)
+
assertDataFramesEqual(
actualDfWithZeroDistanceNeighbors,
sparkSession.createDataFrame(
- Seq(Neighbors(0, Seq(1, 2)), Neighbors(1, Seq(0, 2)), Neighbors(2,
Seq(0, 1)))))
+ Seq(Neighbors(0, Seq(2, 1)), Neighbors(1, Seq(2, 0)), Neighbors(2,
Seq(1, 0)))))
val actualDfWithoutZeroDistanceNeighbors = Weighting
.addDistanceBandColumn(getDupedData(), 1.1)
@@ -127,15 +135,18 @@ class WeightingTest extends TestBaseScala {
assertDataFramesEqual(
actualDfWithoutZeroDistanceNeighbors,
sparkSession.createDataFrame(
- Seq(Neighbors(0, Seq(2)), Neighbors(1, Seq(2)), Neighbors(2, Seq(0,
1)))))
+ Seq(Neighbors(0, Seq(2)), Neighbors(1, Seq(2)), Neighbors(2, Seq(1,
0)))))
}
it("adds binary weights") {
-
val result = Weighting.addDistanceBandColumn(getData(), 2.0, geometry =
"geometry")
val weights = result.select("weights").collect().map(_.getSeq[Row](0))
+ hasOptimizationTurnedOn(result)
+
assert(weights.forall(_.forall(_.getAs[Double]("value") == 1.0)))
+
+ hasOptimizationTurnedOn(result)
}
it("adds non-binary weights when binary is false") {
@@ -148,6 +159,8 @@ class WeightingTest extends TestBaseScala {
geometry = "geometry")
val weights = result.select("weights").collect().map(_.getSeq[Row](0))
assert(weights.exists(_.exists(_.getAs[Double]("value") != 1.0)))
+
+ hasOptimizationTurnedOn(result)
}
it("throws IllegalArgumentException when threshold is negative") {
@@ -175,4 +188,27 @@ class WeightingTest extends TestBaseScala {
}
}
}
+
+ private def hasOptimizationTurnedOn(result: DataFrame) = {
+ val sparkPlan = captureStdOut(result.explain())
+
+ val distanceJoinOptimization = "DistanceJoin"
+
+ val occurrences =
+ sparkPlan.sliding(distanceJoinOptimization.length).count(_ ==
distanceJoinOptimization)
+
+ assert(occurrences == 1)
+ }
+
+ def captureStdOut(block: => Unit): String = {
+ val stream = new ByteArrayOutputStream()
+ val ps = new PrintStream(stream)
+
+ Console.withOut(ps) {
+ block
+ }
+
+ ps.flush()
+ stream.toString("UTF-8")
+ }
}