This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2be03d81cea3 [SPARK-45592][SQL] Correctness issue in AQE with 
InMemoryTableScanExec
2be03d81cea3 is described below

commit 2be03d81cea34ab08c44426837260c22c67e092e
Author: Emil Ejbyfeldt <[email protected]>
AuthorDate: Tue Oct 31 11:19:32 2023 +0800

    [SPARK-45592][SQL] Correctness issue in AQE with InMemoryTableScanExec
    
    ### What changes were proposed in this pull request?
    Fixes correctness issue in 3.5.0. The problem seems to be that when 
AQEShuffleRead does a coalesced read it can return a HashPartitioning with the 
coalesced number of partitions. This causes a correctness bug as the 
partitioning is not compatible for joins with other HashPartitioning even 
though the number of partitions matches. This is resolved in this patch by 
introducing CoalescedHashPartitioning and making AQEShuffleRead return that 
instead.
    
    The fix was suggested by cloud-fan
    
    > AQEShuffleRead should probably return a different partitioning, e.g. 
CoalescedHashPartitioning. It still satisfies ClusterDistribution, so Aggregate 
is fine and there will be no shuffle. For joins, two CoalescedHashPartitionings 
are compatible if they have the same original partition number and coalesce 
boundaries, and CoalescedHashPartitioning is not compatible with 
HashPartitioning.
    
    ### Why are the changes needed?
    Correctness bug.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, fixed correctness issue.
    
    ### How was this patch tested?
    New and existing unit test.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #43435 from eejbyfeldt/SPARK-45592.
    
    Authored-by: Emil Ejbyfeldt <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../sql/catalyst/plans/physical/partitioning.scala |  49 +++
 .../spark/sql/catalyst/DistributionSuite.scala     | 124 ++++---
 .../spark/sql/catalyst/ShuffleSpecSuite.scala      | 401 ++++++++++++---------
 .../execution/adaptive/AQEShuffleReadExec.scala    |  11 +-
 .../scala/org/apache/spark/sql/DatasetSuite.scala  |  14 +
 .../WriteDistributionAndOrderingSuite.scala        |  53 ++-
 6 files changed, 386 insertions(+), 266 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index a61bd3b7324b..0ae2857161c8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -306,6 +306,35 @@ case class HashPartitioning(expressions: Seq[Expression], 
numPartitions: Int)
 
   override protected def withNewChildrenInternal(
     newChildren: IndexedSeq[Expression]): HashPartitioning = copy(expressions 
= newChildren)
+
+}
+
+case class CoalescedBoundary(startReducerIndex: Int, endReducerIndex: Int)
+
+/**
+ * Represents a partitioning where partitions have been coalesced from a 
HashPartitioning into a
+ * fewer number of partitions.
+ */
+case class CoalescedHashPartitioning(from: HashPartitioning, partitions: 
Seq[CoalescedBoundary])
+  extends Expression with Partitioning with Unevaluable {
+
+  override def children: Seq[Expression] = from.expressions
+  override def nullable: Boolean = from.nullable
+  override def dataType: DataType = from.dataType
+
+  override def satisfies0(required: Distribution): Boolean = 
from.satisfies0(required)
+
+  override def createShuffleSpec(distribution: ClusteredDistribution): 
ShuffleSpec =
+    CoalescedHashShuffleSpec(from.createShuffleSpec(distribution), partitions)
+
+  override protected def withNewChildrenInternal(
+    newChildren: IndexedSeq[Expression]): CoalescedHashPartitioning =
+      copy(from = from.copy(expressions = newChildren))
+
+  override val numPartitions: Int = partitions.length
+
+  override def toString: String = from.toString
+  override def sql: String = from.sql
 }
 
 /**
@@ -708,6 +737,26 @@ case class HashShuffleSpec(
   override def numPartitions: Int = partitioning.numPartitions
 }
 
+case class CoalescedHashShuffleSpec(
+    from: ShuffleSpec,
+    partitions: Seq[CoalescedBoundary]) extends ShuffleSpec {
+
+  override def isCompatibleWith(other: ShuffleSpec): Boolean = other match {
+    case SinglePartitionShuffleSpec =>
+      numPartitions == 1
+    case CoalescedHashShuffleSpec(otherParent, otherPartitions) =>
+      partitions == otherPartitions && from.isCompatibleWith(otherParent)
+    case ShuffleSpecCollection(specs) =>
+      specs.exists(isCompatibleWith)
+    case _ =>
+      false
+  }
+
+  override def canCreatePartitioning: Boolean = false
+
+  override def numPartitions: Int = partitions.length
+}
+
 /**
  * [[ShuffleSpec]] created by [[KeyGroupedPartitioning]].
  *
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala
index a924a9ed02e5..7cb4d5f12325 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst
 import org.apache.spark.SparkFunSuite
 /* Implicit conversions */
 import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.catalyst.expressions.{Literal, Murmur3Hash, Pmod}
+import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, 
Murmur3Hash, Pmod}
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.types.IntegerType
 
@@ -146,63 +146,75 @@ class DistributionSuite extends SparkFunSuite {
       false)
   }
 
-  test("HashPartitioning is the output partitioning") {
-    // HashPartitioning can satisfy ClusteredDistribution iff its hash 
expressions are a subset of
-    // the required clustering expressions.
-    checkSatisfied(
-      HashPartitioning(Seq($"a", $"b", $"c"), 10),
-      ClusteredDistribution(Seq($"a", $"b", $"c")),
-      true)
-
-    checkSatisfied(
-      HashPartitioning(Seq($"b", $"c"), 10),
-      ClusteredDistribution(Seq($"a", $"b", $"c")),
-      true)
-
-    checkSatisfied(
-      HashPartitioning(Seq($"a", $"b", $"c"), 10),
-      ClusteredDistribution(Seq($"b", $"c")),
-      false)
-
-    checkSatisfied(
-      HashPartitioning(Seq($"a", $"b", $"c"), 10),
-      ClusteredDistribution(Seq($"d", $"e")),
-      false)
-
-    // When ClusteredDistribution.requireAllClusterKeys is set to true,
-    // HashPartitioning can only satisfy ClusteredDistribution iff its hash 
expressions are
-    // exactly same as the required clustering expressions.
-    checkSatisfied(
-      HashPartitioning(Seq($"a", $"b", $"c"), 10),
-      ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = 
true),
-      true)
-
-    checkSatisfied(
-      HashPartitioning(Seq($"b", $"c"), 10),
-      ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = 
true),
-      false)
-
-    checkSatisfied(
-      HashPartitioning(Seq($"b", $"a", $"c"), 10),
-      ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = 
true),
-      false)
-
-    // HashPartitioning cannot satisfy OrderedDistribution
-    checkSatisfied(
-      HashPartitioning(Seq($"a", $"b", $"c"), 10),
-      OrderedDistribution(Seq($"a".asc, $"b".asc, $"c".asc)),
-      false)
+  private def testHashPartitioningLike(
+      partitioningName: String,
+      create: (Seq[Expression], Int) => Partitioning): Unit = {
+
+    test(s"$partitioningName is the output partitioning") {
+      // HashPartitioning can satisfy ClusteredDistribution iff its hash 
expressions are a subset of
+      // the required clustering expressions.
+      checkSatisfied(
+        create(Seq($"a", $"b", $"c"), 10),
+        ClusteredDistribution(Seq($"a", $"b", $"c")),
+        true)
+
+      checkSatisfied(
+        create(Seq($"b", $"c"), 10),
+        ClusteredDistribution(Seq($"a", $"b", $"c")),
+        true)
+
+      checkSatisfied(
+        create(Seq($"a", $"b", $"c"), 10),
+        ClusteredDistribution(Seq($"b", $"c")),
+        false)
+
+      checkSatisfied(
+        create(Seq($"a", $"b", $"c"), 10),
+        ClusteredDistribution(Seq($"d", $"e")),
+        false)
+
+      // When ClusteredDistribution.requireAllClusterKeys is set to true,
+      // HashPartitioning can only satisfy ClusteredDistribution iff its hash 
expressions are
+      // exactly same as the required clustering expressions.
+      checkSatisfied(
+        create(Seq($"a", $"b", $"c"), 10),
+        ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = 
true),
+        true)
+
+      checkSatisfied(
+        create(Seq($"b", $"c"), 10),
+        ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = 
true),
+        false)
+
+      checkSatisfied(
+        create(Seq($"b", $"a", $"c"), 10),
+        ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = 
true),
+        false)
+
+      // HashPartitioning cannot satisfy OrderedDistribution
+      checkSatisfied(
+        create(Seq($"a", $"b", $"c"), 10),
+        OrderedDistribution(Seq($"a".asc, $"b".asc, $"c".asc)),
+        false)
+
+      checkSatisfied(
+        create(Seq($"a", $"b", $"c"), 1),
+        OrderedDistribution(Seq($"a".asc, $"b".asc, $"c".asc)),
+        false) // TODO: this can be relaxed.
+
+      checkSatisfied(
+        create(Seq($"b", $"c"), 10),
+        OrderedDistribution(Seq($"a".asc, $"b".asc, $"c".asc)),
+        false)
+    }
+  }
 
-    checkSatisfied(
-      HashPartitioning(Seq($"a", $"b", $"c"), 1),
-      OrderedDistribution(Seq($"a".asc, $"b".asc, $"c".asc)),
-      false) // TODO: this can be relaxed.
+  testHashPartitioningLike("HashPartitioning",
+    (expressions, numPartitions) => HashPartitioning(expressions, 
numPartitions))
 
-    checkSatisfied(
-      HashPartitioning(Seq($"b", $"c"), 10),
-      OrderedDistribution(Seq($"a".asc, $"b".asc, $"c".asc)),
-      false)
-  }
+  testHashPartitioningLike("CoalescedHashPartitioning", (expressions, 
numPartitions) =>
+      CoalescedHashPartitioning(
+        HashPartitioning(expressions, numPartitions), Seq(CoalescedBoundary(0, 
numPartitions))))
 
   test("RangePartitioning is the output partitioning") {
     // RangePartitioning can satisfy OrderedDistribution iff its ordering is a 
prefix
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala
index 51e768873226..6b069d1c9736 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala
@@ -62,211 +62,254 @@ class ShuffleSpecSuite extends SparkFunSuite with 
SQLHelper {
     }
   }
 
-  test("compatibility: HashShuffleSpec on both sides") {
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      expected = true
-    )
-
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a"), 10), 
ClusteredDistribution(Seq($"a", $"b"))),
-      HashShuffleSpec(HashPartitioning(Seq($"a"), 10), 
ClusteredDistribution(Seq($"a", $"b"))),
-      expected = true
-    )
+  private def testHashShuffleSpecLike(
+      shuffleSpecName: String,
+      create: (HashPartitioning, ClusteredDistribution) => ShuffleSpec): Unit 
= {
 
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"b"), 10), 
ClusteredDistribution(Seq($"a", $"b"))),
-      HashShuffleSpec(HashPartitioning(Seq($"d"), 10), 
ClusteredDistribution(Seq($"c", $"d"))),
-      expected = true
-    )
+    test(s"compatibility: $shuffleSpecName on both sides") {
+      checkCompatible(
+        create(HashPartitioning(Seq($"a", $"b"), 10),
+          ClusteredDistribution(Seq($"a", $"b"))),
+        create(HashPartitioning(Seq($"a", $"b"), 10),
+          ClusteredDistribution(Seq($"a", $"b"))),
+        expected = true
+      )
 
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"a", $"b"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      HashShuffleSpec(HashPartitioning(Seq($"c", $"c", $"d"), 10),
-        ClusteredDistribution(Seq($"c", $"d"))),
-      expected = true
-    )
+      checkCompatible(
+        create(HashPartitioning(Seq($"a"), 10), 
ClusteredDistribution(Seq($"a", $"b"))),
+        create(HashPartitioning(Seq($"a"), 10), 
ClusteredDistribution(Seq($"a", $"b"))),
+        expected = true
+      )
 
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10),
-        ClusteredDistribution(Seq($"a", $"b", $"b"))),
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"d"), 10),
-        ClusteredDistribution(Seq($"a", $"c", $"d"))),
-      expected = true
-    )
+      checkCompatible(
+        create(HashPartitioning(Seq($"b"), 10), 
ClusteredDistribution(Seq($"a", $"b"))),
+        create(HashPartitioning(Seq($"d"), 10), 
ClusteredDistribution(Seq($"c", $"d"))),
+        expected = true
+      )
 
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"b", $"a"), 10),
-        ClusteredDistribution(Seq($"a", $"b", $"b"))),
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"c", $"a"), 10),
-        ClusteredDistribution(Seq($"a", $"c", $"c"))),
-      expected = true
-    )
+      checkCompatible(
+        create(HashPartitioning(Seq($"a", $"a", $"b"), 10),
+          ClusteredDistribution(Seq($"a", $"b"))),
+        create(HashPartitioning(Seq($"c", $"c", $"d"), 10),
+          ClusteredDistribution(Seq($"c", $"d"))),
+        expected = true
+      )
 
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"b", $"a"), 10),
-        ClusteredDistribution(Seq($"a", $"b", $"b"))),
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"c", $"a"), 10),
-        ClusteredDistribution(Seq($"a", $"c", $"d"))),
-      expected = true
-    )
+      checkCompatible(
+        create(HashPartitioning(Seq($"a", $"b"), 10),
+          ClusteredDistribution(Seq($"a", $"b", $"b"))),
+        create(HashPartitioning(Seq($"a", $"d"), 10),
+          ClusteredDistribution(Seq($"a", $"c", $"d"))),
+        expected = true
+      )
 
-    // negative cases
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      HashShuffleSpec(HashPartitioning(Seq($"c"), 5),
-        ClusteredDistribution(Seq($"c", $"d"))),
-      expected = false
-    )
+      checkCompatible(
+        create(HashPartitioning(Seq($"a", $"b", $"a"), 10),
+          ClusteredDistribution(Seq($"a", $"b", $"b"))),
+        create(HashPartitioning(Seq($"a", $"c", $"a"), 10),
+          ClusteredDistribution(Seq($"a", $"c", $"c"))),
+        expected = true
+      )
 
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      HashShuffleSpec(HashPartitioning(Seq($"b"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      expected = false
-    )
+      checkCompatible(
+        create(HashPartitioning(Seq($"a", $"b", $"a"), 10),
+          ClusteredDistribution(Seq($"a", $"b", $"b"))),
+        create(HashPartitioning(Seq($"a", $"c", $"a"), 10),
+          ClusteredDistribution(Seq($"a", $"c", $"d"))),
+        expected = true
+      )
 
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      HashShuffleSpec(HashPartitioning(Seq($"b"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      expected = false
-    )
+      // negative cases
+      checkCompatible(
+        create(HashPartitioning(Seq($"a"), 10),
+          ClusteredDistribution(Seq($"a", $"b"))),
+        create(HashPartitioning(Seq($"c"), 5),
+          ClusteredDistribution(Seq($"c", $"d"))),
+        expected = false
+      )
 
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      HashShuffleSpec(HashPartitioning(Seq($"d"), 10),
-        ClusteredDistribution(Seq($"c", $"d"))),
-      expected = false
-    )
+      checkCompatible(
+        create(HashPartitioning(Seq($"a", $"b"), 10),
+          ClusteredDistribution(Seq($"a", $"b"))),
+        create(HashPartitioning(Seq($"b"), 10),
+          ClusteredDistribution(Seq($"a", $"b"))),
+        expected = false
+      )
 
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      HashShuffleSpec(HashPartitioning(Seq($"d"), 10),
-        ClusteredDistribution(Seq($"c", $"d"))),
-      expected = false
-    )
+      checkCompatible(
+        create(HashPartitioning(Seq($"a"), 10),
+          ClusteredDistribution(Seq($"a", $"b"))),
+        create(HashPartitioning(Seq($"b"), 10),
+          ClusteredDistribution(Seq($"a", $"b"))),
+        expected = false
+      )
 
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"a", $"b"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"b", $"a"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      expected = false
-    )
+      checkCompatible(
+        create(HashPartitioning(Seq($"a"), 10),
+          ClusteredDistribution(Seq($"a", $"b"))),
+        create(HashPartitioning(Seq($"d"), 10),
+          ClusteredDistribution(Seq($"c", $"d"))),
+        expected = false
+      )
 
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"a", $"b"), 10),
-        ClusteredDistribution(Seq($"a", $"b", $"b"))),
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"b", $"a"), 10),
-        ClusteredDistribution(Seq($"a", $"b", $"b"))),
-      expected = false
-    )
-  }
+      checkCompatible(
+        create(HashPartitioning(Seq($"a"), 10),
+          ClusteredDistribution(Seq($"a", $"b"))),
+        create(HashPartitioning(Seq($"d"), 10),
+          ClusteredDistribution(Seq($"c", $"d"))),
+        expected = false
+      )
 
-  test("compatibility: Only one side is HashShuffleSpec") {
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      SinglePartitionShuffleSpec,
-      expected = false
-    )
+      checkCompatible(
+        create(HashPartitioning(Seq($"a", $"a", $"b"), 10),
+          ClusteredDistribution(Seq($"a", $"b"))),
+        create(HashPartitioning(Seq($"a", $"b", $"a"), 10),
+          ClusteredDistribution(Seq($"a", $"b"))),
+        expected = false
+      )
 
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 1),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      SinglePartitionShuffleSpec,
-      expected = true
-    )
+      checkCompatible(
+        create(HashPartitioning(Seq($"a", $"a", $"b"), 10),
+          ClusteredDistribution(Seq($"a", $"b", $"b"))),
+        create(HashPartitioning(Seq($"a", $"b", $"a"), 10),
+          ClusteredDistribution(Seq($"a", $"b", $"b"))),
+        expected = false
+      )
+    }
 
-    checkCompatible(
-      SinglePartitionShuffleSpec,
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 1),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      expected = true
-    )
+    test(s"compatibility: Only one side is $shuffleSpecName") {
+      checkCompatible(
+        create(HashPartitioning(Seq($"a", $"b"), 10),
+          ClusteredDistribution(Seq($"a", $"b"))),
+        SinglePartitionShuffleSpec,
+        expected = false
+      )
 
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      RangeShuffleSpec(10, ClusteredDistribution(Seq($"a", $"b"))),
-      expected = false
-    )
+      checkCompatible(
+        create(HashPartitioning(Seq($"a", $"b"), 1),
+          ClusteredDistribution(Seq($"a", $"b"))),
+        SinglePartitionShuffleSpec,
+        expected = true
+      )
 
-    checkCompatible(
-      RangeShuffleSpec(10, ClusteredDistribution(Seq($"a", $"b"))),
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      expected = false
-    )
+      checkCompatible(
+        SinglePartitionShuffleSpec,
+        create(HashPartitioning(Seq($"a", $"b"), 1),
+          ClusteredDistribution(Seq($"a", $"b"))),
+        expected = true
+      )
 
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      ShuffleSpecCollection(Seq(
-        HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10),
-          ClusteredDistribution(Seq($"a", $"b"))))),
-      expected = true
-    )
+      checkCompatible(
+        create(HashPartitioning(Seq($"a", $"b"), 10),
+          ClusteredDistribution(Seq($"a", $"b"))),
+        RangeShuffleSpec(10, ClusteredDistribution(Seq($"a", $"b"))),
+        expected = false
+      )
 
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      ShuffleSpecCollection(Seq(
-        HashShuffleSpec(HashPartitioning(Seq($"a"), 10),
+      checkCompatible(
+        RangeShuffleSpec(10, ClusteredDistribution(Seq($"a", $"b"))),
+        create(HashPartitioning(Seq($"a", $"b"), 10),
           ClusteredDistribution(Seq($"a", $"b"))),
-        HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10),
-          ClusteredDistribution(Seq($"a", $"b"))))),
-      expected = true
-    )
+        expected = false
+      )
 
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      ShuffleSpecCollection(Seq(
-        HashShuffleSpec(HashPartitioning(Seq($"a"), 10),
+      checkCompatible(
+        create(HashPartitioning(Seq($"a", $"b"), 10),
           ClusteredDistribution(Seq($"a", $"b"))),
-        HashShuffleSpec(HashPartitioning(Seq($"a", $"b", $"c"), 10),
-          ClusteredDistribution(Seq($"a", $"b", $"c"))))),
-      expected = false
-    )
+        ShuffleSpecCollection(Seq(
+          create(HashPartitioning(Seq($"a", $"b"), 10),
+            ClusteredDistribution(Seq($"a", $"b"))))),
+        expected = true
+      )
 
-    checkCompatible(
-      ShuffleSpecCollection(Seq(
-        HashShuffleSpec(HashPartitioning(Seq($"b"), 10),
+      checkCompatible(
+        create(HashPartitioning(Seq($"a", $"b"), 10),
           ClusteredDistribution(Seq($"a", $"b"))),
-        HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10),
-          ClusteredDistribution(Seq($"a", $"b"))))),
-      ShuffleSpecCollection(Seq(
-        HashShuffleSpec(HashPartitioning(Seq($"a", $"b", $"c"), 10),
-          ClusteredDistribution(Seq($"a", $"b", $"c"))),
-        HashShuffleSpec(HashPartitioning(Seq($"d"), 10),
-          ClusteredDistribution(Seq($"c", $"d"))))),
-      expected = true
-    )
+        ShuffleSpecCollection(Seq(
+          create(HashPartitioning(Seq($"a"), 10),
+            ClusteredDistribution(Seq($"a", $"b"))),
+          create(HashPartitioning(Seq($"a", $"b"), 10),
+            ClusteredDistribution(Seq($"a", $"b"))))),
+        expected = true
+      )
 
-    checkCompatible(
-      ShuffleSpecCollection(Seq(
-        HashShuffleSpec(HashPartitioning(Seq($"b"), 10),
+      checkCompatible(
+        create(HashPartitioning(Seq($"a", $"b"), 10),
           ClusteredDistribution(Seq($"a", $"b"))),
-        HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10),
-          ClusteredDistribution(Seq($"a", $"b"))))),
-      ShuffleSpecCollection(Seq(
-        HashShuffleSpec(HashPartitioning(Seq($"a", $"b", $"c"), 10),
-          ClusteredDistribution(Seq($"a", $"b", $"c"))),
-        HashShuffleSpec(HashPartitioning(Seq($"c"), 10),
-          ClusteredDistribution(Seq($"c", $"d"))))),
-      expected = false
-    )
+        ShuffleSpecCollection(Seq(
+          create(HashPartitioning(Seq($"a"), 10),
+            ClusteredDistribution(Seq($"a", $"b"))),
+          create(HashPartitioning(Seq($"a", $"b", $"c"), 10),
+            ClusteredDistribution(Seq($"a", $"b", $"c"))))),
+        expected = false
+      )
+
+      checkCompatible(
+        ShuffleSpecCollection(Seq(
+          create(HashPartitioning(Seq($"b"), 10),
+            ClusteredDistribution(Seq($"a", $"b"))),
+          create(HashPartitioning(Seq($"a", $"b"), 10),
+            ClusteredDistribution(Seq($"a", $"b"))))),
+        ShuffleSpecCollection(Seq(
+          create(HashPartitioning(Seq($"a", $"b", $"c"), 10),
+            ClusteredDistribution(Seq($"a", $"b", $"c"))),
+          create(HashPartitioning(Seq($"d"), 10),
+            ClusteredDistribution(Seq($"c", $"d"))))),
+        expected = true
+      )
+
+      checkCompatible(
+        ShuffleSpecCollection(Seq(
+          create(HashPartitioning(Seq($"b"), 10),
+            ClusteredDistribution(Seq($"a", $"b"))),
+          create(HashPartitioning(Seq($"a", $"b"), 10),
+            ClusteredDistribution(Seq($"a", $"b"))))),
+        ShuffleSpecCollection(Seq(
+          create(HashPartitioning(Seq($"a", $"b", $"c"), 10),
+            ClusteredDistribution(Seq($"a", $"b", $"c"))),
+          create(HashPartitioning(Seq($"c"), 10),
+            ClusteredDistribution(Seq($"c", $"d"))))),
+        expected = false
+      )
+    }
+  }
+
+  testHashShuffleSpecLike("HashShuffleSpec",
+    (partitioning, distribution) => HashShuffleSpec(partitioning, 
distribution))
+   testHashShuffleSpecLike("CoalescedHashShuffleSpec",
+    (partitioning, distribution) => {
+      val partitions = if (partitioning.numPartitions == 1) {
+        Seq(CoalescedBoundary(0, 1))
+      } else {
+        Seq(CoalescedBoundary(0, 1), CoalescedBoundary(0, 
partitioning.numPartitions))
+      }
+      CoalescedHashShuffleSpec(HashShuffleSpec(partitioning, distribution), 
partitions)
+  })
+
+  test("compatibility: CoalescedHashShuffleSpec other specs") {
+      val hashShuffleSpec = HashShuffleSpec(
+        HashPartitioning(Seq($"a", $"b"), 10), ClusteredDistribution(Seq($"a", 
$"b")))
+      checkCompatible(
+        hashShuffleSpec,
+        CoalescedHashShuffleSpec(hashShuffleSpec, Seq(CoalescedBoundary(0, 
10))),
+        expected = false
+      )
+
+      checkCompatible(
+        CoalescedHashShuffleSpec(hashShuffleSpec,
+          Seq(CoalescedBoundary(0, 5), CoalescedBoundary(5, 10))),
+        CoalescedHashShuffleSpec(hashShuffleSpec,
+          Seq(CoalescedBoundary(0, 5), CoalescedBoundary(5, 10))),
+        expected = true
+      )
+
+      checkCompatible(
+        CoalescedHashShuffleSpec(hashShuffleSpec,
+          Seq(CoalescedBoundary(0, 4), CoalescedBoundary(4, 10))),
+        CoalescedHashShuffleSpec(hashShuffleSpec,
+          Seq(CoalescedBoundary(0, 5), CoalescedBoundary(5, 10))),
+        expected = false
+      )
   }
 
   test("compatibility: other specs") {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala
index 46ec91dcc0ab..6b39ac70a62e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala
@@ -19,10 +19,11 @@ package org.apache.spark.sql.execution.adaptive
 
 import scala.collection.mutable.ArrayBuffer
 
+import org.apache.spark.SparkException
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
-import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, 
Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition, 
UnknownPartitioning}
+import org.apache.spark.sql.catalyst.plans.physical.{CoalescedBoundary, 
CoalescedHashPartitioning, HashPartitioning, Partitioning, RangePartitioning, 
RoundRobinPartitioning, SinglePartition, UnknownPartitioning}
 import org.apache.spark.sql.catalyst.trees.CurrentOrigin
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, 
ShuffleExchangeLike}
@@ -75,7 +76,13 @@ case class AQEShuffleReadExec private(
       // partitions is changed.
       child.outputPartitioning match {
         case h: HashPartitioning =>
-          CurrentOrigin.withOrigin(h.origin)(h.copy(numPartitions = 
partitionSpecs.length))
+          val partitions = partitionSpecs.map {
+            case CoalescedPartitionSpec(start, end, _) => 
CoalescedBoundary(start, end)
+            // Can not happend due to isCoalescedRead
+            case unexpected =>
+              throw SparkException.internalError(s"Unexpected 
ShufflePartitionSpec: $unexpected")
+          }
+          CurrentOrigin.withOrigin(h.origin)(CoalescedHashPartitioning(h, 
partitions))
         case r: RangePartitioning =>
           CurrentOrigin.withOrigin(r.origin)(r.copy(numPartitions = 
partitionSpecs.length))
         // This can only happen for `REBALANCE_PARTITIONS_BY_NONE`, which uses
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 6b00799cabd1..bf78e6e11fe9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -2645,6 +2645,20 @@ class DatasetSuite extends QueryTest
     val ds = Seq(1, 2).toDS().persist(StorageLevel.NONE)
     assert(ds.count() == 2)
   }
+
+  test("SPARK-45592: Coaleasced shuffle read is not compatible with hash 
partitioning") {
+    val ee = spark.range(0, 1000000, 1, 5).map(l => (l, l)).toDF()
+      .persist(org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK)
+    ee.count()
+
+    val minNbrs1 = ee
+      .groupBy("_1").agg(min(col("_2")).as("min_number"))
+      .persist(org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK)
+
+    val join = ee.join(minNbrs1, "_1")
+    assert(join.count() == 1000000)
+  }
+
 }
 
 class DatasetLargeResultCollectingSuite extends QueryTest
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala
index 6cab0e0239dc..40938eb64247 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.{catalyst, AnalysisException, 
DataFrame, Row}
 import org.apache.spark.sql.catalyst.expressions.{ApplyFunctionExpression, 
Cast, Literal}
 import org.apache.spark.sql.catalyst.expressions.objects.Invoke
 import org.apache.spark.sql.catalyst.plans.physical
-import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, 
RangePartitioning, UnknownPartitioning}
+import org.apache.spark.sql.catalyst.plans.physical.{CoalescedBoundary, 
CoalescedHashPartitioning, HashPartitioning, RangePartitioning, 
UnknownPartitioning}
 import org.apache.spark.sql.connector.catalog.Identifier
 import org.apache.spark.sql.connector.catalog.functions._
 import org.apache.spark.sql.connector.distributions.{Distribution, 
Distributions}
@@ -264,11 +264,8 @@ class WriteDistributionAndOrderingSuite extends 
DistributionAndOrderingSuiteBase
       )
     )
     val writePartitioningExprs = Seq(attr("data"), attr("id"))
-    val writePartitioning = if (!coalesce) {
-      clusteredWritePartitioning(writePartitioningExprs, targetNumPartitions)
-    } else {
-      clusteredWritePartitioning(writePartitioningExprs, Some(1))
-    }
+    val writePartitioning = clusteredWritePartitioning(
+      writePartitioningExprs, targetNumPartitions, coalesce)
 
     checkWriteRequirements(
       tableDistribution,
@@ -377,11 +374,8 @@ class WriteDistributionAndOrderingSuite extends 
DistributionAndOrderingSuiteBase
       )
     )
     val writePartitioningExprs = Seq(attr("data"))
-    val writePartitioning = if (!coalesce) {
-      clusteredWritePartitioning(writePartitioningExprs, targetNumPartitions)
-    } else {
-      clusteredWritePartitioning(writePartitioningExprs, Some(1))
-    }
+    val writePartitioning = clusteredWritePartitioning(
+      writePartitioningExprs, targetNumPartitions, coalesce)
 
     checkWriteRequirements(
       tableDistribution,
@@ -875,11 +869,8 @@ class WriteDistributionAndOrderingSuite extends 
DistributionAndOrderingSuiteBase
       )
     )
     val writePartitioningExprs = Seq(attr("data"))
-    val writePartitioning = if (!coalesce) {
-      clusteredWritePartitioning(writePartitioningExprs, targetNumPartitions)
-    } else {
-      clusteredWritePartitioning(writePartitioningExprs, Some(1))
-    }
+    val writePartitioning = clusteredWritePartitioning(
+      writePartitioningExprs, targetNumPartitions, coalesce)
 
     checkWriteRequirements(
       tableDistribution,
@@ -963,11 +954,8 @@ class WriteDistributionAndOrderingSuite extends 
DistributionAndOrderingSuiteBase
       )
     )
     val writePartitioningExprs = Seq(attr("data"))
-    val writePartitioning = if (!coalesce) {
-      clusteredWritePartitioning(writePartitioningExprs, targetNumPartitions)
-    } else {
-      clusteredWritePartitioning(writePartitioningExprs, Some(1))
-    }
+    val writePartitioning = clusteredWritePartitioning(
+      writePartitioningExprs, targetNumPartitions, coalesce)
 
     checkWriteRequirements(
       tableDistribution,
@@ -1154,11 +1142,8 @@ class WriteDistributionAndOrderingSuite extends 
DistributionAndOrderingSuiteBase
     )
 
     val writePartitioningExprs = Seq(truncateExpr)
-    val writePartitioning = if (!coalesce) {
-      clusteredWritePartitioning(writePartitioningExprs, targetNumPartitions)
-    } else {
-      clusteredWritePartitioning(writePartitioningExprs, Some(1))
-    }
+    val writePartitioning = clusteredWritePartitioning(
+      writePartitioningExprs, targetNumPartitions, coalesce)
 
     checkWriteRequirements(
       tableDistribution,
@@ -1422,6 +1407,9 @@ class WriteDistributionAndOrderingSuite extends 
DistributionAndOrderingSuiteBase
       case p: physical.HashPartitioning =>
         val resolvedExprs = p.expressions.map(resolveAttrs(_, plan))
         p.copy(expressions = resolvedExprs)
+      case c: physical.CoalescedHashPartitioning =>
+        val resolvedExprs = c.from.expressions.map(resolveAttrs(_, plan))
+        c.copy(from = c.from.copy(expressions = resolvedExprs))
       case _: UnknownPartitioning =>
         // don't check partitioning if no particular one is expected
         actualPartitioning
@@ -1480,9 +1468,16 @@ class WriteDistributionAndOrderingSuite extends 
DistributionAndOrderingSuiteBase
 
   private def clusteredWritePartitioning(
       writePartitioningExprs: Seq[catalyst.expressions.Expression],
-      targetNumPartitions: Option[Int]): physical.Partitioning = {
-    HashPartitioning(writePartitioningExprs,
-      targetNumPartitions.getOrElse(conf.numShufflePartitions))
+      targetNumPartitions: Option[Int],
+      coalesce: Boolean): physical.Partitioning = {
+    val partitioning = HashPartitioning(writePartitioningExprs,
+        targetNumPartitions.getOrElse(conf.numShufflePartitions))
+    if (coalesce)  {
+      CoalescedHashPartitioning(
+        partitioning, Seq(CoalescedBoundary(0, partitioning.numPartitions)))
+    } else {
+      partitioning
+    }
   }
 
   private def partitionSizes(dataSkew: Boolean, coalesce: Boolean): 
Seq[Option[Long]] = {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to