This is an automated email from the ASF dual-hosted git repository.
sunchao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new ce12f6dbad2 [SPARK-41471][SQL] Reduce Spark shuffle when only one side
of a join is KeyGroupedPartitioning
ce12f6dbad2 is described below
commit ce12f6dbad2d713c6a2a52ac36d1f7a910399ad4
Author: Jia Fan <[email protected]>
AuthorDate: Thu Aug 24 08:35:16 2023 -0700
[SPARK-41471][SQL] Reduce Spark shuffle when only one side of a join is
KeyGroupedPartitioning
### What changes were proposed in this pull request?
When only one side of a SPJ (Storage-Partitioned Join) is
KeyGroupedPartitioning, Spark currently needs to shuffle both sides using
HashPartitioning. However, we may just need to shuffle the other side according
to the partition transforms defined in KeyGroupedPartitioning. This is
especially useful when the other side is relatively small.
1. Add new config `spark.sql.sources.v2.bucketing.shuffle.enabled` to
control this feature enable or not.
2. Add `KeyGroupedPartitioner` use to partition when we know the tranform
value of another side (KeyGroupedPartitioning at now). Spark already know the
partition value with partition id of KeyGroupedPartitioning side in
`EnsureRequirements`. Then save it in `KeyGroupedPartitioner` use to shuffle
another partition, to make sure the same key data will shuffle into same
partition.
3. only `identity` transform will work now. Because have another problem
for now, same transform between DS V2 connector implement and catalog function
will report different value, before solve this problem, we should only support
`identity`. eg: in test package, `YearFunction`
https://github.com/apache/spark/blob/master/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala#L47
and https://github.com/apache/spark/blob/master/sql/catalyst/src [...]
### Why are the changes needed?
Reduce data shuffle in specific SPJ scenarios
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
add new test
Closes #42194 from Hisoka-X/SPARK-41471_one_side_keygroup.
Authored-by: Jia Fan <[email protected]>
Signed-off-by: Chao Sun <[email protected]>
---
.../main/scala/org/apache/spark/Partitioner.scala | 16 ++
.../sql/catalyst/plans/physical/partitioning.scala | 8 +-
.../util/InternalRowComparableWrapper.scala | 7 +-
.../org/apache/spark/sql/internal/SQLConf.scala | 13 ++
.../execution/datasources/v2/BatchScanExec.scala | 2 +-
.../execution/exchange/ShuffleExchangeExec.scala | 9 +
.../connector/KeyGroupedPartitioningSuite.scala | 181 +++++++++++++++++++++
.../exchange/EnsureRequirementsSuite.scala | 32 +++-
8 files changed, 263 insertions(+), 5 deletions(-)
diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala
b/core/src/main/scala/org/apache/spark/Partitioner.scala
index 5dffba2ee8e..ae39e2e183e 100644
--- a/core/src/main/scala/org/apache/spark/Partitioner.scala
+++ b/core/src/main/scala/org/apache/spark/Partitioner.scala
@@ -137,6 +137,22 @@ private[spark] class PartitionIdPassthrough(override val
numPartitions: Int) ext
override def getPartition(key: Any): Int = key.asInstanceOf[Int]
}
+/**
+ * A [[org.apache.spark.Partitioner]] that partitions all records using
partition value map.
+ * The `valueMap` is a map that contains tuples of (partition value, partition
id). It is generated
+ * by [[org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning]],
used to partition
+ * the other side of a join to make sure records with same partition value are
in the same
+ * partition.
+ */
+private[spark] class KeyGroupedPartitioner(
+ valueMap: mutable.Map[Seq[Any], Int],
+ override val numPartitions: Int) extends Partitioner {
+ override def getPartition(key: Any): Int = {
+ val keys = key.asInstanceOf[Seq[Any]]
+ valueMap.getOrElseUpdate(keys, Utils.nonNegativeMod(keys.hashCode,
numPartitions))
+ }
+}
+
/**
* A [[org.apache.spark.Partitioner]] that partitions all records into a
single partition.
*/
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index 456005768bd..ce557422a08 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -735,7 +735,13 @@ case class KeyGroupedShuffleSpec(
case _ => false
}
- override def canCreatePartitioning: Boolean = false
+ override def canCreatePartitioning: Boolean =
SQLConf.get.v2BucketingShuffleEnabled &&
+ // Only support partition expressions are AttributeReference for now
+ partitioning.expressions.forall(_.isInstanceOf[AttributeReference])
+
+ override def createPartitioning(clustering: Seq[Expression]): Partitioning =
{
+ KeyGroupedPartitioning(clustering, partitioning.numPartitions,
partitioning.partitionValues)
+ }
}
case class ShuffleSpecCollection(specs: Seq[ShuffleSpec]) extends ShuffleSpec {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala
index b0e53090731..9a0bdc6bcfd 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala
@@ -79,6 +79,11 @@ object InternalRowComparableWrapper {
rightPartitioning.partitionValues
.map(new InternalRowComparableWrapper(_, partitionDataTypes))
.foreach(partition => partitionsSet.add(partition))
- partitionsSet.map(_.row).toSeq
+ // SPARK-41471: We keep to order of partitions to make sure the order of
+ // partitions is deterministic in different case.
+ val partitionOrdering: Ordering[InternalRow] = {
+ RowOrdering.createNaturalAscendingOrdering(partitionDataTypes)
+ }
+ partitionsSet.map(_.row).toSeq.sorted(partitionOrdering)
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 556d6b7c798..8bbc64a0aa7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -1500,6 +1500,16 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
+ val V2_BUCKETING_SHUFFLE_ENABLED =
+ buildConf("spark.sql.sources.v2.bucketing.shuffle.enabled")
+ .doc("During a storage-partitioned join, whether to allow to shuffle
only one side." +
+ "When only one side is KeyGroupedPartitioning, if the conditions are
met, spark will " +
+ "only shuffle the other side. This optimization will reduce the amount
of data that " +
+ s"needs to be shuffle. This config requires
${V2_BUCKETING_ENABLED.key} to be enabled")
+ .version("4.0.0")
+ .booleanConf
+ .createWithDefault(false)
+
val BUCKETING_MAX_BUCKETS =
buildConf("spark.sql.sources.bucketing.maxBuckets")
.doc("The maximum number of buckets allowed.")
.version("2.4.0")
@@ -4899,6 +4909,9 @@ class SQLConf extends Serializable with Logging with
SqlApiConf {
def v2BucketingPartiallyClusteredDistributionEnabled: Boolean =
getConf(SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED)
+ def v2BucketingShuffleEnabled: Boolean =
+ getConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED)
+
def dataFrameSelfJoinAutoResolveAmbiguity: Boolean =
getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
index eba3c71f871..cc674961f8e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
@@ -153,7 +153,7 @@ case class BatchScanExec(
if (spjParams.commonPartitionValues.isDefined &&
spjParams.applyPartialClustering) {
// A mapping from the common partition values to how many splits
the partition
- // should contain. Note this no longer maintain the partition
key ordering.
+ // should contain.
val commonPartValuesMap = spjParams.commonPartitionValues
.get
.map(t => (InternalRowComparableWrapper(t._1, p.expressions),
t._2))
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
index 91f2099ce2d..750b96dc83d 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.exchange
import java.util.function.Supplier
+import scala.collection.mutable
import scala.concurrent.Future
import org.apache.spark._
@@ -29,6 +30,7 @@ import org.apache.spark.shuffle.{ShuffleWriteMetricsReporter,
ShuffleWriteProces
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference,
UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import
org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical._
@@ -299,6 +301,11 @@ object ShuffleExchangeExec {
ascending = true,
samplePointsPerPartitionHint =
SQLConf.get.rangeExchangeSampleSizePerPartition)
case SinglePartition => new ConstantPartitioner
+ case k @ KeyGroupedPartitioning(expressions, n, _) =>
+ val valueMap = k.uniquePartitionValues.zipWithIndex.map {
+ case (partition, index) =>
(partition.toSeq(expressions.map(_.dataType)), index)
+ }.toMap
+ new KeyGroupedPartitioner(mutable.Map(valueMap.toSeq: _*), n)
case _ => throw new IllegalStateException(s"Exchange not implemented for
$newPartitioning")
// TODO: Handle BroadcastPartitioning.
}
@@ -325,6 +332,8 @@ object ShuffleExchangeExec {
val projection =
UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes)
row => projection(row)
case SinglePartition => identity
+ case KeyGroupedPartitioning(expressions, _, _) =>
+ row => bindReferences(expressions, outputAttributes).map(_.eval(row))
case _ => throw new IllegalStateException(s"Exchange not implemented for
$newPartitioning")
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
index 8461f528277..5b5e4021173 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
@@ -1040,6 +1040,187 @@ class KeyGroupedPartitioningSuite extends
DistributionAndOrderingSuiteBase {
}
}
+ test("SPARK-41471: shuffle one side: only one side reports partitioning") {
+ val items_partitions = Array(identity("id"))
+ createTable(items, items_schema, items_partitions)
+
+ sql(s"INSERT INTO testcat.ns.$items VALUES " +
+ "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+ "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+ "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")
+
+ createTable(purchases, purchases_schema, Array.empty)
+ sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
+ "(1, 42.0, cast('2020-01-01' as timestamp)), " +
+ "(3, 19.5, cast('2020-02-01' as timestamp))")
+
+ Seq(true, false).foreach { shuffle =>
+ withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key ->
shuffle.toString) {
+ val df = sql("SELECT id, name, i.price as purchase_price, p.price as
sale_price " +
+ s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " +
+ "ON i.id = p.item_id ORDER BY id, purchase_price, sale_price")
+
+ val shuffles = collectShuffles(df.queryExecution.executedPlan)
+ if (shuffle) {
+ assert(shuffles.size == 1, "only shuffle one side not report
partitioning")
+ } else {
+ assert(shuffles.size == 2, "should add two side shuffle when
bucketing shuffle one side" +
+ " is not enabled")
+ }
+
+ checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0,
19.5)))
+ }
+ }
+ }
+
+ test("SPARK-41471: shuffle one side: shuffle side has more partition value")
{
+ val items_partitions = Array(identity("id"))
+ createTable(items, items_schema, items_partitions)
+
+ sql(s"INSERT INTO testcat.ns.$items VALUES " +
+ "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+ "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+ "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")
+
+ createTable(purchases, purchases_schema, Array.empty)
+ sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
+ "(1, 42.0, cast('2020-01-01' as timestamp)), " +
+ "(3, 19.5, cast('2020-02-01' as timestamp)), " +
+ "(5, 26.0, cast('2023-01-01' as timestamp)), " +
+ "(6, 50.0, cast('2023-02-01' as timestamp))")
+
+ Seq(true, false).foreach { shuffle =>
+ withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key ->
shuffle.toString) {
+ Seq("JOIN", "LEFT OUTER JOIN", "RIGHT OUTER JOIN", "FULL OUTER
JOIN").foreach { joinType =>
+ val df = sql(s"SELECT id, name, i.price as purchase_price, p.price
as sale_price " +
+ s"FROM testcat.ns.$items i $joinType testcat.ns.$purchases p " +
+ "ON i.id = p.item_id ORDER BY id, purchase_price, sale_price")
+
+ val shuffles = collectShuffles(df.queryExecution.executedPlan)
+ if (shuffle) {
+ assert(shuffles.size == 1, "only shuffle one side not report
partitioning")
+ } else {
+ assert(shuffles.size == 2, "should add two side shuffle when
bucketing shuffle one " +
+ "side is not enabled")
+ }
+ joinType match {
+ case "JOIN" =>
+ checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0,
19.5)))
+ case "LEFT OUTER JOIN" =>
+ checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0,
19.5),
+ Row(4, "cc", 15.5, null)))
+ case "RIGHT OUTER JOIN" =>
+ checkAnswer(df, Seq(Row(null, null, null, 26.0), Row(null, null,
null, 50.0),
+ Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 19.5)))
+ case "FULL OUTER JOIN" =>
+ checkAnswer(df, Seq(Row(null, null, null, 26.0), Row(null, null,
null, 50.0),
+ Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 19.5),
+ Row(4, "cc", 15.5, null)))
+ }
+ }
+ }
+ }
+ }
+
+ test("SPARK-41471: shuffle one side: only one side reports partitioning with
two identity") {
+ val items_partitions = Array(identity("id"), identity("arrive_time"))
+ createTable(items, items_schema, items_partitions)
+
+ sql(s"INSERT INTO testcat.ns.$items VALUES " +
+ "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+ "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+ "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")
+
+ createTable(purchases, purchases_schema, Array.empty)
+ sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
+ "(1, 42.0, cast('2020-01-01' as timestamp)), " +
+ "(3, 19.5, cast('2020-02-01' as timestamp))")
+
+ Seq(true, false).foreach { shuffle =>
+ withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key ->
shuffle.toString) {
+ val df = sql("SELECT id, name, i.price as purchase_price, p.price as
sale_price " +
+ s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " +
+ "ON i.id = p.item_id and i.arrive_time = p.time ORDER BY id,
purchase_price, sale_price")
+
+ val shuffles = collectShuffles(df.queryExecution.executedPlan)
+ if (shuffle) {
+ assert(shuffles.size == 1, "only shuffle one side not report
partitioning")
+ } else {
+ assert(shuffles.size == 2, "should add two side shuffle when
bucketing shuffle one side" +
+ " is not enabled")
+ }
+
+ checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0)))
+ }
+ }
+ }
+
+ test("SPARK-41471: shuffle one side: partitioning with transform") {
+ val items_partitions = Array(years("arrive_time"))
+ createTable(items, items_schema, items_partitions)
+
+ sql(s"INSERT INTO testcat.ns.$items VALUES " +
+ "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+ "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+ "(4, 'cc', 15.5, cast('2021-02-01' as timestamp))")
+
+ createTable(purchases, purchases_schema, Array.empty)
+ sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
+ "(1, 42.0, cast('2020-01-01' as timestamp)), " +
+ "(3, 19.5, cast('2021-02-01' as timestamp))")
+
+ Seq(true, false).foreach { shuffle =>
+ withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key ->
shuffle.toString) {
+ val df = sql("SELECT id, name, i.price as purchase_price, p.price as
sale_price " +
+ s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " +
+ "ON i.arrive_time = p.time ORDER BY id, purchase_price, sale_price")
+
+ val shuffles = collectShuffles(df.queryExecution.executedPlan)
+ if (shuffle) {
+ assert(shuffles.size == 2, "partitioning with transform not work
now")
+ } else {
+ assert(shuffles.size == 2, "should add two side shuffle when
bucketing shuffle one side" +
+ " is not enabled")
+ }
+
+ checkAnswer(df, Seq(
+ Row(1, "aa", 40.0, 42.0),
+ Row(3, "bb", 10.0, 42.0),
+ Row(4, "cc", 15.5, 19.5)))
+ }
+ }
+ }
+
+ test("SPARK-41471: shuffle one side: work with group partition split") {
+ val items_partitions = Array(identity("id"))
+ createTable(items, items_schema, items_partitions)
+
+ sql(s"INSERT INTO testcat.ns.$items VALUES " +
+ "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+ "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
+ "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")
+
+ createTable(purchases, purchases_schema, Array.empty)
+ sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
+ "(1, 42.0, cast('2020-01-01' as timestamp)), " +
+ "(3, 19.5, cast('2020-02-01' as timestamp)), " +
+ "(5, 26.0, cast('2023-01-01' as timestamp)), " +
+ "(6, 50.0, cast('2023-02-01' as timestamp))")
+
+ Seq(true, false).foreach { shuffle =>
+ withSQLConf(
+ SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> shuffle.toString,
+ SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
+ SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key ->
"true") {
+ val df = sql("SELECT id, name, i.price as purchase_price, p.price as
sale_price " +
+ s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " +
+ "ON i.id = p.item_id ORDER BY id, purchase_price, sale_price")
+
+ checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0,
19.5)))
+ }
+ }
+ }
+
test("SPARK-44641: duplicated records when SPJ is not triggered") {
val items_partitions = Array(bucket(8, "id"))
createTable(items, items_schema, items_partitions)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
index 09da1e1e7b0..3c9b92e5f66 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala
@@ -18,15 +18,17 @@
package org.apache.spark.sql.execution.exchange
import org.apache.spark.api.python.PythonEvalType
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
+import org.apache.spark.sql.catalyst.optimizer.BuildRight
import org.apache.spark.sql.catalyst.plans.Inner
-import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.catalyst.plans.physical.{SinglePartition, _}
import org.apache.spark.sql.catalyst.statsEstimation.StatsTestPlan
import org.apache.spark.sql.connector.catalog.functions._
import org.apache.spark.sql.execution.{DummySparkPlan, SortExec}
import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.joins.SortMergeJoinExec
+import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec,
SortMergeJoinExec}
import org.apache.spark.sql.execution.python.FlatMapCoGroupsInPandasExec
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.internal.SQLConf
@@ -1109,6 +1111,32 @@ class EnsureRequirementsSuite extends SharedSparkSession
{
}
}
+ test("SPARK-41471: shuffle right side when" +
+ " spark.sql.sources.v2.bucketing.shuffle.enabled is true") {
+ withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true") {
+
+ val a1 = AttributeReference("a1", IntegerType)()
+
+ val partitionValue = Seq(50, 51, 52).map(v =>
InternalRow.fromSeq(Seq(v)))
+ val plan1 = DummySparkPlan(outputPartitioning = KeyGroupedPartitioning(
+ identity(a1) :: Nil, 4, partitionValue))
+ val plan2 = DummySparkPlan(outputPartitioning = SinglePartition)
+
+ val smjExec = ShuffledHashJoinExec(
+ a1 :: Nil, a1 :: Nil, Inner, BuildRight, None, plan1, plan2)
+ EnsureRequirements.apply(smjExec) match {
+ case ShuffledHashJoinExec(_, _, _, _, _,
+ DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _),
+ ShuffleExchangeExec(KeyGroupedPartitioning(attrs, 4, pv),
+ DummySparkPlan(_, _, SinglePartition, _, _), _, _), _) =>
+ assert(left.expressions == a1 :: Nil)
+ assert(attrs == a1 :: Nil)
+ assert(partitionValue == pv)
+ case other => fail(other.toString)
+ }
+ }
+ }
+
test("SPARK-42168: FlatMapCoGroupInPandas and Window function with differing
key order") {
val lKey = AttributeReference("key", IntegerType)()
val lKey2 = AttributeReference("key2", IntegerType)()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]