This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 92fea74fda1 [SPARK-41398][SQL] Relax constraints on 
Storage-Partitioned Join when partition keys after runtime filtering do not 
match
92fea74fda1 is described below

commit 92fea74fda1f2f764f6fc38c82eb2dff7972ad87
Author: Chao Sun <[email protected]>
AuthorDate: Tue Dec 6 07:48:34 2022 -0800

    [SPARK-41398][SQL] Relax constraints on Storage-Partitioned Join when 
partition keys after runtime filtering do not match
    
    ### What changes were proposed in this pull request?
    
    This PR relaxes the current constraint of Storage-Partitioned Join which 
requires that the partition keys after runtime filtering to be exact the same 
as the partition keys before the filtering.
    
    ### Why are the changes needed?
    
    At the moment, Spark requires that when Storage-Partitioned Join is used 
together with runtime filtering, the partition keys before and after the 
filtering shall exact match. If not, a `SparkException` is thrown.
    
    However, this is not strictly necessary in the case where the partition 
keys after the filtering is a subset of the original keys. In this scenario, we 
can use empty partitions for those missing keys in the latter.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Modified an existing test case to match the change.
    
    Closes #38924 from sunchao/SPARK-41398.
    
    Authored-by: Chao Sun <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 .../execution/datasources/v2/BatchScanExec.scala   | 36 ++++++++++++++++------
 .../connector/KeyGroupedPartitioningSuite.scala    |  6 ++--
 2 files changed, 29 insertions(+), 13 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
index 48569ddc07d..0f7bdd9e1fb 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
@@ -81,18 +81,21 @@ case class BatchScanExec(
 
           val newRows = new InternalRowSet(p.expressions.map(_.dataType))
           newRows ++= 
newPartitions.map(_.asInstanceOf[HasPartitionKey].partitionKey())
-          val oldRows = p.partitionValuesOpt.get
 
-          if (oldRows.size != newRows.size) {
-            throw new SparkException("Data source must have preserved the 
original partitioning " +
-                "during runtime filtering: the number of unique partition 
values obtained " +
-                s"through HasPartitionKey changed: before ${oldRows.size}, 
after ${newRows.size}")
+          val oldRows = p.partitionValuesOpt.get.toSet
+          // We require the new number of partition keys to be equal or less 
than the old number
+          // of partition keys here. In the case of less than, empty 
partitions will be added for
+          // those missing keys that are not present in the new input 
partitions.
+          if (oldRows.size < newRows.size) {
+            throw new SparkException("During runtime filtering, data source 
must either report " +
+                "the same number of partition keys, or a subset of partition 
keys from the " +
+                s"original. Before: ${oldRows.size} partition keys. After: 
${newRows.size} " +
+                "partition keys")
           }
 
-          if (!oldRows.forall(newRows.contains)) {
-            throw new SparkException("Data source must have preserved the 
original partitioning " +
-                "during runtime filtering: the number of unique partition 
values obtained " +
-                s"through HasPartitionKey remain the same but do not exactly 
match")
+          if (!newRows.forall(oldRows.contains)) {
+            throw new SparkException("During runtime filtering, data source 
must not report new " +
+                "partition keys that are not present in the original 
partitioning.")
           }
 
           groupPartitions(newPartitions).get.map(_._2)
@@ -114,8 +117,21 @@ case class BatchScanExec(
       // return an empty RDD with 1 partition if dynamic filtering removed the 
only split
       sparkContext.parallelize(Array.empty[InternalRow], 1)
     } else {
+      var finalPartitions = filteredPartitions
+
+      outputPartitioning match {
+        case p: KeyGroupedPartitioning =>
+          val partitionMapping = finalPartitions.map(s =>
+            s.head.asInstanceOf[HasPartitionKey].partitionKey() -> s).toMap
+          finalPartitions = p.partitionValuesOpt.get.map { partKey =>
+            // Use empty partition for those partition keys that are not 
present
+            partitionMapping.getOrElse(partKey, Seq.empty)
+          }
+        case _ =>
+      }
+
       new DataSourceRDD(
-        sparkContext, filteredPartitions, readerFactory, supportsColumnar, 
customMetrics)
+        sparkContext, finalPartitions, readerFactory, supportsColumnar, 
customMetrics)
     }
     postDriverMetrics()
     rdd
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
index c0dc3263616..b2b8951a979 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
@@ -433,11 +433,11 @@ class KeyGroupedPartitioningSuite extends 
DistributionAndOrderingSuiteBase {
           s"(2, 11.0, cast('2020-01-01' as timestamp)), " +
           s"(3, 19.5, cast('2020-02-01' as timestamp))")
 
-      // number of unique partitions changed after dynamic filtering - should 
throw exception
+      // number of unique partitions changed after dynamic filtering - the gap 
should be filled
+      // with empty partitions and the job should still succeed
       var df = sql(s"SELECT sum(p.price) from testcat.ns.$items i, 
testcat.ns.$purchases p WHERE " +
           s"i.id = p.item_id AND i.price > 40.0")
-      val e = intercept[Exception](df.collect())
-      assert(e.getMessage.contains("number of unique partition values"))
+      checkAnswer(df, Seq(Row(131)))
 
       // dynamic filtering doesn't change partitioning so storage-partitioned 
join should kick in
       df = sql(s"SELECT sum(p.price) from testcat.ns.$items i, 
testcat.ns.$purchases p WHERE " +


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to