This is an automated email from the ASF dual-hosted git repository.

gengliangwang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6a855fdce5d2 [SPARK-56868][SQL] Extract V2 runtime-filter + partition 
planning into a shared helper
6a855fdce5d2 is described below

commit 6a855fdce5d2da770646b937c498bd2523e021c8
Author: Vitalii Li <[email protected]>
AuthorDate: Fri May 15 16:38:10 2026 -0700

    [SPARK-56868][SQL] Extract V2 runtime-filter + partition planning into a 
shared helper
    
    ### What changes were proposed in this pull request?
    
    Lift the body of `BatchScanExec.filteredPartitions` (runtime filter 
pushdown, re-planning, and `KeyedPartitioning` validation + `None`-padding) 
into a new `PushDownUtils.filterAndPlanPartitions` helper so the logic can be 
reused by alternative DataSourceV2 physical scan operators.
    
    ### Why are the changes needed?
    
    The runtime-filter pushdown pipeline in `BatchScanExec.filteredPartitions` 
contains non-trivial logic: V2 predicate translation (DPP + scalar subqueries 
via SPARK-56467), iterative `PartitionPredicate` pushdown (SPARK-55596), 
`KeyedPartitioning` validation, and `None`-padding to preserve SPJ key 
alignment. This refactoring allows for reuse of this logic by alternative scan 
operators.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Existing coverage, no new logic added
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: Claude (Anthropic), via Claude Code
    
    Closes #55887 from vitaliili-db/spark-unify-v2-runtime-filter-helper.
    
    Authored-by: Vitalii Li <[email protected]>
    Signed-off-by: Gengliang Wang <[email protected]>
---
 .../execution/datasources/v2/BatchScanExec.scala   |  71 ++-----------
 .../execution/datasources/v2/PushDownUtils.scala   | 117 ++++++++++++++++++++-
 2 files changed, 123 insertions(+), 65 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
index e9a18833ed9a..ea25f3b1c85f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
@@ -19,13 +19,12 @@ package org.apache.spark.sql.execution.datasources.v2
 
 import java.util.Objects
 
-import org.apache.spark.SparkException
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.QueryPlan
-import org.apache.spark.sql.catalyst.plans.physical.{KeyedPartitioning, 
SinglePartition}
-import org.apache.spark.sql.catalyst.util.{truncatedString, 
InternalRowComparableWrapper}
+import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
+import org.apache.spark.sql.catalyst.util.truncatedString
 import org.apache.spark.sql.connector.catalog.Table
 import org.apache.spark.sql.connector.read._
 import org.apache.spark.util.ArrayImplicits._
@@ -60,64 +59,14 @@ case class BatchScanExec(
     batch.planInputPartitions().toImmutableArraySeq
 
   // Visible for testing
-  @transient private[sql] lazy val filteredPartitions: 
Seq[Option[InputPartition]] = {
-    val originalPartitioning = outputPartitioning
-
-    val filtered = PushDownUtils.pushRuntimeFilters(scan, runtimeFilters, 
table, output)
-    if (filtered) {
-      // call toBatch again to get filtered partitions
-      val newPartitions = scan.toBatch.planInputPartitions()
-
-      originalPartitioning match {
-        case k: KeyedPartitioning =>
-          if (newPartitions.exists(!_.isInstanceOf[HasPartitionKey])) {
-            throw new SparkException("Data source must have preserved the 
original partitioning " +
-                "during runtime filtering: not all partitions implement 
HasPartitionKey after " +
-                "filtering")
-          }
-
-          val inputMap = 
k.partitionKeys.groupBy(identity).view.mapValues(_.size)
-          val comparableKeyWrapperFactory = InternalRowComparableWrapper
-            .getInternalRowComparableWrapperFactory(k.expressionDataTypes)
-          val filteredMap = newPartitions.groupBy(
-            p => 
comparableKeyWrapperFactory(p.asInstanceOf[HasPartitionKey].partitionKey())
-          )
-
-          if (!filteredMap.keySet.subsetOf(inputMap.keySet)) {
-            throw new SparkException("During runtime filtering, data source 
must not report new " +
-                "partition keys that are not present in the original 
partitioning.")
-          }
-
-          inputMap.toSeq
-            .sortBy(_._1)(k.keyOrdering)
-            .flatMap { case (key, size) =>
-              // We require the new number of partitions to be equal or less 
than the old number of
-              // partitions for a given key. In the case of less than, empty 
partitions are added.
-              val fps = filteredMap.getOrElse(key, Array.empty)
-
-              if (fps.size > size) {
-                throw new SparkException("During runtime filtering, data 
source must not report " +
-                  s"new partitions for a given key. Before: $size partitions. 
" +
-                  s"After: ${fps.size} partitions")
-              }
-
-              fps.map(Some).padTo(size, None)
-            }
-
-        case _ =>
-          // no validation is needed as the data source did not report any 
specific partitioning
-          newPartitions.toSeq.map(Some)
-      }
-
-    } else {
-      (originalPartitioning match {
-        case k: KeyedPartitioning =>
-          
inputPartitions.sortBy(_.asInstanceOf[HasPartitionKey].partitionKey())(k.keyRowOrdering)
-
-        case _ => inputPartitions
-      }).map(Some)
-    }
-  }
+  @transient private[sql] lazy val filteredPartitions: 
Seq[Option[InputPartition]] =
+    PushDownUtils.replanWithRuntimeFilters(
+      scan,
+      runtimeFilters,
+      table,
+      output,
+      outputPartitioning,
+      inputPartitions)
 
   override lazy val readerFactory: PartitionReaderFactory = 
batch.createReaderFactory()
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
index e31e81fc1fa9..dc6de6f29af9 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
@@ -19,17 +19,19 @@ package org.apache.spark.sql.execution.datasources.v2
 
 import scala.collection.mutable
 
+import org.apache.spark.SparkException
 import org.apache.spark.internal.{Logging, LogKeys}
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
AttributeSet, DynamicPruning, DynamicPruningExpression, Expression, 
ExpressionSet, GetStructField, NamedExpression, PythonUDF, SchemaPruning, 
SubqueryExpression, V2ExpressionUtils}
 import org.apache.spark.sql.catalyst.plans.logical.SampleMethod
+import org.apache.spark.sql.catalyst.plans.physical.{KeyedPartitioning, 
Partitioning}
 import org.apache.spark.sql.catalyst.types.DataTypeUtils
 import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
-import org.apache.spark.sql.catalyst.util.CharVarcharUtils
+import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, 
InternalRowComparableWrapper}
 import org.apache.spark.sql.connector.catalog.Table
-import org.apache.spark.sql.connector.expressions.{IdentityTransform, 
SortOrder}
+import org.apache.spark.sql.connector.expressions.{IdentityTransform, 
SortOrder, Transform}
 import org.apache.spark.sql.connector.expressions.filter.Predicate
-import org.apache.spark.sql.connector.read.{SampleMethod => SampleMethodV2, 
Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownLimit, 
SupportsPushDownOffset, SupportsPushDownRequiredColumns, 
SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters, 
SupportsRuntimeV2Filtering}
+import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, 
SampleMethod => SampleMethodV2, Scan, ScanBuilder, SupportsPushDownFilters, 
SupportsPushDownLimit, SupportsPushDownOffset, SupportsPushDownRequiredColumns, 
SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters, 
SupportsRuntimeV2Filtering}
 import org.apache.spark.sql.execution.{ScalarSubquery => ExecScalarSubquery}
 import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, 
DataSourceUtils}
 import org.apache.spark.sql.internal.SQLConf
@@ -187,6 +189,101 @@ object PushDownUtils extends Logging {
     }
   }
 
+  /**
+   * Pushes runtime filters into `scan` and re-plans its input partitions. For 
scans whose
+   * `outputPartitioning` is a [[KeyedPartitioning]] (SPJ-active), validates 
that the data source
+   * preserved the original partitioning and pads with `None` to preserve key 
alignment with the
+   * pre-filter partition set.
+   *
+   * Must be called at execute time: runtime filters carry 
[[DynamicPruningExpression]] and
+   * scalar-subquery references whose values are only resolved after their 
broadcast/subquery
+   * side completes. The mutating [[pushRuntimeFilters]] call must run at most 
once per scan
+   * instance; callers are responsible for caching the result.
+   *
+   * Precondition: when `outputPartitioning` is a [[KeyedPartitioning]], every 
element of
+   * `originalPartitions` (and every partition re-planned by the data source) 
must implement
+   * [[HasPartitionKey]].
+   *
+   * @param scan                the V2 scan to push filters into
+   * @param runtimeFilters      runtime filters to translate and push
+   * @param table               the table backing the scan, used to derive the 
partition-predicate
+   *                            schema for iterative [[PartitionPredicate]] 
pushdown
+   * @param output              scan output attributes
+   * @param outputPartitioning  Spark-side output partitioning (used for SPJ 
validation)
+   * @param originalPartitions  unfiltered partitions, consulted only when no 
runtime filters fire
+   * @return one entry per original input partition: `Some(part)` for 
surviving partitions and
+   *         `None` for partition keys whose splits were entirely pruned (SPJ 
alignment)
+   */
+  def replanWithRuntimeFilters(
+      scan: Scan,
+      runtimeFilters: Seq[Expression],
+      table: Table,
+      output: Seq[AttributeReference],
+      outputPartitioning: Partitioning,
+      originalPartitions: => Seq[InputPartition]): Seq[Option[InputPartition]] 
= {
+    val filtered = pushRuntimeFilters(scan, runtimeFilters, table, output)
+    if (filtered) {
+      // call toBatch again to get filtered partitions
+      val newPartitions = scan.toBatch.planInputPartitions()
+
+      outputPartitioning match {
+        case k: KeyedPartitioning =>
+          if (newPartitions.exists(!_.isInstanceOf[HasPartitionKey])) {
+            throw new SparkException("Data source must have preserved the 
original partitioning " +
+                "during runtime filtering: not all partitions implement 
HasPartitionKey after " +
+                "filtering")
+          }
+
+          val inputMap = 
k.partitionKeys.groupBy(identity).view.mapValues(_.size)
+          val comparableKeyWrapperFactory = InternalRowComparableWrapper
+            .getInternalRowComparableWrapperFactory(k.expressionDataTypes)
+          val filteredMap = newPartitions.groupBy(
+            p => 
comparableKeyWrapperFactory(p.asInstanceOf[HasPartitionKey].partitionKey())
+          )
+
+          if (!filteredMap.keySet.subsetOf(inputMap.keySet)) {
+            throw new SparkException("During runtime filtering, data source 
must not report new " +
+                "partition keys that are not present in the original 
partitioning.")
+          }
+
+          // Pad the post-filter partitions with `None` per original key so 
SPJ key alignment with
+          // the other side of the join is preserved when splits are entirely 
pruned.
+          inputMap.toSeq
+            .sortBy(_._1)(k.keyOrdering)
+            .flatMap { case (key, size) =>
+              // We require the new number of partitions to be equal or less 
than the old number of
+              // partitions for a given key. In the case of less than, empty 
partitions are added.
+              val fps = filteredMap.getOrElse(key, Array.empty)
+
+              if (fps.size > size) {
+                throw new SparkException("During runtime filtering, data 
source must not report " +
+                  s"new partitions for a given key. Before: $size partitions. 
" +
+                  s"After: ${fps.size} partitions")
+              }
+
+              fps.map(Some).padTo(size, None)
+            }
+
+        case _ =>
+          // no validation is needed as the data source did not report any 
specific partitioning
+          newPartitions.toSeq.map(Some)
+      }
+
+    } else {
+      val parts = originalPartitions
+      (outputPartitioning match {
+        case k: KeyedPartitioning =>
+          if (parts.exists(!_.isInstanceOf[HasPartitionKey])) {
+            throw new SparkException("Original partitions must implement 
HasPartitionKey when " +
+                "outputPartitioning is KeyedPartitioning.")
+          }
+          
parts.sortBy(_.asInstanceOf[HasPartitionKey].partitionKey())(k.keyRowOrdering)
+
+        case _ => parts
+      }).map(Some)
+    }
+  }
+
   /**
    * Returns a Seq of [[PartitionPredicateField]] representing partition 
transform expression types,
    * if schema is supported for [[PartitionPredicate]] push down. None if not 
supported.
@@ -202,7 +299,19 @@ object PushDownUtils extends Logging {
    */
   def getPartitionPredicateSchema(table: Table, output: 
Seq[AttributeReference])
   : Option[Seq[PartitionPredicateField]] = {
-    val transforms = table.partitioning
+    getPartitionPredicateSchema(table.partitioning, output)
+  }
+
+  /**
+   * Returns a Seq of [[PartitionPredicateField]] representing partition 
transform expression types,
+   * if schema is supported for [[PartitionPredicate]] push down. None if not 
supported.
+   *
+   * Use this overload when the caller has access to the partition transforms 
but not the
+   * full [[Table]].
+   */
+  def getPartitionPredicateSchema(
+      transforms: Array[Transform],
+      output: Seq[AttributeReference]): Option[Seq[PartitionPredicateField]] = 
{
     if (transforms.isEmpty) {
       None
     } else {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to