This is an automated email from the ASF dual-hosted git repository.
yuanzhou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new f085450fcf [GLUTEN-4836][VL] Enable `Rank` and `DenseRank` for
`WindowGroupLimitExec` (#11567)
f085450fcf is described below
commit f085450fcfff9f20448d5716d0a205eec5767854
Author: Yuan <[email protected]>
AuthorDate: Fri Feb 13 23:06:46 2026 +0800
[GLUTEN-4836][VL] Enable `Rank` and `DenseRank` for `WindowGroupLimitExec`
(#11567)
Enable Rank/DenseRank for WindowGroupLimit operator. These two functions
are supported on Velox TopNRowNumber
---------
Signed-off-by: Yuan <[email protected]>
---
.../gluten/backendsapi/velox/VeloxBackend.scala | 2 +-
cpp/velox/substrait/SubstraitParser.cc | 16 +++++++
cpp/velox/substrait/SubstraitParser.h | 7 +++
cpp/velox/substrait/SubstraitToVeloxPlan.cc | 11 ++++-
.../WindowGroupLimitExecTransformer.scala | 21 ++++++++-
.../utils/clickhouse/ClickHouseTestSettings.scala | 3 ++
.../gluten/utils/velox/VeloxTestSettings.scala | 2 +
...utenRemoveRedundantWindowGroupLimitsSuite.scala | 51 +++++++++++++++++++++-
.../execution/GlutenSQLWindowFunctionSuite.scala | 46 ++++++++++++++++++-
.../utils/clickhouse/ClickHouseTestSettings.scala | 3 ++
.../gluten/utils/velox/VeloxTestSettings.scala | 2 +
...utenRemoveRedundantWindowGroupLimitsSuite.scala | 51 +++++++++++++++++++++-
.../execution/GlutenSQLWindowFunctionSuite.scala | 46 ++++++++++++++++++-
.../utils/clickhouse/ClickHouseTestSettings.scala | 3 ++
.../gluten/utils/velox/VeloxTestSettings.scala | 2 +
...utenRemoveRedundantWindowGroupLimitsSuite.scala | 51 +++++++++++++++++++++-
.../execution/GlutenSQLWindowFunctionSuite.scala | 46 ++++++++++++++++++-
17 files changed, 354 insertions(+), 9 deletions(-)
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
index 8a1a343087..24d08a5792 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
@@ -419,7 +419,7 @@ object VeloxBackendSettings extends BackendSettingsApi {
override def supportWindowGroupLimitExec(rankLikeFunction: Expression):
Boolean = {
rankLikeFunction match {
- case _: RowNumber => true
+ case _: RowNumber | _: Rank | _: DenseRank => true
case _ => false
}
}
diff --git a/cpp/velox/substrait/SubstraitParser.cc
b/cpp/velox/substrait/SubstraitParser.cc
index 2bc1dd71c3..c67ad56f09 100644
--- a/cpp/velox/substrait/SubstraitParser.cc
+++ b/cpp/velox/substrait/SubstraitParser.cc
@@ -289,6 +289,22 @@ bool SubstraitParser::configSetInOptimization(
return false;
}
+bool SubstraitParser::checkWindowFunction(
+ const ::substrait::extensions::AdvancedExtension& extension,
+ const std::string& targetFunction) {
+ const std::string config = "window_function=";
+ if (extension.has_optimization()) {
+ google::protobuf::StringValue msg;
+ extension.optimization().UnpackTo(&msg);
+ std::size_t pos = msg.value().find(config);
+ if ((pos != std::string::npos) && (msg.value().size() >=
targetFunction.size()) &&
+ (msg.value().substr(pos + config.size(), targetFunction.size()) ==
targetFunction)) {
+ return true;
+ }
+ }
+ return false;
+}
+
std::vector<TypePtr> SubstraitParser::sigToTypes(const std::string& signature)
{
std::vector<std::string> typeStrs =
SubstraitParser::getSubFunctionTypes(signature);
std::vector<TypePtr> types;
diff --git a/cpp/velox/substrait/SubstraitParser.h
b/cpp/velox/substrait/SubstraitParser.h
index f42d05b4a2..8131851ed0 100644
--- a/cpp/velox/substrait/SubstraitParser.h
+++ b/cpp/velox/substrait/SubstraitParser.h
@@ -93,6 +93,13 @@ class SubstraitParser {
/// @return Whether the config is set as true.
static bool configSetInOptimization(const
::substrait::extensions::AdvancedExtension&, const std::string& config);
+ /// @brief Return whether a config is set as true in AdvancedExtension
+ /// optimization.
+ /// @param extension Substrait advanced extension.
+ /// @param target function
+ /// @return Whether the target function is match.
+ static bool checkWindowFunction(const
::substrait::extensions::AdvancedExtension&, const std::string& targetFunction);
+
/// Extract input types from Substrait function signature.
static std::vector<facebook::velox::TypePtr> sigToTypes(const std::string&
functionSig);
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.cc
b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
index b543dfa8ba..83099efdde 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlan.cc
+++ b/cpp/velox/substrait/SubstraitToVeloxPlan.cc
@@ -1169,9 +1169,18 @@ core::PlanNodePtr
SubstraitToVeloxPlanConverter::toVeloxPlan(
childNode);
}
+ auto windowFunc = core::TopNRowNumberNode::RankFunction::kRowNumber;
+ if (windowGroupLimitRel.has_advanced_extension()) {
+ if
(SubstraitParser::checkWindowFunction(windowGroupLimitRel.advanced_extension(),
"rank")){
+ windowFunc = core::TopNRowNumberNode::RankFunction::kRank;
+ } else if
(SubstraitParser::checkWindowFunction(windowGroupLimitRel.advanced_extension(),
"dense_rank")) {
+ windowFunc = core::TopNRowNumberNode::RankFunction::kDenseRank;
+ }
+ }
+
return std::make_shared<core::TopNRowNumberNode>(
nextPlanNodeId(),
- core::TopNRowNumberNode::RankFunction::kRowNumber,
+ windowFunc,
partitionKeys,
sortingKeys,
sortingOrders,
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala
index 282e1b8e71..27bc765047 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WindowGroupLimitExecTransformer.scala
@@ -17,16 +17,19 @@
package org.apache.gluten.execution
import org.apache.gluten.backendsapi.BackendsApiManager
+import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.expression.ExpressionConverter
import org.apache.gluten.metrics.MetricsUpdater
import org.apache.gluten.substrait.SubstraitContext
+import org.apache.gluten.substrait.extensions.ExtensionBuilder
import org.apache.gluten.substrait.rel.{RelBuilder, RelNode}
-import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute,
Expression, SortOrder}
+import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute,
DenseRank, Expression, Rank, RowNumber, SortOrder}
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples,
ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.window.{GlutenFinal, GlutenPartial,
GlutenWindowGroupLimitMode}
+import com.google.protobuf.StringValue
import io.substrait.proto.SortField
import scala.collection.JavaConverters._
@@ -111,11 +114,27 @@ case class WindowGroupLimitExecTransformer(
builder.build()
}.asJava
if (!validation) {
+ val windowFunction = rankLikeFunction match {
+ case _: RowNumber => "row_number"
+ case _: Rank => "rank"
+ case _: DenseRank => "dense_rank"
+ case _ => throw new GlutenNotSupportException(s"Unknow window function
$rankLikeFunction")
+ }
+ val parametersStr = new StringBuffer("WindowGroupLimitParameters:")
+ parametersStr
+ .append("window_function=")
+ .append(windowFunction)
+ .append("\n")
+ val message =
StringValue.newBuilder().setValue(parametersStr.toString).build()
+ val extensionNode = ExtensionBuilder.makeAdvancedExtension(
+ BackendsApiManager.getTransformerApiInstance.packPBMessage(message),
+ null)
RelBuilder.makeWindowGroupLimitRel(
input,
partitionsExpressions,
sortFieldList,
limit,
+ extensionNode,
context,
operatorId)
} else {
diff --git
a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index d325d8a6b9..29d7534e8f 100644
---
a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++
b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -2006,6 +2006,7 @@ class ClickHouseTestSettings extends BackendTestSettings {
.excludeCH("SPLIT")
enableSuite[GlutenRemoveRedundantWindowGroupLimitsSuite]
.excludeCH("remove redundant WindowGroupLimits")
+ .excludeCH("Gluten - remove redundant WindowGroupLimits")
enableSuite[GlutenReplaceHashWithSortAggSuite]
.exclude("replace partial hash aggregate with sort aggregate")
.exclude("replace partial and final hash aggregate together with sort
aggregate")
@@ -2060,6 +2061,8 @@ class ClickHouseTestSettings extends BackendTestSettings {
.excludeCH(
"window function: multiple window expressions specified by range in a
single expression")
.excludeCH("Gluten - Filter on row number")
+ .excludeCH("Gluten - Filter on rank")
+ .excludeCH("Gluten - Filter on dense_rank")
enableSuite[GlutenSameResultSuite]
enableSuite[GlutenSaveLoadSuite]
enableSuite[GlutenScalaReflectionRelationSuite]
diff --git
a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
index f4427d7d43..b76a717e42 100644
---
a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
+++
b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
@@ -960,6 +960,8 @@ class VeloxTestSettings extends BackendTestSettings {
enableSuite[GlutenParquetFileMetadataStructRowIndexSuite]
enableSuite[GlutenTableLocationSuite]
enableSuite[GlutenRemoveRedundantWindowGroupLimitsSuite]
+ // rewrite with Gluten test
+ .exclude("remove redundant WindowGroupLimits")
enableSuite[GlutenSQLCollectLimitExecSuite]
enableSuite[GlutenBatchEvalPythonExecSuite]
// Replaced with other tests that check for native operations
diff --git
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/GlutenRemoveRedundantWindowGroupLimitsSuite.scala
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/GlutenRemoveRedundantWindowGroupLimitsSuite.scala
index 9d819d2bd9..455fa283b1 100644
---
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/GlutenRemoveRedundantWindowGroupLimitsSuite.scala
+++
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/GlutenRemoveRedundantWindowGroupLimitsSuite.scala
@@ -16,8 +16,57 @@
*/
package org.apache.spark.sql.execution
+import org.apache.gluten.execution.WindowGroupLimitExecTransformer
+
+import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.GlutenSQLTestsBaseTrait
+import org.apache.spark.sql.functions.lit
class GlutenRemoveRedundantWindowGroupLimitsSuite
extends RemoveRedundantWindowGroupLimitsSuite
- with GlutenSQLTestsBaseTrait {}
+ with GlutenSQLTestsBaseTrait {
+ private def checkNumWindowGroupLimits(df: DataFrame, count: Int): Unit = {
+ val plan = df.queryExecution.executedPlan
+ assert(collectWithSubqueries(plan) {
+ case exec: WindowGroupLimitExecTransformer => exec
+ }.length == count)
+ }
+
+ private def checkWindowGroupLimits(query: String, count: Int): Unit = {
+ val df = sql(query)
+ checkNumWindowGroupLimits(df, count)
+ val result = df.collect()
+ checkAnswer(df, result)
+ }
+
+ testGluten("remove redundant WindowGroupLimits") {
+ withTempView("t") {
+ spark.range(0, 100).withColumn("value",
lit(1)).createOrReplaceTempView("t")
+ val query1 =
+ """
+ |SELECT *
+ |FROM (
+ | SELECT id, rank() OVER w AS rn
+ | FROM t
+ | GROUP BY id
+ | WINDOW w AS (PARTITION BY id ORDER BY max(value))
+ |)
+ |WHERE rn < 3
+ |""".stripMargin
+ checkWindowGroupLimits(query1, 1)
+
+ val query2 =
+ """
+ |SELECT *
+ |FROM (
+ | SELECT id, rank() OVER w AS rn
+ | FROM t
+ | GROUP BY id
+ | WINDOW w AS (ORDER BY max(value))
+ |)
+ |WHERE rn < 3
+ |""".stripMargin
+ checkWindowGroupLimits(query2, 2)
+ }
+ }
+}
diff --git
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/GlutenSQLWindowFunctionSuite.scala
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/GlutenSQLWindowFunctionSuite.scala
index 4a87bac690..32e7e2c717 100644
---
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/GlutenSQLWindowFunctionSuite.scala
+++
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/GlutenSQLWindowFunctionSuite.scala
@@ -176,7 +176,51 @@ class GlutenSQLWindowFunctionSuite extends
SQLWindowFunctionSuite with GlutenSQL
)
)
assert(
- !getExecutedPlan(df).exists {
+ getExecutedPlan(df).exists {
+ case _: WindowGroupLimitExecTransformer => true
+ case _ => false
+ }
+ )
+ }
+ }
+
+ testGluten("Filter on dense_rank") {
+ withTable("customer") {
+ val rdd = spark.sparkContext.parallelize(customerData)
+ val customerDF = spark.createDataFrame(rdd, customerSchema)
+ customerDF.createOrReplaceTempView("customer")
+ val query =
+ """
+ |SELECT * from (SELECT
+ | c_custkey,
+ | c_acctbal,
+ | dense_rank() OVER (
+ | PARTITION BY c_nationkey,
+ | "a"
+ | ORDER BY
+ | c_custkey,
+ | "a"
+ | ) AS rank
+ |FROM
+ | customer ORDER BY 1, 2) where rank <=2
+ |""".stripMargin
+ val df = sql(query)
+ checkAnswer(
+ df,
+ Seq(
+ Row(4553, BigDecimal(638841L, 2), 1),
+ Row(4953, BigDecimal(603728L, 2), 1),
+ Row(9954, BigDecimal(758725L, 2), 1),
+ Row(35403, BigDecimal(603470L, 2), 2),
+ Row(35803, BigDecimal(528487L, 2), 1),
+ Row(61065, BigDecimal(728477L, 2), 1),
+ Row(95337, BigDecimal(91561L, 2), 2),
+ Row(127412, BigDecimal(462141L, 2), 2),
+ Row(148303, BigDecimal(430230L, 2), 2)
+ )
+ )
+ assert(
+ getExecutedPlan(df).exists {
case _: WindowGroupLimitExecTransformer => true
case _ => false
}
diff --git
a/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
b/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index abccf9fe91..ec99089c32 100644
---
a/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++
b/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -1982,6 +1982,7 @@ class ClickHouseTestSettings extends BackendTestSettings {
.excludeCH("SPLIT")
enableSuite[GlutenRemoveRedundantWindowGroupLimitsSuite]
.excludeCH("remove redundant WindowGroupLimits")
+ .excludeCH("Gluten - remove redundant WindowGroupLimits")
enableSuite[GlutenReplaceHashWithSortAggSuite]
.exclude("replace partial hash aggregate with sort aggregate")
.exclude("replace partial and final hash aggregate together with sort
aggregate")
@@ -2036,6 +2037,8 @@ class ClickHouseTestSettings extends BackendTestSettings {
.excludeCH(
"window function: multiple window expressions specified by range in a
single expression")
.excludeCH("Gluten - Filter on row number")
+ .excludeCH("Gluten - Filter on rank")
+ .excludeCH("Gluten - Filter on dense_rank")
enableSuite[GlutenSameResultSuite]
enableSuite[GlutenSaveLoadSuite]
enableSuite[GlutenScalaReflectionRelationSuite]
diff --git
a/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
b/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
index 202705b6d1..f5c9d22db6 100644
---
a/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
+++
b/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
@@ -1122,6 +1122,8 @@ class VeloxTestSettings extends BackendTestSettings {
enableSuite[GlutenParquetFileMetadataStructRowIndexSuite]
enableSuite[GlutenTableLocationSuite]
enableSuite[GlutenRemoveRedundantWindowGroupLimitsSuite]
+ // rewrite with Gluten test
+ .exclude("remove redundant WindowGroupLimits")
enableSuite[GlutenSQLCollectLimitExecSuite]
// Generated suites for org.apache.spark.sql.execution.python
// TODO: 4.x enableSuite[GlutenPythonDataSourceSuite] // 1 failure
diff --git
a/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/execution/GlutenRemoveRedundantWindowGroupLimitsSuite.scala
b/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/execution/GlutenRemoveRedundantWindowGroupLimitsSuite.scala
index 9d819d2bd9..455fa283b1 100644
---
a/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/execution/GlutenRemoveRedundantWindowGroupLimitsSuite.scala
+++
b/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/execution/GlutenRemoveRedundantWindowGroupLimitsSuite.scala
@@ -16,8 +16,57 @@
*/
package org.apache.spark.sql.execution
+import org.apache.gluten.execution.WindowGroupLimitExecTransformer
+
+import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.GlutenSQLTestsBaseTrait
+import org.apache.spark.sql.functions.lit
class GlutenRemoveRedundantWindowGroupLimitsSuite
extends RemoveRedundantWindowGroupLimitsSuite
- with GlutenSQLTestsBaseTrait {}
+ with GlutenSQLTestsBaseTrait {
+ private def checkNumWindowGroupLimits(df: DataFrame, count: Int): Unit = {
+ val plan = df.queryExecution.executedPlan
+ assert(collectWithSubqueries(plan) {
+ case exec: WindowGroupLimitExecTransformer => exec
+ }.length == count)
+ }
+
+ private def checkWindowGroupLimits(query: String, count: Int): Unit = {
+ val df = sql(query)
+ checkNumWindowGroupLimits(df, count)
+ val result = df.collect()
+ checkAnswer(df, result)
+ }
+
+ testGluten("remove redundant WindowGroupLimits") {
+ withTempView("t") {
+ spark.range(0, 100).withColumn("value",
lit(1)).createOrReplaceTempView("t")
+ val query1 =
+ """
+ |SELECT *
+ |FROM (
+ | SELECT id, rank() OVER w AS rn
+ | FROM t
+ | GROUP BY id
+ | WINDOW w AS (PARTITION BY id ORDER BY max(value))
+ |)
+ |WHERE rn < 3
+ |""".stripMargin
+ checkWindowGroupLimits(query1, 1)
+
+ val query2 =
+ """
+ |SELECT *
+ |FROM (
+ | SELECT id, rank() OVER w AS rn
+ | FROM t
+ | GROUP BY id
+ | WINDOW w AS (ORDER BY max(value))
+ |)
+ |WHERE rn < 3
+ |""".stripMargin
+ checkWindowGroupLimits(query2, 2)
+ }
+ }
+}
diff --git
a/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/execution/GlutenSQLWindowFunctionSuite.scala
b/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/execution/GlutenSQLWindowFunctionSuite.scala
index 7c803dd78d..7515d45fca 100644
---
a/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/execution/GlutenSQLWindowFunctionSuite.scala
+++
b/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/execution/GlutenSQLWindowFunctionSuite.scala
@@ -178,7 +178,51 @@ class GlutenSQLWindowFunctionSuite extends
SQLWindowFunctionSuite with GlutenSQL
)
)
assert(
- !getExecutedPlan(df).exists {
+ getExecutedPlan(df).exists {
+ case _: WindowGroupLimitExecTransformer => true
+ case _ => false
+ }
+ )
+ }
+ }
+
+ testGluten("Filter on dense_rank") {
+ withTable("customer") {
+ val rdd = spark.sparkContext.parallelize(customerData)
+ val customerDF = spark.createDataFrame(rdd, customerSchema)
+ customerDF.createOrReplaceTempView("customer")
+ val query =
+ """
+ |SELECT * from (SELECT
+ | c_custkey,
+ | c_acctbal,
+ | dense_rank() OVER (
+ | PARTITION BY c_nationkey,
+ | "a"
+ | ORDER BY
+ | c_custkey,
+ | "a"
+ | ) AS rank
+ |FROM
+ | customer ORDER BY 1, 2) where rank <=2
+ |""".stripMargin
+ val df = sql(query)
+ checkAnswer(
+ df,
+ Seq(
+ Row(4553, BigDecimal(638841L, 2), 1),
+ Row(4953, BigDecimal(603728L, 2), 1),
+ Row(9954, BigDecimal(758725L, 2), 1),
+ Row(35403, BigDecimal(603470L, 2), 2),
+ Row(35803, BigDecimal(528487L, 2), 1),
+ Row(61065, BigDecimal(728477L, 2), 1),
+ Row(95337, BigDecimal(91561L, 2), 2),
+ Row(127412, BigDecimal(462141L, 2), 2),
+ Row(148303, BigDecimal(430230L, 2), 2)
+ )
+ )
+ assert(
+ getExecutedPlan(df).exists {
case _: WindowGroupLimitExecTransformer => true
case _ => false
}
diff --git
a/gluten-ut/spark41/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
b/gluten-ut/spark41/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index abccf9fe91..ec99089c32 100644
---
a/gluten-ut/spark41/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++
b/gluten-ut/spark41/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -1982,6 +1982,7 @@ class ClickHouseTestSettings extends BackendTestSettings {
.excludeCH("SPLIT")
enableSuite[GlutenRemoveRedundantWindowGroupLimitsSuite]
.excludeCH("remove redundant WindowGroupLimits")
+ .excludeCH("Gluten - remove redundant WindowGroupLimits")
enableSuite[GlutenReplaceHashWithSortAggSuite]
.exclude("replace partial hash aggregate with sort aggregate")
.exclude("replace partial and final hash aggregate together with sort
aggregate")
@@ -2036,6 +2037,8 @@ class ClickHouseTestSettings extends BackendTestSettings {
.excludeCH(
"window function: multiple window expressions specified by range in a
single expression")
.excludeCH("Gluten - Filter on row number")
+ .excludeCH("Gluten - Filter on rank")
+ .excludeCH("Gluten - Filter on dense_rank")
enableSuite[GlutenSameResultSuite]
enableSuite[GlutenSaveLoadSuite]
enableSuite[GlutenScalaReflectionRelationSuite]
diff --git
a/gluten-ut/spark41/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
b/gluten-ut/spark41/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
index a74142c95d..e8f8dfa762 100644
---
a/gluten-ut/spark41/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
+++
b/gluten-ut/spark41/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
@@ -1108,6 +1108,8 @@ class VeloxTestSettings extends BackendTestSettings {
enableSuite[GlutenParquetFileMetadataStructRowIndexSuite]
enableSuite[GlutenTableLocationSuite]
enableSuite[GlutenRemoveRedundantWindowGroupLimitsSuite]
+ // rewrite with Gluten test
+ .exclude("remove redundant WindowGroupLimits")
enableSuite[GlutenSQLCollectLimitExecSuite]
// Generated suites for org.apache.spark.sql.execution.python
// TODO: 4.x enableSuite[GlutenPythonDataSourceSuite]
diff --git
a/gluten-ut/spark41/src/test/scala/org/apache/spark/sql/execution/GlutenRemoveRedundantWindowGroupLimitsSuite.scala
b/gluten-ut/spark41/src/test/scala/org/apache/spark/sql/execution/GlutenRemoveRedundantWindowGroupLimitsSuite.scala
index 9d819d2bd9..455fa283b1 100644
---
a/gluten-ut/spark41/src/test/scala/org/apache/spark/sql/execution/GlutenRemoveRedundantWindowGroupLimitsSuite.scala
+++
b/gluten-ut/spark41/src/test/scala/org/apache/spark/sql/execution/GlutenRemoveRedundantWindowGroupLimitsSuite.scala
@@ -16,8 +16,57 @@
*/
package org.apache.spark.sql.execution
+import org.apache.gluten.execution.WindowGroupLimitExecTransformer
+
+import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.GlutenSQLTestsBaseTrait
+import org.apache.spark.sql.functions.lit
class GlutenRemoveRedundantWindowGroupLimitsSuite
extends RemoveRedundantWindowGroupLimitsSuite
- with GlutenSQLTestsBaseTrait {}
+ with GlutenSQLTestsBaseTrait {
+ private def checkNumWindowGroupLimits(df: DataFrame, count: Int): Unit = {
+ val plan = df.queryExecution.executedPlan
+ assert(collectWithSubqueries(plan) {
+ case exec: WindowGroupLimitExecTransformer => exec
+ }.length == count)
+ }
+
+ private def checkWindowGroupLimits(query: String, count: Int): Unit = {
+ val df = sql(query)
+ checkNumWindowGroupLimits(df, count)
+ val result = df.collect()
+ checkAnswer(df, result)
+ }
+
+ testGluten("remove redundant WindowGroupLimits") {
+ withTempView("t") {
+ spark.range(0, 100).withColumn("value",
lit(1)).createOrReplaceTempView("t")
+ val query1 =
+ """
+ |SELECT *
+ |FROM (
+ | SELECT id, rank() OVER w AS rn
+ | FROM t
+ | GROUP BY id
+ | WINDOW w AS (PARTITION BY id ORDER BY max(value))
+ |)
+ |WHERE rn < 3
+ |""".stripMargin
+ checkWindowGroupLimits(query1, 1)
+
+ val query2 =
+ """
+ |SELECT *
+ |FROM (
+ | SELECT id, rank() OVER w AS rn
+ | FROM t
+ | GROUP BY id
+ | WINDOW w AS (ORDER BY max(value))
+ |)
+ |WHERE rn < 3
+ |""".stripMargin
+ checkWindowGroupLimits(query2, 2)
+ }
+ }
+}
diff --git
a/gluten-ut/spark41/src/test/scala/org/apache/spark/sql/execution/GlutenSQLWindowFunctionSuite.scala
b/gluten-ut/spark41/src/test/scala/org/apache/spark/sql/execution/GlutenSQLWindowFunctionSuite.scala
index 7c803dd78d..7515d45fca 100644
---
a/gluten-ut/spark41/src/test/scala/org/apache/spark/sql/execution/GlutenSQLWindowFunctionSuite.scala
+++
b/gluten-ut/spark41/src/test/scala/org/apache/spark/sql/execution/GlutenSQLWindowFunctionSuite.scala
@@ -178,7 +178,51 @@ class GlutenSQLWindowFunctionSuite extends
SQLWindowFunctionSuite with GlutenSQL
)
)
assert(
- !getExecutedPlan(df).exists {
+ getExecutedPlan(df).exists {
+ case _: WindowGroupLimitExecTransformer => true
+ case _ => false
+ }
+ )
+ }
+ }
+
+ testGluten("Filter on dense_rank") {
+ withTable("customer") {
+ val rdd = spark.sparkContext.parallelize(customerData)
+ val customerDF = spark.createDataFrame(rdd, customerSchema)
+ customerDF.createOrReplaceTempView("customer")
+ val query =
+ """
+ |SELECT * from (SELECT
+ | c_custkey,
+ | c_acctbal,
+ | dense_rank() OVER (
+ | PARTITION BY c_nationkey,
+ | "a"
+ | ORDER BY
+ | c_custkey,
+ | "a"
+ | ) AS rank
+ |FROM
+ | customer ORDER BY 1, 2) where rank <=2
+ |""".stripMargin
+ val df = sql(query)
+ checkAnswer(
+ df,
+ Seq(
+ Row(4553, BigDecimal(638841L, 2), 1),
+ Row(4953, BigDecimal(603728L, 2), 1),
+ Row(9954, BigDecimal(758725L, 2), 1),
+ Row(35403, BigDecimal(603470L, 2), 2),
+ Row(35803, BigDecimal(528487L, 2), 1),
+ Row(61065, BigDecimal(728477L, 2), 1),
+ Row(95337, BigDecimal(91561L, 2), 2),
+ Row(127412, BigDecimal(462141L, 2), 2),
+ Row(148303, BigDecimal(430230L, 2), 2)
+ )
+ )
+ assert(
+ getExecutedPlan(df).exists {
case _: WindowGroupLimitExecTransformer => true
case _ => false
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]