This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 6dd91bac6c [GLUTEN-7745][VL] Incorporate SQL Union operator into Velox
execution pipeline (#7842)
6dd91bac6c is described below
commit 6dd91bac6cc155810beb34415a6dee1a03069f6d
Author: Hongze Zhang <[email protected]>
AuthorDate: Tue Dec 3 09:40:46 2024 +0800
[GLUTEN-7745][VL] Incorporate SQL Union operator into Velox execution
pipeline (#7842)
---
.../backendsapi/clickhouse/CHMetricsApi.scala | 8 ++
.../org/apache/gluten/utils/GlutenURLDecoder.java | 2 +-
.../backendsapi/velox/VeloxListenerApi.scala | 3 +-
.../gluten/backendsapi/velox/VeloxMetricsApi.scala | 11 ++
.../gluten/backendsapi/velox/VeloxRuleApi.scala | 2 +
.../org/apache/gluten/metrics/MetricsUtil.scala | 9 +-
.../gluten/metrics/UnionMetricsUpdater.scala | 35 +++++
.../gluten/execution/MiscOperatorSuite.scala | 28 +++-
.../VeloxOrcDataTypeValidationSuite.scala | 5 +-
.../VeloxParquetDataTypeValidationSuite.scala | 5 +-
cpp/velox/compute/WholeStageResultIterator.cc | 44 +++++-
cpp/velox/substrait/SubstraitToVeloxPlan.cc | 46 +++++++
cpp/velox/substrait/SubstraitToVeloxPlan.h | 3 +
.../substrait/SubstraitToVeloxPlanValidator.cc | 153 ++++++++++++++++-----
.../substrait/SubstraitToVeloxPlanValidator.h | 15 +-
.../apache/gluten/substrait/rel/RelBuilder.java | 17 +++
.../apache/gluten/substrait/rel/SetRelNode.java | 62 +++++++++
.../org/apache/gluten/backendsapi/MetricsApi.scala | 4 +
.../BasicPhysicalOperatorTransformer.scala | 3 +-
.../gluten/execution/UnionExecTransformer.scala | 99 +++++++++++++
.../extension/columnar/UnionTransformerRule.scala | 61 ++++++++
.../execution/WholeStageTransformerSuite.scala | 35 +++--
.../scala/org/apache/gluten/GlutenConfig.scala | 9 ++
23 files changed, 604 insertions(+), 55 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala
index 73b2d0f211..a0576a807b 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala
@@ -450,6 +450,14 @@ class CHMetricsApi extends MetricsApi with Logging with
LogLevelUtil {
s"SampleTransformer metrics update is not supported in CH backend")
}
+ override def genUnionTransformerMetrics(sparkContext: SparkContext):
Map[String, SQLMetric] =
+ throw new UnsupportedOperationException(
+ "UnionExecTransformer metrics update is not supported in CH backend")
+
+ override def genUnionTransformerMetricsUpdater(metrics: Map[String,
SQLMetric]): MetricsUpdater =
+ throw new UnsupportedOperationException(
+ "UnionExecTransformer metrics update is not supported in CH backend")
+
def genWriteFilesTransformerMetrics(sparkContext: SparkContext): Map[String,
SQLMetric] =
Map(
"physicalWrittenBytes" -> SQLMetrics.createMetric(sparkContext, "number
of written bytes"),
diff --git
a/backends-velox/src/main/java/org/apache/gluten/utils/GlutenURLDecoder.java
b/backends-velox/src/main/java/org/apache/gluten/utils/GlutenURLDecoder.java
index 9228a2f860..856ddf1597 100644
--- a/backends-velox/src/main/java/org/apache/gluten/utils/GlutenURLDecoder.java
+++ b/backends-velox/src/main/java/org/apache/gluten/utils/GlutenURLDecoder.java
@@ -31,7 +31,7 @@ public class GlutenURLDecoder {
* <p><em><strong>Note:</strong> The <a href=
* "http://www.w3.org/TR/html40/appendix/notes.html#non-ascii-chars"> World
Wide Web Consortium
* Recommendation</a> states that UTF-8 should be used. Not doing so may
introduce
- * incompatibilites.</em>
+ * incompatibilities.</em>
*
* @param s the <code>String</code> to decode
* @param enc The name of a supported <a
href="../lang/package-summary.html#charenc">character
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
index d29d302970..3a82abe618 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
@@ -31,6 +31,7 @@ import org.apache.spark.{HdfsConfGenerator, SparkConf,
SparkContext}
import org.apache.spark.api.plugin.PluginContext
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.ByteUnit
+import org.apache.spark.sql.execution.ColumnarCachedBatchSerializer
import org.apache.spark.sql.execution.datasources.GlutenWriterColumnarRules
import
org.apache.spark.sql.execution.datasources.velox.{VeloxParquetWriterInjects,
VeloxRowSplitter}
import org.apache.spark.sql.expression.UDFResolver
@@ -75,7 +76,7 @@ class VeloxListenerApi extends ListenerApi with Logging {
if (conf.getBoolean(GlutenConfig.COLUMNAR_TABLE_CACHE_ENABLED.key,
defaultValue = false)) {
conf.set(
StaticSQLConf.SPARK_CACHE_SERIALIZER.key,
- "org.apache.spark.sql.execution.ColumnarCachedBatchSerializer")
+ classOf[ColumnarCachedBatchSerializer].getName)
}
// Static initializers for driver.
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxMetricsApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxMetricsApi.scala
index e70e1d13bd..934b680382 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxMetricsApi.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxMetricsApi.scala
@@ -582,4 +582,15 @@ class VeloxMetricsApi extends MetricsApi with Logging {
override def genSampleTransformerMetricsUpdater(metrics: Map[String,
SQLMetric]): MetricsUpdater =
new SampleMetricsUpdater(metrics)
+
+ override def genUnionTransformerMetrics(sparkContext: SparkContext):
Map[String, SQLMetric] = Map(
+ "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input
rows"),
+ "inputVectors" -> SQLMetrics.createMetric(sparkContext, "number of input
vectors"),
+ "inputBytes" -> SQLMetrics.createSizeMetric(sparkContext, "number of input
bytes"),
+ "wallNanos" -> SQLMetrics.createNanoTimingMetric(sparkContext, "time of
union"),
+ "cpuCount" -> SQLMetrics.createMetric(sparkContext, "cpu wall time count")
+ )
+
+ override def genUnionTransformerMetricsUpdater(metrics: Map[String,
SQLMetric]): MetricsUpdater =
+ new UnionMetricsUpdater(metrics)
}
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
index 7841e6cd94..7337be5737 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
@@ -92,6 +92,7 @@ object VeloxRuleApi {
c => HeuristicTransform.Single(validatorBuilder(c.glutenConf), rewrites,
offloads))
// Legacy: Post-transform rules.
+ injector.injectPostTransform(_ => UnionTransformerRule())
injector.injectPostTransform(c => PartialProjectRule.apply(c.session))
injector.injectPostTransform(_ => RemoveNativeWriteFilesSortAndProject())
injector.injectPostTransform(c => RewriteTransformer.apply(c.session))
@@ -178,6 +179,7 @@ object VeloxRuleApi {
// Gluten RAS: Post rules.
injector.injectPostTransform(_ => RemoveTransitions)
+ injector.injectPostTransform(_ => UnionTransformerRule())
injector.injectPostTransform(c => PartialProjectRule.apply(c.session))
injector.injectPostTransform(_ => RemoveNativeWriteFilesSortAndProject())
injector.injectPostTransform(c => RewriteTransformer.apply(c.session))
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/metrics/MetricsUtil.scala
b/backends-velox/src/main/scala/org/apache/gluten/metrics/MetricsUtil.scala
index cd50d0b8e2..b8ef1620f9 100644
--- a/backends-velox/src/main/scala/org/apache/gluten/metrics/MetricsUtil.scala
+++ b/backends-velox/src/main/scala/org/apache/gluten/metrics/MetricsUtil.scala
@@ -58,7 +58,8 @@ object MetricsUtil extends Logging {
assert(t.children.size == 1, "MetricsUpdater.None can only be used
on unary operator")
treeifyMetricsUpdaters(t.children.head)
case t: TransformSupport =>
- MetricsUpdaterTree(t.metricsUpdater(),
t.children.map(treeifyMetricsUpdaters))
+ // Reversed children order to match the traversal code.
+ MetricsUpdaterTree(t.metricsUpdater(),
t.children.reverse.map(treeifyMetricsUpdaters))
case _ =>
MetricsUpdaterTree(MetricsUpdater.Terminate, Seq())
}
@@ -233,6 +234,12 @@ object MetricsUtil extends Logging {
operatorMetrics,
metrics.getSingleMetrics,
joinParamsMap.get(operatorIdx))
+ case u: UnionMetricsUpdater =>
+ // JoinRel outputs two suites of metrics respectively for hash build
and hash probe.
+ // Therefore, fetch one more suite of metrics here.
+ operatorMetrics.add(metrics.getOperatorMetrics(curMetricsIdx))
+ curMetricsIdx -= 1
+ u.updateUnionMetrics(operatorMetrics)
case hau: HashAggregateMetricsUpdater =>
hau.updateAggregationMetrics(operatorMetrics,
aggParamsMap.get(operatorIdx))
case lu: LimitMetricsUpdater =>
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/metrics/UnionMetricsUpdater.scala
b/backends-velox/src/main/scala/org/apache/gluten/metrics/UnionMetricsUpdater.scala
new file mode 100644
index 0000000000..9e91cf368c
--- /dev/null
+++
b/backends-velox/src/main/scala/org/apache/gluten/metrics/UnionMetricsUpdater.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.metrics
+
+import org.apache.spark.sql.execution.metric.SQLMetric
+
+class UnionMetricsUpdater(val metrics: Map[String, SQLMetric]) extends
MetricsUpdater {
+ override def updateNativeMetrics(opMetrics: IOperatorMetrics): Unit = {
+ throw new UnsupportedOperationException()
+ }
+
+ def updateUnionMetrics(unionMetrics: java.util.ArrayList[OperatorMetrics]):
Unit = {
+ // Union was interpreted to LocalExchange + LocalPartition. Use metrics
from LocalExchange.
+ val localExchangeMetrics = unionMetrics.get(0)
+ metrics("numInputRows") += localExchangeMetrics.inputRows
+ metrics("inputVectors") += localExchangeMetrics.inputVectors
+ metrics("inputBytes") += localExchangeMetrics.inputBytes
+ metrics("cpuCount") += localExchangeMetrics.cpuCount
+ metrics("wallNanos") += localExchangeMetrics.wallNanos
+ }
+}
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala
b/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala
index 5cb2b65260..8063a5d122 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala
@@ -537,11 +537,37 @@ class MiscOperatorSuite extends
VeloxWholeStageTransformerSuite with AdaptiveSpa
|""".stripMargin) {
df =>
{
- getExecutedPlan(df).exists(plan =>
plan.find(_.isInstanceOf[ColumnarUnionExec]).isDefined)
+ assert(
+ getExecutedPlan(df).exists(
+ plan => plan.find(_.isInstanceOf[ColumnarUnionExec]).isDefined))
}
}
}
+ test("union_all two tables with known partitioning") {
+ withSQLConf(GlutenConfig.NATIVE_UNION_ENABLED.key -> "true") {
+ compareDfResultsAgainstVanillaSpark(
+ () => {
+ val df1 = spark.sql("select l_orderkey as orderkey from lineitem")
+ val df2 = spark.sql("select o_orderkey as orderkey from orders")
+ df1.repartition(5).union(df2.repartition(5))
+ },
+ compareResult = true,
+ checkGlutenOperatorMatch[UnionExecTransformer]
+ )
+
+ compareDfResultsAgainstVanillaSpark(
+ () => {
+ val df1 = spark.sql("select l_orderkey as orderkey from lineitem")
+ val df2 = spark.sql("select o_orderkey as orderkey from orders")
+ df1.repartition(5).union(df2.repartition(6))
+ },
+ compareResult = true,
+ checkGlutenOperatorMatch[ColumnarUnionExec]
+ )
+ }
+ }
+
test("union_all three tables") {
runQueryAndCompare("""
|select count(orderkey) from (
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxOrcDataTypeValidationSuite.scala
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxOrcDataTypeValidationSuite.scala
index 24e04f2dfc..6ac59ba4fa 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxOrcDataTypeValidationSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxOrcDataTypeValidationSuite.scala
@@ -255,7 +255,10 @@ class VeloxOrcDataTypeValidationSuite extends
VeloxWholeStageTransformerSuite {
|""".stripMargin) {
df =>
{
- assert(getExecutedPlan(df).exists(plan =>
plan.isInstanceOf[ColumnarUnionExec]))
+ assert(
+ getExecutedPlan(df).exists(
+ plan =>
+ plan.isInstanceOf[ColumnarUnionExec] ||
plan.isInstanceOf[UnionExecTransformer]))
}
}
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxParquetDataTypeValidationSuite.scala
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxParquetDataTypeValidationSuite.scala
index 57ca448fec..cb5614f396 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxParquetDataTypeValidationSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxParquetDataTypeValidationSuite.scala
@@ -254,7 +254,10 @@ class VeloxParquetDataTypeValidationSuite extends
VeloxWholeStageTransformerSuit
|""".stripMargin) {
df =>
{
- assert(getExecutedPlan(df).exists(plan =>
plan.isInstanceOf[ColumnarUnionExec]))
+ assert(
+ getExecutedPlan(df).exists(
+ plan =>
+ plan.isInstanceOf[ColumnarUnionExec] ||
plan.isInstanceOf[UnionExecTransformer]))
}
}
diff --git a/cpp/velox/compute/WholeStageResultIterator.cc
b/cpp/velox/compute/WholeStageResultIterator.cc
index b6ecbd959f..411c6c5636 100644
--- a/cpp/velox/compute/WholeStageResultIterator.cc
+++ b/cpp/velox/compute/WholeStageResultIterator.cc
@@ -91,7 +91,7 @@ WholeStageResultIterator::WholeStageResultIterator(
std::move(queryCtx),
velox::exec::Task::ExecutionMode::kSerial);
if (!task_->supportSerialExecutionMode()) {
- throw std::runtime_error("Task doesn't support single thread execution: "
+ planNode->toString());
+ throw std::runtime_error("Task doesn't support single threaded execution:
" + planNode->toString());
}
auto fileSystem = velox::filesystems::getFileSystem(spillDir, nullptr);
GLUTEN_CHECK(fileSystem != nullptr, "File System for spilling is null!");
@@ -248,15 +248,47 @@ void WholeStageResultIterator::getOrderedNodeIds(
const std::shared_ptr<const velox::core::PlanNode>& planNode,
std::vector<velox::core::PlanNodeId>& nodeIds) {
bool isProjectNode = (std::dynamic_pointer_cast<const
velox::core::ProjectNode>(planNode) != nullptr);
+ bool isLocalExchangeNode = (std::dynamic_pointer_cast<const
velox::core::LocalPartitionNode>(planNode) != nullptr);
+ bool isUnionNode = isLocalExchangeNode &&
+ std::dynamic_pointer_cast<const
velox::core::LocalPartitionNode>(planNode)->type() ==
+ velox::core::LocalPartitionNode::Type::kGather;
const auto& sourceNodes = planNode->sources();
- for (const auto& sourceNode : sourceNodes) {
+ if (isProjectNode) {
+ GLUTEN_CHECK(sourceNodes.size() == 1, "Illegal state");
+ const auto sourceNode = sourceNodes.at(0);
// Filter over Project are mapped into FilterProject operator in Velox.
// Metrics are all applied on Project node, and the metrics for Filter node
// do not exist.
- if (isProjectNode && std::dynamic_pointer_cast<const
velox::core::FilterNode>(sourceNode)) {
+ if (std::dynamic_pointer_cast<const velox::core::FilterNode>(sourceNode)) {
omittedNodeIds_.insert(sourceNode->id());
}
getOrderedNodeIds(sourceNode, nodeIds);
+ nodeIds.emplace_back(planNode->id());
+ return;
+ }
+
+ if (isUnionNode) {
+ // FIXME: The whole metrics system in gluten-substrait is magic. Passing
metrics trees through JNI with a trivial
+ // array is possible but requires for a solid design. Apparently we
haven't had it. All the code requires complete
+ // rework.
+ // Union was interpreted as LocalPartition + LocalExchange + 2 fake
projects as children in Velox. So we only fetch
+ // metrics from the root node.
+ std::vector<std::shared_ptr<const velox::core::PlanNode>> unionChildren{};
+ for (const auto& source : planNode->sources()) {
+ const auto projectedChild = std::dynamic_pointer_cast<const
velox::core::ProjectNode>(source);
+ GLUTEN_CHECK(projectedChild != nullptr, "Illegal state");
+ const auto projectSources = projectedChild->sources();
+ GLUTEN_CHECK(projectSources.size() == 1, "Illegal state");
+ const auto projectSource = projectSources.at(0);
+ getOrderedNodeIds(projectSource, nodeIds);
+ }
+ nodeIds.emplace_back(planNode->id());
+ return;
+ }
+
+ for (const auto& sourceNode : sourceNodes) {
+ // Post-order traversal.
+ getOrderedNodeIds(sourceNode, nodeIds);
}
nodeIds.emplace_back(planNode->id());
}
@@ -350,9 +382,9 @@ void WholeStageResultIterator::collectMetrics() {
continue;
}
- const auto& status = planStats.at(nodeId);
- // Add each operator status into metrics.
- for (const auto& entry : status.operatorStats) {
+ const auto& stats = planStats.at(nodeId);
+ // Add each operator stats into metrics.
+ for (const auto& entry : stats.operatorStats) {
const auto& second = entry.second;
metrics_->get(Metrics::kInputRows)[metricIndex] = second->inputRows;
metrics_->get(Metrics::kInputVectors)[metricIndex] =
second->inputVectors;
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.cc
b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
index 1efa733879..3ceccca4a3 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlan.cc
+++ b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
@@ -1043,6 +1043,50 @@ core::PlanNodePtr
SubstraitToVeloxPlanConverter::toVeloxPlan(
childNode);
}
+core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const
::substrait::SetRel& setRel) {
+ switch (setRel.op()) {
+ case ::substrait::SetRel_SetOp::SetRel_SetOp_SET_OP_UNION_ALL: {
+ std::vector<core::PlanNodePtr> children;
+ for (int32_t i = 0; i < setRel.inputs_size(); ++i) {
+ const auto& input = setRel.inputs(i);
+ children.push_back(toVeloxPlan(input));
+ }
+ GLUTEN_CHECK(!children.empty(), "At least one source is required for
Velox LocalPartition");
+
+ // Velox doesn't allow different field names in schemas of
LocalPartitionNode's children.
+ // Add project nodes to unify the schemas.
+ const RowTypePtr outRowType = asRowType(children[0]->outputType());
+ std::vector<std::string> outNames;
+ for (int32_t colIdx = 0; colIdx < outRowType->size(); ++colIdx) {
+ const auto name = outRowType->childAt(colIdx)->name();
+ outNames.push_back(name);
+ }
+
+ std::vector<core::PlanNodePtr> projectedChildren;
+ for (int32_t i = 0; i < children.size(); ++i) {
+ const auto& child = children[i];
+ const RowTypePtr& childRowType = child->outputType();
+ std::vector<core::TypedExprPtr> expressions;
+ for (int32_t colIdx = 0; colIdx < outNames.size(); ++colIdx) {
+ const auto fa =
+
std::make_shared<core::FieldAccessTypedExpr>(childRowType->childAt(colIdx),
childRowType->nameOf(colIdx));
+ const auto cast =
std::make_shared<core::CastTypedExpr>(outRowType->childAt(colIdx), fa, false);
+ expressions.push_back(cast);
+ }
+ auto project = std::make_shared<core::ProjectNode>(nextPlanNodeId(),
outNames, expressions, child);
+ projectedChildren.push_back(project);
+ }
+ return std::make_shared<core::LocalPartitionNode>(
+ nextPlanNodeId(),
+ core::LocalPartitionNode::Type::kGather,
+ std::make_shared<core::GatherPartitionFunctionSpec>(),
+ projectedChildren);
+ }
+ default:
+ throw GlutenException("Unsupported SetRel op: " +
std::to_string(setRel.op()));
+ }
+}
+
core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const
::substrait::SortRel& sortRel) {
auto childNode = convertSingleInput<::substrait::SortRel>(sortRel);
auto [sortingKeys, sortingOrders] = processSortField(sortRel.sorts(),
childNode->outputType());
@@ -1298,6 +1342,8 @@ core::PlanNodePtr
SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
return toVeloxPlan(rel.write());
} else if (rel.has_windowgrouplimit()) {
return toVeloxPlan(rel.windowgrouplimit());
+ } else if (rel.has_set()) {
+ return toVeloxPlan(rel.set());
} else {
VELOX_NYI("Substrait conversion not supported for Rel.");
}
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.h
b/cpp/velox/substrait/SubstraitToVeloxPlan.h
index 51e50ce347..6121923df7 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlan.h
+++ b/cpp/velox/substrait/SubstraitToVeloxPlan.h
@@ -84,6 +84,9 @@ class SubstraitToVeloxPlanConverter {
/// Used to convert Substrait WindowGroupLimitRel into Velox PlanNode.
core::PlanNodePtr toVeloxPlan(const ::substrait::WindowGroupLimitRel&
windowGroupLimitRel);
+ /// Used to convert Substrait SetRel into Velox PlanNode.
+ core::PlanNodePtr toVeloxPlan(const ::substrait::SetRel& setRel);
+
/// Used to convert Substrait JoinRel into Velox PlanNode.
core::PlanNodePtr toVeloxPlan(const ::substrait::JoinRel& joinRel);
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
index 3b74caf8ba..9325fed321 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
+++ b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
@@ -22,7 +22,6 @@
#include "TypeUtils.h"
#include "udf/UdfLoader.h"
#include "utils/Common.h"
-#include "velox/core/ExpressionEvaluator.h"
#include "velox/exec/Aggregate.h"
#include "velox/expression/Expr.h"
#include "velox/expression/SignatureBinder.h"
@@ -30,7 +29,7 @@
namespace gluten {
namespace {
-static const char* extractFileName(const char* file) {
+const char* extractFileName(const char* file) {
return strrchr(file, '/') ? strrchr(file, '/') + 1 : file;
}
@@ -53,13 +52,13 @@ static const char* extractFileName(const char* file) {
__FUNCTION__, \
reason))
-static const std::unordered_set<std::string> kRegexFunctions = {
+const std::unordered_set<std::string> kRegexFunctions = {
"regexp_extract",
"regexp_extract_all",
"regexp_replace",
"rlike"};
-static const std::unordered_set<std::string> kBlackList = {
+const std::unordered_set<std::string> kBlackList = {
"split_part",
"factorial",
"concat_ws",
@@ -70,32 +69,59 @@ static const std::unordered_set<std::string> kBlackList = {
"approx_percentile",
"get_array_struct_fields",
"map_from_arrays"};
-
} // namespace
-bool SubstraitToVeloxPlanValidator::validateInputTypes(
+bool SubstraitToVeloxPlanValidator::parseVeloxType(
const ::substrait::extensions::AdvancedExtension& extension,
- std::vector<TypePtr>& types) {
+ TypePtr& out) {
+ ::substrait::Type substraitType;
// The input type is wrapped in enhancement.
if (!extension.has_enhancement()) {
LOG_VALIDATION_MSG("Input type is not wrapped in enhancement.");
return false;
}
const auto& enhancement = extension.enhancement();
- ::substrait::Type inputType;
- if (!enhancement.UnpackTo(&inputType)) {
+ if (!enhancement.UnpackTo(&substraitType)) {
LOG_VALIDATION_MSG("Enhancement can't be unpacked to inputType.");
return false;
}
- if (!inputType.has_struct_()) {
- LOG_VALIDATION_MSG("Input type has no struct.");
+
+ out = SubstraitParser::parseType(substraitType);
+ return true;
+}
+
+bool SubstraitToVeloxPlanValidator::flattenVeloxType1(const TypePtr& type,
std::vector<TypePtr>& out) {
+ if (type->kind() != TypeKind::ROW) {
+ LOG_VALIDATION_MSG("Type is not a RowType.");
+ return false;
+ }
+ auto rowType = std::dynamic_pointer_cast<const RowType>(type);
+ if (!rowType) {
+ LOG_VALIDATION_MSG("Failed to cast to RowType.");
return false;
}
+ for (const auto& field : rowType->children()) {
+ out.emplace_back(field);
+ }
+ return true;
+}
- // Get the input types.
- const auto& sTypes = inputType.struct_().types();
- for (const auto& sType : sTypes) {
- types.emplace_back(SubstraitParser::parseType(sType));
+bool SubstraitToVeloxPlanValidator::flattenVeloxType2(const TypePtr& type,
std::vector<std::vector<TypePtr>>& out) {
+ if (type->kind() != TypeKind::ROW) {
+ LOG_VALIDATION_MSG("Type is not a RowType.");
+ return false;
+ }
+ auto rowType = std::dynamic_pointer_cast<const RowType>(type);
+ if (!rowType) {
+ LOG_VALIDATION_MSG("Failed to cast to RowType.");
+ return false;
+ }
+ for (const auto& field : rowType->children()) {
+ std::vector<TypePtr> inner;
+ if (!flattenVeloxType1(field, inner)) {
+ return false;
+ }
+ out.emplace_back(inner);
}
return true;
}
@@ -341,10 +367,11 @@ bool SubstraitToVeloxPlanValidator::validate(const
::substrait::WriteRel& writeR
}
// Validate input data type.
+ TypePtr inputRowType;
std::vector<TypePtr> types;
if (writeRel.has_named_table()) {
const auto& extension = writeRel.named_table().advanced_extension();
- if (!validateInputTypes(extension, types)) {
+ if (!parseVeloxType(extension, inputRowType) ||
!flattenVeloxType1(inputRowType, types)) {
LOG_VALIDATION_MSG("Validation failed for input type validation in
WriteRel.");
return false;
}
@@ -380,12 +407,12 @@ bool SubstraitToVeloxPlanValidator::validate(const
::substrait::WriteRel& writeR
}
bool SubstraitToVeloxPlanValidator::validate(const ::substrait::FetchRel&
fetchRel) {
- RowTypePtr rowType = nullptr;
// Get and validate the input types from extension.
if (fetchRel.has_advanced_extension()) {
const auto& extension = fetchRel.advanced_extension();
+ TypePtr inputRowType;
std::vector<TypePtr> types;
- if (!validateInputTypes(extension, types)) {
+ if (!parseVeloxType(extension, inputRowType) ||
!flattenVeloxType1(inputRowType, types)) {
LOG_VALIDATION_MSG("Unsupported input types in FetchRel.");
return false;
}
@@ -396,7 +423,6 @@ bool SubstraitToVeloxPlanValidator::validate(const
::substrait::FetchRel& fetchR
for (auto colIdx = 0; colIdx < types.size(); colIdx++) {
names.emplace_back(SubstraitParser::makeNodeName(inputPlanNodeId,
colIdx));
}
- rowType = std::make_shared<RowType>(std::move(names), std::move(types));
}
if (fetchRel.offset() < 0 || fetchRel.count() < 0) {
@@ -412,8 +438,9 @@ bool SubstraitToVeloxPlanValidator::validate(const
::substrait::TopNRel& topNRel
// Get and validate the input types from extension.
if (topNRel.has_advanced_extension()) {
const auto& extension = topNRel.advanced_extension();
+ TypePtr inputRowType;
std::vector<TypePtr> types;
- if (!validateInputTypes(extension, types)) {
+ if (!parseVeloxType(extension, inputRowType) ||
!flattenVeloxType1(inputRowType, types)) {
LOG_VALIDATION_MSG("Unsupported input types in TopNRel.");
return false;
}
@@ -457,8 +484,9 @@ bool SubstraitToVeloxPlanValidator::validate(const
::substrait::GenerateRel& gen
return false;
}
const auto& extension = generateRel.advanced_extension();
+ TypePtr inputRowType;
std::vector<TypePtr> types;
- if (!validateInputTypes(extension, types)) {
+ if (!parseVeloxType(extension, inputRowType) ||
!flattenVeloxType1(inputRowType, types)) {
LOG_VALIDATION_MSG("Validation failed for input types in GenerateRel.");
return false;
}
@@ -487,8 +515,9 @@ bool SubstraitToVeloxPlanValidator::validate(const
::substrait::ExpandRel& expan
// Get and validate the input types from extension.
if (expandRel.has_advanced_extension()) {
const auto& extension = expandRel.advanced_extension();
+ TypePtr inputRowType;
std::vector<TypePtr> types;
- if (!validateInputTypes(extension, types)) {
+ if (!parseVeloxType(extension, inputRowType) ||
!flattenVeloxType1(inputRowType, types)) {
LOG_VALIDATION_MSG("Unsupported input types in ExpandRel.");
return false;
}
@@ -571,8 +600,9 @@ bool SubstraitToVeloxPlanValidator::validate(const
::substrait::WindowRel& windo
return false;
}
const auto& extension = windowRel.advanced_extension();
+ TypePtr inputRowType;
std::vector<TypePtr> types;
- if (!validateInputTypes(extension, types)) {
+ if (!parseVeloxType(extension, inputRowType) ||
!flattenVeloxType1(inputRowType, types)) {
LOG_VALIDATION_MSG("Validation failed for input types in WindowRel.");
return false;
}
@@ -680,7 +710,7 @@ bool SubstraitToVeloxPlanValidator::validate(const
::substrait::WindowRel& windo
LOG_VALIDATION_MSG("in windowRel, the sorting key in Sort Operator
only support field.");
return false;
}
- exec::ExprSet exprSet({std::move(expression)}, execCtx_);
+ exec::ExprSet exprSet1({std::move(expression)}, execCtx_);
}
}
@@ -699,8 +729,9 @@ bool SubstraitToVeloxPlanValidator::validate(const
::substrait::WindowGroupLimit
return false;
}
const auto& extension = windowGroupLimitRel.advanced_extension();
+ TypePtr inputRowType;
std::vector<TypePtr> types;
- if (!validateInputTypes(extension, types)) {
+ if (!parseVeloxType(extension, inputRowType) ||
!flattenVeloxType1(inputRowType, types)) {
LOG_VALIDATION_MSG("Validation failed for input types in
WindowGroupLimitRel.");
return false;
}
@@ -750,13 +781,61 @@ bool SubstraitToVeloxPlanValidator::validate(const
::substrait::WindowGroupLimit
LOG_VALIDATION_MSG("in windowGroupLimitRel, the sorting key in Sort
Operator only support field.");
return false;
}
- exec::ExprSet exprSet({std::move(expression)}, execCtx_);
+ exec::ExprSet exprSet1({std::move(expression)}, execCtx_);
}
}
return true;
}
+bool SubstraitToVeloxPlanValidator::validate(const ::substrait::SetRel&
setRel) {
+ switch (setRel.op()) {
+ case ::substrait::SetRel_SetOp::SetRel_SetOp_SET_OP_UNION_ALL: {
+ for (int32_t i = 0; i < setRel.inputs_size(); ++i) {
+ const auto& input = setRel.inputs(i);
+ if (!validate(input)) {
+ LOG_VALIDATION_MSG("ProjectRel input");
+ return false;
+ }
+ }
+ if (!setRel.has_advanced_extension()) {
+ LOG_VALIDATION_MSG("Input types are expected in SetRel.");
+ return false;
+ }
+ const auto& extension = setRel.advanced_extension();
+ TypePtr inputRowType;
+ std::vector<std::vector<TypePtr>> childrenTypes;
+ if (!parseVeloxType(extension, inputRowType) ||
!flattenVeloxType2(inputRowType, childrenTypes)) {
+ LOG_VALIDATION_MSG("Validation failed for input types in SetRel.");
+ return false;
+ }
+ std::vector<RowTypePtr> childrenRowTypes;
+ for (auto i = 0; i < childrenTypes.size(); ++i) {
+ auto& types = childrenTypes.at(i);
+ std::vector<std::string> names;
+ names.reserve(types.size());
+ for (auto colIdx = 0; colIdx < types.size(); colIdx++) {
+ names.emplace_back(SubstraitParser::makeNodeName(i, colIdx));
+ }
+ childrenRowTypes.push_back(std::make_shared<RowType>(std::move(names),
std::move(types)));
+ }
+
+ for (auto i = 1; i < childrenRowTypes.size(); ++i) {
+ if (!(childrenRowTypes[i]->equivalent(*childrenRowTypes[0]))) {
+ LOG_VALIDATION_MSG(
+ "All sources of the Set operation must have the same output
type: " + childrenRowTypes[i]->toString() +
+ " vs. " + childrenRowTypes[0]->toString());
+ return false;
+ }
+ }
+ return true;
+ }
+ default:
+ LOG_VALIDATION_MSG("Unsupported SetRel op: " +
std::to_string(setRel.op()));
+ return false;
+ }
+}
+
bool SubstraitToVeloxPlanValidator::validate(const ::substrait::SortRel&
sortRel) {
if (sortRel.has_input() && !validate(sortRel.input())) {
return false;
@@ -769,8 +848,9 @@ bool SubstraitToVeloxPlanValidator::validate(const
::substrait::SortRel& sortRel
}
const auto& extension = sortRel.advanced_extension();
+ TypePtr inputRowType;
std::vector<TypePtr> types;
- if (!validateInputTypes(extension, types)) {
+ if (!parseVeloxType(extension, inputRowType) ||
!flattenVeloxType1(inputRowType, types)) {
LOG_VALIDATION_MSG("Validation failed for input types in SortRel.");
return false;
}
@@ -822,8 +902,9 @@ bool SubstraitToVeloxPlanValidator::validate(const
::substrait::ProjectRel& proj
return false;
}
const auto& extension = projectRel.advanced_extension();
+ TypePtr inputRowType;
std::vector<TypePtr> types;
- if (!validateInputTypes(extension, types)) {
+ if (!parseVeloxType(extension, inputRowType) ||
!flattenVeloxType1(inputRowType, types)) {
LOG_VALIDATION_MSG("Validation failed for input types in ProjectRel.");
return false;
}
@@ -865,8 +946,9 @@ bool SubstraitToVeloxPlanValidator::validate(const
::substrait::FilterRel& filte
return false;
}
const auto& extension = filterRel.advanced_extension();
+ TypePtr inputRowType;
std::vector<TypePtr> types;
- if (!validateInputTypes(extension, types)) {
+ if (!parseVeloxType(extension, inputRowType) ||
!flattenVeloxType1(inputRowType, types)) {
LOG_VALIDATION_MSG("Validation failed for input types in FilterRel.");
return false;
}
@@ -938,8 +1020,9 @@ bool SubstraitToVeloxPlanValidator::validate(const
::substrait::JoinRel& joinRel
}
const auto& extension = joinRel.advanced_extension();
+ TypePtr inputRowType;
std::vector<TypePtr> types;
- if (!validateInputTypes(extension, types)) {
+ if (!parseVeloxType(extension, inputRowType) ||
!flattenVeloxType1(inputRowType, types)) {
LOG_VALIDATION_MSG("Validation failed for input types in JoinRel.");
return false;
}
@@ -991,8 +1074,9 @@ bool SubstraitToVeloxPlanValidator::validate(const
::substrait::CrossRel& crossR
}
const auto& extension = crossRel.advanced_extension();
+ TypePtr inputRowType;
std::vector<TypePtr> types;
- if (!validateInputTypes(extension, types)) {
+ if (!parseVeloxType(extension, inputRowType) ||
!flattenVeloxType1(inputRowType, types)) {
logValidateMsg("Native validation failed due to: Validation failed for
input types in CrossRel");
return false;
}
@@ -1070,11 +1154,13 @@ bool SubstraitToVeloxPlanValidator::validate(const
::substrait::AggregateRel& ag
// Validate input types.
if (aggRel.has_advanced_extension()) {
+ TypePtr inputRowType;
std::vector<TypePtr> types;
const auto& extension = aggRel.advanced_extension();
// Aggregate always has advanced extension for streaming aggregate
optimization,
// but only some of them have enhancement for validation.
- if (extension.has_enhancement() && !validateInputTypes(extension, types)) {
+ if (extension.has_enhancement() &&
+ (!parseVeloxType(extension, inputRowType) ||
!flattenVeloxType1(inputRowType, types))) {
LOG_VALIDATION_MSG("Validation failed for input types in AggregateRel.");
return false;
}
@@ -1266,7 +1352,10 @@ bool SubstraitToVeloxPlanValidator::validate(const
::substrait::Rel& rel) {
return validate(rel.write());
} else if (rel.has_windowgrouplimit()) {
return validate(rel.windowgrouplimit());
+ } else if (rel.has_set()) {
+ return validate(rel.set());
} else {
+ LOG_VALIDATION_MSG("Unsupported relation type: " + rel.GetTypeName());
return false;
}
}
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.h
b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.h
index 1fe174928f..0c8d882ca0 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.h
+++ b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.h
@@ -61,6 +61,9 @@ class SubstraitToVeloxPlanValidator {
/// Used to validate whether the computing of this WindowGroupLimit is
supported.
bool validate(const ::substrait::WindowGroupLimitRel& windowGroupLimitRel);
+ /// Used to validate whether the computing of this Set is supported.
+ bool validate(const ::substrait::SetRel& setRel);
+
/// Used to validate whether the computing of this Aggregation is supported.
bool validate(const ::substrait::AggregateRel& aggRel);
@@ -103,9 +106,17 @@ class SubstraitToVeloxPlanValidator {
std::vector<std::string> validateLog_;
- /// Used to get types from advanced extension and validate them.
- bool validateInputTypes(const ::substrait::extensions::AdvancedExtension&
extension, std::vector<TypePtr>& types);
+ /// Used to get types from advanced extension and validate them, then
convert to a Velox type that has arbitrary
+ /// levels of nesting.
+ bool parseVeloxType(const ::substrait::extensions::AdvancedExtension&
extension, TypePtr& out);
+
+ /// Flattens a Velox type with single level of nesting into a std::vector of
child types.
+ bool flattenVeloxType1(const TypePtr& type, std::vector<TypePtr>& out);
+
+ /// Flattens a Velox type with two level of nesting into a dual-nested
std::vector of child types.
+ bool flattenVeloxType2(const TypePtr& type,
std::vector<std::vector<TypePtr>>& out);
+ /// Validate aggregate rel.
bool validateAggRelFunctionType(const ::substrait::AggregateRel&
substraitAgg);
/// Validate the round scalar function.
diff --git
a/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java
b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java
index def1dca0a0..7d19311808 100644
---
a/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java
+++
b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java
@@ -27,6 +27,7 @@ import org.apache.gluten.substrait.type.TypeNode;
import io.substrait.proto.CrossRel;
import io.substrait.proto.JoinRel;
+import io.substrait.proto.SetRel;
import io.substrait.proto.SortField;
import org.apache.spark.sql.catalyst.expressions.Attribute;
@@ -317,4 +318,20 @@ public class RelBuilder {
context.registerRelToOperator(operatorId);
return new GenerateRelNode(input, generator, childOutput, extensionNode,
outer);
}
+
+ public static RelNode makeSetRel(
+ List<RelNode> inputs, SetRel.SetOp setOp, SubstraitContext context, Long
operatorId) {
+ context.registerRelToOperator(operatorId);
+ return new SetRelNode(inputs, setOp);
+ }
+
+ public static RelNode makeSetRel(
+ List<RelNode> inputs,
+ SetRel.SetOp setOp,
+ AdvancedExtensionNode extensionNode,
+ SubstraitContext context,
+ Long operatorId) {
+ context.registerRelToOperator(operatorId);
+ return new SetRelNode(inputs, setOp, extensionNode);
+ }
}
diff --git
a/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/SetRelNode.java
b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/SetRelNode.java
new file mode 100644
index 0000000000..ddcfb1701d
--- /dev/null
+++
b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/SetRelNode.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.substrait.rel;
+
+import org.apache.gluten.substrait.extensions.AdvancedExtensionNode;
+
+import io.substrait.proto.Rel;
+import io.substrait.proto.RelCommon;
+import io.substrait.proto.SetRel;
+
+import java.io.Serializable;
+import java.util.List;
+
+public class SetRelNode implements RelNode, Serializable {
+ private final List<RelNode> inputs;
+ private final SetRel.SetOp setOp;
+ private final AdvancedExtensionNode extensionNode;
+
+ public SetRelNode(List<RelNode> inputs, SetRel.SetOp setOp,
AdvancedExtensionNode extensionNode) {
+ this.inputs = inputs;
+ this.setOp = setOp;
+ this.extensionNode = extensionNode;
+ }
+
+ public SetRelNode(List<RelNode> inputs, SetRel.SetOp setOp) {
+ this(inputs, setOp, null);
+ }
+
+ @Override
+ public Rel toProtobuf() {
+ final RelCommon.Builder relCommonBuilder = RelCommon.newBuilder();
+ relCommonBuilder.setDirect(RelCommon.Direct.newBuilder());
+ final SetRel.Builder setBuilder = SetRel.newBuilder();
+ setBuilder.setCommon(relCommonBuilder.build());
+ if (inputs != null) {
+ for (RelNode input : inputs) {
+ setBuilder.addInputs(input.toProtobuf());
+ }
+ }
+ setBuilder.setOp(setOp);
+ if (extensionNode != null) {
+ setBuilder.setAdvancedExtension(extensionNode.toProtobuf());
+ }
+ final Rel.Builder builder = Rel.newBuilder();
+ builder.setSet(setBuilder.build());
+ return builder.build();
+ }
+}
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/MetricsApi.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/MetricsApi.scala
index c67d4b5f88..453cfab4e4 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/MetricsApi.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/MetricsApi.scala
@@ -126,6 +126,10 @@ trait MetricsApi extends Serializable {
def genSampleTransformerMetricsUpdater(metrics: Map[String, SQLMetric]):
MetricsUpdater
+ def genUnionTransformerMetrics(sparkContext: SparkContext): Map[String,
SQLMetric]
+
+ def genUnionTransformerMetricsUpdater(metrics: Map[String, SQLMetric]):
MetricsUpdater
+
def genColumnarInMemoryTableMetrics(sparkContext: SparkContext): Map[String,
SQLMetric] =
Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of
output rows"))
}
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala
index f9755605ca..ac8e610956 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala
@@ -261,10 +261,11 @@ abstract class ProjectExecTransformerBase(val list:
Seq[NamedExpression], val in
}
}
-// An alternatives for UnionExec.
+// An alternative for UnionExec.
case class ColumnarUnionExec(children: Seq[SparkPlan]) extends ValidatablePlan
{
children.foreach {
case w: WholeStageTransformer =>
+ // FIXME: Avoid such practice for plan immutability.
w.setOutputSchemaForPlan(output)
case _ =>
}
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/UnionExecTransformer.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/UnionExecTransformer.scala
new file mode 100644
index 0000000000..d27558746a
--- /dev/null
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/UnionExecTransformer.scala
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.execution
+
+import org.apache.gluten.backendsapi.BackendsApiManager
+import org.apache.gluten.expression.ConverterUtils
+import org.apache.gluten.extension.ValidationResult
+import org.apache.gluten.metrics.MetricsUpdater
+import org.apache.gluten.substrait.`type`.TypeBuilder
+import org.apache.gluten.substrait.SubstraitContext
+import org.apache.gluten.substrait.extensions.ExtensionBuilder
+import org.apache.gluten.substrait.rel.{RelBuilder, RelNode}
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.execution.{SparkPlan, UnionExec}
+import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.vectorized.ColumnarBatch
+
+import io.substrait.proto.SetRel.SetOp
+
+import scala.collection.JavaConverters._
+
+/** Transformer for UnionExec. Note: Spark's UnionExec represents a SQL UNION
ALL. */
+case class UnionExecTransformer(children: Seq[SparkPlan]) extends
TransformSupport {
+ private val union = UnionExec(children)
+
+ // Note: "metrics" is made transient to avoid sending driver-side metrics to
tasks.
+ @transient override lazy val metrics: Map[String, SQLMetric] =
+
BackendsApiManager.getMetricsApiInstance.genUnionTransformerMetrics(sparkContext)
+
+ override def output: Seq[Attribute] = union.output
+
+ override def columnarInputRDDs: Seq[RDD[ColumnarBatch]] =
children.flatMap(getColumnarInputRDDs)
+
+ override def metricsUpdater(): MetricsUpdater =
+
BackendsApiManager.getMetricsApiInstance.genUnionTransformerMetricsUpdater(metrics)
+
+ override protected def withNewChildrenInternal(newChildren:
IndexedSeq[SparkPlan]): SparkPlan =
+ copy(children = newChildren)
+
+ override protected def doValidateInternal(): ValidationResult = {
+ val context = new SubstraitContext
+ val operatorId = context.nextOperatorId(this.nodeName)
+ val relNode = getRelNode(context, operatorId, children.map(_.output),
null, true)
+ doNativeValidation(context, relNode)
+ }
+
+ override protected def doTransform(context: SubstraitContext):
TransformContext = {
+ val childrenCtx =
children.map(_.asInstanceOf[TransformSupport].transform(context))
+ val operatorId = context.nextOperatorId(this.nodeName)
+ val relNode =
+ getRelNode(context, operatorId, children.map(_.output),
childrenCtx.map(_.root), false)
+ TransformContext(output, relNode)
+ }
+
+ private def getRelNode(
+ context: SubstraitContext,
+ operatorId: Long,
+ inputAttributes: Seq[Seq[Attribute]],
+ inputs: Seq[RelNode],
+ validation: Boolean): RelNode = {
+ if (validation) {
+ // Use the second level of nesting to represent N way inputs.
+ val inputTypeNodes =
+ inputAttributes.map(
+ attributes =>
+ attributes.map(attr => ConverterUtils.getTypeNode(attr.dataType,
attr.nullable)).asJava)
+ val extensionNode = ExtensionBuilder.makeAdvancedExtension(
+ BackendsApiManager.getTransformerApiInstance.packPBMessage(
+ TypeBuilder
+ .makeStruct(
+ false,
+ inputTypeNodes.map(nodes => TypeBuilder.makeStruct(false,
nodes)).asJava)
+ .toProtobuf))
+ return RelBuilder.makeSetRel(
+ inputs.asJava,
+ SetOp.SET_OP_UNION_ALL,
+ extensionNode,
+ context,
+ operatorId)
+ }
+ RelBuilder.makeSetRel(inputs.asJava, SetOp.SET_OP_UNION_ALL, context,
operatorId)
+ }
+}
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/UnionTransformerRule.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/UnionTransformerRule.scala
new file mode 100644
index 0000000000..f0eea08018
--- /dev/null
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/UnionTransformerRule.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension.columnar
+
+import org.apache.gluten.GlutenConfig
+import org.apache.gluten.execution.{ColumnarUnionExec, UnionExecTransformer}
+
+import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.SparkPlan
+
+/**
+ * Replace ColumnarUnionExec with UnionExecTransformer if possible.
+ *
+ * The rule is not included in
[[org.apache.gluten.extension.columnar.heuristic.HeuristicTransform]]
+ * or [[org.apache.gluten.extension.columnar.enumerated.EnumeratedTransform]]
because it relies on
+ * children's output partitioning to be fully provided.
+ */
+case class UnionTransformerRule() extends Rule[SparkPlan] {
+ override def apply(plan: SparkPlan): SparkPlan = {
+ if (!GlutenConfig.getConf.enableNativeUnion) {
+ return plan
+ }
+ plan.transformUp {
+ case plan: ColumnarUnionExec =>
+ val transformer = UnionExecTransformer(plan.children)
+ if (sameNumPartitions(plan.children) && validate(transformer)) {
+ transformer
+ } else {
+ plan
+ }
+ }
+ }
+
+ private def sameNumPartitions(plans: Seq[SparkPlan]): Boolean = {
+ val partitioning = plans.map(_.outputPartitioning)
+ if (partitioning.exists(p => p.isInstanceOf[UnknownPartitioning])) {
+ return false
+ }
+ val numPartitions = plans.map(_.outputPartitioning.numPartitions)
+ numPartitions.forall(_ == numPartitions.head)
+ }
+
+ private def validate(union: UnionExecTransformer): Boolean = {
+ union.doValidate().ok()
+ }
+}
diff --git
a/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
b/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
index fd250834d0..08081fadb5 100644
---
a/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
+++
b/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
@@ -176,25 +176,39 @@ abstract class WholeStageTransformerSuite
result
}
+ protected def compareResultsAgainstVanillaSpark(
+ sql: String,
+ compareResult: Boolean = true,
+ customCheck: DataFrame => Unit,
+ noFallBack: Boolean = true,
+ cache: Boolean = false): DataFrame = {
+ compareDfResultsAgainstVanillaSpark(
+ () => spark.sql(sql),
+ compareResult,
+ customCheck,
+ noFallBack,
+ cache)
+ }
+
/**
* run a query with native engine as well as vanilla spark then compare the
result set for
* correctness check
*/
- protected def compareResultsAgainstVanillaSpark(
- sqlStr: String,
+ protected def compareDfResultsAgainstVanillaSpark(
+ dataframe: () => DataFrame,
compareResult: Boolean = true,
customCheck: DataFrame => Unit,
noFallBack: Boolean = true,
cache: Boolean = false): DataFrame = {
var expected: Seq[Row] = null
withSQLConf(vanillaSparkConfs(): _*) {
- val df = spark.sql(sqlStr)
+ val df = dataframe()
expected = df.collect()
}
- // By default we will fallabck complex type scan but here we should allow
+ // By default, we will fallback complex type scan but here we should allow
// to test support of complex type
spark.conf.set("spark.gluten.sql.complexType.scan.fallback.enabled",
"false");
- val df = spark.sql(sqlStr)
+ val df = dataframe()
if (cache) {
df.cache()
}
@@ -239,7 +253,12 @@ abstract class WholeStageTransformerSuite
noFallBack: Boolean = true,
cache: Boolean = false)(customCheck: DataFrame => Unit): DataFrame = {
- compareResultsAgainstVanillaSpark(sqlStr, compareResult, customCheck,
noFallBack, cache)
+ compareDfResultsAgainstVanillaSpark(
+ () => spark.sql(sqlStr),
+ compareResult,
+ customCheck,
+ noFallBack,
+ cache)
}
/**
@@ -256,8 +275,8 @@ abstract class WholeStageTransformerSuite
customCheck: DataFrame => Unit,
noFallBack: Boolean = true,
compareResult: Boolean = true): Unit =
- compareResultsAgainstVanillaSpark(
- tpchSQL(queryNum, tpchQueries),
+ compareDfResultsAgainstVanillaSpark(
+ () => spark.sql(tpchSQL(queryNum, tpchQueries)),
compareResult,
customCheck,
noFallBack)
diff --git a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
index f643ad7eed..9ae4c0ce90 100644
--- a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
+++ b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
@@ -86,6 +86,8 @@ class GlutenConfig(conf: SQLConf) extends Logging {
def enableColumnarUnion: Boolean = conf.getConf(COLUMNAR_UNION_ENABLED)
+ def enableNativeUnion: Boolean = conf.getConf(NATIVE_UNION_ENABLED)
+
def enableColumnarExpand: Boolean = conf.getConf(COLUMNAR_EXPAND_ENABLED)
def enableColumnarBroadcastExchange: Boolean =
conf.getConf(COLUMNAR_BROADCAST_EXCHANGE_ENABLED)
@@ -1022,6 +1024,13 @@ object GlutenConfig {
.booleanConf
.createWithDefault(true)
+ val NATIVE_UNION_ENABLED =
+ buildConf("spark.gluten.sql.native.union")
+ .internal()
+ .doc("Enable or disable native union where computation is completely
offloaded to backend.")
+ .booleanConf
+ .createWithDefault(false)
+
val COLUMNAR_EXPAND_ENABLED =
buildConf("spark.gluten.sql.columnar.expand")
.internal()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]