This is an automated email from the ASF dual-hosted git repository.
zhangzc pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new cd624d711 [GLUTEN-6557][CH] Try to replace sort merge join with hash
join when cannot offload it (#6570)
cd624d711 is described below
commit cd624d711eaa55595535834ed6e63f2945757b2b
Author: lgbo <[email protected]>
AuthorDate: Fri Aug 2 10:44:28 2024 +0800
[GLUTEN-6557][CH] Try to replace sort merge join with hash join when cannot
offload it (#6570)
[CH] Try to replace sort merge join with hash join when cannot offload it
---
.../clickhouse/CHSparkPlanExecApi.scala | 7 +-
.../execution/CHHashJoinExecTransformer.scala | 60 ++++++----
.../RewriteSortMergeJoinToHashJoinRule.scala | 122 +++++++++++++++++++++
.../apache/gluten/utils/CHJoinValidateUtil.scala | 12 +-
.../GlutenClickHouseTPCDSAbstractSuite.scala | 7 +-
...nClickHouseTPCDSParquetSortMergeJoinSuite.scala | 22 ++--
6 files changed, 183 insertions(+), 47 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index b8a76b421..3069c4a3f 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -21,7 +21,7 @@ import org.apache.gluten.backendsapi.{BackendsApiManager,
SparkPlanExecApi}
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.execution._
import org.apache.gluten.expression._
-import org.apache.gluten.extension.{CountDistinctWithoutExpand,
FallbackBroadcastHashJoin, FallbackBroadcastHashJoinPrepQueryStage,
RewriteToDateExpresstionRule}
+import org.apache.gluten.extension.{CountDistinctWithoutExpand,
FallbackBroadcastHashJoin, FallbackBroadcastHashJoinPrepQueryStage,
RewriteSortMergeJoinToHashJoinRule, RewriteToDateExpresstionRule}
import org.apache.gluten.extension.columnar.AddFallbackTagRule
import
org.apache.gluten.extension.columnar.MiscColumnarRules.TransformPreOverrides
import org.apache.gluten.extension.columnar.transition.Convention
@@ -555,8 +555,9 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
*
* @return
*/
- override def genExtendedQueryStagePrepRules(): List[SparkSession =>
Rule[SparkPlan]] =
+ override def genExtendedQueryStagePrepRules(): List[SparkSession =>
Rule[SparkPlan]] = {
List(spark => FallbackBroadcastHashJoinPrepQueryStage(spark))
+ }
/**
* Generate extended Analyzers. Currently only for ClickHouse backend.
@@ -597,7 +598,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
* @return
*/
override def genExtendedColumnarTransformRules(): List[SparkSession =>
Rule[SparkPlan]] =
- List()
+ List(spark => RewriteSortMergeJoinToHashJoinRule(spark))
override def genInjectPostHocResolutionRules(): List[SparkSession =>
Rule[LogicalPlan]] = {
List()
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
index c44156373..ed946e1d2 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala
@@ -31,6 +31,35 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
import io.substrait.proto.JoinRel
+object JoinTypeTransform {
+ def toNativeJoinType(joinType: JoinType): JoinType = {
+ joinType match {
+ case ExistenceJoin(_) =>
+ LeftSemi
+ case _ =>
+ joinType
+ }
+ }
+
+ def toSubstraitType(joinType: JoinType): JoinRel.JoinType = {
+ joinType match {
+ case _: InnerLike =>
+ JoinRel.JoinType.JOIN_TYPE_INNER
+ case FullOuter =>
+ JoinRel.JoinType.JOIN_TYPE_OUTER
+ case LeftOuter | RightOuter =>
+ JoinRel.JoinType.JOIN_TYPE_LEFT
+ case LeftSemi | ExistenceJoin(_) =>
+ JoinRel.JoinType.JOIN_TYPE_LEFT_SEMI
+ case LeftAnti =>
+ JoinRel.JoinType.JOIN_TYPE_ANTI
+ case _ =>
+ // TODO: Support cross join with Cross Rel
+ JoinRel.JoinType.UNRECOGNIZED
+ }
+ }
+}
+
case class CHShuffledHashJoinExecTransformer(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
@@ -57,7 +86,7 @@ case class CHShuffledHashJoinExecTransformer(
override protected def doValidateInternal(): ValidationResult = {
val shouldFallback =
CHJoinValidateUtil.shouldFallback(
- ShuffleHashJoinStrategy(joinType),
+ ShuffleHashJoinStrategy(finalJoinType),
left.outputSet,
right.outputSet,
condition)
@@ -66,6 +95,9 @@ case class CHShuffledHashJoinExecTransformer(
}
super.doValidateInternal()
}
+ private val finalJoinType = JoinTypeTransform.toNativeJoinType(joinType)
+ override protected lazy val substraitJoinType: JoinRel.JoinType =
+ JoinTypeTransform.toSubstraitType(joinType)
}
case class CHBroadcastBuildSideRDD(
@@ -171,27 +203,7 @@ case class CHBroadcastHashJoinExecTransformer(
// Indeed, the ExistenceJoin is transformed into left any join in CH.
// We don't have left any join in substrait, so use left semi join instead.
// and isExistenceJoin is set to true to indicate that it is an existence
join.
- private val finalJoinType = joinType match {
- case ExistenceJoin(_) =>
- LeftSemi
- case _ =>
- joinType
- }
- override protected lazy val substraitJoinType: JoinRel.JoinType = {
- joinType match {
- case _: InnerLike =>
- JoinRel.JoinType.JOIN_TYPE_INNER
- case FullOuter =>
- JoinRel.JoinType.JOIN_TYPE_OUTER
- case LeftOuter | RightOuter =>
- JoinRel.JoinType.JOIN_TYPE_LEFT
- case LeftSemi | ExistenceJoin(_) =>
- JoinRel.JoinType.JOIN_TYPE_LEFT_SEMI
- case LeftAnti =>
- JoinRel.JoinType.JOIN_TYPE_ANTI
- case _ =>
- // TODO: Support cross join with Cross Rel
- JoinRel.JoinType.UNRECOGNIZED
- }
- }
+ private val finalJoinType = JoinTypeTransform.toNativeJoinType(joinType)
+ override protected lazy val substraitJoinType: JoinRel.JoinType =
+ JoinTypeTransform.toSubstraitType(joinType)
}
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteSortMergeJoinToHashJoinRule.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteSortMergeJoinToHashJoinRule.scala
new file mode 100644
index 000000000..8c5ada043
--- /dev/null
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RewriteSortMergeJoinToHashJoinRule.scala
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension
+
+import org.apache.gluten.execution._
+import org.apache.gluten.utils.{CHJoinValidateUtil, ShuffleHashJoinStrategy}
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.optimizer._
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec
+import org.apache.spark.sql.execution.joins._
+
+// import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+// If a SortMergeJoinExec cannot be offloaded, try to replace it with
ShuffledHashJoinExec
+// instead.
+// This is rule is applied after spark plan nodes are transformed into
columnar ones.
+case class RewriteSortMergeJoinToHashJoinRule(session: SparkSession)
+ extends Rule[SparkPlan]
+ with Logging {
+ override def apply(plan: SparkPlan): SparkPlan = {
+ visitPlan(plan)
+ }
+
+ private def visitPlan(plan: SparkPlan): SparkPlan = {
+ plan match {
+ case smj: SortMergeJoinExec =>
+ tryReplaceSortMergeJoin(smj)
+ case other =>
+ other.withNewChildren(other.children.map(visitPlan))
+ }
+ }
+
+ private def tryReplaceSortMergeJoin(smj: SortMergeJoinExec): SparkPlan = {
+ // cannot offload SortMergeJoin, try to replace it with ShuffledHashJoin
+ val needFallback = CHJoinValidateUtil.shouldFallback(
+ ShuffleHashJoinStrategy(smj.joinType),
+ smj.left.outputSet,
+ smj.right.outputSet,
+ smj.condition)
+ // also cannot offload HashJoin, don't replace it.
+ if (needFallback) {
+ logInfo(s"Cannot offload this join by hash join algorithm")
+ return smj
+ } else {
+ replaceSortMergeJoinWithHashJoin(smj)
+ }
+ }
+
+ private def replaceSortMergeJoinWithHashJoin(smj: SortMergeJoinExec):
SparkPlan = {
+ val newLeft = replaceSortMergeJoinChild(smj.left)
+ val newRight = replaceSortMergeJoinChild(smj.right)
+ // Some cases that we cannot handle.
+ if (newLeft == null || newRight == null) {
+ logInfo("Apply on sort merge children failed")
+ return smj
+ }
+
+ var hashJoin = CHShuffledHashJoinExecTransformer(
+ smj.leftKeys,
+ smj.rightKeys,
+ smj.joinType,
+ BuildRight,
+ smj.condition,
+ newLeft,
+ newRight,
+ smj.isSkewJoin)
+ val validateResult = hashJoin.doValidate()
+ if (!validateResult.ok()) {
+ logError(s"Validation failed for ShuffledHashJoinExec:
${validateResult.reason()}")
+ return smj
+ }
+ hashJoin
+ }
+
+ private def replaceSortMergeJoinChild(plan: SparkPlan): SparkPlan = {
+ plan match {
+ case sort: SortExecTransformer =>
+ sort.child match {
+ case hashShuffle: ColumnarShuffleExchangeExec =>
+ // drop sort node, return the shuffle node direclty
+ hashShuffle.withNewChildren(hashShuffle.children.map(visitPlan))
+ case aqeShuffle: AQEShuffleReadExec =>
+ // drop sort node, return the shuffle node direclty
+ aqeShuffle.withNewChildren(aqeShuffle.children.map(visitPlan))
+ case columnarPlan: TransformSupport =>
+ visitPlan(columnarPlan)
+ case _ =>
+ // other cases that we don't know
+ logInfo(s"Expected ColumnarShuffleExchangeExec, got
${sort.child.getClass}")
+ null
+ }
+ case smj: SortMergeJoinExec =>
+ val newChild = replaceSortMergeJoinWithHashJoin(smj)
+ if (newChild.isInstanceOf[SortMergeJoinExec]) {
+ null
+ } else {
+ newChild
+ }
+ case _: TransformSupport => visitPlan(plan)
+ case _ =>
+ logInfo(s"Expected Columnar node, got ${plan.getClass}")
+ null
+ }
+ }
+}
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala
index 08b5ef5b2..0f5b5e2c4 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala
@@ -18,7 +18,7 @@ package org.apache.gluten.utils
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Expression}
-import org.apache.spark.sql.catalyst.plans.JoinType
+import org.apache.spark.sql.catalyst.plans._
trait JoinStrategy {
val joinType: JoinType
@@ -54,11 +54,8 @@ object CHJoinValidateUtil extends Logging {
condition: Option[Expression]): Boolean = {
var shouldFallback = false
val joinType = joinStrategy.joinType
- if (joinType.toString.contains("ExistenceJoin")) {
- logError("Fallback for join type ExistenceJoin")
- return true
- }
- if (joinType.sql.contains("INNER")) {
+
+ if (!joinType.isInstanceOf[ExistenceJoin] &&
joinType.sql.contains("INNER")) {
shouldFallback = false;
} else if (
condition.isDefined && hasTwoTableColumn(leftOutputSet, rightOutputSet,
condition.get)
@@ -75,7 +72,8 @@ object CHJoinValidateUtil extends Logging {
} else {
shouldFallback = joinStrategy match {
case SortMergeJoinStrategy(joinTy) =>
- joinTy.sql.contains("SEMI") || joinTy.sql.contains("ANTI")
+ joinTy.sql.contains("SEMI") || joinTy.sql.contains("ANTI") ||
joinTy.toString.contains(
+ "ExistenceJoin")
case _ => false
}
}
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala
index 6ca587beb..f2a1e5a71 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala
@@ -66,12 +66,7 @@ abstract class GlutenClickHouseTPCDSAbstractSuite
// q16 smj + left semi + not condition
// Q94 BroadcastHashJoin, LeftSemi, NOT condition
- if (isAqe) {
- Set(16, 94) | more
- } else {
- // q10, q35 smj + existence join
- Set(10, 16, 35, 94) | more
- }
+ Set(16, 94) | more
}
protected def excludedTpcdsQueries: Set[String] = Set(
"q66" // inconsistent results
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpcds/GlutenClickHouseTPCDSParquetSortMergeJoinSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpcds/GlutenClickHouseTPCDSParquetSortMergeJoinSuite.scala
index 7e480361b..509c83054 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpcds/GlutenClickHouseTPCDSParquetSortMergeJoinSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpcds/GlutenClickHouseTPCDSParquetSortMergeJoinSuite.scala
@@ -16,7 +16,7 @@
*/
package org.apache.gluten.execution.tpcds
-import org.apache.gluten.execution.{CHSortMergeJoinExecTransformer,
GlutenClickHouseTPCDSAbstractSuite}
+import org.apache.gluten.execution.{CHShuffledHashJoinExecTransformer,
CHSortMergeJoinExecTransformer, GlutenClickHouseTPCDSAbstractSuite}
import org.apache.gluten.test.FallbackUtil
import org.apache.spark.SparkConf
@@ -114,7 +114,7 @@ class GlutenClickHouseTPCDSParquetSortMergeJoinSuite
extends GlutenClickHouseTPC
}
}
- test("sort merge join: left semi join should fallback") {
+ test("sort merge join: left semi join should be replaced with hash join") {
withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "-1") {
val testSql =
"""SELECT count(*) cnt
@@ -125,12 +125,16 @@ class GlutenClickHouseTPCDSParquetSortMergeJoinSuite
extends GlutenClickHouseTPC
val smjTransformers = df.queryExecution.executedPlan.collect {
case f: CHSortMergeJoinExecTransformer => f
}
- assert(smjTransformers.isEmpty)
- assert(FallbackUtil.hasFallback(df.queryExecution.executedPlan))
+ val hashJoinTransformers = df.queryExecution.executedPlan.collect {
+ case f: CHShuffledHashJoinExecTransformer => f
+ }
+ assert(smjTransformers.size == 0)
+ assert(hashJoinTransformers.size > 0)
+ assert(!FallbackUtil.hasFallback(df.queryExecution.executedPlan))
}
}
- test("sort merge join: left anti join should fallback") {
+ test("sort merge join: left anti join should be replace with hash join") {
withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "-1") {
val testSql =
"""SELECT count(*) cnt
@@ -141,8 +145,12 @@ class GlutenClickHouseTPCDSParquetSortMergeJoinSuite
extends GlutenClickHouseTPC
val smjTransformers = df.queryExecution.executedPlan.collect {
case f: CHSortMergeJoinExecTransformer => f
}
- assert(smjTransformers.isEmpty)
- assert(FallbackUtil.hasFallback(df.queryExecution.executedPlan))
+ val hashJoinTransformers = df.queryExecution.executedPlan.collect {
+ case f: CHShuffledHashJoinExecTransformer => f
+ }
+ assert(smjTransformers.size == 0)
+ assert(hashJoinTransformers.size > 0)
+ assert(!FallbackUtil.hasFallback(df.queryExecution.executedPlan))
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]