This is an automated email from the ASF dual-hosted git repository.
changchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new d2b50ac54 [GLUTEN-4745][CH] support Sort Merge Join (#4812)
d2b50ac54 is described below
commit d2b50ac5413af9be07118952fc5abff5fe0f4c5d
Author: loudongfeng <[email protected]>
AuthorDate: Tue Apr 2 10:35:23 2024 +0800
[GLUTEN-4745][CH] support Sort Merge Join (#4812)
Support inner and outer joins
---
.../gluten/backendsapi/clickhouse/CHBackend.scala | 2 +-
.../backendsapi/clickhouse/CHListenerApi.scala | 16 ++
.../clickhouse/CHSparkPlanExecApi.scala | 20 ++
.../execution/CHSortMergeJoinExecTransformer.scala | 58 ++++++
.../apache/gluten/utils/CHJoinValidateUtil.scala | 11 +-
...nClickHouseTPCDSParquetSortMergeJoinSuite.scala | 221 +++++++++++++++++++++
.../backendsapi/velox/SparkPlanExecApiImpl.scala | 20 ++
cpp-ch/local-engine/Parser/JoinRelParser.cpp | 69 +++++--
.../gluten/backendsapi/SparkPlanExecApi.scala | 11 +
.../execution/SortMergeJoinExecTransformer.scala | 47 ++++-
.../extension/columnar/TransformHintRule.scala | 24 +--
.../columnar/transform/ImplementSingleNode.scala | 17 +-
.../utils/clickhouse/ClickHouseTestSettings.scala | 4 +
.../utils/clickhouse/ClickHouseTestSettings.scala | 11 +-
.../GlutenKeyGroupedPartitioningSuite.scala | 5 +-
.../ClickHouseAdaptiveQueryExecSuite.scala | 67 ++++++-
16 files changed, 545 insertions(+), 58 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
index d85b6e269..7217a979b 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
@@ -216,7 +216,7 @@ object CHBackendSettings extends BackendSettingsApi with
Logging {
}
override def supportSortMergeJoinExec(): Boolean = {
- false
+ GlutenConfig.getConf.enableColumnarSortMergeJoin
}
override def supportWindowExec(windowFunctions: Seq[NamedExpression]):
Boolean = {
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
index 5901a11e4..17883cf24 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
@@ -25,6 +25,7 @@ import
org.apache.gluten.vectorized.{CHNativeExpressionEvaluator, JniLibLoader}
import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
+import org.apache.spark.network.util.JavaUtils
import org.apache.spark.sql.execution.datasources.v1._
import org.apache.commons.lang3.StringUtils
@@ -62,6 +63,21 @@ class CHListenerApi extends ListenerApi with Logging {
s".local_engine.settings.log_processors_profiles",
"true")
+ // add memory limit for external sort
+ val externalSortKey =
s"${CHBackendSettings.getBackendConfigPrefix}.runtime_settings" +
+ s".max_bytes_before_external_sort"
+ if (conf.getInt(externalSortKey, -1) < 0) {
+ if (conf.getBoolean("spark.memory.offHeap.enabled", false)) {
+ val memSize =
JavaUtils.byteStringAsBytes(conf.get("spark.memory.offHeap.size")).toInt
+ if (memSize > 0) {
+ val cores = conf.getInt("spark.executor.cores", 1)
+ val sortMemLimit = ((memSize / cores) * 0.8).toInt
+ logInfo(s"max memory for sorting: $sortMemLimit")
+ conf.set(externalSortKey, sortMemLimit.toString)
+ }
+ }
+ }
+
// Load supported hive/python/scala udfs
UDFMappings.loadFromSparkConf(conf)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index fdc303525..57fd81ba0 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -368,6 +368,26 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
right,
isNullAwareAntiJoin)
+ override def genSortMergeJoinExecTransformer(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ joinType: JoinType,
+ condition: Option[Expression],
+ left: SparkPlan,
+ right: SparkPlan,
+ isSkewJoin: Boolean = false,
+ projectList: Seq[NamedExpression] = null):
SortMergeJoinExecTransformerBase =
+ CHSortMergeJoinExecTransformer(
+ leftKeys,
+ rightKeys,
+ joinType,
+ condition,
+ left,
+ right,
+ isSkewJoin,
+ projectList
+ )
+
/** Generate CartesianProductExecTransformer */
override def genCartesianProductExecTransformer(
left: SparkPlan,
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHSortMergeJoinExecTransformer.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHSortMergeJoinExecTransformer.scala
new file mode 100644
index 000000000..a5ac5f658
--- /dev/null
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHSortMergeJoinExecTransformer.scala
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.execution
+
+import org.apache.gluten.extension.ValidationResult
+import org.apache.gluten.utils.CHJoinValidateUtil
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.execution._
+
+case class CHSortMergeJoinExecTransformer(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ joinType: JoinType,
+ condition: Option[Expression],
+ left: SparkPlan,
+ right: SparkPlan,
+ isSkewJoin: Boolean = false,
+ projectList: Seq[NamedExpression] = null)
+ extends SortMergeJoinExecTransformerBase(
+ leftKeys,
+ rightKeys,
+ joinType,
+ condition,
+ left,
+ right,
+ isSkewJoin,
+ projectList) {
+
+ override protected def doValidateInternal(): ValidationResult = {
+ val shouldFallback =
+ CHJoinValidateUtil.shouldFallback(joinType, left.outputSet,
right.outputSet, condition, true)
+ if (shouldFallback) {
+ return ValidationResult.notOk("ch join validate fail")
+ }
+ super.doValidateInternal()
+ }
+
+ override protected def withNewChildrenInternal(
+ newLeft: SparkPlan,
+ newRight: SparkPlan): CHSortMergeJoinExecTransformer =
+ copy(left = newLeft, right = newRight)
+}
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala
index a2ecf5664..06b2445af 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala
@@ -44,7 +44,8 @@ object CHJoinValidateUtil extends Logging {
joinType: JoinType,
leftOutputSet: AttributeSet,
rightOutputSet: AttributeSet,
- condition: Option[Expression]): Boolean = {
+ condition: Option[Expression],
+ isSMJ: Boolean = false): Boolean = {
var shouldFallback = false
if (joinType.toString.contains("ExistenceJoin")) {
return true
@@ -52,6 +53,14 @@ object CHJoinValidateUtil extends Logging {
if (joinType.sql.equals("INNER")) {
return shouldFallback
}
+ if (isSMJ) {
+ if (
+ joinType.sql.contains("SEMI")
+ || joinType.sql.contains("ANTI")
+ ) {
+ return true
+ }
+ }
if (condition.isDefined) {
condition.get.transform {
case Or(l, r) =>
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSParquetSortMergeJoinSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSParquetSortMergeJoinSuite.scala
new file mode 100644
index 000000000..bbedcda18
--- /dev/null
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSParquetSortMergeJoinSuite.scala
@@ -0,0 +1,221 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.execution
+
+import org.apache.gluten.utils.FallbackUtil
+
+import org.apache.spark.SparkConf
+
+class GlutenClickHouseTPCDSParquetSortMergeJoinSuite extends
GlutenClickHouseTPCDSAbstractSuite {
+
+ override protected val tpcdsQueries: String =
+ rootPath +
"../../../../gluten-core/src/test/resources/tpcds-queries/tpcds.queries.original"
+ override protected val queriesResults: String = rootPath +
"tpcds-queries-output"
+
+ override protected def excludedTpcdsQueries: Set[String] = Set(
+ // fallback due to left semi/anti
+ "q8",
+ "q14a",
+ "q14b",
+ "q23a",
+ "q23b",
+ "q51",
+ "q69",
+ "q70",
+ "q78",
+ "q95",
+ "q97"
+ ) ++ super.excludedTpcdsQueries
+
+ /** Run Gluten + ClickHouse Backend with SortShuffleManager */
+ override protected def sparkConf: SparkConf = {
+ super.sparkConf
+ .set("spark.shuffle.manager", "sort")
+ .set("spark.io.compression.codec", "snappy")
+ .set("spark.sql.shuffle.partitions", "5")
+ .set("spark.sql.autoBroadcastJoinThreshold", "10MB")
+ .set("spark.memory.offHeap.size", "8g")
+ .set("spark.gluten.sql.columnar.forceShuffledHashJoin", "false")
+ }
+
+ executeTPCDSTest(false)
+
+ test("sort merge join: inner join") {
+ withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "-1") {
+ val testSql =
+ """SELECT count(*) cnt
+ |FROM item i join item j on j.i_category = i.i_category
+ |where
+ |i.i_current_price > 1.0 """.stripMargin
+ compareResultsAgainstVanillaSpark(
+ testSql,
+ true,
+ df => {
+ val smjTransformers = df.queryExecution.executedPlan.collect {
+ case f: CHSortMergeJoinExecTransformer => f
+ }
+ assert(smjTransformers.size == 1)
+ }
+ )
+ }
+ }
+
+ test("sort merge join: left outer join") {
+ withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "-1") {
+ val testSql =
+ """SELECT count(*) cnt
+ |FROM item i left outer join item j on j.i_category = i.i_category
+ """.stripMargin
+ compareResultsAgainstVanillaSpark(
+ testSql,
+ true,
+ df => {
+ val smjTransformers = df.queryExecution.executedPlan.collect {
+ case f: CHSortMergeJoinExecTransformer => f
+ }
+ assert(smjTransformers.size == 1)
+ }
+ )
+ }
+ }
+
+ test("sort merge join: right outer join") {
+ withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "-1") {
+ val testSql =
+ """SELECT count(*) cnt
+ |FROM item i right outer join item j on j.i_category = i.i_category
+ """.stripMargin
+ compareResultsAgainstVanillaSpark(
+ testSql,
+ true,
+ df => {
+ val smjTransformers = df.queryExecution.executedPlan.collect {
+ case f: CHSortMergeJoinExecTransformer => f
+ }
+ assert(smjTransformers.size == 1)
+ }
+ )
+ }
+ }
+
+ test("sort merge join: left semi join should fallback") {
+ withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "-1") {
+ val testSql =
+ """SELECT count(*) cnt
+ |FROM item i left semi join item j on j.i_category = i.i_category
+ |where
+ |i.i_current_price > 1.0 """.stripMargin
+ val df = spark.sql(testSql)
+ val smjTransformers = df.queryExecution.executedPlan.collect {
+ case f: CHSortMergeJoinExecTransformer => f
+ }
+ assert(smjTransformers.size == 0)
+ assert(FallbackUtil.hasFallback(df.queryExecution.executedPlan))
+ }
+ }
+
+ test("sort merge join: left anti join should fallback") {
+ withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "-1") {
+ val testSql =
+ """SELECT count(*) cnt
+ |FROM item i left anti join item j on j.i_category = i.i_category
+ |where
+ |i.i_current_price > 1.0 """.stripMargin
+ val df = spark.sql(testSql)
+ val smjTransformers = df.queryExecution.executedPlan.collect {
+ case f: CHSortMergeJoinExecTransformer => f
+ }
+ assert(smjTransformers.size == 0)
+ assert(FallbackUtil.hasFallback(df.queryExecution.executedPlan))
+ }
+ }
+
+ val createItem =
+ """CREATE TABLE myitem (
+ | i_current_price DECIMAL(7,2),
+ | i_category STRING)
+ |USING parquet""".stripMargin
+
+ val insertItem =
+ """insert into myitem values
+ |(null,null),
+ |(null,null),
+ |(0.63,null),
+ |(0.74,null),
+ |(null,null),
+ |(90.72,'Books'),
+ |(99.89,'Books'),
+ |(99.41,'Books')
+ |""".stripMargin
+
+ test("sort merge join: full outer join") {
+ withTable("myitem") {
+ withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "-1") {
+ spark.sql(createItem)
+ spark.sql(insertItem)
+ val testSql =
+ """SELECT count(*) cnt
+ |FROM myitem i full outer join myitem j on j.i_category =
i.i_category
+ """.stripMargin
+ compareResultsAgainstVanillaSpark(
+ testSql,
+ true,
+ df => {
+ val smjTransformers = df.queryExecution.executedPlan.collect {
+ case f: CHSortMergeJoinExecTransformer => f
+ }
+ assert(smjTransformers.size == 1)
+ }
+ )
+ }
+ }
+ }
+
+ test("sort merge join: nulls smallest") {
+ withTable("myitem") {
+ withSQLConf(
+ "spark.sql.autoBroadcastJoinThreshold" -> "-1",
+ "spark.sql.shuffle.partitions" -> "3") {
+ spark.sql(createItem)
+ spark.sql(insertItem)
+ val testSql =
+ """SELECT count(*) cnt
+ |FROM myitem i
+ |where
+ |i.i_current_price > 1.0 *
+ | (SELECT avg(j.i_current_price)
+ | FROM myitem j
+ | WHERE j.i_category = i.i_category
+ | ) """.stripMargin
+ spark.sql(testSql).explain()
+ spark.sql(testSql).show()
+ compareResultsAgainstVanillaSpark(
+ testSql,
+ true,
+ df => {
+ val smjTransformers = df.queryExecution.executedPlan.collect {
+ case f: CHSortMergeJoinExecTransformer => f
+ }
+ assert(smjTransformers.size == 1)
+ }
+ )
+ }
+ }
+
+ }
+
+}
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/SparkPlanExecApiImpl.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/SparkPlanExecApiImpl.scala
index b723b865a..82c4bdde8 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/SparkPlanExecApiImpl.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/SparkPlanExecApiImpl.scala
@@ -313,6 +313,26 @@ class SparkPlanExecApiImpl extends SparkPlanExecApi {
right,
isNullAwareAntiJoin)
+ override def genSortMergeJoinExecTransformer(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ joinType: JoinType,
+ condition: Option[Expression],
+ left: SparkPlan,
+ right: SparkPlan,
+ isSkewJoin: Boolean = false,
+ projectList: Seq[NamedExpression] = null):
SortMergeJoinExecTransformerBase = {
+ SortMergeJoinExecTransformer(
+ leftKeys,
+ rightKeys,
+ joinType,
+ condition,
+ left,
+ right,
+ isSkewJoin,
+ projectList
+ )
+ }
override def genCartesianProductExecTransformer(
left: SparkPlan,
right: SparkPlan,
diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.cpp
b/cpp-ch/local-engine/Parser/JoinRelParser.cpp
index 60cc87b74..f02a10b11 100644
--- a/cpp-ch/local-engine/Parser/JoinRelParser.cpp
+++ b/cpp-ch/local-engine/Parser/JoinRelParser.cpp
@@ -18,6 +18,7 @@
#include <IO/ReadBufferFromString.h>
#include <IO/ReadHelpers.h>
#include <Interpreters/CollectJoinOnKeysVisitor.h>
+#include <Interpreters/FullSortingMergeJoin.h>
#include <Interpreters/GraceHashJoin.h>
#include <Interpreters/HashJoin.h>
#include <Interpreters/TableJoin.h>
@@ -42,8 +43,10 @@ namespace ErrorCodes
struct JoinOptimizationInfo
{
- bool is_broadcast;
- bool is_null_aware_anti_join;
+ bool is_broadcast = false;
+ bool is_smj = false;
+ bool is_null_aware_anti_join = false;
+ bool is_existence_join = false;
std::string storage_join_key;
};
@@ -53,20 +56,40 @@ JoinOptimizationInfo parseJoinOptimizationInfo(const
substrait::JoinRel & join)
{
google::protobuf::StringValue optimization;
optimization.ParseFromString(join.advanced_extension().optimization().value());
- ReadBufferFromString in(optimization.value());
- assertString("JoinParameters:", in);
- assertString("isBHJ=", in);
JoinOptimizationInfo info;
- readBoolText(info.is_broadcast, in);
- assertChar('\n', in);
- if (info.is_broadcast)
+ if (optimization.value().contains("isBHJ="))
{
- assertString("isNullAwareAntiJoin=", in);
- readBoolText(info.is_null_aware_anti_join, in);
+ ReadBufferFromString in(optimization.value());
+ assertString("JoinParameters:", in);
+ assertString("isBHJ=", in);
+ readBoolText(info.is_broadcast, in);
assertChar('\n', in);
- assertString("buildHashTableId=", in);
- readString(info.storage_join_key, in);
+ if (info.is_broadcast)
+ {
+ assertString("isNullAwareAntiJoin=", in);
+ readBoolText(info.is_null_aware_anti_join, in);
+ assertChar('\n', in);
+ assertString("buildHashTableId=", in);
+ readString(info.storage_join_key, in);
+ assertChar('\n', in);
+ }
+ }
+ else
+ {
+ ReadBufferFromString in(optimization.value());
+ assertString("JoinParameters:", in);
+ assertString("isSMJ=", in);
+ readBoolText(info.is_smj, in);
assertChar('\n', in);
+ if (info.is_smj)
+ {
+ assertString("isNullAwareAntiJoin=", in);
+ readBoolText(info.is_null_aware_anti_join, in);
+ assertChar('\n', in);
+ assertString("isExistenceJoin=", in);
+ readBoolText(info.is_existence_join, in);
+ assertChar('\n', in);
+ }
}
return info;
}
@@ -101,6 +124,8 @@ std::pair<DB::JoinKind, DB::JoinStrictness>
getJoinKindAndStrictness(substrait::
return {DB::JoinKind::Left, DB::JoinStrictness::Anti};
case substrait::JoinRel_JoinType_JOIN_TYPE_LEFT:
return {DB::JoinKind::Left, DB::JoinStrictness::All};
+ case substrait::JoinRel_JoinType_JOIN_TYPE_RIGHT:
+ return {DB::JoinKind::Right, DB::JoinStrictness::All};
case substrait::JoinRel_JoinType_JOIN_TYPE_OUTER:
return {DB::JoinKind::Full, DB::JoinStrictness::All};
default:
@@ -189,13 +214,29 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const
substrait::JoinRel & join, DB::Q
auto broadcast_hash_join = storage_join->getJoinLocked(table_join,
context);
QueryPlanStepPtr join_step =
std::make_unique<FilledJoinStep>(left->getCurrentDataStream(),
broadcast_hash_join, 8192);
- join_step->setStepDescription("JOIN");
+ join_step->setStepDescription("STORAGE_JOIN");
steps.emplace_back(join_step.get());
left->addStep(std::move(join_step));
query_plan = std::move(left);
/// hold right plan for profile
extra_plan_holder.emplace_back(std::move(right));
}
+ else if (join_opt_info.is_smj)
+ {
+ JoinPtr smj_join = std::make_shared<FullSortingMergeJoin>(table_join,
right->getCurrentDataStream().header.cloneEmpty(), -1);
+ MultiEnum<DB::JoinAlgorithm> join_algorithm =
context->getSettingsRef().join_algorithm;
+ QueryPlanStepPtr join_step
+ =
std::make_unique<DB::JoinStep>(left->getCurrentDataStream(),
right->getCurrentDataStream(), smj_join, 8192, 1, false);
+
+ join_step->setStepDescription("SORT_MERGE_JOIN");
+ steps.emplace_back(join_step.get());
+ std::vector<QueryPlanPtr> plans;
+ plans.emplace_back(std::move(left));
+ plans.emplace_back(std::move(right));
+
+ query_plan = std::make_unique<QueryPlan>();
+ query_plan->unitePlans(std::move(join_step), {std::move(plans)});
+ }
else
{
/// TODO: make grace hash join be the default hash join algorithm.
@@ -225,7 +266,7 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const
substrait::JoinRel & join, DB::Q
QueryPlanStepPtr join_step
= std::make_unique<DB::JoinStep>(left->getCurrentDataStream(),
right->getCurrentDataStream(), hash_join, 8192, 1, false);
- join_step->setStepDescription("JOIN");
+ join_step->setStepDescription("HASH_JOIN");
steps.emplace_back(join_step.get());
std::vector<QueryPlanPtr> plans;
plans.emplace_back(std::move(left));
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
index c7d17c84a..59244c380 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
@@ -130,6 +130,17 @@ trait SparkPlanExecApi {
right: SparkPlan,
isNullAwareAntiJoin: Boolean = false):
BroadcastHashJoinExecTransformerBase
+ /** Generate ShuffledHashJoinExecTransformer. */
+ def genSortMergeJoinExecTransformer(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ joinType: JoinType,
+ condition: Option[Expression],
+ left: SparkPlan,
+ right: SparkPlan,
+ isSkewJoin: Boolean = false,
+ projectList: Seq[NamedExpression] = null):
SortMergeJoinExecTransformerBase
+
/** Generate CartesianProductExecTransformer. */
def genCartesianProductExecTransformer(
left: SparkPlan,
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/execution/SortMergeJoinExecTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/execution/SortMergeJoinExecTransformer.scala
index ffc007326..5ca11a53c 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/execution/SortMergeJoinExecTransformer.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/execution/SortMergeJoinExecTransformer.scala
@@ -25,23 +25,26 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.joins.BaseJoinExec
import org.apache.spark.sql.vectorized.ColumnarBatch
import com.google.protobuf.{Any, StringValue}
import io.substrait.proto.JoinRel
-/** Performs a sort merge join of two child relations. */
-case class SortMergeJoinExecTransformer(
+trait MergeJoinLikeExecTransformer
+ extends BaseJoinExec
+ with TransformSupport
+ with ColumnarShuffledJoin {}
+abstract class SortMergeJoinExecTransformerBase(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
condition: Option[Expression],
left: SparkPlan,
right: SparkPlan,
- isSkewJoin: Boolean = false)
- extends ColumnarShuffledJoin
- with TransformSupport {
-
+ isSkewJoin: Boolean = false,
+ projectList: Seq[NamedExpression] = null)
+ extends MergeJoinLikeExecTransformer {
// Note: "metrics" is made transient to avoid sending driver-side metrics to
tasks.
@transient override lazy val metrics =
BackendsApiManager.getMetricsApiInstance.genSortMergeJoinTransformerMetrics(sparkContext)
@@ -229,6 +232,38 @@ case class SortMergeJoinExecTransformer(
JoinUtils.createTransformContext(false, output, joinRel,
inputStreamedOutput, inputBuildOutput)
}
+}
+
+/** Performs a sort merge join of two child relations. */
+case class SortMergeJoinExecTransformer(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ joinType: JoinType,
+ condition: Option[Expression],
+ left: SparkPlan,
+ right: SparkPlan,
+ isSkewJoin: Boolean = false,
+ projectList: Seq[NamedExpression] = null)
+ extends SortMergeJoinExecTransformerBase(
+ leftKeys,
+ rightKeys,
+ joinType,
+ condition,
+ left,
+ right,
+ isSkewJoin,
+ projectList) {
+
+ override protected def doValidateInternal(): ValidationResult = {
+ val substraitContext = new SubstraitContext
+ // Firstly, need to check if the Substrait plan for this operator can be
successfully generated.
+ if (substraitJoinType == JoinRel.JoinType.JOIN_TYPE_OUTER) {
+ return ValidationResult
+ .notOk(s"Found unsupported join type of $joinType for velox smj:
$substraitJoinType")
+ }
+ super.doValidateInternal()
+ }
+
override protected def withNewChildrenInternal(
newLeft: SparkPlan,
newRight: SparkPlan): SortMergeJoinExecTransformer =
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformHintRule.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformHintRule.scala
index aa68ce72e..1dc56d254 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformHintRule.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformHintRule.scala
@@ -29,7 +29,6 @@ import org.apache.spark.api.python.EvalPythonExecTransformer
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{AttributeReference,
Expression}
-import org.apache.spark.sql.catalyst.plans.FullOuter
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.execution._
@@ -619,19 +618,18 @@ case class AddTransformHintRule() extends Rule[SparkPlan]
{
transformer.doValidate().tagOnFallback(plan)
}
case plan: SortMergeJoinExec =>
- if (!enableColumnarSortMergeJoin || plan.joinType == FullOuter) {
- TransformHints.tagNotTransformable(
- plan,
- "columnar sort merge join is not enabled or join type is
FullOuter")
+ if (!enableColumnarSortMergeJoin) {
+ TransformHints.tagNotTransformable(plan, "columnar sort merge join
is not enabled")
} else {
- val transformer = SortMergeJoinExecTransformer(
- plan.leftKeys,
- plan.rightKeys,
- plan.joinType,
- plan.condition,
- plan.left,
- plan.right,
- plan.isSkewJoin)
+ val transformer = BackendsApiManager.getSparkPlanExecApiInstance
+ .genSortMergeJoinExecTransformer(
+ plan.leftKeys,
+ plan.rightKeys,
+ plan.joinType,
+ plan.condition,
+ plan.left,
+ plan.right,
+ plan.isSkewJoin)
transformer.doValidate().tagOnFallback(plan)
}
case plan: CartesianProductExec =>
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transform/ImplementSingleNode.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transform/ImplementSingleNode.scala
index 22edfaa5d..5820de270 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transform/ImplementSingleNode.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transform/ImplementSingleNode.scala
@@ -192,14 +192,15 @@ case class ImplementJoin() extends ImplementSingleNode
with LogLevelUtil {
val left = plan.left
val right = plan.right
logDebug(s"Columnar Processing for ${plan.getClass} is currently
supported.")
- SortMergeJoinExecTransformer(
- plan.leftKeys,
- plan.rightKeys,
- plan.joinType,
- plan.condition,
- left,
- right,
- plan.isSkewJoin)
+ BackendsApiManager.getSparkPlanExecApiInstance
+ .genSortMergeJoinExecTransformer(
+ plan.leftKeys,
+ plan.rightKeys,
+ plan.joinType,
+ plan.condition,
+ left,
+ right,
+ plan.isSkewJoin)
case plan: BroadcastHashJoinExec =>
val left = plan.left
val right = plan.right
diff --git
a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index bfecd2292..4c1f8e65e 100644
---
a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++
b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -1187,6 +1187,10 @@ class ClickHouseTestSettings extends BackendTestSettings
{
.exclude("SPARK-36020: Check logical link in remove redundant projects")
.exclude("SPARK-36032: Use inputPlan instead of currentPhysicalPlan to
initialize logical link")
.exclude("SPARK-37742: AQE reads invalid InMemoryRelation stats and
mistakenly plans BHJ")
+ // SMJ Exec have changed to CH SMJ Transformer
+ .exclude("Change broadcast join to merge join")
+ .exclude("Avoid plan change if cost is greater")
+ .exclude("SPARK-30524: Do not optimize skew join if introduce additional
shuffle")
.excludeGlutenTest("Change broadcast join to merge join")
.excludeGlutenTest("Empty stage coalesced to 1-partition RDD")
.excludeGlutenTest(
diff --git
a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index 385ef6381..be6273857 100644
---
a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++
b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -1232,12 +1232,17 @@ class ClickHouseTestSettings extends
BackendTestSettings {
.exclude("SPARK-37742: AQE reads invalid InMemoryRelation stats and
mistakenly plans BHJ")
.exclude("SPARK-37328: skew join with 3 tables")
.exclude("SPARK-39915: Dataset.repartition(N) may not create N partitions")
- .excludeGlutenTest("Change broadcast join to merge join")
- .excludeGlutenTest("Change broadcast join to merge join")
+ .exclude("Change broadcast join to merge join")
+ .exclude("Avoid plan change if cost is greater")
+ .exclude("SPARK-37652: optimize skewed join through union")
+ .exclude("SPARK-35455: Unify empty relation optimization between normal
and AQE optimizer " +
+ "- single join")
+ .exclude("SPARK-35455: Unify empty relation optimization between normal
and AQE optimizer " +
+ "- multi join")
.excludeGlutenTest("Empty stage coalesced to 1-partition RDD")
.excludeGlutenTest(
"Avoid changing merge join to broadcast join if too many empty
partitions on build plan")
- .excludeGlutenTest("SPARK-30524: Do not optimize skew join if introduce
additional shuffle")
+ .exclude("SPARK-30524: Do not optimize skew join if introduce additional
shuffle")
.excludeGlutenTest("SPARK-33551: Do not use AQE shuffle read for
repartition")
.excludeGlutenTest("SPARK-35264: Support AQE side broadcastJoin threshold")
.excludeGlutenTest("SPARK-35264: Support AQE side shuffled hash join
formula")
diff --git
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/connector/GlutenKeyGroupedPartitioningSuite.scala
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/connector/GlutenKeyGroupedPartitioningSuite.scala
index 12ce843cc..8bc67f6f8 100644
---
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/connector/GlutenKeyGroupedPartitioningSuite.scala
+++
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/connector/GlutenKeyGroupedPartitioningSuite.scala
@@ -15,7 +15,8 @@
* limitations under the License.
*/
package org.apache.spark.sql.connector
-import org.apache.gluten.execution.SortMergeJoinExecTransformer
+
+import org.apache.gluten.execution.SortMergeJoinExecTransformerBase
import org.apache.spark.SparkConf
import org.apache.spark.sql.GlutenSQLTestsBaseTrait
@@ -109,7 +110,7 @@ class GlutenKeyGroupedPartitioningSuite
plan: SparkPlan): Seq[ColumnarShuffleExchangeExec] = {
// here we skip collecting shuffle operators that are not associated with
SMJ
collect(plan) {
- case s: SortMergeJoinExecTransformer => s
+ case s: SortMergeJoinExecTransformerBase => s
case s: SortMergeJoinExec => s
}.flatMap(smj => collect(smj) { case s: ColumnarShuffleExchangeExec => s })
}
diff --git
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/adaptive/clickhouse/ClickHouseAdaptiveQueryExecSuite.scala
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/adaptive/clickhouse/ClickHouseAdaptiveQueryExecSuite.scala
index 091abe11d..9ddaac185 100644
---
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/adaptive/clickhouse/ClickHouseAdaptiveQueryExecSuite.scala
+++
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/adaptive/clickhouse/ClickHouseAdaptiveQueryExecSuite.scala
@@ -16,7 +16,7 @@
*/
package org.apache.spark.sql.execution.adaptive.clickhouse
-import org.apache.gluten.execution.{BroadcastHashJoinExecTransformerBase,
ShuffledHashJoinExecTransformerBase, SortExecTransformer,
SortMergeJoinExecTransformer}
+import org.apache.gluten.execution.{BroadcastHashJoinExecTransformerBase,
ShuffledHashJoinExecTransformerBase, SortExecTransformer,
SortMergeJoinExecTransformerBase}
import org.apache.spark.SparkConf
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent}
@@ -102,8 +102,8 @@ class ClickHouseAdaptiveQueryExecSuite extends
AdaptiveQueryExecSuite with Glute
}
private def findTopLevelSortMergeJoinTransform(
- plan: SparkPlan): Seq[SortMergeJoinExecTransformer] = {
- collect(plan) { case j: SortMergeJoinExecTransformer => j }
+ plan: SparkPlan): Seq[SortMergeJoinExecTransformerBase] = {
+ collect(plan) { case j: SortMergeJoinExecTransformerBase => j }
}
private def sortMergeJoinSize(plan: SparkPlan): Int = {
@@ -260,7 +260,7 @@ class ClickHouseAdaptiveQueryExecSuite extends
AdaptiveQueryExecSuite with Glute
.count()
checkAnswer(testDf, Seq())
val plan = testDf.queryExecution.executedPlan
-
assert(find(plan)(_.isInstanceOf[SortMergeJoinExecTransformer]).isDefined)
+
assert(find(plan)(_.isInstanceOf[SortMergeJoinExecTransformerBase]).isDefined)
}
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1") {
@@ -671,7 +671,7 @@ class ClickHouseAdaptiveQueryExecSuite extends
AdaptiveQueryExecSuite with Glute
rightSkewNum: Int): Unit = {
assert(joins.size == 1)
joins.head match {
- case s: SortMergeJoinExecTransformer => assert(s.isSkewJoin)
+ case s: SortMergeJoinExecTransformerBase =>
assert(s.isSkewJoin)
case g: ShuffledHashJoinExecTransformerBase =>
assert(g.isSkewJoin)
case _ => assert(false)
}
@@ -804,11 +804,6 @@ class ClickHouseAdaptiveQueryExecSuite extends
AdaptiveQueryExecSuite with Glute
ignore(
GLUTEN_TEST + "SPARK-32573: Eliminate NAAJ when BuildSide is
HashedRelationWithAllNullKeys") {}
- // EmptyRelation case
- ignore(
- GLUTEN_TEST + "SPARK-35455: Unify empty relation optimization " +
- "between normal and AQE optimizer - single join") {}
-
testGluten("SPARK-32753: Only copy tags to node with no tags") {
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
withTempView("v1") {
@@ -1511,4 +1506,56 @@ class ClickHouseAdaptiveQueryExecSuite extends
AdaptiveQueryExecSuite with Glute
}
}
}
+
+ testGluten("SPARK-37652: optimize skewed join through union") {
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "100",
+ SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100"
+ ) {
+ withTempView("skewData1", "skewData2") {
+ spark
+ .range(0, 1000, 1, 10)
+ .selectExpr("id % 3 as key1", "id as value1")
+ .createOrReplaceTempView("skewData1")
+ spark
+ .range(0, 1000, 1, 10)
+ .selectExpr("id % 1 as key2", "id as value2")
+ .createOrReplaceTempView("skewData2")
+
+ def checkSkewJoin(query: String, joinNums: Int, optimizeSkewJoinNums:
Int): Unit = {
+ val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult(query)
+ val joins = findTopLevelSortMergeJoinTransform(innerAdaptivePlan)
+ val optimizeSkewJoins = joins.filter(_.isSkewJoin)
+ assert(joins.size == joinNums && optimizeSkewJoins.size ==
optimizeSkewJoinNums)
+ }
+
+ // skewJoin union skewJoin
+ checkSkewJoin(
+ "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " +
+ "UNION ALL SELECT key2 FROM skewData1 JOIN skewData2 ON key1 =
key2",
+ 2,
+ 2)
+
+ // skewJoin union aggregate
+ checkSkewJoin(
+ "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " +
+ "UNION ALL SELECT key2 FROM skewData2 GROUP BY key2",
+ 1,
+ 1)
+
+ // skewJoin1 union (skewJoin2 join aggregate)
+ // skewJoin2 will lead to extra shuffles, but skew1 cannot be optimized
+ checkSkewJoin(
+ "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 UNION ALL
" +
+ "SELECT key1 from (SELECT key1 FROM skewData1 JOIN skewData2 ON
key1 = key2) tmp1 " +
+ "JOIN (SELECT key2 FROM skewData2 GROUP BY key2) tmp2 ON key1 =
key2",
+ 3,
+ 0
+ )
+ }
+ }
+ }
+
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]