This is an automated email from the ASF dual-hosted git repository.
kejia pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 4e04d31d9 [GLUTEN-4836][VL]Add support for WindowGroupLimitExec in
gluten (#5398)
4e04d31d9 is described below
commit 4e04d31d95d2edf9ca150126a349887749ea0897
Author: ayushi-agarwal <[email protected]>
AuthorDate: Wed Apr 24 11:05:15 2024 +0530
[GLUTEN-4836][VL]Add support for WindowGroupLimitExec in gluten (#5398)
* Add support for WindowGroupLimitExec in gluten
---------
Co-authored-by: ayushi agarwal <[email protected]>
---
.../gluten/backendsapi/velox/VeloxBackend.scala | 7 +
cpp/velox/substrait/SubstraitToVeloxPlan.cc | 47 ++++++
cpp/velox/substrait/SubstraitToVeloxPlan.h | 5 +-
.../substrait/SubstraitToVeloxPlanValidator.cc | 72 ++++++++++
.../substrait/SubstraitToVeloxPlanValidator.h | 3 +
.../apache/gluten/substrait/rel/RelBuilder.java | 23 +++
.../substrait/rel/WindowGroupLimitRelNode.java | 91 ++++++++++++
.../substrait/proto/substrait/algebra.proto | 10 ++
.../gluten/backendsapi/BackendSettingsApi.scala | 7 +-
.../WindowGroupLimitExecTransformer.scala | 158 +++++++++++++++++++++
.../extension/columnar/PullOutPreProject.scala | 35 ++++-
.../columnar/RewriteSparkPlanRulesManager.scala | 2 +
.../extension/columnar/TransformHintRule.scala | 23 ++-
.../extension/columnar/TransformSingleNode.scala | 14 +-
.../execution/GlutenSQLWindowFunctionSuite.scala | 90 +++++++++++-
.../scala/org/apache/gluten/GlutenConfig.scala | 9 ++
.../org/apache/gluten/sql/shims/SparkShims.scala | 6 +
.../window/WindowGroupLimitExecShim.scala | 47 ++++++
.../window/WindowGroupLimitExecShim.scala | 47 ++++++
.../window/WindowGroupLimitExecShim.scala | 47 ++++++
.../gluten/sql/shims/spark35/Spark35Shims.scala | 30 ++++
.../window/WindowGroupLimitExecShim.scala | 41 ++++++
22 files changed, 808 insertions(+), 6 deletions(-)
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
index c9e0dc694..9ac7fed97 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
@@ -274,6 +274,13 @@ object VeloxBackendSettings extends BackendSettingsApi {
GlutenConfig.getConf.enableColumnarSortMergeJoin
}
+ override def supportWindowGroupLimitExec(rankLikeFunction: Expression):
Boolean = {
+ rankLikeFunction match {
+ case _: RowNumber => true
+ case _ => false
+ }
+ }
+
override def supportWindowExec(windowFunctions: Seq[NamedExpression]):
Boolean = {
var allSupported = true
breakable {
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.cc
b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
index 9bf183c16..4b1a2543e 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlan.cc
+++ b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
@@ -913,6 +913,51 @@ core::PlanNodePtr
SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
}
}
+core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(
+ const ::substrait::WindowGroupLimitRel& windowGroupLimitRel) {
+ core::PlanNodePtr childNode;
+ if (windowGroupLimitRel.has_input()) {
+ childNode = toVeloxPlan(windowGroupLimitRel.input());
+ } else {
+ VELOX_FAIL("Child Rel is expected in WindowGroupLimitRel.");
+ }
+ const auto& inputType = childNode->outputType();
+ // Construct partitionKeys
+ std::vector<core::FieldAccessTypedExprPtr> partitionKeys;
+ std::unordered_set<std::string> keyNames;
+ const auto& partitions = windowGroupLimitRel.partition_expressions();
+ partitionKeys.reserve(partitions.size());
+ for (const auto& partition : partitions) {
+ auto expression = exprConverter_->toVeloxExpr(partition, inputType);
+ core::FieldAccessTypedExprPtr veloxPartitionKey =
+ std::dynamic_pointer_cast<const
core::FieldAccessTypedExpr>(expression);
+ VELOX_USER_CHECK_NOT_NULL(veloxPartitionKey, "Window Group Limit Operator
only supports field partition key.");
+ // Constructs unique partition keys.
+ if (keyNames.insert(veloxPartitionKey->name()).second) {
+ partitionKeys.emplace_back(veloxPartitionKey);
+ }
+ }
+ std::vector<core::FieldAccessTypedExprPtr> sortingKeys;
+ std::vector<core::SortOrder> sortingOrders;
+ const auto& [rawSortingKeys, rawSortingOrders] =
processSortField(windowGroupLimitRel.sorts(), inputType);
+ for (vector_size_t i = 0; i < rawSortingKeys.size(); ++i) {
+ // Constructs unique sort keys and excludes keys overlapped with partition
keys.
+ if (keyNames.insert(rawSortingKeys[i]->name()).second) {
+ sortingKeys.emplace_back(rawSortingKeys[i]);
+ sortingOrders.emplace_back(rawSortingOrders[i]);
+ }
+ }
+ const std::optional<std::string> rowNumberColumnName = std::nullopt;
+ return std::make_shared<core::TopNRowNumberNode>(
+ nextPlanNodeId(),
+ partitionKeys,
+ sortingKeys,
+ sortingOrders,
+ rowNumberColumnName,
+ (int32_t)windowGroupLimitRel.limit(),
+ childNode);
+}
+
core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const
::substrait::SortRel& sortRel) {
auto childNode = convertSingleInput<::substrait::SortRel>(sortRel);
auto [sortingKeys, sortingOrders] = processSortField(sortRel.sorts(),
childNode->outputType());
@@ -1197,6 +1242,8 @@ core::PlanNodePtr
SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
return toVeloxPlan(rel.window());
} else if (rel.has_write()) {
return toVeloxPlan(rel.write());
+ } else if (rel.has_windowgrouplimit()) {
+ return toVeloxPlan(rel.windowgrouplimit());
} else {
VELOX_NYI("Substrait conversion not supported for Rel.");
}
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.h
b/cpp/velox/substrait/SubstraitToVeloxPlan.h
index f11856395..dc97d2a4c 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlan.h
+++ b/cpp/velox/substrait/SubstraitToVeloxPlan.h
@@ -74,9 +74,12 @@ class SubstraitToVeloxPlanConverter {
/// Used to convert Substrait GenerateRel into Velox PlanNode.
core::PlanNodePtr toVeloxPlan(const ::substrait::GenerateRel& generateRel);
- /// Used to convert Substrait SortRel into Velox PlanNode.
+ /// Used to convert Substrait WindowRel into Velox PlanNode.
core::PlanNodePtr toVeloxPlan(const ::substrait::WindowRel& windowRel);
+ /// Used to convert Substrait WindowGroupLimitRel into Velox PlanNode.
+ core::PlanNodePtr toVeloxPlan(const ::substrait::WindowGroupLimitRel&
windowGroupLimitRel);
+
/// Used to convert Substrait JoinRel into Velox PlanNode.
core::PlanNodePtr toVeloxPlan(const ::substrait::JoinRel& joinRel);
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
index 2a5857ae9..ba711e774 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
+++ b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
@@ -698,6 +698,76 @@ bool SubstraitToVeloxPlanValidator::validate(const
::substrait::WindowRel& windo
return true;
}
+bool SubstraitToVeloxPlanValidator::validate(const
::substrait::WindowGroupLimitRel& windowGroupLimitRel) {
+ if (windowGroupLimitRel.has_input() &&
!validate(windowGroupLimitRel.input())) {
+ LOG_VALIDATION_MSG("WindowGroupLimitRel input fails to validate.");
+ return false;
+ }
+
+ // Get and validate the input types from extension.
+ if (!windowGroupLimitRel.has_advanced_extension()) {
+ LOG_VALIDATION_MSG("Input types are expected in WindowGroupLimitRel.");
+ return false;
+ }
+ const auto& extension = windowGroupLimitRel.advanced_extension();
+ std::vector<TypePtr> types;
+ if (!validateInputTypes(extension, types)) {
+ LOG_VALIDATION_MSG("Validation failed for input types in
WindowGroupLimitRel.");
+ return false;
+ }
+
+ int32_t inputPlanNodeId = 0;
+ std::vector<std::string> names;
+ names.reserve(types.size());
+ for (auto colIdx = 0; colIdx < types.size(); colIdx++) {
+ names.emplace_back(SubstraitParser::makeNodeName(inputPlanNodeId, colIdx));
+ }
+ auto rowType = std::make_shared<RowType>(std::move(names), std::move(types));
+ // Validate groupby expression
+ const auto& groupByExprs = windowGroupLimitRel.partition_expressions();
+ std::vector<core::TypedExprPtr> expressions;
+ expressions.reserve(groupByExprs.size());
+ for (const auto& expr : groupByExprs) {
+ auto expression = exprConverter_->toVeloxExpr(expr, rowType);
+ auto exprField = dynamic_cast<const
core::FieldAccessTypedExpr*>(expression.get());
+ if (exprField == nullptr) {
+ LOG_VALIDATION_MSG("Only field is supported for partition key in Window
Group Limit Operator!");
+ return false;
+ } else {
+ expressions.emplace_back(expression);
+ }
+ }
+ // Try to compile the expressions. If there is any unregistered function or
+ // mismatched type, exception will be thrown.
+ exec::ExprSet exprSet(std::move(expressions), execCtx_);
+ // Validate Sort expression
+ const auto& sorts = windowGroupLimitRel.sorts();
+ for (const auto& sort : sorts) {
+ switch (sort.direction()) {
+ case ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_FIRST:
+ case ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_LAST:
+ case
::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_FIRST:
+ case ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_LAST:
+ break;
+ default:
+ LOG_VALIDATION_MSG("in windowGroupLimitRel, unsupported Sort direction
" + std::to_string(sort.direction()));
+ return false;
+ }
+
+ if (sort.has_expr()) {
+ auto expression = exprConverter_->toVeloxExpr(sort.expr(), rowType);
+ auto exprField = dynamic_cast<const
core::FieldAccessTypedExpr*>(expression.get());
+ if (!exprField) {
+ LOG_VALIDATION_MSG("in windowGroupLimitRel, the sorting key in Sort
Operator only support field.");
+ return false;
+ }
+ exec::ExprSet exprSet({std::move(expression)}, execCtx_);
+ }
+ }
+
+ return true;
+}
+
bool SubstraitToVeloxPlanValidator::validate(const ::substrait::SortRel&
sortRel) {
if (sortRel.has_input() && !validate(sortRel.input())) {
return false;
@@ -1200,6 +1270,8 @@ bool SubstraitToVeloxPlanValidator::validate(const
::substrait::Rel& rel) {
return validate(rel.window());
} else if (rel.has_write()) {
return validate(rel.write());
+ } else if (rel.has_windowgrouplimit()) {
+ return validate(rel.windowgrouplimit());
} else {
return false;
}
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.h
b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.h
index ab9d4445c..1fe174928 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.h
+++ b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.h
@@ -58,6 +58,9 @@ class SubstraitToVeloxPlanValidator {
/// Used to validate whether the computing of this Window is supported.
bool validate(const ::substrait::WindowRel& windowRel);
+ /// Used to validate whether the computing of this WindowGroupLimit is
supported.
+ bool validate(const ::substrait::WindowGroupLimitRel& windowGroupLimitRel);
+
/// Used to validate whether the computing of this Aggregation is supported.
bool validate(const ::substrait::AggregateRel& aggRel);
diff --git
a/gluten-core/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java
b/gluten-core/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java
index 717fffe44..b784e3e7f 100644
--- a/gluten-core/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java
+++ b/gluten-core/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java
@@ -272,6 +272,29 @@ public class RelBuilder {
partitionExpressions, sorts);
}
+ public static RelNode makeWindowGroupLimitRel(
+ RelNode input,
+ List<ExpressionNode> partitionExpressions,
+ List<SortField> sorts,
+ Integer limit,
+ AdvancedExtensionNode extensionNode,
+ SubstraitContext context,
+ Long operatorId) {
+ context.registerRelToOperator(operatorId);
+ return new WindowGroupLimitRelNode(input, partitionExpressions, sorts,
limit, extensionNode);
+ }
+
+ public static RelNode makeWindowGroupLimitRel(
+ RelNode input,
+ List<ExpressionNode> partitionExpressions,
+ List<SortField> sorts,
+ Integer limit,
+ SubstraitContext context,
+ Long operatorId) {
+ context.registerRelToOperator(operatorId);
+ return new WindowGroupLimitRelNode(input, partitionExpressions, sorts,
limit);
+ }
+
public static RelNode makeGenerateRel(
RelNode input,
ExpressionNode generator,
diff --git
a/gluten-core/src/main/java/org/apache/gluten/substrait/rel/WindowGroupLimitRelNode.java
b/gluten-core/src/main/java/org/apache/gluten/substrait/rel/WindowGroupLimitRelNode.java
new file mode 100644
index 000000000..a7a0ea62a
--- /dev/null
+++
b/gluten-core/src/main/java/org/apache/gluten/substrait/rel/WindowGroupLimitRelNode.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.substrait.rel;
+
+import org.apache.gluten.substrait.expression.ExpressionNode;
+import org.apache.gluten.substrait.extensions.AdvancedExtensionNode;
+
+import io.substrait.proto.Rel;
+import io.substrait.proto.RelCommon;
+import io.substrait.proto.SortField;
+import io.substrait.proto.WindowGroupLimitRel;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+
+public class WindowGroupLimitRelNode implements RelNode, Serializable {
+ private final RelNode input;
+ private final List<ExpressionNode> partitionExpressions = new ArrayList<>();
+ private final List<SortField> sorts = new ArrayList<>();
+ private final AdvancedExtensionNode extensionNode;
+ private final Integer limit;
+
+ public WindowGroupLimitRelNode(
+ RelNode input,
+ List<ExpressionNode> partitionExpressions,
+ List<SortField> sorts,
+ Integer limit) {
+ this.input = input;
+ this.partitionExpressions.addAll(partitionExpressions);
+ this.sorts.addAll(sorts);
+ this.limit = limit;
+ this.extensionNode = null;
+ }
+
+ public WindowGroupLimitRelNode(
+ RelNode input,
+ List<ExpressionNode> partitionExpressions,
+ List<SortField> sorts,
+ Integer limit,
+ AdvancedExtensionNode extensionNode) {
+ this.input = input;
+ this.partitionExpressions.addAll(partitionExpressions);
+ this.sorts.addAll(sorts);
+ this.limit = limit;
+ this.extensionNode = extensionNode;
+ }
+
+ @Override
+ public Rel toProtobuf() {
+ RelCommon.Builder relCommonBuilder = RelCommon.newBuilder();
+ relCommonBuilder.setDirect(RelCommon.Direct.newBuilder());
+
+ WindowGroupLimitRel.Builder windowBuilder =
WindowGroupLimitRel.newBuilder();
+ windowBuilder.setCommon(relCommonBuilder.build());
+ if (input != null) {
+ windowBuilder.setInput(input.toProtobuf());
+ }
+
+ for (int i = 0; i < partitionExpressions.size(); i++) {
+ windowBuilder.addPartitionExpressions(i,
partitionExpressions.get(i).toProtobuf());
+ }
+
+ for (int i = 0; i < sorts.size(); i++) {
+ windowBuilder.addSorts(i, sorts.get(i));
+ }
+
+ windowBuilder.setLimit(limit);
+
+ if (extensionNode != null) {
+ windowBuilder.setAdvancedExtension(extensionNode.toProtobuf());
+ }
+ Rel.Builder builder = Rel.newBuilder();
+ builder.setWindowGroupLimit(windowBuilder.build());
+ return builder.build();
+ }
+}
diff --git
a/gluten-core/src/main/resources/substrait/proto/substrait/algebra.proto
b/gluten-core/src/main/resources/substrait/proto/substrait/algebra.proto
index d76a85827..266aba4b0 100644
--- a/gluten-core/src/main/resources/substrait/proto/substrait/algebra.proto
+++ b/gluten-core/src/main/resources/substrait/proto/substrait/algebra.proto
@@ -328,6 +328,15 @@ message WindowRel {
}
}
+message WindowGroupLimitRel {
+ RelCommon common = 1;
+ Rel input = 2;
+ repeated Expression partition_expressions = 3;
+ repeated SortField sorts = 4;
+ int32 limit = 5;
+ substrait.extensions.AdvancedExtension advanced_extension = 10;
+}
+
// The relational operator capturing simple FILTERs (as in the WHERE clause of
SQL)
message FilterRel {
RelCommon common = 1;
@@ -495,6 +504,7 @@ message Rel {
GenerateRel generate = 17;
WriteRel write = 18;
TopNRel top_n = 19;
+ WindowGroupLimitRel windowGroupLimit = 20;
}
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala
index bb05275dc..5e8b347b3 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala
@@ -22,7 +22,7 @@ import
org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat
import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.catalog.BucketSpec
-import org.apache.spark.sql.catalyst.expressions.NamedExpression
+import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.SparkPlan
@@ -49,6 +49,9 @@ trait BackendSettingsApi {
def supportWindowExec(windowFunctions: Seq[NamedExpression]): Boolean = {
false
}
+ def supportWindowGroupLimitExec(rankLikeFunction: Expression): Boolean = {
+ false
+ }
def supportColumnarShuffleExec(): Boolean = {
GlutenConfig.getConf.enableColumnarShuffle
}
@@ -114,6 +117,8 @@ trait BackendSettingsApi {
def requiredChildOrderingForWindow(): Boolean = false
+ def requiredChildOrderingForWindowGroupLimit(): Boolean = false
+
def staticPartitionWriteOnly(): Boolean = false
def supportTransformWriteFiles: Boolean = false
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala
b/gluten-core/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala
new file mode 100644
index 000000000..bba79fa76
--- /dev/null
+++
b/gluten-core/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.execution
+
+import org.apache.gluten.backendsapi.BackendsApiManager
+import org.apache.gluten.expression.{ConverterUtils, ExpressionConverter}
+import org.apache.gluten.extension.ValidationResult
+import org.apache.gluten.metrics.MetricsUpdater
+import org.apache.gluten.substrait.`type`.TypeBuilder
+import org.apache.gluten.substrait.SubstraitContext
+import org.apache.gluten.substrait.extensions.ExtensionBuilder
+import org.apache.gluten.substrait.rel.{RelBuilder, RelNode}
+
+import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute,
Expression, SortOrder}
+import org.apache.spark.sql.catalyst.plans.physical.{AllTuples,
ClusteredDistribution, Distribution, Partitioning}
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.window.{Final, Partial,
WindowGroupLimitMode}
+
+import io.substrait.proto.SortField
+
+import scala.collection.JavaConverters._
+
+case class WindowGroupLimitExecTransformer(
+ partitionSpec: Seq[Expression],
+ orderSpec: Seq[SortOrder],
+ rankLikeFunction: Expression,
+ limit: Int,
+ mode: WindowGroupLimitMode,
+ child: SparkPlan)
+ extends UnaryTransformSupport {
+
+ @transient override lazy val metrics =
+
BackendsApiManager.getMetricsApiInstance.genWindowTransformerMetrics(sparkContext)
+
+ override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+ copy(child = newChild)
+
+ override def metricsUpdater(): MetricsUpdater =
+
BackendsApiManager.getMetricsApiInstance.genWindowTransformerMetricsUpdater(metrics)
+
+ override def output: Seq[Attribute] = child.output
+
+ override def requiredChildDistribution: Seq[Distribution] = mode match {
+ case Partial => super.requiredChildDistribution
+ case Final =>
+ if (partitionSpec.isEmpty) {
+ AllTuples :: Nil
+ } else {
+ ClusteredDistribution(partitionSpec) :: Nil
+ }
+ }
+
+ override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
+ if
(BackendsApiManager.getSettings.requiredChildOrderingForWindowGroupLimit()) {
+ // Velox StreamingTopNRowNumber need to require child order.
+ Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec)
+ } else {
+ Seq(Nil)
+ }
+ }
+
+ override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+
+ override def outputPartitioning: Partitioning = child.outputPartitioning
+
+ def getWindowGroupLimitRel(
+ context: SubstraitContext,
+ originalInputAttributes: Seq[Attribute],
+ operatorId: Long,
+ input: RelNode,
+ validation: Boolean): RelNode = {
+ val args = context.registeredFunction
+ // Partition By Expressions
+ val partitionsExpressions = partitionSpec
+ .map(
+ ExpressionConverter
+ .replaceWithExpressionTransformer(_, attributeSeq = child.output)
+ .doTransform(args))
+ .asJava
+
+ // Sort By Expressions
+ val sortFieldList =
+ orderSpec.map {
+ order =>
+ val builder = SortField.newBuilder()
+ val exprNode = ExpressionConverter
+ .replaceWithExpressionTransformer(order.child, attributeSeq =
child.output)
+ .doTransform(args)
+ builder.setExpr(exprNode.toProtobuf)
+
builder.setDirectionValue(SortExecTransformer.transformSortDirection(order))
+ builder.build()
+ }.asJava
+ if (!validation) {
+ RelBuilder.makeWindowGroupLimitRel(
+ input,
+ partitionsExpressions,
+ sortFieldList,
+ limit,
+ context,
+ operatorId)
+ } else {
+ // Use a extension node to send the input types through Substrait plan
for validation.
+ val inputTypeNodeList = originalInputAttributes
+ .map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
+ .asJava
+ val extensionNode = ExtensionBuilder.makeAdvancedExtension(
+ BackendsApiManager.getTransformerApiInstance.packPBMessage(
+ TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
+
+ RelBuilder.makeWindowGroupLimitRel(
+ input,
+ partitionsExpressions,
+ sortFieldList,
+ limit,
+ extensionNode,
+ context,
+ operatorId)
+ }
+ }
+
+ override protected def doValidateInternal(): ValidationResult = {
+ if
(!BackendsApiManager.getSettings.supportWindowGroupLimitExec(rankLikeFunction))
{
+ return ValidationResult
+ .notOk(s"Found unsupported rank like function: $rankLikeFunction")
+ }
+ val substraitContext = new SubstraitContext
+ val operatorId = substraitContext.nextOperatorId(this.nodeName)
+
+ val relNode =
+ getWindowGroupLimitRel(substraitContext, child.output, operatorId, null,
validation = true)
+
+ doNativeValidation(substraitContext, relNode)
+ }
+
+ override def doTransform(context: SubstraitContext): TransformContext = {
+ val childCtx = child.asInstanceOf[TransformSupport].doTransform(context)
+ val operatorId = context.nextOperatorId(this.nodeName)
+
+ val currRel =
+ getWindowGroupLimitRel(context, child.output, operatorId, childCtx.root,
validation = false)
+ assert(currRel != null, "Window Group Limit Rel should be valid")
+ TransformContext(childCtx.outputAttributes, output, currRel)
+ }
+}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/PullOutPreProject.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/PullOutPreProject.scala
index 9925b8269..48a9a7687 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/PullOutPreProject.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/PullOutPreProject.scala
@@ -17,6 +17,7 @@
package org.apache.gluten.extension.columnar
import org.apache.gluten.backendsapi.BackendsApiManager
+import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.utils.PullOutProjectHelper
import org.apache.spark.sql.catalyst.expressions._
@@ -24,7 +25,7 @@ import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{ExpandExec, GenerateExec, ProjectExec,
SortExec, SparkPlan, TakeOrderedAndProjectExec}
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec,
TypedAggregateExpression}
-import org.apache.spark.sql.execution.window.WindowExec
+import org.apache.spark.sql.execution.window.{WindowExec,
WindowGroupLimitExecShim}
import scala.collection.mutable
@@ -75,6 +76,12 @@ object PullOutPreProject extends Rule[SparkPlan] with
PullOutProjectHelper {
}
case _ => false
}.isDefined)
+ case plan if SparkShimLoader.getSparkShims.isWindowGroupLimitExec(plan)
=>
+ val window = SparkShimLoader.getSparkShims
+ .getWindowGroupLimitExecShim(plan)
+ .asInstanceOf[WindowGroupLimitExecShim]
+ window.orderSpec.exists(o => isNotAttribute(o.child)) ||
+ window.partitionSpec.exists(isNotAttribute)
case expand: ExpandExec =>
expand.projections.flatten.exists(isNotAttributeAndLiteral)
case _ => false
}
@@ -181,6 +188,32 @@ object PullOutPreProject extends Rule[SparkPlan] with
PullOutProjectHelper {
ProjectExec(window.output, newWindow)
+ case plan
+ if SparkShimLoader.getSparkShims.isWindowGroupLimitExec(plan) &&
needsPreProject(plan) =>
+ val windowLimit = SparkShimLoader.getSparkShims
+ .getWindowGroupLimitExecShim(plan)
+ .asInstanceOf[WindowGroupLimitExecShim]
+ val expressionMap = new mutable.HashMap[Expression, NamedExpression]()
+ // Handle orderSpec.
+ val newOrderSpec = getNewSortOrder(windowLimit.orderSpec, expressionMap)
+
+ // Handle partitionSpec.
+ val newPartitionSpec =
+ windowLimit.partitionSpec.map(replaceExpressionWithAttribute(_,
expressionMap))
+
+ val newWindowLimitShim = windowLimit.copy(
+ orderSpec = newOrderSpec,
+ partitionSpec = newPartitionSpec,
+ child = ProjectExec(
+ eliminateProjectList(windowLimit.child.outputSet,
expressionMap.values.toSeq),
+ windowLimit.child)
+ )
+
+ val newWindowLimit = SparkShimLoader.getSparkShims
+ .getWindowGroupLimitExec(newWindowLimitShim)
+
+ ProjectExec(plan.output, newWindowLimit)
+
case expand: ExpandExec if needsPreProject(expand) =>
val expressionMap = new mutable.HashMap[Expression, NamedExpression]()
val newProjections =
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/RewriteSparkPlanRulesManager.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/RewriteSparkPlanRulesManager.scala
index e694c23a7..6070613c1 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/RewriteSparkPlanRulesManager.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/RewriteSparkPlanRulesManager.scala
@@ -17,6 +17,7 @@
package org.apache.gluten.extension.columnar
import org.apache.gluten.extension.{RewriteCollect, RewriteIn}
+import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
@@ -58,6 +59,7 @@ class RewriteSparkPlanRulesManager private (rewriteRules:
Seq[Rule[SparkPlan]])
case _: FileSourceScanExec => true
case _: ExpandExec => true
case _: GenerateExec => true
+ case plan if
SparkShimLoader.getSparkShims.isWindowGroupLimitExec(plan) => true
case _ => false
}
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformHintRule.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformHintRule.scala
index 084f91514..8f2607a97 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformHintRule.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformHintRule.scala
@@ -39,7 +39,7 @@ import
org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.exchange._
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.execution.python.EvalPythonExec
-import org.apache.spark.sql.execution.window.WindowExec
+import org.apache.spark.sql.execution.window.{WindowExec,
WindowGroupLimitExecShim}
import org.apache.spark.sql.hive.HiveTableScanExecTransformer
import org.apache.spark.sql.types.StringType
@@ -301,6 +301,8 @@ case class AddTransformHintRule() extends Rule[SparkPlan] {
!scanOnly && BackendsApiManager.getSettings.supportColumnarShuffleExec()
val enableColumnarSort: Boolean = !scanOnly &&
columnarConf.enableColumnarSort
val enableColumnarWindow: Boolean = !scanOnly &&
columnarConf.enableColumnarWindow
+ val enableColumnarWindowGroupLimit: Boolean = !scanOnly &&
+ columnarConf.enableColumnarWindowGroupLimit
val enableColumnarSortMergeJoin: Boolean = !scanOnly &&
BackendsApiManager.getSettings.supportSortMergeJoinExec()
val enableColumnarBatchScan: Boolean = columnarConf.enableColumnarBatchScan
@@ -622,6 +624,25 @@ case class AddTransformHintRule() extends Rule[SparkPlan] {
plan.child)
transformer.doValidate().tagOnFallback(plan)
}
+ case plan if
SparkShimLoader.getSparkShims.isWindowGroupLimitExec(plan) =>
+ if (!enableColumnarWindowGroupLimit) {
+ TransformHints.tagNotTransformable(
+ plan,
+ "columnar window group limit is not enabled in
WindowGroupLimitExec")
+ } else {
+ val windowGroupLimitPlan = SparkShimLoader.getSparkShims
+ .getWindowGroupLimitExecShim(plan)
+ .asInstanceOf[WindowGroupLimitExecShim]
+ val transformer = WindowGroupLimitExecTransformer(
+ windowGroupLimitPlan.partitionSpec,
+ windowGroupLimitPlan.orderSpec,
+ windowGroupLimitPlan.rankLikeFunction,
+ windowGroupLimitPlan.limit,
+ windowGroupLimitPlan.mode,
+ windowGroupLimitPlan.child
+ )
+ transformer.doValidate().tagOnFallback(plan)
+ }
case plan: CoalesceExec =>
if (!enableColumnarCoalesce) {
TransformHints.tagNotTransformable(
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformSingleNode.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformSingleNode.scala
index d005baa30..bc1276d63 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformSingleNode.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformSingleNode.scala
@@ -37,7 +37,7 @@ import
org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, FileScan}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec,
ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.execution.python.EvalPythonExec
-import org.apache.spark.sql.execution.window.WindowExec
+import org.apache.spark.sql.execution.window.{WindowExec,
WindowGroupLimitExecShim}
import org.apache.spark.sql.hive.HiveTableScanExecTransformer
sealed trait TransformSingleNode extends Logging {
@@ -408,6 +408,18 @@ object TransformOthers {
plan.partitionSpec,
plan.orderSpec,
plan.child)
+ case plan if
SparkShimLoader.getSparkShims.isWindowGroupLimitExec(plan) =>
+ val windowGroupLimitPlan = SparkShimLoader.getSparkShims
+ .getWindowGroupLimitExecShim(plan)
+ .asInstanceOf[WindowGroupLimitExecShim]
+ WindowGroupLimitExecTransformer(
+ windowGroupLimitPlan.partitionSpec,
+ windowGroupLimitPlan.orderSpec,
+ windowGroupLimitPlan.rankLikeFunction,
+ windowGroupLimitPlan.limit,
+ windowGroupLimitPlan.mode,
+ windowGroupLimitPlan.child
+ )
case plan: GlobalLimitExec =>
logDebug(s"Columnar Processing for ${plan.getClass} is currently
supported.")
val child = plan.child
diff --git
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/GlutenSQLWindowFunctionSuite.scala
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/GlutenSQLWindowFunctionSuite.scala
index 666517420..89a435174 100644
---
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/GlutenSQLWindowFunctionSuite.scala
+++
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/GlutenSQLWindowFunctionSuite.scala
@@ -16,7 +16,7 @@
*/
package org.apache.spark.sql.execution
-import org.apache.gluten.execution.WindowExecTransformer
+import org.apache.gluten.execution.{WindowExecTransformer,
WindowGroupLimitExecTransformer}
import org.apache.spark.sql.GlutenSQLTestsTrait
import org.apache.spark.sql.Row
@@ -93,6 +93,94 @@ class GlutenSQLWindowFunctionSuite extends
SQLWindowFunctionSuite with GlutenSQL
}
}
+ testGluten("Filter on row number") {
+ withTable("customer") {
+ val rdd = spark.sparkContext.parallelize(customerData)
+ val customerDF = spark.createDataFrame(rdd, customerSchema)
+ customerDF.createOrReplaceTempView("customer")
+ val query =
+ """
+ |SELECT * from (SELECT
+ | c_custkey,
+ | c_acctbal,
+ | row_number() OVER (
+ | PARTITION BY c_nationkey,
+ | "a"
+ | ORDER BY
+ | c_custkey,
+ | "a"
+ | ) AS row_num
+ |FROM
+ | customer ORDER BY 1, 2) where row_num <=2
+ |""".stripMargin
+ val df = sql(query)
+ checkAnswer(
+ df,
+ Seq(
+ Row(4553, BigDecimal(638841L, 2), 1),
+ Row(4953, BigDecimal(603728L, 2), 1),
+ Row(9954, BigDecimal(758725L, 2), 1),
+ Row(35403, BigDecimal(603470L, 2), 2),
+ Row(35803, BigDecimal(528487L, 2), 1),
+ Row(61065, BigDecimal(728477L, 2), 1),
+ Row(95337, BigDecimal(91561L, 2), 2),
+ Row(127412, BigDecimal(462141L, 2), 2),
+ Row(148303, BigDecimal(430230L, 2), 2)
+ )
+ )
+ assert(
+ getExecutedPlan(df).exists {
+ case _: WindowGroupLimitExecTransformer => true
+ case _ => false
+ }
+ )
+ }
+ }
+
+ testGluten("Filter on rank") {
+ withTable("customer") {
+ val rdd = spark.sparkContext.parallelize(customerData)
+ val customerDF = spark.createDataFrame(rdd, customerSchema)
+ customerDF.createOrReplaceTempView("customer")
+ val query =
+ """
+ |SELECT * from (SELECT
+ | c_custkey,
+ | c_acctbal,
+ | rank() OVER (
+ | PARTITION BY c_nationkey,
+ | "a"
+ | ORDER BY
+ | c_custkey,
+ | "a"
+ | ) AS rank
+ |FROM
+ | customer ORDER BY 1, 2) where rank <=2
+ |""".stripMargin
+ val df = sql(query)
+ checkAnswer(
+ df,
+ Seq(
+ Row(4553, BigDecimal(638841L, 2), 1),
+ Row(4953, BigDecimal(603728L, 2), 1),
+ Row(9954, BigDecimal(758725L, 2), 1),
+ Row(35403, BigDecimal(603470L, 2), 2),
+ Row(35803, BigDecimal(528487L, 2), 1),
+ Row(61065, BigDecimal(728477L, 2), 1),
+ Row(95337, BigDecimal(91561L, 2), 2),
+ Row(127412, BigDecimal(462141L, 2), 2),
+ Row(148303, BigDecimal(430230L, 2), 2)
+ )
+ )
+ assert(
+ !getExecutedPlan(df).exists {
+ case _: WindowGroupLimitExecTransformer => true
+ case _ => false
+ }
+ )
+ }
+ }
+
testGluten("Expression in WindowExpression") {
withTable("customer") {
val rdd = spark.sparkContext.parallelize(customerData)
diff --git a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
index 49541f534..a55fdc7f2 100644
--- a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
+++ b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
@@ -67,6 +67,8 @@ class GlutenConfig(conf: SQLConf) extends Logging {
def enableColumnarWindow: Boolean = conf.getConf(COLUMNAR_WINDOW_ENABLED)
+ def enableColumnarWindowGroupLimit: Boolean =
conf.getConf(COLUMNAR_WINDOW_GROUP_LIMIT_ENABLED)
+
def veloxColumnarWindowType: String =
conf.getConfString(COLUMNAR_VELOX_WINDOW_TYPE.key)
def enableColumnarShuffledHashJoin: Boolean =
conf.getConf(COLUMNAR_SHUFFLED_HASH_JOIN_ENABLED)
@@ -771,6 +773,13 @@ object GlutenConfig {
.booleanConf
.createWithDefault(true)
+ val COLUMNAR_WINDOW_GROUP_LIMIT_ENABLED =
+ buildConf("spark.gluten.sql.columnar.window.group.limit")
+ .internal()
+ .doc("Enable or disable columnar window group limit.")
+ .booleanConf
+ .createWithDefault(true)
+
val COLUMNAR_VELOX_WINDOW_TYPE =
buildConf("spark.gluten.sql.columnar.backend.velox.window.type")
.internal()
diff --git
a/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala
b/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala
index 699f1c102..0895c7e9a 100644
--- a/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala
+++ b/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala
@@ -129,6 +129,12 @@ trait SparkShims {
expr: Expression,
mightContainReplacer: (Expression, Expression) => BinaryExpression):
Expression
+ def isWindowGroupLimitExec(plan: SparkPlan): Boolean = false
+
+ def getWindowGroupLimitExecShim(plan: SparkPlan): SparkPlan = null
+
+ def getWindowGroupLimitExec(windowGroupLimitPlan: SparkPlan): SparkPlan =
null
+
def getLimitAndOffsetFromGlobalLimit(plan: GlobalLimitExec): (Int, Int) =
(plan.limit, 0)
def getLimitAndOffsetFromTopK(plan: TakeOrderedAndProjectExec): (Int, Int) =
(plan.limit, 0)
diff --git
a/shims/spark32/src/main/scala/org/apache/spark/sql/execution/window/WindowGroupLimitExecShim.scala
b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/window/WindowGroupLimitExecShim.scala
new file mode 100644
index 000000000..6aa2e5fb8
--- /dev/null
+++
b/shims/spark32/src/main/scala/org/apache/spark/sql/execution/window/WindowGroupLimitExecShim.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.window
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression,
SortOrder}
+import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
+
+sealed trait WindowGroupLimitMode
+
+case object Partial extends WindowGroupLimitMode
+
+case object Final extends WindowGroupLimitMode
+
+case class WindowGroupLimitExecShim(
+ partitionSpec: Seq[Expression],
+ orderSpec: Seq[SortOrder],
+ rankLikeFunction: Expression,
+ limit: Int,
+ mode: WindowGroupLimitMode,
+ child: SparkPlan)
+ extends UnaryExecNode {
+ override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+ copy(child = newChild)
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ throw new UnsupportedOperationException(
+ s"${this.getClass.getSimpleName} doesn't support doExecute")
+ }
+
+ override def output: Seq[Attribute] = child.output
+}
diff --git
a/shims/spark33/src/main/scala/org/apache/spark/sql/execution/window/WindowGroupLimitExecShim.scala
b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/window/WindowGroupLimitExecShim.scala
new file mode 100644
index 000000000..6aa2e5fb8
--- /dev/null
+++
b/shims/spark33/src/main/scala/org/apache/spark/sql/execution/window/WindowGroupLimitExecShim.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.window
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression,
SortOrder}
+import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
+
+sealed trait WindowGroupLimitMode
+
+case object Partial extends WindowGroupLimitMode
+
+case object Final extends WindowGroupLimitMode
+
+case class WindowGroupLimitExecShim(
+ partitionSpec: Seq[Expression],
+ orderSpec: Seq[SortOrder],
+ rankLikeFunction: Expression,
+ limit: Int,
+ mode: WindowGroupLimitMode,
+ child: SparkPlan)
+ extends UnaryExecNode {
+ override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+ copy(child = newChild)
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ throw new UnsupportedOperationException(
+ s"${this.getClass.getSimpleName} doesn't support doExecute")
+ }
+
+ override def output: Seq[Attribute] = child.output
+}
diff --git
a/shims/spark34/src/main/scala/org/apache/spark/sql/execution/window/WindowGroupLimitExecShim.scala
b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/window/WindowGroupLimitExecShim.scala
new file mode 100644
index 000000000..6aa2e5fb8
--- /dev/null
+++
b/shims/spark34/src/main/scala/org/apache/spark/sql/execution/window/WindowGroupLimitExecShim.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.window
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression,
SortOrder}
+import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
+
+sealed trait WindowGroupLimitMode
+
+case object Partial extends WindowGroupLimitMode
+
+case object Final extends WindowGroupLimitMode
+
+case class WindowGroupLimitExecShim(
+ partitionSpec: Seq[Expression],
+ orderSpec: Seq[SortOrder],
+ rankLikeFunction: Expression,
+ limit: Int,
+ mode: WindowGroupLimitMode,
+ child: SparkPlan)
+ extends UnaryExecNode {
+ override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+ copy(child = newChild)
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ throw new UnsupportedOperationException(
+ s"${this.getClass.getSimpleName} doesn't support doExecute")
+ }
+
+ override def output: Seq[Attribute] = child.output
+}
diff --git
a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala
b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala
index 1e23b49f9..1be269b6c 100644
---
a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala
+++
b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala
@@ -43,6 +43,7 @@ import
org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.text.TextScan
import org.apache.spark.sql.execution.datasources.v2.utils.CatalogUtil
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike,
ShuffleExchangeLike}
+import org.apache.spark.sql.execution.window.{WindowGroupLimitExec,
WindowGroupLimitExecShim}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.storage.{BlockId, BlockManagerId}
@@ -250,6 +251,35 @@ class Spark35Shims extends SparkShims {
(getLimit(plan.limit, plan.offset), plan.offset)
}
+ override def isWindowGroupLimitExec(plan: SparkPlan): Boolean = plan match {
+ case _: WindowGroupLimitExec => true
+ case _ => false
+ }
+
+ override def getWindowGroupLimitExecShim(plan: SparkPlan):
WindowGroupLimitExecShim = {
+ val windowGroupLimitPlan = plan.asInstanceOf[WindowGroupLimitExec]
+ WindowGroupLimitExecShim(
+ windowGroupLimitPlan.partitionSpec,
+ windowGroupLimitPlan.orderSpec,
+ windowGroupLimitPlan.rankLikeFunction,
+ windowGroupLimitPlan.limit,
+ windowGroupLimitPlan.mode,
+ windowGroupLimitPlan.child
+ )
+ }
+
+ override def getWindowGroupLimitExec(windowGroupLimitPlan: SparkPlan):
SparkPlan = {
+ val windowGroupLimitExecShim =
windowGroupLimitPlan.asInstanceOf[WindowGroupLimitExecShim]
+ WindowGroupLimitExec(
+ windowGroupLimitExecShim.partitionSpec,
+ windowGroupLimitExecShim.orderSpec,
+ windowGroupLimitExecShim.rankLikeFunction,
+ windowGroupLimitExecShim.limit,
+ windowGroupLimitExecShim.mode,
+ windowGroupLimitExecShim.child
+ )
+ }
+
override def getLimitAndOffsetFromTopK(plan: TakeOrderedAndProjectExec):
(Int, Int) = {
(getLimit(plan.limit, plan.offset), plan.offset)
}
diff --git
a/shims/spark35/src/main/scala/org/apache/spark/sql/execution/window/WindowGroupLimitExecShim.scala
b/shims/spark35/src/main/scala/org/apache/spark/sql/execution/window/WindowGroupLimitExecShim.scala
new file mode 100644
index 000000000..16166e817
--- /dev/null
+++
b/shims/spark35/src/main/scala/org/apache/spark/sql/execution/window/WindowGroupLimitExecShim.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.window
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression,
SortOrder}
+import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
+
+case class WindowGroupLimitExecShim(
+ partitionSpec: Seq[Expression],
+ orderSpec: Seq[SortOrder],
+ rankLikeFunction: Expression,
+ limit: Int,
+ mode: WindowGroupLimitMode,
+ child: SparkPlan)
+ extends UnaryExecNode {
+ override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+ copy(child = newChild)
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ throw new UnsupportedOperationException(
+ s"${this.getClass.getSimpleName} doesn't support doExecute")
+ }
+
+ override def output: Seq[Attribute] = child.output
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]