This is an automated email from the ASF dual-hosted git repository.

liuneng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new b0d836bac [GLUTEN-6159][CH] Support array functions with lambda 
functions (#6248)
b0d836bac is described below

commit b0d836bac9f9a5e4bba2dbd3de861d2abac80d8d
Author: lgbo <[email protected]>
AuthorDate: Tue Jul 2 14:31:16 2024 +0800

    [GLUTEN-6159][CH] Support array functions with lambda functions (#6248)
    
    What changes were proposed in this pull request?
    (Please fill in changes proposed in this fix)
    
    Fixes: #6159
    
    support following array functions
    
    filter
    transform
    aggregate
    How was this patch tested?
    (Please explain how this patch was tested. E.g. unit tests, integration 
tests, manual tests)
    
    unit tests
    
    (If this patch involves UI changes, please attach a screenshot; otherwise, 
remove this)
---
 .../clickhouse/CHSparkPlanExecApi.scala            |  18 ++
 .../execution/GlutenFunctionValidateSuite.scala    |  17 ++
 .../local-engine/Parser/SerializedPlanParser.cpp   |  13 +-
 cpp-ch/local-engine/Parser/SerializedPlanParser.h  |   4 +-
 .../arrayHighOrderFunctions.cpp                    | 154 ++++++++++++++
 .../scalar_function_parser/lambdaFunction.cpp      | 222 +++++++++++++++++++++
 .../Parser/scalar_function_parser/lambdaFunction.h |  23 +++
 7 files changed, 448 insertions(+), 3 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index 7ed333aec..c0dee707e 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -842,6 +842,24 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
     CHGenerateExecTransformer(generator, requiredChildOutput, outer, 
generatorOutput, child)
   }
 
+  /** Transform array filter to Substrait. */
+  override def genArrayFilterTransformer(
+      substraitExprName: String,
+      argument: ExpressionTransformer,
+      function: ExpressionTransformer,
+      expr: ArrayFilter): ExpressionTransformer = {
+    GenericExpressionTransformer(substraitExprName, Seq(argument, function), 
expr)
+  }
+
+  /** Transform array transform to Substrait. */
+  override def genArrayTransformTransformer(
+      substraitExprName: String,
+      argument: ExpressionTransformer,
+      function: ExpressionTransformer,
+      expr: ArrayTransform): ExpressionTransformer = {
+    GenericExpressionTransformer(substraitExprName, Seq(argument, function), 
expr)
+  }
+
   override def genPreProjectForGenerate(generate: GenerateExec): SparkPlan = 
generate
 
   override def genPostProjectForGenerate(generate: GenerateExec): SparkPlan = 
generate
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
index 9327137fa..d3e3e9446 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
@@ -713,4 +713,21 @@ class GlutenFunctionValidateSuite extends 
GlutenClickHouseWholeStageTransformerS
     }
 
   }
+
+  test("array functions with lambda") {
+    withTable("tb_array") {
+      sql("create table tb_array(ids array<int>) using parquet")
+      sql("""
+            |insert into tb_array values (array(1,5,2,null, 3)), 
(array(1,1,3,2)), (null), (array())
+            |""".stripMargin)
+      val transform_sql = "select transform(ids, x -> x + 1) from tb_array"
+      
runQueryAndCompare(transform_sql)(checkGlutenOperatorMatch[ProjectExecTransformer])
+
+      val filter_sql = "select filter(ids, x -> x % 2 == 1) from tb_array";
+      
runQueryAndCompare(filter_sql)(checkGlutenOperatorMatch[ProjectExecTransformer])
+
+      val aggregate_sql = "select ids, aggregate(ids, 3, (acc, x) -> acc + x) 
from tb_array";
+      
runQueryAndCompare(aggregate_sql)(checkGlutenOperatorMatch[ProjectExecTransformer])
+    }
+  }
 }
diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp 
b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
index 77819fd73..ea33dc210 100644
--- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
+++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
@@ -564,6 +564,16 @@ NamesAndTypesList 
SerializedPlanParser::blockToNameAndTypeList(const Block & hea
     return types;
 }
 
+std::optional<String> SerializedPlanParser::getFunctionSignatureName(UInt32 
function_ref) const
+{
+    auto it = function_mapping.find(std::to_string(function_ref));
+    if (it == function_mapping.end())
+        return {};
+    auto function_signature = it->second;
+    auto pos = function_signature.find(':');
+    return function_signature.substr(0, pos);
+}
+
 std::string
 SerializedPlanParser::getFunctionName(const std::string & function_signature, 
const substrait::Expression_ScalarFunction & function)
 {
@@ -1122,8 +1132,7 @@ const ActionsDAG::Node * 
SerializedPlanParser::parseFunctionArgument(
     {
         std::string arg_name;
         bool keep_arg = FUNCTION_NEED_KEEP_ARGUMENTS.contains(function_name);
-        parseFunctionWithDAG(arg.value(), arg_name, actions_dag, keep_arg);
-        res = &actions_dag->getNodes().back();
+        res = parseFunctionWithDAG(arg.value(), arg_name, actions_dag, 
keep_arg);
     }
     else
     {
diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h 
b/cpp-ch/local-engine/Parser/SerializedPlanParser.h
index ad2b0d50e..ffd414803 100644
--- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h
+++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h
@@ -307,9 +307,12 @@ public:
     const std::unordered_map<std::string, std::string> & getFunctionMapping() 
{ return function_mapping; }
 
     static std::string getFunctionName(const std::string & function_sig, const 
substrait::Expression_ScalarFunction & function);
+    std::optional<std::string> getFunctionSignatureName(UInt32 function_ref) 
const;
 
     IQueryPlanStep * addRemoveNullableStep(QueryPlan & plan, const 
std::set<String> & columns);
     IQueryPlanStep * addRollbackFilterHeaderStep(QueryPlanPtr & query_plan, 
const Block & input_header);
+    
+    static std::pair<DataTypePtr, Field> parseLiteral(const 
substrait::Expression_Literal & literal);
 
     static ContextMutablePtr global_context;
     static Context::ConfigurationPtr config;
@@ -384,7 +387,6 @@ private:
     // remove nullable after isNotNull
     void removeNullableForRequiredColumns(const std::set<String> & 
require_columns, const ActionsDAGPtr & actions_dag) const;
     std::string getUniqueName(const std::string & name) { return name + "_" + 
std::to_string(name_no++); }
-    static std::pair<DataTypePtr, Field> parseLiteral(const 
substrait::Expression_Literal & literal);
     void wrapNullable(
         const std::vector<String> & columns, ActionsDAGPtr actions_dag, 
std::map<std::string, std::string> & nullable_measure_names);
     static std::pair<DB::DataTypePtr, DB::Field> convertStructFieldType(const 
DB::DataTypePtr & type, const DB::Field & field);
diff --git 
a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp 
b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp
new file mode 100644
index 000000000..584bc0ef1
--- /dev/null
+++ 
b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <Parser/FunctionParser.h>
+#include <Common/Exception.h>
+#include <Poco/Logger.h>
+#include <Common/logger_useful.h>
+#include <Common/CHUtil.h>
+#include <DataTypes/DataTypeFunction.h>
+#include <DataTypes/DataTypeNullable.h>
+#include <Core/Types.h>
+#include <Parser/TypeParser.h>
+#include <Parser/scalar_function_parser/lambdaFunction.h>
+
+namespace DB::ErrorCodes
+{
+    extern const int LOGICAL_ERROR;
+}
+
+namespace local_engine
+{
+class ArrayFilter : public FunctionParser
+{
+public:
+    static constexpr auto name = "filter";
+    explicit ArrayFilter(SerializedPlanParser * plan_parser_) : 
FunctionParser(plan_parser_) {}
+    ~ArrayFilter() override = default;
+
+    String getName() const override { return name; }
+
+    String getCHFunctionName(const substrait::Expression_ScalarFunction & 
scalar_function) const override
+    {
+        return "arrayFilter";
+    }
+
+    const DB::ActionsDAG::Node * parse(const 
substrait::Expression_ScalarFunction & substrait_func,
+        DB::ActionsDAGPtr & actions_dag) const
+    {
+        auto ch_func_name = getCHFunctionName(substrait_func);
+        auto parsed_args = parseFunctionArguments(substrait_func, 
ch_func_name, actions_dag);
+        assert(parsed_args.size() == 2);
+        if (collectLambdaArguments(*plan_parser, 
substrait_func.arguments()[1].value().scalar_function()).size() == 1)
+            return toFunctionNode(actions_dag, ch_func_name, {parsed_args[1], 
parsed_args[0]});
+
+        /// filter with index argument.
+        const auto * range_end_node = toFunctionNode(actions_dag, "length", 
{toFunctionNode(actions_dag, "assumeNotNull", {parsed_args[0]})});
+        range_end_node = ActionsDAGUtil::convertNodeType(
+            actions_dag, range_end_node, "Nullable(Int32)", 
range_end_node->result_name);
+        const auto * index_array_node = toFunctionNode(
+            actions_dag,
+            "range",
+            {addColumnToActionsDAG(actions_dag, 
std::make_shared<DataTypeInt32>(), 0), range_end_node});
+        return toFunctionNode(actions_dag, ch_func_name, {parsed_args[1], 
parsed_args[0], index_array_node});
+    }
+};
+static FunctionParserRegister<ArrayFilter> register_array_filter;
+
+class ArrayTransform : public FunctionParser
+{
+public:
+    static constexpr auto name = "transform";
+    explicit ArrayTransform(SerializedPlanParser * plan_parser_) : 
FunctionParser(plan_parser_) {}
+    ~ArrayTransform() override = default;
+    String getName() const override { return name; }
+    String getCHFunctionName(const substrait::Expression_ScalarFunction & 
scalar_function) const override
+    {
+        return "arrayMap";
+    }
+
+    const DB::ActionsDAG::Node * parse(const 
substrait::Expression_ScalarFunction & substrait_func,
+        DB::ActionsDAGPtr & actions_dag) const
+    {
+        auto ch_func_name = getCHFunctionName(substrait_func);
+        auto lambda_args = collectLambdaArguments(*plan_parser, 
substrait_func.arguments()[1].value().scalar_function());
+        auto parsed_args = parseFunctionArguments(substrait_func, 
ch_func_name, actions_dag);
+        assert(parsed_args.size() == 2);
+        if (lambda_args.size() == 1)
+        {
+            return toFunctionNode(actions_dag, ch_func_name, {parsed_args[1], 
parsed_args[0]});
+        }
+
+        /// transform with index argument.
+        const auto * range_end_node = toFunctionNode(actions_dag, "length", 
{toFunctionNode(actions_dag, "assumeNotNull", {parsed_args[0]})});
+        range_end_node = ActionsDAGUtil::convertNodeType(
+            actions_dag, range_end_node, "Nullable(Int32)", 
range_end_node->result_name);
+        const auto * index_array_node = toFunctionNode(
+            actions_dag,
+            "range",
+            {addColumnToActionsDAG(actions_dag, 
std::make_shared<DataTypeInt32>(), 0), range_end_node});
+        return toFunctionNode(actions_dag, ch_func_name, {parsed_args[1], 
parsed_args[0], index_array_node});
+    }
+};
+static FunctionParserRegister<ArrayTransform> register_array_map;
+
+class ArrayAggregate : public FunctionParser
+{
+public:
+    static constexpr auto name = "aggregate";
+    explicit ArrayAggregate(SerializedPlanParser * plan_parser_) : 
FunctionParser(plan_parser_) {}
+    ~ArrayAggregate() override = default;
+    String getName() const override { return name; }
+    String getCHFunctionName(const substrait::Expression_ScalarFunction & 
scalar_function) const override
+    {
+        return "arrayFold";
+    }
+    const DB::ActionsDAG::Node * parse(const 
substrait::Expression_ScalarFunction & substrait_func,
+        DB::ActionsDAGPtr & actions_dag) const
+    {
+        auto ch_func_name = getCHFunctionName(substrait_func);
+        auto parsed_args = parseFunctionArguments(substrait_func, 
ch_func_name, actions_dag);
+        assert(parsed_args.size() == 3);
+        const auto * function_type = typeid_cast<const DataTypeFunction 
*>(parsed_args[2]->result_type.get());
+        if (!function_type)
+            throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "The third 
argument of aggregate function must be a lambda function");
+        if 
(!parsed_args[1]->result_type->equals(*(function_type->getReturnType())))
+        {
+            parsed_args[1] = ActionsDAGUtil::convertNodeType(
+                actions_dag,
+                parsed_args[1],
+                function_type->getReturnType()->getName(),
+                parsed_args[1]->result_name);
+        }
+
+        /// arrayFold cannot accept nullable(array)
+        const auto * array_col_node = parsed_args[0];
+        if (parsed_args[0]->result_type->isNullable())
+        {
+            array_col_node = toFunctionNode(actions_dag, "assumeNotNull", 
{parsed_args[0]});
+        }
+        const auto * func_node = toFunctionNode(actions_dag, ch_func_name, 
{parsed_args[2], array_col_node, parsed_args[1]});
+        /// For null array, result is null.
+        /// TODO: make a new version of arrayFold that can handle nullable 
array.
+        const auto * is_null_node = toFunctionNode(actions_dag, "isNull", 
{parsed_args[0]});
+        const auto * null_node = addColumnToActionsDAG(actions_dag, 
DB::makeNullable(func_node->result_type), DB::Null());
+        return toFunctionNode(actions_dag, "if", {is_null_node, null_node, 
func_node});
+    }
+};
+static FunctionParserRegister<ArrayAggregate> register_array_aggregate;
+
+}
diff --git 
a/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp 
b/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp
new file mode 100644
index 000000000..57c076ed2
--- /dev/null
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <Parser/FunctionParser.h>
+#include <Parser/TypeParser.h>
+#include <Common/Exception.h>
+#include <Poco/Logger.h>
+#include <Common/logger_useful.h>
+#include <Interpreters/ExpressionActionsSettings.h>
+#include <Interpreters/ExpressionActions.h>
+#include <Functions/FunctionsMiscellaneous.h>
+#include <Common/CHUtil.h>
+#include <unordered_set>
+
+namespace DB::ErrorCodes
+{
+    extern const int LOGICAL_ERROR;
+}
+
+namespace local_engine
+{
+DB::NamesAndTypesList collectLambdaArguments(const SerializedPlanParser & 
plan_parser_, const substrait::Expression_ScalarFunction & substrait_func)
+{
+    DB::NamesAndTypesList lambda_arguments;
+    std::unordered_set<String> collected_names;
+
+    for (const auto & arg : substrait_func.arguments())
+    {
+        if (arg.value().has_scalar_function()
+            && 
plan_parser_.getFunctionSignatureName(arg.value().scalar_function().function_reference())
 == "namedlambdavariable")
+        {
+            auto [_, col_name_field] = 
plan_parser_.parseLiteral(arg.value().scalar_function().arguments()[0].value().literal());
+            String col_name = col_name_field.get<String>();
+            if (collected_names.contains(col_name))
+            {
+                continue;
+            }
+            collected_names.insert(col_name);
+            auto type = 
TypeParser::parseType(arg.value().scalar_function().output_type());
+            lambda_arguments.emplace_back(col_name, type);
+        }
+    }
+    return lambda_arguments;
+}
+
+/// Refer to `PlannerActionsVisitorImpl::visitLambda` for how to build a 
lambda function node.
+class LambdaFunction : public FunctionParser
+{
+public:
+    static constexpr auto name = "lambdafunction";
+    explicit LambdaFunction(SerializedPlanParser * plan_parser_) : 
FunctionParser(plan_parser_) {}
+    ~LambdaFunction() override = default;
+
+    String getName() const override { return name; }
+protected:
+    String getCHFunctionName(const substrait::Expression_ScalarFunction & 
scalar_function) const override
+    {
+        throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "getCHFunctionName 
is not implemented for LambdaFunction");
+    }
+
+    DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments(
+        const substrait::Expression_ScalarFunction & substrait_func,
+        const String & ch_func_name,
+        DB::ActionsDAGPtr & actions_dag) const override
+    {
+        throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, 
"parseFunctionArguments is not implemented for LambdaFunction");
+    }
+ 
+    const DB::ActionsDAG::Node * convertNodeTypeIfNeeded(
+        const substrait::Expression_ScalarFunction & substrait_func,
+        const DB::ActionsDAG::Node * func_node,
+        DB::ActionsDAGPtr & actions_dag) const override
+    {
+        throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, 
"convertNodeTypeIfNeeded is not implemented for NamedLambdaVariable");
+    }
+
+    const DB::ActionsDAG::Node * parse(const 
substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAGPtr & 
actions_dag) const override
+    {
+        /// Some special cases, for example, `transform(arr, x -> concat(arr, 
array(x)))` refers to
+        /// a column `arr` out of it directly. We need a `arr` as an input 
column for `lambda_actions_dag`
+        DB::NamesAndTypesList parent_header;
+        for (const auto * output_node : actions_dag->getOutputs())
+        {
+            parent_header.emplace_back(output_node->result_name, 
output_node->result_type);
+        } 
+        auto lambda_actions_dag = 
std::make_shared<DB::ActionsDAG>(parent_header);
+
+        /// The first argument is the lambda function body, followings are the 
lambda arguments which is
+        /// needed by the lambda function body.
+        /// There could be a nested lambda function in the lambda function 
body, and it refer a variable from
+        /// this outside lambda function's arguments. For an example, 
transform(number, x -> transform(letter, y -> struct(x, y))).
+        /// Before parsing the lambda function body, we add lambda function 
arguments int actions dag at first.
+        for (size_t i = 1; i < substrait_func.arguments().size(); ++i)
+        {
+            (void)parseExpression(lambda_actions_dag, 
substrait_func.arguments()[i].value());
+        }
+        const auto & substrait_lambda_body = 
substrait_func.arguments()[0].value();
+        const auto * lambda_body_node = parseExpression(lambda_actions_dag, 
substrait_lambda_body);
+        lambda_actions_dag->getOutputs().push_back(lambda_body_node);
+        lambda_actions_dag->removeUnusedActions(Names(1, 
lambda_body_node->result_name));
+
+        auto expression_actions_settings = 
DB::ExpressionActionsSettings::fromContext(getContext(), 
DB::CompileExpressions::yes);
+        auto lambda_actions = 
std::make_shared<DB::ExpressionActions>(lambda_actions_dag, 
expression_actions_settings);
+
+        DB::Names captured_column_names;
+        DB::Names required_column_names = lambda_actions->getRequiredColumns();
+        DB::ActionsDAG::NodeRawConstPtrs lambda_children;
+        auto lambda_function_args = collectLambdaArguments(*plan_parser, 
substrait_func);
+        const auto & lambda_actions_inputs = lambda_actions_dag->getInputs();
+
+        std::unordered_map<String, const DB::ActionsDAG::Node *> parent_nodes;
+        for (const auto & node : actions_dag->getNodes())
+        {
+            parent_nodes[node.result_name] = &node;
+        }
+        for (const auto & required_column_name : required_column_names)
+        {
+            if (std::find_if(
+                    lambda_function_args.begin(),
+                    lambda_function_args.end(),
+                    [&required_column_name](const DB::NameAndTypePair & 
name_type) { return name_type.name == required_column_name; })
+                == lambda_function_args.end())
+            {
+                auto it = std::find_if(
+                    lambda_actions_inputs.begin(),
+                    lambda_actions_inputs.end(),
+                    [&required_column_name](const auto & node) { return 
node->result_name == required_column_name; });
+                if (it == lambda_actions_inputs.end())
+                {
+                    throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, 
"Required column not found: {}", required_column_name);
+                }
+                auto parent_node_it = parent_nodes.find(required_column_name);
+                if (parent_node_it == parent_nodes.end())
+                {
+                    throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Not 
found column {} in actions dag:\n{}",
+                        required_column_name,
+                        actions_dag->dumpDAG());
+                }
+                /// The nodes must be the ones in `actions_dag`, otherwise 
`ActionsDAG::evaluatePartialResult` will fail. Because nodes may have the
+                /// same name but their addresses are different.
+                lambda_children.push_back(parent_node_it->second);
+                captured_column_names.push_back(required_column_name);
+            }
+        }
+
+        auto function_capture = 
std::make_shared<DB::FunctionCaptureOverloadResolver>(
+            lambda_actions,
+            captured_column_names,
+            lambda_function_args,
+            lambda_body_node->result_type,
+            lambda_body_node->result_name);
+
+        const auto * result = &actions_dag->addFunction(function_capture, 
lambda_children, lambda_body_node->result_name);
+        return result;
+    }
+};
+
+static FunctionParserRegister<LambdaFunction> register_lambda_function;
+
+
+class NamedLambdaVariable : public FunctionParser
+{
+public:
+    static constexpr auto name = "namedlambdavariable";
+    explicit NamedLambdaVariable(SerializedPlanParser * plan_parser_) : 
FunctionParser(plan_parser_) {}
+    ~NamedLambdaVariable() override = default;
+
+    String getName() const override { return name; }
+protected:
+    String getCHFunctionName(const substrait::Expression_ScalarFunction & 
scalar_function) const override
+    {
+        throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "getCHFunctionName 
is not implemented for NamedLambdaVariable");
+    }
+
+    DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments(
+        const substrait::Expression_ScalarFunction & substrait_func,
+        const String & ch_func_name,
+        DB::ActionsDAGPtr & actions_dag) const override
+    {
+        throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, 
"parseFunctionArguments is not implemented for NamedLambdaVariable");
+    }
+
+    const DB::ActionsDAG::Node * convertNodeTypeIfNeeded(
+        const substrait::Expression_ScalarFunction & substrait_func,
+        const DB::ActionsDAG::Node * func_node,
+        DB::ActionsDAGPtr & actions_dag) const override
+    {
+        throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, 
"convertNodeTypeIfNeeded is not implemented for NamedLambdaVariable");
+    }
+
+    const DB::ActionsDAG::Node * parse(const 
substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAGPtr & 
actions_dag) const override
+    {
+        auto [_, col_name_field] = 
parseLiteral(substrait_func.arguments()[0].value().literal());
+        String col_name = col_name_field.get<String>();
+
+        auto type = TypeParser::parseType(substrait_func.output_type());
+        const auto & inputs = actions_dag->getInputs();
+        auto it = std::find_if(inputs.begin(), inputs.end(), [&col_name](const 
auto * node) { return node->result_name == col_name; });
+        if (it == inputs.end())
+        {
+            return &(actions_dag->addInput(col_name, type));
+        }
+        return *it;
+    }
+};
+
+static FunctionParserRegister<NamedLambdaVariable> 
register_named_lambda_variable;
+
+}
diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.h 
b/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.h
new file mode 100644
index 000000000..327c72ade
--- /dev/null
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.h
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+#include <Parser/FunctionParser.h>
+#include <Core/NamesAndTypes.h>
+namespace local_engine
+{
+DB::NamesAndTypesList collectLambdaArguments(const SerializedPlanParser & 
plan_parser_, const substrait::Expression_ScalarFunction & substrait_func);
+} 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to