This is an automated email from the ASF dual-hosted git repository.

hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 6dd91bac6c [GLUTEN-7745][VL] Incorporate SQL Union operator into Velox 
execution pipeline (#7842)
6dd91bac6c is described below

commit 6dd91bac6cc155810beb34415a6dee1a03069f6d
Author: Hongze Zhang <[email protected]>
AuthorDate: Tue Dec 3 09:40:46 2024 +0800

    [GLUTEN-7745][VL] Incorporate SQL Union operator into Velox execution 
pipeline (#7842)
---
 .../backendsapi/clickhouse/CHMetricsApi.scala      |   8 ++
 .../org/apache/gluten/utils/GlutenURLDecoder.java  |   2 +-
 .../backendsapi/velox/VeloxListenerApi.scala       |   3 +-
 .../gluten/backendsapi/velox/VeloxMetricsApi.scala |  11 ++
 .../gluten/backendsapi/velox/VeloxRuleApi.scala    |   2 +
 .../org/apache/gluten/metrics/MetricsUtil.scala    |   9 +-
 .../gluten/metrics/UnionMetricsUpdater.scala       |  35 +++++
 .../gluten/execution/MiscOperatorSuite.scala       |  28 +++-
 .../VeloxOrcDataTypeValidationSuite.scala          |   5 +-
 .../VeloxParquetDataTypeValidationSuite.scala      |   5 +-
 cpp/velox/compute/WholeStageResultIterator.cc      |  44 +++++-
 cpp/velox/substrait/SubstraitToVeloxPlan.cc        |  46 +++++++
 cpp/velox/substrait/SubstraitToVeloxPlan.h         |   3 +
 .../substrait/SubstraitToVeloxPlanValidator.cc     | 153 ++++++++++++++++-----
 .../substrait/SubstraitToVeloxPlanValidator.h      |  15 +-
 .../apache/gluten/substrait/rel/RelBuilder.java    |  17 +++
 .../apache/gluten/substrait/rel/SetRelNode.java    |  62 +++++++++
 .../org/apache/gluten/backendsapi/MetricsApi.scala |   4 +
 .../BasicPhysicalOperatorTransformer.scala         |   3 +-
 .../gluten/execution/UnionExecTransformer.scala    |  99 +++++++++++++
 .../extension/columnar/UnionTransformerRule.scala  |  61 ++++++++
 .../execution/WholeStageTransformerSuite.scala     |  35 +++--
 .../scala/org/apache/gluten/GlutenConfig.scala     |   9 ++
 23 files changed, 604 insertions(+), 55 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala
index 73b2d0f211..a0576a807b 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala
@@ -450,6 +450,14 @@ class CHMetricsApi extends MetricsApi with Logging with 
LogLevelUtil {
       s"SampleTransformer metrics update is not supported in CH backend")
   }
 
+  override def genUnionTransformerMetrics(sparkContext: SparkContext): 
Map[String, SQLMetric] =
+    throw new UnsupportedOperationException(
+      "UnionExecTransformer metrics update is not supported in CH backend")
+
+  override def genUnionTransformerMetricsUpdater(metrics: Map[String, 
SQLMetric]): MetricsUpdater =
+    throw new UnsupportedOperationException(
+      "UnionExecTransformer metrics update is not supported in CH backend")
+
   def genWriteFilesTransformerMetrics(sparkContext: SparkContext): Map[String, 
SQLMetric] =
     Map(
       "physicalWrittenBytes" -> SQLMetrics.createMetric(sparkContext, "number 
of written bytes"),
diff --git 
a/backends-velox/src/main/java/org/apache/gluten/utils/GlutenURLDecoder.java 
b/backends-velox/src/main/java/org/apache/gluten/utils/GlutenURLDecoder.java
index 9228a2f860..856ddf1597 100644
--- a/backends-velox/src/main/java/org/apache/gluten/utils/GlutenURLDecoder.java
+++ b/backends-velox/src/main/java/org/apache/gluten/utils/GlutenURLDecoder.java
@@ -31,7 +31,7 @@ public class GlutenURLDecoder {
    * <p><em><strong>Note:</strong> The <a href=
    * "http://www.w3.org/TR/html40/appendix/notes.html#non-ascii-chars";> World 
Wide Web Consortium
    * Recommendation</a> states that UTF-8 should be used. Not doing so may 
introduce
-   * incompatibilites.</em>
+   * incompatibilities.</em>
    *
    * @param s the <code>String</code> to decode
    * @param enc The name of a supported <a 
href="../lang/package-summary.html#charenc">character
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
index d29d302970..3a82abe618 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
@@ -31,6 +31,7 @@ import org.apache.spark.{HdfsConfGenerator, SparkConf, 
SparkContext}
 import org.apache.spark.api.plugin.PluginContext
 import org.apache.spark.internal.Logging
 import org.apache.spark.network.util.ByteUnit
+import org.apache.spark.sql.execution.ColumnarCachedBatchSerializer
 import org.apache.spark.sql.execution.datasources.GlutenWriterColumnarRules
 import 
org.apache.spark.sql.execution.datasources.velox.{VeloxParquetWriterInjects, 
VeloxRowSplitter}
 import org.apache.spark.sql.expression.UDFResolver
@@ -75,7 +76,7 @@ class VeloxListenerApi extends ListenerApi with Logging {
     if (conf.getBoolean(GlutenConfig.COLUMNAR_TABLE_CACHE_ENABLED.key, 
defaultValue = false)) {
       conf.set(
         StaticSQLConf.SPARK_CACHE_SERIALIZER.key,
-        "org.apache.spark.sql.execution.ColumnarCachedBatchSerializer")
+        classOf[ColumnarCachedBatchSerializer].getName)
     }
 
     // Static initializers for driver.
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxMetricsApi.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxMetricsApi.scala
index e70e1d13bd..934b680382 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxMetricsApi.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxMetricsApi.scala
@@ -582,4 +582,15 @@ class VeloxMetricsApi extends MetricsApi with Logging {
 
   override def genSampleTransformerMetricsUpdater(metrics: Map[String, 
SQLMetric]): MetricsUpdater =
     new SampleMetricsUpdater(metrics)
+
+  override def genUnionTransformerMetrics(sparkContext: SparkContext): 
Map[String, SQLMetric] = Map(
+    "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input 
rows"),
+    "inputVectors" -> SQLMetrics.createMetric(sparkContext, "number of input 
vectors"),
+    "inputBytes" -> SQLMetrics.createSizeMetric(sparkContext, "number of input 
bytes"),
+    "wallNanos" -> SQLMetrics.createNanoTimingMetric(sparkContext, "time of 
union"),
+    "cpuCount" -> SQLMetrics.createMetric(sparkContext, "cpu wall time count")
+  )
+
+  override def genUnionTransformerMetricsUpdater(metrics: Map[String, 
SQLMetric]): MetricsUpdater =
+    new UnionMetricsUpdater(metrics)
 }
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
index 7841e6cd94..7337be5737 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
@@ -92,6 +92,7 @@ object VeloxRuleApi {
       c => HeuristicTransform.Single(validatorBuilder(c.glutenConf), rewrites, 
offloads))
 
     // Legacy: Post-transform rules.
+    injector.injectPostTransform(_ => UnionTransformerRule())
     injector.injectPostTransform(c => PartialProjectRule.apply(c.session))
     injector.injectPostTransform(_ => RemoveNativeWriteFilesSortAndProject())
     injector.injectPostTransform(c => RewriteTransformer.apply(c.session))
@@ -178,6 +179,7 @@ object VeloxRuleApi {
 
     // Gluten RAS: Post rules.
     injector.injectPostTransform(_ => RemoveTransitions)
+    injector.injectPostTransform(_ => UnionTransformerRule())
     injector.injectPostTransform(c => PartialProjectRule.apply(c.session))
     injector.injectPostTransform(_ => RemoveNativeWriteFilesSortAndProject())
     injector.injectPostTransform(c => RewriteTransformer.apply(c.session))
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/metrics/MetricsUtil.scala 
b/backends-velox/src/main/scala/org/apache/gluten/metrics/MetricsUtil.scala
index cd50d0b8e2..b8ef1620f9 100644
--- a/backends-velox/src/main/scala/org/apache/gluten/metrics/MetricsUtil.scala
+++ b/backends-velox/src/main/scala/org/apache/gluten/metrics/MetricsUtil.scala
@@ -58,7 +58,8 @@ object MetricsUtil extends Logging {
           assert(t.children.size == 1, "MetricsUpdater.None can only be used 
on unary operator")
           treeifyMetricsUpdaters(t.children.head)
         case t: TransformSupport =>
-          MetricsUpdaterTree(t.metricsUpdater(), 
t.children.map(treeifyMetricsUpdaters))
+          // Reversed children order to match the traversal code.
+          MetricsUpdaterTree(t.metricsUpdater(), 
t.children.reverse.map(treeifyMetricsUpdaters))
         case _ =>
           MetricsUpdaterTree(MetricsUpdater.Terminate, Seq())
       }
@@ -233,6 +234,12 @@ object MetricsUtil extends Logging {
           operatorMetrics,
           metrics.getSingleMetrics,
           joinParamsMap.get(operatorIdx))
+      case u: UnionMetricsUpdater =>
+        // JoinRel outputs two suites of metrics respectively for hash build 
and hash probe.
+        // Therefore, fetch one more suite of metrics here.
+        operatorMetrics.add(metrics.getOperatorMetrics(curMetricsIdx))
+        curMetricsIdx -= 1
+        u.updateUnionMetrics(operatorMetrics)
       case hau: HashAggregateMetricsUpdater =>
         hau.updateAggregationMetrics(operatorMetrics, 
aggParamsMap.get(operatorIdx))
       case lu: LimitMetricsUpdater =>
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/metrics/UnionMetricsUpdater.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/metrics/UnionMetricsUpdater.scala
new file mode 100644
index 0000000000..9e91cf368c
--- /dev/null
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/metrics/UnionMetricsUpdater.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.metrics
+
+import org.apache.spark.sql.execution.metric.SQLMetric
+
+class UnionMetricsUpdater(val metrics: Map[String, SQLMetric]) extends 
MetricsUpdater {
+  override def updateNativeMetrics(opMetrics: IOperatorMetrics): Unit = {
+    throw new UnsupportedOperationException()
+  }
+
+  def updateUnionMetrics(unionMetrics: java.util.ArrayList[OperatorMetrics]): 
Unit = {
+    // Union was interpreted to LocalExchange + LocalPartition. Use metrics 
from LocalExchange.
+    val localExchangeMetrics = unionMetrics.get(0)
+    metrics("numInputRows") += localExchangeMetrics.inputRows
+    metrics("inputVectors") += localExchangeMetrics.inputVectors
+    metrics("inputBytes") += localExchangeMetrics.inputBytes
+    metrics("cpuCount") += localExchangeMetrics.cpuCount
+    metrics("wallNanos") += localExchangeMetrics.wallNanos
+  }
+}
diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala
 
b/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala
index 5cb2b65260..8063a5d122 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala
@@ -537,11 +537,37 @@ class MiscOperatorSuite extends 
VeloxWholeStageTransformerSuite with AdaptiveSpa
                          |""".stripMargin) {
       df =>
         {
-          getExecutedPlan(df).exists(plan => 
plan.find(_.isInstanceOf[ColumnarUnionExec]).isDefined)
+          assert(
+            getExecutedPlan(df).exists(
+              plan => plan.find(_.isInstanceOf[ColumnarUnionExec]).isDefined))
         }
     }
   }
 
+  test("union_all two tables with known partitioning") {
+    withSQLConf(GlutenConfig.NATIVE_UNION_ENABLED.key -> "true") {
+      compareDfResultsAgainstVanillaSpark(
+        () => {
+          val df1 = spark.sql("select l_orderkey as orderkey from lineitem")
+          val df2 = spark.sql("select o_orderkey as orderkey from orders")
+          df1.repartition(5).union(df2.repartition(5))
+        },
+        compareResult = true,
+        checkGlutenOperatorMatch[UnionExecTransformer]
+      )
+
+      compareDfResultsAgainstVanillaSpark(
+        () => {
+          val df1 = spark.sql("select l_orderkey as orderkey from lineitem")
+          val df2 = spark.sql("select o_orderkey as orderkey from orders")
+          df1.repartition(5).union(df2.repartition(6))
+        },
+        compareResult = true,
+        checkGlutenOperatorMatch[ColumnarUnionExec]
+      )
+    }
+  }
+
   test("union_all three tables") {
     runQueryAndCompare("""
                          |select count(orderkey) from (
diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxOrcDataTypeValidationSuite.scala
 
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxOrcDataTypeValidationSuite.scala
index 24e04f2dfc..6ac59ba4fa 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxOrcDataTypeValidationSuite.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxOrcDataTypeValidationSuite.scala
@@ -255,7 +255,10 @@ class VeloxOrcDataTypeValidationSuite extends 
VeloxWholeStageTransformerSuite {
                          |""".stripMargin) {
       df =>
         {
-          assert(getExecutedPlan(df).exists(plan => 
plan.isInstanceOf[ColumnarUnionExec]))
+          assert(
+            getExecutedPlan(df).exists(
+              plan =>
+                plan.isInstanceOf[ColumnarUnionExec] || 
plan.isInstanceOf[UnionExecTransformer]))
         }
     }
 
diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxParquetDataTypeValidationSuite.scala
 
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxParquetDataTypeValidationSuite.scala
index 57ca448fec..cb5614f396 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxParquetDataTypeValidationSuite.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxParquetDataTypeValidationSuite.scala
@@ -254,7 +254,10 @@ class VeloxParquetDataTypeValidationSuite extends 
VeloxWholeStageTransformerSuit
                          |""".stripMargin) {
       df =>
         {
-          assert(getExecutedPlan(df).exists(plan => 
plan.isInstanceOf[ColumnarUnionExec]))
+          assert(
+            getExecutedPlan(df).exists(
+              plan =>
+                plan.isInstanceOf[ColumnarUnionExec] || 
plan.isInstanceOf[UnionExecTransformer]))
         }
     }
 
diff --git a/cpp/velox/compute/WholeStageResultIterator.cc 
b/cpp/velox/compute/WholeStageResultIterator.cc
index b6ecbd959f..411c6c5636 100644
--- a/cpp/velox/compute/WholeStageResultIterator.cc
+++ b/cpp/velox/compute/WholeStageResultIterator.cc
@@ -91,7 +91,7 @@ WholeStageResultIterator::WholeStageResultIterator(
       std::move(queryCtx),
       velox::exec::Task::ExecutionMode::kSerial);
   if (!task_->supportSerialExecutionMode()) {
-    throw std::runtime_error("Task doesn't support single thread execution: " 
+ planNode->toString());
+    throw std::runtime_error("Task doesn't support single threaded execution: 
" + planNode->toString());
   }
   auto fileSystem = velox::filesystems::getFileSystem(spillDir, nullptr);
   GLUTEN_CHECK(fileSystem != nullptr, "File System for spilling is null!");
@@ -248,15 +248,47 @@ void WholeStageResultIterator::getOrderedNodeIds(
     const std::shared_ptr<const velox::core::PlanNode>& planNode,
     std::vector<velox::core::PlanNodeId>& nodeIds) {
   bool isProjectNode = (std::dynamic_pointer_cast<const 
velox::core::ProjectNode>(planNode) != nullptr);
+  bool isLocalExchangeNode = (std::dynamic_pointer_cast<const 
velox::core::LocalPartitionNode>(planNode) != nullptr);
+  bool isUnionNode = isLocalExchangeNode &&
+      std::dynamic_pointer_cast<const 
velox::core::LocalPartitionNode>(planNode)->type() ==
+          velox::core::LocalPartitionNode::Type::kGather;
   const auto& sourceNodes = planNode->sources();
-  for (const auto& sourceNode : sourceNodes) {
+  if (isProjectNode) {
+    GLUTEN_CHECK(sourceNodes.size() == 1, "Illegal state");
+    const auto sourceNode = sourceNodes.at(0);
     // Filter over Project are mapped into FilterProject operator in Velox.
     // Metrics are all applied on Project node, and the metrics for Filter node
     // do not exist.
-    if (isProjectNode && std::dynamic_pointer_cast<const 
velox::core::FilterNode>(sourceNode)) {
+    if (std::dynamic_pointer_cast<const velox::core::FilterNode>(sourceNode)) {
       omittedNodeIds_.insert(sourceNode->id());
     }
     getOrderedNodeIds(sourceNode, nodeIds);
+    nodeIds.emplace_back(planNode->id());
+    return;
+  }
+
+  if (isUnionNode) {
+    // FIXME: The whole metrics system in gluten-substrait is magic. Passing 
metrics trees through JNI with a trivial
+    //  array is possible but requires for a solid design. Apparently we 
haven't had it. All the code requires complete
+    //  rework.
+    // Union was interpreted as LocalPartition + LocalExchange + 2 fake 
projects as children in Velox. So we only fetch
+    // metrics from the root node.
+    std::vector<std::shared_ptr<const velox::core::PlanNode>> unionChildren{};
+    for (const auto& source : planNode->sources()) {
+      const auto projectedChild = std::dynamic_pointer_cast<const 
velox::core::ProjectNode>(source);
+      GLUTEN_CHECK(projectedChild != nullptr, "Illegal state");
+      const auto projectSources = projectedChild->sources();
+      GLUTEN_CHECK(projectSources.size() == 1, "Illegal state");
+      const auto projectSource = projectSources.at(0);
+      getOrderedNodeIds(projectSource, nodeIds);
+    }
+    nodeIds.emplace_back(planNode->id());
+    return;
+  }
+
+  for (const auto& sourceNode : sourceNodes) {
+    // Post-order traversal.
+    getOrderedNodeIds(sourceNode, nodeIds);
   }
   nodeIds.emplace_back(planNode->id());
 }
@@ -350,9 +382,9 @@ void WholeStageResultIterator::collectMetrics() {
       continue;
     }
 
-    const auto& status = planStats.at(nodeId);
-    // Add each operator status into metrics.
-    for (const auto& entry : status.operatorStats) {
+    const auto& stats = planStats.at(nodeId);
+    // Add each operator stats into metrics.
+    for (const auto& entry : stats.operatorStats) {
       const auto& second = entry.second;
       metrics_->get(Metrics::kInputRows)[metricIndex] = second->inputRows;
       metrics_->get(Metrics::kInputVectors)[metricIndex] = 
second->inputVectors;
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.cc 
b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
index 1efa733879..3ceccca4a3 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlan.cc
+++ b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
@@ -1043,6 +1043,50 @@ core::PlanNodePtr 
SubstraitToVeloxPlanConverter::toVeloxPlan(
       childNode);
 }
 
+core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const 
::substrait::SetRel& setRel) {
+  switch (setRel.op()) {
+    case ::substrait::SetRel_SetOp::SetRel_SetOp_SET_OP_UNION_ALL: {
+      std::vector<core::PlanNodePtr> children;
+      for (int32_t i = 0; i < setRel.inputs_size(); ++i) {
+        const auto& input = setRel.inputs(i);
+        children.push_back(toVeloxPlan(input));
+      }
+      GLUTEN_CHECK(!children.empty(), "At least one source is required for 
Velox LocalPartition");
+
+      // Velox doesn't allow different field names in schemas of 
LocalPartitionNode's children.
+      // Add project nodes to unify the schemas.
+      const RowTypePtr outRowType = asRowType(children[0]->outputType());
+      std::vector<std::string> outNames;
+      for (int32_t colIdx = 0; colIdx < outRowType->size(); ++colIdx) {
+        const auto name = outRowType->childAt(colIdx)->name();
+        outNames.push_back(name);
+      }
+
+      std::vector<core::PlanNodePtr> projectedChildren;
+      for (int32_t i = 0; i < children.size(); ++i) {
+        const auto& child = children[i];
+        const RowTypePtr& childRowType = child->outputType();
+        std::vector<core::TypedExprPtr> expressions;
+        for (int32_t colIdx = 0; colIdx < outNames.size(); ++colIdx) {
+          const auto fa =
+              
std::make_shared<core::FieldAccessTypedExpr>(childRowType->childAt(colIdx), 
childRowType->nameOf(colIdx));
+          const auto cast = 
std::make_shared<core::CastTypedExpr>(outRowType->childAt(colIdx), fa, false);
+          expressions.push_back(cast);
+        }
+        auto project = std::make_shared<core::ProjectNode>(nextPlanNodeId(), 
outNames, expressions, child);
+        projectedChildren.push_back(project);
+      }
+      return std::make_shared<core::LocalPartitionNode>(
+          nextPlanNodeId(),
+          core::LocalPartitionNode::Type::kGather,
+          std::make_shared<core::GatherPartitionFunctionSpec>(),
+          projectedChildren);
+    }
+    default:
+      throw GlutenException("Unsupported SetRel op: " + 
std::to_string(setRel.op()));
+  }
+}
+
 core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const 
::substrait::SortRel& sortRel) {
   auto childNode = convertSingleInput<::substrait::SortRel>(sortRel);
   auto [sortingKeys, sortingOrders] = processSortField(sortRel.sorts(), 
childNode->outputType());
@@ -1298,6 +1342,8 @@ core::PlanNodePtr 
SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
     return toVeloxPlan(rel.write());
   } else if (rel.has_windowgrouplimit()) {
     return toVeloxPlan(rel.windowgrouplimit());
+  } else if (rel.has_set()) {
+    return toVeloxPlan(rel.set());
   } else {
     VELOX_NYI("Substrait conversion not supported for Rel.");
   }
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.h 
b/cpp/velox/substrait/SubstraitToVeloxPlan.h
index 51e50ce347..6121923df7 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlan.h
+++ b/cpp/velox/substrait/SubstraitToVeloxPlan.h
@@ -84,6 +84,9 @@ class SubstraitToVeloxPlanConverter {
   /// Used to convert Substrait WindowGroupLimitRel into Velox PlanNode.
   core::PlanNodePtr toVeloxPlan(const ::substrait::WindowGroupLimitRel& 
windowGroupLimitRel);
 
+  /// Used to convert Substrait SetRel into Velox PlanNode.
+  core::PlanNodePtr toVeloxPlan(const ::substrait::SetRel& setRel);
+
   /// Used to convert Substrait JoinRel into Velox PlanNode.
   core::PlanNodePtr toVeloxPlan(const ::substrait::JoinRel& joinRel);
 
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc 
b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
index 3b74caf8ba..9325fed321 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
+++ b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
@@ -22,7 +22,6 @@
 #include "TypeUtils.h"
 #include "udf/UdfLoader.h"
 #include "utils/Common.h"
-#include "velox/core/ExpressionEvaluator.h"
 #include "velox/exec/Aggregate.h"
 #include "velox/expression/Expr.h"
 #include "velox/expression/SignatureBinder.h"
@@ -30,7 +29,7 @@
 namespace gluten {
 
 namespace {
-static const char* extractFileName(const char* file) {
+const char* extractFileName(const char* file) {
   return strrchr(file, '/') ? strrchr(file, '/') + 1 : file;
 }
 
@@ -53,13 +52,13 @@ static const char* extractFileName(const char* file) {
       __FUNCTION__,                                                    \
       reason))
 
-static const std::unordered_set<std::string> kRegexFunctions = {
+const std::unordered_set<std::string> kRegexFunctions = {
     "regexp_extract",
     "regexp_extract_all",
     "regexp_replace",
     "rlike"};
 
-static const std::unordered_set<std::string> kBlackList = {
+const std::unordered_set<std::string> kBlackList = {
     "split_part",
     "factorial",
     "concat_ws",
@@ -70,32 +69,59 @@ static const std::unordered_set<std::string> kBlackList = {
     "approx_percentile",
     "get_array_struct_fields",
     "map_from_arrays"};
-
 } // namespace
 
-bool SubstraitToVeloxPlanValidator::validateInputTypes(
+bool SubstraitToVeloxPlanValidator::parseVeloxType(
     const ::substrait::extensions::AdvancedExtension& extension,
-    std::vector<TypePtr>& types) {
+    TypePtr& out) {
+  ::substrait::Type substraitType;
   // The input type is wrapped in enhancement.
   if (!extension.has_enhancement()) {
     LOG_VALIDATION_MSG("Input type is not wrapped in enhancement.");
     return false;
   }
   const auto& enhancement = extension.enhancement();
-  ::substrait::Type inputType;
-  if (!enhancement.UnpackTo(&inputType)) {
+  if (!enhancement.UnpackTo(&substraitType)) {
     LOG_VALIDATION_MSG("Enhancement can't be unpacked to inputType.");
     return false;
   }
-  if (!inputType.has_struct_()) {
-    LOG_VALIDATION_MSG("Input type has no struct.");
+
+  out = SubstraitParser::parseType(substraitType);
+  return true;
+}
+
+bool SubstraitToVeloxPlanValidator::flattenVeloxType1(const TypePtr& type, 
std::vector<TypePtr>& out) {
+  if (type->kind() != TypeKind::ROW) {
+    LOG_VALIDATION_MSG("Type is not a RowType.");
+    return false;
+  }
+  auto rowType = std::dynamic_pointer_cast<const RowType>(type);
+  if (!rowType) {
+    LOG_VALIDATION_MSG("Failed to cast to RowType.");
     return false;
   }
+  for (const auto& field : rowType->children()) {
+    out.emplace_back(field);
+  }
+  return true;
+}
 
-  // Get the input types.
-  const auto& sTypes = inputType.struct_().types();
-  for (const auto& sType : sTypes) {
-    types.emplace_back(SubstraitParser::parseType(sType));
+bool SubstraitToVeloxPlanValidator::flattenVeloxType2(const TypePtr& type, 
std::vector<std::vector<TypePtr>>& out) {
+  if (type->kind() != TypeKind::ROW) {
+    LOG_VALIDATION_MSG("Type is not a RowType.");
+    return false;
+  }
+  auto rowType = std::dynamic_pointer_cast<const RowType>(type);
+  if (!rowType) {
+    LOG_VALIDATION_MSG("Failed to cast to RowType.");
+    return false;
+  }
+  for (const auto& field : rowType->children()) {
+    std::vector<TypePtr> inner;
+    if (!flattenVeloxType1(field, inner)) {
+      return false;
+    }
+    out.emplace_back(inner);
   }
   return true;
 }
@@ -341,10 +367,11 @@ bool SubstraitToVeloxPlanValidator::validate(const 
::substrait::WriteRel& writeR
   }
 
   // Validate input data type.
+  TypePtr inputRowType;
   std::vector<TypePtr> types;
   if (writeRel.has_named_table()) {
     const auto& extension = writeRel.named_table().advanced_extension();
-    if (!validateInputTypes(extension, types)) {
+    if (!parseVeloxType(extension, inputRowType) || 
!flattenVeloxType1(inputRowType, types)) {
       LOG_VALIDATION_MSG("Validation failed for input type validation in 
WriteRel.");
       return false;
     }
@@ -380,12 +407,12 @@ bool SubstraitToVeloxPlanValidator::validate(const 
::substrait::WriteRel& writeR
 }
 
 bool SubstraitToVeloxPlanValidator::validate(const ::substrait::FetchRel& 
fetchRel) {
-  RowTypePtr rowType = nullptr;
   // Get and validate the input types from extension.
   if (fetchRel.has_advanced_extension()) {
     const auto& extension = fetchRel.advanced_extension();
+    TypePtr inputRowType;
     std::vector<TypePtr> types;
-    if (!validateInputTypes(extension, types)) {
+    if (!parseVeloxType(extension, inputRowType) || 
!flattenVeloxType1(inputRowType, types)) {
       LOG_VALIDATION_MSG("Unsupported input types in FetchRel.");
       return false;
     }
@@ -396,7 +423,6 @@ bool SubstraitToVeloxPlanValidator::validate(const 
::substrait::FetchRel& fetchR
     for (auto colIdx = 0; colIdx < types.size(); colIdx++) {
       names.emplace_back(SubstraitParser::makeNodeName(inputPlanNodeId, 
colIdx));
     }
-    rowType = std::make_shared<RowType>(std::move(names), std::move(types));
   }
 
   if (fetchRel.offset() < 0 || fetchRel.count() < 0) {
@@ -412,8 +438,9 @@ bool SubstraitToVeloxPlanValidator::validate(const 
::substrait::TopNRel& topNRel
   // Get and validate the input types from extension.
   if (topNRel.has_advanced_extension()) {
     const auto& extension = topNRel.advanced_extension();
+    TypePtr inputRowType;
     std::vector<TypePtr> types;
-    if (!validateInputTypes(extension, types)) {
+    if (!parseVeloxType(extension, inputRowType) || 
!flattenVeloxType1(inputRowType, types)) {
       LOG_VALIDATION_MSG("Unsupported input types in TopNRel.");
       return false;
     }
@@ -457,8 +484,9 @@ bool SubstraitToVeloxPlanValidator::validate(const 
::substrait::GenerateRel& gen
     return false;
   }
   const auto& extension = generateRel.advanced_extension();
+  TypePtr inputRowType;
   std::vector<TypePtr> types;
-  if (!validateInputTypes(extension, types)) {
+  if (!parseVeloxType(extension, inputRowType) || 
!flattenVeloxType1(inputRowType, types)) {
     LOG_VALIDATION_MSG("Validation failed for input types in GenerateRel.");
     return false;
   }
@@ -487,8 +515,9 @@ bool SubstraitToVeloxPlanValidator::validate(const 
::substrait::ExpandRel& expan
   // Get and validate the input types from extension.
   if (expandRel.has_advanced_extension()) {
     const auto& extension = expandRel.advanced_extension();
+    TypePtr inputRowType;
     std::vector<TypePtr> types;
-    if (!validateInputTypes(extension, types)) {
+    if (!parseVeloxType(extension, inputRowType) || 
!flattenVeloxType1(inputRowType, types)) {
       LOG_VALIDATION_MSG("Unsupported input types in ExpandRel.");
       return false;
     }
@@ -571,8 +600,9 @@ bool SubstraitToVeloxPlanValidator::validate(const 
::substrait::WindowRel& windo
     return false;
   }
   const auto& extension = windowRel.advanced_extension();
+  TypePtr inputRowType;
   std::vector<TypePtr> types;
-  if (!validateInputTypes(extension, types)) {
+  if (!parseVeloxType(extension, inputRowType) || 
!flattenVeloxType1(inputRowType, types)) {
     LOG_VALIDATION_MSG("Validation failed for input types in WindowRel.");
     return false;
   }
@@ -680,7 +710,7 @@ bool SubstraitToVeloxPlanValidator::validate(const 
::substrait::WindowRel& windo
         LOG_VALIDATION_MSG("in windowRel, the sorting key in Sort Operator 
only support field.");
         return false;
       }
-      exec::ExprSet exprSet({std::move(expression)}, execCtx_);
+      exec::ExprSet exprSet1({std::move(expression)}, execCtx_);
     }
   }
 
@@ -699,8 +729,9 @@ bool SubstraitToVeloxPlanValidator::validate(const 
::substrait::WindowGroupLimit
     return false;
   }
   const auto& extension = windowGroupLimitRel.advanced_extension();
+  TypePtr inputRowType;
   std::vector<TypePtr> types;
-  if (!validateInputTypes(extension, types)) {
+  if (!parseVeloxType(extension, inputRowType) || 
!flattenVeloxType1(inputRowType, types)) {
     LOG_VALIDATION_MSG("Validation failed for input types in 
WindowGroupLimitRel.");
     return false;
   }
@@ -750,13 +781,61 @@ bool SubstraitToVeloxPlanValidator::validate(const 
::substrait::WindowGroupLimit
         LOG_VALIDATION_MSG("in windowGroupLimitRel, the sorting key in Sort 
Operator only support field.");
         return false;
       }
-      exec::ExprSet exprSet({std::move(expression)}, execCtx_);
+      exec::ExprSet exprSet1({std::move(expression)}, execCtx_);
     }
   }
 
   return true;
 }
 
+bool SubstraitToVeloxPlanValidator::validate(const ::substrait::SetRel& 
setRel) {
+  switch (setRel.op()) {
+    case ::substrait::SetRel_SetOp::SetRel_SetOp_SET_OP_UNION_ALL: {
+      for (int32_t i = 0; i < setRel.inputs_size(); ++i) {
+        const auto& input = setRel.inputs(i);
+        if (!validate(input)) {
+          LOG_VALIDATION_MSG("ProjectRel input");
+          return false;
+        }
+      }
+      if (!setRel.has_advanced_extension()) {
+        LOG_VALIDATION_MSG("Input types are expected in SetRel.");
+        return false;
+      }
+      const auto& extension = setRel.advanced_extension();
+      TypePtr inputRowType;
+      std::vector<std::vector<TypePtr>> childrenTypes;
+      if (!parseVeloxType(extension, inputRowType) || 
!flattenVeloxType2(inputRowType, childrenTypes)) {
+        LOG_VALIDATION_MSG("Validation failed for input types in SetRel.");
+        return false;
+      }
+      std::vector<RowTypePtr> childrenRowTypes;
+      for (auto i = 0; i < childrenTypes.size(); ++i) {
+        auto& types = childrenTypes.at(i);
+        std::vector<std::string> names;
+        names.reserve(types.size());
+        for (auto colIdx = 0; colIdx < types.size(); colIdx++) {
+          names.emplace_back(SubstraitParser::makeNodeName(i, colIdx));
+        }
+        childrenRowTypes.push_back(std::make_shared<RowType>(std::move(names), 
std::move(types)));
+      }
+
+      for (auto i = 1; i < childrenRowTypes.size(); ++i) {
+        if (!(childrenRowTypes[i]->equivalent(*childrenRowTypes[0]))) {
+          LOG_VALIDATION_MSG(
+              "All sources of the Set operation must have the same output 
type: " + childrenRowTypes[i]->toString() +
+              " vs. " + childrenRowTypes[0]->toString());
+          return false;
+        }
+      }
+      return true;
+    }
+    default:
+      LOG_VALIDATION_MSG("Unsupported SetRel op: " + 
std::to_string(setRel.op()));
+      return false;
+  }
+}
+
 bool SubstraitToVeloxPlanValidator::validate(const ::substrait::SortRel& 
sortRel) {
   if (sortRel.has_input() && !validate(sortRel.input())) {
     return false;
@@ -769,8 +848,9 @@ bool SubstraitToVeloxPlanValidator::validate(const 
::substrait::SortRel& sortRel
   }
 
   const auto& extension = sortRel.advanced_extension();
+  TypePtr inputRowType;
   std::vector<TypePtr> types;
-  if (!validateInputTypes(extension, types)) {
+  if (!parseVeloxType(extension, inputRowType) || 
!flattenVeloxType1(inputRowType, types)) {
     LOG_VALIDATION_MSG("Validation failed for input types in SortRel.");
     return false;
   }
@@ -822,8 +902,9 @@ bool SubstraitToVeloxPlanValidator::validate(const 
::substrait::ProjectRel& proj
     return false;
   }
   const auto& extension = projectRel.advanced_extension();
+  TypePtr inputRowType;
   std::vector<TypePtr> types;
-  if (!validateInputTypes(extension, types)) {
+  if (!parseVeloxType(extension, inputRowType) || 
!flattenVeloxType1(inputRowType, types)) {
     LOG_VALIDATION_MSG("Validation failed for input types in ProjectRel.");
     return false;
   }
@@ -865,8 +946,9 @@ bool SubstraitToVeloxPlanValidator::validate(const 
::substrait::FilterRel& filte
     return false;
   }
   const auto& extension = filterRel.advanced_extension();
+  TypePtr inputRowType;
   std::vector<TypePtr> types;
-  if (!validateInputTypes(extension, types)) {
+  if (!parseVeloxType(extension, inputRowType) || 
!flattenVeloxType1(inputRowType, types)) {
     LOG_VALIDATION_MSG("Validation failed for input types in FilterRel.");
     return false;
   }
@@ -938,8 +1020,9 @@ bool SubstraitToVeloxPlanValidator::validate(const 
::substrait::JoinRel& joinRel
   }
 
   const auto& extension = joinRel.advanced_extension();
+  TypePtr inputRowType;
   std::vector<TypePtr> types;
-  if (!validateInputTypes(extension, types)) {
+  if (!parseVeloxType(extension, inputRowType) || 
!flattenVeloxType1(inputRowType, types)) {
     LOG_VALIDATION_MSG("Validation failed for input types in JoinRel.");
     return false;
   }
@@ -991,8 +1074,9 @@ bool SubstraitToVeloxPlanValidator::validate(const 
::substrait::CrossRel& crossR
   }
 
   const auto& extension = crossRel.advanced_extension();
+  TypePtr inputRowType;
   std::vector<TypePtr> types;
-  if (!validateInputTypes(extension, types)) {
+  if (!parseVeloxType(extension, inputRowType) || 
!flattenVeloxType1(inputRowType, types)) {
     logValidateMsg("Native validation failed due to: Validation failed for 
input types in CrossRel");
     return false;
   }
@@ -1070,11 +1154,13 @@ bool SubstraitToVeloxPlanValidator::validate(const 
::substrait::AggregateRel& ag
 
   // Validate input types.
   if (aggRel.has_advanced_extension()) {
+    TypePtr inputRowType;
     std::vector<TypePtr> types;
     const auto& extension = aggRel.advanced_extension();
     // Aggregate always has advanced extension for streaming aggregate 
optimization,
     // but only some of them have enhancement for validation.
-    if (extension.has_enhancement() && !validateInputTypes(extension, types)) {
+    if (extension.has_enhancement() &&
+        (!parseVeloxType(extension, inputRowType) || 
!flattenVeloxType1(inputRowType, types))) {
       LOG_VALIDATION_MSG("Validation failed for input types in AggregateRel.");
       return false;
     }
@@ -1266,7 +1352,10 @@ bool SubstraitToVeloxPlanValidator::validate(const 
::substrait::Rel& rel) {
     return validate(rel.write());
   } else if (rel.has_windowgrouplimit()) {
     return validate(rel.windowgrouplimit());
+  } else if (rel.has_set()) {
+    return validate(rel.set());
   } else {
+    LOG_VALIDATION_MSG("Unsupported relation type: " + rel.GetTypeName());
     return false;
   }
 }
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.h 
b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.h
index 1fe174928f..0c8d882ca0 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.h
+++ b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.h
@@ -61,6 +61,9 @@ class SubstraitToVeloxPlanValidator {
   /// Used to validate whether the computing of this WindowGroupLimit is 
supported.
   bool validate(const ::substrait::WindowGroupLimitRel& windowGroupLimitRel);
 
+  /// Used to validate whether the computing of this Set is supported.
+  bool validate(const ::substrait::SetRel& setRel);
+
   /// Used to validate whether the computing of this Aggregation is supported.
   bool validate(const ::substrait::AggregateRel& aggRel);
 
@@ -103,9 +106,17 @@ class SubstraitToVeloxPlanValidator {
 
   std::vector<std::string> validateLog_;
 
-  /// Used to get types from advanced extension and validate them.
-  bool validateInputTypes(const ::substrait::extensions::AdvancedExtension& 
extension, std::vector<TypePtr>& types);
+  /// Used to get types from advanced extension and validate them, then 
convert to a Velox type that has arbitrary
+  /// levels of nesting.
+  bool parseVeloxType(const ::substrait::extensions::AdvancedExtension& 
extension, TypePtr& out);
+
+  /// Flattens a Velox type with single level of nesting into a std::vector of 
child types.
+  bool flattenVeloxType1(const TypePtr& type, std::vector<TypePtr>& out);
+
+  /// Flattens a Velox type with two level of nesting into a dual-nested 
std::vector of child types.
+  bool flattenVeloxType2(const TypePtr& type, 
std::vector<std::vector<TypePtr>>& out);
 
+  /// Validate aggregate rel.
   bool validateAggRelFunctionType(const ::substrait::AggregateRel& 
substraitAgg);
 
   /// Validate the round scalar function.
diff --git 
a/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java
 
b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java
index def1dca0a0..7d19311808 100644
--- 
a/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java
+++ 
b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java
@@ -27,6 +27,7 @@ import org.apache.gluten.substrait.type.TypeNode;
 
 import io.substrait.proto.CrossRel;
 import io.substrait.proto.JoinRel;
+import io.substrait.proto.SetRel;
 import io.substrait.proto.SortField;
 import org.apache.spark.sql.catalyst.expressions.Attribute;
 
@@ -317,4 +318,20 @@ public class RelBuilder {
     context.registerRelToOperator(operatorId);
     return new GenerateRelNode(input, generator, childOutput, extensionNode, 
outer);
   }
+
+  public static RelNode makeSetRel(
+      List<RelNode> inputs, SetRel.SetOp setOp, SubstraitContext context, Long 
operatorId) {
+    context.registerRelToOperator(operatorId);
+    return new SetRelNode(inputs, setOp);
+  }
+
+  public static RelNode makeSetRel(
+      List<RelNode> inputs,
+      SetRel.SetOp setOp,
+      AdvancedExtensionNode extensionNode,
+      SubstraitContext context,
+      Long operatorId) {
+    context.registerRelToOperator(operatorId);
+    return new SetRelNode(inputs, setOp, extensionNode);
+  }
 }
diff --git 
a/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/SetRelNode.java
 
b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/SetRelNode.java
new file mode 100644
index 0000000000..ddcfb1701d
--- /dev/null
+++ 
b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/SetRelNode.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.substrait.rel;
+
+import org.apache.gluten.substrait.extensions.AdvancedExtensionNode;
+
+import io.substrait.proto.Rel;
+import io.substrait.proto.RelCommon;
+import io.substrait.proto.SetRel;
+
+import java.io.Serializable;
+import java.util.List;
+
+public class SetRelNode implements RelNode, Serializable {
+  private final List<RelNode> inputs;
+  private final SetRel.SetOp setOp;
+  private final AdvancedExtensionNode extensionNode;
+
+  public SetRelNode(List<RelNode> inputs, SetRel.SetOp setOp, 
AdvancedExtensionNode extensionNode) {
+    this.inputs = inputs;
+    this.setOp = setOp;
+    this.extensionNode = extensionNode;
+  }
+
+  public SetRelNode(List<RelNode> inputs, SetRel.SetOp setOp) {
+    this(inputs, setOp, null);
+  }
+
+  @Override
+  public Rel toProtobuf() {
+    final RelCommon.Builder relCommonBuilder = RelCommon.newBuilder();
+    relCommonBuilder.setDirect(RelCommon.Direct.newBuilder());
+    final SetRel.Builder setBuilder = SetRel.newBuilder();
+    setBuilder.setCommon(relCommonBuilder.build());
+    if (inputs != null) {
+      for (RelNode input : inputs) {
+        setBuilder.addInputs(input.toProtobuf());
+      }
+    }
+    setBuilder.setOp(setOp);
+    if (extensionNode != null) {
+      setBuilder.setAdvancedExtension(extensionNode.toProtobuf());
+    }
+    final Rel.Builder builder = Rel.newBuilder();
+    builder.setSet(setBuilder.build());
+    return builder.build();
+  }
+}
diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/MetricsApi.scala
 
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/MetricsApi.scala
index c67d4b5f88..453cfab4e4 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/MetricsApi.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/MetricsApi.scala
@@ -126,6 +126,10 @@ trait MetricsApi extends Serializable {
 
   def genSampleTransformerMetricsUpdater(metrics: Map[String, SQLMetric]): 
MetricsUpdater
 
+  def genUnionTransformerMetrics(sparkContext: SparkContext): Map[String, 
SQLMetric]
+
+  def genUnionTransformerMetricsUpdater(metrics: Map[String, SQLMetric]): 
MetricsUpdater
+
   def genColumnarInMemoryTableMetrics(sparkContext: SparkContext): Map[String, 
SQLMetric] =
     Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of 
output rows"))
 }
diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala
 
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala
index f9755605ca..ac8e610956 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala
@@ -261,10 +261,11 @@ abstract class ProjectExecTransformerBase(val list: 
Seq[NamedExpression], val in
   }
 }
 
-// An alternatives for UnionExec.
+// An alternative for UnionExec.
 case class ColumnarUnionExec(children: Seq[SparkPlan]) extends ValidatablePlan 
{
   children.foreach {
     case w: WholeStageTransformer =>
+      // FIXME: Avoid such practice for plan immutability.
       w.setOutputSchemaForPlan(output)
     case _ =>
   }
diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/UnionExecTransformer.scala
 
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/UnionExecTransformer.scala
new file mode 100644
index 0000000000..d27558746a
--- /dev/null
+++ 
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/UnionExecTransformer.scala
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.execution
+
+import org.apache.gluten.backendsapi.BackendsApiManager
+import org.apache.gluten.expression.ConverterUtils
+import org.apache.gluten.extension.ValidationResult
+import org.apache.gluten.metrics.MetricsUpdater
+import org.apache.gluten.substrait.`type`.TypeBuilder
+import org.apache.gluten.substrait.SubstraitContext
+import org.apache.gluten.substrait.extensions.ExtensionBuilder
+import org.apache.gluten.substrait.rel.{RelBuilder, RelNode}
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.execution.{SparkPlan, UnionExec}
+import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.vectorized.ColumnarBatch
+
+import io.substrait.proto.SetRel.SetOp
+
+import scala.collection.JavaConverters._
+
+/** Transformer for UnionExec. Note: Spark's UnionExec represents a SQL UNION 
ALL. */
+case class UnionExecTransformer(children: Seq[SparkPlan]) extends 
TransformSupport {
+  private val union = UnionExec(children)
+
+  // Note: "metrics" is made transient to avoid sending driver-side metrics to 
tasks.
+  @transient override lazy val metrics: Map[String, SQLMetric] =
+    
BackendsApiManager.getMetricsApiInstance.genUnionTransformerMetrics(sparkContext)
+
+  override def output: Seq[Attribute] = union.output
+
+  override def columnarInputRDDs: Seq[RDD[ColumnarBatch]] = 
children.flatMap(getColumnarInputRDDs)
+
+  override def metricsUpdater(): MetricsUpdater =
+    
BackendsApiManager.getMetricsApiInstance.genUnionTransformerMetricsUpdater(metrics)
+
+  override protected def withNewChildrenInternal(newChildren: 
IndexedSeq[SparkPlan]): SparkPlan =
+    copy(children = newChildren)
+
+  override protected def doValidateInternal(): ValidationResult = {
+    val context = new SubstraitContext
+    val operatorId = context.nextOperatorId(this.nodeName)
+    val relNode = getRelNode(context, operatorId, children.map(_.output), 
null, true)
+    doNativeValidation(context, relNode)
+  }
+
+  override protected def doTransform(context: SubstraitContext): 
TransformContext = {
+    val childrenCtx = 
children.map(_.asInstanceOf[TransformSupport].transform(context))
+    val operatorId = context.nextOperatorId(this.nodeName)
+    val relNode =
+      getRelNode(context, operatorId, children.map(_.output), 
childrenCtx.map(_.root), false)
+    TransformContext(output, relNode)
+  }
+
+  private def getRelNode(
+      context: SubstraitContext,
+      operatorId: Long,
+      inputAttributes: Seq[Seq[Attribute]],
+      inputs: Seq[RelNode],
+      validation: Boolean): RelNode = {
+    if (validation) {
+      // Use the second level of nesting to represent N way inputs.
+      val inputTypeNodes =
+        inputAttributes.map(
+          attributes =>
+            attributes.map(attr => ConverterUtils.getTypeNode(attr.dataType, 
attr.nullable)).asJava)
+      val extensionNode = ExtensionBuilder.makeAdvancedExtension(
+        BackendsApiManager.getTransformerApiInstance.packPBMessage(
+          TypeBuilder
+            .makeStruct(
+              false,
+              inputTypeNodes.map(nodes => TypeBuilder.makeStruct(false, 
nodes)).asJava)
+            .toProtobuf))
+      return RelBuilder.makeSetRel(
+        inputs.asJava,
+        SetOp.SET_OP_UNION_ALL,
+        extensionNode,
+        context,
+        operatorId)
+    }
+    RelBuilder.makeSetRel(inputs.asJava, SetOp.SET_OP_UNION_ALL, context, 
operatorId)
+  }
+}
diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/UnionTransformerRule.scala
 
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/UnionTransformerRule.scala
new file mode 100644
index 0000000000..f0eea08018
--- /dev/null
+++ 
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/UnionTransformerRule.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension.columnar
+
+import org.apache.gluten.GlutenConfig
+import org.apache.gluten.execution.{ColumnarUnionExec, UnionExecTransformer}
+
+import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.SparkPlan
+
+/**
+ * Replace ColumnarUnionExec with UnionExecTransformer if possible.
+ *
+ * The rule is not included in 
[[org.apache.gluten.extension.columnar.heuristic.HeuristicTransform]]
+ * or [[org.apache.gluten.extension.columnar.enumerated.EnumeratedTransform]] 
because it relies on
+ * children's output partitioning to be fully provided.
+ */
+case class UnionTransformerRule() extends Rule[SparkPlan] {
+  override def apply(plan: SparkPlan): SparkPlan = {
+    if (!GlutenConfig.getConf.enableNativeUnion) {
+      return plan
+    }
+    plan.transformUp {
+      case plan: ColumnarUnionExec =>
+        val transformer = UnionExecTransformer(plan.children)
+        if (sameNumPartitions(plan.children) && validate(transformer)) {
+          transformer
+        } else {
+          plan
+        }
+    }
+  }
+
+  private def sameNumPartitions(plans: Seq[SparkPlan]): Boolean = {
+    val partitioning = plans.map(_.outputPartitioning)
+    if (partitioning.exists(p => p.isInstanceOf[UnknownPartitioning])) {
+      return false
+    }
+    val numPartitions = plans.map(_.outputPartitioning.numPartitions)
+    numPartitions.forall(_ == numPartitions.head)
+  }
+
+  private def validate(union: UnionExecTransformer): Boolean = {
+    union.doValidate().ok()
+  }
+}
diff --git 
a/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
 
b/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
index fd250834d0..08081fadb5 100644
--- 
a/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
+++ 
b/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
@@ -176,25 +176,39 @@ abstract class WholeStageTransformerSuite
       result
   }
 
+  protected def compareResultsAgainstVanillaSpark(
+      sql: String,
+      compareResult: Boolean = true,
+      customCheck: DataFrame => Unit,
+      noFallBack: Boolean = true,
+      cache: Boolean = false): DataFrame = {
+    compareDfResultsAgainstVanillaSpark(
+      () => spark.sql(sql),
+      compareResult,
+      customCheck,
+      noFallBack,
+      cache)
+  }
+
   /**
    * run a query with native engine as well as vanilla spark then compare the 
result set for
    * correctness check
    */
-  protected def compareResultsAgainstVanillaSpark(
-      sqlStr: String,
+  protected def compareDfResultsAgainstVanillaSpark(
+      dataframe: () => DataFrame,
       compareResult: Boolean = true,
       customCheck: DataFrame => Unit,
       noFallBack: Boolean = true,
       cache: Boolean = false): DataFrame = {
     var expected: Seq[Row] = null
     withSQLConf(vanillaSparkConfs(): _*) {
-      val df = spark.sql(sqlStr)
+      val df = dataframe()
       expected = df.collect()
     }
-    // By default we will fallabck complex type scan but here we should allow
+    // By default, we will fallback complex type scan but here we should allow
     // to test support of complex type
     spark.conf.set("spark.gluten.sql.complexType.scan.fallback.enabled", 
"false");
-    val df = spark.sql(sqlStr)
+    val df = dataframe()
     if (cache) {
       df.cache()
     }
@@ -239,7 +253,12 @@ abstract class WholeStageTransformerSuite
       noFallBack: Boolean = true,
       cache: Boolean = false)(customCheck: DataFrame => Unit): DataFrame = {
 
-    compareResultsAgainstVanillaSpark(sqlStr, compareResult, customCheck, 
noFallBack, cache)
+    compareDfResultsAgainstVanillaSpark(
+      () => spark.sql(sqlStr),
+      compareResult,
+      customCheck,
+      noFallBack,
+      cache)
   }
 
   /**
@@ -256,8 +275,8 @@ abstract class WholeStageTransformerSuite
       customCheck: DataFrame => Unit,
       noFallBack: Boolean = true,
       compareResult: Boolean = true): Unit =
-    compareResultsAgainstVanillaSpark(
-      tpchSQL(queryNum, tpchQueries),
+    compareDfResultsAgainstVanillaSpark(
+      () => spark.sql(tpchSQL(queryNum, tpchQueries)),
       compareResult,
       customCheck,
       noFallBack)
diff --git a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala 
b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
index f643ad7eed..9ae4c0ce90 100644
--- a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
+++ b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
@@ -86,6 +86,8 @@ class GlutenConfig(conf: SQLConf) extends Logging {
 
   def enableColumnarUnion: Boolean = conf.getConf(COLUMNAR_UNION_ENABLED)
 
+  def enableNativeUnion: Boolean = conf.getConf(NATIVE_UNION_ENABLED)
+
   def enableColumnarExpand: Boolean = conf.getConf(COLUMNAR_EXPAND_ENABLED)
 
   def enableColumnarBroadcastExchange: Boolean = 
conf.getConf(COLUMNAR_BROADCAST_EXCHANGE_ENABLED)
@@ -1022,6 +1024,13 @@ object GlutenConfig {
       .booleanConf
       .createWithDefault(true)
 
+  val NATIVE_UNION_ENABLED =
+    buildConf("spark.gluten.sql.native.union")
+      .internal()
+      .doc("Enable or disable native union where computation is completely 
offloaded to backend.")
+      .booleanConf
+      .createWithDefault(false)
+
   val COLUMNAR_EXPAND_ENABLED =
     buildConf("spark.gluten.sql.columnar.expand")
       .internal()


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to