This is an automated email from the ASF dual-hosted git repository.

changchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new d2b50ac54 [GLUTEN-4745][CH] support Sort Merge Join (#4812)
d2b50ac54 is described below

commit d2b50ac5413af9be07118952fc5abff5fe0f4c5d
Author: loudongfeng <[email protected]>
AuthorDate: Tue Apr 2 10:35:23 2024 +0800

    [GLUTEN-4745][CH] support Sort Merge Join (#4812)
    
    Support inner and outer joins
---
 .../gluten/backendsapi/clickhouse/CHBackend.scala  |   2 +-
 .../backendsapi/clickhouse/CHListenerApi.scala     |  16 ++
 .../clickhouse/CHSparkPlanExecApi.scala            |  20 ++
 .../execution/CHSortMergeJoinExecTransformer.scala |  58 ++++++
 .../apache/gluten/utils/CHJoinValidateUtil.scala   |  11 +-
 ...nClickHouseTPCDSParquetSortMergeJoinSuite.scala | 221 +++++++++++++++++++++
 .../backendsapi/velox/SparkPlanExecApiImpl.scala   |  20 ++
 cpp-ch/local-engine/Parser/JoinRelParser.cpp       |  69 +++++--
 .../gluten/backendsapi/SparkPlanExecApi.scala      |  11 +
 .../execution/SortMergeJoinExecTransformer.scala   |  47 ++++-
 .../extension/columnar/TransformHintRule.scala     |  24 +--
 .../columnar/transform/ImplementSingleNode.scala   |  17 +-
 .../utils/clickhouse/ClickHouseTestSettings.scala  |   4 +
 .../utils/clickhouse/ClickHouseTestSettings.scala  |  11 +-
 .../GlutenKeyGroupedPartitioningSuite.scala        |   5 +-
 .../ClickHouseAdaptiveQueryExecSuite.scala         |  67 ++++++-
 16 files changed, 545 insertions(+), 58 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
index d85b6e269..7217a979b 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
@@ -216,7 +216,7 @@ object CHBackendSettings extends BackendSettingsApi with 
Logging {
   }
 
   override def supportSortMergeJoinExec(): Boolean = {
-    false
+    GlutenConfig.getConf.enableColumnarSortMergeJoin
   }
 
   override def supportWindowExec(windowFunctions: Seq[NamedExpression]): 
Boolean = {
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
index 5901a11e4..17883cf24 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
@@ -25,6 +25,7 @@ import 
org.apache.gluten.vectorized.{CHNativeExpressionEvaluator, JniLibLoader}
 
 import org.apache.spark.SparkConf
 import org.apache.spark.internal.Logging
+import org.apache.spark.network.util.JavaUtils
 import org.apache.spark.sql.execution.datasources.v1._
 
 import org.apache.commons.lang3.StringUtils
@@ -62,6 +63,21 @@ class CHListenerApi extends ListenerApi with Logging {
         s".local_engine.settings.log_processors_profiles",
       "true")
 
+    // add memory limit for external sort
+    val externalSortKey = 
s"${CHBackendSettings.getBackendConfigPrefix}.runtime_settings" +
+      s".max_bytes_before_external_sort"
+    if (conf.getInt(externalSortKey, -1) < 0) {
+      if (conf.getBoolean("spark.memory.offHeap.enabled", false)) {
+        val memSize = 
JavaUtils.byteStringAsBytes(conf.get("spark.memory.offHeap.size")).toInt
+        if (memSize > 0) {
+          val cores = conf.getInt("spark.executor.cores", 1)
+          val sortMemLimit = ((memSize / cores) * 0.8).toInt
+          logInfo(s"max memory for sorting: $sortMemLimit")
+          conf.set(externalSortKey, sortMemLimit.toString)
+        }
+      }
+    }
+
     // Load supported hive/python/scala udfs
     UDFMappings.loadFromSparkConf(conf)
 
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index fdc303525..57fd81ba0 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -368,6 +368,26 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
       right,
       isNullAwareAntiJoin)
 
+  override def genSortMergeJoinExecTransformer(
+      leftKeys: Seq[Expression],
+      rightKeys: Seq[Expression],
+      joinType: JoinType,
+      condition: Option[Expression],
+      left: SparkPlan,
+      right: SparkPlan,
+      isSkewJoin: Boolean = false,
+      projectList: Seq[NamedExpression] = null): 
SortMergeJoinExecTransformerBase =
+    CHSortMergeJoinExecTransformer(
+      leftKeys,
+      rightKeys,
+      joinType,
+      condition,
+      left,
+      right,
+      isSkewJoin,
+      projectList
+    )
+
   /** Generate CartesianProductExecTransformer */
   override def genCartesianProductExecTransformer(
       left: SparkPlan,
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHSortMergeJoinExecTransformer.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHSortMergeJoinExecTransformer.scala
new file mode 100644
index 000000000..a5ac5f658
--- /dev/null
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHSortMergeJoinExecTransformer.scala
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.execution
+
+import org.apache.gluten.extension.ValidationResult
+import org.apache.gluten.utils.CHJoinValidateUtil
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.execution._
+
+case class CHSortMergeJoinExecTransformer(
+    leftKeys: Seq[Expression],
+    rightKeys: Seq[Expression],
+    joinType: JoinType,
+    condition: Option[Expression],
+    left: SparkPlan,
+    right: SparkPlan,
+    isSkewJoin: Boolean = false,
+    projectList: Seq[NamedExpression] = null)
+  extends SortMergeJoinExecTransformerBase(
+    leftKeys,
+    rightKeys,
+    joinType,
+    condition,
+    left,
+    right,
+    isSkewJoin,
+    projectList) {
+
+  override protected def doValidateInternal(): ValidationResult = {
+    val shouldFallback =
+      CHJoinValidateUtil.shouldFallback(joinType, left.outputSet, 
right.outputSet, condition, true)
+    if (shouldFallback) {
+      return ValidationResult.notOk("ch join validate fail")
+    }
+    super.doValidateInternal()
+  }
+
+  override protected def withNewChildrenInternal(
+      newLeft: SparkPlan,
+      newRight: SparkPlan): CHSortMergeJoinExecTransformer =
+    copy(left = newLeft, right = newRight)
+}
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala
index a2ecf5664..06b2445af 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala
@@ -44,7 +44,8 @@ object CHJoinValidateUtil extends Logging {
       joinType: JoinType,
       leftOutputSet: AttributeSet,
       rightOutputSet: AttributeSet,
-      condition: Option[Expression]): Boolean = {
+      condition: Option[Expression],
+      isSMJ: Boolean = false): Boolean = {
     var shouldFallback = false
     if (joinType.toString.contains("ExistenceJoin")) {
       return true
@@ -52,6 +53,14 @@ object CHJoinValidateUtil extends Logging {
     if (joinType.sql.equals("INNER")) {
       return shouldFallback
     }
+    if (isSMJ) {
+      if (
+        joinType.sql.contains("SEMI")
+        || joinType.sql.contains("ANTI")
+      ) {
+        return true
+      }
+    }
     if (condition.isDefined) {
       condition.get.transform {
         case Or(l, r) =>
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSParquetSortMergeJoinSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSParquetSortMergeJoinSuite.scala
new file mode 100644
index 000000000..bbedcda18
--- /dev/null
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSParquetSortMergeJoinSuite.scala
@@ -0,0 +1,221 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.execution
+
+import org.apache.gluten.utils.FallbackUtil
+
+import org.apache.spark.SparkConf
+
+class GlutenClickHouseTPCDSParquetSortMergeJoinSuite extends 
GlutenClickHouseTPCDSAbstractSuite {
+
+  override protected val tpcdsQueries: String =
+    rootPath + 
"../../../../gluten-core/src/test/resources/tpcds-queries/tpcds.queries.original"
+  override protected val queriesResults: String = rootPath + 
"tpcds-queries-output"
+
+  override protected def excludedTpcdsQueries: Set[String] = Set(
+    // fallback due to left semi/anti
+    "q8",
+    "q14a",
+    "q14b",
+    "q23a",
+    "q23b",
+    "q51",
+    "q69",
+    "q70",
+    "q78",
+    "q95",
+    "q97"
+  ) ++ super.excludedTpcdsQueries
+
+  /** Run Gluten + ClickHouse Backend with SortShuffleManager */
+  override protected def sparkConf: SparkConf = {
+    super.sparkConf
+      .set("spark.shuffle.manager", "sort")
+      .set("spark.io.compression.codec", "snappy")
+      .set("spark.sql.shuffle.partitions", "5")
+      .set("spark.sql.autoBroadcastJoinThreshold", "10MB")
+      .set("spark.memory.offHeap.size", "8g")
+      .set("spark.gluten.sql.columnar.forceShuffledHashJoin", "false")
+  }
+
+  executeTPCDSTest(false)
+
+  test("sort merge join: inner join") {
+    withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "-1") {
+      val testSql =
+        """SELECT  count(*) cnt
+          |FROM item i join item j on j.i_category = i.i_category
+          |where
+          |i.i_current_price > 1.0 """.stripMargin
+      compareResultsAgainstVanillaSpark(
+        testSql,
+        true,
+        df => {
+          val smjTransformers = df.queryExecution.executedPlan.collect {
+            case f: CHSortMergeJoinExecTransformer => f
+          }
+          assert(smjTransformers.size == 1)
+        }
+      )
+    }
+  }
+
+  test("sort merge join: left outer join") {
+    withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "-1") {
+      val testSql =
+        """SELECT  count(*) cnt
+          |FROM item i left outer join item j on j.i_category = i.i_category
+        """.stripMargin
+      compareResultsAgainstVanillaSpark(
+        testSql,
+        true,
+        df => {
+          val smjTransformers = df.queryExecution.executedPlan.collect {
+            case f: CHSortMergeJoinExecTransformer => f
+          }
+          assert(smjTransformers.size == 1)
+        }
+      )
+    }
+  }
+
+  test("sort merge join: right outer join") {
+    withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "-1") {
+      val testSql =
+        """SELECT  count(*) cnt
+          |FROM item i right outer join item j on j.i_category = i.i_category
+        """.stripMargin
+      compareResultsAgainstVanillaSpark(
+        testSql,
+        true,
+        df => {
+          val smjTransformers = df.queryExecution.executedPlan.collect {
+            case f: CHSortMergeJoinExecTransformer => f
+          }
+          assert(smjTransformers.size == 1)
+        }
+      )
+    }
+  }
+
+  test("sort merge join: left semi join should fallback") {
+    withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "-1") {
+      val testSql =
+        """SELECT  count(*) cnt
+          |FROM item i left semi join item j on j.i_category = i.i_category
+          |where
+          |i.i_current_price > 1.0 """.stripMargin
+      val df = spark.sql(testSql)
+      val smjTransformers = df.queryExecution.executedPlan.collect {
+        case f: CHSortMergeJoinExecTransformer => f
+      }
+      assert(smjTransformers.size == 0)
+      assert(FallbackUtil.hasFallback(df.queryExecution.executedPlan))
+    }
+  }
+
+  test("sort merge join: left anti join should fallback") {
+    withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "-1") {
+      val testSql =
+        """SELECT  count(*) cnt
+          |FROM item i left anti join item j on j.i_category = i.i_category
+          |where
+          |i.i_current_price > 1.0 """.stripMargin
+      val df = spark.sql(testSql)
+      val smjTransformers = df.queryExecution.executedPlan.collect {
+        case f: CHSortMergeJoinExecTransformer => f
+      }
+      assert(smjTransformers.size == 0)
+      assert(FallbackUtil.hasFallback(df.queryExecution.executedPlan))
+    }
+  }
+
+  val createItem =
+    """CREATE TABLE myitem (
+      |  i_current_price DECIMAL(7,2),
+      |  i_category STRING)
+      |USING parquet""".stripMargin
+
+  val insertItem =
+    """insert into myitem values
+      |(null,null),
+      |(null,null),
+      |(0.63,null),
+      |(0.74,null),
+      |(null,null),
+      |(90.72,'Books'),
+      |(99.89,'Books'),
+      |(99.41,'Books')
+      |""".stripMargin
+
+  test("sort merge join: full outer join") {
+    withTable("myitem") {
+      withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "-1") {
+        spark.sql(createItem)
+        spark.sql(insertItem)
+        val testSql =
+          """SELECT  count(*) cnt
+            |FROM myitem i full outer join myitem j on j.i_category = 
i.i_category
+          """.stripMargin
+        compareResultsAgainstVanillaSpark(
+          testSql,
+          true,
+          df => {
+            val smjTransformers = df.queryExecution.executedPlan.collect {
+              case f: CHSortMergeJoinExecTransformer => f
+            }
+            assert(smjTransformers.size == 1)
+          }
+        )
+      }
+    }
+  }
+
+  test("sort merge join: nulls smallest") {
+    withTable("myitem") {
+      withSQLConf(
+        "spark.sql.autoBroadcastJoinThreshold" -> "-1",
+        "spark.sql.shuffle.partitions" -> "3") {
+        spark.sql(createItem)
+        spark.sql(insertItem)
+        val testSql =
+          """SELECT  count(*) cnt
+            |FROM myitem i
+            |where
+            |i.i_current_price > 1.0 *
+            |  (SELECT avg(j.i_current_price)
+            |  FROM myitem j
+            |  WHERE j.i_category = i.i_category
+            | ) """.stripMargin
+        spark.sql(testSql).explain()
+        spark.sql(testSql).show()
+        compareResultsAgainstVanillaSpark(
+          testSql,
+          true,
+          df => {
+            val smjTransformers = df.queryExecution.executedPlan.collect {
+              case f: CHSortMergeJoinExecTransformer => f
+            }
+            assert(smjTransformers.size == 1)
+          }
+        )
+      }
+    }
+
+  }
+
+}
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/SparkPlanExecApiImpl.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/SparkPlanExecApiImpl.scala
index b723b865a..82c4bdde8 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/SparkPlanExecApiImpl.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/SparkPlanExecApiImpl.scala
@@ -313,6 +313,26 @@ class SparkPlanExecApiImpl extends SparkPlanExecApi {
       right,
       isNullAwareAntiJoin)
 
+  override def genSortMergeJoinExecTransformer(
+      leftKeys: Seq[Expression],
+      rightKeys: Seq[Expression],
+      joinType: JoinType,
+      condition: Option[Expression],
+      left: SparkPlan,
+      right: SparkPlan,
+      isSkewJoin: Boolean = false,
+      projectList: Seq[NamedExpression] = null): 
SortMergeJoinExecTransformerBase = {
+    SortMergeJoinExecTransformer(
+      leftKeys,
+      rightKeys,
+      joinType,
+      condition,
+      left,
+      right,
+      isSkewJoin,
+      projectList
+    )
+  }
   override def genCartesianProductExecTransformer(
       left: SparkPlan,
       right: SparkPlan,
diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.cpp 
b/cpp-ch/local-engine/Parser/JoinRelParser.cpp
index 60cc87b74..f02a10b11 100644
--- a/cpp-ch/local-engine/Parser/JoinRelParser.cpp
+++ b/cpp-ch/local-engine/Parser/JoinRelParser.cpp
@@ -18,6 +18,7 @@
 #include <IO/ReadBufferFromString.h>
 #include <IO/ReadHelpers.h>
 #include <Interpreters/CollectJoinOnKeysVisitor.h>
+#include <Interpreters/FullSortingMergeJoin.h>
 #include <Interpreters/GraceHashJoin.h>
 #include <Interpreters/HashJoin.h>
 #include <Interpreters/TableJoin.h>
@@ -42,8 +43,10 @@ namespace ErrorCodes
 
 struct JoinOptimizationInfo
 {
-    bool is_broadcast;
-    bool is_null_aware_anti_join;
+    bool is_broadcast = false;
+    bool is_smj = false;
+    bool is_null_aware_anti_join = false;
+    bool is_existence_join = false;
     std::string storage_join_key;
 };
 
@@ -53,20 +56,40 @@ JoinOptimizationInfo parseJoinOptimizationInfo(const 
substrait::JoinRel & join)
 {
     google::protobuf::StringValue optimization;
     
optimization.ParseFromString(join.advanced_extension().optimization().value());
-    ReadBufferFromString in(optimization.value());
-    assertString("JoinParameters:", in);
-    assertString("isBHJ=", in);
     JoinOptimizationInfo info;
-    readBoolText(info.is_broadcast, in);
-    assertChar('\n', in);
-    if (info.is_broadcast)
+    if (optimization.value().contains("isBHJ="))
     {
-        assertString("isNullAwareAntiJoin=", in);
-        readBoolText(info.is_null_aware_anti_join, in);
+        ReadBufferFromString in(optimization.value());
+        assertString("JoinParameters:", in);
+        assertString("isBHJ=", in);
+        readBoolText(info.is_broadcast, in);
         assertChar('\n', in);
-        assertString("buildHashTableId=", in);
-        readString(info.storage_join_key, in);
+        if (info.is_broadcast)
+        {
+            assertString("isNullAwareAntiJoin=", in);
+            readBoolText(info.is_null_aware_anti_join, in);
+            assertChar('\n', in);
+            assertString("buildHashTableId=", in);
+            readString(info.storage_join_key, in);
+            assertChar('\n', in);
+        }
+    }
+    else
+    {
+        ReadBufferFromString in(optimization.value());
+        assertString("JoinParameters:", in);
+        assertString("isSMJ=", in);
+        readBoolText(info.is_smj, in);
         assertChar('\n', in);
+        if (info.is_smj)
+        {
+            assertString("isNullAwareAntiJoin=", in);
+            readBoolText(info.is_null_aware_anti_join, in);
+            assertChar('\n', in);
+            assertString("isExistenceJoin=", in);
+            readBoolText(info.is_existence_join, in);
+            assertChar('\n', in);
+        }
     }
     return info;
 }
@@ -101,6 +124,8 @@ std::pair<DB::JoinKind, DB::JoinStrictness> 
getJoinKindAndStrictness(substrait::
             return {DB::JoinKind::Left, DB::JoinStrictness::Anti};
         case substrait::JoinRel_JoinType_JOIN_TYPE_LEFT:
             return {DB::JoinKind::Left, DB::JoinStrictness::All};
+        case substrait::JoinRel_JoinType_JOIN_TYPE_RIGHT:
+            return {DB::JoinKind::Right, DB::JoinStrictness::All};
         case substrait::JoinRel_JoinType_JOIN_TYPE_OUTER:
             return {DB::JoinKind::Full, DB::JoinStrictness::All};
         default:
@@ -189,13 +214,29 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const 
substrait::JoinRel & join, DB::Q
         auto broadcast_hash_join = storage_join->getJoinLocked(table_join, 
context);
         QueryPlanStepPtr join_step = 
std::make_unique<FilledJoinStep>(left->getCurrentDataStream(), 
broadcast_hash_join, 8192);
 
-        join_step->setStepDescription("JOIN");
+        join_step->setStepDescription("STORAGE_JOIN");
         steps.emplace_back(join_step.get());
         left->addStep(std::move(join_step));
         query_plan = std::move(left);
         /// hold right plan for profile
         extra_plan_holder.emplace_back(std::move(right));
     }
+    else if (join_opt_info.is_smj)
+    {
+        JoinPtr smj_join = std::make_shared<FullSortingMergeJoin>(table_join, 
right->getCurrentDataStream().header.cloneEmpty(), -1);
+        MultiEnum<DB::JoinAlgorithm> join_algorithm = 
context->getSettingsRef().join_algorithm;
+        QueryPlanStepPtr join_step
+                 = 
std::make_unique<DB::JoinStep>(left->getCurrentDataStream(), 
right->getCurrentDataStream(), smj_join, 8192, 1, false);
+
+        join_step->setStepDescription("SORT_MERGE_JOIN");
+        steps.emplace_back(join_step.get());
+        std::vector<QueryPlanPtr> plans;
+        plans.emplace_back(std::move(left));
+        plans.emplace_back(std::move(right));
+
+        query_plan = std::make_unique<QueryPlan>();
+        query_plan->unitePlans(std::move(join_step), {std::move(plans)});
+    }
     else
     {
         /// TODO: make grace hash join be the default hash join algorithm.
@@ -225,7 +266,7 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const 
substrait::JoinRel & join, DB::Q
         QueryPlanStepPtr join_step
             = std::make_unique<DB::JoinStep>(left->getCurrentDataStream(), 
right->getCurrentDataStream(), hash_join, 8192, 1, false);
 
-        join_step->setStepDescription("JOIN");
+        join_step->setStepDescription("HASH_JOIN");
         steps.emplace_back(join_step.get());
         std::vector<QueryPlanPtr> plans;
         plans.emplace_back(std::move(left));
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
index c7d17c84a..59244c380 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
@@ -130,6 +130,17 @@ trait SparkPlanExecApi {
       right: SparkPlan,
       isNullAwareAntiJoin: Boolean = false): 
BroadcastHashJoinExecTransformerBase
 
+  /** Generate ShuffledHashJoinExecTransformer. */
+  def genSortMergeJoinExecTransformer(
+      leftKeys: Seq[Expression],
+      rightKeys: Seq[Expression],
+      joinType: JoinType,
+      condition: Option[Expression],
+      left: SparkPlan,
+      right: SparkPlan,
+      isSkewJoin: Boolean = false,
+      projectList: Seq[NamedExpression] = null): 
SortMergeJoinExecTransformerBase
+
   /** Generate CartesianProductExecTransformer. */
   def genCartesianProductExecTransformer(
       left: SparkPlan,
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/execution/SortMergeJoinExecTransformer.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/execution/SortMergeJoinExecTransformer.scala
index ffc007326..5ca11a53c 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/execution/SortMergeJoinExecTransformer.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/execution/SortMergeJoinExecTransformer.scala
@@ -25,23 +25,26 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.joins.BaseJoinExec
 import org.apache.spark.sql.vectorized.ColumnarBatch
 
 import com.google.protobuf.{Any, StringValue}
 import io.substrait.proto.JoinRel
 
-/** Performs a sort merge join of two child relations. */
-case class SortMergeJoinExecTransformer(
+trait MergeJoinLikeExecTransformer
+  extends BaseJoinExec
+  with TransformSupport
+  with ColumnarShuffledJoin {}
+abstract class SortMergeJoinExecTransformerBase(
     leftKeys: Seq[Expression],
     rightKeys: Seq[Expression],
     joinType: JoinType,
     condition: Option[Expression],
     left: SparkPlan,
     right: SparkPlan,
-    isSkewJoin: Boolean = false)
-  extends ColumnarShuffledJoin
-  with TransformSupport {
-
+    isSkewJoin: Boolean = false,
+    projectList: Seq[NamedExpression] = null)
+  extends MergeJoinLikeExecTransformer {
   // Note: "metrics" is made transient to avoid sending driver-side metrics to 
tasks.
   @transient override lazy val metrics =
     
BackendsApiManager.getMetricsApiInstance.genSortMergeJoinTransformerMetrics(sparkContext)
@@ -229,6 +232,38 @@ case class SortMergeJoinExecTransformer(
     JoinUtils.createTransformContext(false, output, joinRel, 
inputStreamedOutput, inputBuildOutput)
   }
 
+}
+
+/** Performs a sort merge join of two child relations. */
+case class SortMergeJoinExecTransformer(
+    leftKeys: Seq[Expression],
+    rightKeys: Seq[Expression],
+    joinType: JoinType,
+    condition: Option[Expression],
+    left: SparkPlan,
+    right: SparkPlan,
+    isSkewJoin: Boolean = false,
+    projectList: Seq[NamedExpression] = null)
+  extends SortMergeJoinExecTransformerBase(
+    leftKeys,
+    rightKeys,
+    joinType,
+    condition,
+    left,
+    right,
+    isSkewJoin,
+    projectList) {
+
+  override protected def doValidateInternal(): ValidationResult = {
+    val substraitContext = new SubstraitContext
+    // Firstly, need to check if the Substrait plan for this operator can be 
successfully generated.
+    if (substraitJoinType == JoinRel.JoinType.JOIN_TYPE_OUTER) {
+      return ValidationResult
+        .notOk(s"Found unsupported join type of $joinType for velox smj: 
$substraitJoinType")
+    }
+    super.doValidateInternal()
+  }
+
   override protected def withNewChildrenInternal(
       newLeft: SparkPlan,
       newRight: SparkPlan): SortMergeJoinExecTransformer =
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformHintRule.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformHintRule.scala
index aa68ce72e..1dc56d254 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformHintRule.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformHintRule.scala
@@ -29,7 +29,6 @@ import org.apache.spark.api.python.EvalPythonExecTransformer
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
Expression}
-import org.apache.spark.sql.catalyst.plans.FullOuter
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.trees.TreeNodeTag
 import org.apache.spark.sql.execution._
@@ -619,19 +618,18 @@ case class AddTransformHintRule() extends Rule[SparkPlan] 
{
             transformer.doValidate().tagOnFallback(plan)
           }
         case plan: SortMergeJoinExec =>
-          if (!enableColumnarSortMergeJoin || plan.joinType == FullOuter) {
-            TransformHints.tagNotTransformable(
-              plan,
-              "columnar sort merge join is not enabled or join type is 
FullOuter")
+          if (!enableColumnarSortMergeJoin) {
+            TransformHints.tagNotTransformable(plan, "columnar sort merge join 
is not enabled")
           } else {
-            val transformer = SortMergeJoinExecTransformer(
-              plan.leftKeys,
-              plan.rightKeys,
-              plan.joinType,
-              plan.condition,
-              plan.left,
-              plan.right,
-              plan.isSkewJoin)
+            val transformer = BackendsApiManager.getSparkPlanExecApiInstance
+              .genSortMergeJoinExecTransformer(
+                plan.leftKeys,
+                plan.rightKeys,
+                plan.joinType,
+                plan.condition,
+                plan.left,
+                plan.right,
+                plan.isSkewJoin)
             transformer.doValidate().tagOnFallback(plan)
           }
         case plan: CartesianProductExec =>
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transform/ImplementSingleNode.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transform/ImplementSingleNode.scala
index 22edfaa5d..5820de270 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transform/ImplementSingleNode.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transform/ImplementSingleNode.scala
@@ -192,14 +192,15 @@ case class ImplementJoin() extends ImplementSingleNode 
with LogLevelUtil {
         val left = plan.left
         val right = plan.right
         logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
-        SortMergeJoinExecTransformer(
-          plan.leftKeys,
-          plan.rightKeys,
-          plan.joinType,
-          plan.condition,
-          left,
-          right,
-          plan.isSkewJoin)
+        BackendsApiManager.getSparkPlanExecApiInstance
+          .genSortMergeJoinExecTransformer(
+            plan.leftKeys,
+            plan.rightKeys,
+            plan.joinType,
+            plan.condition,
+            left,
+            right,
+            plan.isSkewJoin)
       case plan: BroadcastHashJoinExec =>
         val left = plan.left
         val right = plan.right
diff --git 
a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
 
b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index bfecd2292..4c1f8e65e 100644
--- 
a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++ 
b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -1187,6 +1187,10 @@ class ClickHouseTestSettings extends BackendTestSettings 
{
     .exclude("SPARK-36020: Check logical link in remove redundant projects")
     .exclude("SPARK-36032: Use inputPlan instead of currentPhysicalPlan to 
initialize logical link")
     .exclude("SPARK-37742: AQE reads invalid InMemoryRelation stats and 
mistakenly plans BHJ")
+    // SMJ Exec have changed to CH SMJ Transformer
+    .exclude("Change broadcast join to merge join")
+    .exclude("Avoid plan change if cost is greater")
+    .exclude("SPARK-30524: Do not optimize skew join if introduce additional 
shuffle")
     .excludeGlutenTest("Change broadcast join to merge join")
     .excludeGlutenTest("Empty stage coalesced to 1-partition RDD")
     .excludeGlutenTest(
diff --git 
a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
 
b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index 385ef6381..be6273857 100644
--- 
a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++ 
b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -1232,12 +1232,17 @@ class ClickHouseTestSettings extends 
BackendTestSettings {
     .exclude("SPARK-37742: AQE reads invalid InMemoryRelation stats and 
mistakenly plans BHJ")
     .exclude("SPARK-37328: skew join with 3 tables")
     .exclude("SPARK-39915: Dataset.repartition(N) may not create N partitions")
-    .excludeGlutenTest("Change broadcast join to merge join")
-    .excludeGlutenTest("Change broadcast join to merge join")
+    .exclude("Change broadcast join to merge join")
+    .exclude("Avoid plan change if cost is greater")
+    .exclude("SPARK-37652: optimize skewed join through union")
+    .exclude("SPARK-35455: Unify empty relation optimization between normal 
and AQE optimizer " +
+      "- single join")
+    .exclude("SPARK-35455: Unify empty relation optimization between normal 
and AQE optimizer " +
+      "- multi join")
     .excludeGlutenTest("Empty stage coalesced to 1-partition RDD")
     .excludeGlutenTest(
       "Avoid changing merge join to broadcast join if too many empty 
partitions on build plan")
-    .excludeGlutenTest("SPARK-30524: Do not optimize skew join if introduce 
additional shuffle")
+    .exclude("SPARK-30524: Do not optimize skew join if introduce additional 
shuffle")
     .excludeGlutenTest("SPARK-33551: Do not use AQE shuffle read for 
repartition")
     .excludeGlutenTest("SPARK-35264: Support AQE side broadcastJoin threshold")
     .excludeGlutenTest("SPARK-35264: Support AQE side shuffled hash join 
formula")
diff --git 
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/connector/GlutenKeyGroupedPartitioningSuite.scala
 
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/connector/GlutenKeyGroupedPartitioningSuite.scala
index 12ce843cc..8bc67f6f8 100644
--- 
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/connector/GlutenKeyGroupedPartitioningSuite.scala
+++ 
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/connector/GlutenKeyGroupedPartitioningSuite.scala
@@ -15,7 +15,8 @@
  * limitations under the License.
  */
 package org.apache.spark.sql.connector
-import org.apache.gluten.execution.SortMergeJoinExecTransformer
+
+import org.apache.gluten.execution.SortMergeJoinExecTransformerBase
 
 import org.apache.spark.SparkConf
 import org.apache.spark.sql.GlutenSQLTestsBaseTrait
@@ -109,7 +110,7 @@ class GlutenKeyGroupedPartitioningSuite
       plan: SparkPlan): Seq[ColumnarShuffleExchangeExec] = {
     // here we skip collecting shuffle operators that are not associated with 
SMJ
     collect(plan) {
-      case s: SortMergeJoinExecTransformer => s
+      case s: SortMergeJoinExecTransformerBase => s
       case s: SortMergeJoinExec => s
     }.flatMap(smj => collect(smj) { case s: ColumnarShuffleExchangeExec => s })
   }
diff --git 
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/adaptive/clickhouse/ClickHouseAdaptiveQueryExecSuite.scala
 
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/adaptive/clickhouse/ClickHouseAdaptiveQueryExecSuite.scala
index 091abe11d..9ddaac185 100644
--- 
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/adaptive/clickhouse/ClickHouseAdaptiveQueryExecSuite.scala
+++ 
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/adaptive/clickhouse/ClickHouseAdaptiveQueryExecSuite.scala
@@ -16,7 +16,7 @@
  */
 package org.apache.spark.sql.execution.adaptive.clickhouse
 
-import org.apache.gluten.execution.{BroadcastHashJoinExecTransformerBase, 
ShuffledHashJoinExecTransformerBase, SortExecTransformer, 
SortMergeJoinExecTransformer}
+import org.apache.gluten.execution.{BroadcastHashJoinExecTransformerBase, 
ShuffledHashJoinExecTransformerBase, SortExecTransformer, 
SortMergeJoinExecTransformerBase}
 
 import org.apache.spark.SparkConf
 import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent}
@@ -102,8 +102,8 @@ class ClickHouseAdaptiveQueryExecSuite extends 
AdaptiveQueryExecSuite with Glute
   }
 
   private def findTopLevelSortMergeJoinTransform(
-      plan: SparkPlan): Seq[SortMergeJoinExecTransformer] = {
-    collect(plan) { case j: SortMergeJoinExecTransformer => j }
+      plan: SparkPlan): Seq[SortMergeJoinExecTransformerBase] = {
+    collect(plan) { case j: SortMergeJoinExecTransformerBase => j }
   }
 
   private def sortMergeJoinSize(plan: SparkPlan): Int = {
@@ -260,7 +260,7 @@ class ClickHouseAdaptiveQueryExecSuite extends 
AdaptiveQueryExecSuite with Glute
           .count()
         checkAnswer(testDf, Seq())
         val plan = testDf.queryExecution.executedPlan
-        
assert(find(plan)(_.isInstanceOf[SortMergeJoinExecTransformer]).isDefined)
+        
assert(find(plan)(_.isInstanceOf[SortMergeJoinExecTransformerBase]).isDefined)
       }
 
       withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1") {
@@ -671,7 +671,7 @@ class ClickHouseAdaptiveQueryExecSuite extends 
AdaptiveQueryExecSuite with Glute
                 rightSkewNum: Int): Unit = {
               assert(joins.size == 1)
               joins.head match {
-                case s: SortMergeJoinExecTransformer => assert(s.isSkewJoin)
+                case s: SortMergeJoinExecTransformerBase => 
assert(s.isSkewJoin)
                 case g: ShuffledHashJoinExecTransformerBase => 
assert(g.isSkewJoin)
                 case _ => assert(false)
               }
@@ -804,11 +804,6 @@ class ClickHouseAdaptiveQueryExecSuite extends 
AdaptiveQueryExecSuite with Glute
   ignore(
     GLUTEN_TEST + "SPARK-32573: Eliminate NAAJ when BuildSide is 
HashedRelationWithAllNullKeys") {}
 
-  // EmptyRelation case
-  ignore(
-    GLUTEN_TEST + "SPARK-35455: Unify empty relation optimization " +
-      "between normal and AQE optimizer - single join") {}
-
   testGluten("SPARK-32753: Only copy tags to node with no tags") {
     withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
       withTempView("v1") {
@@ -1511,4 +1506,56 @@ class ClickHouseAdaptiveQueryExecSuite extends 
AdaptiveQueryExecSuite with Glute
         }
     }
   }
+
+  testGluten("SPARK-37652: optimize skewed join through union") {
+    withSQLConf(
+      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+      SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "100",
+      SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100"
+    ) {
+      withTempView("skewData1", "skewData2") {
+        spark
+          .range(0, 1000, 1, 10)
+          .selectExpr("id % 3 as key1", "id as value1")
+          .createOrReplaceTempView("skewData1")
+        spark
+          .range(0, 1000, 1, 10)
+          .selectExpr("id % 1 as key2", "id as value2")
+          .createOrReplaceTempView("skewData2")
+
+        def checkSkewJoin(query: String, joinNums: Int, optimizeSkewJoinNums: 
Int): Unit = {
+          val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult(query)
+          val joins = findTopLevelSortMergeJoinTransform(innerAdaptivePlan)
+          val optimizeSkewJoins = joins.filter(_.isSkewJoin)
+          assert(joins.size == joinNums && optimizeSkewJoins.size == 
optimizeSkewJoinNums)
+        }
+
+        // skewJoin union skewJoin
+        checkSkewJoin(
+          "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " +
+            "UNION ALL SELECT key2 FROM skewData1 JOIN skewData2 ON key1 = 
key2",
+          2,
+          2)
+
+        // skewJoin union aggregate
+        checkSkewJoin(
+          "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " +
+            "UNION ALL SELECT key2 FROM skewData2 GROUP BY key2",
+          1,
+          1)
+
+        // skewJoin1 union (skewJoin2 join aggregate)
+        // skewJoin2 will lead to extra shuffles, but skew1 cannot be optimized
+        checkSkewJoin(
+          "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 UNION ALL 
" +
+            "SELECT key1 from (SELECT key1 FROM skewData1 JOIN skewData2 ON 
key1 = key2) tmp1 " +
+            "JOIN (SELECT key2 FROM skewData2 GROUP BY key2) tmp2 ON key1 = 
key2",
+          3,
+          0
+        )
+      }
+    }
+  }
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to