This is an automated email from the ASF dual-hosted git repository.

zhangzc pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new cd624d711 [GLUTEN-6557][CH] Try to replace sort merge join with hash 
join when cannot offload it (#6570)
cd624d711 is described below

commit cd624d711eaa55595535834ed6e63f2945757b2b
Author: lgbo <[email protected]>
AuthorDate: Fri Aug 2 10:44:28 2024 +0800

    [GLUTEN-6557][CH] Try to replace sort merge join with hash join when cannot 
offload it (#6570)
    
    [CH] Try to replace sort merge join with hash join when cannot offload it
---
 .../clickhouse/CHSparkPlanExecApi.scala            |   7 +-
 .../execution/CHHashJoinExecTransformer.scala      |  60 ++++++----
 .../RewriteSortMergeJoinToHashJoinRule.scala       | 122 +++++++++++++++++++++
 .../apache/gluten/utils/CHJoinValidateUtil.scala   |  12 +-
 .../GlutenClickHouseTPCDSAbstractSuite.scala       |   7 +-
 ...nClickHouseTPCDSParquetSortMergeJoinSuite.scala |  22 ++--
 6 files changed, 183 insertions(+), 47 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index b8a76b421..3069c4a3f 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -21,7 +21,7 @@ import org.apache.gluten.backendsapi.{BackendsApiManager, 
SparkPlanExecApi}
 import org.apache.gluten.exception.GlutenNotSupportException
 import org.apache.gluten.execution._
 import org.apache.gluten.expression._
-import org.apache.gluten.extension.{CountDistinctWithoutExpand, 
FallbackBroadcastHashJoin, FallbackBroadcastHashJoinPrepQueryStage, 
RewriteToDateExpresstionRule}
+import org.apache.gluten.extension.{CountDistinctWithoutExpand, 
FallbackBroadcastHashJoin, FallbackBroadcastHashJoinPrepQueryStage, 
RewriteSortMergeJoinToHashJoinRule, RewriteToDateExpresstionRule}
 import org.apache.gluten.extension.columnar.AddFallbackTagRule
 import 
org.apache.gluten.extension.columnar.MiscColumnarRules.TransformPreOverrides
 import org.apache.gluten.extension.columnar.transition.Convention
@@ -555,8 +555,9 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
    *
    * @return
    */
-  override def genExtendedQueryStagePrepRules(): List[SparkSession => 
Rule[SparkPlan]] =
+  override def genExtendedQueryStagePrepRules(): List[SparkSession => 
Rule[SparkPlan]] = {
     List(spark => FallbackBroadcastHashJoinPrepQueryStage(spark))
+  }
 
   /**
    * Generate extended Analyzers. Currently only for ClickHouse backend.
@@ -597,7 +598,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
    * @return
    */
   override def genExtendedColumnarTransformRules(): List[SparkSession => 
Rule[SparkPlan]] =
-    List()
+    List(spark => RewriteSortMergeJoinToHashJoinRule(spark))
 
   override def genInjectPostHocResolutionRules(): List[SparkSession => 
Rule[LogicalPlan]] = {
     List()
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
index c44156373..ed946e1d2 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
@@ -31,6 +31,35 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
 
 import io.substrait.proto.JoinRel
 
+object JoinTypeTransform {
+  def toNativeJoinType(joinType: JoinType): JoinType = {
+    joinType match {
+      case ExistenceJoin(_) =>
+        LeftSemi
+      case _ =>
+        joinType
+    }
+  }
+
+  def toSubstraitType(joinType: JoinType): JoinRel.JoinType = {
+    joinType match {
+      case _: InnerLike =>
+        JoinRel.JoinType.JOIN_TYPE_INNER
+      case FullOuter =>
+        JoinRel.JoinType.JOIN_TYPE_OUTER
+      case LeftOuter | RightOuter =>
+        JoinRel.JoinType.JOIN_TYPE_LEFT
+      case LeftSemi | ExistenceJoin(_) =>
+        JoinRel.JoinType.JOIN_TYPE_LEFT_SEMI
+      case LeftAnti =>
+        JoinRel.JoinType.JOIN_TYPE_ANTI
+      case _ =>
+        // TODO: Support cross join with Cross Rel
+        JoinRel.JoinType.UNRECOGNIZED
+    }
+  }
+}
+
 case class CHShuffledHashJoinExecTransformer(
     leftKeys: Seq[Expression],
     rightKeys: Seq[Expression],
@@ -57,7 +86,7 @@ case class CHShuffledHashJoinExecTransformer(
   override protected def doValidateInternal(): ValidationResult = {
     val shouldFallback =
       CHJoinValidateUtil.shouldFallback(
-        ShuffleHashJoinStrategy(joinType),
+        ShuffleHashJoinStrategy(finalJoinType),
         left.outputSet,
         right.outputSet,
         condition)
@@ -66,6 +95,9 @@ case class CHShuffledHashJoinExecTransformer(
     }
     super.doValidateInternal()
   }
+  private val finalJoinType = JoinTypeTransform.toNativeJoinType(joinType)
+  override protected lazy val substraitJoinType: JoinRel.JoinType =
+    JoinTypeTransform.toSubstraitType(joinType)
 }
 
 case class CHBroadcastBuildSideRDD(
@@ -171,27 +203,7 @@ case class CHBroadcastHashJoinExecTransformer(
   // Indeed, the ExistenceJoin is transformed into left any join in CH.
   // We don't have left any join in substrait, so use left semi join instead.
   // and isExistenceJoin is set to true to indicate that it is an existence 
join.
-  private val finalJoinType = joinType match {
-    case ExistenceJoin(_) =>
-      LeftSemi
-    case _ =>
-      joinType
-  }
-  override protected lazy val substraitJoinType: JoinRel.JoinType = {
-    joinType match {
-      case _: InnerLike =>
-        JoinRel.JoinType.JOIN_TYPE_INNER
-      case FullOuter =>
-        JoinRel.JoinType.JOIN_TYPE_OUTER
-      case LeftOuter | RightOuter =>
-        JoinRel.JoinType.JOIN_TYPE_LEFT
-      case LeftSemi | ExistenceJoin(_) =>
-        JoinRel.JoinType.JOIN_TYPE_LEFT_SEMI
-      case LeftAnti =>
-        JoinRel.JoinType.JOIN_TYPE_ANTI
-      case _ =>
-        // TODO: Support cross join with Cross Rel
-        JoinRel.JoinType.UNRECOGNIZED
-    }
-  }
+  private val finalJoinType = JoinTypeTransform.toNativeJoinType(joinType)
+  override protected lazy val substraitJoinType: JoinRel.JoinType =
+    JoinTypeTransform.toSubstraitType(joinType)
 }
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteSortMergeJoinToHashJoinRule.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteSortMergeJoinToHashJoinRule.scala
new file mode 100644
index 000000000..8c5ada043
--- /dev/null
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteSortMergeJoinToHashJoinRule.scala
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension
+
+import org.apache.gluten.execution._
+import org.apache.gluten.utils.{CHJoinValidateUtil, ShuffleHashJoinStrategy}
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.optimizer._
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec
+import org.apache.spark.sql.execution.joins._
+
+// import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+// If a SortMergeJoinExec cannot be offloaded, try to replace it with 
ShuffledHashJoinExec
+// instead.
+// This is rule is applied after spark plan nodes are transformed into 
columnar ones.
+case class RewriteSortMergeJoinToHashJoinRule(session: SparkSession)
+  extends Rule[SparkPlan]
+  with Logging {
+  override def apply(plan: SparkPlan): SparkPlan = {
+    visitPlan(plan)
+  }
+
+  private def visitPlan(plan: SparkPlan): SparkPlan = {
+    plan match {
+      case smj: SortMergeJoinExec =>
+        tryReplaceSortMergeJoin(smj)
+      case other =>
+        other.withNewChildren(other.children.map(visitPlan))
+    }
+  }
+
+  private def tryReplaceSortMergeJoin(smj: SortMergeJoinExec): SparkPlan = {
+    // cannot offload SortMergeJoin, try to replace it with ShuffledHashJoin
+    val needFallback = CHJoinValidateUtil.shouldFallback(
+      ShuffleHashJoinStrategy(smj.joinType),
+      smj.left.outputSet,
+      smj.right.outputSet,
+      smj.condition)
+    // also cannot offload HashJoin, don't replace it.
+    if (needFallback) {
+      logInfo(s"Cannot offload this join by hash join algorithm")
+      return smj
+    } else {
+      replaceSortMergeJoinWithHashJoin(smj)
+    }
+  }
+
+  private def replaceSortMergeJoinWithHashJoin(smj: SortMergeJoinExec): 
SparkPlan = {
+    val newLeft = replaceSortMergeJoinChild(smj.left)
+    val newRight = replaceSortMergeJoinChild(smj.right)
+    // Some cases that we cannot handle.
+    if (newLeft == null || newRight == null) {
+      logInfo("Apply on sort merge children failed")
+      return smj
+    }
+
+    var hashJoin = CHShuffledHashJoinExecTransformer(
+      smj.leftKeys,
+      smj.rightKeys,
+      smj.joinType,
+      BuildRight,
+      smj.condition,
+      newLeft,
+      newRight,
+      smj.isSkewJoin)
+    val validateResult = hashJoin.doValidate()
+    if (!validateResult.ok()) {
+      logError(s"Validation failed for ShuffledHashJoinExec: 
${validateResult.reason()}")
+      return smj
+    }
+    hashJoin
+  }
+
+  private def replaceSortMergeJoinChild(plan: SparkPlan): SparkPlan = {
+    plan match {
+      case sort: SortExecTransformer =>
+        sort.child match {
+          case hashShuffle: ColumnarShuffleExchangeExec =>
+            // drop sort node, return the shuffle node direclty
+            hashShuffle.withNewChildren(hashShuffle.children.map(visitPlan))
+          case aqeShuffle: AQEShuffleReadExec =>
+            // drop sort node, return the shuffle node direclty
+            aqeShuffle.withNewChildren(aqeShuffle.children.map(visitPlan))
+          case columnarPlan: TransformSupport =>
+            visitPlan(columnarPlan)
+          case _ =>
+            // other cases that we don't know
+            logInfo(s"Expected ColumnarShuffleExchangeExec, got 
${sort.child.getClass}")
+            null
+        }
+      case smj: SortMergeJoinExec =>
+        val newChild = replaceSortMergeJoinWithHashJoin(smj)
+        if (newChild.isInstanceOf[SortMergeJoinExec]) {
+          null
+        } else {
+          newChild
+        }
+      case _: TransformSupport => visitPlan(plan)
+      case _ =>
+        logInfo(s"Expected Columnar node, got ${plan.getClass}")
+        null
+    }
+  }
+}
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala
index 08b5ef5b2..0f5b5e2c4 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala
@@ -18,7 +18,7 @@ package org.apache.gluten.utils
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Expression}
-import org.apache.spark.sql.catalyst.plans.JoinType
+import org.apache.spark.sql.catalyst.plans._
 
 trait JoinStrategy {
   val joinType: JoinType
@@ -54,11 +54,8 @@ object CHJoinValidateUtil extends Logging {
       condition: Option[Expression]): Boolean = {
     var shouldFallback = false
     val joinType = joinStrategy.joinType
-    if (joinType.toString.contains("ExistenceJoin")) {
-      logError("Fallback for join type ExistenceJoin")
-      return true
-    }
-    if (joinType.sql.contains("INNER")) {
+
+    if (!joinType.isInstanceOf[ExistenceJoin] && 
joinType.sql.contains("INNER")) {
       shouldFallback = false;
     } else if (
       condition.isDefined && hasTwoTableColumn(leftOutputSet, rightOutputSet, 
condition.get)
@@ -75,7 +72,8 @@ object CHJoinValidateUtil extends Logging {
     } else {
       shouldFallback = joinStrategy match {
         case SortMergeJoinStrategy(joinTy) =>
-          joinTy.sql.contains("SEMI") || joinTy.sql.contains("ANTI")
+          joinTy.sql.contains("SEMI") || joinTy.sql.contains("ANTI") || 
joinTy.toString.contains(
+            "ExistenceJoin")
         case _ => false
       }
     }
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala
index 6ca587beb..f2a1e5a71 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala
@@ -66,12 +66,7 @@ abstract class GlutenClickHouseTPCDSAbstractSuite
 
     // q16 smj + left semi + not condition
     // Q94 BroadcastHashJoin, LeftSemi, NOT condition
-    if (isAqe) {
-      Set(16, 94) | more
-    } else {
-      // q10, q35 smj + existence join
-      Set(10, 16, 35, 94) | more
-    }
+    Set(16, 94) | more
   }
   protected def excludedTpcdsQueries: Set[String] = Set(
     "q66" // inconsistent results
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpcds/GlutenClickHouseTPCDSParquetSortMergeJoinSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpcds/GlutenClickHouseTPCDSParquetSortMergeJoinSuite.scala
index 7e480361b..509c83054 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpcds/GlutenClickHouseTPCDSParquetSortMergeJoinSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpcds/GlutenClickHouseTPCDSParquetSortMergeJoinSuite.scala
@@ -16,7 +16,7 @@
  */
 package org.apache.gluten.execution.tpcds
 
-import org.apache.gluten.execution.{CHSortMergeJoinExecTransformer, 
GlutenClickHouseTPCDSAbstractSuite}
+import org.apache.gluten.execution.{CHShuffledHashJoinExecTransformer, 
CHSortMergeJoinExecTransformer, GlutenClickHouseTPCDSAbstractSuite}
 import org.apache.gluten.test.FallbackUtil
 
 import org.apache.spark.SparkConf
@@ -114,7 +114,7 @@ class GlutenClickHouseTPCDSParquetSortMergeJoinSuite 
extends GlutenClickHouseTPC
     }
   }
 
-  test("sort merge join: left semi join should fallback") {
+  test("sort merge join: left semi join should be replaced with hash join") {
     withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "-1") {
       val testSql =
         """SELECT  count(*) cnt
@@ -125,12 +125,16 @@ class GlutenClickHouseTPCDSParquetSortMergeJoinSuite 
extends GlutenClickHouseTPC
       val smjTransformers = df.queryExecution.executedPlan.collect {
         case f: CHSortMergeJoinExecTransformer => f
       }
-      assert(smjTransformers.isEmpty)
-      assert(FallbackUtil.hasFallback(df.queryExecution.executedPlan))
+      val hashJoinTransformers = df.queryExecution.executedPlan.collect {
+        case f: CHShuffledHashJoinExecTransformer => f
+      }
+      assert(smjTransformers.size == 0)
+      assert(hashJoinTransformers.size > 0)
+      assert(!FallbackUtil.hasFallback(df.queryExecution.executedPlan))
     }
   }
 
-  test("sort merge join: left anti join should fallback") {
+  test("sort merge join: left anti join should be replace with hash join") {
     withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "-1") {
       val testSql =
         """SELECT  count(*) cnt
@@ -141,8 +145,12 @@ class GlutenClickHouseTPCDSParquetSortMergeJoinSuite 
extends GlutenClickHouseTPC
       val smjTransformers = df.queryExecution.executedPlan.collect {
         case f: CHSortMergeJoinExecTransformer => f
       }
-      assert(smjTransformers.isEmpty)
-      assert(FallbackUtil.hasFallback(df.queryExecution.executedPlan))
+      val hashJoinTransformers = df.queryExecution.executedPlan.collect {
+        case f: CHShuffledHashJoinExecTransformer => f
+      }
+      assert(smjTransformers.size == 0)
+      assert(hashJoinTransformers.size > 0)
+      assert(!FallbackUtil.hasFallback(df.queryExecution.executedPlan))
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to