This is an automated email from the ASF dual-hosted git repository.
liuneng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new b0d836bac [GLUTEN-6159][CH] Support array functions with lambda
functions (#6248)
b0d836bac is described below
commit b0d836bac9f9a5e4bba2dbd3de861d2abac80d8d
Author: lgbo <[email protected]>
AuthorDate: Tue Jul 2 14:31:16 2024 +0800
[GLUTEN-6159][CH] Support array functions with lambda functions (#6248)
What changes were proposed in this pull request?
(Please fill in changes proposed in this fix)
Fixes: #6159
support following array functions
filter
transform
aggregate
How was this patch tested?
(Please explain how this patch was tested. E.g. unit tests, integration
tests, manual tests)
unit tests
(If this patch involves UI changes, please attach a screenshot; otherwise,
remove this)
---
.../clickhouse/CHSparkPlanExecApi.scala | 18 ++
.../execution/GlutenFunctionValidateSuite.scala | 17 ++
.../local-engine/Parser/SerializedPlanParser.cpp | 13 +-
cpp-ch/local-engine/Parser/SerializedPlanParser.h | 4 +-
.../arrayHighOrderFunctions.cpp | 154 ++++++++++++++
.../scalar_function_parser/lambdaFunction.cpp | 222 +++++++++++++++++++++
.../Parser/scalar_function_parser/lambdaFunction.h | 23 +++
7 files changed, 448 insertions(+), 3 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index 7ed333aec..c0dee707e 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -842,6 +842,24 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
CHGenerateExecTransformer(generator, requiredChildOutput, outer,
generatorOutput, child)
}
+ /** Transform array filter to Substrait. */
+ override def genArrayFilterTransformer(
+ substraitExprName: String,
+ argument: ExpressionTransformer,
+ function: ExpressionTransformer,
+ expr: ArrayFilter): ExpressionTransformer = {
+ GenericExpressionTransformer(substraitExprName, Seq(argument, function),
expr)
+ }
+
+ /** Transform array transform to Substrait. */
+ override def genArrayTransformTransformer(
+ substraitExprName: String,
+ argument: ExpressionTransformer,
+ function: ExpressionTransformer,
+ expr: ArrayTransform): ExpressionTransformer = {
+ GenericExpressionTransformer(substraitExprName, Seq(argument, function),
expr)
+ }
+
override def genPreProjectForGenerate(generate: GenerateExec): SparkPlan =
generate
override def genPostProjectForGenerate(generate: GenerateExec): SparkPlan =
generate
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
index 9327137fa..d3e3e9446 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
@@ -713,4 +713,21 @@ class GlutenFunctionValidateSuite extends
GlutenClickHouseWholeStageTransformerS
}
}
+
+ test("array functions with lambda") {
+ withTable("tb_array") {
+ sql("create table tb_array(ids array<int>) using parquet")
+ sql("""
+ |insert into tb_array values (array(1,5,2,null, 3)),
(array(1,1,3,2)), (null), (array())
+ |""".stripMargin)
+ val transform_sql = "select transform(ids, x -> x + 1) from tb_array"
+
runQueryAndCompare(transform_sql)(checkGlutenOperatorMatch[ProjectExecTransformer])
+
+ val filter_sql = "select filter(ids, x -> x % 2 == 1) from tb_array";
+
runQueryAndCompare(filter_sql)(checkGlutenOperatorMatch[ProjectExecTransformer])
+
+ val aggregate_sql = "select ids, aggregate(ids, 3, (acc, x) -> acc + x)
from tb_array";
+
runQueryAndCompare(aggregate_sql)(checkGlutenOperatorMatch[ProjectExecTransformer])
+ }
+ }
}
diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
index 77819fd73..ea33dc210 100644
--- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
+++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
@@ -564,6 +564,16 @@ NamesAndTypesList
SerializedPlanParser::blockToNameAndTypeList(const Block & hea
return types;
}
+std::optional<String> SerializedPlanParser::getFunctionSignatureName(UInt32
function_ref) const
+{
+ auto it = function_mapping.find(std::to_string(function_ref));
+ if (it == function_mapping.end())
+ return {};
+ auto function_signature = it->second;
+ auto pos = function_signature.find(':');
+ return function_signature.substr(0, pos);
+}
+
std::string
SerializedPlanParser::getFunctionName(const std::string & function_signature,
const substrait::Expression_ScalarFunction & function)
{
@@ -1122,8 +1132,7 @@ const ActionsDAG::Node *
SerializedPlanParser::parseFunctionArgument(
{
std::string arg_name;
bool keep_arg = FUNCTION_NEED_KEEP_ARGUMENTS.contains(function_name);
- parseFunctionWithDAG(arg.value(), arg_name, actions_dag, keep_arg);
- res = &actions_dag->getNodes().back();
+ res = parseFunctionWithDAG(arg.value(), arg_name, actions_dag,
keep_arg);
}
else
{
diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h
b/cpp-ch/local-engine/Parser/SerializedPlanParser.h
index ad2b0d50e..ffd414803 100644
--- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h
+++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h
@@ -307,9 +307,12 @@ public:
const std::unordered_map<std::string, std::string> & getFunctionMapping()
{ return function_mapping; }
static std::string getFunctionName(const std::string & function_sig, const
substrait::Expression_ScalarFunction & function);
+ std::optional<std::string> getFunctionSignatureName(UInt32 function_ref)
const;
IQueryPlanStep * addRemoveNullableStep(QueryPlan & plan, const
std::set<String> & columns);
IQueryPlanStep * addRollbackFilterHeaderStep(QueryPlanPtr & query_plan,
const Block & input_header);
+
+ static std::pair<DataTypePtr, Field> parseLiteral(const
substrait::Expression_Literal & literal);
static ContextMutablePtr global_context;
static Context::ConfigurationPtr config;
@@ -384,7 +387,6 @@ private:
// remove nullable after isNotNull
void removeNullableForRequiredColumns(const std::set<String> &
require_columns, const ActionsDAGPtr & actions_dag) const;
std::string getUniqueName(const std::string & name) { return name + "_" +
std::to_string(name_no++); }
- static std::pair<DataTypePtr, Field> parseLiteral(const
substrait::Expression_Literal & literal);
void wrapNullable(
const std::vector<String> & columns, ActionsDAGPtr actions_dag,
std::map<std::string, std::string> & nullable_measure_names);
static std::pair<DB::DataTypePtr, DB::Field> convertStructFieldType(const
DB::DataTypePtr & type, const DB::Field & field);
diff --git
a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp
b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp
new file mode 100644
index 000000000..584bc0ef1
--- /dev/null
+++
b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <Parser/FunctionParser.h>
+#include <Common/Exception.h>
+#include <Poco/Logger.h>
+#include <Common/logger_useful.h>
+#include <Common/CHUtil.h>
+#include <DataTypes/DataTypeFunction.h>
+#include <DataTypes/DataTypeNullable.h>
+#include <Core/Types.h>
+#include <Parser/TypeParser.h>
+#include <Parser/scalar_function_parser/lambdaFunction.h>
+
+namespace DB::ErrorCodes
+{
+ extern const int LOGICAL_ERROR;
+}
+
+namespace local_engine
+{
+class ArrayFilter : public FunctionParser
+{
+public:
+ static constexpr auto name = "filter";
+ explicit ArrayFilter(SerializedPlanParser * plan_parser_) :
FunctionParser(plan_parser_) {}
+ ~ArrayFilter() override = default;
+
+ String getName() const override { return name; }
+
+ String getCHFunctionName(const substrait::Expression_ScalarFunction &
scalar_function) const override
+ {
+ return "arrayFilter";
+ }
+
+ const DB::ActionsDAG::Node * parse(const
substrait::Expression_ScalarFunction & substrait_func,
+ DB::ActionsDAGPtr & actions_dag) const
+ {
+ auto ch_func_name = getCHFunctionName(substrait_func);
+ auto parsed_args = parseFunctionArguments(substrait_func,
ch_func_name, actions_dag);
+ assert(parsed_args.size() == 2);
+ if (collectLambdaArguments(*plan_parser,
substrait_func.arguments()[1].value().scalar_function()).size() == 1)
+ return toFunctionNode(actions_dag, ch_func_name, {parsed_args[1],
parsed_args[0]});
+
+ /// filter with index argument.
+ const auto * range_end_node = toFunctionNode(actions_dag, "length",
{toFunctionNode(actions_dag, "assumeNotNull", {parsed_args[0]})});
+ range_end_node = ActionsDAGUtil::convertNodeType(
+ actions_dag, range_end_node, "Nullable(Int32)",
range_end_node->result_name);
+ const auto * index_array_node = toFunctionNode(
+ actions_dag,
+ "range",
+ {addColumnToActionsDAG(actions_dag,
std::make_shared<DataTypeInt32>(), 0), range_end_node});
+ return toFunctionNode(actions_dag, ch_func_name, {parsed_args[1],
parsed_args[0], index_array_node});
+ }
+};
+static FunctionParserRegister<ArrayFilter> register_array_filter;
+
+class ArrayTransform : public FunctionParser
+{
+public:
+ static constexpr auto name = "transform";
+ explicit ArrayTransform(SerializedPlanParser * plan_parser_) :
FunctionParser(plan_parser_) {}
+ ~ArrayTransform() override = default;
+ String getName() const override { return name; }
+ String getCHFunctionName(const substrait::Expression_ScalarFunction &
scalar_function) const override
+ {
+ return "arrayMap";
+ }
+
+ const DB::ActionsDAG::Node * parse(const
substrait::Expression_ScalarFunction & substrait_func,
+ DB::ActionsDAGPtr & actions_dag) const
+ {
+ auto ch_func_name = getCHFunctionName(substrait_func);
+ auto lambda_args = collectLambdaArguments(*plan_parser,
substrait_func.arguments()[1].value().scalar_function());
+ auto parsed_args = parseFunctionArguments(substrait_func,
ch_func_name, actions_dag);
+ assert(parsed_args.size() == 2);
+ if (lambda_args.size() == 1)
+ {
+ return toFunctionNode(actions_dag, ch_func_name, {parsed_args[1],
parsed_args[0]});
+ }
+
+ /// transform with index argument.
+ const auto * range_end_node = toFunctionNode(actions_dag, "length",
{toFunctionNode(actions_dag, "assumeNotNull", {parsed_args[0]})});
+ range_end_node = ActionsDAGUtil::convertNodeType(
+ actions_dag, range_end_node, "Nullable(Int32)",
range_end_node->result_name);
+ const auto * index_array_node = toFunctionNode(
+ actions_dag,
+ "range",
+ {addColumnToActionsDAG(actions_dag,
std::make_shared<DataTypeInt32>(), 0), range_end_node});
+ return toFunctionNode(actions_dag, ch_func_name, {parsed_args[1],
parsed_args[0], index_array_node});
+ }
+};
+static FunctionParserRegister<ArrayTransform> register_array_map;
+
+class ArrayAggregate : public FunctionParser
+{
+public:
+ static constexpr auto name = "aggregate";
+ explicit ArrayAggregate(SerializedPlanParser * plan_parser_) :
FunctionParser(plan_parser_) {}
+ ~ArrayAggregate() override = default;
+ String getName() const override { return name; }
+ String getCHFunctionName(const substrait::Expression_ScalarFunction &
scalar_function) const override
+ {
+ return "arrayFold";
+ }
+ const DB::ActionsDAG::Node * parse(const
substrait::Expression_ScalarFunction & substrait_func,
+ DB::ActionsDAGPtr & actions_dag) const
+ {
+ auto ch_func_name = getCHFunctionName(substrait_func);
+ auto parsed_args = parseFunctionArguments(substrait_func,
ch_func_name, actions_dag);
+ assert(parsed_args.size() == 3);
+ const auto * function_type = typeid_cast<const DataTypeFunction
*>(parsed_args[2]->result_type.get());
+ if (!function_type)
+ throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "The third
argument of aggregate function must be a lambda function");
+ if
(!parsed_args[1]->result_type->equals(*(function_type->getReturnType())))
+ {
+ parsed_args[1] = ActionsDAGUtil::convertNodeType(
+ actions_dag,
+ parsed_args[1],
+ function_type->getReturnType()->getName(),
+ parsed_args[1]->result_name);
+ }
+
+ /// arrayFold cannot accept nullable(array)
+ const auto * array_col_node = parsed_args[0];
+ if (parsed_args[0]->result_type->isNullable())
+ {
+ array_col_node = toFunctionNode(actions_dag, "assumeNotNull",
{parsed_args[0]});
+ }
+ const auto * func_node = toFunctionNode(actions_dag, ch_func_name,
{parsed_args[2], array_col_node, parsed_args[1]});
+ /// For null array, result is null.
+ /// TODO: make a new version of arrayFold that can handle nullable
array.
+ const auto * is_null_node = toFunctionNode(actions_dag, "isNull",
{parsed_args[0]});
+ const auto * null_node = addColumnToActionsDAG(actions_dag,
DB::makeNullable(func_node->result_type), DB::Null());
+ return toFunctionNode(actions_dag, "if", {is_null_node, null_node,
func_node});
+ }
+};
+static FunctionParserRegister<ArrayAggregate> register_array_aggregate;
+
+}
diff --git
a/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp
b/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp
new file mode 100644
index 000000000..57c076ed2
--- /dev/null
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <Parser/FunctionParser.h>
+#include <Parser/TypeParser.h>
+#include <Common/Exception.h>
+#include <Poco/Logger.h>
+#include <Common/logger_useful.h>
+#include <Interpreters/ExpressionActionsSettings.h>
+#include <Interpreters/ExpressionActions.h>
+#include <Functions/FunctionsMiscellaneous.h>
+#include <Common/CHUtil.h>
+#include <unordered_set>
+
+namespace DB::ErrorCodes
+{
+ extern const int LOGICAL_ERROR;
+}
+
+namespace local_engine
+{
+DB::NamesAndTypesList collectLambdaArguments(const SerializedPlanParser &
plan_parser_, const substrait::Expression_ScalarFunction & substrait_func)
+{
+ DB::NamesAndTypesList lambda_arguments;
+ std::unordered_set<String> collected_names;
+
+ for (const auto & arg : substrait_func.arguments())
+ {
+ if (arg.value().has_scalar_function()
+ &&
plan_parser_.getFunctionSignatureName(arg.value().scalar_function().function_reference())
== "namedlambdavariable")
+ {
+ auto [_, col_name_field] =
plan_parser_.parseLiteral(arg.value().scalar_function().arguments()[0].value().literal());
+ String col_name = col_name_field.get<String>();
+ if (collected_names.contains(col_name))
+ {
+ continue;
+ }
+ collected_names.insert(col_name);
+ auto type =
TypeParser::parseType(arg.value().scalar_function().output_type());
+ lambda_arguments.emplace_back(col_name, type);
+ }
+ }
+ return lambda_arguments;
+}
+
+/// Refer to `PlannerActionsVisitorImpl::visitLambda` for how to build a
lambda function node.
+class LambdaFunction : public FunctionParser
+{
+public:
+ static constexpr auto name = "lambdafunction";
+ explicit LambdaFunction(SerializedPlanParser * plan_parser_) :
FunctionParser(plan_parser_) {}
+ ~LambdaFunction() override = default;
+
+ String getName() const override { return name; }
+protected:
+ String getCHFunctionName(const substrait::Expression_ScalarFunction &
scalar_function) const override
+ {
+ throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "getCHFunctionName
is not implemented for LambdaFunction");
+ }
+
+ DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments(
+ const substrait::Expression_ScalarFunction & substrait_func,
+ const String & ch_func_name,
+ DB::ActionsDAGPtr & actions_dag) const override
+ {
+ throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR,
"parseFunctionArguments is not implemented for LambdaFunction");
+ }
+
+ const DB::ActionsDAG::Node * convertNodeTypeIfNeeded(
+ const substrait::Expression_ScalarFunction & substrait_func,
+ const DB::ActionsDAG::Node * func_node,
+ DB::ActionsDAGPtr & actions_dag) const override
+ {
+ throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR,
"convertNodeTypeIfNeeded is not implemented for NamedLambdaVariable");
+ }
+
+ const DB::ActionsDAG::Node * parse(const
substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAGPtr &
actions_dag) const override
+ {
+ /// Some special cases, for example, `transform(arr, x -> concat(arr,
array(x)))` refers to
+ /// a column `arr` out of it directly. We need a `arr` as an input
column for `lambda_actions_dag`
+ DB::NamesAndTypesList parent_header;
+ for (const auto * output_node : actions_dag->getOutputs())
+ {
+ parent_header.emplace_back(output_node->result_name,
output_node->result_type);
+ }
+ auto lambda_actions_dag =
std::make_shared<DB::ActionsDAG>(parent_header);
+
+ /// The first argument is the lambda function body, followings are the
lambda arguments which is
+ /// needed by the lambda function body.
+ /// There could be a nested lambda function in the lambda function
body, and it refer a variable from
+ /// this outside lambda function's arguments. For an example,
transform(number, x -> transform(letter, y -> struct(x, y))).
+ /// Before parsing the lambda function body, we add lambda function
arguments int actions dag at first.
+ for (size_t i = 1; i < substrait_func.arguments().size(); ++i)
+ {
+ (void)parseExpression(lambda_actions_dag,
substrait_func.arguments()[i].value());
+ }
+ const auto & substrait_lambda_body =
substrait_func.arguments()[0].value();
+ const auto * lambda_body_node = parseExpression(lambda_actions_dag,
substrait_lambda_body);
+ lambda_actions_dag->getOutputs().push_back(lambda_body_node);
+ lambda_actions_dag->removeUnusedActions(Names(1,
lambda_body_node->result_name));
+
+ auto expression_actions_settings =
DB::ExpressionActionsSettings::fromContext(getContext(),
DB::CompileExpressions::yes);
+ auto lambda_actions =
std::make_shared<DB::ExpressionActions>(lambda_actions_dag,
expression_actions_settings);
+
+ DB::Names captured_column_names;
+ DB::Names required_column_names = lambda_actions->getRequiredColumns();
+ DB::ActionsDAG::NodeRawConstPtrs lambda_children;
+ auto lambda_function_args = collectLambdaArguments(*plan_parser,
substrait_func);
+ const auto & lambda_actions_inputs = lambda_actions_dag->getInputs();
+
+ std::unordered_map<String, const DB::ActionsDAG::Node *> parent_nodes;
+ for (const auto & node : actions_dag->getNodes())
+ {
+ parent_nodes[node.result_name] = &node;
+ }
+ for (const auto & required_column_name : required_column_names)
+ {
+ if (std::find_if(
+ lambda_function_args.begin(),
+ lambda_function_args.end(),
+ [&required_column_name](const DB::NameAndTypePair &
name_type) { return name_type.name == required_column_name; })
+ == lambda_function_args.end())
+ {
+ auto it = std::find_if(
+ lambda_actions_inputs.begin(),
+ lambda_actions_inputs.end(),
+ [&required_column_name](const auto & node) { return
node->result_name == required_column_name; });
+ if (it == lambda_actions_inputs.end())
+ {
+ throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR,
"Required column not found: {}", required_column_name);
+ }
+ auto parent_node_it = parent_nodes.find(required_column_name);
+ if (parent_node_it == parent_nodes.end())
+ {
+ throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Not
found column {} in actions dag:\n{}",
+ required_column_name,
+ actions_dag->dumpDAG());
+ }
+ /// The nodes must be the ones in `actions_dag`, otherwise
`ActionsDAG::evaluatePartialResult` will fail. Because nodes may have the
+ /// same name but their addresses are different.
+ lambda_children.push_back(parent_node_it->second);
+ captured_column_names.push_back(required_column_name);
+ }
+ }
+
+ auto function_capture =
std::make_shared<DB::FunctionCaptureOverloadResolver>(
+ lambda_actions,
+ captured_column_names,
+ lambda_function_args,
+ lambda_body_node->result_type,
+ lambda_body_node->result_name);
+
+ const auto * result = &actions_dag->addFunction(function_capture,
lambda_children, lambda_body_node->result_name);
+ return result;
+ }
+};
+
+static FunctionParserRegister<LambdaFunction> register_lambda_function;
+
+
+class NamedLambdaVariable : public FunctionParser
+{
+public:
+ static constexpr auto name = "namedlambdavariable";
+ explicit NamedLambdaVariable(SerializedPlanParser * plan_parser_) :
FunctionParser(plan_parser_) {}
+ ~NamedLambdaVariable() override = default;
+
+ String getName() const override { return name; }
+protected:
+ String getCHFunctionName(const substrait::Expression_ScalarFunction &
scalar_function) const override
+ {
+ throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "getCHFunctionName
is not implemented for NamedLambdaVariable");
+ }
+
+ DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments(
+ const substrait::Expression_ScalarFunction & substrait_func,
+ const String & ch_func_name,
+ DB::ActionsDAGPtr & actions_dag) const override
+ {
+ throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR,
"parseFunctionArguments is not implemented for NamedLambdaVariable");
+ }
+
+ const DB::ActionsDAG::Node * convertNodeTypeIfNeeded(
+ const substrait::Expression_ScalarFunction & substrait_func,
+ const DB::ActionsDAG::Node * func_node,
+ DB::ActionsDAGPtr & actions_dag) const override
+ {
+ throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR,
"convertNodeTypeIfNeeded is not implemented for NamedLambdaVariable");
+ }
+
+ const DB::ActionsDAG::Node * parse(const
substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAGPtr &
actions_dag) const override
+ {
+ auto [_, col_name_field] =
parseLiteral(substrait_func.arguments()[0].value().literal());
+ String col_name = col_name_field.get<String>();
+
+ auto type = TypeParser::parseType(substrait_func.output_type());
+ const auto & inputs = actions_dag->getInputs();
+ auto it = std::find_if(inputs.begin(), inputs.end(), [&col_name](const
auto * node) { return node->result_name == col_name; });
+ if (it == inputs.end())
+ {
+ return &(actions_dag->addInput(col_name, type));
+ }
+ return *it;
+ }
+};
+
+static FunctionParserRegister<NamedLambdaVariable>
register_named_lambda_variable;
+
+}
diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.h
b/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.h
new file mode 100644
index 000000000..327c72ade
--- /dev/null
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.h
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+#include <Parser/FunctionParser.h>
+#include <Core/NamesAndTypes.h>
+namespace local_engine
+{
+DB::NamesAndTypesList collectLambdaArguments(const SerializedPlanParser &
plan_parser_, const substrait::Expression_ScalarFunction & substrait_func);
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]