This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 212f00462f2 [SPARK-45592][SQL] Correctness issue in AQE with 
InMemoryTableScanExec
212f00462f2 is described below

commit 212f00462f2419039742a305cdc941a886e3a15a
Author: Emil Ejbyfeldt <[email protected]>
AuthorDate: Tue Oct 31 11:19:32 2023 +0800

    [SPARK-45592][SQL] Correctness issue in AQE with InMemoryTableScanExec
    
    Fixes correctness issue in 3.5.0. The problem seems to be that when 
AQEShuffleRead does a coalesced read it can return a HashPartitioning with the 
coalesced number of partitions. This causes a correctness bug as the 
partitioning is not compatible for joins with other HashPartitioning even 
though the number of partitions matches. This is resolved in this patch by 
introducing CoalescedHashPartitioning and making AQEShuffleRead return that 
instead.
    
    The fix was suggested by cloud-fan
    
    > AQEShuffleRead should probably return a different partitioning, e.g. 
CoalescedHashPartitioning. It still satisfies ClusterDistribution, so Aggregate 
is fine and there will be no shuffle. For joins, two CoalescedHashPartitionings 
are compatible if they have the same original partition number and coalesce 
boundaries, and CoalescedHashPartitioning is not compatible with 
HashPartitioning.
    
    Correctness bug.
    
    Yes, fixed correctness issue.
    
    New and existing unit test.
    
    No
    
    Closes #43435 from eejbyfeldt/SPARK-45592.
    
    Authored-by: Emil Ejbyfeldt <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
    (cherry picked from commit 2be03d81cea34ab08c44426837260c22c67e092e)
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../sql/catalyst/plans/physical/partitioning.scala |  49 +++
 .../spark/sql/catalyst/DistributionSuite.scala     | 124 ++++---
 .../spark/sql/catalyst/ShuffleSpecSuite.scala      | 401 ++++++++++++---------
 .../execution/adaptive/AQEShuffleReadExec.scala    |  11 +-
 .../scala/org/apache/spark/sql/DatasetSuite.scala  |  14 +
 .../WriteDistributionAndOrderingSuite.scala        |  53 ++-
 6 files changed, 386 insertions(+), 266 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index d2f9e9b5d5b..1eefe65859b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -306,6 +306,35 @@ case class HashPartitioning(expressions: Seq[Expression], 
numPartitions: Int)
 
   override protected def withNewChildrenInternal(
     newChildren: IndexedSeq[Expression]): HashPartitioning = copy(expressions 
= newChildren)
+
+}
+
+case class CoalescedBoundary(startReducerIndex: Int, endReducerIndex: Int)
+
+/**
+ * Represents a partitioning where partitions have been coalesced from a 
HashPartitioning into a
+ * fewer number of partitions.
+ */
+case class CoalescedHashPartitioning(from: HashPartitioning, partitions: 
Seq[CoalescedBoundary])
+  extends Expression with Partitioning with Unevaluable {
+
+  override def children: Seq[Expression] = from.expressions
+  override def nullable: Boolean = from.nullable
+  override def dataType: DataType = from.dataType
+
+  override def satisfies0(required: Distribution): Boolean = 
from.satisfies0(required)
+
+  override def createShuffleSpec(distribution: ClusteredDistribution): 
ShuffleSpec =
+    CoalescedHashShuffleSpec(from.createShuffleSpec(distribution), partitions)
+
+  override protected def withNewChildrenInternal(
+    newChildren: IndexedSeq[Expression]): CoalescedHashPartitioning =
+      copy(from = from.copy(expressions = newChildren))
+
+  override val numPartitions: Int = partitions.length
+
+  override def toString: String = from.toString
+  override def sql: String = from.sql
 }
 
 /**
@@ -661,6 +690,26 @@ case class HashShuffleSpec(
   override def numPartitions: Int = partitioning.numPartitions
 }
 
+case class CoalescedHashShuffleSpec(
+    from: ShuffleSpec,
+    partitions: Seq[CoalescedBoundary]) extends ShuffleSpec {
+
+  override def isCompatibleWith(other: ShuffleSpec): Boolean = other match {
+    case SinglePartitionShuffleSpec =>
+      numPartitions == 1
+    case CoalescedHashShuffleSpec(otherParent, otherPartitions) =>
+      partitions == otherPartitions && from.isCompatibleWith(otherParent)
+    case ShuffleSpecCollection(specs) =>
+      specs.exists(isCompatibleWith)
+    case _ =>
+      false
+  }
+
+  override def canCreatePartitioning: Boolean = false
+
+  override def numPartitions: Int = partitions.length
+}
+
 case class KeyGroupedShuffleSpec(
     partitioning: KeyGroupedPartitioning,
     distribution: ClusteredDistribution) extends ShuffleSpec {
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala
index a924a9ed02e..7cb4d5f1232 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst
 import org.apache.spark.SparkFunSuite
 /* Implicit conversions */
 import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.catalyst.expressions.{Literal, Murmur3Hash, Pmod}
+import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, 
Murmur3Hash, Pmod}
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.types.IntegerType
 
@@ -146,63 +146,75 @@ class DistributionSuite extends SparkFunSuite {
       false)
   }
 
-  test("HashPartitioning is the output partitioning") {
-    // HashPartitioning can satisfy ClusteredDistribution iff its hash 
expressions are a subset of
-    // the required clustering expressions.
-    checkSatisfied(
-      HashPartitioning(Seq($"a", $"b", $"c"), 10),
-      ClusteredDistribution(Seq($"a", $"b", $"c")),
-      true)
-
-    checkSatisfied(
-      HashPartitioning(Seq($"b", $"c"), 10),
-      ClusteredDistribution(Seq($"a", $"b", $"c")),
-      true)
-
-    checkSatisfied(
-      HashPartitioning(Seq($"a", $"b", $"c"), 10),
-      ClusteredDistribution(Seq($"b", $"c")),
-      false)
-
-    checkSatisfied(
-      HashPartitioning(Seq($"a", $"b", $"c"), 10),
-      ClusteredDistribution(Seq($"d", $"e")),
-      false)
-
-    // When ClusteredDistribution.requireAllClusterKeys is set to true,
-    // HashPartitioning can only satisfy ClusteredDistribution iff its hash 
expressions are
-    // exactly same as the required clustering expressions.
-    checkSatisfied(
-      HashPartitioning(Seq($"a", $"b", $"c"), 10),
-      ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = 
true),
-      true)
-
-    checkSatisfied(
-      HashPartitioning(Seq($"b", $"c"), 10),
-      ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = 
true),
-      false)
-
-    checkSatisfied(
-      HashPartitioning(Seq($"b", $"a", $"c"), 10),
-      ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = 
true),
-      false)
-
-    // HashPartitioning cannot satisfy OrderedDistribution
-    checkSatisfied(
-      HashPartitioning(Seq($"a", $"b", $"c"), 10),
-      OrderedDistribution(Seq($"a".asc, $"b".asc, $"c".asc)),
-      false)
+  private def testHashPartitioningLike(
+      partitioningName: String,
+      create: (Seq[Expression], Int) => Partitioning): Unit = {
+
+    test(s"$partitioningName is the output partitioning") {
+      // HashPartitioning can satisfy ClusteredDistribution iff its hash 
expressions are a subset of
+      // the required clustering expressions.
+      checkSatisfied(
+        create(Seq($"a", $"b", $"c"), 10),
+        ClusteredDistribution(Seq($"a", $"b", $"c")),
+        true)
+
+      checkSatisfied(
+        create(Seq($"b", $"c"), 10),
+        ClusteredDistribution(Seq($"a", $"b", $"c")),
+        true)
+
+      checkSatisfied(
+        create(Seq($"a", $"b", $"c"), 10),
+        ClusteredDistribution(Seq($"b", $"c")),
+        false)
+
+      checkSatisfied(
+        create(Seq($"a", $"b", $"c"), 10),
+        ClusteredDistribution(Seq($"d", $"e")),
+        false)
+
+      // When ClusteredDistribution.requireAllClusterKeys is set to true,
+      // HashPartitioning can only satisfy ClusteredDistribution iff its hash 
expressions are
+      // exactly same as the required clustering expressions.
+      checkSatisfied(
+        create(Seq($"a", $"b", $"c"), 10),
+        ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = 
true),
+        true)
+
+      checkSatisfied(
+        create(Seq($"b", $"c"), 10),
+        ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = 
true),
+        false)
+
+      checkSatisfied(
+        create(Seq($"b", $"a", $"c"), 10),
+        ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = 
true),
+        false)
+
+      // HashPartitioning cannot satisfy OrderedDistribution
+      checkSatisfied(
+        create(Seq($"a", $"b", $"c"), 10),
+        OrderedDistribution(Seq($"a".asc, $"b".asc, $"c".asc)),
+        false)
+
+      checkSatisfied(
+        create(Seq($"a", $"b", $"c"), 1),
+        OrderedDistribution(Seq($"a".asc, $"b".asc, $"c".asc)),
+        false) // TODO: this can be relaxed.
+
+      checkSatisfied(
+        create(Seq($"b", $"c"), 10),
+        OrderedDistribution(Seq($"a".asc, $"b".asc, $"c".asc)),
+        false)
+    }
+  }
 
-    checkSatisfied(
-      HashPartitioning(Seq($"a", $"b", $"c"), 1),
-      OrderedDistribution(Seq($"a".asc, $"b".asc, $"c".asc)),
-      false) // TODO: this can be relaxed.
+  testHashPartitioningLike("HashPartitioning",
+    (expressions, numPartitions) => HashPartitioning(expressions, 
numPartitions))
 
-    checkSatisfied(
-      HashPartitioning(Seq($"b", $"c"), 10),
-      OrderedDistribution(Seq($"a".asc, $"b".asc, $"c".asc)),
-      false)
-  }
+  testHashPartitioningLike("CoalescedHashPartitioning", (expressions, 
numPartitions) =>
+      CoalescedHashPartitioning(
+        HashPartitioning(expressions, numPartitions), Seq(CoalescedBoundary(0, 
numPartitions))))
 
   test("RangePartitioning is the output partitioning") {
     // RangePartitioning can satisfy OrderedDistribution iff its ordering is a 
prefix
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala
index 51e76887322..6b069d1c973 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala
@@ -62,211 +62,254 @@ class ShuffleSpecSuite extends SparkFunSuite with 
SQLHelper {
     }
   }
 
-  test("compatibility: HashShuffleSpec on both sides") {
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      expected = true
-    )
-
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a"), 10), 
ClusteredDistribution(Seq($"a", $"b"))),
-      HashShuffleSpec(HashPartitioning(Seq($"a"), 10), 
ClusteredDistribution(Seq($"a", $"b"))),
-      expected = true
-    )
+  private def testHashShuffleSpecLike(
+      shuffleSpecName: String,
+      create: (HashPartitioning, ClusteredDistribution) => ShuffleSpec): Unit 
= {
 
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"b"), 10), 
ClusteredDistribution(Seq($"a", $"b"))),
-      HashShuffleSpec(HashPartitioning(Seq($"d"), 10), 
ClusteredDistribution(Seq($"c", $"d"))),
-      expected = true
-    )
+    test(s"compatibility: $shuffleSpecName on both sides") {
+      checkCompatible(
+        create(HashPartitioning(Seq($"a", $"b"), 10),
+          ClusteredDistribution(Seq($"a", $"b"))),
+        create(HashPartitioning(Seq($"a", $"b"), 10),
+          ClusteredDistribution(Seq($"a", $"b"))),
+        expected = true
+      )
 
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"a", $"b"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      HashShuffleSpec(HashPartitioning(Seq($"c", $"c", $"d"), 10),
-        ClusteredDistribution(Seq($"c", $"d"))),
-      expected = true
-    )
+      checkCompatible(
+        create(HashPartitioning(Seq($"a"), 10), 
ClusteredDistribution(Seq($"a", $"b"))),
+        create(HashPartitioning(Seq($"a"), 10), 
ClusteredDistribution(Seq($"a", $"b"))),
+        expected = true
+      )
 
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10),
-        ClusteredDistribution(Seq($"a", $"b", $"b"))),
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"d"), 10),
-        ClusteredDistribution(Seq($"a", $"c", $"d"))),
-      expected = true
-    )
+      checkCompatible(
+        create(HashPartitioning(Seq($"b"), 10), 
ClusteredDistribution(Seq($"a", $"b"))),
+        create(HashPartitioning(Seq($"d"), 10), 
ClusteredDistribution(Seq($"c", $"d"))),
+        expected = true
+      )
 
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"b", $"a"), 10),
-        ClusteredDistribution(Seq($"a", $"b", $"b"))),
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"c", $"a"), 10),
-        ClusteredDistribution(Seq($"a", $"c", $"c"))),
-      expected = true
-    )
+      checkCompatible(
+        create(HashPartitioning(Seq($"a", $"a", $"b"), 10),
+          ClusteredDistribution(Seq($"a", $"b"))),
+        create(HashPartitioning(Seq($"c", $"c", $"d"), 10),
+          ClusteredDistribution(Seq($"c", $"d"))),
+        expected = true
+      )
 
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"b", $"a"), 10),
-        ClusteredDistribution(Seq($"a", $"b", $"b"))),
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"c", $"a"), 10),
-        ClusteredDistribution(Seq($"a", $"c", $"d"))),
-      expected = true
-    )
+      checkCompatible(
+        create(HashPartitioning(Seq($"a", $"b"), 10),
+          ClusteredDistribution(Seq($"a", $"b", $"b"))),
+        create(HashPartitioning(Seq($"a", $"d"), 10),
+          ClusteredDistribution(Seq($"a", $"c", $"d"))),
+        expected = true
+      )
 
-    // negative cases
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      HashShuffleSpec(HashPartitioning(Seq($"c"), 5),
-        ClusteredDistribution(Seq($"c", $"d"))),
-      expected = false
-    )
+      checkCompatible(
+        create(HashPartitioning(Seq($"a", $"b", $"a"), 10),
+          ClusteredDistribution(Seq($"a", $"b", $"b"))),
+        create(HashPartitioning(Seq($"a", $"c", $"a"), 10),
+          ClusteredDistribution(Seq($"a", $"c", $"c"))),
+        expected = true
+      )
 
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      HashShuffleSpec(HashPartitioning(Seq($"b"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      expected = false
-    )
+      checkCompatible(
+        create(HashPartitioning(Seq($"a", $"b", $"a"), 10),
+          ClusteredDistribution(Seq($"a", $"b", $"b"))),
+        create(HashPartitioning(Seq($"a", $"c", $"a"), 10),
+          ClusteredDistribution(Seq($"a", $"c", $"d"))),
+        expected = true
+      )
 
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      HashShuffleSpec(HashPartitioning(Seq($"b"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      expected = false
-    )
+      // negative cases
+      checkCompatible(
+        create(HashPartitioning(Seq($"a"), 10),
+          ClusteredDistribution(Seq($"a", $"b"))),
+        create(HashPartitioning(Seq($"c"), 5),
+          ClusteredDistribution(Seq($"c", $"d"))),
+        expected = false
+      )
 
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      HashShuffleSpec(HashPartitioning(Seq($"d"), 10),
-        ClusteredDistribution(Seq($"c", $"d"))),
-      expected = false
-    )
+      checkCompatible(
+        create(HashPartitioning(Seq($"a", $"b"), 10),
+          ClusteredDistribution(Seq($"a", $"b"))),
+        create(HashPartitioning(Seq($"b"), 10),
+          ClusteredDistribution(Seq($"a", $"b"))),
+        expected = false
+      )
 
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      HashShuffleSpec(HashPartitioning(Seq($"d"), 10),
-        ClusteredDistribution(Seq($"c", $"d"))),
-      expected = false
-    )
+      checkCompatible(
+        create(HashPartitioning(Seq($"a"), 10),
+          ClusteredDistribution(Seq($"a", $"b"))),
+        create(HashPartitioning(Seq($"b"), 10),
+          ClusteredDistribution(Seq($"a", $"b"))),
+        expected = false
+      )
 
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"a", $"b"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"b", $"a"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      expected = false
-    )
+      checkCompatible(
+        create(HashPartitioning(Seq($"a"), 10),
+          ClusteredDistribution(Seq($"a", $"b"))),
+        create(HashPartitioning(Seq($"d"), 10),
+          ClusteredDistribution(Seq($"c", $"d"))),
+        expected = false
+      )
 
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"a", $"b"), 10),
-        ClusteredDistribution(Seq($"a", $"b", $"b"))),
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"b", $"a"), 10),
-        ClusteredDistribution(Seq($"a", $"b", $"b"))),
-      expected = false
-    )
-  }
+      checkCompatible(
+        create(HashPartitioning(Seq($"a"), 10),
+          ClusteredDistribution(Seq($"a", $"b"))),
+        create(HashPartitioning(Seq($"d"), 10),
+          ClusteredDistribution(Seq($"c", $"d"))),
+        expected = false
+      )
 
-  test("compatibility: Only one side is HashShuffleSpec") {
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      SinglePartitionShuffleSpec,
-      expected = false
-    )
+      checkCompatible(
+        create(HashPartitioning(Seq($"a", $"a", $"b"), 10),
+          ClusteredDistribution(Seq($"a", $"b"))),
+        create(HashPartitioning(Seq($"a", $"b", $"a"), 10),
+          ClusteredDistribution(Seq($"a", $"b"))),
+        expected = false
+      )
 
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 1),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      SinglePartitionShuffleSpec,
-      expected = true
-    )
+      checkCompatible(
+        create(HashPartitioning(Seq($"a", $"a", $"b"), 10),
+          ClusteredDistribution(Seq($"a", $"b", $"b"))),
+        create(HashPartitioning(Seq($"a", $"b", $"a"), 10),
+          ClusteredDistribution(Seq($"a", $"b", $"b"))),
+        expected = false
+      )
+    }
 
-    checkCompatible(
-      SinglePartitionShuffleSpec,
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 1),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      expected = true
-    )
+    test(s"compatibility: Only one side is $shuffleSpecName") {
+      checkCompatible(
+        create(HashPartitioning(Seq($"a", $"b"), 10),
+          ClusteredDistribution(Seq($"a", $"b"))),
+        SinglePartitionShuffleSpec,
+        expected = false
+      )
 
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      RangeShuffleSpec(10, ClusteredDistribution(Seq($"a", $"b"))),
-      expected = false
-    )
+      checkCompatible(
+        create(HashPartitioning(Seq($"a", $"b"), 1),
+          ClusteredDistribution(Seq($"a", $"b"))),
+        SinglePartitionShuffleSpec,
+        expected = true
+      )
 
-    checkCompatible(
-      RangeShuffleSpec(10, ClusteredDistribution(Seq($"a", $"b"))),
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      expected = false
-    )
+      checkCompatible(
+        SinglePartitionShuffleSpec,
+        create(HashPartitioning(Seq($"a", $"b"), 1),
+          ClusteredDistribution(Seq($"a", $"b"))),
+        expected = true
+      )
 
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      ShuffleSpecCollection(Seq(
-        HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10),
-          ClusteredDistribution(Seq($"a", $"b"))))),
-      expected = true
-    )
+      checkCompatible(
+        create(HashPartitioning(Seq($"a", $"b"), 10),
+          ClusteredDistribution(Seq($"a", $"b"))),
+        RangeShuffleSpec(10, ClusteredDistribution(Seq($"a", $"b"))),
+        expected = false
+      )
 
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      ShuffleSpecCollection(Seq(
-        HashShuffleSpec(HashPartitioning(Seq($"a"), 10),
+      checkCompatible(
+        RangeShuffleSpec(10, ClusteredDistribution(Seq($"a", $"b"))),
+        create(HashPartitioning(Seq($"a", $"b"), 10),
           ClusteredDistribution(Seq($"a", $"b"))),
-        HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10),
-          ClusteredDistribution(Seq($"a", $"b"))))),
-      expected = true
-    )
+        expected = false
+      )
 
-    checkCompatible(
-      HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10),
-        ClusteredDistribution(Seq($"a", $"b"))),
-      ShuffleSpecCollection(Seq(
-        HashShuffleSpec(HashPartitioning(Seq($"a"), 10),
+      checkCompatible(
+        create(HashPartitioning(Seq($"a", $"b"), 10),
           ClusteredDistribution(Seq($"a", $"b"))),
-        HashShuffleSpec(HashPartitioning(Seq($"a", $"b", $"c"), 10),
-          ClusteredDistribution(Seq($"a", $"b", $"c"))))),
-      expected = false
-    )
+        ShuffleSpecCollection(Seq(
+          create(HashPartitioning(Seq($"a", $"b"), 10),
+            ClusteredDistribution(Seq($"a", $"b"))))),
+        expected = true
+      )
 
-    checkCompatible(
-      ShuffleSpecCollection(Seq(
-        HashShuffleSpec(HashPartitioning(Seq($"b"), 10),
+      checkCompatible(
+        create(HashPartitioning(Seq($"a", $"b"), 10),
           ClusteredDistribution(Seq($"a", $"b"))),
-        HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10),
-          ClusteredDistribution(Seq($"a", $"b"))))),
-      ShuffleSpecCollection(Seq(
-        HashShuffleSpec(HashPartitioning(Seq($"a", $"b", $"c"), 10),
-          ClusteredDistribution(Seq($"a", $"b", $"c"))),
-        HashShuffleSpec(HashPartitioning(Seq($"d"), 10),
-          ClusteredDistribution(Seq($"c", $"d"))))),
-      expected = true
-    )
+        ShuffleSpecCollection(Seq(
+          create(HashPartitioning(Seq($"a"), 10),
+            ClusteredDistribution(Seq($"a", $"b"))),
+          create(HashPartitioning(Seq($"a", $"b"), 10),
+            ClusteredDistribution(Seq($"a", $"b"))))),
+        expected = true
+      )
 
-    checkCompatible(
-      ShuffleSpecCollection(Seq(
-        HashShuffleSpec(HashPartitioning(Seq($"b"), 10),
+      checkCompatible(
+        create(HashPartitioning(Seq($"a", $"b"), 10),
           ClusteredDistribution(Seq($"a", $"b"))),
-        HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10),
-          ClusteredDistribution(Seq($"a", $"b"))))),
-      ShuffleSpecCollection(Seq(
-        HashShuffleSpec(HashPartitioning(Seq($"a", $"b", $"c"), 10),
-          ClusteredDistribution(Seq($"a", $"b", $"c"))),
-        HashShuffleSpec(HashPartitioning(Seq($"c"), 10),
-          ClusteredDistribution(Seq($"c", $"d"))))),
-      expected = false
-    )
+        ShuffleSpecCollection(Seq(
+          create(HashPartitioning(Seq($"a"), 10),
+            ClusteredDistribution(Seq($"a", $"b"))),
+          create(HashPartitioning(Seq($"a", $"b", $"c"), 10),
+            ClusteredDistribution(Seq($"a", $"b", $"c"))))),
+        expected = false
+      )
+
+      checkCompatible(
+        ShuffleSpecCollection(Seq(
+          create(HashPartitioning(Seq($"b"), 10),
+            ClusteredDistribution(Seq($"a", $"b"))),
+          create(HashPartitioning(Seq($"a", $"b"), 10),
+            ClusteredDistribution(Seq($"a", $"b"))))),
+        ShuffleSpecCollection(Seq(
+          create(HashPartitioning(Seq($"a", $"b", $"c"), 10),
+            ClusteredDistribution(Seq($"a", $"b", $"c"))),
+          create(HashPartitioning(Seq($"d"), 10),
+            ClusteredDistribution(Seq($"c", $"d"))))),
+        expected = true
+      )
+
+      checkCompatible(
+        ShuffleSpecCollection(Seq(
+          create(HashPartitioning(Seq($"b"), 10),
+            ClusteredDistribution(Seq($"a", $"b"))),
+          create(HashPartitioning(Seq($"a", $"b"), 10),
+            ClusteredDistribution(Seq($"a", $"b"))))),
+        ShuffleSpecCollection(Seq(
+          create(HashPartitioning(Seq($"a", $"b", $"c"), 10),
+            ClusteredDistribution(Seq($"a", $"b", $"c"))),
+          create(HashPartitioning(Seq($"c"), 10),
+            ClusteredDistribution(Seq($"c", $"d"))))),
+        expected = false
+      )
+    }
+  }
+
+  testHashShuffleSpecLike("HashShuffleSpec",
+    (partitioning, distribution) => HashShuffleSpec(partitioning, 
distribution))
+   testHashShuffleSpecLike("CoalescedHashShuffleSpec",
+    (partitioning, distribution) => {
+      val partitions = if (partitioning.numPartitions == 1) {
+        Seq(CoalescedBoundary(0, 1))
+      } else {
+        Seq(CoalescedBoundary(0, 1), CoalescedBoundary(0, 
partitioning.numPartitions))
+      }
+      CoalescedHashShuffleSpec(HashShuffleSpec(partitioning, distribution), 
partitions)
+  })
+
+  test("compatibility: CoalescedHashShuffleSpec other specs") {
+      val hashShuffleSpec = HashShuffleSpec(
+        HashPartitioning(Seq($"a", $"b"), 10), ClusteredDistribution(Seq($"a", 
$"b")))
+      checkCompatible(
+        hashShuffleSpec,
+        CoalescedHashShuffleSpec(hashShuffleSpec, Seq(CoalescedBoundary(0, 
10))),
+        expected = false
+      )
+
+      checkCompatible(
+        CoalescedHashShuffleSpec(hashShuffleSpec,
+          Seq(CoalescedBoundary(0, 5), CoalescedBoundary(5, 10))),
+        CoalescedHashShuffleSpec(hashShuffleSpec,
+          Seq(CoalescedBoundary(0, 5), CoalescedBoundary(5, 10))),
+        expected = true
+      )
+
+      checkCompatible(
+        CoalescedHashShuffleSpec(hashShuffleSpec,
+          Seq(CoalescedBoundary(0, 4), CoalescedBoundary(4, 10))),
+        CoalescedHashShuffleSpec(hashShuffleSpec,
+          Seq(CoalescedBoundary(0, 5), CoalescedBoundary(5, 10))),
+        expected = false
+      )
   }
 
   test("compatibility: other specs") {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala
index 46ec91dcc0a..6b39ac70a62 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala
@@ -19,10 +19,11 @@ package org.apache.spark.sql.execution.adaptive
 
 import scala.collection.mutable.ArrayBuffer
 
+import org.apache.spark.SparkException
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
-import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, 
Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition, 
UnknownPartitioning}
+import org.apache.spark.sql.catalyst.plans.physical.{CoalescedBoundary, 
CoalescedHashPartitioning, HashPartitioning, Partitioning, RangePartitioning, 
RoundRobinPartitioning, SinglePartition, UnknownPartitioning}
 import org.apache.spark.sql.catalyst.trees.CurrentOrigin
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, 
ShuffleExchangeLike}
@@ -75,7 +76,13 @@ case class AQEShuffleReadExec private(
       // partitions is changed.
       child.outputPartitioning match {
         case h: HashPartitioning =>
-          CurrentOrigin.withOrigin(h.origin)(h.copy(numPartitions = 
partitionSpecs.length))
+          val partitions = partitionSpecs.map {
+            case CoalescedPartitionSpec(start, end, _) => 
CoalescedBoundary(start, end)
+            // Can not happend due to isCoalescedRead
+            case unexpected =>
+              throw SparkException.internalError(s"Unexpected 
ShufflePartitionSpec: $unexpected")
+          }
+          CurrentOrigin.withOrigin(h.origin)(CoalescedHashPartitioning(h, 
partitions))
         case r: RangePartitioning =>
           CurrentOrigin.withOrigin(r.origin)(r.copy(numPartitions = 
partitionSpecs.length))
         // This can only happen for `REBALANCE_PARTITIONS_BY_NONE`, which uses
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 6d9c43f866a..207c66dc4d4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -2541,6 +2541,20 @@ class DatasetSuite extends QueryTest
     val ds = Seq(1, 2).toDS().persist(StorageLevel.NONE)
     assert(ds.count() == 2)
   }
+
+  test("SPARK-45592: Coaleasced shuffle read is not compatible with hash 
partitioning") {
+    val ee = spark.range(0, 1000000, 1, 5).map(l => (l, l)).toDF()
+      .persist(org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK)
+    ee.count()
+
+    val minNbrs1 = ee
+      .groupBy("_1").agg(min(col("_2")).as("min_number"))
+      .persist(org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK)
+
+    val join = ee.join(minNbrs1, "_1")
+    assert(join.count() == 1000000)
+  }
+
 }
 
 class DatasetLargeResultCollectingSuite extends QueryTest
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala
index 6cab0e0239d..40938eb6424 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.{catalyst, AnalysisException, 
DataFrame, Row}
 import org.apache.spark.sql.catalyst.expressions.{ApplyFunctionExpression, 
Cast, Literal}
 import org.apache.spark.sql.catalyst.expressions.objects.Invoke
 import org.apache.spark.sql.catalyst.plans.physical
-import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, 
RangePartitioning, UnknownPartitioning}
+import org.apache.spark.sql.catalyst.plans.physical.{CoalescedBoundary, 
CoalescedHashPartitioning, HashPartitioning, RangePartitioning, 
UnknownPartitioning}
 import org.apache.spark.sql.connector.catalog.Identifier
 import org.apache.spark.sql.connector.catalog.functions._
 import org.apache.spark.sql.connector.distributions.{Distribution, 
Distributions}
@@ -264,11 +264,8 @@ class WriteDistributionAndOrderingSuite extends 
DistributionAndOrderingSuiteBase
       )
     )
     val writePartitioningExprs = Seq(attr("data"), attr("id"))
-    val writePartitioning = if (!coalesce) {
-      clusteredWritePartitioning(writePartitioningExprs, targetNumPartitions)
-    } else {
-      clusteredWritePartitioning(writePartitioningExprs, Some(1))
-    }
+    val writePartitioning = clusteredWritePartitioning(
+      writePartitioningExprs, targetNumPartitions, coalesce)
 
     checkWriteRequirements(
       tableDistribution,
@@ -377,11 +374,8 @@ class WriteDistributionAndOrderingSuite extends 
DistributionAndOrderingSuiteBase
       )
     )
     val writePartitioningExprs = Seq(attr("data"))
-    val writePartitioning = if (!coalesce) {
-      clusteredWritePartitioning(writePartitioningExprs, targetNumPartitions)
-    } else {
-      clusteredWritePartitioning(writePartitioningExprs, Some(1))
-    }
+    val writePartitioning = clusteredWritePartitioning(
+      writePartitioningExprs, targetNumPartitions, coalesce)
 
     checkWriteRequirements(
       tableDistribution,
@@ -875,11 +869,8 @@ class WriteDistributionAndOrderingSuite extends 
DistributionAndOrderingSuiteBase
       )
     )
     val writePartitioningExprs = Seq(attr("data"))
-    val writePartitioning = if (!coalesce) {
-      clusteredWritePartitioning(writePartitioningExprs, targetNumPartitions)
-    } else {
-      clusteredWritePartitioning(writePartitioningExprs, Some(1))
-    }
+    val writePartitioning = clusteredWritePartitioning(
+      writePartitioningExprs, targetNumPartitions, coalesce)
 
     checkWriteRequirements(
       tableDistribution,
@@ -963,11 +954,8 @@ class WriteDistributionAndOrderingSuite extends 
DistributionAndOrderingSuiteBase
       )
     )
     val writePartitioningExprs = Seq(attr("data"))
-    val writePartitioning = if (!coalesce) {
-      clusteredWritePartitioning(writePartitioningExprs, targetNumPartitions)
-    } else {
-      clusteredWritePartitioning(writePartitioningExprs, Some(1))
-    }
+    val writePartitioning = clusteredWritePartitioning(
+      writePartitioningExprs, targetNumPartitions, coalesce)
 
     checkWriteRequirements(
       tableDistribution,
@@ -1154,11 +1142,8 @@ class WriteDistributionAndOrderingSuite extends 
DistributionAndOrderingSuiteBase
     )
 
     val writePartitioningExprs = Seq(truncateExpr)
-    val writePartitioning = if (!coalesce) {
-      clusteredWritePartitioning(writePartitioningExprs, targetNumPartitions)
-    } else {
-      clusteredWritePartitioning(writePartitioningExprs, Some(1))
-    }
+    val writePartitioning = clusteredWritePartitioning(
+      writePartitioningExprs, targetNumPartitions, coalesce)
 
     checkWriteRequirements(
       tableDistribution,
@@ -1422,6 +1407,9 @@ class WriteDistributionAndOrderingSuite extends 
DistributionAndOrderingSuiteBase
       case p: physical.HashPartitioning =>
         val resolvedExprs = p.expressions.map(resolveAttrs(_, plan))
         p.copy(expressions = resolvedExprs)
+      case c: physical.CoalescedHashPartitioning =>
+        val resolvedExprs = c.from.expressions.map(resolveAttrs(_, plan))
+        c.copy(from = c.from.copy(expressions = resolvedExprs))
       case _: UnknownPartitioning =>
         // don't check partitioning if no particular one is expected
         actualPartitioning
@@ -1480,9 +1468,16 @@ class WriteDistributionAndOrderingSuite extends 
DistributionAndOrderingSuiteBase
 
   private def clusteredWritePartitioning(
       writePartitioningExprs: Seq[catalyst.expressions.Expression],
-      targetNumPartitions: Option[Int]): physical.Partitioning = {
-    HashPartitioning(writePartitioningExprs,
-      targetNumPartitions.getOrElse(conf.numShufflePartitions))
+      targetNumPartitions: Option[Int],
+      coalesce: Boolean): physical.Partitioning = {
+    val partitioning = HashPartitioning(writePartitioningExprs,
+        targetNumPartitions.getOrElse(conf.numShufflePartitions))
+    if (coalesce)  {
+      CoalescedHashPartitioning(
+        partitioning, Seq(CoalescedBoundary(0, partitioning.numPartitions)))
+    } else {
+      partitioning
+    }
   }
 
   private def partitionSizes(dataSkew: Boolean, coalesce: Boolean): 
Seq[Option[Long]] = {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to