This is an automated email from the ASF dual-hosted git repository.
lgbo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new afccfed98 [GLUTEN-7087][CH] Support `WindowGroupLimitExec` (#7176)
afccfed98 is described below
commit afccfed98ef85590975997adf35468ad227a69a1
Author: lgbo <[email protected]>
AuthorDate: Thu Sep 12 09:55:18 2024 +0800
[GLUTEN-7087][CH] Support `WindowGroupLimitExec` (#7176)
* support WindowGroupLimit
* 0903
* implement window group limit
---
.../gluten/backendsapi/clickhouse/CHBackend.scala | 2 +
.../clickhouse/CHSparkPlanExecApi.scala | 16 +
.../CHWindowGroupLimitExecTransformer.scala | 187 +++++++++++
.../GlutenClickHouseTPCDSAbstractSuite.scala | 2 +-
.../GlutenClickHouseTPCHSaltNullParquetSuite.scala | 6 +-
cpp-ch/local-engine/Common/CHUtil.cpp | 41 +--
cpp-ch/local-engine/Operator/ReplicateRowsStep.cpp | 20 +-
.../local-engine/Operator/WindowGroupLimitStep.cpp | 365 +++++++++++++++++++++
.../local-engine/Operator/WindowGroupLimitStep.h | 51 +++
.../Parser/AdvancedParametersParseUtil.cpp | 31 +-
.../Parser/AdvancedParametersParseUtil.h | 9 +-
cpp-ch/local-engine/Parser/RelParser.cpp | 14 +-
.../local-engine/Parser/SerializedPlanParser.cpp | 29 +-
.../Parser/WindowGroupLimitRelParser.cpp | 112 +++++++
.../Parser/WindowGroupLimitRelParser.h | 52 +++
.../gluten/backendsapi/SparkPlanExecApi.scala | 10 +
.../extension/columnar/OffloadSingleNode.scala | 2 +-
17 files changed, 876 insertions(+), 73 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
index 69ea899c4..45aee4322 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
@@ -418,4 +418,6 @@ object CHBackendSettings extends BackendSettingsApi with
Logging {
}
}
}
+
+ override def supportWindowGroupLimitExec(rankLikeFunction: Expression):
Boolean = true
}
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index 1108b8b3c..f765a75d2 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -53,6 +53,7 @@ import
org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleEx
import org.apache.spark.sql.execution.joins.{BuildSideRelation,
ClickHouseBuildSideRelation, HashedRelationBroadcastMode}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.utils.{CHExecUtil, PushDownUtil}
+import org.apache.spark.sql.execution.window._
import org.apache.spark.sql.types.{DecimalType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch
@@ -909,4 +910,19 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with
Logging {
toScale: Int): DecimalType = {
SparkShimLoader.getSparkShims.genDecimalRoundExpressionOutput(decimalType,
toScale)
}
+
+ override def genWindowGroupLimitTransformer(
+ partitionSpec: Seq[Expression],
+ orderSpec: Seq[SortOrder],
+ rankLikeFunction: Expression,
+ limit: Int,
+ mode: WindowGroupLimitMode,
+ child: SparkPlan): SparkPlan =
+ CHWindowGroupLimitExecTransformer(
+ partitionSpec,
+ orderSpec,
+ rankLikeFunction,
+ limit,
+ mode,
+ child)
}
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHWindowGroupLimitExecTransformer.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHWindowGroupLimitExecTransformer.scala
new file mode 100644
index 000000000..c2648f29e
--- /dev/null
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHWindowGroupLimitExecTransformer.scala
@@ -0,0 +1,187 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.execution
+
+import org.apache.gluten.backendsapi.BackendsApiManager
+import org.apache.gluten.exception.GlutenNotSupportException
+import org.apache.gluten.expression._
+import org.apache.gluten.expression.{ConverterUtils, ExpressionConverter}
+import org.apache.gluten.extension.ValidationResult
+import org.apache.gluten.metrics.MetricsUpdater
+import org.apache.gluten.substrait.`type`.TypeBuilder
+import org.apache.gluten.substrait.SubstraitContext
+import org.apache.gluten.substrait.extensions.ExtensionBuilder
+import org.apache.gluten.substrait.rel.{RelBuilder, RelNode}
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical.{AllTuples,
ClusteredDistribution, Distribution, Partitioning}
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.window.{Final, Partial,
WindowGroupLimitMode}
+
+import com.google.protobuf.StringValue
+import io.substrait.proto.SortField
+
+import scala.collection.JavaConverters._
+
+case class CHWindowGroupLimitExecTransformer(
+ partitionSpec: Seq[Expression],
+ orderSpec: Seq[SortOrder],
+ rankLikeFunction: Expression,
+ limit: Int,
+ mode: WindowGroupLimitMode,
+ child: SparkPlan)
+ extends UnaryTransformSupport {
+
+ @transient override lazy val metrics =
+
BackendsApiManager.getMetricsApiInstance.genWindowTransformerMetrics(sparkContext)
+
+ override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+ copy(child = newChild)
+
+ override def metricsUpdater(): MetricsUpdater =
+
BackendsApiManager.getMetricsApiInstance.genWindowTransformerMetricsUpdater(metrics)
+
+ override def output: Seq[Attribute] = child.output
+
+ override def requiredChildDistribution: Seq[Distribution] = mode match {
+ case Partial => super.requiredChildDistribution
+ case Final =>
+ if (partitionSpec.isEmpty) {
+ AllTuples :: Nil
+ } else {
+ ClusteredDistribution(partitionSpec) :: Nil
+ }
+ }
+
+ override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
+ if
(BackendsApiManager.getSettings.requiredChildOrderingForWindowGroupLimit()) {
+ Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec)
+ } else {
+ Seq(Nil)
+ }
+ }
+
+ override def outputOrdering: Seq[SortOrder] = {
+ if (requiredChildOrdering.forall(_.isEmpty)) {
+ // The Velox backend `TopNRowNumber` does not require child ordering,
because it
+ // uses hash table to store partition and use priority queue to track of
top limit rows.
+ // Ideally, the output of `TopNRowNumber` is unordered but it is grouped
for partition keys.
+ // To be safe, here we do not propagate the ordering.
+ // TODO: Make the framework aware of grouped data distribution
+ Nil
+ } else {
+ child.outputOrdering
+ }
+ }
+
+ override def outputPartitioning: Partitioning = child.outputPartitioning
+
+ def getWindowGroupLimitRel(
+ context: SubstraitContext,
+ originalInputAttributes: Seq[Attribute],
+ operatorId: Long,
+ input: RelNode,
+ validation: Boolean): RelNode = {
+ val args = context.registeredFunction
+ // Partition By Expressions
+ val partitionsExpressions = partitionSpec
+ .map(
+ ExpressionConverter
+ .replaceWithExpressionTransformer(_, attributeSeq = child.output)
+ .doTransform(args))
+ .asJava
+
+ // Sort By Expressions
+ val sortFieldList =
+ orderSpec.map {
+ order =>
+ val builder = SortField.newBuilder()
+ val exprNode = ExpressionConverter
+ .replaceWithExpressionTransformer(order.child, attributeSeq =
child.output)
+ .doTransform(args)
+ builder.setExpr(exprNode.toProtobuf)
+
builder.setDirectionValue(SortExecTransformer.transformSortDirection(order))
+ builder.build()
+ }.asJava
+ if (!validation) {
+ val windowFunction = rankLikeFunction match {
+ case _: RowNumber => ExpressionNames.ROW_NUMBER
+ case _: Rank => ExpressionNames.RANK
+ case _: DenseRank => ExpressionNames.DENSE_RANK
+ case _ => throw new GlutenNotSupportException(s"Unknow window function
$rankLikeFunction")
+ }
+ val parametersStr = new StringBuffer("WindowGroupLimitParameters:")
+ parametersStr
+ .append("window_function=")
+ .append(windowFunction)
+ .append("\n")
+ val message =
StringValue.newBuilder().setValue(parametersStr.toString).build()
+ val extensionNode = ExtensionBuilder.makeAdvancedExtension(
+ BackendsApiManager.getTransformerApiInstance.packPBMessage(message),
+ null)
+ RelBuilder.makeWindowGroupLimitRel(
+ input,
+ partitionsExpressions,
+ sortFieldList,
+ limit,
+ extensionNode,
+ context,
+ operatorId)
+ } else {
+ // Use a extension node to send the input types through Substrait plan
for validation.
+ val inputTypeNodeList = originalInputAttributes
+ .map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
+ .asJava
+ val extensionNode = ExtensionBuilder.makeAdvancedExtension(
+ BackendsApiManager.getTransformerApiInstance.packPBMessage(
+ TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
+
+ RelBuilder.makeWindowGroupLimitRel(
+ input,
+ partitionsExpressions,
+ sortFieldList,
+ limit,
+ extensionNode,
+ context,
+ operatorId)
+ }
+ }
+
+ override protected def doValidateInternal(): ValidationResult = {
+ if
(!BackendsApiManager.getSettings.supportWindowGroupLimitExec(rankLikeFunction))
{
+ return ValidationResult
+ .failed(s"Found unsupported rank like function: $rankLikeFunction")
+ }
+ val substraitContext = new SubstraitContext
+ val operatorId = substraitContext.nextOperatorId(this.nodeName)
+
+ val relNode =
+ getWindowGroupLimitRel(substraitContext, child.output, operatorId, null,
validation = true)
+
+ doNativeValidation(substraitContext, relNode)
+ }
+
+ override protected def doTransform(context: SubstraitContext):
TransformContext = {
+ val childCtx = child.asInstanceOf[TransformSupport].transform(context)
+ val operatorId = context.nextOperatorId(this.nodeName)
+
+ val currRel =
+ getWindowGroupLimitRel(context, child.output, operatorId, childCtx.root,
validation = false)
+ assert(currRel != null, "Window Group Limit Rel should be valid")
+ TransformContext(childCtx.outputAttributes, output, currRel)
+ }
+}
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala
index 03b26fa98..abb7d27ff 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala
@@ -62,7 +62,7 @@ abstract class GlutenClickHouseTPCDSAbstractSuite
})
protected def fallbackSets(isAqe: Boolean): Set[Int] = {
- if (isSparkVersionGE("3.5")) Set(44, 67, 70) else Set.empty[Int]
+ Set.empty[Int]
}
protected def excludedTpcdsQueries: Set[String] = Set(
"q66" // inconsistent results
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
index f7cf0de37..9ac35441a 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
@@ -1855,7 +1855,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends
GlutenClickHouseTPCHAbstr
| ) t1
|) t2 where rank = 1
""".stripMargin
- compareResultsAgainstVanillaSpark(sql, true, { _ => },
isSparkVersionLE("3.3"))
+ compareResultsAgainstVanillaSpark(sql, true, { _ => })
}
test("GLUTEN-1874 not null in both streams") {
@@ -1873,7 +1873,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends
GlutenClickHouseTPCHAbstr
| ) t1
|) t2 where rank = 1
""".stripMargin
- compareResultsAgainstVanillaSpark(sql, true, { _ => },
isSparkVersionLE("3.3"))
+ compareResultsAgainstVanillaSpark(sql, true, { _ => })
}
test("GLUTEN-2095: test cast(string as binary)") {
@@ -2456,7 +2456,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends
GlutenClickHouseTPCHAbstr
| ) t1
|) t2 where rank = 1 order by p_partkey limit 100
|""".stripMargin
- runQueryAndCompare(sql, noFallBack = isSparkVersionLE("3.3"))({ _ => })
+ runQueryAndCompare(sql, noFallBack = true)({ _ => })
}
test("GLUTEN-4190: crush on flattening a const null column") {
diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp
b/cpp-ch/local-engine/Common/CHUtil.cpp
index b702082d6..94a214e5e 100644
--- a/cpp-ch/local-engine/Common/CHUtil.cpp
+++ b/cpp-ch/local-engine/Common/CHUtil.cpp
@@ -53,6 +53,7 @@
#include <Parser/SubstraitParserUtils.h>
#include <Planner/PlannerActionsVisitor.h>
#include <Processors/Chunk.h>
+#include <Processors/Formats/IOutputFormat.h>
#include <Processors/QueryPlan/ExpressionStep.h>
#include <Processors/QueryPlan/QueryPlan.h>
#include <QueryPipeline/QueryPipelineBuilder.h>
@@ -315,7 +316,6 @@ DB::Block
BlockUtil::concatenateBlocksMemoryEfficiently(std::vector<DB::Block> &
return out;
}
-
size_t PODArrayUtil::adjustMemoryEfficientSize(size_t n)
{
/// According to definition of DEFUALT_BLOCK_SIZE
@@ -560,9 +560,7 @@ std::map<std::string, std::string>
BackendInitializerUtil::getBackendConfMap(std
}
std::vector<String> BackendInitializerUtil::wrapDiskPathConfig(
- const String & path_prefix,
- const String & path_suffix,
- Poco::Util::AbstractConfiguration & config)
+ const String & path_prefix, const String & path_suffix,
Poco::Util::AbstractConfiguration & config)
{
std::vector<String> changed_paths;
if (path_prefix.empty() && path_suffix.empty())
@@ -657,9 +655,7 @@ DB::Context::ConfigurationPtr
BackendInitializerUtil::initConfig(std::map<std::s
auto path_need_clean = wrapDiskPathConfig("", "/" + pid, *config);
std::lock_guard lock(BackendFinalizerUtil::paths_mutex);
BackendFinalizerUtil::paths_need_to_clean.insert(
- BackendFinalizerUtil::paths_need_to_clean.end(),
- path_need_clean.begin(),
- path_need_clean.end());
+ BackendFinalizerUtil::paths_need_to_clean.end(),
path_need_clean.begin(), path_need_clean.end());
}
return config;
}
@@ -683,7 +679,9 @@ void
BackendInitializerUtil::initEnvs(DB::Context::ConfigurationPtr config)
{
const std::string config_timezone = config->getString("timezone");
const String mapped_timezone =
DateTimeUtil::convertTimeZone(config_timezone);
- if (0 != setenv("TZ", mapped_timezone.data(), 1)) //
NOLINT(concurrency-mt-unsafe) // ok if not called concurrently with other
setenv/getenv
+ if (0
+ != setenv(
+ "TZ", mapped_timezone.data(), 1)) //
NOLINT(concurrency-mt-unsafe) // ok if not called concurrently with other
setenv/getenv
throw Poco::Exception("Cannot setenv TZ variable");
tzset();
@@ -807,8 +805,7 @@ void
BackendInitializerUtil::initSettings(std::map<std::string, std::string> & b
{
auto mem_gb = task_memory / static_cast<double>(1_GiB);
// 2.8x+5, Heuristics calculate the block size of external sort,
[8,16]
- settings.prefer_external_sort_block_bytes = std::max(std::min(
- static_cast<size_t>(2.8*mem_gb + 5), 16ul), 8ul) * 1024 * 1024;
+ settings.prefer_external_sort_block_bytes =
std::max(std::min(static_cast<size_t>(2.8 * mem_gb + 5), 16ul), 8ul) * 1024 *
1024;
}
}
}
@@ -848,10 +845,14 @@ void
BackendInitializerUtil::initContexts(DB::Context::ConfigurationPtr config)
global_context->setMarkCache(mark_cache_policy, mark_cache_size,
mark_cache_size_ratio);
- String index_uncompressed_cache_policy =
config->getString("index_uncompressed_cache_policy",
DEFAULT_INDEX_UNCOMPRESSED_CACHE_POLICY);
- size_t index_uncompressed_cache_size =
config->getUInt64("index_uncompressed_cache_size",
DEFAULT_INDEX_UNCOMPRESSED_CACHE_MAX_SIZE);
- double index_uncompressed_cache_size_ratio =
config->getDouble("index_uncompressed_cache_size_ratio",
DEFAULT_INDEX_UNCOMPRESSED_CACHE_SIZE_RATIO);
-
global_context->setIndexUncompressedCache(index_uncompressed_cache_policy,
index_uncompressed_cache_size, index_uncompressed_cache_size_ratio);
+ String index_uncompressed_cache_policy
+ = config->getString("index_uncompressed_cache_policy",
DEFAULT_INDEX_UNCOMPRESSED_CACHE_POLICY);
+ size_t index_uncompressed_cache_size
+ = config->getUInt64("index_uncompressed_cache_size",
DEFAULT_INDEX_UNCOMPRESSED_CACHE_MAX_SIZE);
+ double index_uncompressed_cache_size_ratio
+ = config->getDouble("index_uncompressed_cache_size_ratio",
DEFAULT_INDEX_UNCOMPRESSED_CACHE_SIZE_RATIO);
+ global_context->setIndexUncompressedCache(
+ index_uncompressed_cache_policy, index_uncompressed_cache_size,
index_uncompressed_cache_size_ratio);
String index_mark_cache_policy =
config->getString("index_mark_cache_policy", DEFAULT_INDEX_MARK_CACHE_POLICY);
size_t index_mark_cache_size =
config->getUInt64("index_mark_cache_size", DEFAULT_INDEX_MARK_CACHE_MAX_SIZE);
@@ -1023,11 +1024,13 @@ void BackendFinalizerUtil::finalizeGlobally()
StorageMergeTreeFactory::clear();
QueryContext::resetGlobal();
std::lock_guard lock(paths_mutex);
- std::ranges::for_each(paths_need_to_clean, [](const auto & path)
- {
- if (fs::exists(path))
- fs::remove_all(path);
- });
+ std::ranges::for_each(
+ paths_need_to_clean,
+ [](const auto & path)
+ {
+ if (fs::exists(path))
+ fs::remove_all(path);
+ });
paths_need_to_clean.clear();
}
diff --git a/cpp-ch/local-engine/Operator/ReplicateRowsStep.cpp
b/cpp-ch/local-engine/Operator/ReplicateRowsStep.cpp
index f2d4bc8a8..ecb027c18 100644
--- a/cpp-ch/local-engine/Operator/ReplicateRowsStep.cpp
+++ b/cpp-ch/local-engine/Operator/ReplicateRowsStep.cpp
@@ -32,16 +32,14 @@ namespace local_engine
{
static DB::ITransformingStep::Traits getTraits()
{
- return DB::ITransformingStep::Traits
- {
+ return DB::ITransformingStep::Traits{
{
.preserves_number_of_streams = true,
.preserves_sorting = false,
},
{
.preserves_number_of_rows = false,
- }
- };
+ }};
}
ReplicateRowsStep::ReplicateRowsStep(const DB::DataStream & input_stream)
@@ -49,7 +47,7 @@ ReplicateRowsStep::ReplicateRowsStep(const DB::DataStream &
input_stream)
{
}
-DB::Block ReplicateRowsStep::transformHeader(const DB::Block& input)
+DB::Block ReplicateRowsStep::transformHeader(const DB::Block & input)
{
DB::Block output;
for (int i = 1; i < input.columns(); i++)
@@ -59,15 +57,9 @@ DB::Block ReplicateRowsStep::transformHeader(const
DB::Block& input)
return output;
}
-void ReplicateRowsStep::transformPipeline(
- DB::QueryPipelineBuilder & pipeline,
- const DB::BuildQueryPipelineSettings & /*settings*/)
+void ReplicateRowsStep::transformPipeline(DB::QueryPipelineBuilder & pipeline,
const DB::BuildQueryPipelineSettings & /*settings*/)
{
- pipeline.addSimpleTransform(
- [&](const DB::Block & header)
- {
- return std::make_shared<ReplicateRowsTransform>(header);
- });
+ pipeline.addSimpleTransform([&](const DB::Block & header) { return
std::make_shared<ReplicateRowsTransform>(header); });
}
void ReplicateRowsStep::updateOutputStream()
@@ -105,4 +97,4 @@ void ReplicateRowsTransform::transform(DB::Chunk & chunk)
chunk.setColumns(std::move(mutable_columns), total_rows);
}
-}
\ No newline at end of file
+}
diff --git a/cpp-ch/local-engine/Operator/WindowGroupLimitStep.cpp
b/cpp-ch/local-engine/Operator/WindowGroupLimitStep.cpp
new file mode 100644
index 000000000..af04ef579
--- /dev/null
+++ b/cpp-ch/local-engine/Operator/WindowGroupLimitStep.cpp
@@ -0,0 +1,365 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "WindowGroupLimitStep.h"
+#include <memory>
+#include <Processors/Chunk.h>
+#include <Processors/IProcessor.h>
+#include <QueryPipeline/QueryPipelineBuilder.h>
+#include <Poco/Logger.h>
+#include <Common/CHUtil.h>
+#include <Common/logger_useful.h>
+
+namespace DB::ErrorCodes
+{
+extern const int LOGICAL_ERROR;
+}
+
+namespace local_engine
+{
+
+enum class WindowGroupLimitFunction
+{
+ RowNumber,
+ Rank,
+ DenseRank
+};
+
+
+template <WindowGroupLimitFunction function>
+class WindowGroupLimitTransform : public DB::IProcessor
+{
+public:
+ using Status = DB::IProcessor::Status;
+ explicit WindowGroupLimitTransform(
+ const DB::Block & header_, const std::vector<size_t> &
partition_columns_, const std::vector<size_t> & sort_columns_, size_t limit_)
+ : DB::IProcessor({header_}, {header_})
+ , header(header_)
+ , partition_columns(partition_columns_)
+ , sort_columns(sort_columns_)
+ , limit(limit_)
+
+ {
+ }
+ ~WindowGroupLimitTransform() override = default;
+ String getName() const override { return "WindowGroupLimitTransform"; }
+
+ Status prepare() override
+ {
+ auto & output_port = outputs.front();
+ auto & input_port = inputs.front();
+ if (output_port.isFinished())
+ {
+ input_port.close();
+ return Status::Finished;
+ }
+
+ if (has_output)
+ {
+ if (output_port.canPush())
+ {
+ output_port.push(std::move(output_chunk));
+ has_output = false;
+ }
+ return Status::PortFull;
+ }
+
+ if (has_input)
+ return Status::Ready;
+
+ if (input_port.isFinished())
+ {
+ output_port.finish();
+ return Status::Finished;
+ }
+ input_port.setNeeded();
+ if (!input_port.hasData())
+ return Status::NeedData;
+ input_chunk = input_port.pull(true);
+ has_input = true;
+ return Status::Ready;
+ }
+
+ void work() override
+ {
+ if (!has_input) [[unlikely]]
+ {
+ return;
+ }
+ DB::Block block = header.cloneWithColumns(input_chunk.getColumns());
+ size_t partition_start_row = 0;
+ size_t chunk_rows = input_chunk.getNumRows();
+ while (partition_start_row < chunk_rows)
+ {
+ auto next_partition_start_row = advanceNextPartition(input_chunk,
partition_start_row);
+ iteratePartition(input_chunk, partition_start_row,
next_partition_start_row);
+ partition_start_row = next_partition_start_row;
+ // corner case, the partition end row is the last row of chunk.
+ if (partition_start_row < chunk_rows)
+ {
+ current_row_rank_value = 1;
+ if constexpr (function == WindowGroupLimitFunction::Rank)
+ current_peer_group_rows = 0;
+ partition_start_row_columns =
extractOneRowColumns(input_chunk, partition_start_row);
+ }
+ }
+
+ if (!output_columns.empty() && output_columns[0]->size() > 0)
+ {
+ auto rows = output_columns[0]->size();
+ output_chunk = DB::Chunk(std::move(output_columns), rows);
+ output_columns.clear();
+ has_output = true;
+ }
+ has_input = false;
+ }
+
+private:
+ DB::Block header;
+ // Which columns are used as the partition keys
+ std::vector<size_t> partition_columns;
+ // which columns are used as the order by keys, excluding partition
columns.
+ std::vector<size_t> sort_columns;
+ // Limitations for each partition.
+ size_t limit = 0;
+
+ bool has_input = false;
+ DB::Chunk input_chunk;
+ bool has_output = false;
+ DB::MutableColumns output_columns;
+ DB::Chunk output_chunk;
+
+ // We don't have window frame here. in fact all of frame are (unbounded
preceding, current row]
+ // the start value is 1
+ size_t current_row_rank_value = 1;
+ // rank need this to record how many rows in current peer group.
+ // A peer group in a partition is defined as the rows have the same value
on the sort columns.
+ size_t current_peer_group_rows = 0;
+
+ DB::Columns partition_start_row_columns;
+ DB::Columns peer_group_start_row_columns;
+
+
+ size_t advanceNextPartition(const DB::Chunk & chunk, size_t start_offset)
+ {
+ if (partition_start_row_columns.empty())
+ partition_start_row_columns = extractOneRowColumns(chunk,
start_offset);
+
+ size_t max_row = chunk.getNumRows();
+ for (size_t i = start_offset; i < max_row; ++i)
+ {
+ if (!isRowEqual(partition_columns, partition_start_row_columns, 0,
chunk.getColumns(), i))
+ {
+ return i;
+ }
+ }
+ return max_row;
+ }
+
+ static DB::Columns extractOneRowColumns(const DB::Chunk & chunk, size_t
offset)
+ {
+ DB::Columns row;
+ for (const auto & col : chunk.getColumns())
+ {
+ auto new_col = col->cloneEmpty();
+ new_col->insertFrom(*col, offset);
+ row.push_back(std::move(new_col));
+ }
+ return row;
+ }
+
+ static bool isRowEqual(
+ const std::vector<size_t> & fields, const DB::Columns & left_cols,
size_t loffset, const DB::Columns & right_cols, size_t roffset)
+ {
+ for (size_t i = 0; i < fields.size(); ++i)
+ {
+ const auto & field = fields[i];
+ /// don't care about nan_direction_hint
+ if (left_cols[field]->compareAt(loffset, roffset,
*right_cols[field], 1))
+ return false;
+ }
+ return true;
+ }
+
+ void iteratePartition(const DB::Chunk & chunk, size_t start_offset, size_t
end_offset)
+ {
+ // Skip the rest rows int this partition.
+ if (current_row_rank_value > limit)
+ return;
+
+
+ size_t chunk_rows = chunk.getNumRows();
+ auto has_peer_group_ended = [&](size_t offset, size_t
partition_end_offset, size_t chunk_rows_)
+ { return offset < partition_end_offset || end_offset < chunk_rows_; };
+ auto try_end_peer_group
+ = [&](size_t peer_group_start_offset, size_t
next_peer_group_start_offset, size_t partition_end_offset, size_t chunk_rows_)
+ {
+ if constexpr (function == WindowGroupLimitFunction::Rank)
+ {
+ current_peer_group_rows += next_peer_group_start_offset -
peer_group_start_offset;
+ if (has_peer_group_ended(next_peer_group_start_offset,
partition_end_offset, chunk_rows_))
+ {
+ current_row_rank_value += current_peer_group_rows;
+ current_peer_group_rows = 0;
+ peer_group_start_row_columns = extractOneRowColumns(chunk,
next_peer_group_start_offset);
+ }
+ }
+ else if constexpr (function == WindowGroupLimitFunction::DenseRank)
+ {
+ if (has_peer_group_ended(next_peer_group_start_offset,
partition_end_offset, chunk_rows_))
+ {
+ current_row_rank_value += 1;
+ peer_group_start_row_columns = extractOneRowColumns(chunk,
next_peer_group_start_offset);
+ }
+ }
+ };
+
+ // This is a corner case. prev partition's last row is the last row of
a chunk.
+ if (start_offset >= end_offset)
+ {
+ assert(!start_offset);
+ try_end_peer_group(start_offset, end_offset, end_offset,
chunk_rows);
+ return;
+ }
+
+ // row_number is simple
+ if constexpr (function == WindowGroupLimitFunction::RowNumber)
+ {
+ size_t rows = end_offset - start_offset;
+ size_t limit_remained = limit - current_row_rank_value + 1;
+ rows = rows > limit_remained ? limit_remained : rows;
+ insertResultValue(chunk, start_offset, rows);
+ current_row_rank_value += rows;
+ }
+ else
+ {
+ size_t peer_group_start_offset = start_offset;
+ while (peer_group_start_offset < end_offset &&
current_row_rank_value <= limit)
+ {
+ auto next_peer_group_start_offset =
advanceNextPeerGroup(chunk, peer_group_start_offset, end_offset);
+
+ insertResultValue(chunk, peer_group_start_offset,
next_peer_group_start_offset - peer_group_start_offset);
+ try_end_peer_group(peer_group_start_offset,
next_peer_group_start_offset, end_offset, chunk_rows);
+ peer_group_start_offset = next_peer_group_start_offset;
+ }
+ }
+ }
+ void insertResultValue(const DB::Chunk & chunk, size_t start_offset,
size_t rows)
+ {
+ if (!rows)
+ return;
+ if (output_columns.empty())
+ {
+ for (const auto & col : chunk.getColumns())
+ {
+ output_columns.push_back(col->cloneEmpty());
+ }
+ }
+ size_t i = 0;
+ for (const auto & col : chunk.getColumns())
+ {
+ output_columns[i]->insertRangeFrom(*col, start_offset, rows);
+ i += 1;
+ }
+ }
+ size_t advanceNextPeerGroup(const DB::Chunk & chunk, size_t start_offset,
size_t partition_end_offset)
+ {
+ if (peer_group_start_row_columns.empty())
+ peer_group_start_row_columns = extractOneRowColumns(chunk,
start_offset);
+ for (size_t i = start_offset; i < partition_end_offset; ++i)
+ {
+ if (!isRowEqual(sort_columns, peer_group_start_row_columns, 0,
chunk.getColumns(), i))
+ {
+ return i;
+ }
+ }
+ return partition_end_offset;
+ }
+};
+
+static DB::ITransformingStep::Traits getTraits()
+{
+ return DB::ITransformingStep::Traits{
+ {
+ .preserves_number_of_streams = false,
+ .preserves_sorting = true,
+ },
+ {
+ .preserves_number_of_rows = false,
+ }};
+}
+
+WindowGroupLimitStep::WindowGroupLimitStep(
+ const DB::DataStream & input_stream_,
+ const String & function_name_,
+ const std::vector<size_t> partition_columns_,
+ const std::vector<size_t> sort_columns_,
+ size_t limit_)
+ : DB::ITransformingStep(input_stream_, input_stream_.header, getTraits())
+ , function_name(function_name_)
+ , partition_columns(partition_columns_)
+ , sort_columns(sort_columns_)
+ , limit(limit_)
+{
+}
+
+void WindowGroupLimitStep::describePipeline(DB::IQueryPlanStep::FormatSettings
& settings) const
+{
+ if (!processors.empty())
+ DB::IQueryPlanStep::describePipeline(processors, settings);
+}
+
+void WindowGroupLimitStep::updateOutputStream()
+{
+ output_stream = createOutputStream(input_streams.front(),
input_streams.front().header, getDataStreamTraits());
+}
+
+
+void WindowGroupLimitStep::transformPipeline(DB::QueryPipelineBuilder &
pipeline, const DB::BuildQueryPipelineSettings & /*settings*/)
+{
+ if (function_name == "row_number")
+ {
+ pipeline.addSimpleTransform(
+ [&](const DB::Block & header)
+ {
+ return
std::make_shared<WindowGroupLimitTransform<WindowGroupLimitFunction::RowNumber>>(
+ header, partition_columns, sort_columns, limit);
+ });
+ }
+ else if (function_name == "rank")
+ {
+ pipeline.addSimpleTransform(
+ [&](const DB::Block & header) {
+ return
std::make_shared<WindowGroupLimitTransform<WindowGroupLimitFunction::Rank>>(
+ header, partition_columns, sort_columns, limit);
+ });
+ }
+ else if (function_name == "dense_rank")
+ {
+ pipeline.addSimpleTransform(
+ [&](const DB::Block & header)
+ {
+ return
std::make_shared<WindowGroupLimitTransform<WindowGroupLimitFunction::DenseRank>>(
+ header, partition_columns, sort_columns, limit);
+ });
+ }
+ else
+ {
+ throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unsupport function
{} in WindowGroupLimit", function_name);
+ }
+}
+}
diff --git a/cpp-ch/local-engine/Operator/WindowGroupLimitStep.h
b/cpp-ch/local-engine/Operator/WindowGroupLimitStep.h
new file mode 100644
index 000000000..bbbbf42ab
--- /dev/null
+++ b/cpp-ch/local-engine/Operator/WindowGroupLimitStep.h
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+#include <vector>
+#include <Core/Block.h>
+#include <Processors/QueryPlan/IQueryPlanStep.h>
+#include <Processors/QueryPlan/ITransformingStep.h>
+
+namespace local_engine
+{
+class WindowGroupLimitStep : public DB::ITransformingStep
+{
+public:
+ explicit WindowGroupLimitStep(
+ const DB::DataStream & input_stream_,
+ const String & function_name_,
+ const std::vector<size_t> partition_columns_,
+ const std::vector<size_t> sort_columns_,
+ size_t limit_);
+ ~WindowGroupLimitStep() override = default;
+
+ String getName() const override { return "WindowGroupLimitStep"; }
+
+ void transformPipeline(DB::QueryPipelineBuilder & pipeline, const
DB::BuildQueryPipelineSettings & settings) override;
+ void describePipeline(DB::IQueryPlanStep::FormatSettings & settings) const
override;
+ void updateOutputStream() override;
+
+private:
+ // window function name, one of row_number, rank and dense_rank
+ String function_name;
+ std::vector<size_t> partition_columns;
+ std::vector<size_t> sort_columns;
+ size_t limit;
+};
+
+}
diff --git a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp
b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp
index 42d4f4d4d..cc7738a15 100644
--- a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp
+++ b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp
@@ -14,25 +14,26 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-#include <base/find_symbols.h>
+#include "AdvancedParametersParseUtil.h"
#include <IO/ReadBufferFromString.h>
#include <IO/ReadHelpers.h>
-#include <Common/Exception.h>
-#include <Parser/AdvancedParametersParseUtil.h>
+#include <base/find_symbols.h>
#include <Poco/Logger.h>
+#include <Common/Exception.h>
#include <Common/logger_useful.h>
+
namespace DB::ErrorCodes
{
- extern const int BAD_ARGUMENTS;
+extern const int BAD_ARGUMENTS;
}
namespace local_engine
{
-template<typename T>
+template <typename T>
void tryAssign(const std::unordered_map<String, String> & kvs, const String &
key, T & v);
-template<>
+template <>
void tryAssign<String>(const std::unordered_map<String, String> & kvs, const
String & key, String & v)
{
auto it = kvs.find(key);
@@ -40,7 +41,7 @@ void tryAssign<String>(const std::unordered_map<String,
String> & kvs, const Str
v = it->second;
}
-template<>
+template <>
void tryAssign<bool>(const std::unordered_map<String, String> & kvs, const
String & key, bool & v)
{
auto it = kvs.find(key);
@@ -57,7 +58,7 @@ void tryAssign<bool>(const std::unordered_map<String, String>
& kvs, const Strin
}
}
-template<>
+template <>
void tryAssign<Int64>(const std::unordered_map<String, String> & kvs, const
String & key, Int64 & v)
{
auto it = kvs.find(key);
@@ -94,9 +95,9 @@ void readStringUntilCharsInto(String & s, DB::ReadBuffer &
buf)
std::unordered_map<String, std::unordered_map<String, String>>
convertToKVs(const String & advance)
{
std::unordered_map<String, std::unordered_map<String, String>> res;
- std::unordered_map<String, String> *kvs;
+ std::unordered_map<String, String> * kvs;
DB::ReadBufferFromString in(advance);
- while(!in.eof())
+ while (!in.eof())
{
String key;
readStringUntilCharsInto<'=', '\n', ':'>(key, in);
@@ -146,5 +147,13 @@ JoinOptimizationInfo JoinOptimizationInfo::parse(const
String & advance)
tryAssign(kvs, "numPartitions", info.partitions_num);
return info;
}
-}
+WindowGroupOptimizationInfo WindowGroupOptimizationInfo::parse(const String &
advance)
+{
+ WindowGroupOptimizationInfo info;
+ auto kkvs = convertToKVs(advance);
+ auto & kvs = kkvs["WindowGroupLimitParameters"];
+ tryAssign(kvs, "window_function", info.window_function);
+ return info;
+}
+}
diff --git a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h
b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h
index 5f6fe6d25..fc478db33 100644
--- a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h
+++ b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h
@@ -16,10 +16,10 @@
*/
#pragma once
#include <unordered_map>
+#include <base/types.h>
namespace local_engine
{
-
std::unordered_map<String, std::unordered_map<String, String>>
convertToKVs(const String & advance);
@@ -38,5 +38,10 @@ struct JoinOptimizationInfo
static JoinOptimizationInfo parse(const String & advance);
};
-}
+struct WindowGroupOptimizationInfo
+{
+ String window_function;
+ static WindowGroupOptimizationInfo parse(const String & advnace);
+};
+}
diff --git a/cpp-ch/local-engine/Parser/RelParser.cpp
b/cpp-ch/local-engine/Parser/RelParser.cpp
index f651146a3..a7f6d0586 100644
--- a/cpp-ch/local-engine/Parser/RelParser.cpp
+++ b/cpp-ch/local-engine/Parser/RelParser.cpp
@@ -30,8 +30,8 @@ namespace DB
{
namespace ErrorCodes
{
- extern const int BAD_ARGUMENTS;
- extern const int LOGICAL_ERROR;
+extern const int BAD_ARGUMENTS;
+extern const int LOGICAL_ERROR;
}
}
@@ -89,14 +89,15 @@ DB::QueryPlanPtr RelParser::parseOp(const substrait::Rel &
rel, std::list<const
return parse(std::move(query_plan), rel, rel_stack);
}
-std::map<std::string, std::string>
RelParser::parseFormattedRelAdvancedOptimization(const
substrait::extensions::AdvancedExtension &advanced_extension)
+std::map<std::string, std::string>
+RelParser::parseFormattedRelAdvancedOptimization(const
substrait::extensions::AdvancedExtension & advanced_extension)
{
std::map<std::string, std::string> configs;
if (advanced_extension.has_optimization())
{
google::protobuf::StringValue msg;
advanced_extension.optimization().UnpackTo(&msg);
- Poco::StringTokenizer kvs( msg.value(), "\n");
+ Poco::StringTokenizer kvs(msg.value(), "\n");
for (auto & kv : kvs)
{
if (kv.empty())
@@ -114,7 +115,8 @@ std::map<std::string, std::string>
RelParser::parseFormattedRelAdvancedOptimizat
return configs;
}
-std::string RelParser::getStringConfig(const std::map<std::string,
std::string> & configs, const std::string & key, const std::string &
default_value)
+std::string
+RelParser::getStringConfig(const std::map<std::string, std::string> & configs,
const std::string & key, const std::string & default_value)
{
auto it = configs.find(key);
if (it == configs.end())
@@ -150,6 +152,7 @@ RelParserFactory::RelParserBuilder
RelParserFactory::getBuilder(UInt32 k)
}
void registerWindowRelParser(RelParserFactory & factory);
+void registerWindowGroupLimitRelParser(RelParserFactory & factory);
void registerSortRelParser(RelParserFactory & factory);
void registerExpandRelParser(RelParserFactory & factory);
void registerAggregateParser(RelParserFactory & factory);
@@ -162,6 +165,7 @@ void registerRelParsers()
{
auto & factory = RelParserFactory::instance();
registerWindowRelParser(factory);
+ registerWindowGroupLimitRelParser(factory);
registerSortRelParser(factory);
registerExpandRelParser(factory);
registerAggregateParser(factory);
diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
index 589f3826a..e893dd35b 100644
--- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
+++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
@@ -391,9 +391,9 @@ void adjustOutput(const DB::QueryPlanPtr & query_plan,
const substrait::PlanRel
}
if (need_final_project)
{
- ActionsDAG final_project
- = ActionsDAG::makeConvertingActions(original_cols, final_cols,
ActionsDAG::MatchColumnsMode::Position);
- QueryPlanStepPtr final_project_step =
std::make_unique<ExpressionStep>(query_plan->getCurrentDataStream(),
std::move(final_project));
+ ActionsDAG final_project =
ActionsDAG::makeConvertingActions(original_cols, final_cols,
ActionsDAG::MatchColumnsMode::Position);
+ QueryPlanStepPtr final_project_step
+ =
std::make_unique<ExpressionStep>(query_plan->getCurrentDataStream(),
std::move(final_project));
final_project_step->setStepDescription("Project for output
schema");
query_plan->addStep(std::move(final_project_step));
}
@@ -504,6 +504,7 @@ QueryPlanPtr SerializedPlanParser::parseOp(const
substrait::Rel & rel, std::list
case substrait::Rel::RelTypeCase::kWindow:
case substrait::Rel::RelTypeCase::kJoin:
case substrait::Rel::RelTypeCase::kCross:
+ case substrait::Rel::RelTypeCase::kWindowGroupLimit:
case substrait::Rel::RelTypeCase::kExpand: {
auto op_parser =
RelParserFactory::instance().getBuilder(rel.rel_type_case())(this);
query_plan = op_parser->parseOp(rel, rel_stack);
@@ -606,7 +607,7 @@ void SerializedPlanParser::parseArrayJoinArguments(
}
ActionsDAG::NodeRawConstPtrs SerializedPlanParser::parseArrayJoinWithDAG(
- const substrait::Expression & rel, std::vector<String> & result_names,
ActionsDAG& actions_dag, bool keep_result, bool position)
+ const substrait::Expression & rel, std::vector<String> & result_names,
ActionsDAG & actions_dag, bool keep_result, bool position)
{
if (!rel.has_scalar_function())
throw Exception(ErrorCodes::BAD_ARGUMENTS, "The root of expression
should be a scalar function:\n {}", rel.DebugString());
@@ -723,7 +724,7 @@ ActionsDAG::NodeRawConstPtrs
SerializedPlanParser::parseArrayJoinWithDAG(
}
const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG(
- const substrait::Expression & rel, std::string & result_name, ActionsDAG&
actions_dag, bool keep_result)
+ const substrait::Expression & rel, std::string & result_name, ActionsDAG &
actions_dag, bool keep_result)
{
if (!rel.has_scalar_function())
throw Exception(ErrorCodes::BAD_ARGUMENTS, "the root of expression
should be a scalar function:\n {}", rel.DebugString());
@@ -785,9 +786,8 @@ bool
SerializedPlanParser::isFunction(substrait::Expression_ScalarFunction rel,
}
void SerializedPlanParser::parseFunctionOrExpression(
- const substrait::Expression & rel, std::string & result_name, ActionsDAG&
actions_dag, bool keep_result)
+ const substrait::Expression & rel, std::string & result_name, ActionsDAG &
actions_dag, bool keep_result)
{
-
if (rel.has_scalar_function())
parseFunctionWithDAG(rel, result_name, actions_dag, keep_result);
else
@@ -798,11 +798,7 @@ void SerializedPlanParser::parseFunctionOrExpression(
}
void SerializedPlanParser::parseJsonTuple(
- const substrait::Expression & rel,
- std::vector<String> & result_names,
- ActionsDAG& actions_dag,
- bool keep_result,
- bool)
+ const substrait::Expression & rel, std::vector<String> & result_names,
ActionsDAG & actions_dag, bool keep_result, bool)
{
const auto & scalar_function = rel.scalar_function();
auto function_signature =
function_mapping.at(std::to_string(rel.scalar_function().function_reference()));
@@ -861,7 +857,7 @@ void SerializedPlanParser::parseJsonTuple(
}
const ActionsDAG::Node *
-SerializedPlanParser::toFunctionNode(ActionsDAG& actions_dag, const String &
function, const ActionsDAG::NodeRawConstPtrs & args)
+SerializedPlanParser::toFunctionNode(ActionsDAG & actions_dag, const String &
function, const ActionsDAG::NodeRawConstPtrs & args)
{
auto function_builder = FunctionFactory::instance().get(function, context);
std::string args_name = join(args, ',');
@@ -1073,7 +1069,7 @@ std::pair<DataTypePtr, Field>
SerializedPlanParser::parseLiteral(const substrait
return std::make_pair(std::move(type), std::move(field));
}
-const ActionsDAG::Node * SerializedPlanParser::parseExpression(ActionsDAG&
actions_dag, const substrait::Expression & rel)
+const ActionsDAG::Node * SerializedPlanParser::parseExpression(ActionsDAG &
actions_dag, const substrait::Expression & rel)
{
switch (rel.rex_type_case())
{
@@ -1534,8 +1530,7 @@ ASTPtr ASTParser::parseArgumentToAST(const Names & names,
const substrait::Expre
}
}
-void SerializedPlanParser::removeNullableForRequiredColumns(
- const std::set<String> & require_columns, ActionsDAG & actions_dag) const
+void SerializedPlanParser::removeNullableForRequiredColumns(const
std::set<String> & require_columns, ActionsDAG & actions_dag) const
{
for (const auto & item : require_columns)
{
@@ -1550,7 +1545,7 @@ void
SerializedPlanParser::removeNullableForRequiredColumns(
}
void SerializedPlanParser::wrapNullable(
- const std::vector<String> & columns, ActionsDAG& actions_dag,
std::map<std::string, std::string> & nullable_measure_names)
+ const std::vector<String> & columns, ActionsDAG & actions_dag,
std::map<std::string, std::string> & nullable_measure_names)
{
for (const auto & item : columns)
{
diff --git a/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.cpp
b/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.cpp
new file mode 100644
index 000000000..f6c10386f
--- /dev/null
+++ b/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.cpp
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "WindowGroupLimitRelParser.h"
+#include <Interpreters/ActionsDAG.h>
+#include <Operator/WindowGroupLimitStep.h>
+#include <Parser/AdvancedParametersParseUtil.h>
+#include <Parser/SortRelParser.h>
+#include <Parser/WindowGroupLimitRelParser.h>
+#include <Processors/QueryPlan/ExpressionStep.h>
+#include <google/protobuf/repeated_field.h>
+#include <google/protobuf/wrappers.pb.h>
+#include "AdvancedParametersParseUtil.h"
+
+namespace DB::ErrorCodes
+{
+extern const int BAD_ARGUMENTS;
+}
+
+namespace local_engine
+{
+WindowGroupLimitRelParser::WindowGroupLimitRelParser(SerializedPlanParser *
plan_parser_) : RelParser(plan_parser_)
+{
+}
+
+DB::QueryPlanPtr
+WindowGroupLimitRelParser::parse(DB::QueryPlanPtr current_plan_, const
substrait::Rel & rel, std::list<const substrait::Rel *> & rel_stack_)
+{
+ const auto win_rel_def = rel.windowgrouplimit();
+ google::protobuf::StringValue optimize_info_str;
+
optimize_info_str.ParseFromString(win_rel_def.advanced_extension().optimization().value());
+ auto optimization_info =
WindowGroupOptimizationInfo::parse(optimize_info_str.value());
+ window_function_name = optimization_info.window_function;
+
+ current_plan = std::move(current_plan_);
+
+ auto partition_fields =
parsePartitoinFields(win_rel_def.partition_expressions());
+ auto sort_fields = parseSortFields(win_rel_def.sorts());
+ size_t limit = static_cast<size_t>(win_rel_def.limit());
+
+ auto window_group_limit_step = std::make_unique<WindowGroupLimitStep>(
+ current_plan->getCurrentDataStream(), window_function_name,
partition_fields, sort_fields, limit);
+ window_group_limit_step->setStepDescription("Window group limit");
+ steps.emplace_back(window_group_limit_step.get());
+ current_plan->addStep(std::move(window_group_limit_step));
+
+ return std::move(current_plan);
+}
+
+std::vector<size_t>
+WindowGroupLimitRelParser::parsePartitoinFields(const
google::protobuf::RepeatedPtrField<substrait::Expression> & expressions)
+{
+ std::vector<size_t> fields;
+ for (const auto & expr : expressions)
+ {
+ if (expr.has_selection())
+ {
+
fields.push_back(static_cast<size_t>(expr.selection().direct_reference().struct_field().field()));
+ }
+ else if (expr.has_literal())
+ {
+ continue;
+ }
+ else
+ {
+ throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknow
expression: {}", expr.DebugString());
+ }
+ }
+ return fields;
+}
+
+std::vector<size_t> WindowGroupLimitRelParser::parseSortFields(const
google::protobuf::RepeatedPtrField<substrait::SortField> & sort_fields)
+{
+ std::vector<size_t> fields;
+ for (const auto sort_field : sort_fields)
+ {
+ if (sort_field.expr().has_literal())
+ {
+ continue;
+ }
+ else if (sort_field.expr().has_selection())
+ {
+
fields.push_back(static_cast<size_t>(sort_field.expr().selection().direct_reference().struct_field().field()));
+ }
+ else
+ {
+ throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknown
expression: {}", sort_field.expr().DebugString());
+ }
+ }
+ return fields;
+}
+
+void registerWindowGroupLimitRelParser(RelParserFactory & factory)
+{
+ auto builder = [](SerializedPlanParser * plan_parser) { return
std::make_shared<WindowGroupLimitRelParser>(plan_parser); };
+ factory.registerBuilder(substrait::Rel::RelTypeCase::kWindowGroupLimit,
builder);
+}
+}
diff --git a/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.h
b/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.h
new file mode 100644
index 000000000..c9c503ed4
--- /dev/null
+++ b/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.h
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+#include <unordered_map>
+#include <Core/Field.h>
+#include <Core/SortDescription.h>
+#include <DataTypes/IDataType.h>
+#include <Interpreters/WindowDescription.h>
+#include <Parser/AggregateFunctionParser.h>
+#include <Parser/RelParser.h>
+#include <Parser/SerializedPlanParser.h>
+#include <Processors/QueryPlan/QueryPlan.h>
+#include <Poco/Logger.h>
+#include <Common/logger_useful.h>
+
+namespace local_engine
+{
+/// Similar to WindowRelParser. Some differences
+/// 1. cannot support aggregate functions. only support window functions:
row_number, rank, dense_rank
+/// 2. row_number, rank and dense_rank are mapped to new variants
+/// 3. the output columns don't contain window function results
+class WindowGroupLimitRelParser : public RelParser
+{
+public:
+ explicit WindowGroupLimitRelParser(SerializedPlanParser * plan_parser_);
+ ~WindowGroupLimitRelParser() override = default;
+ DB::QueryPlanPtr
+ parse(DB::QueryPlanPtr current_plan_, const substrait::Rel & rel,
std::list<const substrait::Rel *> & rel_stack_) override;
+ const substrait::Rel & getSingleInput(const substrait::Rel & rel) override
{ return rel.windowgrouplimit().input(); }
+
+private:
+ DB::QueryPlanPtr current_plan;
+ String window_function_name;
+
+ std::vector<size_t> parsePartitoinFields(const
google::protobuf::RepeatedPtrField<substrait::Expression> & expressions);
+ std::vector<size_t> parseSortFields(const
google::protobuf::RepeatedPtrField<substrait::SortField> & sort_fields);
+};
+}
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
index a55926d76..dd4150806 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
@@ -41,6 +41,7 @@ import
org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.BuildSideRelation
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.ArrowEvalPythonExec
+import org.apache.spark.sql.execution.window._
import org.apache.spark.sql.hive.{HiveTableScanExecTransformer,
HiveUDFTransformer}
import org.apache.spark.sql.types.{DecimalType, LongType, NullType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch
@@ -678,6 +679,15 @@ trait SparkPlanExecApi {
}
}
+ def genWindowGroupLimitTransformer(
+ partitionSpec: Seq[Expression],
+ orderSpec: Seq[SortOrder],
+ rankLikeFunction: Expression,
+ limit: Int,
+ mode: WindowGroupLimitMode,
+ child: SparkPlan): SparkPlan =
+ WindowGroupLimitExecTransformer(partitionSpec, orderSpec,
rankLikeFunction, limit, mode, child)
+
def genHiveUDFTransformer(
expr: Expression,
attributeSeq: Seq[Attribute]): ExpressionTransformer = {
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala
index cdc71f447..9ca60177c 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala
@@ -372,7 +372,7 @@ object OffloadOthers {
val windowGroupLimitPlan = SparkShimLoader.getSparkShims
.getWindowGroupLimitExecShim(plan)
.asInstanceOf[WindowGroupLimitExecShim]
- WindowGroupLimitExecTransformer(
+
BackendsApiManager.getSparkPlanExecApiInstance.genWindowGroupLimitTransformer(
windowGroupLimitPlan.partitionSpec,
windowGroupLimitPlan.orderSpec,
windowGroupLimitPlan.rankLikeFunction,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]