This is an automated email from the ASF dual-hosted git repository.

liuneng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new de51f7cb2 [GLUTEN-4724][CH] Support function array_except (#7039)
de51f7cb2 is described below

commit de51f7cb2651120567c5d2ce71fdac1665983011
Author: 李扬 <[email protected]>
AuthorDate: Wed Aug 28 15:48:44 2024 +0800

    [GLUTEN-4724][CH] Support function array_except (#7039)
    
    What changes were proposed in this pull request?
    (Please fill in changes proposed in this fix)
    
    (Fixes: #4724)
    
    How was this patch tested?
    new added uts
---
 .../org/apache/gluten/utils/CHExpressionUtil.scala |   1 -
 .../execution/GlutenFunctionValidateSuite.scala    |   8 ++
 .../Parser/scalar_function_parser/arrayExcept.cpp  | 108 +++++++++++++++++++++
 3 files changed, 116 insertions(+), 1 deletion(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
index ae072b0fb..bb4710ef2 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
@@ -195,7 +195,6 @@ object CHExpressionUtil {
     DATE_FORMAT -> DateFormatClassValidator(),
     DECODE -> EncodeDecodeValidator(),
     ENCODE -> EncodeDecodeValidator(),
-    ARRAY_EXCEPT -> DefaultValidator(),
     ARRAY_REPEAT -> DefaultValidator(),
     ARRAY_REMOVE -> DefaultValidator(),
     ARRAYS_ZIP -> DefaultValidator(),
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
index 45485ac90..1278264b4 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
@@ -755,4 +755,12 @@ class GlutenFunctionValidateSuite extends 
GlutenClickHouseWholeStageTransformerS
                 |""".stripMargin
     runQueryAndCompare(sql)(checkGlutenOperatorMatch[ProjectExecTransformer])
   }
+
+  test("test function array_except") {
+    val sql = """
+                |SELECT array_except(array(id, id+1, id+2), array(id+2, id+3))
+                |FROM RANGE(10)
+                |""".stripMargin
+    runQueryAndCompare(sql)(checkGlutenOperatorMatch[ProjectExecTransformer])
+  }
 }
diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayExcept.cpp 
b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayExcept.cpp
new file mode 100644
index 000000000..e90fd4070
--- /dev/null
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayExcept.cpp
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <DataTypes/DataTypeArray.h>
+#include <Functions/FunctionsMiscellaneous.h>
+#include <Parser/FunctionParser.h>
+#include <Common/Exception.h>
+#include <Common/assert_cast.h>
+
+namespace DB
+{
+namespace ErrorCodes
+{
+extern const int SIZES_OF_COLUMNS_DOESNT_MATCH;
+};
+};
+
+namespace local_engine
+{
+class FunctionParserArrayExcept : public FunctionParser
+{
+public:
+    FunctionParserArrayExcept(SerializedPlanParser * plan_parser_) : 
FunctionParser(plan_parser_) { }
+    ~FunctionParserArrayExcept() override = default;
+
+    static constexpr auto name = "array_except";
+    String getName() const override { return name; }
+
+    const DB::ActionsDAG::Node *
+    parse(const substrait::Expression_ScalarFunction & substrait_func, 
DB::ActionsDAG & actions_dag) const override
+    {
+        auto parsed_args = parseFunctionArguments(substrait_func, actions_dag);
+        if (parsed_args.size() != 2)
+            throw Exception(DB::ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH, 
"Function {} requires exactly two arguments", getName());
+
+        /// Parse spark array_except(arr1, arr2)
+        /// if (arr1 == null || arr2 == null)
+        ///    return null
+        /// else
+        ///    return arrayDistinct(arrayFilter(x -> !has(assumeNotNull(arr2), 
x), assumeNotNull(arr1)))
+        const auto * arr1_arg = parsed_args[0];
+        const auto * arr2_arg = parsed_args[1];
+        const auto * arr1_not_null = toFunctionNode(actions_dag, 
"assumeNotNull", {arr1_arg});
+        const auto * arr2_not_null = toFunctionNode(actions_dag, 
"assumeNotNull", {arr2_arg});
+        // std::cout << "actions_dag:" << actions_dag.dumpDAG() << std::endl;
+
+        // Create lambda function x -> !has(arr2, x)
+        ActionsDAG lambda_actions_dag;
+        const auto * arr2_in_lambda = 
&lambda_actions_dag.addInput(arr2_not_null->result_name, 
arr2_not_null->result_type);
+        const auto & nested_type = assert_cast<const DataTypeArray 
&>(*removeNullable(arr1_not_null->result_type)).getNestedType();
+        const auto * x_in_lambda = &lambda_actions_dag.addInput("x", 
nested_type);
+        const auto * has_in_lambda = toFunctionNode(lambda_actions_dag, "has", 
{arr2_in_lambda, x_in_lambda});
+        const auto * lambda_output = toFunctionNode(lambda_actions_dag, "not", 
{has_in_lambda});
+        lambda_actions_dag.getOutputs().push_back(lambda_output);
+        lambda_actions_dag.removeUnusedActions(Names(1, 
lambda_output->result_name));
+
+        auto expression_actions_settings = 
DB::ExpressionActionsSettings::fromContext(getContext(), 
DB::CompileExpressions::yes);
+        auto lambda_actions = 
std::make_shared<DB::ExpressionActions>(std::move(lambda_actions_dag), 
expression_actions_settings);
+
+        DB::Names captured_column_names{arr2_in_lambda->result_name};
+        NamesAndTypesList lambda_arguments_names_and_types;
+        
lambda_arguments_names_and_types.emplace_back(x_in_lambda->result_name, 
x_in_lambda->result_type);
+        DB::Names required_column_names = lambda_actions->getRequiredColumns();
+        auto function_capture = 
std::make_shared<FunctionCaptureOverloadResolver>(
+            lambda_actions,
+            captured_column_names,
+            lambda_arguments_names_and_types,
+            lambda_output->result_type,
+            lambda_output->result_name);
+        const auto * lambda_function = 
&actions_dag.addFunction(function_capture, {arr2_not_null}, 
lambda_output->result_name);
+
+        // Apply arrayFilter with the lambda function
+        const auto * array_filter_node = toFunctionNode(actions_dag, 
"arrayFilter", {lambda_function, arr1_not_null});
+
+        // Apply arrayDistinct to the result of arrayFilter
+        const auto * array_distinct_node = toFunctionNode(actions_dag, 
"arrayDistinct", {array_filter_node});
+
+        /// Return null if any of arr1 or arr2 is null
+        const auto * arr1_is_null_node = toFunctionNode(actions_dag, "isNull", 
{arr1_arg});
+        const auto * arr2_is_null_node = toFunctionNode(actions_dag, "isNull", 
{arr2_arg});
+        const auto * null_array_node
+            = addColumnToActionsDAG(actions_dag, 
std::make_shared<DataTypeNullable>(array_distinct_node->result_type), {});
+        const auto * multi_if_node = toFunctionNode(actions_dag, "multiIf", {
+            arr1_is_null_node,
+            null_array_node,
+            arr2_is_null_node,
+            null_array_node,
+            array_distinct_node,
+        });
+        return convertNodeTypeIfNeeded(substrait_func, multi_if_node, 
actions_dag);
+    }
+};
+
+static FunctionParserRegister<FunctionParserArrayExcept> register_array_except;
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to