This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new ef3305280 perf: Update RewriteJoin logic to choose optimal build side 
(#1424)
ef3305280 is described below

commit ef3305280762337c1029179f1251196dfd6bce45
Author: Andy Grove <[email protected]>
AuthorDate: Fri Feb 21 15:21:42 2025 -0700

    perf: Update RewriteJoin logic to choose optimal build side (#1424)
---
 .../scala/org/apache/comet/rules/RewriteJoin.scala | 51 ++++++++++++++++++----
 1 file changed, 42 insertions(+), 9 deletions(-)

diff --git a/spark/src/main/scala/org/apache/comet/rules/RewriteJoin.scala 
b/spark/src/main/scala/org/apache/comet/rules/RewriteJoin.scala
index 2f1f5ab74..6dd102352 100644
--- a/spark/src/main/scala/org/apache/comet/rules/RewriteJoin.scala
+++ b/spark/src/main/scala/org/apache/comet/rules/RewriteJoin.scala
@@ -20,7 +20,8 @@
 package org.apache.comet.rules
 
 import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, 
BuildSide, JoinSelectionHelper}
-import org.apache.spark.sql.catalyst.plans.{JoinType, LeftSemi}
+import org.apache.spark.sql.catalyst.plans.LeftSemi
+import org.apache.spark.sql.catalyst.plans.logical.Join
 import org.apache.spark.sql.execution.{SortExec, SparkPlan}
 import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, 
SortMergeJoinExec}
 
@@ -31,14 +32,29 @@ import 
org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoin
  */
 object RewriteJoin extends JoinSelectionHelper {
 
-  private def getBuildSide(joinType: JoinType): Option[BuildSide] = {
-    if (canBuildShuffledHashJoinRight(joinType)) {
-      Some(BuildRight)
-    } else if (canBuildShuffledHashJoinLeft(joinType)) {
-      Some(BuildLeft)
-    } else {
-      None
+  private def getSmjBuildSide(join: SortMergeJoinExec): Option[BuildSide] = {
+    val leftBuildable = canBuildShuffledHashJoinLeft(join.joinType)
+    val rightBuildable = canBuildShuffledHashJoinRight(join.joinType)
+    if (!leftBuildable && !rightBuildable) {
+      return None
     }
+    if (!leftBuildable) {
+      return Some(BuildRight)
+    }
+    if (!rightBuildable) {
+      return Some(BuildLeft)
+    }
+    val side = join.logicalLink
+      .flatMap {
+        case join: Join => Some(getOptimalBuildSide(join))
+        case _ => None
+      }
+      .getOrElse {
+        // If smj has no logical link, or its logical link is not a join,
+        // then we always choose left as build side.
+        BuildLeft
+      }
+    Some(side)
   }
 
   private def removeSort(plan: SparkPlan) = plan match {
@@ -48,7 +64,7 @@ object RewriteJoin extends JoinSelectionHelper {
 
   def rewrite(plan: SparkPlan): SparkPlan = plan match {
     case smj: SortMergeJoinExec =>
-      getBuildSide(smj.joinType) match {
+      getSmjBuildSide(smj) match {
         case Some(BuildRight) if smj.joinType == LeftSemi =>
           // TODO this was added as a workaround for TPC-DS q14 hanging and 
needs
           // further investigation
@@ -67,4 +83,21 @@ object RewriteJoin extends JoinSelectionHelper {
       }
     case _ => plan
   }
+
+  def getOptimalBuildSide(join: Join): BuildSide = {
+    val leftSize = join.left.stats.sizeInBytes
+    val rightSize = join.right.stats.sizeInBytes
+    val leftRowCount = join.left.stats.rowCount
+    val rightRowCount = join.right.stats.rowCount
+    if (leftSize == rightSize && rightRowCount.isDefined && 
leftRowCount.isDefined) {
+      if (rightRowCount.get <= leftRowCount.get) {
+        return BuildRight
+      }
+      return BuildLeft
+    }
+    if (rightSize <= leftSize) {
+      return BuildRight
+    }
+    BuildLeft
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to