This is an automated email from the ASF dual-hosted git repository.
lgbo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 995145e93 support sort_array (#6323)
995145e93 is described below
commit 995145e93cb8c930501e69156ca57211d9932d2a
Author: lgbo <[email protected]>
AuthorDate: Fri Jul 5 08:34:20 2024 +0800
support sort_array (#6323)
---
.../clickhouse/CHSparkPlanExecApi.scala | 9 +
.../Functions/SparkFunctionArraySort.cpp | 223 +++++++++++++++++----
...ionArraySort.cpp => SparkFunctionSortArray.cpp} | 10 +-
...unctionArraySort.h => SparkFunctionSortArray.h} | 14 +-
.../arrayHighOrderFunctions.cpp | 144 +++++++++++++
.../Parser/scalar_function_parser/sortArray.cpp | 4 +-
.../gluten/backendsapi/SparkPlanExecApi.scala | 9 +
.../gluten/expression/ExpressionConverter.scala | 13 ++
.../gluten/expression/ExpressionMappings.scala | 1 +
.../apache/gluten/expression/ExpressionNames.scala | 1 +
10 files changed, 371 insertions(+), 57 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index add82cbb5..f5feade88 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -876,6 +876,15 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
GenericExpressionTransformer(substraitExprName, Seq(argument, function),
expr)
}
+ /** Transform array sort to Substrait. */
+ override def genArraySortTransformer(
+ substraitExprName: String,
+ argument: ExpressionTransformer,
+ function: ExpressionTransformer,
+ expr: ArraySort): ExpressionTransformer = {
+ GenericExpressionTransformer(substraitExprName, Seq(argument, function),
expr)
+ }
+
override def genPreProjectForGenerate(generate: GenerateExec): SparkPlan =
generate
override def genPostProjectForGenerate(generate: GenerateExec): SparkPlan =
generate
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionArraySort.cpp
b/cpp-ch/local-engine/Functions/SparkFunctionArraySort.cpp
index 126b84eaa..1371ec60e 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionArraySort.cpp
+++ b/cpp-ch/local-engine/Functions/SparkFunctionArraySort.cpp
@@ -14,75 +14,212 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-#include <Functions/SparkFunctionArraySort.h>
+#include <Functions/FunctionHelpers.h>
#include <Functions/FunctionFactory.h>
+#include <Columns/ColumnArray.h>
+#include <Columns/ColumnFunction.h>
+#include <Columns/ColumnNullable.h>
+#include <Common/Exception.h>
+#include <DataTypes/DataTypeArray.h>
+#include <DataTypes/DataTypeFunction.h>
+#include <DataTypes/DataTypeLowCardinality.h>
+#include <Poco/Logger.h>
+#include <Common/logger_useful.h>
+#include <base/sort.h>
-namespace DB
+namespace DB::ErrorCodes
{
+ extern const int TOO_FEW_ARGUMENTS_FOR_FUNCTION;
+ extern const int ILLEGAL_TYPE_OF_ARGUMENT;
+ extern const int TYPE_MISMATCH;
+ extern const int ILLEGAL_COLUMN;
+}
-namespace ErrorCodes
+/// The usage of `arraySort` in CH is different from Spark's `sort_array`
function.
+/// We need to implement a custom function to sort arrays.
+namespace local_engine
{
- extern const int LOGICAL_ERROR;
-}
-namespace
+struct LambdaLess
{
+ const DB::IColumn & column;
+ DB::DataTypePtr type;
+ const DB::ColumnFunction & lambda;
+ explicit LambdaLess(const DB::IColumn & column_, DB::DataTypePtr type_,
const DB::ColumnFunction & lambda_)
+ : column(column_), type(type_), lambda(lambda_) {}
+
+ /// May not efficient
+ bool operator()(size_t lhs, size_t rhs) const
+ {
+ /// The column name seems not matter.
+ auto left_value_col = DB::ColumnWithTypeAndName(oneRowColumn(lhs),
type, "left");
+ auto right_value_col = DB::ColumnWithTypeAndName(oneRowColumn(rhs),
type, "right");
+ auto cloned_lambda = lambda.cloneResized(1);
+ auto * lambda_ = typeid_cast<DB::ColumnFunction
*>(cloned_lambda.get());
+ lambda_->appendArguments({std::move(left_value_col),
std::move(right_value_col)});
+ auto compare_res_col = lambda_->reduce();
+ DB::Field field;
+ compare_res_col.column->get(0, field);
+ return field.get<Int32>() < 0;
+ }
+private:
+ ALWAYS_INLINE DB::ColumnPtr oneRowColumn(size_t i) const
+ {
+ auto res = column.cloneEmpty();
+ res->insertFrom(column, i);
+ return std::move(res);
+ }
+};
-template <bool positive>
struct Less
{
- const IColumn & column;
+ const DB::IColumn & column;
- explicit Less(const IColumn & column_) : column(column_) { }
+ explicit Less(const DB::IColumn & column_) : column(column_) { }
bool operator()(size_t lhs, size_t rhs) const
{
- if constexpr (positive)
- /*
- Note: We use nan_direction_hint=-1 for ascending sort to make
NULL the least value.
- However, NaN is also considered the least value,
- which results in different sorting results compared to Spark
since Spark treats NaN as the greatest value.
- For now, we are temporarily ignoring this issue because cases
with NaN are rare,
- and aligning with Spark would require tricky modifications to
the CH underlying code.
- */
- return column.compareAt(lhs, rhs, column, -1) < 0;
- else
- return column.compareAt(lhs, rhs, column, -1) > 0;
+ return column.compareAt(lhs, rhs, column, 1) < 0;
}
};
-}
-
-template <bool positive>
-ColumnPtr SparkArraySortImpl<positive>::execute(
- const ColumnArray & array,
- ColumnPtr mapped,
- const ColumnWithTypeAndName * fixed_arguments [[maybe_unused]])
+class FunctionSparkArraySort : public DB::IFunction
{
- const ColumnArray::Offsets & offsets = array.getOffsets();
+public:
+ static constexpr auto name = "arraySortSpark";
+ static DB::FunctionPtr create(DB::ContextPtr /*context*/) { return
std::make_shared<FunctionSparkArraySort>(); }
- size_t size = offsets.size();
- size_t nested_size = array.getData().size();
- IColumn::Permutation permutation(nested_size);
+ bool isVariadic() const override { return true; }
+ size_t getNumberOfArguments() const override { return 0; }
+ bool isSuitableForShortCircuitArgumentsExecution(const
DB::DataTypesWithConstInfo &) const override { return true; }
- for (size_t i = 0; i < nested_size; ++i)
- permutation[i] = i;
+ bool useDefaultImplementationForNulls() const override { return false; }
+ bool useDefaultImplementationForLowCardinalityColumns() const override {
return false; }
- ColumnArray::Offset current_offset = 0;
- for (size_t i = 0; i < size; ++i)
+ void getLambdaArgumentTypes(DB::DataTypes & arguments) const override
{
- auto next_offset = offsets[i];
- ::sort(&permutation[current_offset], &permutation[next_offset],
Less<positive>(*mapped));
- current_offset = next_offset;
+ if (arguments.size() < 2)
+ throw
DB::Exception(DB::ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION, "Function {}
requires as arguments a lambda function and an array", getName());
+
+ if (arguments.size() > 1)
+ {
+ const auto * lambda_function_type =
DB::checkAndGetDataType<DB::DataTypeFunction>(arguments[0].get());
+ if (!lambda_function_type ||
lambda_function_type->getArgumentTypes().size() != 2)
+ throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"First argument of function {} must be a lambda function with 2 arguments,
found {} instead.",
+ getName(), arguments[0]->getName());
+ auto array_nesteed_type =
DB::checkAndGetDataType<DB::DataTypeArray>(arguments.back().get())->getNestedType();
+ DB::DataTypes lambda_args = {array_nesteed_type,
array_nesteed_type};
+ arguments[0] = std::make_shared<DB::DataTypeFunction>(lambda_args);
+ }
}
- return ColumnArray::create(array.getData().permute(permutation, 0),
array.getOffsetsPtr());
-}
+ DB::DataTypePtr getReturnTypeImpl(const DB::ColumnsWithTypeAndName &
arguments) const override
+ {
+ if (arguments.size() > 1)
+ {
+ const auto * lambda_function_type =
checkAndGetDataType<DB::DataTypeFunction>(arguments[0].type.get());
+ if (!lambda_function_type)
+ throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"First argument for function {} must be a function", getName());
+ }
+
+ return arguments.back().type;
+ }
+
+ DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments,
const DB::DataTypePtr &, size_t input_rows_count) const override
+ {
+ auto array_col = arguments.back().column;
+ auto array_type = arguments.back().type;
+ DB::ColumnPtr null_map = nullptr;
+ if (const auto * null_col = typeid_cast<const DB::ColumnNullable
*>(array_col.get()))
+ {
+ null_map = null_col->getNullMapColumnPtr();
+ array_col = null_col->getNestedColumnPtr();
+ array_type = typeid_cast<const DB::DataTypeNullable
*>(array_type.get())->getNestedType();
+ }
+
+ const auto * array_col_concrete =
DB::checkAndGetColumn<DB::ColumnArray>(array_col.get());
+ if (!array_col_concrete)
+ {
+ const auto * aray_col_concrete_const =
DB::checkAndGetColumnConst<DB::ColumnArray>(array_col.get());
+ if (!aray_col_concrete_const)
+ {
+ throw DB::Exception(DB::ErrorCodes::ILLEGAL_COLUMN, "Expected
array column, found {}", array_col->getName());
+ }
+ array_col =
DB::recursiveRemoveLowCardinality(aray_col_concrete_const->convertToFullColumn());
+ array_col_concrete =
DB::checkAndGetColumn<DB::ColumnArray>(array_col.get());
+ }
+ auto array_nested_type =
DB::checkAndGetDataType<DB::DataTypeArray>(array_type.get())->getNestedType();
+
+ DB::ColumnPtr sorted_array_col = nullptr;
+ if (arguments.size() > 1)
+ sorted_array_col = executeWithLambda(*array_col_concrete,
array_nested_type,
*checkAndGetColumn<DB::ColumnFunction>(arguments[0].column.get()));
+ else
+ sorted_array_col = executeWithoutLambda(*array_col_concrete);
+
+ if (null_map)
+ {
+ sorted_array_col = DB::ColumnNullable::create(sorted_array_col,
null_map);
+ }
+ return sorted_array_col;
+ }
+private:
+ static DB::ColumnPtr executeWithLambda(const DB::ColumnArray & array_col,
DB::DataTypePtr array_nested_type, const DB::ColumnFunction & lambda)
+ {
+ const auto & offsets = array_col.getOffsets();
+ auto rows = array_col.size();
+
+ size_t nested_size = array_col.getData().size();
+ DB::IColumn::Permutation permutation(nested_size);
+ for (size_t i = 0; i < nested_size; ++i)
+ permutation[i] = i;
+
+ DB::ColumnArray::Offset current_offset = 0;
+ for (size_t i = 0; i < rows; ++i)
+ {
+ auto next_offset = offsets[i];
+ ::sort(&permutation[current_offset],
+ &permutation[next_offset],
+ LambdaLess(array_col.getData(),
+ array_nested_type,
+ lambda));
+ current_offset = next_offset;
+ }
+ auto res =
DB::ColumnArray::create(array_col.getData().permute(permutation, 0),
array_col.getOffsetsPtr());
+ return res;
+ }
+
+ static DB::ColumnPtr executeWithoutLambda(const DB::ColumnArray &
array_col)
+ {
+ const auto & offsets = array_col.getOffsets();
+ auto rows = array_col.size();
+
+ size_t nested_size = array_col.getData().size();
+ DB::IColumn::Permutation permutation(nested_size);
+ for (size_t i = 0; i < nested_size; ++i)
+ permutation[i] = i;
+
+ DB::ColumnArray::Offset current_offset = 0;
+ for (size_t i = 0; i < rows; ++i)
+ {
+ auto next_offset = offsets[i];
+ ::sort(&permutation[current_offset],
+ &permutation[next_offset],
+ Less(array_col.getData()));
+ current_offset = next_offset;
+ }
+ auto res =
DB::ColumnArray::create(array_col.getData().permute(permutation, 0),
array_col.getOffsetsPtr());
+ return res;
+ }
+
+ String getName() const override
+ {
+ return name;
+ }
+
+};
REGISTER_FUNCTION(ArraySortSpark)
{
- factory.registerFunction<SparkFunctionArraySort>();
- factory.registerFunction<SparkFunctionArrayReverseSort>();
+ factory.registerFunction<FunctionSparkArraySort>();
}
-
}
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionArraySort.cpp
b/cpp-ch/local-engine/Functions/SparkFunctionSortArray.cpp
similarity index 91%
copy from cpp-ch/local-engine/Functions/SparkFunctionArraySort.cpp
copy to cpp-ch/local-engine/Functions/SparkFunctionSortArray.cpp
index 126b84eaa..42b88fbce 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionArraySort.cpp
+++ b/cpp-ch/local-engine/Functions/SparkFunctionSortArray.cpp
@@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-#include <Functions/SparkFunctionArraySort.h>
+#include <Functions/SparkFunctionSortArray.h>
#include <Functions/FunctionFactory.h>
namespace DB
@@ -54,7 +54,7 @@ struct Less
}
template <bool positive>
-ColumnPtr SparkArraySortImpl<positive>::execute(
+ColumnPtr SparkSortArrayImpl<positive>::execute(
const ColumnArray & array,
ColumnPtr mapped,
const ColumnWithTypeAndName * fixed_arguments [[maybe_unused]])
@@ -79,10 +79,10 @@ ColumnPtr SparkArraySortImpl<positive>::execute(
return ColumnArray::create(array.getData().permute(permutation, 0),
array.getOffsetsPtr());
}
-REGISTER_FUNCTION(ArraySortSpark)
+REGISTER_FUNCTION(SortArraySpark)
{
- factory.registerFunction<SparkFunctionArraySort>();
- factory.registerFunction<SparkFunctionArrayReverseSort>();
+ factory.registerFunction<SparkFunctionSortArray>();
+ factory.registerFunction<SparkFunctionReverseSortArray>();
}
}
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionArraySort.h
b/cpp-ch/local-engine/Functions/SparkFunctionSortArray.h
similarity index 86%
rename from cpp-ch/local-engine/Functions/SparkFunctionArraySort.h
rename to cpp-ch/local-engine/Functions/SparkFunctionSortArray.h
index 9ce48f9c0..18c2128c0 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionArraySort.h
+++ b/cpp-ch/local-engine/Functions/SparkFunctionSortArray.h
@@ -32,7 +32,7 @@ namespace ErrorCodes
/** Sort arrays, by values of its elements, or by values of corresponding
elements of calculated expression (known as "schwartzsort").
*/
template <bool positive>
-struct SparkArraySortImpl
+struct SparkSortArrayImpl
{
static bool needBoolean() { return false; }
static bool needExpression() { return false; }
@@ -67,16 +67,16 @@ struct SparkArraySortImpl
const ColumnWithTypeAndName * fixed_arguments [[maybe_unused]] =
nullptr);
};
-struct NameArraySort
+struct NameSortArray
{
- static constexpr auto name = "arraySortSpark";
+ static constexpr auto name = "sortArraySpark";
};
-struct NameArrayReverseSort
+struct NameReverseSortArray
{
- static constexpr auto name = "arrayReverseSortSpark";
+ static constexpr auto name = "reverseSortArraySpark";
};
-using SparkFunctionArraySort = FunctionArrayMapped<SparkArraySortImpl<true>,
NameArraySort>;
-using SparkFunctionArrayReverseSort =
FunctionArrayMapped<SparkArraySortImpl<false>, NameArrayReverseSort>;
+using SparkFunctionSortArray = FunctionArrayMapped<SparkSortArrayImpl<true>,
NameSortArray>;
+using SparkFunctionReverseSortArray =
FunctionArrayMapped<SparkSortArrayImpl<false>, NameReverseSortArray>;
}
diff --git
a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp
b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp
index 584bc0ef1..3811880ae 100644
---
a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp
+++
b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp
@@ -151,4 +151,148 @@ public:
};
static FunctionParserRegister<ArrayAggregate> register_array_aggregate;
+class ArraySort : public FunctionParser
+{
+public:
+ static constexpr auto name = "array_sort";
+ explicit ArraySort(SerializedPlanParser * plan_parser_) :
FunctionParser(plan_parser_) {}
+ ~ArraySort() override = default;
+ String getName() const override { return name; }
+ String getCHFunctionName(const substrait::Expression_ScalarFunction &
scalar_function) const override
+ {
+ return "arraySortSpark";
+ }
+ const DB::ActionsDAG::Node * parse(const
substrait::Expression_ScalarFunction & substrait_func,
+ DB::ActionsDAGPtr & actions_dag) const
+ {
+ auto ch_func_name = getCHFunctionName(substrait_func);
+ auto parsed_args = parseFunctionArguments(substrait_func,
ch_func_name, actions_dag);
+
+ if (parsed_args.size() != 2)
+ throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "array_sort
function must have two arguments");
+ if
(isDefaultCompare(substrait_func.arguments()[1].value().scalar_function()))
+ {
+ return toFunctionNode(actions_dag, ch_func_name, {parsed_args[0]});
+ }
+
+ return toFunctionNode(actions_dag, ch_func_name, {parsed_args[1],
parsed_args[0]});
+ }
+private:
+
+ /// The default lambda compare function for array_sort, `array_sort(x)`.
+ bool isDefaultCompare(const substrait::Expression_ScalarFunction &
scalar_function) const
+ {
+ String left_variable_name, right_variable_name;
+ auto names_types = collectLambdaArguments(*plan_parser,
scalar_function);
+ {
+ auto it = names_types.begin();
+ left_variable_name = it->name;
+ it++;
+ right_variable_name = it->name;
+ }
+
+ auto is_function = [&](const substrait::Expression & expr, const
String & function_name) {
+ return expr.has_scalar_function()
+ &&
*(plan_parser->getFunctionSignatureName(expr.scalar_function().function_reference()))
== function_name;
+ };
+
+ auto is_variable = [&](const substrait::Expression & expr, const
String & var) {
+ if (!is_function(expr, "namedlambdavariable"))
+ {
+ return false;
+ }
+ const auto var_expr =
expr.scalar_function().arguments()[0].value();
+ if (!var_expr.has_literal())
+ return false;
+ auto [_, name] = plan_parser->parseLiteral(var_expr.literal());
+ return var == name.get<String>();
+ };
+
+ auto is_int_value = [&](const substrait::Expression & expr, Int32 val)
{
+ if (!expr.has_literal())
+ return false;
+ auto [_, x] = plan_parser->parseLiteral(expr.literal());
+ return val == x.get<Int32>();
+ };
+
+ auto is_variable_null = [&](const substrait::Expression & expr, const
String & var) {
+ return is_function(expr, "is_null") &&
is_variable(expr.scalar_function().arguments(0).value(), var);
+ };
+
+ auto is_both_null = [&](const substrait::Expression & expr) {
+ return is_function(expr, "and")
+ &&
is_variable_null(expr.scalar_function().arguments(0).value(),
left_variable_name)
+ &&
is_variable_null(expr.scalar_function().arguments(1).value(),
right_variable_name);
+ };
+
+ auto is_left_greater_right = [&](const substrait::Expression & expr) {
+ if (!expr.has_if_then())
+ return false;
+
+ const auto & if_ = expr.if_then().ifs(0);
+ if (!is_function(if_.if_(), "gt"))
+ return false;
+
+ const auto & less_args = if_.if_().scalar_function().arguments();
+ return is_variable(less_args[0].value(), left_variable_name)
+ && is_variable(less_args[1].value(), right_variable_name)
+ && is_int_value(if_.then(), 1)
+ && is_int_value(expr.if_then().else_(), 0);
+ };
+
+ auto is_left_less_right = [&](const substrait::Expression & expr) {
+ if (!expr.has_if_then())
+ return false;
+
+ const auto & if_ = expr.if_then().ifs(0);
+ if (!is_function(if_.if_(), "lt"))
+ return false;
+
+ const auto & less_args = if_.if_().scalar_function().arguments();
+ return is_variable(less_args[0].value(), left_variable_name)
+ && is_variable(less_args[1].value(), right_variable_name)
+ && is_int_value(if_.then(), -1)
+ && is_left_greater_right(expr.if_then().else_());
+ };
+
+ auto is_right_null_else = [&](const substrait::Expression & expr) {
+ if (!expr.has_if_then())
+ return false;
+
+ /// if right arg is null, return 1
+ const auto & if_then = expr.if_then();
+ return is_variable_null(if_then.ifs(0).if_(), right_variable_name)
+ && is_int_value(if_then.ifs(0).then(), -1)
+ && is_left_less_right(if_then.else_());
+
+ };
+
+ auto is_left_null_else = [&](const substrait::Expression & expr) {
+ if (!expr.has_if_then())
+ return false;
+
+ /// if left arg is null, return 1
+ const auto & if_then = expr.if_then();
+ return is_variable_null(if_then.ifs(0).if_(), left_variable_name)
+ && is_int_value(if_then.ifs(0).then(), 1)
+ && is_right_null_else(if_then.else_());
+ };
+
+ auto is_if_both_null_else = [&](const substrait::Expression & expr) {
+ if (!expr.has_if_then())
+ {
+ return false;
+ }
+ const auto & if_ = expr.if_then().ifs(0);
+ return is_both_null(if_.if_())
+ && is_int_value(if_.then(), 0)
+ && is_left_null_else(expr.if_then().else_());
+ };
+
+ const auto & lambda_body = scalar_function.arguments()[0].value();
+ return is_if_both_null_else(lambda_body);
+ }
+};
+static FunctionParserRegister<ArraySort> register_array_sort;
+
}
diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/sortArray.cpp
b/cpp-ch/local-engine/Parser/scalar_function_parser/sortArray.cpp
index 85416bd71..4fd2fd4f6 100644
--- a/cpp-ch/local-engine/Parser/scalar_function_parser/sortArray.cpp
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/sortArray.cpp
@@ -52,8 +52,8 @@ public:
const auto * array_arg = parsed_args[0];
const auto * order_arg = parsed_args[1];
- const auto * sort_node = toFunctionNode(actions_dag, "arraySortSpark",
{array_arg});
- const auto * reverse_sort_node = toFunctionNode(actions_dag,
"arrayReverseSortSpark", {array_arg});
+ const auto * sort_node = toFunctionNode(actions_dag, "sortArraySpark",
{array_arg});
+ const auto * reverse_sort_node = toFunctionNode(actions_dag,
"reverseSortArraySpark", {array_arg});
const auto * result_node = toFunctionNode(actions_dag, "if",
{order_arg, sort_node, reverse_sort_node});
return result_node;
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
index ff7449e2d..a69d41d00 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
@@ -258,6 +258,15 @@ trait SparkPlanExecApi {
throw new GlutenNotSupportException("all_match is not supported")
}
+ /** Transform array array_sort to Substrait. */
+ def genArraySortTransformer(
+ substraitExprName: String,
+ argument: ExpressionTransformer,
+ function: ExpressionTransformer,
+ expr: ArraySort): ExpressionTransformer = {
+ throw new GlutenNotSupportException("array_sort(on array) is not
supported")
+ }
+
/** Transform array exists to Substrait */
def genArrayExistsTransformer(
substraitExprName: String,
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
index b5bcb6876..805ff9490 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
@@ -556,6 +556,19 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
expressionsMap),
arrayTransform
)
+ case arraySort: ArraySort =>
+ BackendsApiManager.getSparkPlanExecApiInstance.genArraySortTransformer(
+ substraitExprName,
+ replaceWithExpressionTransformerInternal(
+ arraySort.argument,
+ attributeSeq,
+ expressionsMap),
+ replaceWithExpressionTransformerInternal(
+ arraySort.function,
+ attributeSeq,
+ expressionsMap),
+ arraySort
+ )
case tryEval @ TryEval(a: Add) =>
BackendsApiManager.getSparkPlanExecApiInstance.genTryArithmeticTransformer(
substraitExprName,
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
index e7e9c7ffe..51e78a97e 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
@@ -248,6 +248,7 @@ object ExpressionMappings {
Sig[ArrayFilter](FILTER),
Sig[ArrayForAll](FORALL),
Sig[ArrayExists](EXISTS),
+ Sig[ArraySort](ARRAY_SORT),
Sig[Shuffle](SHUFFLE),
Sig[ZipWith](ZIP_WITH),
Sig[Flatten](FLATTEN),
diff --git
a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala
b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala
index 278f11922..e3dc3a8ab 100644
---
a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala
+++
b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala
@@ -261,6 +261,7 @@ object ExpressionNames {
final val ARRAY_EXCEPT = "array_except"
final val ARRAY_REPEAT = "array_repeat"
final val ARRAY_REMOVE = "array_remove"
+ final val ARRAY_SORT = "array_sort"
final val ARRAYS_ZIP = "arrays_zip"
final val FILTER = "filter"
final val FORALL = "forall"
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]