This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new 03cb6709f [SEDONA-560] Properly handle dataframes containing 0 
partitions when running spatial join (#1430)
03cb6709f is described below

commit 03cb6709fd2d2d1c72cec9d7661b2534bd282283
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Mon May 27 12:34:10 2024 +0800

    [SEDONA-560] Properly handle dataframes containing 0 partitions when 
running spatial join (#1430)
    
    * Refactored spatial join tests
    
    * Properly handle dataframes containing 0 partitions in spatial join
---
 .../strategy/join/SpatialIndexExec.scala           |  14 ++-
 .../strategy/join/TraitJoinQueryBase.scala         |   2 +-
 .../org/apache/sedona/sql/RasterJoinSuite.scala    |  69 ++++++-------
 .../org/apache/sedona/sql/SpatialJoinSuite.scala   | 108 +++++++++++++--------
 .../sedona/sql/SphereDistanceJoinSuite.scala       |  33 +++----
 .../org/apache/sedona/sql/TestBaseScala.scala      |  13 +++
 6 files changed, 142 insertions(+), 97 deletions(-)

diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
index 4955fb2ba..519af1d42 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
@@ -19,6 +19,7 @@
 package org.apache.spark.sql.sedona_sql.strategy.join
 
 import org.apache.sedona.core.enums.IndexType
+import org.apache.sedona.core.spatialRddTool.IndexBuilder
 
 import scala.jdk.CollectionConverters._
 import org.apache.spark.broadcast.Broadcast
@@ -28,6 +29,9 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, 
Expression, UnsafeRow}
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.sedona_sql.execution.SedonaUnaryExecNode
+import org.locationtech.jts.geom.Geometry
+
+import java.util.Collections
 
 
 case class SpatialIndexExec(child: SparkPlan,
@@ -60,7 +64,15 @@ case class SpatialIndexExec(child: SparkPlan,
     }
 
     spatialRDD.buildIndex(indexType, false)
-    
sparkContext.broadcast(spatialRDD.indexedRawRDD.take(1).asScala.head).asInstanceOf[Broadcast[T]]
+    val spatialIndexes = spatialRDD.indexedRawRDD.take(1).asScala
+    val spatialIndex = if (spatialIndexes.nonEmpty) {
+      spatialIndexes.head
+    } else {
+      // The broadcasted dataframe contains 0 partition. In this case, we 
should provide an empty spatial index.
+      val indexBuilder = new IndexBuilder[Geometry](indexType)
+      indexBuilder.call(Collections.emptyIterator()).next()
+    }
+    sparkContext.broadcast(spatialIndex).asInstanceOf[Broadcast[T]]
   }
 
   protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = {
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
index 70868128e..98688915f 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
@@ -96,7 +96,7 @@ trait TraitJoinQueryBase {
 
   def doSpatialPartitioning(dominantShapes: SpatialRDD[Geometry], 
followerShapes: SpatialRDD[Geometry],
                             numPartitions: Integer, sedonaConf: SedonaConf): 
Unit = {
-    if (dominantShapes.approximateTotalCount > 0) {
+    if (dominantShapes.approximateTotalCount > 0 && numPartitions > 0) {
       dominantShapes.spatialPartitioning(sedonaConf.getJoinGridType, 
numPartitions)
       followerShapes.spatialPartitioning(dominantShapes.getPartitioner)
     }
diff --git 
a/spark/common/src/test/scala/org/apache/sedona/sql/RasterJoinSuite.scala 
b/spark/common/src/test/scala/org/apache/sedona/sql/RasterJoinSuite.scala
index 4c91aa28b..fb9de40e3 100644
--- a/spark/common/src/test/scala/org/apache/sedona/sql/RasterJoinSuite.scala
+++ b/spark/common/src/test/scala/org/apache/sedona/sql/RasterJoinSuite.scala
@@ -29,7 +29,6 @@ import org.scalatest.prop.TableDrivenPropertyChecks
 class RasterJoinSuite extends TestBaseScala with TableDrivenPropertyChecks {
 
   private val spatialJoinPartitionSideConfKey = "sedona.join.spatitionside"
-  private val spatialJoinPartitionSide = 
sparkSession.sparkContext.getConf.get(spatialJoinPartitionSideConfKey, "left")
 
   private val rasters: Seq[(GridCoverage2D, Int)] = Seq(
     // Japan
@@ -178,60 +177,56 @@ class RasterJoinSuite extends TestBaseScala with 
TableDrivenPropertyChecks {
       ("df2 JOIN df1", "RS_Within(df2.geom, df1.rast)")
     )
 
-    try {
-      forAll(joinConditions) { case (joinClause, joinCondition) =>
-        val expected = buildExpectedResult(joinCondition)
-        it(s"$joinClause ON $joinCondition, with left side as dominant side") {
-          
sparkSession.sparkContext.getConf.set(spatialJoinPartitionSideConfKey, "left")
+    forAll(joinConditions) { case (joinClause, joinCondition) =>
+      val expected = buildExpectedResult(joinCondition)
+      it(s"$joinClause ON $joinCondition, with left side as dominant side") {
+        withConf(Map(spatialJoinPartitionSideConfKey -> "left")) {
           val result = sparkSession.sql(s"SELECT df1.id, df2.id FROM 
$joinClause ON $joinCondition")
           verifyResult(expected, result)
         }
-        it(s"$joinClause ON $joinCondition, with right side as dominant side") 
{
-          
sparkSession.sparkContext.getConf.set(spatialJoinPartitionSideConfKey, "right")
+      }
+      it(s"$joinClause ON $joinCondition, with right side as dominant side") {
+        withConf(Map(spatialJoinPartitionSideConfKey -> "right")) {
           val result = sparkSession.sql(s"SELECT df1.id, df2.id FROM 
$joinClause ON $joinCondition")
           verifyResult(expected, result)
         }
-        it(s"$joinClause ON $joinCondition, broadcast df1") {
-          val result = sparkSession.sql(s"SELECT /*+ BROADCAST(df1) */ df1.id, 
df2.id FROM $joinClause ON $joinCondition")
-          verifyResult(expected, result)
-        }
-        it(s"$joinClause ON $joinCondition, broadcast df2") {
-          val result = sparkSession.sql(s"SELECT /*+ BROADCAST(df2) */ df1.id, 
df2.id FROM $joinClause ON $joinCondition")
-          verifyResult(expected, result)
-        }
       }
-    } finally {
-      sparkSession.sparkContext.getConf.set(spatialJoinPartitionSideConfKey, 
spatialJoinPartitionSide)
+      it(s"$joinClause ON $joinCondition, broadcast df1") {
+        val result = sparkSession.sql(s"SELECT /*+ BROADCAST(df1) */ df1.id, 
df2.id FROM $joinClause ON $joinCondition")
+        verifyResult(expected, result)
+      }
+      it(s"$joinClause ON $joinCondition, broadcast df2") {
+        val result = sparkSession.sql(s"SELECT /*+ BROADCAST(df2) */ df1.id, 
df2.id FROM $joinClause ON $joinCondition")
+        verifyResult(expected, result)
+      }
     }
   }
 
   describe("raster-raster join") {
-    try {
-      val expected = rasters.flatMap { case (rast, id1) =>
-        rasters.flatMap { case (otherRast, id2) =>
-          if (RasterPredicates.rsIntersects(rast, otherRast)) Some((id1, id2)) 
else None
-        }
+    val expected = rasters.flatMap { case (rast, id1) =>
+      rasters.flatMap { case (otherRast, id2) =>
+        if (RasterPredicates.rsIntersects(rast, otherRast)) Some((id1, id2)) 
else None
       }
-      it("raster-raster join, with left side as dominant side") {
-        sparkSession.sparkContext.getConf.set(spatialJoinPartitionSideConfKey, 
"left")
+    }
+    it("raster-raster join, with left side as dominant side") {
+      withConf(Map(spatialJoinPartitionSideConfKey -> "left")) {
         val result = sparkSession.sql("SELECT df1.id, df3.id FROM df1 JOIN df3 
ON RS_Intersects(df1.rast, df3.rast)")
         verifyResult(expected, result)
       }
-      it("raster-raster join, with right side as dominant side") {
-        sparkSession.sparkContext.getConf.set(spatialJoinPartitionSideConfKey, 
"right")
+    }
+    it("raster-raster join, with right side as dominant side") {
+      withConf(Map(spatialJoinPartitionSideConfKey -> "right")) {
         val result = sparkSession.sql("SELECT df1.id, df3.id FROM df1 JOIN df3 
ON RS_Intersects(df1.rast, df3.rast)")
         verifyResult(expected, result)
       }
-      it("raster-raster join, broadcast left") {
-        val result = sparkSession.sql("SELECT /*+ BROADCAST(df1) */ df1.id, 
df3.id FROM df1 JOIN df3 ON RS_Intersects(df1.rast, df3.rast)")
-        verifyResult(expected, result)
-      }
-      it("raster-raster join, broadcast right") {
-        val result = sparkSession.sql("SELECT /*+ BROADCAST(df3) */ df1.id, 
df3.id FROM df1 JOIN df3 ON RS_Intersects(df1.rast, df3.rast)")
-        verifyResult(expected, result)
-      }
-    } finally {
-      sparkSession.sparkContext.getConf.set(spatialJoinPartitionSideConfKey, 
spatialJoinPartitionSide)
+    }
+    it("raster-raster join, broadcast left") {
+      val result = sparkSession.sql("SELECT /*+ BROADCAST(df1) */ df1.id, 
df3.id FROM df1 JOIN df3 ON RS_Intersects(df1.rast, df3.rast)")
+      verifyResult(expected, result)
+    }
+    it("raster-raster join, broadcast right") {
+      val result = sparkSession.sql("SELECT /*+ BROADCAST(df3) */ df1.id, 
df3.id FROM df1 JOIN df3 ON RS_Intersects(df1.rast, df3.rast)")
+      verifyResult(expected, result)
     }
   }
 
diff --git 
a/spark/common/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala 
b/spark/common/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
index 952878ce9..e5867771e 100644
--- a/spark/common/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
+++ b/spark/common/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
@@ -19,12 +19,12 @@
 
 package org.apache.sedona.sql
 
-import org.apache.spark.sql.Column
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{Column, DataFrame, Row}
 import org.apache.spark.sql.functions.{col, expr}
+import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
 import 
org.apache.spark.sql.sedona_sql.expressions.st_constructors.ST_GeomFromText
 import org.apache.spark.sql.sedona_sql.strategy.join.{BroadcastIndexJoinExec, 
DistanceJoinExec, RangeJoinExec}
-import org.apache.spark.sql.types.IntegerType
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
 import org.locationtech.jts.geom.Geometry
 import org.locationtech.jts.io.WKTReader
 import org.scalatest.prop.TableDrivenPropertyChecks
@@ -34,6 +34,11 @@ class SpatialJoinSuite extends TestBaseScala with 
TableDrivenPropertyChecks {
   val testDataDelimiter = "\t"
   val spatialJoinPartitionSideConfKey = "sedona.join.spatitionside"
 
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    prepareTempViewsForTestData()
+  }
+
   describe("Sedona-SQL Spatial Join Test") {
     val joinConditions = Table("join condition",
       "ST_Contains(df1.geom, df2.geom)",
@@ -70,39 +75,31 @@ class SpatialJoinSuite extends TestBaseScala with 
TableDrivenPropertyChecks {
       "1.0 >= ST_Distance(df1.geom, df2.geom)"
     )
 
-    var spatialJoinPartitionSide = "left"
-    try {
-      spatialJoinPartitionSide = 
sparkSession.sparkContext.getConf.get(spatialJoinPartitionSideConfKey, "left")
-      forAll (joinConditions) { joinCondition =>
-        it(s"should join two dataframes with $joinCondition") {
-          
sparkSession.sparkContext.getConf.set(spatialJoinPartitionSideConfKey, "left")
-          prepareTempViewsForTestData()
+    forAll (joinConditions) { joinCondition =>
+      it(s"should join two dataframes with $joinCondition") {
+        withConf(Map(spatialJoinPartitionSideConfKey -> "left")) {
           val result = sparkSession.sql(s"SELECT df1.id, df2.id FROM df1 JOIN 
df2 ON $joinCondition")
           val expected = buildExpectedResult(joinCondition)
           verifyResult(expected, result)
         }
-        it(s"should join two dataframes with $joinCondition, with right side 
as dominant side") {
-          
sparkSession.sparkContext.getConf.set(spatialJoinPartitionSideConfKey, "right")
-          prepareTempViewsForTestData()
+      }
+      it(s"should join two dataframes with $joinCondition, with right side as 
dominant side") {
+        withConf(Map(spatialJoinPartitionSideConfKey -> "right")) {
           val result = sparkSession.sql(s"SELECT df1.id, df2.id FROM df1 JOIN 
df2 ON $joinCondition")
           val expected = buildExpectedResult(joinCondition)
           verifyResult(expected, result)
         }
-        it(s"should join two dataframes with $joinCondition, broadcast the 
left side") {
-          prepareTempViewsForTestData()
-          val result = sparkSession.sql(s"SELECT /*+ BROADCAST(df1) */ df1.id, 
df2.id FROM df1 JOIN df2 ON $joinCondition")
-          val expected = buildExpectedResult(joinCondition)
-          verifyResult(expected, result)
-        }
-        it(s"should join two dataframes with $joinCondition, broadcast the 
right side") {
-          prepareTempViewsForTestData()
-          val result = sparkSession.sql(s"SELECT /*+ BROADCAST(df2) */ df1.id, 
df2.id FROM df1 JOIN df2 ON $joinCondition")
-          val expected = buildExpectedResult(joinCondition)
-          verifyResult(expected, result)
-        }
       }
-    } finally {
-      sparkSession.sparkContext.getConf.set(spatialJoinPartitionSideConfKey, 
spatialJoinPartitionSide)
+      it(s"should join two dataframes with $joinCondition, broadcast the left 
side") {
+        val result = sparkSession.sql(s"SELECT /*+ BROADCAST(df1) */ df1.id, 
df2.id FROM df1 JOIN df2 ON $joinCondition")
+        val expected = buildExpectedResult(joinCondition)
+        verifyResult(expected, result)
+      }
+      it(s"should join two dataframes with $joinCondition, broadcast the right 
side") {
+        val result = sparkSession.sql(s"SELECT /*+ BROADCAST(df2) */ df1.id, 
df2.id FROM df1 JOIN df2 ON $joinCondition")
+        val expected = buildExpectedResult(joinCondition)
+        verifyResult(expected, result)
+      }
     }
   }
 
@@ -116,7 +113,6 @@ class SpatialJoinSuite extends TestBaseScala with 
TableDrivenPropertyChecks {
 
     forAll (joinConditions) { joinCondition =>
       it(s"should SELECT * in join query with $joinCondition produce correct 
result") {
-        prepareTempViewsForTestData()
         val resultAll = sparkSession.sql(s"SELECT * FROM df1 JOIN df2 ON 
$joinCondition").collect()
         val result = resultAll.map(row => (row.getInt(0), 
row.getInt(3))).sorted
         val expected = buildExpectedResult(joinCondition)
@@ -125,7 +121,6 @@ class SpatialJoinSuite extends TestBaseScala with 
TableDrivenPropertyChecks {
       }
 
       it(s"should SELECT * in join query with $joinCondition produce correct 
result, broadcast the left side") {
-        prepareTempViewsForTestData()
         val resultAll = sparkSession.sql(s"SELECT /*+ BROADCAST(df1) */ * FROM 
df1 JOIN df2 ON $joinCondition").collect()
         val result = resultAll.map(row => (row.getInt(0), 
row.getInt(3))).sorted
         val expected = buildExpectedResult(joinCondition)
@@ -134,7 +129,6 @@ class SpatialJoinSuite extends TestBaseScala with 
TableDrivenPropertyChecks {
       }
 
       it(s"should SELECT * in join query with $joinCondition produce correct 
result, broadcast the right side") {
-        prepareTempViewsForTestData()
         val resultAll = sparkSession.sql(s"SELECT /*+ BROADCAST(df2) */ * FROM 
df1 JOIN df2 ON $joinCondition").collect()
         val result = resultAll.map(row => (row.getInt(0), 
row.getInt(3))).sorted
         val expected = buildExpectedResult(joinCondition)
@@ -147,7 +141,6 @@ class SpatialJoinSuite extends TestBaseScala with 
TableDrivenPropertyChecks {
   describe("Spatial join in Sedona SQL should be configurable using 
sedona.join.optimizationmode") {
     it("Optimize all spatial joins when sedona.join.optimizationmode = all") {
       withOptimizationMode("all") {
-        prepareTempViewsForTestData()
         val df = sparkSession.sql("SELECT df1.id, df2.id FROM df1 JOIN df2 ON 
df1.id = df2.id AND ST_Intersects(df1.geom, df2.geom)")
         assert(isUsingOptimizedSpatialJoin(df))
         val expectedResult = buildExpectedResult("ST_Intersects(df1.geom, 
df2.geom)")
@@ -158,7 +151,6 @@ class SpatialJoinSuite extends TestBaseScala with 
TableDrivenPropertyChecks {
 
     it("Only optimize non-equi-joins when sedona.join.optimizationmode = 
nonequi") {
       withOptimizationMode("nonequi") {
-        prepareTempViewsForTestData()
         val df = sparkSession.sql("SELECT df1.id, df2.id FROM df1 JOIN df2 ON 
ST_Intersects(df1.geom, df2.geom)")
         assert(isUsingOptimizedSpatialJoin(df))
         val df2 = sparkSession.sql("SELECT df1.id, df2.id FROM df1 JOIN df2 ON 
df1.id = df2.id AND ST_Intersects(df1.geom, df2.geom)")
@@ -168,7 +160,6 @@ class SpatialJoinSuite extends TestBaseScala with 
TableDrivenPropertyChecks {
 
     it("Won't optimize spatial joins when sedona.join.optimizationmode = 
none") {
       withOptimizationMode("none") {
-        prepareTempViewsForTestData()
         val df = sparkSession.sql("SELECT df1.id, df2.id FROM df1 JOIN df2 ON 
ST_Intersects(df1.geom, df2.geom)")
         assert(!isUsingOptimizedSpatialJoin(df))
       }
@@ -191,14 +182,46 @@ class SpatialJoinSuite extends TestBaseScala with 
TableDrivenPropertyChecks {
     }
   }
 
-  private def withOptimizationMode(mode: String)(body: => Unit) : Unit = {
-    val oldOptimizationMode = 
sparkSession.conf.get("sedona.join.optimizationmode", "nonequi")
-    try {
-      sparkSession.conf.set("sedona.join.optimizationmode", mode)
-      body
-    } finally {
-      sparkSession.conf.set("sedona.join.optimizationmode", 
oldOptimizationMode)
+  describe("Spatial join should work with dataframe containing 0 partitions") {
+    val queries = Table("join queries",
+      "SELECT * FROM df1 JOIN dfEmpty WHERE ST_Intersects(df1.geom, 
dfEmpty.geom)",
+      "SELECT * FROM dfEmpty JOIN df1 WHERE ST_Intersects(df1.geom, 
dfEmpty.geom)",
+      "SELECT /*+ BROADCAST(df1) */ * FROM df1 JOIN dfEmpty WHERE 
ST_Intersects(df1.geom, dfEmpty.geom)",
+      "SELECT /*+ BROADCAST(dfEmpty) */ * FROM df1 JOIN dfEmpty WHERE 
ST_Intersects(df1.geom, dfEmpty.geom)",
+      "SELECT /*+ BROADCAST(df1) */ * FROM dfEmpty JOIN df1 WHERE 
ST_Intersects(df1.geom, dfEmpty.geom)",
+      "SELECT /*+ BROADCAST(dfEmpty) */ * FROM dfEmpty JOIN df1 WHERE 
ST_Intersects(df1.geom, dfEmpty.geom)")
+
+    forAll (queries) { query =>
+      it(s"Legacy join: $query") {
+        withConf(Map(spatialJoinPartitionSideConfKey -> "left")) {
+          val resultRows = sparkSession.sql(query).collect()
+          assert(resultRows.isEmpty)
+        }
+        withConf(Map(spatialJoinPartitionSideConfKey -> "right")) {
+          val resultRows = sparkSession.sql(query).collect()
+          assert(resultRows.isEmpty)
+        }
+      }
     }
+
+    it("non-empty dataframe has lots of partitions") {
+      val df = sparkSession.range(0, 4).toDF("id").withColumn("geom", 
expr("ST_Point(id, id)")).repartition(10)
+      df.createOrReplaceTempView("df10parts")
+
+      val query = "SELECT * FROM df10parts JOIN dfEmpty WHERE 
ST_Intersects(df10parts.geom, dfEmpty.geom)";
+      withConf(Map(spatialJoinPartitionSideConfKey -> "left")) {
+        val resultRows = sparkSession.sql(query).collect()
+        assert(resultRows.isEmpty)
+      }
+      withConf(Map(spatialJoinPartitionSideConfKey -> "right")) {
+        val resultRows = sparkSession.sql(query).collect()
+        assert(resultRows.isEmpty)
+      }
+    }
+  }
+
+  private def withOptimizationMode(mode: String)(body: => Unit) : Unit = {
+    withConf(Map("sedona.join.optimizationmode" -> mode))(body)
   }
 
   private def prepareTempViewsForTestData(): (DataFrame, DataFrame) = {
@@ -214,8 +237,13 @@ class SpatialJoinSuite extends TestBaseScala with 
TableDrivenPropertyChecks {
       .withColumn("geom", ST_GeomFromText(new Column("_c2")))
       .select("id", "geom")
       .withColumn("dist", expr("ST_Area(geom)"))
+    val emptyRdd = sparkSession.sparkContext.emptyRDD[Row]
+    val emptyDf = sparkSession.createDataFrame(emptyRdd, StructType(Seq(
+      StructField("id", IntegerType), StructField("geom", GeometryUDT)
+    )))
     df1.createOrReplaceTempView("df1")
     df2.createOrReplaceTempView("df2")
+    emptyDf.createOrReplaceTempView("dfEmpty")
     (df1, df2)
   }
 
diff --git 
a/spark/common/src/test/scala/org/apache/sedona/sql/SphereDistanceJoinSuite.scala
 
b/spark/common/src/test/scala/org/apache/sedona/sql/SphereDistanceJoinSuite.scala
index ca78c6f4f..335e89e6a 100644
--- 
a/spark/common/src/test/scala/org/apache/sedona/sql/SphereDistanceJoinSuite.scala
+++ 
b/spark/common/src/test/scala/org/apache/sedona/sql/SphereDistanceJoinSuite.scala
@@ -28,7 +28,6 @@ import scala.util.Random
 
 class SphereDistanceJoinSuite extends TestBaseScala with 
TableDrivenPropertyChecks {
   private val spatialJoinPartitionSideConfKey = "sedona.join.spatitionside"
-  private val spatialJoinPartitionSide = 
sparkSession.sparkContext.getConf.get(spatialJoinPartitionSideConfKey, "left")
 
   private val testData1: Seq[(Int, Double, Geometry)] = generateTestData()
   private val testData2: Seq[(Int, Double, Geometry)] = generateTestData()
@@ -56,30 +55,28 @@ class SphereDistanceJoinSuite extends TestBaseScala with 
TableDrivenPropertyChec
       "ST_DistanceSpheroid(df2.geom, df1.geom) < df2.dist"
     )
 
-    try {
-      forAll(joinConditions) { joinCondition =>
-        val expected = buildExpectedResult(joinCondition)
-        it(s"sphere distance join ON $joinCondition, with left side as 
dominant side") {
-          
sparkSession.sparkContext.getConf.set(spatialJoinPartitionSideConfKey, "left")
+    forAll(joinConditions) { joinCondition =>
+      val expected = buildExpectedResult(joinCondition)
+      it(s"sphere distance join ON $joinCondition, with left side as dominant 
side") {
+        withConf(Map(spatialJoinPartitionSideConfKey -> "left")) {
           val result = sparkSession.sql(s"SELECT df1.id, df2.id FROM df1 JOIN 
df2 ON $joinCondition")
           verifyResult(expected, result)
         }
-        it(s"sphere distance join ON $joinCondition, with right side as 
dominant side") {
-          
sparkSession.sparkContext.getConf.set(spatialJoinPartitionSideConfKey, "right")
+      }
+      it(s"sphere distance join ON $joinCondition, with right side as dominant 
side") {
+        withConf(Map(spatialJoinPartitionSideConfKey -> "right")) {
           val result = sparkSession.sql(s"SELECT df1.id, df2.id FROM df1 JOIN 
df2 ON $joinCondition")
           verifyResult(expected, result)
         }
-        it(s"sphere distance ON $joinCondition, broadcast df1") {
-          val result = sparkSession.sql(s"SELECT /*+ BROADCAST(df1) */ df1.id, 
df2.id FROM df1 JOIN df2 ON $joinCondition")
-          verifyResult(expected, result)
-        }
-        it(s"sphere distance ON $joinCondition, broadcast df2") {
-          val result = sparkSession.sql(s"SELECT /*+ BROADCAST(df2) */ df1.id, 
df2.id FROM df1 JOIN df2 ON $joinCondition")
-          verifyResult(expected, result)
-        }
       }
-    } finally {
-      sparkSession.sparkContext.getConf.set(spatialJoinPartitionSideConfKey, 
spatialJoinPartitionSide)
+      it(s"sphere distance ON $joinCondition, broadcast df1") {
+        val result = sparkSession.sql(s"SELECT /*+ BROADCAST(df1) */ df1.id, 
df2.id FROM df1 JOIN df2 ON $joinCondition")
+        verifyResult(expected, result)
+      }
+      it(s"sphere distance ON $joinCondition, broadcast df2") {
+        val result = sparkSession.sql(s"SELECT /*+ BROADCAST(df2) */ df1.id, 
df2.id FROM df1 JOIN df2 ON $joinCondition")
+        verifyResult(expected, result)
+      }
     }
   }
 
diff --git 
a/spark/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala 
b/spark/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
index 4c50bc3c0..eaace3a32 100644
--- a/spark/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
+++ b/spark/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
@@ -143,6 +143,19 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll 
{
     }
   }
 
+  def withConf[T](conf: Map[String, String])(f: => T): T = {
+    val oldConf = conf.values.map(key => key -> 
sparkSession.conf.getOption(key))
+    conf.foreach{ case (key, value) => sparkSession.conf.set(key, value) }
+    try {
+      f
+    } finally {
+      oldConf.foreach { case (key, value) => value match {
+        case Some(v) => sparkSession.conf.set(key, v)
+        case None => sparkSession.conf.unset(key)
+      }}
+    }
+  }
+
   protected def bruteForceDistanceJoinCountSpheroid(sampleCount:Int, distance: 
Double): Int = {
     val input = buildPointLonLatDf.limit(sampleCount).collect()
     input.map(row => {

Reply via email to