szehon-ho commented on code in PR #52443:
URL: https://github.com/apache/spark/pull/52443#discussion_r2450017594


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala:
##########
@@ -1196,6 +1197,191 @@ class EnsureRequirementsSuite extends 
SharedSparkSession {
     TransformExpression(BucketFunction, expr, Some(numBuckets))
   }
 
+  test("ShufflePartitionIdPassThrough - avoid unnecessary shuffle when 
children are compatible") {
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+      val passThrough_a_5 = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)
+
+      val leftPlan = DummySparkPlan(outputPartitioning = passThrough_a_5)
+      val rightPlan = DummySparkPlan(outputPartitioning = passThrough_a_5)
+      val join = SortMergeJoinExec(exprA :: Nil, exprA :: Nil, Inner, None, 
leftPlan, rightPlan)
+
+      EnsureRequirements.apply(join) match {
+        case SortMergeJoinExec(
+            leftKeys,
+            rightKeys,
+            _,
+            _,
+            SortExec(_, _, DummySparkPlan(_, _, _: 
ShufflePartitionIdPassThrough, _, _), _),
+            SortExec(_, _, DummySparkPlan(_, _, _: 
ShufflePartitionIdPassThrough, _, _), _),
+            _
+            ) =>
+          assert(leftKeys === Seq(exprA))
+          assert(rightKeys === Seq(exprA))
+        case other => fail(s"We don't expect shuffle on neither sides, but 
got: $other")
+      }
+    }
+  }
+
+  test("ShufflePartitionIdPassThrough incompatibility - different partitions") 
{
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+      // Different number of partitions - should add shuffles
+      val leftPlan = DummySparkPlan(
+        outputPartitioning = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+      val rightPlan = DummySparkPlan(
+        outputPartitioning = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 8))
+      val join = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, 
leftPlan, rightPlan)
+
+      EnsureRequirements.apply(join) match {
+        case SortMergeJoinExec(_, _, _, _,
+          SortExec(_, _, ShuffleExchangeExec(p1: HashPartitioning, _, _, _), 
_),
+          SortExec(_, _, ShuffleExchangeExec(p2: HashPartitioning, _, _, _), 
_), _) =>
+          // Both sides should be shuffled to default partitions
+          assert(p1.numPartitions == 10)
+          assert(p2.numPartitions == 10)
+          assert(p1.expressions == Seq(exprA))
+          assert(p2.expressions == Seq(exprB))
+        case other => fail(s"Expected shuffles on both sides, but got: $other")
+      }
+    }
+  }
+
+  test("ShufflePartitionIdPassThrough incompatibility - key position 
mismatch") {
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+      // Key position mismatch - should add shuffles
+      val leftPlan = DummySparkPlan(
+        outputPartitioning = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+      val rightPlan = DummySparkPlan(
+        outputPartitioning = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprC), 5))
+      // Join on different keys than partitioning keys
+      val join = SortMergeJoinExec(exprA :: exprB :: Nil, exprD :: exprC :: 
Nil, Inner, None,
+        leftPlan, rightPlan)
+
+      EnsureRequirements.apply(join) match {
+        case SortMergeJoinExec(_, _, _, _,
+          SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _),
+          SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), 
_), _) =>
+          // Both sides shuffled due to key mismatch
+        case other => fail(s"Expected shuffles on both sides, but got: $other")
+      }
+    }
+  }
+
+  test("ShufflePartitionIdPassThrough vs HashPartitioning - always shuffles") {
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+      // ShufflePartitionIdPassThrough vs HashPartitioning - always adds 
shuffles
+      val leftPlan = DummySparkPlan(
+        outputPartitioning = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+      val rightPlan = DummySparkPlan(
+        outputPartitioning = HashPartitioning(exprB :: Nil, 5))
+      val join = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, 
leftPlan, rightPlan)
+
+      EnsureRequirements.apply(join) match {
+        case SortMergeJoinExec(_, _, _, _,
+          SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _),
+          SortExec(_, _, _: DummySparkPlan, _), _) =>
+          // Left side shuffled, right side kept as-is
+        case other => fail(s"Expected shuffle on at least one side, but got: 
$other")
+      }
+    }
+  }
+
+  test("ShufflePartitionIdPassThrough vs SinglePartition - shuffles added") {
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "5") {
+      // Even when compatible (numPartitions=1), shuffles added due to 
canCreatePartitioning=false
+      val leftPlan = DummySparkPlan(
+        outputPartitioning = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 1))
+      val rightPlan = DummySparkPlan(outputPartitioning = SinglePartition)
+      val join = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, 
leftPlan, rightPlan)
+
+      EnsureRequirements.apply(join) match {
+        case SortMergeJoinExec(_, _, _, _,
+          SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _),
+          SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), 
_), _) =>
+          // Both sides shuffled due to canCreatePartitioning = false
+        case other => fail(s"Expected shuffles on both sides, but got: $other")
+      }
+    }
+  }
+
+
+  test("ShufflePartitionIdPassThrough - compatible with multiple clustering 
keys") {
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+      val passThrough_a_5 = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)
+      val passThrough_b_5 = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 5)
+
+      // Both partitioned by exprA, joined on (exprA, exprB)
+      // Should be compatible because exprA positions overlap
+      val leftPlanA = DummySparkPlan(outputPartitioning = passThrough_a_5)
+      val rightPlanA = DummySparkPlan(outputPartitioning = passThrough_a_5)
+      val joinA = SortMergeJoinExec(exprA :: exprB :: Nil, exprA :: exprB :: 
Nil, Inner, None,
+        leftPlanA, rightPlanA)
+
+      EnsureRequirements.apply(joinA) match {
+        case SortMergeJoinExec(
+            leftKeys,
+            rightKeys,
+            _,
+            _,
+            SortExec(_, _, DummySparkPlan(_, _, _: 
ShufflePartitionIdPassThrough, _, _), _),
+            SortExec(_, _, DummySparkPlan(_, _, _: 
ShufflePartitionIdPassThrough, _, _), _),
+            _
+            ) =>
+          assert(leftKeys === Seq(exprA, exprB))
+          assert(rightKeys === Seq(exprA, exprB))
+        case other => fail(s"We don't expect shuffle on neither sides with 
multiple " +

Review Comment:
   nit: same (on either side)



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala:
##########
@@ -1196,6 +1197,191 @@ class EnsureRequirementsSuite extends 
SharedSparkSession {
     TransformExpression(BucketFunction, expr, Some(numBuckets))
   }
 
+  test("ShufflePartitionIdPassThrough - avoid unnecessary shuffle when 
children are compatible") {
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+      val passThrough_a_5 = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)
+
+      val leftPlan = DummySparkPlan(outputPartitioning = passThrough_a_5)
+      val rightPlan = DummySparkPlan(outputPartitioning = passThrough_a_5)
+      val join = SortMergeJoinExec(exprA :: Nil, exprA :: Nil, Inner, None, 
leftPlan, rightPlan)
+
+      EnsureRequirements.apply(join) match {
+        case SortMergeJoinExec(
+            leftKeys,
+            rightKeys,
+            _,
+            _,
+            SortExec(_, _, DummySparkPlan(_, _, _: 
ShufflePartitionIdPassThrough, _, _), _),
+            SortExec(_, _, DummySparkPlan(_, _, _: 
ShufflePartitionIdPassThrough, _, _), _),
+            _
+            ) =>
+          assert(leftKeys === Seq(exprA))
+          assert(rightKeys === Seq(exprA))
+        case other => fail(s"We don't expect shuffle on neither sides, but 
got: $other")

Review Comment:
   nit:  'on either side' (its double negative)



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala:
##########
@@ -1196,6 +1197,191 @@ class EnsureRequirementsSuite extends 
SharedSparkSession {
     TransformExpression(BucketFunction, expr, Some(numBuckets))
   }
 
+  test("ShufflePartitionIdPassThrough - avoid unnecessary shuffle when 
children are compatible") {
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+      val passThrough_a_5 = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5)
+
+      val leftPlan = DummySparkPlan(outputPartitioning = passThrough_a_5)
+      val rightPlan = DummySparkPlan(outputPartitioning = passThrough_a_5)
+      val join = SortMergeJoinExec(exprA :: Nil, exprA :: Nil, Inner, None, 
leftPlan, rightPlan)
+
+      EnsureRequirements.apply(join) match {
+        case SortMergeJoinExec(
+            leftKeys,
+            rightKeys,
+            _,
+            _,
+            SortExec(_, _, DummySparkPlan(_, _, _: 
ShufflePartitionIdPassThrough, _, _), _),
+            SortExec(_, _, DummySparkPlan(_, _, _: 
ShufflePartitionIdPassThrough, _, _), _),
+            _
+            ) =>
+          assert(leftKeys === Seq(exprA))
+          assert(rightKeys === Seq(exprA))
+        case other => fail(s"We don't expect shuffle on neither sides, but 
got: $other")
+      }
+    }
+  }
+
+  test("ShufflePartitionIdPassThrough incompatibility - different partitions") 
{
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+      // Different number of partitions - should add shuffles
+      val leftPlan = DummySparkPlan(
+        outputPartitioning = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+      val rightPlan = DummySparkPlan(
+        outputPartitioning = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprB), 8))
+      val join = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, 
leftPlan, rightPlan)
+
+      EnsureRequirements.apply(join) match {
+        case SortMergeJoinExec(_, _, _, _,
+          SortExec(_, _, ShuffleExchangeExec(p1: HashPartitioning, _, _, _), 
_),
+          SortExec(_, _, ShuffleExchangeExec(p2: HashPartitioning, _, _, _), 
_), _) =>
+          // Both sides should be shuffled to default partitions
+          assert(p1.numPartitions == 10)
+          assert(p2.numPartitions == 10)
+          assert(p1.expressions == Seq(exprA))
+          assert(p2.expressions == Seq(exprB))
+        case other => fail(s"Expected shuffles on both sides, but got: $other")
+      }
+    }
+  }
+
+  test("ShufflePartitionIdPassThrough incompatibility - key position 
mismatch") {
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+      // Key position mismatch - should add shuffles
+      val leftPlan = DummySparkPlan(
+        outputPartitioning = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+      val rightPlan = DummySparkPlan(
+        outputPartitioning = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprC), 5))
+      // Join on different keys than partitioning keys
+      val join = SortMergeJoinExec(exprA :: exprB :: Nil, exprD :: exprC :: 
Nil, Inner, None,
+        leftPlan, rightPlan)
+
+      EnsureRequirements.apply(join) match {
+        case SortMergeJoinExec(_, _, _, _,
+          SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _),
+          SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), 
_), _) =>
+          // Both sides shuffled due to key mismatch
+        case other => fail(s"Expected shuffles on both sides, but got: $other")
+      }
+    }
+  }
+
+  test("ShufflePartitionIdPassThrough vs HashPartitioning - always shuffles") {
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+      // ShufflePartitionIdPassThrough vs HashPartitioning - always adds 
shuffles
+      val leftPlan = DummySparkPlan(
+        outputPartitioning = 
ShufflePartitionIdPassThrough(DirectShufflePartitionID(exprA), 5))
+      val rightPlan = DummySparkPlan(
+        outputPartitioning = HashPartitioning(exprB :: Nil, 5))
+      val join = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, 
leftPlan, rightPlan)
+
+      EnsureRequirements.apply(join) match {
+        case SortMergeJoinExec(_, _, _, _,
+          SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _, _), _),
+          SortExec(_, _, _: DummySparkPlan, _), _) =>
+          // Left side shuffled, right side kept as-is
+        case other => fail(s"Expected shuffle on at least one side, but got: 
$other")

Review Comment:
   nit: just say 'left side'



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to