This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new 68abdcd08 [SEDONA-637] Show spatial filters pushed to GeoParquet scans 
in the query plan; allow disabling spatial filter pushdown (#1540)
68abdcd08 is described below

commit 68abdcd08c7999f79e1c29314e28246b7e569597
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Tue Aug 6 00:45:49 2024 +0800

    [SEDONA-637] Show spatial filters pushed to GeoParquet scans in the query 
plan; allow disabling spatial filter pushdown (#1540)
    
    * Show pushed-down spatial filters in query plan, allow tuning off spatial 
filter push-down using 
spark.conf.set("spark.sedona.geoparquet.spatialFilterPushDown", "false")
    
    * Backport to Spark 3.0 and Spark 3.4
    
    * Document spark.sedona.geoparquet.spatialFilterPushDown option
---
 docs/api/sql/Optimizer.md                          |  2 ++
 .../parquet/GeoParquetSpatialFilter.scala          |  8 ++++-
 .../SpatialFilterPushDownForGeoParquet.scala       | 39 +++++++++++++---------
 .../datasources/parquet/GeoParquetFileFormat.scala |  8 +++++
 .../sql/GeoParquetSpatialFilterPushDownSuite.scala | 20 +++++++++++
 .../org/apache/sedona/sql/TestBaseScala.scala      | 15 +++++++++
 .../datasources/parquet/GeoParquetFileFormat.scala |  8 +++++
 .../sql/GeoParquetSpatialFilterPushDownSuite.scala | 20 +++++++++++
 .../org/apache/sedona/sql/TestBaseScala.scala      | 15 +++++++++
 .../datasources/parquet/GeoParquetFileFormat.scala |  8 +++++
 .../sql/GeoParquetSpatialFilterPushDownSuite.scala | 20 +++++++++++
 .../org/apache/sedona/sql/TestBaseScala.scala      | 15 +++++++++
 12 files changed, 161 insertions(+), 17 deletions(-)

diff --git a/docs/api/sql/Optimizer.md b/docs/api/sql/Optimizer.md
index 3a96718dc..f52258150 100644
--- a/docs/api/sql/Optimizer.md
+++ b/docs/api/sql/Optimizer.md
@@ -343,3 +343,5 @@ We can compare the metrics of querying the GeoParquet 
dataset with or without th
 | Without spatial predicate | With spatial predicate |
 | ----------- | ----------- |
 | ![](../../image/scan-parquet-without-spatial-pred.png) | 
![](../../image/scan-parquet-with-spatial-pred.png) |
+
+Spatial predicate push-down to GeoParquet is enabled by default. Users can 
manually disable it by setting the Spark configuration 
`spark.sedona.geoparquet.spatialFilterPushDown` to `false`.
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSpatialFilter.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSpatialFilter.scala
index 57ca2161d..5aa782e5b 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSpatialFilter.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSpatialFilter.scala
@@ -29,20 +29,25 @@ import org.locationtech.jts.geom.Geometry
  */
 trait GeoParquetSpatialFilter {
   def evaluate(columns: Map[String, GeometryFieldMetaData]): Boolean
+  def simpleString: String
 }
 
 object GeoParquetSpatialFilter {
 
   case class AndFilter(left: GeoParquetSpatialFilter, right: 
GeoParquetSpatialFilter)
       extends GeoParquetSpatialFilter {
-    override def evaluate(columns: Map[String, GeometryFieldMetaData]): 
Boolean =
+    override def evaluate(columns: Map[String, GeometryFieldMetaData]): 
Boolean = {
       left.evaluate(columns) && right.evaluate(columns)
+    }
+
+    override def simpleString: String = s"(${left.simpleString}) AND 
(${right.simpleString})"
   }
 
   case class OrFilter(left: GeoParquetSpatialFilter, right: 
GeoParquetSpatialFilter)
       extends GeoParquetSpatialFilter {
     override def evaluate(columns: Map[String, GeometryFieldMetaData]): 
Boolean =
       left.evaluate(columns) || right.evaluate(columns)
+    override def simpleString: String = s"(${left.simpleString}) OR 
(${right.simpleString})"
   }
 
   /**
@@ -77,5 +82,6 @@ object GeoParquetSpatialFilter {
         }
       }
     }
+    override def simpleString: String = s"$columnName ${predicateType.name} 
$queryWindow"
   }
 }
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/SpatialFilterPushDownForGeoParquet.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/SpatialFilterPushDownForGeoParquet.scala
index c09f7947b..ba0ecf8a4 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/SpatialFilterPushDownForGeoParquet.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/SpatialFilterPushDownForGeoParquet.scala
@@ -62,23 +62,30 @@ import org.locationtech.jts.geom.Point
 
 class SpatialFilterPushDownForGeoParquet(sparkSession: SparkSession) extends 
Rule[LogicalPlan] {
 
-  override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
-    case filter @ Filter(condition, lr: LogicalRelation) if 
isGeoParquetRelation(lr) =>
-      val filters = splitConjunctivePredicates(condition)
-      val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, 
lr.output)
-      val (_, normalizedFiltersWithoutSubquery) =
-        normalizedFilters.partition(SubqueryExpression.hasSubquery)
-      val geoParquetSpatialFilters =
-        translateToGeoParquetSpatialFilters(normalizedFiltersWithoutSubquery)
-      val hadoopFsRelation = lr.relation.asInstanceOf[HadoopFsRelation]
-      val fileFormat = 
hadoopFsRelation.fileFormat.asInstanceOf[GeoParquetFileFormatBase]
-      if (geoParquetSpatialFilters.isEmpty) filter
-      else {
-        val combinedSpatialFilter = geoParquetSpatialFilters.reduce(AndFilter)
-        val newFileFormat = 
fileFormat.withSpatialPredicates(combinedSpatialFilter)
-        val newRelation = hadoopFsRelation.copy(fileFormat = 
newFileFormat)(sparkSession)
-        filter.copy(child = lr.copy(relation = newRelation))
+  override def apply(plan: LogicalPlan): LogicalPlan = {
+    val enableSpatialFilterPushDown =
+      sparkSession.conf.get("spark.sedona.geoparquet.spatialFilterPushDown", 
"true").toBoolean
+    if (!enableSpatialFilterPushDown) plan
+    else {
+      plan transform {
+        case filter @ Filter(condition, lr: LogicalRelation) if 
isGeoParquetRelation(lr) =>
+          val filters = splitConjunctivePredicates(condition)
+          val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, 
lr.output)
+          val (_, normalizedFiltersWithoutSubquery) =
+            normalizedFilters.partition(SubqueryExpression.hasSubquery)
+          val geoParquetSpatialFilters =
+            
translateToGeoParquetSpatialFilters(normalizedFiltersWithoutSubquery)
+          val hadoopFsRelation = lr.relation.asInstanceOf[HadoopFsRelation]
+          val fileFormat = 
hadoopFsRelation.fileFormat.asInstanceOf[GeoParquetFileFormatBase]
+          if (geoParquetSpatialFilters.isEmpty) filter
+          else {
+            val combinedSpatialFilter = 
geoParquetSpatialFilters.reduce(AndFilter)
+            val newFileFormat = 
fileFormat.withSpatialPredicates(combinedSpatialFilter)
+            val newRelation = hadoopFsRelation.copy(fileFormat = 
newFileFormat)(sparkSession)
+            filter.copy(child = lr.copy(relation = newRelation))
+          }
       }
+    }
   }
 
   private def isGeoParquetRelation(lr: LogicalRelation): Boolean =
diff --git 
a/spark/spark-3.0/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala
 
b/spark/spark-3.0/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala
index 702c6f31f..1924bbfba 100644
--- 
a/spark/spark-3.0/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala
+++ 
b/spark/spark-3.0/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala
@@ -66,6 +66,14 @@ class GeoParquetFileFormat(val spatialFilter: 
Option[GeoParquetSpatialFilter])
 
   override def hashCode(): Int = getClass.hashCode()
 
+  override def toString(): String = {
+    // HACK: This is the only place we can inject spatial filter information 
into the described query plan.
+    // Please see 
org.apache.spark.sql.execution.DataSourceScanExec#simpleString for more details.
+    "GeoParquet" + spatialFilter
+      .map(filter => " with spatial filter [" + filter.simpleString + "]")
+      .getOrElse("")
+  }
+
   def withSpatialPredicates(spatialFilter: GeoParquetSpatialFilter): 
GeoParquetFileFormat =
     new GeoParquetFileFormat(Some(spatialFilter))
 
diff --git 
a/spark/spark-3.0/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala
 
b/spark/spark-3.0/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala
index 8f3cc3f1e..a2a257e8f 100644
--- 
a/spark/spark-3.0/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala
+++ 
b/spark/spark-3.0/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.FileSourceScanExec
 import org.apache.spark.sql.execution.datasources.parquet.GeoParquetFileFormat
 import org.apache.spark.sql.execution.datasources.parquet.GeoParquetMetaData
 import 
org.apache.spark.sql.execution.datasources.parquet.GeoParquetSpatialFilter
+import org.apache.spark.sql.execution.SimpleMode
 import org.locationtech.jts.geom.Coordinate
 import org.locationtech.jts.geom.Geometry
 import org.locationtech.jts.geom.GeometryFactory
@@ -223,6 +224,25 @@ class GeoParquetSpatialFilterPushDownSuite extends 
TestBaseScala with TableDrive
         "id < 10 AND ST_Intersects(geom, ST_GeomFromText('POLYGON ((5 -5, 15 
-5, 15 5, 5 5, 5 -5))'))",
         Seq(1, 3))
     }
+
+    it("Explain geoparquet scan with spatial filter push-down") {
+      val dfFiltered = geoParquetDf.where(
+        "ST_Intersects(geom, ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 
0, 0 0))'))")
+      val explainString = dfFiltered.queryExecution.explainString(SimpleMode)
+      assert(explainString.contains("FileScan geoparquet"))
+      assert(explainString.contains("with spatial filter"))
+    }
+
+    it("Manually disable spatial filter push-down") {
+      withConf(Map("spark.sedona.geoparquet.spatialFilterPushDown" -> 
"false")) {
+        val dfFiltered = geoParquetDf.where(
+          "ST_Intersects(geom, ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 
0, 0 0))'))")
+        val explainString = dfFiltered.queryExecution.explainString(SimpleMode)
+        assert(explainString.contains("FileScan geoparquet"))
+        assert(!explainString.contains("with spatial filter"))
+        assert(getPushedDownSpatialFilter(dfFiltered).isEmpty)
+      }
+    }
   }
 
   /**
diff --git 
a/spark/spark-3.0/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala 
b/spark/spark-3.0/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
index 2da12eceb..5dd5d9309 100644
--- a/spark/spark-3.0/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
+++ b/spark/spark-3.0/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
@@ -54,4 +54,19 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll {
   def loadCsv(path: String): DataFrame = {
     sparkSession.read.format("csv").option("delimiter", ",").option("header", 
"false").load(path)
   }
+
+  def withConf[T](conf: Map[String, String])(f: => T): T = {
+    val oldConf = conf.keys.map(key => key -> sparkSession.conf.getOption(key))
+    conf.foreach { case (key, value) => sparkSession.conf.set(key, value) }
+    try {
+      f
+    } finally {
+      oldConf.foreach { case (key, value) =>
+        value match {
+          case Some(v) => sparkSession.conf.set(key, v)
+          case None => sparkSession.conf.unset(key)
+        }
+      }
+    }
+  }
 }
diff --git 
a/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala
 
b/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala
index dde566ba2..325a72098 100644
--- 
a/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala
+++ 
b/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala
@@ -65,6 +65,14 @@ class GeoParquetFileFormat(val spatialFilter: 
Option[GeoParquetSpatialFilter])
 
   override def hashCode(): Int = getClass.hashCode()
 
+  override def toString(): String = {
+    // HACK: This is the only place we can inject spatial filter information 
into the described query plan.
+    // Please see 
org.apache.spark.sql.execution.DataSourceScanExec#simpleString for more details.
+    "GeoParquet" + spatialFilter
+      .map(filter => " with spatial filter [" + filter.simpleString + "]")
+      .getOrElse("")
+  }
+
   def withSpatialPredicates(spatialFilter: GeoParquetSpatialFilter): 
GeoParquetFileFormat =
     new GeoParquetFileFormat(Some(spatialFilter))
 
diff --git 
a/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala
 
b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala
index 8f3cc3f1e..a2a257e8f 100644
--- 
a/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala
+++ 
b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.FileSourceScanExec
 import org.apache.spark.sql.execution.datasources.parquet.GeoParquetFileFormat
 import org.apache.spark.sql.execution.datasources.parquet.GeoParquetMetaData
 import 
org.apache.spark.sql.execution.datasources.parquet.GeoParquetSpatialFilter
+import org.apache.spark.sql.execution.SimpleMode
 import org.locationtech.jts.geom.Coordinate
 import org.locationtech.jts.geom.Geometry
 import org.locationtech.jts.geom.GeometryFactory
@@ -223,6 +224,25 @@ class GeoParquetSpatialFilterPushDownSuite extends 
TestBaseScala with TableDrive
         "id < 10 AND ST_Intersects(geom, ST_GeomFromText('POLYGON ((5 -5, 15 
-5, 15 5, 5 5, 5 -5))'))",
         Seq(1, 3))
     }
+
+    it("Explain geoparquet scan with spatial filter push-down") {
+      val dfFiltered = geoParquetDf.where(
+        "ST_Intersects(geom, ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 
0, 0 0))'))")
+      val explainString = dfFiltered.queryExecution.explainString(SimpleMode)
+      assert(explainString.contains("FileScan geoparquet"))
+      assert(explainString.contains("with spatial filter"))
+    }
+
+    it("Manually disable spatial filter push-down") {
+      withConf(Map("spark.sedona.geoparquet.spatialFilterPushDown" -> 
"false")) {
+        val dfFiltered = geoParquetDf.where(
+          "ST_Intersects(geom, ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 
0, 0 0))'))")
+        val explainString = dfFiltered.queryExecution.explainString(SimpleMode)
+        assert(explainString.contains("FileScan geoparquet"))
+        assert(!explainString.contains("with spatial filter"))
+        assert(getPushedDownSpatialFilter(dfFiltered).isEmpty)
+      }
+    }
   }
 
   /**
diff --git 
a/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala 
b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
index 2da12eceb..5dd5d9309 100644
--- a/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
+++ b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
@@ -54,4 +54,19 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll {
   def loadCsv(path: String): DataFrame = {
     sparkSession.read.format("csv").option("delimiter", ",").option("header", 
"false").load(path)
   }
+
+  def withConf[T](conf: Map[String, String])(f: => T): T = {
+    val oldConf = conf.keys.map(key => key -> sparkSession.conf.getOption(key))
+    conf.foreach { case (key, value) => sparkSession.conf.set(key, value) }
+    try {
+      f
+    } finally {
+      oldConf.foreach { case (key, value) =>
+        value match {
+          case Some(v) => sparkSession.conf.set(key, v)
+          case None => sparkSession.conf.unset(key)
+        }
+      }
+    }
+  }
 }
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala
index b8d422dce..06c9683cd 100644
--- 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala
@@ -64,6 +64,14 @@ class GeoParquetFileFormat(val spatialFilter: 
Option[GeoParquetSpatialFilter])
 
   override def hashCode(): Int = getClass.hashCode()
 
+  override def toString(): String = {
+    // HACK: This is the only place we can inject spatial filter information 
into the described query plan.
+    // Please see 
org.apache.spark.sql.execution.DataSourceScanExec#simpleString for more details.
+    "GeoParquet" + spatialFilter
+      .map(filter => " with spatial filter [" + filter.simpleString + "]")
+      .getOrElse("")
+  }
+
   def withSpatialPredicates(spatialFilter: GeoParquetSpatialFilter): 
GeoParquetFileFormat =
     new GeoParquetFileFormat(Some(spatialFilter))
 
diff --git 
a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala
 
b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala
index 8f3cc3f1e..a2a257e8f 100644
--- 
a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala
+++ 
b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.FileSourceScanExec
 import org.apache.spark.sql.execution.datasources.parquet.GeoParquetFileFormat
 import org.apache.spark.sql.execution.datasources.parquet.GeoParquetMetaData
 import 
org.apache.spark.sql.execution.datasources.parquet.GeoParquetSpatialFilter
+import org.apache.spark.sql.execution.SimpleMode
 import org.locationtech.jts.geom.Coordinate
 import org.locationtech.jts.geom.Geometry
 import org.locationtech.jts.geom.GeometryFactory
@@ -223,6 +224,25 @@ class GeoParquetSpatialFilterPushDownSuite extends 
TestBaseScala with TableDrive
         "id < 10 AND ST_Intersects(geom, ST_GeomFromText('POLYGON ((5 -5, 15 
-5, 15 5, 5 5, 5 -5))'))",
         Seq(1, 3))
     }
+
+    it("Explain geoparquet scan with spatial filter push-down") {
+      val dfFiltered = geoParquetDf.where(
+        "ST_Intersects(geom, ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 
0, 0 0))'))")
+      val explainString = dfFiltered.queryExecution.explainString(SimpleMode)
+      assert(explainString.contains("FileScan geoparquet"))
+      assert(explainString.contains("with spatial filter"))
+    }
+
+    it("Manually disable spatial filter push-down") {
+      withConf(Map("spark.sedona.geoparquet.spatialFilterPushDown" -> 
"false")) {
+        val dfFiltered = geoParquetDf.where(
+          "ST_Intersects(geom, ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 
0, 0 0))'))")
+        val explainString = dfFiltered.queryExecution.explainString(SimpleMode)
+        assert(explainString.contains("FileScan geoparquet"))
+        assert(!explainString.contains("with spatial filter"))
+        assert(getPushedDownSpatialFilter(dfFiltered).isEmpty)
+      }
+    }
   }
 
   /**
diff --git 
a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala 
b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
index 2da12eceb..5dd5d9309 100644
--- a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
+++ b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
@@ -54,4 +54,19 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll {
   def loadCsv(path: String): DataFrame = {
     sparkSession.read.format("csv").option("delimiter", ",").option("header", 
"false").load(path)
   }
+
+  def withConf[T](conf: Map[String, String])(f: => T): T = {
+    val oldConf = conf.keys.map(key => key -> sparkSession.conf.getOption(key))
+    conf.foreach { case (key, value) => sparkSession.conf.set(key, value) }
+    try {
+      f
+    } finally {
+      oldConf.foreach { case (key, value) =>
+        value match {
+          case Some(v) => sparkSession.conf.set(key, v)
+          case None => sparkSession.conf.unset(key)
+        }
+      }
+    }
+  }
 }

Reply via email to