vitaliili-db commented on code in PR #55887:
URL: https://github.com/apache/spark/pull/55887#discussion_r3250287337


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala:
##########
@@ -187,6 +193,90 @@ object PushDownUtils extends Logging {
     }
   }
 
+  /**
+   * Pushes runtime filters into `scan` and re-plans its input partitions. For 
scans whose
+   * `outputPartitioning` is a [[KeyedPartitioning]] (SPJ-active), validates 
that the data source
+   * preserved the original partitioning and pads with `None` to preserve key 
alignment with the
+   * pre-filter partition set.
+   *
+   * Must be called at execute time: runtime filters carry 
[[DynamicPruningExpression]] and
+   * scalar-subquery references whose values are only resolved after their 
broadcast/subquery
+   * side completes. Callers should wrap the result in a `lazy val` so the 
mutating
+   * [[pushRuntimeFilters]] call runs at most once per scan instance.
+   *
+   * @param scan                      the V2 scan to push filters into
+   * @param runtimeFilters            runtime filters to translate and push
+   * @param partitionPredicateSchema  by-name schema for iterative 
[[PartitionPredicate]] pushdown
+   * @param output                    scan output attributes
+   * @param outputPartitioning        Spark-side output partitioning (used for 
SPJ validation)
+   * @param inputPartitions           by-name original (unfiltered) 
partitions; consulted only when
+   *                                  no runtime filters fire, so callers can 
compute it lazily
+   * @return one entry per original input partition: `Some(part)` for 
surviving partitions and
+   *         `None` for partition keys whose splits were entirely pruned (SPJ 
alignment)
+   */
+  def filterAndPlanPartitions(
+      scan: Scan,
+      runtimeFilters: Seq[Expression],
+      partitionPredicateSchema: => Option[Seq[PartitionPredicateField]],
+      output: Seq[AttributeReference],
+      outputPartitioning: Partitioning,
+      inputPartitions: => Seq[InputPartition]): Seq[Option[InputPartition]] = {
+    val filtered = pushRuntimeFilters(scan, runtimeFilters, 
partitionPredicateSchema, output)
+    if (filtered) {
+      // call toBatch again to get filtered partitions
+      val newPartitions = scan.toBatch.planInputPartitions()
+
+      outputPartitioning match {
+        case k: KeyedPartitioning =>
+          if (newPartitions.exists(!_.isInstanceOf[HasPartitionKey])) {
+            throw new SparkException("Data source must have preserved the 
original partitioning " +
+                "during runtime filtering: not all partitions implement 
HasPartitionKey after " +
+                "filtering")
+          }
+
+          val inputMap = 
k.partitionKeys.groupBy(identity).view.mapValues(_.size)
+          val comparableKeyWrapperFactory = InternalRowComparableWrapper
+            .getInternalRowComparableWrapperFactory(k.expressionDataTypes)
+          val filteredMap = newPartitions.groupBy(
+            p => 
comparableKeyWrapperFactory(p.asInstanceOf[HasPartitionKey].partitionKey())
+          )
+
+          if (!filteredMap.keySet.subsetOf(inputMap.keySet)) {
+            throw new SparkException("During runtime filtering, data source 
must not report new " +
+                "partition keys that are not present in the original 
partitioning.")
+          }
+
+          inputMap.toSeq
+            .sortBy(_._1)(k.keyOrdering)
+            .flatMap { case (key, size) =>
+              // We require the new number of partitions to be equal or less 
than the old number of
+              // partitions for a given key. In the case of less than, empty 
partitions are added.
+              val fps = filteredMap.getOrElse(key, Array.empty)
+
+              if (fps.size > size) {
+                throw new SparkException("During runtime filtering, data 
source must not report " +
+                  s"new partitions for a given key. Before: $size partitions. 
" +
+                  s"After: ${fps.size} partitions")
+              }
+
+              fps.map(Some).padTo(size, None)
+            }
+
+        case _ =>
+          // no validation is needed as the data source did not report any 
specific partitioning
+          newPartitions.toSeq.map(Some)
+      }
+
+    } else {
+      (outputPartitioning match {
+        case k: KeyedPartitioning =>
+          
inputPartitions.sortBy(_.asInstanceOf[HasPartitionKey].partitionKey())(k.keyRowOrdering)

Review Comment:
   done, enforced condition



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to