This is an automated email from the ASF dual-hosted git repository.

sunchao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new aa1261dc129 [SPARK-44641][SQL] Incorrect result in certain scenarios 
when SPJ is not triggered
aa1261dc129 is described below

commit aa1261dc129618d27a1bdc743a5fdd54219f7c01
Author: Chao Sun <sunc...@apple.com>
AuthorDate: Mon Aug 7 19:16:38 2023 -0700

    [SPARK-44641][SQL] Incorrect result in certain scenarios when SPJ is not 
triggered
    
    ### What changes were proposed in this pull request?
    
    This PR makes sure we use unique partition values when calculating the 
final partitions in `BatchScanExec`, to make sure no duplicated partitions are 
generated.
    
    ### Why are the changes needed?
    
    When `spark.sql.sources.v2.bucketing.pushPartValues.enabled` and 
`spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled` are 
enabled, and SPJ is not triggered, currently Spark will generate 
incorrect/duplicated results.
    
    This is because with both configs enabled, Spark will delay the partition 
grouping until the time it calculates the final partitions used by the input 
RDD. To calculate the partitions, it uses partition values from the 
`KeyGroupedPartitioning` to find out the right ordering for the partitions. 
However, since grouping is not done when the partition values is computed, 
there could be duplicated partition values. This means the result could contain 
duplicated partitions too.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, this is a bug fix.
    
    ### How was this patch tested?
    
    Added a new test case for this scenario.
    
    Closes #42324 from sunchao/SPARK-44641.
    
    Authored-by: Chao Sun <sunc...@apple.com>
    Signed-off-by: Chao Sun <sunc...@apple.com>
---
 .../sql/catalyst/plans/physical/partitioning.scala |  9 +++-
 .../execution/datasources/v2/BatchScanExec.scala   |  9 +++-
 .../connector/KeyGroupedPartitioningSuite.scala    | 56 ++++++++++++++++++++++
 3 files changed, 72 insertions(+), 2 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index bd8ba54ddd7..456005768bd 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -313,7 +313,7 @@ case class HashPartitioning(expressions: Seq[Expression], 
numPartitions: Int)
  * by `expressions`. `partitionValues`, if defined, should contain value of 
partition key(s) in
  * ascending order, after evaluated by the transforms in `expressions`, for 
each input partition.
  * In addition, its length must be the same as the number of input partitions 
(and thus is a 1-1
- * mapping), and each row in `partitionValues` must be unique.
+ * mapping). The `partitionValues` may contain duplicated partition values.
  *
  * For example, if `expressions` is `[years(ts_col)]`, then a valid value of 
`partitionValues` is
  * `[0, 1, 2]`, which represents 3 input partitions with distinct partition 
values. All rows
@@ -355,6 +355,13 @@ case class KeyGroupedPartitioning(
 
   override def createShuffleSpec(distribution: ClusteredDistribution): 
ShuffleSpec =
     KeyGroupedShuffleSpec(this, distribution)
+
+  lazy val uniquePartitionValues: Seq[InternalRow] = {
+    partitionValues
+        .map(InternalRowComparableWrapper(_, expressions))
+        .distinct
+        .map(_.row)
+  }
 }
 
 object KeyGroupedPartitioning {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
index 4b538197392..eba3c71f871 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
@@ -190,10 +190,17 @@ case class BatchScanExec(
                     Seq.fill(numSplits)(Seq.empty))
               }
             } else {
+              // either `commonPartitionValues` is not defined, or it is 
defined but
+              // `applyPartialClustering` is false.
               val partitionMapping = groupedPartitions.map { case (row, parts) 
=>
                 InternalRowComparableWrapper(row, p.expressions) -> parts
               }.toMap
-              finalPartitions = p.partitionValues.map { partValue =>
+
+              // In case `commonPartitionValues` is not defined (e.g., SPJ is 
not used), there
+              // could exist duplicated partition values, as partition 
grouping is not done
+              // at the beginning and postponed to this method. It is 
important to use unique
+              // partition values here so that grouped partitions won't get 
duplicated.
+              finalPartitions = p.uniquePartitionValues.map { partValue =>
                 // Use empty partition for those partition values that are not 
present
                 partitionMapping.getOrElse(
                   InternalRowComparableWrapper(partValue, p.expressions), 
Seq.empty)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
index 880c30ba9f9..8461f528277 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
@@ -1039,4 +1039,60 @@ class KeyGroupedPartitioningSuite extends 
DistributionAndOrderingSuiteBase {
       }
     }
   }
+
+  test("SPARK-44641: duplicated records when SPJ is not triggered") {
+    val items_partitions = Array(bucket(8, "id"))
+    createTable(items, items_schema, items_partitions)
+    sql(s"""
+        INSERT INTO testcat.ns.$items VALUES
+        (1, 'aa', 40.0, cast('2020-01-01' as timestamp)),
+        (1, 'aa', 41.0, cast('2020-01-15' as timestamp)),
+        (2, 'bb', 10.0, cast('2020-01-01' as timestamp)),
+        (2, 'bb', 10.5, cast('2020-01-01' as timestamp)),
+        (3, 'cc', 15.5, cast('2020-02-01' as timestamp))""")
+
+    val purchases_partitions = Array(bucket(8, "item_id"))
+    createTable(purchases, purchases_schema, purchases_partitions)
+    sql(s"""INSERT INTO testcat.ns.$purchases VALUES
+        (1, 42.0, cast('2020-01-01' as timestamp)),
+        (1, 44.0, cast('2020-01-15' as timestamp)),
+        (1, 45.0, cast('2020-01-15' as timestamp)),
+        (2, 11.0, cast('2020-01-01' as timestamp)),
+        (3, 19.5, cast('2020-02-01' as timestamp))""")
+
+    Seq(true, false).foreach { pushDownValues =>
+      Seq(true, false).foreach { partiallyClusteredEnabled =>
+        withSQLConf(
+          SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> 
pushDownValues.toString,
+          SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key ->
+              partiallyClusteredEnabled.toString) {
+
+          // join keys are not the same as the partition keys, therefore SPJ 
is not triggered.
+          val df = sql(
+            s"""
+               SELECT id, name, i.price as purchase_price, p.item_id, p.price 
as sale_price
+               FROM testcat.ns.$items i JOIN testcat.ns.$purchases p
+               ON i.arrive_time = p.time ORDER BY id, purchase_price, 
p.item_id, sale_price
+               """)
+
+          val shuffles = collectShuffles(df.queryExecution.executedPlan)
+          assert(shuffles.nonEmpty, "shuffle should exist when SPJ is not 
used")
+
+          checkAnswer(df,
+            Seq(
+              Row(1, "aa", 40.0, 1, 42.0),
+              Row(1, "aa", 40.0, 2, 11.0),
+              Row(1, "aa", 41.0, 1, 44.0),
+              Row(1, "aa", 41.0, 1, 45.0),
+              Row(2, "bb", 10.0, 1, 42.0),
+              Row(2, "bb", 10.0, 2, 11.0),
+              Row(2, "bb", 10.5, 1, 42.0),
+              Row(2, "bb", 10.5, 2, 11.0),
+              Row(3, "cc", 15.5, 3, 19.5)
+            )
+          )
+        }
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to