gengliangwang commented on code in PR #55887:
URL: https://github.com/apache/spark/pull/55887#discussion_r3246275170
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala:
##########
@@ -187,6 +193,90 @@ object PushDownUtils extends Logging {
}
}
+ /**
+ * Pushes runtime filters into `scan` and re-plans its input partitions. For
scans whose
+ * `outputPartitioning` is a [[KeyedPartitioning]] (SPJ-active), validates
that the data source
+ * preserved the original partitioning and pads with `None` to preserve key
alignment with the
+ * pre-filter partition set.
+ *
+ * Must be called at execute time: runtime filters carry
[[DynamicPruningExpression]] and
+ * scalar-subquery references whose values are only resolved after their
broadcast/subquery
+ * side completes. Callers should wrap the result in a `lazy val` so the
mutating
+ * [[pushRuntimeFilters]] call runs at most once per scan instance.
+ *
+ * @param scan the V2 scan to push filters into
+ * @param runtimeFilters runtime filters to translate and push
+ * @param partitionPredicateSchema by-name schema for iterative
[[PartitionPredicate]] pushdown
+ * @param output scan output attributes
+ * @param outputPartitioning Spark-side output partitioning (used for
SPJ validation)
+ * @param inputPartitions by-name original (unfiltered)
partitions; consulted only when
+ * no runtime filters fire, so callers can
compute it lazily
+ * @return one entry per original input partition: `Some(part)` for
surviving partitions and
+ * `None` for partition keys whose splits were entirely pruned (SPJ
alignment)
+ */
+ def filterAndPlanPartitions(
Review Comment:
Could you link the follow-up PR/JIRA for the alternative scan operator that
motivates this extraction? Without seeing the second caller it's hard to
validate the parameter shape (by-name `partitionPredicateSchema`, by-name
`inputPartitions`, and the new transforms-based `getPartitionPredicateSchema`
overload) — they look defensive for `BatchScanExec` but are presumably
load-bearing for the upcoming caller.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala:
##########
@@ -139,12 +141,16 @@ object PushDownUtils extends Logging {
* the first pass are used to derive PartitionPredicates in the second pass,
avoiding duplicate
* pushdown.
*
+ * The partition-predicate schema is passed by-name so callers that cannot
supply one (no
+ * partition transforms available) or whose scan does not opt into iterative
pushdown pay no
+ * derivation cost.
+ *
* @return true if any filters were pushed to the data source
*/
def pushRuntimeFilters(
scan: Scan,
runtimeFilters: Seq[Expression],
- table: Table,
+ partitionPredicateSchema: => Option[Seq[PartitionPredicateField]],
Review Comment:
Previously this method took `table: Table` and computed the schema
internally; now every caller has to call `getPartitionPredicateSchema(table,
output)` themselves and pass the result. For callers that have a `Table` (the
only existing caller today), that's two calls where there used to be one.
Consider keeping a thin `Table`-accepting overload as the ergonomic default,
with this `partitionPredicateSchema`-accepting form for callers that don't have
a `Table`.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala:
##########
@@ -187,6 +193,90 @@ object PushDownUtils extends Logging {
}
}
+ /**
+ * Pushes runtime filters into `scan` and re-plans its input partitions. For
scans whose
+ * `outputPartitioning` is a [[KeyedPartitioning]] (SPJ-active), validates
that the data source
+ * preserved the original partitioning and pads with `None` to preserve key
alignment with the
+ * pre-filter partition set.
+ *
+ * Must be called at execute time: runtime filters carry
[[DynamicPruningExpression]] and
+ * scalar-subquery references whose values are only resolved after their
broadcast/subquery
+ * side completes. Callers should wrap the result in a `lazy val` so the
mutating
+ * [[pushRuntimeFilters]] call runs at most once per scan instance.
+ *
+ * @param scan the V2 scan to push filters into
+ * @param runtimeFilters runtime filters to translate and push
+ * @param partitionPredicateSchema by-name schema for iterative
[[PartitionPredicate]] pushdown
+ * @param output scan output attributes
+ * @param outputPartitioning Spark-side output partitioning (used for
SPJ validation)
+ * @param inputPartitions by-name original (unfiltered)
partitions; consulted only when
+ * no runtime filters fire, so callers can
compute it lazily
+ * @return one entry per original input partition: `Some(part)` for
surviving partitions and
+ * `None` for partition keys whose splits were entirely pruned (SPJ
alignment)
+ */
+ def filterAndPlanPartitions(
+ scan: Scan,
+ runtimeFilters: Seq[Expression],
+ partitionPredicateSchema: => Option[Seq[PartitionPredicateField]],
+ output: Seq[AttributeReference],
+ outputPartitioning: Partitioning,
+ inputPartitions: => Seq[InputPartition]): Seq[Option[InputPartition]] = {
+ val filtered = pushRuntimeFilters(scan, runtimeFilters,
partitionPredicateSchema, output)
+ if (filtered) {
+ // call toBatch again to get filtered partitions
+ val newPartitions = scan.toBatch.planInputPartitions()
+
+ outputPartitioning match {
+ case k: KeyedPartitioning =>
+ if (newPartitions.exists(!_.isInstanceOf[HasPartitionKey])) {
+ throw new SparkException("Data source must have preserved the
original partitioning " +
+ "during runtime filtering: not all partitions implement
HasPartitionKey after " +
+ "filtering")
+ }
+
+ val inputMap =
k.partitionKeys.groupBy(identity).view.mapValues(_.size)
+ val comparableKeyWrapperFactory = InternalRowComparableWrapper
+ .getInternalRowComparableWrapperFactory(k.expressionDataTypes)
+ val filteredMap = newPartitions.groupBy(
+ p =>
comparableKeyWrapperFactory(p.asInstanceOf[HasPartitionKey].partitionKey())
+ )
+
+ if (!filteredMap.keySet.subsetOf(inputMap.keySet)) {
+ throw new SparkException("During runtime filtering, data source
must not report new " +
+ "partition keys that are not present in the original
partitioning.")
+ }
+
+ inputMap.toSeq
+ .sortBy(_._1)(k.keyOrdering)
+ .flatMap { case (key, size) =>
+ // We require the new number of partitions to be equal or less
than the old number of
+ // partitions for a given key. In the case of less than, empty
partitions are added.
+ val fps = filteredMap.getOrElse(key, Array.empty)
+
+ if (fps.size > size) {
+ throw new SparkException("During runtime filtering, data
source must not report " +
+ s"new partitions for a given key. Before: $size partitions.
" +
+ s"After: ${fps.size} partitions")
+ }
+
+ fps.map(Some).padTo(size, None)
+ }
+
+ case _ =>
+ // no validation is needed as the data source did not report any
specific partitioning
+ newPartitions.toSeq.map(Some)
+ }
+
+ } else {
+ (outputPartitioning match {
+ case k: KeyedPartitioning =>
+
inputPartitions.sortBy(_.asInstanceOf[HasPartitionKey].partitionKey())(k.keyRowOrdering)
Review Comment:
When `outputPartitioning` is `KeyedPartitioning`, this branch
unconditionally casts each input partition to `HasPartitionKey`. For
`BatchScanExec` this is safe because
`DataSourceV2ScanExecBase.outputPartitioning` only produces `KeyedPartitioning`
when every input partition already implements `HasPartitionKey`. The helper
itself doesn't document or enforce this invariant, though — a future caller
pairing a `KeyedPartitioning` with non-`HasPartitionKey` partitions would hit a
cryptic `ClassCastException`. Worth a sentence in the Scaladoc, or an explicit
precondition check.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]