This is an automated email from the ASF dual-hosted git repository.

sunchao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ce12f6dbad2 [SPARK-41471][SQL] Reduce Spark shuffle when only one side 
of a join is KeyGroupedPartitioning
ce12f6dbad2 is described below

commit ce12f6dbad2d713c6a2a52ac36d1f7a910399ad4
Author: Jia Fan <[email protected]>
AuthorDate: Thu Aug 24 08:35:16 2023 -0700

    [SPARK-41471][SQL] Reduce Spark shuffle when only one side of a join is 
KeyGroupedPartitioning
    
    ### What changes were proposed in this pull request?
    When only one side of a SPJ (Storage-Partitioned Join) is 
KeyGroupedPartitioning, Spark currently needs to shuffle both sides using 
HashPartitioning. However, we may just need to shuffle the other side according 
to the partition transforms defined in KeyGroupedPartitioning. This is 
especially useful when the other side is relatively small.
    1. Add new config `spark.sql.sources.v2.bucketing.shuffle.enabled` to 
control this feature enable or not.
    2. Add `KeyGroupedPartitioner` use to partition when we know the tranform 
value of another side (KeyGroupedPartitioning at now). Spark already know the 
partition value with partition id of KeyGroupedPartitioning side in 
`EnsureRequirements`. Then save it in `KeyGroupedPartitioner` use to shuffle 
another partition, to make sure the same key data will shuffle into same 
partition.
    3. only `identity` transform will work now. Because have another problem 
for now, same transform between DS V2 connector implement and catalog function 
will report different value, before solve this problem, we should only support 
`identity`. eg: in test package, `YearFunction` 
https://github.com/apache/spark/blob/master/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala#L47
 and https://github.com/apache/spark/blob/master/sql/catalyst/src [...]
    
    ### Why are the changes needed?
    Reduce data shuffle in specific SPJ scenarios
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    add new test
    
    Closes #42194 from Hisoka-X/SPARK-41471_one_side_keygroup.
    
    Authored-by: Jia Fan <[email protected]>
    Signed-off-by: Chao Sun <[email protected]>
---
 .../main/scala/org/apache/spark/Partitioner.scala  |  16 ++
 .../sql/catalyst/plans/physical/partitioning.scala |   8 +-
 .../util/InternalRowComparableWrapper.scala        |   7 +-
 .../org/apache/spark/sql/internal/SQLConf.scala    |  13 ++
 .../execution/datasources/v2/BatchScanExec.scala   |   2 +-
 .../execution/exchange/ShuffleExchangeExec.scala   |   9 +
 .../connector/KeyGroupedPartitioningSuite.scala    | 181 +++++++++++++++++++++
 .../exchange/EnsureRequirementsSuite.scala         |  32 +++-
 8 files changed, 263 insertions(+), 5 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala 
b/core/src/main/scala/org/apache/spark/Partitioner.scala
index 5dffba2ee8e..ae39e2e183e 100644
--- a/core/src/main/scala/org/apache/spark/Partitioner.scala
+++ b/core/src/main/scala/org/apache/spark/Partitioner.scala
@@ -137,6 +137,22 @@ private[spark] class PartitionIdPassthrough(override val 
numPartitions: Int) ext
   override def getPartition(key: Any): Int = key.asInstanceOf[Int]
 }
 
+/**
+ * A [[org.apache.spark.Partitioner]] that partitions all records using 
partition value map.
+ * The `valueMap` is a map that contains tuples of (partition value, partition 
id). It is generated
+ * by [[org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning]], 
used to partition
+ * the other side of a join to make sure records with same partition value are 
in the same
+ * partition.
+ */
+private[spark] class KeyGroupedPartitioner(
+    valueMap: mutable.Map[Seq[Any], Int],
+    override val numPartitions: Int) extends Partitioner {
+  override def getPartition(key: Any): Int = {
+    val keys = key.asInstanceOf[Seq[Any]]
+    valueMap.getOrElseUpdate(keys, Utils.nonNegativeMod(keys.hashCode, 
numPartitions))
+  }
+}
+
 /**
  * A [[org.apache.spark.Partitioner]] that partitions all records into a 
single partition.
  */
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index 456005768bd..ce557422a08 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -735,7 +735,13 @@ case class KeyGroupedShuffleSpec(
       case _ => false
     }
 
-  override def canCreatePartitioning: Boolean = false
+  override def canCreatePartitioning: Boolean = 
SQLConf.get.v2BucketingShuffleEnabled &&
+    // Only support partition expressions are AttributeReference for now
+    partitioning.expressions.forall(_.isInstanceOf[AttributeReference])
+
+  override def createPartitioning(clustering: Seq[Expression]): Partitioning = 
{
+    KeyGroupedPartitioning(clustering, partitioning.numPartitions, 
partitioning.partitionValues)
+  }
 }
 
 case class ShuffleSpecCollection(specs: Seq[ShuffleSpec]) extends ShuffleSpec {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala
index b0e53090731..9a0bdc6bcfd 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala
@@ -79,6 +79,11 @@ object InternalRowComparableWrapper {
     rightPartitioning.partitionValues
       .map(new InternalRowComparableWrapper(_, partitionDataTypes))
       .foreach(partition => partitionsSet.add(partition))
-    partitionsSet.map(_.row).toSeq
+    // SPARK-41471: We keep to order of partitions to make sure the order of
+    // partitions is deterministic in different case.
+    val partitionOrdering: Ordering[InternalRow] = {
+      RowOrdering.createNaturalAscendingOrdering(partitionDataTypes)
+    }
+    partitionsSet.map(_.row).toSeq.sorted(partitionOrdering)
   }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 556d6b7c798..8bbc64a0aa7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -1500,6 +1500,16 @@ object SQLConf {
       .booleanConf
       .createWithDefault(false)
 
+ val V2_BUCKETING_SHUFFLE_ENABLED =
+    buildConf("spark.sql.sources.v2.bucketing.shuffle.enabled")
+      .doc("During a storage-partitioned join, whether to allow to shuffle 
only one side." +
+        "When only one side is KeyGroupedPartitioning, if the conditions are 
met, spark will " +
+        "only shuffle the other side. This optimization will reduce the amount 
of data that " +
+        s"needs to be shuffle. This config requires 
${V2_BUCKETING_ENABLED.key} to be enabled")
+      .version("4.0.0")
+      .booleanConf
+      .createWithDefault(false)
+
   val BUCKETING_MAX_BUCKETS = 
buildConf("spark.sql.sources.bucketing.maxBuckets")
     .doc("The maximum number of buckets allowed.")
     .version("2.4.0")
@@ -4899,6 +4909,9 @@ class SQLConf extends Serializable with Logging with 
SqlApiConf {
   def v2BucketingPartiallyClusteredDistributionEnabled: Boolean =
     getConf(SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED)
 
+  def v2BucketingShuffleEnabled: Boolean =
+    getConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED)
+
   def dataFrameSelfJoinAutoResolveAmbiguity: Boolean =
     getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY)
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
index eba3c71f871..cc674961f8e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
@@ -153,7 +153,7 @@ case class BatchScanExec(
             if (spjParams.commonPartitionValues.isDefined &&
               spjParams.applyPartialClustering) {
               // A mapping from the common partition values to how many splits 
the partition
-              // should contain. Note this no longer maintain the partition 
key ordering.
+              // should contain.
               val commonPartValuesMap = spjParams.commonPartitionValues
                 .get
                 .map(t => (InternalRowComparableWrapper(t._1, p.expressions), 
t._2))
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
index 91f2099ce2d..750b96dc83d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.exchange
 
 import java.util.function.Supplier
 
+import scala.collection.mutable
 import scala.concurrent.Future
 
 import org.apache.spark._
@@ -29,6 +30,7 @@ import org.apache.spark.shuffle.{ShuffleWriteMetricsReporter, 
ShuffleWriteProces
 import org.apache.spark.shuffle.sort.SortShuffleManager
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, 
UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import 
org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
 import org.apache.spark.sql.catalyst.plans.logical.Statistics
 import org.apache.spark.sql.catalyst.plans.physical._
@@ -299,6 +301,11 @@ object ShuffleExchangeExec {
           ascending = true,
           samplePointsPerPartitionHint = 
SQLConf.get.rangeExchangeSampleSizePerPartition)
       case SinglePartition => new ConstantPartitioner
+      case k @ KeyGroupedPartitioning(expressions, n, _) =>
+        val valueMap = k.uniquePartitionValues.zipWithIndex.map {
+          case (partition, index) => 
(partition.toSeq(expressions.map(_.dataType)), index)
+        }.toMap
+        new KeyGroupedPartitioner(mutable.Map(valueMap.toSeq: _*), n)
       case _ => throw new IllegalStateException(s"Exchange not implemented for 
$newPartitioning")
       // TODO: Handle BroadcastPartitioning.
     }
@@ -325,6 +332,8 @@ object ShuffleExchangeExec {
         val projection = 
UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes)
         row => projection(row)
       case SinglePartition => identity
+      case KeyGroupedPartitioning(expressions, _, _) =>
+        row => bindReferences(expressions, outputAttributes).map(_.eval(row))
       case _ => throw new IllegalStateException(s"Exchange not implemented for 
$newPartitioning")
     }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
index 8461f528277..5b5e4021173 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
@@ -1040,6 +1040,187 @@ class KeyGroupedPartitioningSuite extends 
DistributionAndOrderingSuiteBase {
     }
   }
 
+  test("SPARK-41471: shuffle one side: only one side reports partitioning") {
+    val items_partitions = Array(identity("id"))
+    createTable(items, items_schema, items_partitions)
+
+    sql(s"INSERT INTO testcat.ns.$items VALUES " +
+      "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+      "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+      "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")
+
+    createTable(purchases, purchases_schema, Array.empty)
+    sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
+      "(1, 42.0, cast('2020-01-01' as timestamp)), " +
+      "(3, 19.5, cast('2020-02-01' as timestamp))")
+
+    Seq(true, false).foreach { shuffle =>
+      withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> 
shuffle.toString) {
+        val df = sql("SELECT id, name, i.price as purchase_price, p.price as 
sale_price " +
+          s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " +
+          "ON i.id = p.item_id ORDER BY id, purchase_price, sale_price")
+
+        val shuffles = collectShuffles(df.queryExecution.executedPlan)
+        if (shuffle) {
+          assert(shuffles.size == 1, "only shuffle one side not report 
partitioning")
+        } else {
+          assert(shuffles.size == 2, "should add two side shuffle when 
bucketing shuffle one side" +
+            " is not enabled")
+        }
+
+        checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 
19.5)))
+      }
+    }
+  }
+
+  test("SPARK-41471: shuffle one side: shuffle side has more partition value") 
{
+    val items_partitions = Array(identity("id"))
+    createTable(items, items_schema, items_partitions)
+
+    sql(s"INSERT INTO testcat.ns.$items VALUES " +
+      "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+      "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+      "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")
+
+    createTable(purchases, purchases_schema, Array.empty)
+    sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
+      "(1, 42.0, cast('2020-01-01' as timestamp)), " +
+      "(3, 19.5, cast('2020-02-01' as timestamp)), " +
+      "(5, 26.0, cast('2023-01-01' as timestamp)), " +
+      "(6, 50.0, cast('2023-02-01' as timestamp))")
+
+    Seq(true, false).foreach { shuffle =>
+      withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> 
shuffle.toString) {
+        Seq("JOIN", "LEFT OUTER JOIN", "RIGHT OUTER JOIN", "FULL OUTER 
JOIN").foreach { joinType =>
+          val df = sql(s"SELECT id, name, i.price as purchase_price, p.price 
as sale_price " +
+            s"FROM testcat.ns.$items i $joinType testcat.ns.$purchases p " +
+            "ON i.id = p.item_id ORDER BY id, purchase_price, sale_price")
+
+          val shuffles = collectShuffles(df.queryExecution.executedPlan)
+          if (shuffle) {
+            assert(shuffles.size == 1, "only shuffle one side not report 
partitioning")
+          } else {
+            assert(shuffles.size == 2, "should add two side shuffle when 
bucketing shuffle one " +
+              "side is not enabled")
+          }
+          joinType match {
+            case "JOIN" =>
+              checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 
19.5)))
+            case "LEFT OUTER JOIN" =>
+              checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 
19.5),
+                Row(4, "cc", 15.5, null)))
+            case "RIGHT OUTER JOIN" =>
+              checkAnswer(df, Seq(Row(null, null, null, 26.0), Row(null, null, 
null, 50.0),
+                Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 19.5)))
+            case "FULL OUTER JOIN" =>
+              checkAnswer(df, Seq(Row(null, null, null, 26.0), Row(null, null, 
null, 50.0),
+                Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 19.5),
+                Row(4, "cc", 15.5, null)))
+          }
+        }
+      }
+    }
+  }
+
+  test("SPARK-41471: shuffle one side: only one side reports partitioning with 
two identity") {
+    val items_partitions = Array(identity("id"), identity("arrive_time"))
+    createTable(items, items_schema, items_partitions)
+
+    sql(s"INSERT INTO testcat.ns.$items VALUES " +
+      "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+      "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+      "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")
+
+    createTable(purchases, purchases_schema, Array.empty)
+    sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
+      "(1, 42.0, cast('2020-01-01' as timestamp)), " +
+      "(3, 19.5, cast('2020-02-01' as timestamp))")
+
+    Seq(true, false).foreach { shuffle =>
+      withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> 
shuffle.toString) {
+        val df = sql("SELECT id, name, i.price as purchase_price, p.price as 
sale_price " +
+          s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " +
+          "ON i.id = p.item_id and i.arrive_time = p.time ORDER BY id, 
purchase_price, sale_price")
+
+        val shuffles = collectShuffles(df.queryExecution.executedPlan)
+        if (shuffle) {
+          assert(shuffles.size == 1, "only shuffle one side not report 
partitioning")
+        } else {
+          assert(shuffles.size == 2, "should add two side shuffle when 
bucketing shuffle one side" +
+            " is not enabled")
+        }
+
+        checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0)))
+      }
+    }
+  }
+
+  test("SPARK-41471: shuffle one side: partitioning with transform") {
+    val items_partitions = Array(years("arrive_time"))
+    createTable(items, items_schema, items_partitions)
+
+    sql(s"INSERT INTO testcat.ns.$items VALUES " +
+      "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+      "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+      "(4, 'cc', 15.5, cast('2021-02-01' as timestamp))")
+
+    createTable(purchases, purchases_schema, Array.empty)
+    sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
+      "(1, 42.0, cast('2020-01-01' as timestamp)), " +
+      "(3, 19.5, cast('2021-02-01' as timestamp))")
+
+    Seq(true, false).foreach { shuffle =>
+      withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> 
shuffle.toString) {
+        val df = sql("SELECT id, name, i.price as purchase_price, p.price as 
sale_price " +
+          s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " +
+          "ON i.arrive_time = p.time ORDER BY id, purchase_price, sale_price")
+
+        val shuffles = collectShuffles(df.queryExecution.executedPlan)
+        if (shuffle) {
+          assert(shuffles.size == 2, "partitioning with transform not work 
now")
+        } else {
+          assert(shuffles.size == 2, "should add two side shuffle when 
bucketing shuffle one side" +
+            " is not enabled")
+        }
+
+        checkAnswer(df, Seq(
+          Row(1, "aa", 40.0, 42.0),
+          Row(3, "bb", 10.0, 42.0),
+          Row(4, "cc", 15.5, 19.5)))
+      }
+    }
+  }
+
+  test("SPARK-41471: shuffle one side: work with group partition split") {
+    val items_partitions = Array(identity("id"))
+    createTable(items, items_schema, items_partitions)
+
+    sql(s"INSERT INTO testcat.ns.$items VALUES " +
+      "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+      "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+      "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")
+
+    createTable(purchases, purchases_schema, Array.empty)
+    sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
+      "(1, 42.0, cast('2020-01-01' as timestamp)), " +
+      "(3, 19.5, cast('2020-02-01' as timestamp)), " +
+      "(5, 26.0, cast('2023-01-01' as timestamp)), " +
+      "(6, 50.0, cast('2023-02-01' as timestamp))")
+
+    Seq(true, false).foreach { shuffle =>
+      withSQLConf(
+        SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> shuffle.toString,
+        SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
+        SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> 
"true") {
+        val df = sql("SELECT id, name, i.price as purchase_price, p.price as 
sale_price " +
+          s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " +
+          "ON i.id = p.item_id ORDER BY id, purchase_price, sale_price")
+
+        checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 
19.5)))
+      }
+    }
+  }
+
   test("SPARK-44641: duplicated records when SPJ is not triggered") {
     val items_partitions = Array(bucket(8, "id"))
     createTable(items, items_schema, items_partitions)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
index 09da1e1e7b0..3c9b92e5f66 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
@@ -18,15 +18,17 @@
 package org.apache.spark.sql.execution.exchange
 
 import org.apache.spark.api.python.PythonEvalType
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
+import org.apache.spark.sql.catalyst.optimizer.BuildRight
 import org.apache.spark.sql.catalyst.plans.Inner
-import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.catalyst.plans.physical.{SinglePartition, _}
 import org.apache.spark.sql.catalyst.statsEstimation.StatsTestPlan
 import org.apache.spark.sql.connector.catalog.functions._
 import org.apache.spark.sql.execution.{DummySparkPlan, SortExec}
 import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.joins.SortMergeJoinExec
+import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, 
SortMergeJoinExec}
 import org.apache.spark.sql.execution.python.FlatMapCoGroupsInPandasExec
 import org.apache.spark.sql.execution.window.WindowExec
 import org.apache.spark.sql.internal.SQLConf
@@ -1109,6 +1111,32 @@ class EnsureRequirementsSuite extends SharedSparkSession 
{
     }
   }
 
+  test("SPARK-41471: shuffle right side when" +
+    " spark.sql.sources.v2.bucketing.shuffle.enabled is true") {
+    withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true") {
+
+      val a1 = AttributeReference("a1", IntegerType)()
+
+      val partitionValue = Seq(50, 51, 52).map(v => 
InternalRow.fromSeq(Seq(v)))
+      val plan1 = DummySparkPlan(outputPartitioning = KeyGroupedPartitioning(
+          identity(a1) :: Nil, 4, partitionValue))
+      val plan2 = DummySparkPlan(outputPartitioning = SinglePartition)
+
+      val smjExec = ShuffledHashJoinExec(
+        a1 :: Nil, a1 :: Nil, Inner, BuildRight, None, plan1, plan2)
+      EnsureRequirements.apply(smjExec) match {
+        case ShuffledHashJoinExec(_, _, _, _, _,
+        DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _),
+        ShuffleExchangeExec(KeyGroupedPartitioning(attrs, 4, pv),
+        DummySparkPlan(_, _, SinglePartition, _, _), _, _), _) =>
+          assert(left.expressions == a1 :: Nil)
+          assert(attrs == a1 :: Nil)
+          assert(partitionValue == pv)
+        case other => fail(other.toString)
+      }
+    }
+  }
+
   test("SPARK-42168: FlatMapCoGroupInPandas and Window function with differing 
key order") {
     val lKey = AttributeReference("key", IntegerType)()
     val lKey2 = AttributeReference("key2", IntegerType)()


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to