This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 03cb6709f [SEDONA-560] Properly handle dataframes containing 0
partitions when running spatial join (#1430)
03cb6709f is described below
commit 03cb6709fd2d2d1c72cec9d7661b2534bd282283
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Mon May 27 12:34:10 2024 +0800
[SEDONA-560] Properly handle dataframes containing 0 partitions when
running spatial join (#1430)
* Refactored spatial join tests
* Properly handle dataframes containing 0 partitions in spatial join
---
.../strategy/join/SpatialIndexExec.scala | 14 ++-
.../strategy/join/TraitJoinQueryBase.scala | 2 +-
.../org/apache/sedona/sql/RasterJoinSuite.scala | 69 ++++++-------
.../org/apache/sedona/sql/SpatialJoinSuite.scala | 108 +++++++++++++--------
.../sedona/sql/SphereDistanceJoinSuite.scala | 33 +++----
.../org/apache/sedona/sql/TestBaseScala.scala | 13 +++
6 files changed, 142 insertions(+), 97 deletions(-)
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
index 4955fb2ba..519af1d42 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
@@ -19,6 +19,7 @@
package org.apache.spark.sql.sedona_sql.strategy.join
import org.apache.sedona.core.enums.IndexType
+import org.apache.sedona.core.spatialRddTool.IndexBuilder
import scala.jdk.CollectionConverters._
import org.apache.spark.broadcast.Broadcast
@@ -28,6 +29,9 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences,
Expression, UnsafeRow}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.sedona_sql.execution.SedonaUnaryExecNode
+import org.locationtech.jts.geom.Geometry
+
+import java.util.Collections
case class SpatialIndexExec(child: SparkPlan,
@@ -60,7 +64,15 @@ case class SpatialIndexExec(child: SparkPlan,
}
spatialRDD.buildIndex(indexType, false)
-
sparkContext.broadcast(spatialRDD.indexedRawRDD.take(1).asScala.head).asInstanceOf[Broadcast[T]]
+ val spatialIndexes = spatialRDD.indexedRawRDD.take(1).asScala
+ val spatialIndex = if (spatialIndexes.nonEmpty) {
+ spatialIndexes.head
+ } else {
+ // The broadcasted dataframe contains 0 partition. In this case, we
should provide an empty spatial index.
+ val indexBuilder = new IndexBuilder[Geometry](indexType)
+ indexBuilder.call(Collections.emptyIterator()).next()
+ }
+ sparkContext.broadcast(spatialIndex).asInstanceOf[Broadcast[T]]
}
protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = {
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
index 70868128e..98688915f 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
@@ -96,7 +96,7 @@ trait TraitJoinQueryBase {
def doSpatialPartitioning(dominantShapes: SpatialRDD[Geometry],
followerShapes: SpatialRDD[Geometry],
numPartitions: Integer, sedonaConf: SedonaConf):
Unit = {
- if (dominantShapes.approximateTotalCount > 0) {
+ if (dominantShapes.approximateTotalCount > 0 && numPartitions > 0) {
dominantShapes.spatialPartitioning(sedonaConf.getJoinGridType,
numPartitions)
followerShapes.spatialPartitioning(dominantShapes.getPartitioner)
}
diff --git
a/spark/common/src/test/scala/org/apache/sedona/sql/RasterJoinSuite.scala
b/spark/common/src/test/scala/org/apache/sedona/sql/RasterJoinSuite.scala
index 4c91aa28b..fb9de40e3 100644
--- a/spark/common/src/test/scala/org/apache/sedona/sql/RasterJoinSuite.scala
+++ b/spark/common/src/test/scala/org/apache/sedona/sql/RasterJoinSuite.scala
@@ -29,7 +29,6 @@ import org.scalatest.prop.TableDrivenPropertyChecks
class RasterJoinSuite extends TestBaseScala with TableDrivenPropertyChecks {
private val spatialJoinPartitionSideConfKey = "sedona.join.spatitionside"
- private val spatialJoinPartitionSide =
sparkSession.sparkContext.getConf.get(spatialJoinPartitionSideConfKey, "left")
private val rasters: Seq[(GridCoverage2D, Int)] = Seq(
// Japan
@@ -178,60 +177,56 @@ class RasterJoinSuite extends TestBaseScala with
TableDrivenPropertyChecks {
("df2 JOIN df1", "RS_Within(df2.geom, df1.rast)")
)
- try {
- forAll(joinConditions) { case (joinClause, joinCondition) =>
- val expected = buildExpectedResult(joinCondition)
- it(s"$joinClause ON $joinCondition, with left side as dominant side") {
-
sparkSession.sparkContext.getConf.set(spatialJoinPartitionSideConfKey, "left")
+ forAll(joinConditions) { case (joinClause, joinCondition) =>
+ val expected = buildExpectedResult(joinCondition)
+ it(s"$joinClause ON $joinCondition, with left side as dominant side") {
+ withConf(Map(spatialJoinPartitionSideConfKey -> "left")) {
val result = sparkSession.sql(s"SELECT df1.id, df2.id FROM
$joinClause ON $joinCondition")
verifyResult(expected, result)
}
- it(s"$joinClause ON $joinCondition, with right side as dominant side")
{
-
sparkSession.sparkContext.getConf.set(spatialJoinPartitionSideConfKey, "right")
+ }
+ it(s"$joinClause ON $joinCondition, with right side as dominant side") {
+ withConf(Map(spatialJoinPartitionSideConfKey -> "right")) {
val result = sparkSession.sql(s"SELECT df1.id, df2.id FROM
$joinClause ON $joinCondition")
verifyResult(expected, result)
}
- it(s"$joinClause ON $joinCondition, broadcast df1") {
- val result = sparkSession.sql(s"SELECT /*+ BROADCAST(df1) */ df1.id,
df2.id FROM $joinClause ON $joinCondition")
- verifyResult(expected, result)
- }
- it(s"$joinClause ON $joinCondition, broadcast df2") {
- val result = sparkSession.sql(s"SELECT /*+ BROADCAST(df2) */ df1.id,
df2.id FROM $joinClause ON $joinCondition")
- verifyResult(expected, result)
- }
}
- } finally {
- sparkSession.sparkContext.getConf.set(spatialJoinPartitionSideConfKey,
spatialJoinPartitionSide)
+ it(s"$joinClause ON $joinCondition, broadcast df1") {
+ val result = sparkSession.sql(s"SELECT /*+ BROADCAST(df1) */ df1.id,
df2.id FROM $joinClause ON $joinCondition")
+ verifyResult(expected, result)
+ }
+ it(s"$joinClause ON $joinCondition, broadcast df2") {
+ val result = sparkSession.sql(s"SELECT /*+ BROADCAST(df2) */ df1.id,
df2.id FROM $joinClause ON $joinCondition")
+ verifyResult(expected, result)
+ }
}
}
describe("raster-raster join") {
- try {
- val expected = rasters.flatMap { case (rast, id1) =>
- rasters.flatMap { case (otherRast, id2) =>
- if (RasterPredicates.rsIntersects(rast, otherRast)) Some((id1, id2))
else None
- }
+ val expected = rasters.flatMap { case (rast, id1) =>
+ rasters.flatMap { case (otherRast, id2) =>
+ if (RasterPredicates.rsIntersects(rast, otherRast)) Some((id1, id2))
else None
}
- it("raster-raster join, with left side as dominant side") {
- sparkSession.sparkContext.getConf.set(spatialJoinPartitionSideConfKey,
"left")
+ }
+ it("raster-raster join, with left side as dominant side") {
+ withConf(Map(spatialJoinPartitionSideConfKey -> "left")) {
val result = sparkSession.sql("SELECT df1.id, df3.id FROM df1 JOIN df3
ON RS_Intersects(df1.rast, df3.rast)")
verifyResult(expected, result)
}
- it("raster-raster join, with right side as dominant side") {
- sparkSession.sparkContext.getConf.set(spatialJoinPartitionSideConfKey,
"right")
+ }
+ it("raster-raster join, with right side as dominant side") {
+ withConf(Map(spatialJoinPartitionSideConfKey -> "right")) {
val result = sparkSession.sql("SELECT df1.id, df3.id FROM df1 JOIN df3
ON RS_Intersects(df1.rast, df3.rast)")
verifyResult(expected, result)
}
- it("raster-raster join, broadcast left") {
- val result = sparkSession.sql("SELECT /*+ BROADCAST(df1) */ df1.id,
df3.id FROM df1 JOIN df3 ON RS_Intersects(df1.rast, df3.rast)")
- verifyResult(expected, result)
- }
- it("raster-raster join, broadcast right") {
- val result = sparkSession.sql("SELECT /*+ BROADCAST(df3) */ df1.id,
df3.id FROM df1 JOIN df3 ON RS_Intersects(df1.rast, df3.rast)")
- verifyResult(expected, result)
- }
- } finally {
- sparkSession.sparkContext.getConf.set(spatialJoinPartitionSideConfKey,
spatialJoinPartitionSide)
+ }
+ it("raster-raster join, broadcast left") {
+ val result = sparkSession.sql("SELECT /*+ BROADCAST(df1) */ df1.id,
df3.id FROM df1 JOIN df3 ON RS_Intersects(df1.rast, df3.rast)")
+ verifyResult(expected, result)
+ }
+ it("raster-raster join, broadcast right") {
+ val result = sparkSession.sql("SELECT /*+ BROADCAST(df3) */ df1.id,
df3.id FROM df1 JOIN df3 ON RS_Intersects(df1.rast, df3.rast)")
+ verifyResult(expected, result)
}
}
diff --git
a/spark/common/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
b/spark/common/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
index 952878ce9..e5867771e 100644
--- a/spark/common/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
+++ b/spark/common/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
@@ -19,12 +19,12 @@
package org.apache.sedona.sql
-import org.apache.spark.sql.Column
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{Column, DataFrame, Row}
import org.apache.spark.sql.functions.{col, expr}
+import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
import
org.apache.spark.sql.sedona_sql.expressions.st_constructors.ST_GeomFromText
import org.apache.spark.sql.sedona_sql.strategy.join.{BroadcastIndexJoinExec,
DistanceJoinExec, RangeJoinExec}
-import org.apache.spark.sql.types.IntegerType
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.locationtech.jts.geom.Geometry
import org.locationtech.jts.io.WKTReader
import org.scalatest.prop.TableDrivenPropertyChecks
@@ -34,6 +34,11 @@ class SpatialJoinSuite extends TestBaseScala with
TableDrivenPropertyChecks {
val testDataDelimiter = "\t"
val spatialJoinPartitionSideConfKey = "sedona.join.spatitionside"
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ prepareTempViewsForTestData()
+ }
+
describe("Sedona-SQL Spatial Join Test") {
val joinConditions = Table("join condition",
"ST_Contains(df1.geom, df2.geom)",
@@ -70,39 +75,31 @@ class SpatialJoinSuite extends TestBaseScala with
TableDrivenPropertyChecks {
"1.0 >= ST_Distance(df1.geom, df2.geom)"
)
- var spatialJoinPartitionSide = "left"
- try {
- spatialJoinPartitionSide =
sparkSession.sparkContext.getConf.get(spatialJoinPartitionSideConfKey, "left")
- forAll (joinConditions) { joinCondition =>
- it(s"should join two dataframes with $joinCondition") {
-
sparkSession.sparkContext.getConf.set(spatialJoinPartitionSideConfKey, "left")
- prepareTempViewsForTestData()
+ forAll (joinConditions) { joinCondition =>
+ it(s"should join two dataframes with $joinCondition") {
+ withConf(Map(spatialJoinPartitionSideConfKey -> "left")) {
val result = sparkSession.sql(s"SELECT df1.id, df2.id FROM df1 JOIN
df2 ON $joinCondition")
val expected = buildExpectedResult(joinCondition)
verifyResult(expected, result)
}
- it(s"should join two dataframes with $joinCondition, with right side
as dominant side") {
-
sparkSession.sparkContext.getConf.set(spatialJoinPartitionSideConfKey, "right")
- prepareTempViewsForTestData()
+ }
+ it(s"should join two dataframes with $joinCondition, with right side as
dominant side") {
+ withConf(Map(spatialJoinPartitionSideConfKey -> "right")) {
val result = sparkSession.sql(s"SELECT df1.id, df2.id FROM df1 JOIN
df2 ON $joinCondition")
val expected = buildExpectedResult(joinCondition)
verifyResult(expected, result)
}
- it(s"should join two dataframes with $joinCondition, broadcast the
left side") {
- prepareTempViewsForTestData()
- val result = sparkSession.sql(s"SELECT /*+ BROADCAST(df1) */ df1.id,
df2.id FROM df1 JOIN df2 ON $joinCondition")
- val expected = buildExpectedResult(joinCondition)
- verifyResult(expected, result)
- }
- it(s"should join two dataframes with $joinCondition, broadcast the
right side") {
- prepareTempViewsForTestData()
- val result = sparkSession.sql(s"SELECT /*+ BROADCAST(df2) */ df1.id,
df2.id FROM df1 JOIN df2 ON $joinCondition")
- val expected = buildExpectedResult(joinCondition)
- verifyResult(expected, result)
- }
}
- } finally {
- sparkSession.sparkContext.getConf.set(spatialJoinPartitionSideConfKey,
spatialJoinPartitionSide)
+ it(s"should join two dataframes with $joinCondition, broadcast the left
side") {
+ val result = sparkSession.sql(s"SELECT /*+ BROADCAST(df1) */ df1.id,
df2.id FROM df1 JOIN df2 ON $joinCondition")
+ val expected = buildExpectedResult(joinCondition)
+ verifyResult(expected, result)
+ }
+ it(s"should join two dataframes with $joinCondition, broadcast the right
side") {
+ val result = sparkSession.sql(s"SELECT /*+ BROADCAST(df2) */ df1.id,
df2.id FROM df1 JOIN df2 ON $joinCondition")
+ val expected = buildExpectedResult(joinCondition)
+ verifyResult(expected, result)
+ }
}
}
@@ -116,7 +113,6 @@ class SpatialJoinSuite extends TestBaseScala with
TableDrivenPropertyChecks {
forAll (joinConditions) { joinCondition =>
it(s"should SELECT * in join query with $joinCondition produce correct
result") {
- prepareTempViewsForTestData()
val resultAll = sparkSession.sql(s"SELECT * FROM df1 JOIN df2 ON
$joinCondition").collect()
val result = resultAll.map(row => (row.getInt(0),
row.getInt(3))).sorted
val expected = buildExpectedResult(joinCondition)
@@ -125,7 +121,6 @@ class SpatialJoinSuite extends TestBaseScala with
TableDrivenPropertyChecks {
}
it(s"should SELECT * in join query with $joinCondition produce correct
result, broadcast the left side") {
- prepareTempViewsForTestData()
val resultAll = sparkSession.sql(s"SELECT /*+ BROADCAST(df1) */ * FROM
df1 JOIN df2 ON $joinCondition").collect()
val result = resultAll.map(row => (row.getInt(0),
row.getInt(3))).sorted
val expected = buildExpectedResult(joinCondition)
@@ -134,7 +129,6 @@ class SpatialJoinSuite extends TestBaseScala with
TableDrivenPropertyChecks {
}
it(s"should SELECT * in join query with $joinCondition produce correct
result, broadcast the right side") {
- prepareTempViewsForTestData()
val resultAll = sparkSession.sql(s"SELECT /*+ BROADCAST(df2) */ * FROM
df1 JOIN df2 ON $joinCondition").collect()
val result = resultAll.map(row => (row.getInt(0),
row.getInt(3))).sorted
val expected = buildExpectedResult(joinCondition)
@@ -147,7 +141,6 @@ class SpatialJoinSuite extends TestBaseScala with
TableDrivenPropertyChecks {
describe("Spatial join in Sedona SQL should be configurable using
sedona.join.optimizationmode") {
it("Optimize all spatial joins when sedona.join.optimizationmode = all") {
withOptimizationMode("all") {
- prepareTempViewsForTestData()
val df = sparkSession.sql("SELECT df1.id, df2.id FROM df1 JOIN df2 ON
df1.id = df2.id AND ST_Intersects(df1.geom, df2.geom)")
assert(isUsingOptimizedSpatialJoin(df))
val expectedResult = buildExpectedResult("ST_Intersects(df1.geom,
df2.geom)")
@@ -158,7 +151,6 @@ class SpatialJoinSuite extends TestBaseScala with
TableDrivenPropertyChecks {
it("Only optimize non-equi-joins when sedona.join.optimizationmode =
nonequi") {
withOptimizationMode("nonequi") {
- prepareTempViewsForTestData()
val df = sparkSession.sql("SELECT df1.id, df2.id FROM df1 JOIN df2 ON
ST_Intersects(df1.geom, df2.geom)")
assert(isUsingOptimizedSpatialJoin(df))
val df2 = sparkSession.sql("SELECT df1.id, df2.id FROM df1 JOIN df2 ON
df1.id = df2.id AND ST_Intersects(df1.geom, df2.geom)")
@@ -168,7 +160,6 @@ class SpatialJoinSuite extends TestBaseScala with
TableDrivenPropertyChecks {
it("Won't optimize spatial joins when sedona.join.optimizationmode =
none") {
withOptimizationMode("none") {
- prepareTempViewsForTestData()
val df = sparkSession.sql("SELECT df1.id, df2.id FROM df1 JOIN df2 ON
ST_Intersects(df1.geom, df2.geom)")
assert(!isUsingOptimizedSpatialJoin(df))
}
@@ -191,14 +182,46 @@ class SpatialJoinSuite extends TestBaseScala with
TableDrivenPropertyChecks {
}
}
- private def withOptimizationMode(mode: String)(body: => Unit) : Unit = {
- val oldOptimizationMode =
sparkSession.conf.get("sedona.join.optimizationmode", "nonequi")
- try {
- sparkSession.conf.set("sedona.join.optimizationmode", mode)
- body
- } finally {
- sparkSession.conf.set("sedona.join.optimizationmode",
oldOptimizationMode)
+ describe("Spatial join should work with dataframe containing 0 partitions") {
+ val queries = Table("join queries",
+ "SELECT * FROM df1 JOIN dfEmpty WHERE ST_Intersects(df1.geom,
dfEmpty.geom)",
+ "SELECT * FROM dfEmpty JOIN df1 WHERE ST_Intersects(df1.geom,
dfEmpty.geom)",
+ "SELECT /*+ BROADCAST(df1) */ * FROM df1 JOIN dfEmpty WHERE
ST_Intersects(df1.geom, dfEmpty.geom)",
+ "SELECT /*+ BROADCAST(dfEmpty) */ * FROM df1 JOIN dfEmpty WHERE
ST_Intersects(df1.geom, dfEmpty.geom)",
+ "SELECT /*+ BROADCAST(df1) */ * FROM dfEmpty JOIN df1 WHERE
ST_Intersects(df1.geom, dfEmpty.geom)",
+ "SELECT /*+ BROADCAST(dfEmpty) */ * FROM dfEmpty JOIN df1 WHERE
ST_Intersects(df1.geom, dfEmpty.geom)")
+
+ forAll (queries) { query =>
+ it(s"Legacy join: $query") {
+ withConf(Map(spatialJoinPartitionSideConfKey -> "left")) {
+ val resultRows = sparkSession.sql(query).collect()
+ assert(resultRows.isEmpty)
+ }
+ withConf(Map(spatialJoinPartitionSideConfKey -> "right")) {
+ val resultRows = sparkSession.sql(query).collect()
+ assert(resultRows.isEmpty)
+ }
+ }
}
+
+ it("non-empty dataframe has lots of partitions") {
+ val df = sparkSession.range(0, 4).toDF("id").withColumn("geom",
expr("ST_Point(id, id)")).repartition(10)
+ df.createOrReplaceTempView("df10parts")
+
+ val query = "SELECT * FROM df10parts JOIN dfEmpty WHERE
ST_Intersects(df10parts.geom, dfEmpty.geom)";
+ withConf(Map(spatialJoinPartitionSideConfKey -> "left")) {
+ val resultRows = sparkSession.sql(query).collect()
+ assert(resultRows.isEmpty)
+ }
+ withConf(Map(spatialJoinPartitionSideConfKey -> "right")) {
+ val resultRows = sparkSession.sql(query).collect()
+ assert(resultRows.isEmpty)
+ }
+ }
+ }
+
+ private def withOptimizationMode(mode: String)(body: => Unit) : Unit = {
+ withConf(Map("sedona.join.optimizationmode" -> mode))(body)
}
private def prepareTempViewsForTestData(): (DataFrame, DataFrame) = {
@@ -214,8 +237,13 @@ class SpatialJoinSuite extends TestBaseScala with
TableDrivenPropertyChecks {
.withColumn("geom", ST_GeomFromText(new Column("_c2")))
.select("id", "geom")
.withColumn("dist", expr("ST_Area(geom)"))
+ val emptyRdd = sparkSession.sparkContext.emptyRDD[Row]
+ val emptyDf = sparkSession.createDataFrame(emptyRdd, StructType(Seq(
+ StructField("id", IntegerType), StructField("geom", GeometryUDT)
+ )))
df1.createOrReplaceTempView("df1")
df2.createOrReplaceTempView("df2")
+ emptyDf.createOrReplaceTempView("dfEmpty")
(df1, df2)
}
diff --git
a/spark/common/src/test/scala/org/apache/sedona/sql/SphereDistanceJoinSuite.scala
b/spark/common/src/test/scala/org/apache/sedona/sql/SphereDistanceJoinSuite.scala
index ca78c6f4f..335e89e6a 100644
---
a/spark/common/src/test/scala/org/apache/sedona/sql/SphereDistanceJoinSuite.scala
+++
b/spark/common/src/test/scala/org/apache/sedona/sql/SphereDistanceJoinSuite.scala
@@ -28,7 +28,6 @@ import scala.util.Random
class SphereDistanceJoinSuite extends TestBaseScala with
TableDrivenPropertyChecks {
private val spatialJoinPartitionSideConfKey = "sedona.join.spatitionside"
- private val spatialJoinPartitionSide =
sparkSession.sparkContext.getConf.get(spatialJoinPartitionSideConfKey, "left")
private val testData1: Seq[(Int, Double, Geometry)] = generateTestData()
private val testData2: Seq[(Int, Double, Geometry)] = generateTestData()
@@ -56,30 +55,28 @@ class SphereDistanceJoinSuite extends TestBaseScala with
TableDrivenPropertyChec
"ST_DistanceSpheroid(df2.geom, df1.geom) < df2.dist"
)
- try {
- forAll(joinConditions) { joinCondition =>
- val expected = buildExpectedResult(joinCondition)
- it(s"sphere distance join ON $joinCondition, with left side as
dominant side") {
-
sparkSession.sparkContext.getConf.set(spatialJoinPartitionSideConfKey, "left")
+ forAll(joinConditions) { joinCondition =>
+ val expected = buildExpectedResult(joinCondition)
+ it(s"sphere distance join ON $joinCondition, with left side as dominant
side") {
+ withConf(Map(spatialJoinPartitionSideConfKey -> "left")) {
val result = sparkSession.sql(s"SELECT df1.id, df2.id FROM df1 JOIN
df2 ON $joinCondition")
verifyResult(expected, result)
}
- it(s"sphere distance join ON $joinCondition, with right side as
dominant side") {
-
sparkSession.sparkContext.getConf.set(spatialJoinPartitionSideConfKey, "right")
+ }
+ it(s"sphere distance join ON $joinCondition, with right side as dominant
side") {
+ withConf(Map(spatialJoinPartitionSideConfKey -> "right")) {
val result = sparkSession.sql(s"SELECT df1.id, df2.id FROM df1 JOIN
df2 ON $joinCondition")
verifyResult(expected, result)
}
- it(s"sphere distance ON $joinCondition, broadcast df1") {
- val result = sparkSession.sql(s"SELECT /*+ BROADCAST(df1) */ df1.id,
df2.id FROM df1 JOIN df2 ON $joinCondition")
- verifyResult(expected, result)
- }
- it(s"sphere distance ON $joinCondition, broadcast df2") {
- val result = sparkSession.sql(s"SELECT /*+ BROADCAST(df2) */ df1.id,
df2.id FROM df1 JOIN df2 ON $joinCondition")
- verifyResult(expected, result)
- }
}
- } finally {
- sparkSession.sparkContext.getConf.set(spatialJoinPartitionSideConfKey,
spatialJoinPartitionSide)
+ it(s"sphere distance ON $joinCondition, broadcast df1") {
+ val result = sparkSession.sql(s"SELECT /*+ BROADCAST(df1) */ df1.id,
df2.id FROM df1 JOIN df2 ON $joinCondition")
+ verifyResult(expected, result)
+ }
+ it(s"sphere distance ON $joinCondition, broadcast df2") {
+ val result = sparkSession.sql(s"SELECT /*+ BROADCAST(df2) */ df1.id,
df2.id FROM df1 JOIN df2 ON $joinCondition")
+ verifyResult(expected, result)
+ }
}
}
diff --git
a/spark/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
b/spark/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
index 4c50bc3c0..eaace3a32 100644
--- a/spark/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
+++ b/spark/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
@@ -143,6 +143,19 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll
{
}
}
+ def withConf[T](conf: Map[String, String])(f: => T): T = {
+ val oldConf = conf.values.map(key => key ->
sparkSession.conf.getOption(key))
+ conf.foreach{ case (key, value) => sparkSession.conf.set(key, value) }
+ try {
+ f
+ } finally {
+ oldConf.foreach { case (key, value) => value match {
+ case Some(v) => sparkSession.conf.set(key, v)
+ case None => sparkSession.conf.unset(key)
+ }}
+ }
+ }
+
protected def bruteForceDistanceJoinCountSpheroid(sampleCount:Int, distance:
Double): Int = {
val input = buildPointLonLatDf.limit(sampleCount).collect()
input.map(row => {