This is an automated email from the ASF dual-hosted git repository.
sunchao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 206cc1a554ef [SPARK-48613][SQL] SPJ: Support auto-shuffle one side +
less join keys than partition keys
206cc1a554ef is described below
commit 206cc1a554ef130cfa0aceb1d4ee330e5c83de5f
Author: Szehon Ho <[email protected]>
AuthorDate: Sun Jul 14 16:53:32 2024 -0700
[SPARK-48613][SQL] SPJ: Support auto-shuffle one side + less join keys than
partition keys
### What changes were proposed in this pull request?
This is the final planned SPJ scenario: auto-shuffle one side + less join
keys than partition keys. Background:
- Auto-shuffle works by creating ShuffleExchange for the non-partitioned
side, with a clone of the partitioned side's KeyGroupedPartitioning.
- "Less join key than partition key" works by 'projecting' all partition
values by join keys (ie, keeping only partition columns that are join columns).
It makes a target KeyGroupedShuffleSpec with 'projected' partition values, and
then pushes this down to BatchScanExec. The BatchScanExec then 'groups' its
projected partition value (except in the skew case but that's a different
story..).
This combination is hard because the SPJ planning calls is spread in
several places in this scenario. Given two sides, a non-partitioned side and a
partitioned side, and the join keys are only a subset:
1. EnsureRequirements creates the target KeyGroupedShuffleSpec from the
join's required distribution (ie, using only the join keys, not all partition
keys).
2. EnsureRequirements copies this to the non-partitoned side's
KeyGroupedPartition (for the auto-shuffle case)
3. BatchScanExec groups the partitions (for the partitioned side),
including by join keys (if they differ from partition keys).
Take the example partition columns (id, name) , and partition values: (1,
"bob"), (2, "alice"), (2, "sam").
Projection leaves us (1, 2, 2), and the final grouped partition values are
(1, 2).
The problem is, that the two sides of the join do not match at all times.
After the steps 1 and 2, the partitioned side has the 'projected' partition
values (1, 2, 2), and the non-partitioned side creates a matching
KeyGroupedPartitioning (1, 2, 2) for ShuffleExechange. But on step 3, the
BatchScanExec for partitioned side groups the partitions to become (1, 2), but
the non-partitioned side does not group and still retains (1, 2, 2) partitions.
This leads to following assert error [...]
```
requirement failed: PartitioningCollection requires all of its
partitionings have the same numPartitions.
java.lang.IllegalArgumentException: requirement failed:
PartitioningCollection requires all of its partitionings have the same
numPartitions.
at scala.Predef$.require(Predef.scala:337)
at
org.apache.spark.sql.catalyst.plans.physical.PartitioningCollection.<init>(partitioning.scala:550)
at
org.apache.spark.sql.execution.joins.ShuffledJoin.outputPartitioning(ShuffledJoin.scala:49)
at
org.apache.spark.sql.execution.joins.ShuffledJoin.outputPartitioning$(ShuffledJoin.scala:47)
at
org.apache.spark.sql.execution.joins.SortMergeJoinExec.outputPartitioning(SortMergeJoinExec.scala:39)
at
org.apache.spark.sql.execution.exchange.EnsureRequirements.$anonfun$ensureDistributionAndOrdering$1(EnsureRequirements.scala:66)
at scala.collection.immutable.Vector1.map(Vector.scala:2140)
at scala.collection.immutable.Vector1.map(Vector.scala:385)
at
org.apache.spark.sql.execution.exchange.EnsureRequirements.org$apache$spark$sql$execution$exchange$EnsureRequirements$$ensureDistributionAndOrdering(EnsureRequirements.scala:65)
at
org.apache.spark.sql.execution.exchange.EnsureRequirements$$anonfun$1.applyOrElse(EnsureRequirements.scala:657)
at
org.apache.spark.sql.execution.exchange.EnsureRequirements$$anonfun$1.applyOrElse(EnsureRequirements.scala:632)
```
The fix is to do the de-duplication in first pass.
1. Pushing down join keys to the BatchScanExec to return a de-duped
outputPartitioning (partitioned side)
2. Creating the non-partitioned side's KeyGroupedPartitioning with de-duped
partition keys (non-partitioned side).
### Why are the changes needed?
This is the last planned scenario for SPJ not yet supported.
### How was this patch tested?
Update existing unit test in KeyGroupedPartitionSuite
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #47064 from szehon-ho/spj_less_join_key_auto_shuffle.
Authored-by: Szehon Ho <[email protected]>
Signed-off-by: Chao Sun <[email protected]>
---
.../sql/catalyst/plans/physical/partitioning.scala | 25 +++++++++----------
.../execution/datasources/v2/BatchScanExec.scala | 28 ++++++++++++++++------
.../execution/exchange/EnsureRequirements.scala | 26 +++++++++++++++++++-
.../connector/KeyGroupedPartitioningSuite.scala | 3 +--
4 files changed, 58 insertions(+), 24 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index 19595eef10b3..f8e980747bf2 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -434,8 +434,13 @@ object KeyGroupedPartitioning {
val projectedOriginalPartitionValues =
originalPartitionValues.map(project(expressions, projectionPositions, _))
- KeyGroupedPartitioning(projectedExpressions,
projectedPartitionValues.length,
- projectedPartitionValues, projectedOriginalPartitionValues)
+ val finalPartitionValues = projectedPartitionValues
+ .map(InternalRowComparableWrapper(_, projectedExpressions))
+ .distinct
+ .map(_.row)
+
+ KeyGroupedPartitioning(projectedExpressions, finalPartitionValues.length,
+ finalPartitionValues, projectedOriginalPartitionValues)
}
def project(
@@ -871,20 +876,12 @@ case class KeyGroupedShuffleSpec(
if (results.forall(p => p.isEmpty)) None else Some(results)
}
- override def canCreatePartitioning: Boolean = {
- // Allow one side shuffle for SPJ for now only if partially-clustered is
not enabled
- // and for join keys less than partition keys only if transforms are not
enabled.
- val checkExprType = if
(SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
- e: Expression => e.isInstanceOf[AttributeReference]
- } else {
- e: Expression => e.isInstanceOf[AttributeReference] ||
e.isInstanceOf[TransformExpression]
- }
+ override def canCreatePartitioning: Boolean =
SQLConf.get.v2BucketingShuffleEnabled &&
!SQLConf.get.v2BucketingPartiallyClusteredDistributionEnabled &&
- partitioning.expressions.forall(checkExprType)
- }
-
-
+ partitioning.expressions.forall { e =>
+ e.isInstanceOf[AttributeReference] ||
e.isInstanceOf[TransformExpression]
+ }
override def createPartitioning(clustering: Seq[Expression]): Partitioning =
{
val newExpressions: Seq[Expression] =
clustering.zip(partitioning.expressions).map {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
index f949dbf71a37..997576a396d2 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
@@ -118,16 +118,29 @@ case class BatchScanExec(
override def outputPartitioning: Partitioning = {
super.outputPartitioning match {
- case k: KeyGroupedPartitioning if
spjParams.commonPartitionValues.isDefined =>
- // We allow duplicated partition values if
- //
`spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled` is true
- val newPartValues = spjParams.commonPartitionValues.get.flatMap {
- case (partValue, numSplits) => Seq.fill(numSplits)(partValue)
- }
+ case k: KeyGroupedPartitioning =>
val expressions = spjParams.joinKeyPositions match {
case Some(projectionPositions) => projectionPositions.map(i =>
k.expressions(i))
case _ => k.expressions
}
+
+ val newPartValues = spjParams.commonPartitionValues match {
+ case Some(commonPartValues) =>
+ // We allow duplicated partition values if
+ //
`spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled` is true
+ commonPartValues.flatMap {
+ case (partValue, numSplits) => Seq.fill(numSplits)(partValue)
+ }
+ case None =>
+ spjParams.joinKeyPositions match {
+ case Some(projectionPositions) => k.partitionValues.map{r =>
+ val projectedRow = KeyGroupedPartitioning.project(expressions,
+ projectionPositions, r)
+ InternalRowComparableWrapper(projectedRow, expressions)
+ }.distinct.map(_.row)
+ case _ => k.partitionValues
+ }
+ }
k.copy(expressions = expressions, numPartitions = newPartValues.length,
partitionValues = newPartValues)
case p => p
@@ -279,7 +292,8 @@ case class StoragePartitionJoinParams(
case other: StoragePartitionJoinParams =>
this.commonPartitionValues == other.commonPartitionValues &&
this.replicatePartitions == other.replicatePartitions &&
- this.applyPartialClustering == other.applyPartialClustering
+ this.applyPartialClustering == other.applyPartialClustering &&
+ this.joinKeyPositions == other.joinKeyPositions
case _ =>
false
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index 67d879bdd8bf..0470aacd4f82 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -175,7 +175,16 @@ case class EnsureRequirements(
child
case ((child, dist), idx) =>
if (bestSpecOpt.isDefined &&
bestSpecOpt.get.isCompatibleWith(specs(idx))) {
- child
+ bestSpecOpt match {
+ // If keyGroupCompatible = false, we can still perform SPJ
+ // by shuffling the other side based on join keys (see the else
case below).
+ // Hence we need to ensure that after this call, the
outputPartitioning of the
+ // partitioned side's BatchScanExec is grouped by join keys to
match,
+ // and we do that by pushing down the join keys
+ case Some(KeyGroupedShuffleSpec(_, _, Some(joinKeyPositions))) =>
+ populateJoinKeyPositions(child, Some(joinKeyPositions))
+ case _ => child
+ }
} else {
val newPartitioning = bestSpecOpt.map { bestSpec =>
// Use the best spec to create a new partitioning to re-shuffle
this child
@@ -578,6 +587,21 @@ case class EnsureRequirements(
child, values, joinKeyPositions, reducers, applyPartialClustering,
replicatePartitions))
}
+
+ private def populateJoinKeyPositions(
+ plan: SparkPlan,
+ joinKeyPositions: Option[Seq[Int]]): SparkPlan = plan match {
+ case scan: BatchScanExec =>
+ scan.copy(
+ spjParams = scan.spjParams.copy(
+ joinKeyPositions = joinKeyPositions
+ )
+ )
+ case node =>
+ node.mapChildren(child => populateJoinKeyPositions(
+ child, joinKeyPositions))
+ }
+
private def reduceCommonPartValues(
commonPartValues: Seq[(InternalRow, Int)],
expressions: Seq[Expression],
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
index d77a6e8b8ac1..5e5453b4cd50 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
@@ -2168,8 +2168,7 @@ class KeyGroupedPartitioningSuite extends
DistributionAndOrderingSuiteBase {
SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key ->
"true") {
val df = createJoinTestDF(Seq("id" -> "item_id"))
val shuffles = collectShuffles(df.queryExecution.executedPlan)
- assert(shuffles.size == 2, "SPJ should not be triggered for transform
expression with" +
- "less join keys than partition keys for now.")
+ assert(shuffles.size == 1, "SPJ should be triggered")
checkAnswer(df, Seq(Row(1, "aa", 30.0, 42.0),
Row(1, "aa", 30.0, 89.0),
Row(1, "aa", 40.0, 42.0),
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]