This is an automated email from the ASF dual-hosted git repository.
liuneng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new de51f7cb2 [GLUTEN-4724][CH] Support function array_except (#7039)
de51f7cb2 is described below
commit de51f7cb2651120567c5d2ce71fdac1665983011
Author: 李扬 <[email protected]>
AuthorDate: Wed Aug 28 15:48:44 2024 +0800
[GLUTEN-4724][CH] Support function array_except (#7039)
What changes were proposed in this pull request?
(Please fill in changes proposed in this fix)
(Fixes: #4724)
How was this patch tested?
new added uts
---
.../org/apache/gluten/utils/CHExpressionUtil.scala | 1 -
.../execution/GlutenFunctionValidateSuite.scala | 8 ++
.../Parser/scalar_function_parser/arrayExcept.cpp | 108 +++++++++++++++++++++
3 files changed, 116 insertions(+), 1 deletion(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
index ae072b0fb..bb4710ef2 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
@@ -195,7 +195,6 @@ object CHExpressionUtil {
DATE_FORMAT -> DateFormatClassValidator(),
DECODE -> EncodeDecodeValidator(),
ENCODE -> EncodeDecodeValidator(),
- ARRAY_EXCEPT -> DefaultValidator(),
ARRAY_REPEAT -> DefaultValidator(),
ARRAY_REMOVE -> DefaultValidator(),
ARRAYS_ZIP -> DefaultValidator(),
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
index 45485ac90..1278264b4 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
@@ -755,4 +755,12 @@ class GlutenFunctionValidateSuite extends
GlutenClickHouseWholeStageTransformerS
|""".stripMargin
runQueryAndCompare(sql)(checkGlutenOperatorMatch[ProjectExecTransformer])
}
+
+ test("test function array_except") {
+ val sql = """
+ |SELECT array_except(array(id, id+1, id+2), array(id+2, id+3))
+ |FROM RANGE(10)
+ |""".stripMargin
+ runQueryAndCompare(sql)(checkGlutenOperatorMatch[ProjectExecTransformer])
+ }
}
diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayExcept.cpp
b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayExcept.cpp
new file mode 100644
index 000000000..e90fd4070
--- /dev/null
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayExcept.cpp
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <DataTypes/DataTypeArray.h>
+#include <Functions/FunctionsMiscellaneous.h>
+#include <Parser/FunctionParser.h>
+#include <Common/Exception.h>
+#include <Common/assert_cast.h>
+
+namespace DB
+{
+namespace ErrorCodes
+{
+extern const int SIZES_OF_COLUMNS_DOESNT_MATCH;
+};
+};
+
+namespace local_engine
+{
+class FunctionParserArrayExcept : public FunctionParser
+{
+public:
+ FunctionParserArrayExcept(SerializedPlanParser * plan_parser_) :
FunctionParser(plan_parser_) { }
+ ~FunctionParserArrayExcept() override = default;
+
+ static constexpr auto name = "array_except";
+ String getName() const override { return name; }
+
+ const DB::ActionsDAG::Node *
+ parse(const substrait::Expression_ScalarFunction & substrait_func,
DB::ActionsDAG & actions_dag) const override
+ {
+ auto parsed_args = parseFunctionArguments(substrait_func, actions_dag);
+ if (parsed_args.size() != 2)
+ throw Exception(DB::ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH,
"Function {} requires exactly two arguments", getName());
+
+ /// Parse spark array_except(arr1, arr2)
+ /// if (arr1 == null || arr2 == null)
+ /// return null
+ /// else
+ /// return arrayDistinct(arrayFilter(x -> !has(assumeNotNull(arr2),
x), assumeNotNull(arr1)))
+ const auto * arr1_arg = parsed_args[0];
+ const auto * arr2_arg = parsed_args[1];
+ const auto * arr1_not_null = toFunctionNode(actions_dag,
"assumeNotNull", {arr1_arg});
+ const auto * arr2_not_null = toFunctionNode(actions_dag,
"assumeNotNull", {arr2_arg});
+ // std::cout << "actions_dag:" << actions_dag.dumpDAG() << std::endl;
+
+ // Create lambda function x -> !has(arr2, x)
+ ActionsDAG lambda_actions_dag;
+ const auto * arr2_in_lambda =
&lambda_actions_dag.addInput(arr2_not_null->result_name,
arr2_not_null->result_type);
+ const auto & nested_type = assert_cast<const DataTypeArray
&>(*removeNullable(arr1_not_null->result_type)).getNestedType();
+ const auto * x_in_lambda = &lambda_actions_dag.addInput("x",
nested_type);
+ const auto * has_in_lambda = toFunctionNode(lambda_actions_dag, "has",
{arr2_in_lambda, x_in_lambda});
+ const auto * lambda_output = toFunctionNode(lambda_actions_dag, "not",
{has_in_lambda});
+ lambda_actions_dag.getOutputs().push_back(lambda_output);
+ lambda_actions_dag.removeUnusedActions(Names(1,
lambda_output->result_name));
+
+ auto expression_actions_settings =
DB::ExpressionActionsSettings::fromContext(getContext(),
DB::CompileExpressions::yes);
+ auto lambda_actions =
std::make_shared<DB::ExpressionActions>(std::move(lambda_actions_dag),
expression_actions_settings);
+
+ DB::Names captured_column_names{arr2_in_lambda->result_name};
+ NamesAndTypesList lambda_arguments_names_and_types;
+
lambda_arguments_names_and_types.emplace_back(x_in_lambda->result_name,
x_in_lambda->result_type);
+ DB::Names required_column_names = lambda_actions->getRequiredColumns();
+ auto function_capture =
std::make_shared<FunctionCaptureOverloadResolver>(
+ lambda_actions,
+ captured_column_names,
+ lambda_arguments_names_and_types,
+ lambda_output->result_type,
+ lambda_output->result_name);
+ const auto * lambda_function =
&actions_dag.addFunction(function_capture, {arr2_not_null},
lambda_output->result_name);
+
+ // Apply arrayFilter with the lambda function
+ const auto * array_filter_node = toFunctionNode(actions_dag,
"arrayFilter", {lambda_function, arr1_not_null});
+
+ // Apply arrayDistinct to the result of arrayFilter
+ const auto * array_distinct_node = toFunctionNode(actions_dag,
"arrayDistinct", {array_filter_node});
+
+ /// Return null if any of arr1 or arr2 is null
+ const auto * arr1_is_null_node = toFunctionNode(actions_dag, "isNull",
{arr1_arg});
+ const auto * arr2_is_null_node = toFunctionNode(actions_dag, "isNull",
{arr2_arg});
+ const auto * null_array_node
+ = addColumnToActionsDAG(actions_dag,
std::make_shared<DataTypeNullable>(array_distinct_node->result_type), {});
+ const auto * multi_if_node = toFunctionNode(actions_dag, "multiIf", {
+ arr1_is_null_node,
+ null_array_node,
+ arr2_is_null_node,
+ null_array_node,
+ array_distinct_node,
+ });
+ return convertNodeTypeIfNeeded(substrait_func, multi_if_node,
actions_dag);
+ }
+};
+
+static FunctionParserRegister<FunctionParserArrayExcept> register_array_except;
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]