This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new ef3305280 perf: Update RewriteJoin logic to choose optimal build side
(#1424)
ef3305280 is described below
commit ef3305280762337c1029179f1251196dfd6bce45
Author: Andy Grove <[email protected]>
AuthorDate: Fri Feb 21 15:21:42 2025 -0700
perf: Update RewriteJoin logic to choose optimal build side (#1424)
---
.../scala/org/apache/comet/rules/RewriteJoin.scala | 51 ++++++++++++++++++----
1 file changed, 42 insertions(+), 9 deletions(-)
diff --git a/spark/src/main/scala/org/apache/comet/rules/RewriteJoin.scala
b/spark/src/main/scala/org/apache/comet/rules/RewriteJoin.scala
index 2f1f5ab74..6dd102352 100644
--- a/spark/src/main/scala/org/apache/comet/rules/RewriteJoin.scala
+++ b/spark/src/main/scala/org/apache/comet/rules/RewriteJoin.scala
@@ -20,7 +20,8 @@
package org.apache.comet.rules
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight,
BuildSide, JoinSelectionHelper}
-import org.apache.spark.sql.catalyst.plans.{JoinType, LeftSemi}
+import org.apache.spark.sql.catalyst.plans.LeftSemi
+import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.execution.{SortExec, SparkPlan}
import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec,
SortMergeJoinExec}
@@ -31,14 +32,29 @@ import
org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoin
*/
object RewriteJoin extends JoinSelectionHelper {
- private def getBuildSide(joinType: JoinType): Option[BuildSide] = {
- if (canBuildShuffledHashJoinRight(joinType)) {
- Some(BuildRight)
- } else if (canBuildShuffledHashJoinLeft(joinType)) {
- Some(BuildLeft)
- } else {
- None
+ private def getSmjBuildSide(join: SortMergeJoinExec): Option[BuildSide] = {
+ val leftBuildable = canBuildShuffledHashJoinLeft(join.joinType)
+ val rightBuildable = canBuildShuffledHashJoinRight(join.joinType)
+ if (!leftBuildable && !rightBuildable) {
+ return None
}
+ if (!leftBuildable) {
+ return Some(BuildRight)
+ }
+ if (!rightBuildable) {
+ return Some(BuildLeft)
+ }
+ val side = join.logicalLink
+ .flatMap {
+ case join: Join => Some(getOptimalBuildSide(join))
+ case _ => None
+ }
+ .getOrElse {
+ // If smj has no logical link, or its logical link is not a join,
+ // then we always choose left as build side.
+ BuildLeft
+ }
+ Some(side)
}
private def removeSort(plan: SparkPlan) = plan match {
@@ -48,7 +64,7 @@ object RewriteJoin extends JoinSelectionHelper {
def rewrite(plan: SparkPlan): SparkPlan = plan match {
case smj: SortMergeJoinExec =>
- getBuildSide(smj.joinType) match {
+ getSmjBuildSide(smj) match {
case Some(BuildRight) if smj.joinType == LeftSemi =>
// TODO this was added as a workaround for TPC-DS q14 hanging and
needs
// further investigation
@@ -67,4 +83,21 @@ object RewriteJoin extends JoinSelectionHelper {
}
case _ => plan
}
+
+ def getOptimalBuildSide(join: Join): BuildSide = {
+ val leftSize = join.left.stats.sizeInBytes
+ val rightSize = join.right.stats.sizeInBytes
+ val leftRowCount = join.left.stats.rowCount
+ val rightRowCount = join.right.stats.rowCount
+ if (leftSize == rightSize && rightRowCount.isDefined &&
leftRowCount.isDefined) {
+ if (rightRowCount.get <= leftRowCount.get) {
+ return BuildRight
+ }
+ return BuildLeft
+ }
+ if (rightSize <= leftSize) {
+ return BuildRight
+ }
+ BuildLeft
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]