This is an automated email from the ASF dual-hosted git repository.

sunchao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 206cc1a554ef [SPARK-48613][SQL] SPJ: Support auto-shuffle one side + 
less join keys than partition keys
206cc1a554ef is described below

commit 206cc1a554ef130cfa0aceb1d4ee330e5c83de5f
Author: Szehon Ho <[email protected]>
AuthorDate: Sun Jul 14 16:53:32 2024 -0700

    [SPARK-48613][SQL] SPJ: Support auto-shuffle one side + less join keys than 
partition keys
    
    ### What changes were proposed in this pull request?
    
    This is the final planned SPJ scenario:  auto-shuffle one side + less join 
keys than partition keys.  Background:
    
    - Auto-shuffle works by creating ShuffleExchange for the non-partitioned 
side, with a clone of the partitioned side's KeyGroupedPartitioning.
    - "Less join key than partition key" works by 'projecting' all partition 
values by join keys (ie, keeping only partition columns that are join columns). 
 It makes a target KeyGroupedShuffleSpec with 'projected' partition values, and 
then pushes this down to BatchScanExec.  The BatchScanExec then 'groups' its 
projected partition value (except in the skew case but that's a different 
story..).
    
    This combination is hard because the SPJ planning calls is spread in 
several places in this scenario.  Given two sides, a non-partitioned side and a 
partitioned side, and the join keys are only a subset:
    
    1.  EnsureRequirements creates the target KeyGroupedShuffleSpec from the 
join's required distribution (ie, using only the join keys, not all partition 
keys).
    2.  EnsureRequirements copies this to the non-partitoned side's 
KeyGroupedPartition (for the auto-shuffle case)
    3.  BatchScanExec groups the partitions (for the partitioned side), 
including by join keys (if they differ from partition keys).
    
    Take the example partition columns (id, name) , and partition values: (1, 
"bob"), (2, "alice"), (2, "sam").
    Projection leaves us (1, 2, 2), and the final grouped partition values are 
(1, 2).
    
    The problem is, that the two sides of the join do not match at all times.  
After the steps 1 and 2, the partitioned side has the 'projected' partition 
values (1, 2, 2), and the non-partitioned side creates a matching 
KeyGroupedPartitioning (1, 2, 2) for ShuffleExechange.  But on step 3, the 
BatchScanExec for partitioned side groups the partitions to become (1, 2), but 
the non-partitioned side does not group and still retains (1, 2, 2) partitions. 
 This leads to following assert error  [...]
    
    ```
    requirement failed: PartitioningCollection requires all of its 
partitionings have the same numPartitions.
    java.lang.IllegalArgumentException: requirement failed: 
PartitioningCollection requires all of its partitionings have the same 
numPartitions.
            at scala.Predef$.require(Predef.scala:337)
            at 
org.apache.spark.sql.catalyst.plans.physical.PartitioningCollection.<init>(partitioning.scala:550)
            at 
org.apache.spark.sql.execution.joins.ShuffledJoin.outputPartitioning(ShuffledJoin.scala:49)
            at 
org.apache.spark.sql.execution.joins.ShuffledJoin.outputPartitioning$(ShuffledJoin.scala:47)
            at 
org.apache.spark.sql.execution.joins.SortMergeJoinExec.outputPartitioning(SortMergeJoinExec.scala:39)
            at 
org.apache.spark.sql.execution.exchange.EnsureRequirements.$anonfun$ensureDistributionAndOrdering$1(EnsureRequirements.scala:66)
            at scala.collection.immutable.Vector1.map(Vector.scala:2140)
            at scala.collection.immutable.Vector1.map(Vector.scala:385)
            at 
org.apache.spark.sql.execution.exchange.EnsureRequirements.org$apache$spark$sql$execution$exchange$EnsureRequirements$$ensureDistributionAndOrdering(EnsureRequirements.scala:65)
            at 
org.apache.spark.sql.execution.exchange.EnsureRequirements$$anonfun$1.applyOrElse(EnsureRequirements.scala:657)
            at 
org.apache.spark.sql.execution.exchange.EnsureRequirements$$anonfun$1.applyOrElse(EnsureRequirements.scala:632)
    ```
    
    The fix is to do the de-duplication in first pass.
    
    1. Pushing down join keys to the BatchScanExec to return a de-duped 
outputPartitioning (partitioned side)
    2. Creating the non-partitioned side's KeyGroupedPartitioning with de-duped 
partition keys (non-partitioned side).
    
      ### Why are the changes needed?
    
    This is the last planned scenario for SPJ not yet supported.
    
      ### How was this patch tested?
    Update existing unit test in KeyGroupedPartitionSuite
    
      ### Was this patch authored or co-authored using generative AI tooling?
     No.
    
    Closes #47064 from szehon-ho/spj_less_join_key_auto_shuffle.
    
    Authored-by: Szehon Ho <[email protected]>
    Signed-off-by: Chao Sun <[email protected]>
---
 .../sql/catalyst/plans/physical/partitioning.scala | 25 +++++++++----------
 .../execution/datasources/v2/BatchScanExec.scala   | 28 ++++++++++++++++------
 .../execution/exchange/EnsureRequirements.scala    | 26 +++++++++++++++++++-
 .../connector/KeyGroupedPartitioningSuite.scala    |  3 +--
 4 files changed, 58 insertions(+), 24 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index 19595eef10b3..f8e980747bf2 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -434,8 +434,13 @@ object KeyGroupedPartitioning {
     val projectedOriginalPartitionValues =
       originalPartitionValues.map(project(expressions, projectionPositions, _))
 
-    KeyGroupedPartitioning(projectedExpressions, 
projectedPartitionValues.length,
-      projectedPartitionValues, projectedOriginalPartitionValues)
+    val finalPartitionValues = projectedPartitionValues
+        .map(InternalRowComparableWrapper(_, projectedExpressions))
+        .distinct
+        .map(_.row)
+
+    KeyGroupedPartitioning(projectedExpressions, finalPartitionValues.length,
+      finalPartitionValues, projectedOriginalPartitionValues)
   }
 
   def project(
@@ -871,20 +876,12 @@ case class KeyGroupedShuffleSpec(
     if (results.forall(p => p.isEmpty)) None else Some(results)
   }
 
-  override def canCreatePartitioning: Boolean = {
-    // Allow one side shuffle for SPJ for now only if partially-clustered is 
not enabled
-    // and for join keys less than partition keys only if transforms are not 
enabled.
-    val checkExprType = if 
(SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
-      e: Expression => e.isInstanceOf[AttributeReference]
-    } else {
-      e: Expression => e.isInstanceOf[AttributeReference] || 
e.isInstanceOf[TransformExpression]
-    }
+  override def canCreatePartitioning: Boolean =
     SQLConf.get.v2BucketingShuffleEnabled &&
       !SQLConf.get.v2BucketingPartiallyClusteredDistributionEnabled &&
-      partitioning.expressions.forall(checkExprType)
-  }
-
-
+      partitioning.expressions.forall { e =>
+        e.isInstanceOf[AttributeReference] || 
e.isInstanceOf[TransformExpression]
+      }
 
   override def createPartitioning(clustering: Seq[Expression]): Partitioning = 
{
     val newExpressions: Seq[Expression] = 
clustering.zip(partitioning.expressions).map {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
index f949dbf71a37..997576a396d2 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
@@ -118,16 +118,29 @@ case class BatchScanExec(
 
   override def outputPartitioning: Partitioning = {
     super.outputPartitioning match {
-      case k: KeyGroupedPartitioning if 
spjParams.commonPartitionValues.isDefined =>
-        // We allow duplicated partition values if
-        // 
`spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled` is true
-        val newPartValues = spjParams.commonPartitionValues.get.flatMap {
-          case (partValue, numSplits) => Seq.fill(numSplits)(partValue)
-        }
+      case k: KeyGroupedPartitioning =>
         val expressions = spjParams.joinKeyPositions match {
           case Some(projectionPositions) => projectionPositions.map(i => 
k.expressions(i))
           case _ => k.expressions
         }
+
+        val newPartValues = spjParams.commonPartitionValues match {
+          case Some(commonPartValues) =>
+            // We allow duplicated partition values if
+            // 
`spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled` is true
+             commonPartValues.flatMap {
+               case (partValue, numSplits) => Seq.fill(numSplits)(partValue)
+             }
+          case None =>
+            spjParams.joinKeyPositions match {
+              case Some(projectionPositions) => k.partitionValues.map{r =>
+                val projectedRow = KeyGroupedPartitioning.project(expressions,
+                  projectionPositions, r)
+                InternalRowComparableWrapper(projectedRow, expressions)
+              }.distinct.map(_.row)
+              case _ => k.partitionValues
+            }
+        }
         k.copy(expressions = expressions, numPartitions = newPartValues.length,
           partitionValues = newPartValues)
       case p => p
@@ -279,7 +292,8 @@ case class StoragePartitionJoinParams(
     case other: StoragePartitionJoinParams =>
       this.commonPartitionValues == other.commonPartitionValues &&
       this.replicatePartitions == other.replicatePartitions &&
-      this.applyPartialClustering == other.applyPartialClustering
+      this.applyPartialClustering == other.applyPartialClustering &&
+      this.joinKeyPositions == other.joinKeyPositions
     case _ =>
       false
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index 67d879bdd8bf..0470aacd4f82 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -175,7 +175,16 @@ case class EnsureRequirements(
           child
         case ((child, dist), idx) =>
           if (bestSpecOpt.isDefined && 
bestSpecOpt.get.isCompatibleWith(specs(idx))) {
-            child
+            bestSpecOpt match {
+              // If keyGroupCompatible = false, we can still perform SPJ
+              // by shuffling the other side based on join keys (see the else 
case below).
+              // Hence we need to ensure that after this call, the 
outputPartitioning of the
+              // partitioned side's BatchScanExec is grouped by join keys to 
match,
+              // and we do that by pushing down the join keys
+              case Some(KeyGroupedShuffleSpec(_, _, Some(joinKeyPositions))) =>
+                populateJoinKeyPositions(child, Some(joinKeyPositions))
+              case _ => child
+            }
           } else {
             val newPartitioning = bestSpecOpt.map { bestSpec =>
               // Use the best spec to create a new partitioning to re-shuffle 
this child
@@ -578,6 +587,21 @@ case class EnsureRequirements(
         child, values, joinKeyPositions, reducers, applyPartialClustering, 
replicatePartitions))
   }
 
+
+  private def populateJoinKeyPositions(
+      plan: SparkPlan,
+      joinKeyPositions: Option[Seq[Int]]): SparkPlan = plan match {
+    case scan: BatchScanExec =>
+      scan.copy(
+        spjParams = scan.spjParams.copy(
+          joinKeyPositions = joinKeyPositions
+        )
+      )
+    case node =>
+      node.mapChildren(child => populateJoinKeyPositions(
+        child, joinKeyPositions))
+  }
+
   private def reduceCommonPartValues(
       commonPartValues: Seq[(InternalRow, Int)],
       expressions: Seq[Expression],
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
index d77a6e8b8ac1..5e5453b4cd50 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
@@ -2168,8 +2168,7 @@ class KeyGroupedPartitioningSuite extends 
DistributionAndOrderingSuiteBase {
      SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> 
"true") {
      val df = createJoinTestDF(Seq("id" -> "item_id"))
      val shuffles = collectShuffles(df.queryExecution.executedPlan)
-     assert(shuffles.size == 2, "SPJ should not be triggered for transform 
expression with" +
-       "less join keys than partition keys for now.")
+     assert(shuffles.size == 1, "SPJ should be triggered")
      checkAnswer(df, Seq(Row(1, "aa", 30.0, 42.0),
        Row(1, "aa", 30.0, 89.0),
        Row(1, "aa", 40.0, 42.0),


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to