This is an automated email from the ASF dual-hosted git repository.

zhangzc pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new de26ed2da [CH] Support flatten (#6194)
de26ed2da is described below

commit de26ed2dad41d2d1e893c8d1b3ae806385d9972f
Author: LiuNeng <[email protected]>
AuthorDate: Tue Jun 25 16:10:05 2024 +0800

    [CH] Support flatten (#6194)
    
    [CH] Support flatten
    
    Co-authored-by: liuneng1994 <[email protected]>
---
 .../org/apache/gluten/utils/CHExpressionUtil.scala |   1 -
 cpp-ch/clickhouse.version                          |   3 +-
 .../local-engine/Functions/SparkArrayFlatten.cpp   | 160 +++++++++++++++++++++
 cpp-ch/local-engine/Parser/SerializedPlanParser.h  |   1 +
 .../utils/clickhouse/ClickHouseTestSettings.scala  |   2 +-
 .../spark/sql/GlutenDataFrameFunctionsSuite.scala  |  82 +++++++++++
 .../utils/clickhouse/ClickHouseTestSettings.scala  |   2 +-
 .../spark/sql/GlutenDataFrameFunctionsSuite.scala  |  82 +++++++++++
 8 files changed, 329 insertions(+), 4 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
index cf45c1118..e9bee8439 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
@@ -209,7 +209,6 @@ object CHExpressionUtil {
     UNIX_MICROS -> DefaultValidator(),
     TIMESTAMP_MILLIS -> DefaultValidator(),
     TIMESTAMP_MICROS -> DefaultValidator(),
-    FLATTEN -> DefaultValidator(),
     STACK -> DefaultValidator()
   )
 }
diff --git a/cpp-ch/clickhouse.version b/cpp-ch/clickhouse.version
index 4a3088e54..54d0a74c5 100644
--- a/cpp-ch/clickhouse.version
+++ b/cpp-ch/clickhouse.version
@@ -1,3 +1,4 @@
 CH_ORG=Kyligence
 CH_BRANCH=rebase_ch/20240621
-CH_COMMIT=acf666c1c4f
+CH_COMMIT=c811cbb985f
+
diff --git a/cpp-ch/local-engine/Functions/SparkArrayFlatten.cpp 
b/cpp-ch/local-engine/Functions/SparkArrayFlatten.cpp
new file mode 100644
index 000000000..d39bca5ea
--- /dev/null
+++ b/cpp-ch/local-engine/Functions/SparkArrayFlatten.cpp
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <Functions/IFunction.h>
+#include <Functions/FunctionFactory.h>
+#include <Functions/FunctionHelpers.h>
+#include <DataTypes/DataTypeArray.h>
+#include <Columns/ColumnArray.h>
+#include <Columns/ColumnNullable.h>
+
+
+namespace DB
+{
+
+namespace ErrorCodes
+{
+    extern const int ILLEGAL_TYPE_OF_ARGUMENT;
+    extern const int ILLEGAL_COLUMN;
+}
+
+/// arrayFlatten([[1, 2, 3], [4, 5]]) = [1, 2, 3, 4, 5] - flatten array.
+class SparkArrayFlatten : public IFunction
+{
+public:
+    static constexpr auto name = "sparkArrayFlatten";
+
+    static FunctionPtr create(ContextPtr) { return 
std::make_shared<SparkArrayFlatten>(); }
+
+    size_t getNumberOfArguments() const override { return 1; }
+    bool useDefaultImplementationForConstants() const override { return true; }
+    bool isSuitableForShortCircuitArgumentsExecution(const 
DataTypesWithConstInfo & /*arguments*/) const override { return true; }
+
+    DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
+    {
+        if (!isArray(arguments[0]))
+            throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal 
type {} of argument of function {}, expected Array",
+                            arguments[0]->getName(), getName());
+
+        DataTypePtr nested_type = arguments[0];
+        nested_type = 
checkAndGetDataType<DataTypeArray>(removeNullable(nested_type).get())->getNestedType();
+        return nested_type;
+    }
+
+    ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const 
DataTypePtr &, size_t input_rows_count) const override
+    {
+        /** We create an array column with array elements as the most deep 
elements of nested arrays,
+          * and construct offsets by selecting elements of most deep offsets 
by values of ancestor offsets.
+          *
+Example 1:
+
+Source column: Array(Array(UInt8)):
+Row 1: [[1, 2, 3], [4, 5]], Row 2: [[6], [7, 8]]
+data: [1, 2, 3], [4, 5], [6], [7, 8]
+offsets: 2, 4
+data.data: 1 2 3 4 5 6 7 8
+data.offsets: 3 5 6 8
+
+Result column: Array(UInt8):
+Row 1: [1, 2, 3, 4, 5], Row 2: [6, 7, 8]
+data: 1 2 3 4 5 6 7 8
+offsets: 5 8
+
+Result offsets are selected from the most deep (data.offsets) by previous deep 
(offsets) (and values are decremented by one):
+3 5 6 8
+  ^   ^
+
+Example 2:
+
+Source column: Array(Array(Array(UInt8))):
+Row 1: [[], [[1], [], [2, 3]]], Row 2: [[[4]]]
+
+most deep data: 1 2 3 4
+
+offsets1: 2 3
+offsets2: 0 3 4
+-           ^ ^ - select by prev offsets
+offsets3: 1 1 3 4
+-             ^ ^ - select by prev offsets
+
+result offsets: 3, 4
+result: Row 1: [1, 2, 3], Row2: [4]
+          */
+
+        const ColumnArray * src_col = 
checkAndGetColumn<ColumnArray>(arguments[0].column.get());
+
+        if (!src_col)
+            throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} in 
argument of function 'arrayFlatten'",
+                arguments[0].column->getName());
+
+        const IColumn::Offsets & src_offsets = src_col->getOffsets();
+
+        ColumnArray::ColumnOffsets::MutablePtr result_offsets_column;
+        const IColumn::Offsets * prev_offsets = &src_offsets;
+        const IColumn * prev_data = &src_col->getData();
+        bool nullable = prev_data->isNullable();
+        // when array has null element, return null
+        if (nullable)
+        {
+            const ColumnNullable *  nullable_column = 
checkAndGetColumn<ColumnNullable>(prev_data);
+            prev_data = nullable_column->getNestedColumnPtr().get();
+            for (size_t i = 0; i < nullable_column->size(); i++)
+            {
+                if (nullable_column->isNullAt(i))
+                {
+                    auto res= nullable_column->cloneEmpty();
+                    res->insertManyDefaults(input_rows_count);
+                    return res;
+                }
+            }
+        }
+        if (isNothing(prev_data->getDataType()))
+            return prev_data->cloneResized(input_rows_count);
+        // only flatten one dimension
+        if (const ColumnArray * next_col = 
checkAndGetColumn<ColumnArray>(prev_data))
+        {
+            result_offsets_column = 
ColumnArray::ColumnOffsets::create(input_rows_count);
+
+            IColumn::Offsets & result_offsets = 
result_offsets_column->getData();
+
+            const IColumn::Offsets * next_offsets = &next_col->getOffsets();
+
+            for (size_t i = 0; i < input_rows_count; ++i)
+                result_offsets[i] = (*next_offsets)[(*prev_offsets)[i] - 1];   
 /// -1 array subscript is Ok, see PaddedPODArray
+            prev_data = &next_col->getData();
+        }
+
+        auto res = ColumnArray::create(
+            prev_data->getPtr(),
+            result_offsets_column ? std::move(result_offsets_column) : 
src_col->getOffsetsPtr());
+        if (nullable)
+            return  makeNullable(res);
+        return res;
+    }
+
+private:
+    String getName() const override
+    {
+        return name;
+    }
+};
+
+REGISTER_FUNCTION(SparkArrayFlatten)
+{
+    factory.registerFunction<SparkArrayFlatten>();
+}
+
+}
diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h 
b/cpp-ch/local-engine/Parser/SerializedPlanParser.h
index 82e8c4077..aa18197e5 100644
--- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h
+++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h
@@ -180,6 +180,7 @@ static const std::map<std::string, std::string> 
SCALAR_FUNCTIONS
        {"array", "array"},
        {"shuffle", "arrayShuffle"},
        {"range", "range"}, /// dummy mapping
+        {"flatten", "sparkArrayFlatten"},
 
        // map functions
        {"map", "map"},
diff --git 
a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
 
b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index 8572ef54d..162671680 100644
--- 
a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++ 
b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -172,6 +172,7 @@ class ClickHouseTestSettings extends BackendTestSettings {
     .exclude("shuffle function - array for primitive type not containing null")
     .exclude("shuffle function - array for primitive type containing null")
     .exclude("shuffle function - array for non-primitive type")
+    .exclude("flatten function")
   enableSuite[GlutenDataFrameHintSuite]
   enableSuite[GlutenDataFrameImplicitsSuite]
   enableSuite[GlutenDataFrameJoinSuite].exclude(
@@ -674,7 +675,6 @@ class ClickHouseTestSettings extends BackendTestSettings {
     .exclude("Sequence with default step")
     .exclude("Reverse")
     .exclude("elementAt")
-    .exclude("Flatten")
     .exclude("ArrayRepeat")
     .exclude("Array remove")
     .exclude("Array Distinct")
diff --git 
a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala
 
b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala
index 2b0b40790..e64f760ab 100644
--- 
a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala
+++ 
b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala
@@ -49,4 +49,86 @@ class GlutenDataFrameFunctionsSuite extends 
DataFrameFunctionsSuite with GlutenS
       false
     )
   }
+
+  testGluten("flatten function") {
+    // Test cases with a primitive type
+    val intDF = Seq(
+      (Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6))),
+      (Seq(Seq(1, 2))),
+      (Seq(Seq(1), Seq.empty)),
+      (Seq(Seq.empty, Seq(1)))
+    ).toDF("i")
+
+    val intDFResult = Seq(Row(Seq(1, 2, 3, 4, 5, 6)), Row(Seq(1, 2)), 
Row(Seq(1)), Row(Seq(1)))
+
+    def testInt(): Unit = {
+      checkAnswer(intDF.select(flatten($"i")), intDFResult)
+      checkAnswer(intDF.selectExpr("flatten(i)"), intDFResult)
+    }
+
+    // Test with local relation, the Project will be evaluated without codegen
+    testInt()
+    // Test with cached relation, the Project will be evaluated with codegen
+    intDF.cache()
+    testInt()
+
+    // Test cases with non-primitive types
+    val strDF = Seq(
+      (Seq(Seq("a", "b"), Seq("c"), Seq("d", "e", "f"))),
+      (Seq(Seq("a", "b"))),
+      (Seq(Seq("a", null), Seq(null, "b"), Seq(null, null))),
+      (Seq(Seq("a"), Seq.empty)),
+      (Seq(Seq.empty, Seq("a")))
+    ).toDF("s")
+
+    val strDFResult = Seq(
+      Row(Seq("a", "b", "c", "d", "e", "f")),
+      Row(Seq("a", "b")),
+      Row(Seq("a", null, null, "b", null, null)),
+      Row(Seq("a")),
+      Row(Seq("a")))
+
+    def testString(): Unit = {
+      checkAnswer(strDF.select(flatten($"s")), strDFResult)
+      checkAnswer(strDF.selectExpr("flatten(s)"), strDFResult)
+    }
+
+    // Test with local relation, the Project will be evaluated without codegen
+    testString()
+    // Test with cached relation, the Project will be evaluated with codegen
+    strDF.cache()
+    testString()
+
+    val arrDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr")
+
+    def testArray(): Unit = {
+      checkAnswer(
+        arrDF.selectExpr("flatten(array(arr, array(null, 5), array(6, 
null)))"),
+        Seq(Row(Seq(1, 2, 3, null, 5, 6, null))))
+      checkAnswer(
+        arrDF.selectExpr("flatten(array(array(arr, arr), array(arr)))"),
+        Seq(Row(Seq(Seq(1, 2, 3), Seq(1, 2, 3), Seq(1, 2, 3)))))
+    }
+
+    // Test with local relation, the Project will be evaluated without codegen
+    testArray()
+    // Test with cached relation, the Project will be evaluated with codegen
+    arrDF.cache()
+    testArray()
+
+    // Error test cases
+    val oneRowDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr")
+    intercept[AnalysisException] {
+      oneRowDF.select(flatten($"arr"))
+    }
+    intercept[AnalysisException] {
+      oneRowDF.select(flatten($"i"))
+    }
+    intercept[AnalysisException] {
+      oneRowDF.select(flatten($"s"))
+    }
+    intercept[AnalysisException] {
+      oneRowDF.selectExpr("flatten(null)")
+    }
+  }
 }
diff --git 
a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
 
b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index 50e7929e4..3147c7c3d 100644
--- 
a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++ 
b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -190,6 +190,7 @@ class ClickHouseTestSettings extends BackendTestSettings {
     .exclude("shuffle function - array for primitive type not containing null")
     .exclude("shuffle function - array for primitive type containing null")
     .exclude("shuffle function - array for non-primitive type")
+    .exclude("flatten function")
   enableSuite[GlutenDataFrameHintSuite]
   enableSuite[GlutenDataFrameImplicitsSuite]
   enableSuite[GlutenDataFrameJoinSuite].exclude(
@@ -714,7 +715,6 @@ class ClickHouseTestSettings extends BackendTestSettings {
     .exclude("Sequence with default step")
     .exclude("Reverse")
     .exclude("elementAt")
-    .exclude("Flatten")
     .exclude("ArrayRepeat")
     .exclude("Array remove")
     .exclude("Array Distinct")
diff --git 
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala
 
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala
index 2b0b40790..e64f760ab 100644
--- 
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala
+++ 
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala
@@ -49,4 +49,86 @@ class GlutenDataFrameFunctionsSuite extends 
DataFrameFunctionsSuite with GlutenS
       false
     )
   }
+
+  testGluten("flatten function") {
+    // Test cases with a primitive type
+    val intDF = Seq(
+      (Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6))),
+      (Seq(Seq(1, 2))),
+      (Seq(Seq(1), Seq.empty)),
+      (Seq(Seq.empty, Seq(1)))
+    ).toDF("i")
+
+    val intDFResult = Seq(Row(Seq(1, 2, 3, 4, 5, 6)), Row(Seq(1, 2)), 
Row(Seq(1)), Row(Seq(1)))
+
+    def testInt(): Unit = {
+      checkAnswer(intDF.select(flatten($"i")), intDFResult)
+      checkAnswer(intDF.selectExpr("flatten(i)"), intDFResult)
+    }
+
+    // Test with local relation, the Project will be evaluated without codegen
+    testInt()
+    // Test with cached relation, the Project will be evaluated with codegen
+    intDF.cache()
+    testInt()
+
+    // Test cases with non-primitive types
+    val strDF = Seq(
+      (Seq(Seq("a", "b"), Seq("c"), Seq("d", "e", "f"))),
+      (Seq(Seq("a", "b"))),
+      (Seq(Seq("a", null), Seq(null, "b"), Seq(null, null))),
+      (Seq(Seq("a"), Seq.empty)),
+      (Seq(Seq.empty, Seq("a")))
+    ).toDF("s")
+
+    val strDFResult = Seq(
+      Row(Seq("a", "b", "c", "d", "e", "f")),
+      Row(Seq("a", "b")),
+      Row(Seq("a", null, null, "b", null, null)),
+      Row(Seq("a")),
+      Row(Seq("a")))
+
+    def testString(): Unit = {
+      checkAnswer(strDF.select(flatten($"s")), strDFResult)
+      checkAnswer(strDF.selectExpr("flatten(s)"), strDFResult)
+    }
+
+    // Test with local relation, the Project will be evaluated without codegen
+    testString()
+    // Test with cached relation, the Project will be evaluated with codegen
+    strDF.cache()
+    testString()
+
+    val arrDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr")
+
+    def testArray(): Unit = {
+      checkAnswer(
+        arrDF.selectExpr("flatten(array(arr, array(null, 5), array(6, 
null)))"),
+        Seq(Row(Seq(1, 2, 3, null, 5, 6, null))))
+      checkAnswer(
+        arrDF.selectExpr("flatten(array(array(arr, arr), array(arr)))"),
+        Seq(Row(Seq(Seq(1, 2, 3), Seq(1, 2, 3), Seq(1, 2, 3)))))
+    }
+
+    // Test with local relation, the Project will be evaluated without codegen
+    testArray()
+    // Test with cached relation, the Project will be evaluated with codegen
+    arrDF.cache()
+    testArray()
+
+    // Error test cases
+    val oneRowDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr")
+    intercept[AnalysisException] {
+      oneRowDF.select(flatten($"arr"))
+    }
+    intercept[AnalysisException] {
+      oneRowDF.select(flatten($"i"))
+    }
+    intercept[AnalysisException] {
+      oneRowDF.select(flatten($"s"))
+    }
+    intercept[AnalysisException] {
+      oneRowDF.selectExpr("flatten(null)")
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to