This is an automated email from the ASF dual-hosted git repository.

hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 73100f49a [VL] Make conf option 
`s.g.s.c.shuffledHashJoin.optimizeBuildSide` work correctly with option 
`s.g.s.c.forceShuffledHashJoin` (#7186)
73100f49a is described below

commit 73100f49a705f838a3e3de655c265d8c83809fde
Author: Hongze Zhang <[email protected]>
AuthorDate: Wed Sep 11 18:51:48 2024 +0800

    [VL] Make conf option `s.g.s.c.shuffledHashJoin.optimizeBuildSide` work 
correctly with option `s.g.s.c.forceShuffledHashJoin` (#7186)
---
 .../gluten/extension/columnar/FallbackRules.scala  | 16 +++----
 .../extension/columnar/OffloadSingleNode.scala     | 54 ++++++++++++++--------
 .../extension/columnar/rewrite/RewriteJoin.scala   | 45 ++++++++++--------
 3 files changed, 69 insertions(+), 46 deletions(-)

diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala
 
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala
index 6b043fbce..a5bba46dc 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala
@@ -333,7 +333,7 @@ case class AddFallbackTagRule() extends Rule[SparkPlan] {
               plan.leftKeys,
               plan.rightKeys,
               plan.joinType,
-              OffloadJoin.getBuildSide(plan),
+              OffloadJoin.getShjBuildSide(plan),
               plan.condition,
               plan.left,
               plan.right,
@@ -443,13 +443,13 @@ case class AddFallbackTagRule() extends Rule[SparkPlan] {
             offset)
           transformer.doValidate().tagOnFallback(plan)
         case plan: SampleExec =>
-          val transformer = 
BackendsApiManager.getSparkPlanExecApiInstance.genSampleExecTransformer(
-            plan.lowerBound,
-            plan.upperBound,
-            plan.withReplacement,
-            plan.seed,
-            plan.child
-          )
+          val transformer =
+            
BackendsApiManager.getSparkPlanExecApiInstance.genSampleExecTransformer(
+              plan.lowerBound,
+              plan.upperBound,
+              plan.withReplacement,
+              plan.seed,
+              plan.child)
           transformer.doValidate().tagOnFallback(plan)
         case _ =>
         // Currently we assume a plan to be offload-able by default.
diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala
 
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala
index 6047789e6..cdc71f447 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala
@@ -118,7 +118,6 @@ case class OffloadExchange() extends OffloadSingleNode with 
LogLevelUtil {
 
 // Join transformation.
 case class OffloadJoin() extends OffloadSingleNode with LogLevelUtil {
-
   override def offload(plan: SparkPlan): SparkPlan = {
     if (FallbackTags.nonEmpty(plan)) {
       logDebug(s"Columnar Processing for ${plan.getClass} is under row guard.")
@@ -134,7 +133,7 @@ case class OffloadJoin() extends OffloadSingleNode with 
LogLevelUtil {
             plan.leftKeys,
             plan.rightKeys,
             plan.joinType,
-            OffloadJoin.getBuildSide(plan),
+            OffloadJoin.getShjBuildSide(plan),
             plan.condition,
             left,
             right,
@@ -186,37 +185,53 @@ case class OffloadJoin() extends OffloadSingleNode with 
LogLevelUtil {
 }
 
 object OffloadJoin {
-
-  def getBuildSide(shj: ShuffledHashJoinExec): BuildSide = {
+  def getShjBuildSide(shj: ShuffledHashJoinExec): BuildSide = {
     val leftBuildable =
       
BackendsApiManager.getSettings.supportHashBuildJoinTypeOnLeft(shj.joinType)
     val rightBuildable =
       
BackendsApiManager.getSettings.supportHashBuildJoinTypeOnRight(shj.joinType)
+
+    assert(leftBuildable || rightBuildable)
+
     if (!leftBuildable) {
       return BuildRight
     }
     if (!rightBuildable) {
       return BuildLeft
     }
+
     // Both left and right are buildable. Find out the better one.
     if (!GlutenConfig.getConf.shuffledHashJoinOptimizeBuildSide) {
+      // User disabled build side re-optimization. Return original build side 
from vanilla Spark.
       return shj.buildSide
     }
-    shj.logicalLink match {
-      case Some(join: Join) =>
-        val leftSize = join.left.stats.sizeInBytes
-        val rightSize = join.right.stats.sizeInBytes
-        val leftRowCount = join.left.stats.rowCount
-        val rightRowCount = join.right.stats.rowCount
-        if (rightSize == leftSize && rightRowCount.isDefined && 
leftRowCount.isDefined) {
-          if (rightRowCount.get <= leftRowCount.get) BuildRight
-          else BuildLeft
-        } else if (rightSize <= leftSize) BuildRight
-        else BuildLeft
-      // Only the ShuffledHashJoinExec generated directly in some spark tests 
is not link
-      // logical plan, such as OuterJoinSuite.
-      case _ => shj.buildSide
+    shj.logicalLink
+      .flatMap {
+        case join: Join => Some(getOptimalBuildSide(join))
+        case _ => None
+      }
+      .getOrElse {
+        // Some shj operators generated in certain Spark tests such as 
OuterJoinSuite,
+        // could possibly have no logical link set.
+        shj.buildSide
+      }
+  }
+
+  def getOptimalBuildSide(join: Join): BuildSide = {
+    val leftSize = join.left.stats.sizeInBytes
+    val rightSize = join.right.stats.sizeInBytes
+    val leftRowCount = join.left.stats.rowCount
+    val rightRowCount = join.right.stats.rowCount
+    if (leftSize == rightSize && rightRowCount.isDefined && 
leftRowCount.isDefined) {
+      if (rightRowCount.get <= leftRowCount.get) {
+        return BuildRight
+      }
+      return BuildLeft
     }
+    if (rightSize <= leftSize) {
+      return BuildRight
+    }
+    BuildLeft
   }
 }
 
@@ -332,8 +347,7 @@ object OffloadOthers {
             plan.partitionColumns,
             plan.bucketSpec,
             plan.options,
-            plan.staticPartitions
-          )
+            plan.staticPartitions)
         case plan: SortExec =>
           val child = plan.child
           logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteJoin.scala
 
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteJoin.scala
index 4fd420b02..d0cac0b29 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteJoin.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteJoin.scala
@@ -17,33 +17,43 @@
 package org.apache.gluten.extension.columnar.rewrite
 
 import org.apache.gluten.GlutenConfig
+import org.apache.gluten.extension.columnar.OffloadJoin
 
 import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, 
BuildSide, JoinSelectionHelper}
-import org.apache.spark.sql.catalyst.plans.JoinType
+import org.apache.spark.sql.catalyst.plans.logical.Join
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, 
SortMergeJoinExec}
 
-/**
- * If force ShuffledHashJoin, convert [[SortMergeJoinExec]] to 
[[ShuffledHashJoinExec]]. There is no
- * need to select a smaller table as buildSide here, it will be reselected 
when offloading.
- */
+/** If force ShuffledHashJoin, convert [[SortMergeJoinExec]] to 
[[ShuffledHashJoinExec]]. */
 object RewriteJoin extends RewriteSingleNode with JoinSelectionHelper {
-
-  private def getBuildSide(joinType: JoinType): Option[BuildSide] = {
-    val leftBuildable = canBuildShuffledHashJoinLeft(joinType)
-    val rightBuildable = canBuildShuffledHashJoinRight(joinType)
-    if (rightBuildable) {
-      Some(BuildRight)
-    } else if (leftBuildable) {
-      Some(BuildLeft)
-    } else {
-      None
+  private def getSmjBuildSide(join: SortMergeJoinExec): Option[BuildSide] = {
+    val leftBuildable = canBuildShuffledHashJoinLeft(join.joinType)
+    val rightBuildable = canBuildShuffledHashJoinRight(join.joinType)
+    if (!leftBuildable && !rightBuildable) {
+      return None
+    }
+    if (!leftBuildable) {
+      return Some(BuildRight)
     }
+    if (!rightBuildable) {
+      return Some(BuildLeft)
+    }
+    val side = join.logicalLink
+      .flatMap {
+        case join: Join => Some(OffloadJoin.getOptimalBuildSide(join))
+        case _ => None
+      }
+      .getOrElse {
+        // If smj has no logical link, or its logical link is not a join,
+        // then we always choose left as build side.
+        BuildLeft
+      }
+    Some(side)
   }
 
   override def rewrite(plan: SparkPlan): SparkPlan = plan match {
     case smj: SortMergeJoinExec if GlutenConfig.getConf.forceShuffledHashJoin 
=>
-      getBuildSide(smj.joinType) match {
+      getSmjBuildSide(smj) match {
         case Some(buildSide) =>
           ShuffledHashJoinExec(
             smj.leftKeys,
@@ -53,8 +63,7 @@ object RewriteJoin extends RewriteSingleNode with 
JoinSelectionHelper {
             smj.condition,
             smj.left,
             smj.right,
-            smj.isSkewJoin
-          )
+            smj.isSkewJoin)
         case _ => plan
       }
     case _ => plan


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to