This is an automated email from the ASF dual-hosted git repository.

lgbo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 995145e93 support sort_array (#6323)
995145e93 is described below

commit 995145e93cb8c930501e69156ca57211d9932d2a
Author: lgbo <[email protected]>
AuthorDate: Fri Jul 5 08:34:20 2024 +0800

    support sort_array (#6323)
---
 .../clickhouse/CHSparkPlanExecApi.scala            |   9 +
 .../Functions/SparkFunctionArraySort.cpp           | 223 +++++++++++++++++----
 ...ionArraySort.cpp => SparkFunctionSortArray.cpp} |  10 +-
 ...unctionArraySort.h => SparkFunctionSortArray.h} |  14 +-
 .../arrayHighOrderFunctions.cpp                    | 144 +++++++++++++
 .../Parser/scalar_function_parser/sortArray.cpp    |   4 +-
 .../gluten/backendsapi/SparkPlanExecApi.scala      |   9 +
 .../gluten/expression/ExpressionConverter.scala    |  13 ++
 .../gluten/expression/ExpressionMappings.scala     |   1 +
 .../apache/gluten/expression/ExpressionNames.scala |   1 +
 10 files changed, 371 insertions(+), 57 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index add82cbb5..f5feade88 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -876,6 +876,15 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
     GenericExpressionTransformer(substraitExprName, Seq(argument, function), 
expr)
   }
 
+  /** Transform array sort to Substrait. */
+  override def genArraySortTransformer(
+      substraitExprName: String,
+      argument: ExpressionTransformer,
+      function: ExpressionTransformer,
+      expr: ArraySort): ExpressionTransformer = {
+    GenericExpressionTransformer(substraitExprName, Seq(argument, function), 
expr)
+  }
+
   override def genPreProjectForGenerate(generate: GenerateExec): SparkPlan = 
generate
 
   override def genPostProjectForGenerate(generate: GenerateExec): SparkPlan = 
generate
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionArraySort.cpp 
b/cpp-ch/local-engine/Functions/SparkFunctionArraySort.cpp
index 126b84eaa..1371ec60e 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionArraySort.cpp
+++ b/cpp-ch/local-engine/Functions/SparkFunctionArraySort.cpp
@@ -14,75 +14,212 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-#include <Functions/SparkFunctionArraySort.h>
+#include <Functions/FunctionHelpers.h>
 #include <Functions/FunctionFactory.h>
+#include <Columns/ColumnArray.h>
+#include <Columns/ColumnFunction.h>
+#include <Columns/ColumnNullable.h>
+#include <Common/Exception.h>
+#include <DataTypes/DataTypeArray.h>
+#include <DataTypes/DataTypeFunction.h>
+#include <DataTypes/DataTypeLowCardinality.h>
+#include <Poco/Logger.h>
+#include <Common/logger_useful.h>
+#include <base/sort.h>
 
-namespace DB
+namespace DB::ErrorCodes
 {
+    extern const int TOO_FEW_ARGUMENTS_FOR_FUNCTION;
+    extern const int ILLEGAL_TYPE_OF_ARGUMENT;
+    extern const int TYPE_MISMATCH;
+    extern const int ILLEGAL_COLUMN;
+}
 
-namespace ErrorCodes
+/// The usage of `arraySort` in CH is different from Spark's `sort_array` 
function.
+/// We need to implement a custom function to sort arrays.
+namespace local_engine
 {
-    extern const int LOGICAL_ERROR;
-}
 
-namespace
+struct LambdaLess
 {
+    const DB::IColumn & column;
+    DB::DataTypePtr type;
+    const DB::ColumnFunction & lambda;
+    explicit LambdaLess(const DB::IColumn & column_, DB::DataTypePtr type_, 
const DB::ColumnFunction & lambda_)
+        : column(column_), type(type_), lambda(lambda_) {}
+
+    /// May not efficient
+    bool operator()(size_t lhs, size_t rhs) const
+    {
+        /// The column name seems not matter.
+        auto left_value_col = DB::ColumnWithTypeAndName(oneRowColumn(lhs), 
type, "left");
+        auto right_value_col = DB::ColumnWithTypeAndName(oneRowColumn(rhs), 
type, "right");
+        auto cloned_lambda = lambda.cloneResized(1);
+        auto * lambda_ = typeid_cast<DB::ColumnFunction 
*>(cloned_lambda.get());
+        lambda_->appendArguments({std::move(left_value_col), 
std::move(right_value_col)});
+        auto compare_res_col = lambda_->reduce();
+        DB::Field field;
+        compare_res_col.column->get(0, field);
+        return field.get<Int32>() < 0;
+    }
+private:
+    ALWAYS_INLINE DB::ColumnPtr oneRowColumn(size_t i) const
+    {
+        auto res = column.cloneEmpty();
+        res->insertFrom(column, i);
+        return std::move(res);
+    }
+};
 
-template <bool positive>
 struct Less
 {
-    const IColumn & column;
+    const DB::IColumn & column;
 
-    explicit Less(const IColumn & column_) : column(column_) { }
+    explicit Less(const DB::IColumn & column_) : column(column_) { }
 
     bool operator()(size_t lhs, size_t rhs) const
     {
-        if constexpr (positive)
-            /*
-                Note: We use nan_direction_hint=-1 for ascending sort to make 
NULL the least value.
-                However, NaN is also considered the least value, 
-                which results in different sorting results compared to Spark 
since Spark treats NaN as the greatest value.
-                For now, we are temporarily ignoring this issue because cases 
with NaN are rare,
-                and aligning with Spark would require tricky modifications to 
the CH underlying code.
-            */
-            return column.compareAt(lhs, rhs, column, -1) < 0;
-        else
-            return column.compareAt(lhs, rhs, column, -1) > 0;
+        return column.compareAt(lhs, rhs, column, 1) < 0;
     }
 };
 
-}
-
-template <bool positive>
-ColumnPtr SparkArraySortImpl<positive>::execute(
-    const ColumnArray & array,
-    ColumnPtr mapped,
-    const ColumnWithTypeAndName * fixed_arguments [[maybe_unused]])
+class FunctionSparkArraySort : public DB::IFunction
 {
-    const ColumnArray::Offsets & offsets = array.getOffsets();
+public:
+    static constexpr auto name = "arraySortSpark";
+    static DB::FunctionPtr create(DB::ContextPtr /*context*/) { return 
std::make_shared<FunctionSparkArraySort>(); }
 
-    size_t size = offsets.size();
-    size_t nested_size = array.getData().size();
-    IColumn::Permutation permutation(nested_size);
+    bool isVariadic() const override { return true; }
+    size_t getNumberOfArguments() const override { return 0; }
+    bool isSuitableForShortCircuitArgumentsExecution(const 
DB::DataTypesWithConstInfo &) const override { return true; }
 
-    for (size_t i = 0; i < nested_size; ++i)
-        permutation[i] = i;
+    bool useDefaultImplementationForNulls() const override { return false; }
+    bool useDefaultImplementationForLowCardinalityColumns() const override { 
return false; }
 
-    ColumnArray::Offset current_offset = 0;
-    for (size_t i = 0; i < size; ++i)
+    void getLambdaArgumentTypes(DB::DataTypes & arguments) const override
     {
-        auto next_offset = offsets[i];
-        ::sort(&permutation[current_offset], &permutation[next_offset], 
Less<positive>(*mapped));
-        current_offset = next_offset;
+        if (arguments.size() < 2)
+            throw 
DB::Exception(DB::ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION, "Function {} 
requires as arguments a lambda function and an array", getName());
+
+        if (arguments.size() > 1)
+        {
+            const auto * lambda_function_type = 
DB::checkAndGetDataType<DB::DataTypeFunction>(arguments[0].get());
+            if (!lambda_function_type || 
lambda_function_type->getArgumentTypes().size() != 2)
+                throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, 
"First argument of function {} must be a lambda function with 2 arguments, 
found {} instead.",
+                        getName(), arguments[0]->getName());
+            auto array_nesteed_type = 
DB::checkAndGetDataType<DB::DataTypeArray>(arguments.back().get())->getNestedType();
+            DB::DataTypes lambda_args = {array_nesteed_type, 
array_nesteed_type};
+            arguments[0] = std::make_shared<DB::DataTypeFunction>(lambda_args);
+        }
     }
 
-    return ColumnArray::create(array.getData().permute(permutation, 0), 
array.getOffsetsPtr());
-}
+    DB::DataTypePtr getReturnTypeImpl(const DB::ColumnsWithTypeAndName & 
arguments) const override
+    {
+        if (arguments.size() > 1)
+        {
+            const auto * lambda_function_type = 
checkAndGetDataType<DB::DataTypeFunction>(arguments[0].type.get());
+            if (!lambda_function_type)
+                throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, 
"First argument for function {} must be a function", getName());
+        }
+
+        return arguments.back().type;
+    }
+
+    DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, 
const DB::DataTypePtr &, size_t input_rows_count) const override
+    {
+        auto array_col = arguments.back().column;
+        auto array_type = arguments.back().type;
+        DB::ColumnPtr null_map = nullptr;
+        if (const auto * null_col = typeid_cast<const DB::ColumnNullable 
*>(array_col.get()))
+        {
+            null_map = null_col->getNullMapColumnPtr();
+            array_col = null_col->getNestedColumnPtr();
+            array_type = typeid_cast<const DB::DataTypeNullable 
*>(array_type.get())->getNestedType();
+        }
+
+        const auto * array_col_concrete = 
DB::checkAndGetColumn<DB::ColumnArray>(array_col.get());
+        if (!array_col_concrete)
+        {
+            const auto * aray_col_concrete_const = 
DB::checkAndGetColumnConst<DB::ColumnArray>(array_col.get());
+            if (!aray_col_concrete_const)
+            {
+                throw DB::Exception(DB::ErrorCodes::ILLEGAL_COLUMN, "Expected 
array column, found {}", array_col->getName());
+            }
+            array_col = 
DB::recursiveRemoveLowCardinality(aray_col_concrete_const->convertToFullColumn());
+            array_col_concrete = 
DB::checkAndGetColumn<DB::ColumnArray>(array_col.get());
+        }
+        auto array_nested_type = 
DB::checkAndGetDataType<DB::DataTypeArray>(array_type.get())->getNestedType();
+
+        DB::ColumnPtr sorted_array_col = nullptr;
+        if (arguments.size() > 1)
+            sorted_array_col = executeWithLambda(*array_col_concrete, 
array_nested_type, 
*checkAndGetColumn<DB::ColumnFunction>(arguments[0].column.get()));
+        else
+            sorted_array_col = executeWithoutLambda(*array_col_concrete);
+
+        if (null_map)
+        {
+            sorted_array_col = DB::ColumnNullable::create(sorted_array_col, 
null_map);
+        }
+        return sorted_array_col;
+    }
+private:
+    static DB::ColumnPtr executeWithLambda(const DB::ColumnArray & array_col, 
DB::DataTypePtr array_nested_type, const DB::ColumnFunction & lambda)
+    {
+        const auto & offsets = array_col.getOffsets();
+        auto rows = array_col.size();
+
+        size_t nested_size = array_col.getData().size();
+        DB::IColumn::Permutation permutation(nested_size);
+        for (size_t i = 0; i < nested_size; ++i)
+            permutation[i] = i;
+
+        DB::ColumnArray::Offset current_offset = 0;
+        for (size_t i = 0; i < rows; ++i)
+        {
+            auto next_offset = offsets[i];
+            ::sort(&permutation[current_offset],
+                &permutation[next_offset],
+                LambdaLess(array_col.getData(),
+                    array_nested_type,
+                    lambda));
+            current_offset = next_offset;
+        }
+        auto res = 
DB::ColumnArray::create(array_col.getData().permute(permutation, 0), 
array_col.getOffsetsPtr());
+        return res;
+    }
+
+    static DB::ColumnPtr executeWithoutLambda(const DB::ColumnArray & 
array_col)
+    {
+        const auto & offsets = array_col.getOffsets();
+        auto rows = array_col.size();
+
+        size_t nested_size = array_col.getData().size();
+        DB::IColumn::Permutation permutation(nested_size);
+        for (size_t i = 0; i < nested_size; ++i)
+            permutation[i] = i;
+
+        DB::ColumnArray::Offset current_offset = 0;
+        for (size_t i = 0; i < rows; ++i)
+        {
+            auto next_offset = offsets[i];
+            ::sort(&permutation[current_offset],
+                    &permutation[next_offset],
+                    Less(array_col.getData()));
+            current_offset = next_offset;
+        }
+        auto res = 
DB::ColumnArray::create(array_col.getData().permute(permutation, 0), 
array_col.getOffsetsPtr());
+        return res;
+    }
+
+    String getName() const override
+    {
+        return name;
+    }
+
+};
 
 REGISTER_FUNCTION(ArraySortSpark)
 {
-    factory.registerFunction<SparkFunctionArraySort>();
-    factory.registerFunction<SparkFunctionArrayReverseSort>();
+    factory.registerFunction<FunctionSparkArraySort>();
 }
-
 }
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionArraySort.cpp 
b/cpp-ch/local-engine/Functions/SparkFunctionSortArray.cpp
similarity index 91%
copy from cpp-ch/local-engine/Functions/SparkFunctionArraySort.cpp
copy to cpp-ch/local-engine/Functions/SparkFunctionSortArray.cpp
index 126b84eaa..42b88fbce 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionArraySort.cpp
+++ b/cpp-ch/local-engine/Functions/SparkFunctionSortArray.cpp
@@ -14,7 +14,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-#include <Functions/SparkFunctionArraySort.h>
+#include <Functions/SparkFunctionSortArray.h>
 #include <Functions/FunctionFactory.h>
 
 namespace DB
@@ -54,7 +54,7 @@ struct Less
 }
 
 template <bool positive>
-ColumnPtr SparkArraySortImpl<positive>::execute(
+ColumnPtr SparkSortArrayImpl<positive>::execute(
     const ColumnArray & array,
     ColumnPtr mapped,
     const ColumnWithTypeAndName * fixed_arguments [[maybe_unused]])
@@ -79,10 +79,10 @@ ColumnPtr SparkArraySortImpl<positive>::execute(
     return ColumnArray::create(array.getData().permute(permutation, 0), 
array.getOffsetsPtr());
 }
 
-REGISTER_FUNCTION(ArraySortSpark)
+REGISTER_FUNCTION(SortArraySpark)
 {
-    factory.registerFunction<SparkFunctionArraySort>();
-    factory.registerFunction<SparkFunctionArrayReverseSort>();
+    factory.registerFunction<SparkFunctionSortArray>();
+    factory.registerFunction<SparkFunctionReverseSortArray>();
 }
 
 }
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionArraySort.h 
b/cpp-ch/local-engine/Functions/SparkFunctionSortArray.h
similarity index 86%
rename from cpp-ch/local-engine/Functions/SparkFunctionArraySort.h
rename to cpp-ch/local-engine/Functions/SparkFunctionSortArray.h
index 9ce48f9c0..18c2128c0 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionArraySort.h
+++ b/cpp-ch/local-engine/Functions/SparkFunctionSortArray.h
@@ -32,7 +32,7 @@ namespace ErrorCodes
 /** Sort arrays, by values of its elements, or by values of corresponding 
elements of calculated expression (known as "schwartzsort").
   */
 template <bool positive>
-struct SparkArraySortImpl
+struct SparkSortArrayImpl
 {
     static bool needBoolean() { return false; }
     static bool needExpression() { return false; }
@@ -67,16 +67,16 @@ struct SparkArraySortImpl
         const ColumnWithTypeAndName * fixed_arguments [[maybe_unused]] = 
nullptr);
 };
 
-struct NameArraySort
+struct NameSortArray
 {
-    static constexpr auto name = "arraySortSpark";
+    static constexpr auto name = "sortArraySpark";
 };
-struct NameArrayReverseSort
+struct NameReverseSortArray
 {
-    static constexpr auto name = "arrayReverseSortSpark";
+    static constexpr auto name = "reverseSortArraySpark";
 };
 
-using SparkFunctionArraySort = FunctionArrayMapped<SparkArraySortImpl<true>, 
NameArraySort>;
-using SparkFunctionArrayReverseSort = 
FunctionArrayMapped<SparkArraySortImpl<false>, NameArrayReverseSort>;
+using SparkFunctionSortArray = FunctionArrayMapped<SparkSortArrayImpl<true>, 
NameSortArray>;
+using SparkFunctionReverseSortArray = 
FunctionArrayMapped<SparkSortArrayImpl<false>, NameReverseSortArray>;
 
 }
diff --git 
a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp 
b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp
index 584bc0ef1..3811880ae 100644
--- 
a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp
+++ 
b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp
@@ -151,4 +151,148 @@ public:
 };
 static FunctionParserRegister<ArrayAggregate> register_array_aggregate;
 
+class ArraySort : public FunctionParser
+{
+public:
+    static constexpr auto name = "array_sort";
+    explicit ArraySort(SerializedPlanParser * plan_parser_) : 
FunctionParser(plan_parser_) {}
+    ~ArraySort() override = default;
+    String getName() const override { return name; }
+    String getCHFunctionName(const substrait::Expression_ScalarFunction & 
scalar_function) const override
+    {
+        return "arraySortSpark";
+    }
+    const DB::ActionsDAG::Node * parse(const 
substrait::Expression_ScalarFunction & substrait_func,
+        DB::ActionsDAGPtr & actions_dag) const
+    {
+        auto ch_func_name = getCHFunctionName(substrait_func);
+        auto parsed_args = parseFunctionArguments(substrait_func, 
ch_func_name, actions_dag);
+
+        if (parsed_args.size() != 2)
+            throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "array_sort 
function must have two arguments");
+        if 
(isDefaultCompare(substrait_func.arguments()[1].value().scalar_function()))
+        {
+            return toFunctionNode(actions_dag, ch_func_name, {parsed_args[0]});
+        }
+
+        return toFunctionNode(actions_dag, ch_func_name, {parsed_args[1], 
parsed_args[0]});
+    }
+private:
+
+    /// The default lambda compare function for array_sort, `array_sort(x)`.
+    bool isDefaultCompare(const substrait::Expression_ScalarFunction & 
scalar_function) const
+    {
+        String left_variable_name, right_variable_name;
+        auto names_types = collectLambdaArguments(*plan_parser, 
scalar_function);
+        {
+            auto it = names_types.begin();
+            left_variable_name = it->name;
+            it++;
+            right_variable_name = it->name;
+        }
+
+        auto is_function = [&](const substrait::Expression & expr, const 
String & function_name) {
+            return expr.has_scalar_function()
+                && 
*(plan_parser->getFunctionSignatureName(expr.scalar_function().function_reference()))
 == function_name;
+        };
+
+        auto is_variable = [&](const substrait::Expression & expr, const 
String & var) {
+            if (!is_function(expr, "namedlambdavariable"))
+            {
+                return false;
+            }
+            const auto var_expr = 
expr.scalar_function().arguments()[0].value();
+            if (!var_expr.has_literal())
+                return false;
+            auto [_, name] = plan_parser->parseLiteral(var_expr.literal());
+            return var == name.get<String>();
+        };
+
+        auto is_int_value = [&](const substrait::Expression & expr, Int32 val) 
{
+            if (!expr.has_literal())
+                return false;
+            auto [_, x] = plan_parser->parseLiteral(expr.literal());
+            return val == x.get<Int32>();
+        };
+
+        auto is_variable_null = [&](const substrait::Expression & expr, const 
String & var) {
+            return is_function(expr, "is_null") && 
is_variable(expr.scalar_function().arguments(0).value(), var);
+        };
+
+        auto is_both_null = [&](const substrait::Expression & expr) {
+            return is_function(expr, "and")
+                && 
is_variable_null(expr.scalar_function().arguments(0).value(), 
left_variable_name)
+                && 
is_variable_null(expr.scalar_function().arguments(1).value(), 
right_variable_name);
+        };
+
+        auto is_left_greater_right = [&](const substrait::Expression & expr) {
+            if (!expr.has_if_then())
+                return false;
+
+            const auto & if_ = expr.if_then().ifs(0);
+            if (!is_function(if_.if_(), "gt"))
+                return false;
+
+            const auto & less_args = if_.if_().scalar_function().arguments();
+            return is_variable(less_args[0].value(), left_variable_name)
+                && is_variable(less_args[1].value(), right_variable_name)
+                && is_int_value(if_.then(), 1)
+                && is_int_value(expr.if_then().else_(), 0);
+        };
+
+        auto is_left_less_right = [&](const substrait::Expression & expr) {
+            if (!expr.has_if_then())
+                return false;
+
+            const auto & if_ = expr.if_then().ifs(0);
+            if (!is_function(if_.if_(), "lt"))
+                return false;
+
+            const auto & less_args = if_.if_().scalar_function().arguments();
+            return is_variable(less_args[0].value(), left_variable_name)
+                && is_variable(less_args[1].value(), right_variable_name)
+                && is_int_value(if_.then(), -1)
+                && is_left_greater_right(expr.if_then().else_());
+        };
+
+        auto is_right_null_else = [&](const substrait::Expression & expr) {
+            if (!expr.has_if_then())
+                return false;
+
+            /// if right arg is null, return 1
+            const auto & if_then = expr.if_then();
+            return is_variable_null(if_then.ifs(0).if_(), right_variable_name)
+                && is_int_value(if_then.ifs(0).then(), -1)
+                && is_left_less_right(if_then.else_());
+
+        };
+
+        auto is_left_null_else = [&](const substrait::Expression & expr) {
+            if (!expr.has_if_then())
+                return false;
+
+            /// if left arg is null, return 1
+            const auto & if_then = expr.if_then();
+            return is_variable_null(if_then.ifs(0).if_(), left_variable_name)
+                && is_int_value(if_then.ifs(0).then(), 1)
+                && is_right_null_else(if_then.else_());
+        };
+
+        auto is_if_both_null_else = [&](const substrait::Expression & expr) {
+            if (!expr.has_if_then())
+            {
+                return false;
+            }
+            const auto & if_ = expr.if_then().ifs(0);
+            return is_both_null(if_.if_())
+                && is_int_value(if_.then(), 0)
+                && is_left_null_else(expr.if_then().else_());
+        };
+
+        const auto & lambda_body = scalar_function.arguments()[0].value();
+        return is_if_both_null_else(lambda_body);
+    }
+};
+static FunctionParserRegister<ArraySort> register_array_sort;
+
 }
diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/sortArray.cpp 
b/cpp-ch/local-engine/Parser/scalar_function_parser/sortArray.cpp
index 85416bd71..4fd2fd4f6 100644
--- a/cpp-ch/local-engine/Parser/scalar_function_parser/sortArray.cpp
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/sortArray.cpp
@@ -52,8 +52,8 @@ public:
         const auto * array_arg = parsed_args[0];
         const auto * order_arg = parsed_args[1];
 
-        const auto * sort_node = toFunctionNode(actions_dag, "arraySortSpark", 
{array_arg});
-        const auto * reverse_sort_node = toFunctionNode(actions_dag, 
"arrayReverseSortSpark", {array_arg});
+        const auto * sort_node = toFunctionNode(actions_dag, "sortArraySpark", 
{array_arg});
+        const auto * reverse_sort_node = toFunctionNode(actions_dag, 
"reverseSortArraySpark", {array_arg});
 
         const auto * result_node = toFunctionNode(actions_dag, "if", 
{order_arg, sort_node, reverse_sort_node});
         return result_node;
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
index ff7449e2d..a69d41d00 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
@@ -258,6 +258,15 @@ trait SparkPlanExecApi {
     throw new GlutenNotSupportException("all_match is not supported")
   }
 
+  /** Transform array array_sort to Substrait. */
+  def genArraySortTransformer(
+      substraitExprName: String,
+      argument: ExpressionTransformer,
+      function: ExpressionTransformer,
+      expr: ArraySort): ExpressionTransformer = {
+    throw new GlutenNotSupportException("array_sort(on array) is not 
supported")
+  }
+
   /** Transform array exists to Substrait */
   def genArrayExistsTransformer(
       substraitExprName: String,
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
index b5bcb6876..805ff9490 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
@@ -556,6 +556,19 @@ object ExpressionConverter extends SQLConfHelper with 
Logging {
             expressionsMap),
           arrayTransform
         )
+      case arraySort: ArraySort =>
+        BackendsApiManager.getSparkPlanExecApiInstance.genArraySortTransformer(
+          substraitExprName,
+          replaceWithExpressionTransformerInternal(
+            arraySort.argument,
+            attributeSeq,
+            expressionsMap),
+          replaceWithExpressionTransformerInternal(
+            arraySort.function,
+            attributeSeq,
+            expressionsMap),
+          arraySort
+        )
       case tryEval @ TryEval(a: Add) =>
         
BackendsApiManager.getSparkPlanExecApiInstance.genTryArithmeticTransformer(
           substraitExprName,
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
index e7e9c7ffe..51e78a97e 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
@@ -248,6 +248,7 @@ object ExpressionMappings {
     Sig[ArrayFilter](FILTER),
     Sig[ArrayForAll](FORALL),
     Sig[ArrayExists](EXISTS),
+    Sig[ArraySort](ARRAY_SORT),
     Sig[Shuffle](SHUFFLE),
     Sig[ZipWith](ZIP_WITH),
     Sig[Flatten](FLATTEN),
diff --git 
a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala
 
b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala
index 278f11922..e3dc3a8ab 100644
--- 
a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala
+++ 
b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala
@@ -261,6 +261,7 @@ object ExpressionNames {
   final val ARRAY_EXCEPT = "array_except"
   final val ARRAY_REPEAT = "array_repeat"
   final val ARRAY_REMOVE = "array_remove"
+  final val ARRAY_SORT = "array_sort"
   final val ARRAYS_ZIP = "arrays_zip"
   final val FILTER = "filter"
   final val FORALL = "forall"


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to