This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new dc7b5a9  [SPARK-35282][SQL] Support AQE side shuffled hash join 
formula using rule
dc7b5a9 is described below

commit dc7b5a99f07e172de2c7530e6a7c14dcf6e4f217
Author: ulysses-you <[email protected]>
AuthorDate: Wed May 26 14:16:04 2021 +0000

    [SPARK-35282][SQL] Support AQE side shuffled hash join formula using rule
    
    ### What changes were proposed in this pull request?
    
    The main code change is:
    * Change rule `DemoteBroadcastHashJoin` to `DynamicJoinSelection` and add 
shuffle hash join selection code.
    * Specify a join strategy hint `SHUFFLE_HASH` if AQE think a join can be 
converted to SHJ.
    * Skip `preferSortMerge` config check in AQE side if a join can be 
converted to SHJ.
    
    ### Why are the changes needed?
    
    Use AQE runtime statistics to decide if we can use shuffled hash join 
instead of sort merge join. Currently, the formula of shuffled hash join 
selection dose not work due to the dymanic shuffle partition number.
    
    Add a new config spark.sql.adaptive.shuffledHashJoinLocalMapThreshold to 
decide if join can be converted to shuffled hash join safely.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, add a new config.
    
    ### How was this patch tested?
    
    Add test.
    
    Closes #32550 from ulysses-you/SPARK-35282-2.
    
    Authored-by: ulysses-you <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../spark/sql/catalyst/optimizer/joins.scala       | 28 ++++++-
 .../spark/sql/catalyst/plans/logical/hints.scala   |  8 ++
 .../org/apache/spark/sql/internal/SQLConf.scala    | 11 +++
 .../spark/sql/execution/SparkStrategies.scala      |  8 +-
 .../sql/execution/adaptive/AQEOptimizer.scala      |  2 +-
 .../adaptive/DemoteBroadcastHashJoin.scala         | 57 -------------
 .../execution/adaptive/DynamicJoinSelection.scala  | 94 ++++++++++++++++++++++
 .../adaptive/AdaptiveQueryExecSuite.scala          | 51 ++++++++++++
 8 files changed, 192 insertions(+), 67 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
index 8431b31..9d698c9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
@@ -253,12 +253,28 @@ trait JoinSelectionHelper {
     val buildLeft = if (hintOnly) {
       hintToShuffleHashJoinLeft(hint)
     } else {
-      canBuildLocalHashMapBySize(left, conf) && muchSmaller(left, right)
+      if (hintToPreferShuffleHashJoinLeft(hint)) {
+        true
+      } else {
+        if (!conf.preferSortMergeJoin) {
+          canBuildLocalHashMapBySize(left, conf) && muchSmaller(left, right)
+        } else {
+          false
+        }
+      }
     }
     val buildRight = if (hintOnly) {
       hintToShuffleHashJoinRight(hint)
     } else {
-      canBuildLocalHashMapBySize(right, conf) && muchSmaller(right, left)
+      if (hintToPreferShuffleHashJoinRight(hint)) {
+        true
+      } else {
+        if (!conf.preferSortMergeJoin) {
+          canBuildLocalHashMapBySize(right, conf) && muchSmaller(right, left)
+        } else {
+          false
+        }
+      }
     }
     getBuildSide(
       canBuildShuffledHashJoinLeft(joinType) && buildLeft,
@@ -345,6 +361,14 @@ trait JoinSelectionHelper {
     hint.rightHint.exists(_.strategy.contains(SHUFFLE_HASH))
   }
 
+  def hintToPreferShuffleHashJoinLeft(hint: JoinHint): Boolean = {
+    hint.leftHint.exists(_.strategy.contains(PREFER_SHUFFLE_HASH))
+  }
+
+  def hintToPreferShuffleHashJoinRight(hint: JoinHint): Boolean = {
+    hint.rightHint.exists(_.strategy.contains(PREFER_SHUFFLE_HASH))
+  }
+
   def hintToSortMergeJoin(hint: JoinHint): Boolean = {
     hint.leftHint.exists(_.strategy.contains(SHUFFLE_MERGE)) ||
       hint.rightHint.exists(_.strategy.contains(SHUFFLE_MERGE))
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
index 5bda94c..0dfd98c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
@@ -173,6 +173,14 @@ case object NO_BROADCAST_HASH extends JoinStrategyHint {
 }
 
 /**
+ * An internal hint to encourage shuffle hash join, used by adaptive query 
execution.
+ */
+case object PREFER_SHUFFLE_HASH extends JoinStrategyHint {
+  override def displayName: String = "prefer_shuffle_hash"
+  override def hintAliases: Set[String] = Set.empty
+}
+
+/**
  * The callback for implementing customized strategies of handling hint errors.
  */
 trait HintErrorHandler {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 71ff082..3dfe4dc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -598,6 +598,17 @@ object SQLConf {
       .bytesConf(ByteUnit.BYTE)
       .createOptional
 
+  val ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD =
+    buildConf("spark.sql.adaptive.maxShuffledHashJoinLocalMapThreshold")
+      .doc("Configures the maximum size in bytes per partition that can be 
allowed to build " +
+        "local hash map. If this value is not smaller than " +
+        s"${ADVISORY_PARTITION_SIZE_IN_BYTES.key} and all the partition size 
are not larger " +
+        "than this config, join selection prefer to use shuffled hash join 
instead of " +
+        s"sort merge join regardless of the value of 
${PREFER_SORTMERGEJOIN.key}.")
+      .version("3.2.0")
+      .bytesConf(ByteUnit.BYTE)
+      .createWithDefault(0L)
+
   val SUBEXPRESSION_ELIMINATION_ENABLED =
     buildConf("spark.sql.subexpressionElimination.enabled")
       .internal()
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index b9e44b3..45d4c7d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -211,13 +211,7 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
 
         def createJoinWithoutHint() = {
           createBroadcastHashJoin(false)
-            .orElse {
-              if (!conf.preferSortMergeJoin) {
-                createShuffleHashJoin(false)
-              } else {
-                None
-              }
-            }
+            .orElse(createShuffleHashJoin(false))
             .orElse(createSortMergeJoin())
             .orElse(createCartesianProduct())
             .getOrElse {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala
index b28626b..95dc7cc 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala
@@ -31,7 +31,7 @@ class AQEOptimizer(conf: SQLConf) extends 
RuleExecutor[LogicalPlan] {
     Batch("Propagate Empty Relations", Once,
       AQEPropagateEmptyRelation,
       UpdateAttributeNullability),
-    Batch("Demote BroadcastHashJoin", Once, DemoteBroadcastHashJoin)
+    Batch("Dynamic Join Selection", Once, DynamicJoinSelection)
   )
 
   final override protected def batches: Seq[Batch] = {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/DemoteBroadcastHashJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/DemoteBroadcastHashJoin.scala
deleted file mode 100644
index 3760782..0000000
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/DemoteBroadcastHashJoin.scala
+++ /dev/null
@@ -1,57 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.adaptive
-
-import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, Join, 
LogicalPlan, NO_BROADCAST_HASH}
-import org.apache.spark.sql.catalyst.rules.Rule
-
-/**
- * This optimization rule detects a join child that has a high ratio of empty 
partitions and
- * adds a no-broadcast-hash-join hint to avoid it being broadcast.
- */
-object DemoteBroadcastHashJoin extends Rule[LogicalPlan] {
-
-  private def shouldDemote(plan: LogicalPlan): Boolean = plan match {
-    case LogicalQueryStage(_, stage: ShuffleQueryStageExec) if 
stage.resultOption.get().isDefined
-      && stage.mapStats.isDefined =>
-      val mapStats = stage.mapStats.get
-      val partitionCnt = mapStats.bytesByPartitionId.length
-      val nonZeroCnt = mapStats.bytesByPartitionId.count(_ > 0)
-      partitionCnt > 0 && nonZeroCnt > 0 &&
-        (nonZeroCnt * 1.0 / partitionCnt) < 
conf.nonEmptyPartitionRatioForBroadcastJoin
-    case _ => false
-  }
-
-  def apply(plan: LogicalPlan): LogicalPlan = plan.transformDown {
-    case j @ Join(left, right, _, _, hint) =>
-      var newHint = hint
-      if (!hint.leftHint.exists(_.strategy.isDefined) && shouldDemote(left)) {
-        newHint = newHint.copy(leftHint =
-          Some(hint.leftHint.getOrElse(HintInfo()).copy(strategy = 
Some(NO_BROADCAST_HASH))))
-      }
-      if (!hint.rightHint.exists(_.strategy.isDefined) && shouldDemote(right)) 
{
-        newHint = newHint.copy(rightHint =
-          Some(hint.rightHint.getOrElse(HintInfo()).copy(strategy = 
Some(NO_BROADCAST_HASH))))
-      }
-      if (newHint.ne(hint)) {
-        j.copy(hint = newHint)
-      } else {
-        j
-      }
-  }
-}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/DynamicJoinSelection.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/DynamicJoinSelection.scala
new file mode 100644
index 0000000..61124f0
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/DynamicJoinSelection.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.adaptive
+
+import org.apache.spark.MapOutputStatistics
+import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, Join, 
JoinStrategyHint, LogicalPlan, NO_BROADCAST_HASH, PREFER_SHUFFLE_HASH, 
SHUFFLE_HASH}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.internal.SQLConf
+
+/**
+ * This optimization rule includes three join selection:
+ *   1. detects a join child that has a high ratio of empty partitions and 
adds a
+ *      NO_BROADCAST_HASH hint to avoid it being broadcast, as shuffle join is 
faster in this case:
+ *      many tasks complete immediately since one join side is empty.
+ *   2. detects a join child that every partition size is less than local map 
threshold and adds a
+ *      PREFER_SHUFFLE_HASH hint to encourage being shuffle hash join instead 
of sort merge join.
+ *   3. if a join satisfies both NO_BROADCAST_HASH and PREFER_SHUFFLE_HASH,
+ *      then add a SHUFFLE_HASH hint.
+ */
+object DynamicJoinSelection extends Rule[LogicalPlan] {
+
+  private def shouldDemoteBroadcastHashJoin(mapStats: MapOutputStatistics): 
Boolean = {
+    val partitionCnt = mapStats.bytesByPartitionId.length
+    val nonZeroCnt = mapStats.bytesByPartitionId.count(_ > 0)
+    partitionCnt > 0 && nonZeroCnt > 0 &&
+      (nonZeroCnt * 1.0 / partitionCnt) < 
conf.nonEmptyPartitionRatioForBroadcastJoin
+  }
+
+  private def preferShuffledHashJoin(mapStats: MapOutputStatistics): Boolean = 
{
+    val maxShuffledHashJoinLocalMapThreshold =
+      conf.getConf(SQLConf.ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD)
+    val advisoryPartitionSize = 
conf.getConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES)
+    if (advisoryPartitionSize <= maxShuffledHashJoinLocalMapThreshold) {
+      mapStats.bytesByPartitionId.forall(_ <= 
maxShuffledHashJoinLocalMapThreshold)
+    } else {
+      false
+    }
+  }
+
+  private def selectJoinStrategy(plan: LogicalPlan): Option[JoinStrategyHint] 
= plan match {
+    case LogicalQueryStage(_, stage: ShuffleQueryStageExec) if 
stage.resultOption.get().isDefined
+      && stage.mapStats.isDefined =>
+      val demoteBroadcastHash = 
shouldDemoteBroadcastHashJoin(stage.mapStats.get)
+      val preferShuffleHash = preferShuffledHashJoin(stage.mapStats.get)
+      if (demoteBroadcastHash && preferShuffleHash) {
+        Some(SHUFFLE_HASH)
+      } else if (demoteBroadcastHash) {
+        Some(NO_BROADCAST_HASH)
+      } else if (preferShuffleHash) {
+        Some(PREFER_SHUFFLE_HASH)
+      } else {
+        None
+      }
+
+    case _ => None
+  }
+
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformDown {
+    case j @ Join(left, right, _, _, hint) =>
+      var newHint = hint
+      if (!hint.leftHint.exists(_.strategy.isDefined)) {
+        selectJoinStrategy(left).foreach { strategy =>
+          newHint = newHint.copy(leftHint =
+            Some(hint.leftHint.getOrElse(HintInfo()).copy(strategy = 
Some(strategy))))
+        }
+      }
+      if (!hint.rightHint.exists(_.strategy.isDefined)) {
+        selectJoinStrategy(right).foreach { strategy =>
+          newHint = newHint.copy(rightHint =
+            Some(hint.rightHint.getOrElse(HintInfo()).copy(strategy = 
Some(strategy))))
+        }
+      }
+      if (newHint.ne(hint)) {
+        j.copy(hint = newHint)
+      } else {
+        j
+      }
+  }
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
index d67adee..7151c51 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
@@ -1707,4 +1707,55 @@ class AdaptiveQueryExecSuite
       }
     }
   }
+
+  test("SPARK-35264: Support AQE side shuffled hash join formula") {
+    withTempView("t1", "t2") {
+      def checkJoinStrategy(shouldShuffleHashJoin: Boolean): Unit = {
+        Seq("100", "100000").foreach { size =>
+          withSQLConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> size) {
+            val (origin1, adaptive1) = runAdaptiveAndVerifyResult(
+              "SELECT t1.c1, t2.c1 FROM t1 JOIN t2 ON t1.c1 = t2.c1")
+            assert(findTopLevelSortMergeJoin(origin1).size === 1)
+            if (shouldShuffleHashJoin && size.toInt < 100000) {
+              val shj = findTopLevelShuffledHashJoin(adaptive1)
+              assert(shj.size === 1)
+              assert(shj.head.buildSide == BuildRight)
+            } else {
+              assert(findTopLevelSortMergeJoin(adaptive1).size === 1)
+            }
+          }
+        }
+        // respect user specified join hint
+        val (origin2, adaptive2) = runAdaptiveAndVerifyResult(
+          "SELECT /*+ MERGE(t1) */ t1.c1, t2.c1 FROM t1 JOIN t2 ON t1.c1 = 
t2.c1")
+        assert(findTopLevelSortMergeJoin(origin2).size === 1)
+        assert(findTopLevelSortMergeJoin(adaptive2).size === 1)
+      }
+
+      spark.sparkContext.parallelize(
+        (1 to 100).map(i => TestData(i, i.toString)), 10)
+        .toDF("c1", "c2").createOrReplaceTempView("t1")
+      spark.sparkContext.parallelize(
+        (1 to 10).map(i => TestData(i, i.toString)), 5)
+        .toDF("c1", "c2").createOrReplaceTempView("t2")
+
+      // t1 partition size: [926, 729, 731]
+      // t2 partition size: [318, 120, 0]
+      withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "3",
+        SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+        SQLConf.PREFER_SORTMERGEJOIN.key -> "true") {
+        // check default value
+        checkJoinStrategy(false)
+        
withSQLConf(SQLConf.ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD.key -> 
"400") {
+          checkJoinStrategy(true)
+        }
+        
withSQLConf(SQLConf.ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD.key -> 
"300") {
+          checkJoinStrategy(false)
+        }
+        
withSQLConf(SQLConf.ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD.key -> 
"1000") {
+          checkJoinStrategy(true)
+        }
+      }
+    }
+  }
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to