This is an automated email from the ASF dual-hosted git repository.

lgbo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new afccfed98 [GLUTEN-7087][CH] Support `WindowGroupLimitExec` (#7176)
afccfed98 is described below

commit afccfed98ef85590975997adf35468ad227a69a1
Author: lgbo <[email protected]>
AuthorDate: Thu Sep 12 09:55:18 2024 +0800

    [GLUTEN-7087][CH] Support `WindowGroupLimitExec` (#7176)
    
    * support WindowGroupLimit
    
    * 0903
    
    * implement window group limit
---
 .../gluten/backendsapi/clickhouse/CHBackend.scala  |   2 +
 .../clickhouse/CHSparkPlanExecApi.scala            |  16 +
 .../CHWindowGroupLimitExecTransformer.scala        | 187 +++++++++++
 .../GlutenClickHouseTPCDSAbstractSuite.scala       |   2 +-
 .../GlutenClickHouseTPCHSaltNullParquetSuite.scala |   6 +-
 cpp-ch/local-engine/Common/CHUtil.cpp              |  41 +--
 cpp-ch/local-engine/Operator/ReplicateRowsStep.cpp |  20 +-
 .../local-engine/Operator/WindowGroupLimitStep.cpp | 365 +++++++++++++++++++++
 .../local-engine/Operator/WindowGroupLimitStep.h   |  51 +++
 .../Parser/AdvancedParametersParseUtil.cpp         |  31 +-
 .../Parser/AdvancedParametersParseUtil.h           |   9 +-
 cpp-ch/local-engine/Parser/RelParser.cpp           |  14 +-
 .../local-engine/Parser/SerializedPlanParser.cpp   |  29 +-
 .../Parser/WindowGroupLimitRelParser.cpp           | 112 +++++++
 .../Parser/WindowGroupLimitRelParser.h             |  52 +++
 .../gluten/backendsapi/SparkPlanExecApi.scala      |  10 +
 .../extension/columnar/OffloadSingleNode.scala     |   2 +-
 17 files changed, 876 insertions(+), 73 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
index 69ea899c4..45aee4322 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
@@ -418,4 +418,6 @@ object CHBackendSettings extends BackendSettingsApi with 
Logging {
         }
       }
   }
+
+  override def supportWindowGroupLimitExec(rankLikeFunction: Expression): 
Boolean = true
 }
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index 1108b8b3c..f765a75d2 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -53,6 +53,7 @@ import 
org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleEx
 import org.apache.spark.sql.execution.joins.{BuildSideRelation, 
ClickHouseBuildSideRelation, HashedRelationBroadcastMode}
 import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.execution.utils.{CHExecUtil, PushDownUtil}
+import org.apache.spark.sql.execution.window._
 import org.apache.spark.sql.types.{DecimalType, StructType}
 import org.apache.spark.sql.vectorized.ColumnarBatch
 
@@ -909,4 +910,19 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with 
Logging {
       toScale: Int): DecimalType = {
     SparkShimLoader.getSparkShims.genDecimalRoundExpressionOutput(decimalType, 
toScale)
   }
+
+  override def genWindowGroupLimitTransformer(
+      partitionSpec: Seq[Expression],
+      orderSpec: Seq[SortOrder],
+      rankLikeFunction: Expression,
+      limit: Int,
+      mode: WindowGroupLimitMode,
+      child: SparkPlan): SparkPlan =
+    CHWindowGroupLimitExecTransformer(
+      partitionSpec,
+      orderSpec,
+      rankLikeFunction,
+      limit,
+      mode,
+      child)
 }
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHWindowGroupLimitExecTransformer.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHWindowGroupLimitExecTransformer.scala
new file mode 100644
index 000000000..c2648f29e
--- /dev/null
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHWindowGroupLimitExecTransformer.scala
@@ -0,0 +1,187 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.execution
+
+import org.apache.gluten.backendsapi.BackendsApiManager
+import org.apache.gluten.exception.GlutenNotSupportException
+import org.apache.gluten.expression._
+import org.apache.gluten.expression.{ConverterUtils, ExpressionConverter}
+import org.apache.gluten.extension.ValidationResult
+import org.apache.gluten.metrics.MetricsUpdater
+import org.apache.gluten.substrait.`type`.TypeBuilder
+import org.apache.gluten.substrait.SubstraitContext
+import org.apache.gluten.substrait.extensions.ExtensionBuilder
+import org.apache.gluten.substrait.rel.{RelBuilder, RelNode}
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, 
ClusteredDistribution, Distribution, Partitioning}
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.window.{Final, Partial, 
WindowGroupLimitMode}
+
+import com.google.protobuf.StringValue
+import io.substrait.proto.SortField
+
+import scala.collection.JavaConverters._
+
+case class CHWindowGroupLimitExecTransformer(
+    partitionSpec: Seq[Expression],
+    orderSpec: Seq[SortOrder],
+    rankLikeFunction: Expression,
+    limit: Int,
+    mode: WindowGroupLimitMode,
+    child: SparkPlan)
+  extends UnaryTransformSupport {
+
+  @transient override lazy val metrics =
+    
BackendsApiManager.getMetricsApiInstance.genWindowTransformerMetrics(sparkContext)
+
+  override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+    copy(child = newChild)
+
+  override def metricsUpdater(): MetricsUpdater =
+    
BackendsApiManager.getMetricsApiInstance.genWindowTransformerMetricsUpdater(metrics)
+
+  override def output: Seq[Attribute] = child.output
+
+  override def requiredChildDistribution: Seq[Distribution] = mode match {
+    case Partial => super.requiredChildDistribution
+    case Final =>
+      if (partitionSpec.isEmpty) {
+        AllTuples :: Nil
+      } else {
+        ClusteredDistribution(partitionSpec) :: Nil
+      }
+  }
+
+  override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
+    if 
(BackendsApiManager.getSettings.requiredChildOrderingForWindowGroupLimit()) {
+      Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec)
+    } else {
+      Seq(Nil)
+    }
+  }
+
+  override def outputOrdering: Seq[SortOrder] = {
+    if (requiredChildOrdering.forall(_.isEmpty)) {
+      // The Velox backend `TopNRowNumber` does not require child ordering, 
because it
+      // uses hash table to store partition and use priority queue to track of 
top limit rows.
+      // Ideally, the output of `TopNRowNumber` is unordered but it is grouped 
for partition keys.
+      // To be safe, here we do not propagate the ordering.
+      // TODO: Make the framework aware of grouped data distribution
+      Nil
+    } else {
+      child.outputOrdering
+    }
+  }
+
+  override def outputPartitioning: Partitioning = child.outputPartitioning
+
+  def getWindowGroupLimitRel(
+      context: SubstraitContext,
+      originalInputAttributes: Seq[Attribute],
+      operatorId: Long,
+      input: RelNode,
+      validation: Boolean): RelNode = {
+    val args = context.registeredFunction
+    // Partition By Expressions
+    val partitionsExpressions = partitionSpec
+      .map(
+        ExpressionConverter
+          .replaceWithExpressionTransformer(_, attributeSeq = child.output)
+          .doTransform(args))
+      .asJava
+
+    // Sort By Expressions
+    val sortFieldList =
+      orderSpec.map {
+        order =>
+          val builder = SortField.newBuilder()
+          val exprNode = ExpressionConverter
+            .replaceWithExpressionTransformer(order.child, attributeSeq = 
child.output)
+            .doTransform(args)
+          builder.setExpr(exprNode.toProtobuf)
+          
builder.setDirectionValue(SortExecTransformer.transformSortDirection(order))
+          builder.build()
+      }.asJava
+    if (!validation) {
+      val windowFunction = rankLikeFunction match {
+        case _: RowNumber => ExpressionNames.ROW_NUMBER
+        case _: Rank => ExpressionNames.RANK
+        case _: DenseRank => ExpressionNames.DENSE_RANK
+        case _ => throw new GlutenNotSupportException(s"Unknow window function 
$rankLikeFunction")
+      }
+      val parametersStr = new StringBuffer("WindowGroupLimitParameters:")
+      parametersStr
+        .append("window_function=")
+        .append(windowFunction)
+        .append("\n")
+      val message = 
StringValue.newBuilder().setValue(parametersStr.toString).build()
+      val extensionNode = ExtensionBuilder.makeAdvancedExtension(
+        BackendsApiManager.getTransformerApiInstance.packPBMessage(message),
+        null)
+      RelBuilder.makeWindowGroupLimitRel(
+        input,
+        partitionsExpressions,
+        sortFieldList,
+        limit,
+        extensionNode,
+        context,
+        operatorId)
+    } else {
+      // Use a extension node to send the input types through Substrait plan 
for validation.
+      val inputTypeNodeList = originalInputAttributes
+        .map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
+        .asJava
+      val extensionNode = ExtensionBuilder.makeAdvancedExtension(
+        BackendsApiManager.getTransformerApiInstance.packPBMessage(
+          TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
+
+      RelBuilder.makeWindowGroupLimitRel(
+        input,
+        partitionsExpressions,
+        sortFieldList,
+        limit,
+        extensionNode,
+        context,
+        operatorId)
+    }
+  }
+
+  override protected def doValidateInternal(): ValidationResult = {
+    if 
(!BackendsApiManager.getSettings.supportWindowGroupLimitExec(rankLikeFunction)) 
{
+      return ValidationResult
+        .failed(s"Found unsupported rank like function: $rankLikeFunction")
+    }
+    val substraitContext = new SubstraitContext
+    val operatorId = substraitContext.nextOperatorId(this.nodeName)
+
+    val relNode =
+      getWindowGroupLimitRel(substraitContext, child.output, operatorId, null, 
validation = true)
+
+    doNativeValidation(substraitContext, relNode)
+  }
+
+  override protected def doTransform(context: SubstraitContext): 
TransformContext = {
+    val childCtx = child.asInstanceOf[TransformSupport].transform(context)
+    val operatorId = context.nextOperatorId(this.nodeName)
+
+    val currRel =
+      getWindowGroupLimitRel(context, child.output, operatorId, childCtx.root, 
validation = false)
+    assert(currRel != null, "Window Group Limit Rel should be valid")
+    TransformContext(childCtx.outputAttributes, output, currRel)
+  }
+}
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala
index 03b26fa98..abb7d27ff 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala
@@ -62,7 +62,7 @@ abstract class GlutenClickHouseTPCDSAbstractSuite
         })
 
   protected def fallbackSets(isAqe: Boolean): Set[Int] = {
-    if (isSparkVersionGE("3.5")) Set(44, 67, 70) else Set.empty[Int]
+    Set.empty[Int]
   }
   protected def excludedTpcdsQueries: Set[String] = Set(
     "q66" // inconsistent results
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
index f7cf0de37..9ac35441a 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
@@ -1855,7 +1855,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends 
GlutenClickHouseTPCHAbstr
         | ) t1
         |) t2 where rank = 1
     """.stripMargin
-    compareResultsAgainstVanillaSpark(sql, true, { _ => }, 
isSparkVersionLE("3.3"))
+    compareResultsAgainstVanillaSpark(sql, true, { _ => })
   }
 
   test("GLUTEN-1874 not null in both streams") {
@@ -1873,7 +1873,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends 
GlutenClickHouseTPCHAbstr
         | ) t1
         |) t2 where rank = 1
     """.stripMargin
-    compareResultsAgainstVanillaSpark(sql, true, { _ => }, 
isSparkVersionLE("3.3"))
+    compareResultsAgainstVanillaSpark(sql, true, { _ => })
   }
 
   test("GLUTEN-2095: test cast(string as binary)") {
@@ -2456,7 +2456,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends 
GlutenClickHouseTPCHAbstr
         |  ) t1
         |) t2 where rank = 1 order by p_partkey limit 100
         |""".stripMargin
-    runQueryAndCompare(sql, noFallBack = isSparkVersionLE("3.3"))({ _ => })
+    runQueryAndCompare(sql, noFallBack = true)({ _ => })
   }
 
   test("GLUTEN-4190: crush on flattening a const null column") {
diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp 
b/cpp-ch/local-engine/Common/CHUtil.cpp
index b702082d6..94a214e5e 100644
--- a/cpp-ch/local-engine/Common/CHUtil.cpp
+++ b/cpp-ch/local-engine/Common/CHUtil.cpp
@@ -53,6 +53,7 @@
 #include <Parser/SubstraitParserUtils.h>
 #include <Planner/PlannerActionsVisitor.h>
 #include <Processors/Chunk.h>
+#include <Processors/Formats/IOutputFormat.h>
 #include <Processors/QueryPlan/ExpressionStep.h>
 #include <Processors/QueryPlan/QueryPlan.h>
 #include <QueryPipeline/QueryPipelineBuilder.h>
@@ -315,7 +316,6 @@ DB::Block 
BlockUtil::concatenateBlocksMemoryEfficiently(std::vector<DB::Block> &
     return out;
 }
 
-
 size_t PODArrayUtil::adjustMemoryEfficientSize(size_t n)
 {
     /// According to definition of DEFUALT_BLOCK_SIZE
@@ -560,9 +560,7 @@ std::map<std::string, std::string> 
BackendInitializerUtil::getBackendConfMap(std
 }
 
 std::vector<String> BackendInitializerUtil::wrapDiskPathConfig(
-    const String & path_prefix,
-    const String & path_suffix,
-    Poco::Util::AbstractConfiguration & config)
+    const String & path_prefix, const String & path_suffix, 
Poco::Util::AbstractConfiguration & config)
 {
     std::vector<String> changed_paths;
     if (path_prefix.empty() && path_suffix.empty())
@@ -657,9 +655,7 @@ DB::Context::ConfigurationPtr 
BackendInitializerUtil::initConfig(std::map<std::s
         auto path_need_clean = wrapDiskPathConfig("", "/" + pid, *config);
         std::lock_guard lock(BackendFinalizerUtil::paths_mutex);
         BackendFinalizerUtil::paths_need_to_clean.insert(
-            BackendFinalizerUtil::paths_need_to_clean.end(),
-            path_need_clean.begin(),
-            path_need_clean.end());
+            BackendFinalizerUtil::paths_need_to_clean.end(), 
path_need_clean.begin(), path_need_clean.end());
     }
     return config;
 }
@@ -683,7 +679,9 @@ void 
BackendInitializerUtil::initEnvs(DB::Context::ConfigurationPtr config)
     {
         const std::string config_timezone = config->getString("timezone");
         const String mapped_timezone = 
DateTimeUtil::convertTimeZone(config_timezone);
-        if (0 != setenv("TZ", mapped_timezone.data(), 1)) // 
NOLINT(concurrency-mt-unsafe) // ok if not called concurrently with other 
setenv/getenv
+        if (0
+            != setenv(
+                "TZ", mapped_timezone.data(), 1)) // 
NOLINT(concurrency-mt-unsafe) // ok if not called concurrently with other 
setenv/getenv
             throw Poco::Exception("Cannot setenv TZ variable");
 
         tzset();
@@ -807,8 +805,7 @@ void 
BackendInitializerUtil::initSettings(std::map<std::string, std::string> & b
         {
             auto mem_gb = task_memory / static_cast<double>(1_GiB);
             // 2.8x+5, Heuristics calculate the block size of external sort, 
[8,16]
-            settings.prefer_external_sort_block_bytes = std::max(std::min(
-                static_cast<size_t>(2.8*mem_gb + 5), 16ul), 8ul) * 1024 * 1024;
+            settings.prefer_external_sort_block_bytes = 
std::max(std::min(static_cast<size_t>(2.8 * mem_gb + 5), 16ul), 8ul) * 1024 * 
1024;
         }
     }
 }
@@ -848,10 +845,14 @@ void 
BackendInitializerUtil::initContexts(DB::Context::ConfigurationPtr config)
 
         global_context->setMarkCache(mark_cache_policy, mark_cache_size, 
mark_cache_size_ratio);
 
-        String index_uncompressed_cache_policy = 
config->getString("index_uncompressed_cache_policy", 
DEFAULT_INDEX_UNCOMPRESSED_CACHE_POLICY);
-        size_t index_uncompressed_cache_size = 
config->getUInt64("index_uncompressed_cache_size", 
DEFAULT_INDEX_UNCOMPRESSED_CACHE_MAX_SIZE);
-        double index_uncompressed_cache_size_ratio = 
config->getDouble("index_uncompressed_cache_size_ratio", 
DEFAULT_INDEX_UNCOMPRESSED_CACHE_SIZE_RATIO);
-        
global_context->setIndexUncompressedCache(index_uncompressed_cache_policy, 
index_uncompressed_cache_size, index_uncompressed_cache_size_ratio);
+        String index_uncompressed_cache_policy
+            = config->getString("index_uncompressed_cache_policy", 
DEFAULT_INDEX_UNCOMPRESSED_CACHE_POLICY);
+        size_t index_uncompressed_cache_size
+            = config->getUInt64("index_uncompressed_cache_size", 
DEFAULT_INDEX_UNCOMPRESSED_CACHE_MAX_SIZE);
+        double index_uncompressed_cache_size_ratio
+            = config->getDouble("index_uncompressed_cache_size_ratio", 
DEFAULT_INDEX_UNCOMPRESSED_CACHE_SIZE_RATIO);
+        global_context->setIndexUncompressedCache(
+            index_uncompressed_cache_policy, index_uncompressed_cache_size, 
index_uncompressed_cache_size_ratio);
 
         String index_mark_cache_policy = 
config->getString("index_mark_cache_policy", DEFAULT_INDEX_MARK_CACHE_POLICY);
         size_t index_mark_cache_size = 
config->getUInt64("index_mark_cache_size", DEFAULT_INDEX_MARK_CACHE_MAX_SIZE);
@@ -1023,11 +1024,13 @@ void BackendFinalizerUtil::finalizeGlobally()
     StorageMergeTreeFactory::clear();
     QueryContext::resetGlobal();
     std::lock_guard lock(paths_mutex);
-    std::ranges::for_each(paths_need_to_clean, [](const auto & path)
-    {
-        if (fs::exists(path))
-            fs::remove_all(path);
-    });
+    std::ranges::for_each(
+        paths_need_to_clean,
+        [](const auto & path)
+        {
+            if (fs::exists(path))
+                fs::remove_all(path);
+        });
     paths_need_to_clean.clear();
 }
 
diff --git a/cpp-ch/local-engine/Operator/ReplicateRowsStep.cpp 
b/cpp-ch/local-engine/Operator/ReplicateRowsStep.cpp
index f2d4bc8a8..ecb027c18 100644
--- a/cpp-ch/local-engine/Operator/ReplicateRowsStep.cpp
+++ b/cpp-ch/local-engine/Operator/ReplicateRowsStep.cpp
@@ -32,16 +32,14 @@ namespace local_engine
 {
 static DB::ITransformingStep::Traits getTraits()
 {
-    return DB::ITransformingStep::Traits
-    {
+    return DB::ITransformingStep::Traits{
         {
             .preserves_number_of_streams = true,
             .preserves_sorting = false,
         },
         {
             .preserves_number_of_rows = false,
-        }
-    };
+        }};
 }
 
 ReplicateRowsStep::ReplicateRowsStep(const DB::DataStream & input_stream)
@@ -49,7 +47,7 @@ ReplicateRowsStep::ReplicateRowsStep(const DB::DataStream & 
input_stream)
 {
 }
 
-DB::Block ReplicateRowsStep::transformHeader(const DB::Block& input)
+DB::Block ReplicateRowsStep::transformHeader(const DB::Block & input)
 {
     DB::Block output;
     for (int i = 1; i < input.columns(); i++)
@@ -59,15 +57,9 @@ DB::Block ReplicateRowsStep::transformHeader(const 
DB::Block& input)
     return output;
 }
 
-void ReplicateRowsStep::transformPipeline(
-    DB::QueryPipelineBuilder & pipeline,
-    const DB::BuildQueryPipelineSettings & /*settings*/)
+void ReplicateRowsStep::transformPipeline(DB::QueryPipelineBuilder & pipeline, 
const DB::BuildQueryPipelineSettings & /*settings*/)
 {
-    pipeline.addSimpleTransform(
-        [&](const DB::Block & header)
-        {
-            return std::make_shared<ReplicateRowsTransform>(header);
-        });
+    pipeline.addSimpleTransform([&](const DB::Block & header) { return 
std::make_shared<ReplicateRowsTransform>(header); });
 }
 
 void ReplicateRowsStep::updateOutputStream()
@@ -105,4 +97,4 @@ void ReplicateRowsTransform::transform(DB::Chunk & chunk)
 
     chunk.setColumns(std::move(mutable_columns), total_rows);
 }
-}
\ No newline at end of file
+}
diff --git a/cpp-ch/local-engine/Operator/WindowGroupLimitStep.cpp 
b/cpp-ch/local-engine/Operator/WindowGroupLimitStep.cpp
new file mode 100644
index 000000000..af04ef579
--- /dev/null
+++ b/cpp-ch/local-engine/Operator/WindowGroupLimitStep.cpp
@@ -0,0 +1,365 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "WindowGroupLimitStep.h"
+#include <memory>
+#include <Processors/Chunk.h>
+#include <Processors/IProcessor.h>
+#include <QueryPipeline/QueryPipelineBuilder.h>
+#include <Poco/Logger.h>
+#include <Common/CHUtil.h>
+#include <Common/logger_useful.h>
+
+namespace DB::ErrorCodes
+{
+extern const int LOGICAL_ERROR;
+}
+
+namespace local_engine
+{
+
+enum class WindowGroupLimitFunction
+{
+    RowNumber,
+    Rank,
+    DenseRank
+};
+
+
+template <WindowGroupLimitFunction function>
+class WindowGroupLimitTransform : public DB::IProcessor
+{
+public:
+    using Status = DB::IProcessor::Status;
+    explicit WindowGroupLimitTransform(
+        const DB::Block & header_, const std::vector<size_t> & 
partition_columns_, const std::vector<size_t> & sort_columns_, size_t limit_)
+        : DB::IProcessor({header_}, {header_})
+        , header(header_)
+        , partition_columns(partition_columns_)
+        , sort_columns(sort_columns_)
+        , limit(limit_)
+
+    {
+    }
+    ~WindowGroupLimitTransform() override = default;
+    String getName() const override { return "WindowGroupLimitTransform"; }
+
+    Status prepare() override
+    {
+        auto & output_port = outputs.front();
+        auto & input_port = inputs.front();
+        if (output_port.isFinished())
+        {
+            input_port.close();
+            return Status::Finished;
+        }
+
+        if (has_output)
+        {
+            if (output_port.canPush())
+            {
+                output_port.push(std::move(output_chunk));
+                has_output = false;
+            }
+            return Status::PortFull;
+        }
+
+        if (has_input)
+            return Status::Ready;
+
+        if (input_port.isFinished())
+        {
+            output_port.finish();
+            return Status::Finished;
+        }
+        input_port.setNeeded();
+        if (!input_port.hasData())
+            return Status::NeedData;
+        input_chunk = input_port.pull(true);
+        has_input = true;
+        return Status::Ready;
+    }
+
+    void work() override
+    {
+        if (!has_input) [[unlikely]]
+        {
+            return;
+        }
+        DB::Block block = header.cloneWithColumns(input_chunk.getColumns());
+        size_t partition_start_row = 0;
+        size_t chunk_rows = input_chunk.getNumRows();
+        while (partition_start_row < chunk_rows)
+        {
+            auto next_partition_start_row = advanceNextPartition(input_chunk, 
partition_start_row);
+            iteratePartition(input_chunk, partition_start_row, 
next_partition_start_row);
+            partition_start_row = next_partition_start_row;
+            // corner case, the partition end row is the last row of chunk.
+            if (partition_start_row < chunk_rows)
+            {
+                current_row_rank_value = 1;
+                if constexpr (function == WindowGroupLimitFunction::Rank)
+                    current_peer_group_rows = 0;
+                partition_start_row_columns = 
extractOneRowColumns(input_chunk, partition_start_row);
+            }
+        }
+
+        if (!output_columns.empty() && output_columns[0]->size() > 0)
+        {
+            auto rows = output_columns[0]->size();
+            output_chunk = DB::Chunk(std::move(output_columns), rows);
+            output_columns.clear();
+            has_output = true;
+        }
+        has_input = false;
+    }
+
+private:
+    DB::Block header;
+    // Which columns are used as the partition keys
+    std::vector<size_t> partition_columns;
+    // which columns are used as the order by keys, excluding partition 
columns.
+    std::vector<size_t> sort_columns;
+    // Limitations for each partition.
+    size_t limit = 0;
+
+    bool has_input = false;
+    DB::Chunk input_chunk;
+    bool has_output = false;
+    DB::MutableColumns output_columns;
+    DB::Chunk output_chunk;
+
+    // We don't have window frame here. in fact all of frame are (unbounded 
preceding, current row]
+    // the start value is 1
+    size_t current_row_rank_value = 1;
+    // rank need this to record how many rows in current peer group.
+    // A peer group in a partition is defined as the rows have the same value 
on the sort columns.
+    size_t current_peer_group_rows = 0;
+
+    DB::Columns partition_start_row_columns;
+    DB::Columns peer_group_start_row_columns;
+
+
+    size_t advanceNextPartition(const DB::Chunk & chunk, size_t start_offset)
+    {
+        if (partition_start_row_columns.empty())
+            partition_start_row_columns = extractOneRowColumns(chunk, 
start_offset);
+
+        size_t max_row = chunk.getNumRows();
+        for (size_t i = start_offset; i < max_row; ++i)
+        {
+            if (!isRowEqual(partition_columns, partition_start_row_columns, 0, 
chunk.getColumns(), i))
+            {
+                return i;
+            }
+        }
+        return max_row;
+    }
+
+    static DB::Columns extractOneRowColumns(const DB::Chunk & chunk, size_t 
offset)
+    {
+        DB::Columns row;
+        for (const auto & col : chunk.getColumns())
+        {
+            auto new_col = col->cloneEmpty();
+            new_col->insertFrom(*col, offset);
+            row.push_back(std::move(new_col));
+        }
+        return row;
+    }
+
+    static bool isRowEqual(
+        const std::vector<size_t> & fields, const DB::Columns & left_cols, 
size_t loffset, const DB::Columns & right_cols, size_t roffset)
+    {
+        for (size_t i = 0; i < fields.size(); ++i)
+        {
+            const auto & field = fields[i];
+            /// don't care about nan_direction_hint
+            if (left_cols[field]->compareAt(loffset, roffset, 
*right_cols[field], 1))
+                return false;
+        }
+        return true;
+    }
+
+    void iteratePartition(const DB::Chunk & chunk, size_t start_offset, size_t 
end_offset)
+    {
+        // Skip the rest rows int this partition.
+        if (current_row_rank_value > limit)
+            return;
+
+
+        size_t chunk_rows = chunk.getNumRows();
+        auto has_peer_group_ended = [&](size_t offset, size_t 
partition_end_offset, size_t chunk_rows_)
+        { return offset < partition_end_offset || end_offset < chunk_rows_; };
+        auto try_end_peer_group
+            = [&](size_t peer_group_start_offset, size_t 
next_peer_group_start_offset, size_t partition_end_offset, size_t chunk_rows_)
+        {
+            if constexpr (function == WindowGroupLimitFunction::Rank)
+            {
+                current_peer_group_rows += next_peer_group_start_offset - 
peer_group_start_offset;
+                if (has_peer_group_ended(next_peer_group_start_offset, 
partition_end_offset, chunk_rows_))
+                {
+                    current_row_rank_value += current_peer_group_rows;
+                    current_peer_group_rows = 0;
+                    peer_group_start_row_columns = extractOneRowColumns(chunk, 
next_peer_group_start_offset);
+                }
+            }
+            else if constexpr (function == WindowGroupLimitFunction::DenseRank)
+            {
+                if (has_peer_group_ended(next_peer_group_start_offset, 
partition_end_offset, chunk_rows_))
+                {
+                    current_row_rank_value += 1;
+                    peer_group_start_row_columns = extractOneRowColumns(chunk, 
next_peer_group_start_offset);
+                }
+            }
+        };
+
+        // This is a corner case. prev partition's last row is the last row of 
a chunk.
+        if (start_offset >= end_offset)
+        {
+            assert(!start_offset);
+            try_end_peer_group(start_offset, end_offset, end_offset, 
chunk_rows);
+            return;
+        }
+
+        //  row_number is simple
+        if constexpr (function == WindowGroupLimitFunction::RowNumber)
+        {
+            size_t rows = end_offset - start_offset;
+            size_t limit_remained = limit - current_row_rank_value + 1;
+            rows = rows > limit_remained ? limit_remained : rows;
+            insertResultValue(chunk, start_offset, rows);
+            current_row_rank_value += rows;
+        }
+        else
+        {
+            size_t peer_group_start_offset = start_offset;
+            while (peer_group_start_offset < end_offset && 
current_row_rank_value <= limit)
+            {
+                auto next_peer_group_start_offset = 
advanceNextPeerGroup(chunk, peer_group_start_offset, end_offset);
+
+                insertResultValue(chunk, peer_group_start_offset, 
next_peer_group_start_offset - peer_group_start_offset);
+                try_end_peer_group(peer_group_start_offset, 
next_peer_group_start_offset, end_offset, chunk_rows);
+                peer_group_start_offset = next_peer_group_start_offset;
+            }
+        }
+    }
+    void insertResultValue(const DB::Chunk & chunk, size_t start_offset, 
size_t rows)
+    {
+        if (!rows)
+            return;
+        if (output_columns.empty())
+        {
+            for (const auto & col : chunk.getColumns())
+            {
+                output_columns.push_back(col->cloneEmpty());
+            }
+        }
+        size_t i = 0;
+        for (const auto & col : chunk.getColumns())
+        {
+            output_columns[i]->insertRangeFrom(*col, start_offset, rows);
+            i += 1;
+        }
+    }
+    size_t advanceNextPeerGroup(const DB::Chunk & chunk, size_t start_offset, 
size_t partition_end_offset)
+    {
+        if (peer_group_start_row_columns.empty())
+            peer_group_start_row_columns = extractOneRowColumns(chunk, 
start_offset);
+        for (size_t i = start_offset; i < partition_end_offset; ++i)
+        {
+            if (!isRowEqual(sort_columns, peer_group_start_row_columns, 0, 
chunk.getColumns(), i))
+            {
+                return i;
+            }
+        }
+        return partition_end_offset;
+    }
+};
+
+static DB::ITransformingStep::Traits getTraits()
+{
+    return DB::ITransformingStep::Traits{
+        {
+            .preserves_number_of_streams = false,
+            .preserves_sorting = true,
+        },
+        {
+            .preserves_number_of_rows = false,
+        }};
+}
+
+WindowGroupLimitStep::WindowGroupLimitStep(
+    const DB::DataStream & input_stream_,
+    const String & function_name_,
+    const std::vector<size_t> partition_columns_,
+    const std::vector<size_t> sort_columns_,
+    size_t limit_)
+    : DB::ITransformingStep(input_stream_, input_stream_.header, getTraits())
+    , function_name(function_name_)
+    , partition_columns(partition_columns_)
+    , sort_columns(sort_columns_)
+    , limit(limit_)
+{
+}
+
+void WindowGroupLimitStep::describePipeline(DB::IQueryPlanStep::FormatSettings 
& settings) const
+{
+    if (!processors.empty())
+        DB::IQueryPlanStep::describePipeline(processors, settings);
+}
+
+void WindowGroupLimitStep::updateOutputStream()
+{
+    output_stream = createOutputStream(input_streams.front(), 
input_streams.front().header, getDataStreamTraits());
+}
+
+
+void WindowGroupLimitStep::transformPipeline(DB::QueryPipelineBuilder & 
pipeline, const DB::BuildQueryPipelineSettings & /*settings*/)
+{
+    if (function_name == "row_number")
+    {
+        pipeline.addSimpleTransform(
+            [&](const DB::Block & header)
+            {
+                return 
std::make_shared<WindowGroupLimitTransform<WindowGroupLimitFunction::RowNumber>>(
+                    header, partition_columns, sort_columns, limit);
+            });
+    }
+    else if (function_name == "rank")
+    {
+        pipeline.addSimpleTransform(
+            [&](const DB::Block & header) {
+                return 
std::make_shared<WindowGroupLimitTransform<WindowGroupLimitFunction::Rank>>(
+                    header, partition_columns, sort_columns, limit);
+            });
+    }
+    else if (function_name == "dense_rank")
+    {
+        pipeline.addSimpleTransform(
+            [&](const DB::Block & header)
+            {
+                return 
std::make_shared<WindowGroupLimitTransform<WindowGroupLimitFunction::DenseRank>>(
+                    header, partition_columns, sort_columns, limit);
+            });
+    }
+    else
+    {
+        throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unsupport function 
{} in WindowGroupLimit", function_name);
+    }
+}
+}
diff --git a/cpp-ch/local-engine/Operator/WindowGroupLimitStep.h 
b/cpp-ch/local-engine/Operator/WindowGroupLimitStep.h
new file mode 100644
index 000000000..bbbbf42ab
--- /dev/null
+++ b/cpp-ch/local-engine/Operator/WindowGroupLimitStep.h
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+#include <vector>
+#include <Core/Block.h>
+#include <Processors/QueryPlan/IQueryPlanStep.h>
+#include <Processors/QueryPlan/ITransformingStep.h>
+
+namespace local_engine
+{
+class WindowGroupLimitStep : public DB::ITransformingStep
+{
+public:
+    explicit WindowGroupLimitStep(
+        const DB::DataStream & input_stream_,
+        const String & function_name_,
+        const std::vector<size_t> partition_columns_,
+        const std::vector<size_t> sort_columns_,
+        size_t limit_);
+    ~WindowGroupLimitStep() override = default;
+
+    String getName() const override { return "WindowGroupLimitStep"; }
+
+    void transformPipeline(DB::QueryPipelineBuilder & pipeline, const 
DB::BuildQueryPipelineSettings & settings) override;
+    void describePipeline(DB::IQueryPlanStep::FormatSettings & settings) const 
override;
+    void updateOutputStream() override;
+
+private:
+    // window function name, one of row_number, rank and dense_rank
+    String function_name;
+    std::vector<size_t> partition_columns;
+    std::vector<size_t> sort_columns;
+    size_t limit;
+};
+
+}
diff --git a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp 
b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp
index 42d4f4d4d..cc7738a15 100644
--- a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp
+++ b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp
@@ -14,25 +14,26 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-#include <base/find_symbols.h>
+#include "AdvancedParametersParseUtil.h"
 #include <IO/ReadBufferFromString.h>
 #include <IO/ReadHelpers.h>
-#include <Common/Exception.h>
-#include <Parser/AdvancedParametersParseUtil.h>
+#include <base/find_symbols.h>
 #include <Poco/Logger.h>
+#include <Common/Exception.h>
 #include <Common/logger_useful.h>
+
 namespace DB::ErrorCodes
 {
-    extern const int BAD_ARGUMENTS;
+extern const int BAD_ARGUMENTS;
 }
 
 namespace local_engine
 {
 
-template<typename T>
+template <typename T>
 void tryAssign(const std::unordered_map<String, String> & kvs, const String & 
key, T & v);
 
-template<>
+template <>
 void tryAssign<String>(const std::unordered_map<String, String> & kvs, const 
String & key, String & v)
 {
     auto it = kvs.find(key);
@@ -40,7 +41,7 @@ void tryAssign<String>(const std::unordered_map<String, 
String> & kvs, const Str
         v = it->second;
 }
 
-template<>
+template <>
 void tryAssign<bool>(const std::unordered_map<String, String> & kvs, const 
String & key, bool & v)
 {
     auto it = kvs.find(key);
@@ -57,7 +58,7 @@ void tryAssign<bool>(const std::unordered_map<String, String> 
& kvs, const Strin
     }
 }
 
-template<>
+template <>
 void tryAssign<Int64>(const std::unordered_map<String, String> & kvs, const 
String & key, Int64 & v)
 {
     auto it = kvs.find(key);
@@ -94,9 +95,9 @@ void readStringUntilCharsInto(String & s, DB::ReadBuffer & 
buf)
 std::unordered_map<String, std::unordered_map<String, String>> 
convertToKVs(const String & advance)
 {
     std::unordered_map<String, std::unordered_map<String, String>> res;
-    std::unordered_map<String, String> *kvs;
+    std::unordered_map<String, String> * kvs;
     DB::ReadBufferFromString in(advance);
-    while(!in.eof())
+    while (!in.eof())
     {
         String key;
         readStringUntilCharsInto<'=', '\n', ':'>(key, in);
@@ -146,5 +147,13 @@ JoinOptimizationInfo JoinOptimizationInfo::parse(const 
String & advance)
     tryAssign(kvs, "numPartitions", info.partitions_num);
     return info;
 }
-}
 
+WindowGroupOptimizationInfo WindowGroupOptimizationInfo::parse(const String & 
advance)
+{
+    WindowGroupOptimizationInfo info;
+    auto kkvs = convertToKVs(advance);
+    auto & kvs = kkvs["WindowGroupLimitParameters"];
+    tryAssign(kvs, "window_function", info.window_function);
+    return info;
+}
+}
diff --git a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h 
b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h
index 5f6fe6d25..fc478db33 100644
--- a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h
+++ b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h
@@ -16,10 +16,10 @@
  */
 #pragma once
 #include <unordered_map>
+#include <base/types.h>
 
 namespace local_engine
 {
-
 std::unordered_map<String, std::unordered_map<String, String>> 
convertToKVs(const String & advance);
 
 
@@ -38,5 +38,10 @@ struct JoinOptimizationInfo
 
     static JoinOptimizationInfo parse(const String & advance);
 };
-}
 
+struct WindowGroupOptimizationInfo
+{
+    String window_function;
+    static WindowGroupOptimizationInfo parse(const String & advnace);
+};
+}
diff --git a/cpp-ch/local-engine/Parser/RelParser.cpp 
b/cpp-ch/local-engine/Parser/RelParser.cpp
index f651146a3..a7f6d0586 100644
--- a/cpp-ch/local-engine/Parser/RelParser.cpp
+++ b/cpp-ch/local-engine/Parser/RelParser.cpp
@@ -30,8 +30,8 @@ namespace DB
 {
 namespace ErrorCodes
 {
-    extern const int BAD_ARGUMENTS;
-    extern const int LOGICAL_ERROR;
+extern const int BAD_ARGUMENTS;
+extern const int LOGICAL_ERROR;
 }
 }
 
@@ -89,14 +89,15 @@ DB::QueryPlanPtr RelParser::parseOp(const substrait::Rel & 
rel, std::list<const
     return parse(std::move(query_plan), rel, rel_stack);
 }
 
-std::map<std::string, std::string> 
RelParser::parseFormattedRelAdvancedOptimization(const 
substrait::extensions::AdvancedExtension &advanced_extension)
+std::map<std::string, std::string>
+RelParser::parseFormattedRelAdvancedOptimization(const 
substrait::extensions::AdvancedExtension & advanced_extension)
 {
     std::map<std::string, std::string> configs;
     if (advanced_extension.has_optimization())
     {
         google::protobuf::StringValue msg;
         advanced_extension.optimization().UnpackTo(&msg);
-        Poco::StringTokenizer kvs( msg.value(), "\n");
+        Poco::StringTokenizer kvs(msg.value(), "\n");
         for (auto & kv : kvs)
         {
             if (kv.empty())
@@ -114,7 +115,8 @@ std::map<std::string, std::string> 
RelParser::parseFormattedRelAdvancedOptimizat
     return configs;
 }
 
-std::string RelParser::getStringConfig(const std::map<std::string, 
std::string> & configs, const std::string & key, const std::string & 
default_value)
+std::string
+RelParser::getStringConfig(const std::map<std::string, std::string> & configs, 
const std::string & key, const std::string & default_value)
 {
     auto it = configs.find(key);
     if (it == configs.end())
@@ -150,6 +152,7 @@ RelParserFactory::RelParserBuilder 
RelParserFactory::getBuilder(UInt32 k)
 }
 
 void registerWindowRelParser(RelParserFactory & factory);
+void registerWindowGroupLimitRelParser(RelParserFactory & factory);
 void registerSortRelParser(RelParserFactory & factory);
 void registerExpandRelParser(RelParserFactory & factory);
 void registerAggregateParser(RelParserFactory & factory);
@@ -162,6 +165,7 @@ void registerRelParsers()
 {
     auto & factory = RelParserFactory::instance();
     registerWindowRelParser(factory);
+    registerWindowGroupLimitRelParser(factory);
     registerSortRelParser(factory);
     registerExpandRelParser(factory);
     registerAggregateParser(factory);
diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp 
b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
index 589f3826a..e893dd35b 100644
--- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
+++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
@@ -391,9 +391,9 @@ void adjustOutput(const DB::QueryPlanPtr & query_plan, 
const substrait::PlanRel
         }
         if (need_final_project)
         {
-            ActionsDAG final_project
-                = ActionsDAG::makeConvertingActions(original_cols, final_cols, 
ActionsDAG::MatchColumnsMode::Position);
-            QueryPlanStepPtr final_project_step = 
std::make_unique<ExpressionStep>(query_plan->getCurrentDataStream(), 
std::move(final_project));
+            ActionsDAG final_project = 
ActionsDAG::makeConvertingActions(original_cols, final_cols, 
ActionsDAG::MatchColumnsMode::Position);
+            QueryPlanStepPtr final_project_step
+                = 
std::make_unique<ExpressionStep>(query_plan->getCurrentDataStream(), 
std::move(final_project));
             final_project_step->setStepDescription("Project for output 
schema");
             query_plan->addStep(std::move(final_project_step));
         }
@@ -504,6 +504,7 @@ QueryPlanPtr SerializedPlanParser::parseOp(const 
substrait::Rel & rel, std::list
         case substrait::Rel::RelTypeCase::kWindow:
         case substrait::Rel::RelTypeCase::kJoin:
         case substrait::Rel::RelTypeCase::kCross:
+        case substrait::Rel::RelTypeCase::kWindowGroupLimit:
         case substrait::Rel::RelTypeCase::kExpand: {
             auto op_parser = 
RelParserFactory::instance().getBuilder(rel.rel_type_case())(this);
             query_plan = op_parser->parseOp(rel, rel_stack);
@@ -606,7 +607,7 @@ void SerializedPlanParser::parseArrayJoinArguments(
 }
 
 ActionsDAG::NodeRawConstPtrs SerializedPlanParser::parseArrayJoinWithDAG(
-    const substrait::Expression & rel, std::vector<String> & result_names, 
ActionsDAG& actions_dag, bool keep_result, bool position)
+    const substrait::Expression & rel, std::vector<String> & result_names, 
ActionsDAG & actions_dag, bool keep_result, bool position)
 {
     if (!rel.has_scalar_function())
         throw Exception(ErrorCodes::BAD_ARGUMENTS, "The root of expression 
should be a scalar function:\n {}", rel.DebugString());
@@ -723,7 +724,7 @@ ActionsDAG::NodeRawConstPtrs 
SerializedPlanParser::parseArrayJoinWithDAG(
 }
 
 const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG(
-    const substrait::Expression & rel, std::string & result_name, ActionsDAG& 
actions_dag, bool keep_result)
+    const substrait::Expression & rel, std::string & result_name, ActionsDAG & 
actions_dag, bool keep_result)
 {
     if (!rel.has_scalar_function())
         throw Exception(ErrorCodes::BAD_ARGUMENTS, "the root of expression 
should be a scalar function:\n {}", rel.DebugString());
@@ -785,9 +786,8 @@ bool 
SerializedPlanParser::isFunction(substrait::Expression_ScalarFunction rel,
 }
 
 void SerializedPlanParser::parseFunctionOrExpression(
-    const substrait::Expression & rel, std::string & result_name, ActionsDAG& 
actions_dag, bool keep_result)
+    const substrait::Expression & rel, std::string & result_name, ActionsDAG & 
actions_dag, bool keep_result)
 {
-
     if (rel.has_scalar_function())
         parseFunctionWithDAG(rel, result_name, actions_dag, keep_result);
     else
@@ -798,11 +798,7 @@ void SerializedPlanParser::parseFunctionOrExpression(
 }
 
 void SerializedPlanParser::parseJsonTuple(
-    const substrait::Expression & rel,
-    std::vector<String> & result_names,
-    ActionsDAG& actions_dag,
-    bool keep_result,
-    bool)
+    const substrait::Expression & rel, std::vector<String> & result_names, 
ActionsDAG & actions_dag, bool keep_result, bool)
 {
     const auto & scalar_function = rel.scalar_function();
     auto function_signature = 
function_mapping.at(std::to_string(rel.scalar_function().function_reference()));
@@ -861,7 +857,7 @@ void SerializedPlanParser::parseJsonTuple(
 }
 
 const ActionsDAG::Node *
-SerializedPlanParser::toFunctionNode(ActionsDAG& actions_dag, const String & 
function, const ActionsDAG::NodeRawConstPtrs & args)
+SerializedPlanParser::toFunctionNode(ActionsDAG & actions_dag, const String & 
function, const ActionsDAG::NodeRawConstPtrs & args)
 {
     auto function_builder = FunctionFactory::instance().get(function, context);
     std::string args_name = join(args, ',');
@@ -1073,7 +1069,7 @@ std::pair<DataTypePtr, Field> 
SerializedPlanParser::parseLiteral(const substrait
     return std::make_pair(std::move(type), std::move(field));
 }
 
-const ActionsDAG::Node * SerializedPlanParser::parseExpression(ActionsDAG& 
actions_dag, const substrait::Expression & rel)
+const ActionsDAG::Node * SerializedPlanParser::parseExpression(ActionsDAG & 
actions_dag, const substrait::Expression & rel)
 {
     switch (rel.rex_type_case())
     {
@@ -1534,8 +1530,7 @@ ASTPtr ASTParser::parseArgumentToAST(const Names & names, 
const substrait::Expre
     }
 }
 
-void SerializedPlanParser::removeNullableForRequiredColumns(
-    const std::set<String> & require_columns, ActionsDAG & actions_dag) const
+void SerializedPlanParser::removeNullableForRequiredColumns(const 
std::set<String> & require_columns, ActionsDAG & actions_dag) const
 {
     for (const auto & item : require_columns)
     {
@@ -1550,7 +1545,7 @@ void 
SerializedPlanParser::removeNullableForRequiredColumns(
 }
 
 void SerializedPlanParser::wrapNullable(
-    const std::vector<String> & columns, ActionsDAG& actions_dag, 
std::map<std::string, std::string> & nullable_measure_names)
+    const std::vector<String> & columns, ActionsDAG & actions_dag, 
std::map<std::string, std::string> & nullable_measure_names)
 {
     for (const auto & item : columns)
     {
diff --git a/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.cpp 
b/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.cpp
new file mode 100644
index 000000000..f6c10386f
--- /dev/null
+++ b/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.cpp
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "WindowGroupLimitRelParser.h"
+#include <Interpreters/ActionsDAG.h>
+#include <Operator/WindowGroupLimitStep.h>
+#include <Parser/AdvancedParametersParseUtil.h>
+#include <Parser/SortRelParser.h>
+#include <Parser/WindowGroupLimitRelParser.h>
+#include <Processors/QueryPlan/ExpressionStep.h>
+#include <google/protobuf/repeated_field.h>
+#include <google/protobuf/wrappers.pb.h>
+#include "AdvancedParametersParseUtil.h"
+
+namespace DB::ErrorCodes
+{
+extern const int BAD_ARGUMENTS;
+}
+
+namespace local_engine
+{
+WindowGroupLimitRelParser::WindowGroupLimitRelParser(SerializedPlanParser * 
plan_parser_) : RelParser(plan_parser_)
+{
+}
+
+DB::QueryPlanPtr
+WindowGroupLimitRelParser::parse(DB::QueryPlanPtr current_plan_, const 
substrait::Rel & rel, std::list<const substrait::Rel *> & rel_stack_)
+{
+    const auto win_rel_def = rel.windowgrouplimit();
+    google::protobuf::StringValue optimize_info_str;
+    
optimize_info_str.ParseFromString(win_rel_def.advanced_extension().optimization().value());
+    auto optimization_info = 
WindowGroupOptimizationInfo::parse(optimize_info_str.value());
+    window_function_name = optimization_info.window_function;
+
+    current_plan = std::move(current_plan_);
+
+    auto partition_fields = 
parsePartitoinFields(win_rel_def.partition_expressions());
+    auto sort_fields = parseSortFields(win_rel_def.sorts());
+    size_t limit = static_cast<size_t>(win_rel_def.limit());
+
+    auto window_group_limit_step = std::make_unique<WindowGroupLimitStep>(
+        current_plan->getCurrentDataStream(), window_function_name, 
partition_fields, sort_fields, limit);
+    window_group_limit_step->setStepDescription("Window group limit");
+    steps.emplace_back(window_group_limit_step.get());
+    current_plan->addStep(std::move(window_group_limit_step));
+
+    return std::move(current_plan);
+}
+
+std::vector<size_t>
+WindowGroupLimitRelParser::parsePartitoinFields(const 
google::protobuf::RepeatedPtrField<substrait::Expression> & expressions)
+{
+    std::vector<size_t> fields;
+    for (const auto & expr : expressions)
+    {
+        if (expr.has_selection())
+        {
+            
fields.push_back(static_cast<size_t>(expr.selection().direct_reference().struct_field().field()));
+        }
+        else if (expr.has_literal())
+        {
+            continue;
+        }
+        else
+        {
+            throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknow 
expression: {}", expr.DebugString());
+        }
+    }
+    return fields;
+}
+
+std::vector<size_t> WindowGroupLimitRelParser::parseSortFields(const 
google::protobuf::RepeatedPtrField<substrait::SortField> & sort_fields)
+{
+    std::vector<size_t> fields;
+    for (const auto sort_field : sort_fields)
+    {
+        if (sort_field.expr().has_literal())
+        {
+            continue;
+        }
+        else if (sort_field.expr().has_selection())
+        {
+            
fields.push_back(static_cast<size_t>(sort_field.expr().selection().direct_reference().struct_field().field()));
+        }
+        else
+        {
+            throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknown 
expression: {}", sort_field.expr().DebugString());
+        }
+    }
+    return fields;
+}
+
+void registerWindowGroupLimitRelParser(RelParserFactory & factory)
+{
+    auto builder = [](SerializedPlanParser * plan_parser) { return 
std::make_shared<WindowGroupLimitRelParser>(plan_parser); };
+    factory.registerBuilder(substrait::Rel::RelTypeCase::kWindowGroupLimit, 
builder);
+}
+}
diff --git a/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.h 
b/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.h
new file mode 100644
index 000000000..c9c503ed4
--- /dev/null
+++ b/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.h
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+#include <unordered_map>
+#include <Core/Field.h>
+#include <Core/SortDescription.h>
+#include <DataTypes/IDataType.h>
+#include <Interpreters/WindowDescription.h>
+#include <Parser/AggregateFunctionParser.h>
+#include <Parser/RelParser.h>
+#include <Parser/SerializedPlanParser.h>
+#include <Processors/QueryPlan/QueryPlan.h>
+#include <Poco/Logger.h>
+#include <Common/logger_useful.h>
+
+namespace local_engine
+{
+/// Similar to WindowRelParser. Some differences
+/// 1. cannot support aggregate functions. only support window functions: 
row_number, rank, dense_rank
+/// 2. row_number, rank and dense_rank are mapped to new variants
+/// 3. the output columns don't contain window function results
+class WindowGroupLimitRelParser : public RelParser
+{
+public:
+    explicit WindowGroupLimitRelParser(SerializedPlanParser * plan_parser_);
+    ~WindowGroupLimitRelParser() override = default;
+    DB::QueryPlanPtr
+    parse(DB::QueryPlanPtr current_plan_, const substrait::Rel & rel, 
std::list<const substrait::Rel *> & rel_stack_) override;
+    const substrait::Rel & getSingleInput(const substrait::Rel & rel) override 
{ return rel.windowgrouplimit().input(); }
+
+private:
+    DB::QueryPlanPtr current_plan;
+    String window_function_name;
+
+    std::vector<size_t> parsePartitoinFields(const 
google::protobuf::RepeatedPtrField<substrait::Expression> & expressions);
+    std::vector<size_t> parseSortFields(const 
google::protobuf::RepeatedPtrField<substrait::SortField> & sort_fields);
+};
+}
diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
 
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
index a55926d76..dd4150806 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
@@ -41,6 +41,7 @@ import 
org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
 import org.apache.spark.sql.execution.joins.BuildSideRelation
 import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.execution.python.ArrowEvalPythonExec
+import org.apache.spark.sql.execution.window._
 import org.apache.spark.sql.hive.{HiveTableScanExecTransformer, 
HiveUDFTransformer}
 import org.apache.spark.sql.types.{DecimalType, LongType, NullType, StructType}
 import org.apache.spark.sql.vectorized.ColumnarBatch
@@ -678,6 +679,15 @@ trait SparkPlanExecApi {
     }
   }
 
+  def genWindowGroupLimitTransformer(
+      partitionSpec: Seq[Expression],
+      orderSpec: Seq[SortOrder],
+      rankLikeFunction: Expression,
+      limit: Int,
+      mode: WindowGroupLimitMode,
+      child: SparkPlan): SparkPlan =
+    WindowGroupLimitExecTransformer(partitionSpec, orderSpec, 
rankLikeFunction, limit, mode, child)
+
   def genHiveUDFTransformer(
       expr: Expression,
       attributeSeq: Seq[Attribute]): ExpressionTransformer = {
diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala
 
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala
index cdc71f447..9ca60177c 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala
@@ -372,7 +372,7 @@ object OffloadOthers {
           val windowGroupLimitPlan = SparkShimLoader.getSparkShims
             .getWindowGroupLimitExecShim(plan)
             .asInstanceOf[WindowGroupLimitExecShim]
-          WindowGroupLimitExecTransformer(
+          
BackendsApiManager.getSparkPlanExecApiInstance.genWindowGroupLimitTransformer(
             windowGroupLimitPlan.partitionSpec,
             windowGroupLimitPlan.orderSpec,
             windowGroupLimitPlan.rankLikeFunction,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to