This is an automated email from the ASF dual-hosted git repository.
zhangzc pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new de26ed2da [CH] Support flatten (#6194)
de26ed2da is described below
commit de26ed2dad41d2d1e893c8d1b3ae806385d9972f
Author: LiuNeng <[email protected]>
AuthorDate: Tue Jun 25 16:10:05 2024 +0800
[CH] Support flatten (#6194)
[CH] Support flatten
Co-authored-by: liuneng1994 <[email protected]>
---
.../org/apache/gluten/utils/CHExpressionUtil.scala | 1 -
cpp-ch/clickhouse.version | 3 +-
.../local-engine/Functions/SparkArrayFlatten.cpp | 160 +++++++++++++++++++++
cpp-ch/local-engine/Parser/SerializedPlanParser.h | 1 +
.../utils/clickhouse/ClickHouseTestSettings.scala | 2 +-
.../spark/sql/GlutenDataFrameFunctionsSuite.scala | 82 +++++++++++
.../utils/clickhouse/ClickHouseTestSettings.scala | 2 +-
.../spark/sql/GlutenDataFrameFunctionsSuite.scala | 82 +++++++++++
8 files changed, 329 insertions(+), 4 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
index cf45c1118..e9bee8439 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
@@ -209,7 +209,6 @@ object CHExpressionUtil {
UNIX_MICROS -> DefaultValidator(),
TIMESTAMP_MILLIS -> DefaultValidator(),
TIMESTAMP_MICROS -> DefaultValidator(),
- FLATTEN -> DefaultValidator(),
STACK -> DefaultValidator()
)
}
diff --git a/cpp-ch/clickhouse.version b/cpp-ch/clickhouse.version
index 4a3088e54..54d0a74c5 100644
--- a/cpp-ch/clickhouse.version
+++ b/cpp-ch/clickhouse.version
@@ -1,3 +1,4 @@
CH_ORG=Kyligence
CH_BRANCH=rebase_ch/20240621
-CH_COMMIT=acf666c1c4f
+CH_COMMIT=c811cbb985f
+
diff --git a/cpp-ch/local-engine/Functions/SparkArrayFlatten.cpp
b/cpp-ch/local-engine/Functions/SparkArrayFlatten.cpp
new file mode 100644
index 000000000..d39bca5ea
--- /dev/null
+++ b/cpp-ch/local-engine/Functions/SparkArrayFlatten.cpp
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <Functions/IFunction.h>
+#include <Functions/FunctionFactory.h>
+#include <Functions/FunctionHelpers.h>
+#include <DataTypes/DataTypeArray.h>
+#include <Columns/ColumnArray.h>
+#include <Columns/ColumnNullable.h>
+
+
+namespace DB
+{
+
+namespace ErrorCodes
+{
+ extern const int ILLEGAL_TYPE_OF_ARGUMENT;
+ extern const int ILLEGAL_COLUMN;
+}
+
+/// arrayFlatten([[1, 2, 3], [4, 5]]) = [1, 2, 3, 4, 5] - flatten array.
+class SparkArrayFlatten : public IFunction
+{
+public:
+ static constexpr auto name = "sparkArrayFlatten";
+
+ static FunctionPtr create(ContextPtr) { return
std::make_shared<SparkArrayFlatten>(); }
+
+ size_t getNumberOfArguments() const override { return 1; }
+ bool useDefaultImplementationForConstants() const override { return true; }
+ bool isSuitableForShortCircuitArgumentsExecution(const
DataTypesWithConstInfo & /*arguments*/) const override { return true; }
+
+ DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
+ {
+ if (!isArray(arguments[0]))
+ throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal
type {} of argument of function {}, expected Array",
+ arguments[0]->getName(), getName());
+
+ DataTypePtr nested_type = arguments[0];
+ nested_type =
checkAndGetDataType<DataTypeArray>(removeNullable(nested_type).get())->getNestedType();
+ return nested_type;
+ }
+
+ ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const
DataTypePtr &, size_t input_rows_count) const override
+ {
+ /** We create an array column with array elements as the most deep
elements of nested arrays,
+ * and construct offsets by selecting elements of most deep offsets
by values of ancestor offsets.
+ *
+Example 1:
+
+Source column: Array(Array(UInt8)):
+Row 1: [[1, 2, 3], [4, 5]], Row 2: [[6], [7, 8]]
+data: [1, 2, 3], [4, 5], [6], [7, 8]
+offsets: 2, 4
+data.data: 1 2 3 4 5 6 7 8
+data.offsets: 3 5 6 8
+
+Result column: Array(UInt8):
+Row 1: [1, 2, 3, 4, 5], Row 2: [6, 7, 8]
+data: 1 2 3 4 5 6 7 8
+offsets: 5 8
+
+Result offsets are selected from the most deep (data.offsets) by previous deep
(offsets) (and values are decremented by one):
+3 5 6 8
+ ^ ^
+
+Example 2:
+
+Source column: Array(Array(Array(UInt8))):
+Row 1: [[], [[1], [], [2, 3]]], Row 2: [[[4]]]
+
+most deep data: 1 2 3 4
+
+offsets1: 2 3
+offsets2: 0 3 4
+- ^ ^ - select by prev offsets
+offsets3: 1 1 3 4
+- ^ ^ - select by prev offsets
+
+result offsets: 3, 4
+result: Row 1: [1, 2, 3], Row2: [4]
+ */
+
+ const ColumnArray * src_col =
checkAndGetColumn<ColumnArray>(arguments[0].column.get());
+
+ if (!src_col)
+ throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} in
argument of function 'arrayFlatten'",
+ arguments[0].column->getName());
+
+ const IColumn::Offsets & src_offsets = src_col->getOffsets();
+
+ ColumnArray::ColumnOffsets::MutablePtr result_offsets_column;
+ const IColumn::Offsets * prev_offsets = &src_offsets;
+ const IColumn * prev_data = &src_col->getData();
+ bool nullable = prev_data->isNullable();
+ // when array has null element, return null
+ if (nullable)
+ {
+ const ColumnNullable * nullable_column =
checkAndGetColumn<ColumnNullable>(prev_data);
+ prev_data = nullable_column->getNestedColumnPtr().get();
+ for (size_t i = 0; i < nullable_column->size(); i++)
+ {
+ if (nullable_column->isNullAt(i))
+ {
+ auto res= nullable_column->cloneEmpty();
+ res->insertManyDefaults(input_rows_count);
+ return res;
+ }
+ }
+ }
+ if (isNothing(prev_data->getDataType()))
+ return prev_data->cloneResized(input_rows_count);
+ // only flatten one dimension
+ if (const ColumnArray * next_col =
checkAndGetColumn<ColumnArray>(prev_data))
+ {
+ result_offsets_column =
ColumnArray::ColumnOffsets::create(input_rows_count);
+
+ IColumn::Offsets & result_offsets =
result_offsets_column->getData();
+
+ const IColumn::Offsets * next_offsets = &next_col->getOffsets();
+
+ for (size_t i = 0; i < input_rows_count; ++i)
+ result_offsets[i] = (*next_offsets)[(*prev_offsets)[i] - 1];
/// -1 array subscript is Ok, see PaddedPODArray
+ prev_data = &next_col->getData();
+ }
+
+ auto res = ColumnArray::create(
+ prev_data->getPtr(),
+ result_offsets_column ? std::move(result_offsets_column) :
src_col->getOffsetsPtr());
+ if (nullable)
+ return makeNullable(res);
+ return res;
+ }
+
+private:
+ String getName() const override
+ {
+ return name;
+ }
+};
+
+REGISTER_FUNCTION(SparkArrayFlatten)
+{
+ factory.registerFunction<SparkArrayFlatten>();
+}
+
+}
diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h
b/cpp-ch/local-engine/Parser/SerializedPlanParser.h
index 82e8c4077..aa18197e5 100644
--- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h
+++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h
@@ -180,6 +180,7 @@ static const std::map<std::string, std::string>
SCALAR_FUNCTIONS
{"array", "array"},
{"shuffle", "arrayShuffle"},
{"range", "range"}, /// dummy mapping
+ {"flatten", "sparkArrayFlatten"},
// map functions
{"map", "map"},
diff --git
a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index 8572ef54d..162671680 100644
---
a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++
b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -172,6 +172,7 @@ class ClickHouseTestSettings extends BackendTestSettings {
.exclude("shuffle function - array for primitive type not containing null")
.exclude("shuffle function - array for primitive type containing null")
.exclude("shuffle function - array for non-primitive type")
+ .exclude("flatten function")
enableSuite[GlutenDataFrameHintSuite]
enableSuite[GlutenDataFrameImplicitsSuite]
enableSuite[GlutenDataFrameJoinSuite].exclude(
@@ -674,7 +675,6 @@ class ClickHouseTestSettings extends BackendTestSettings {
.exclude("Sequence with default step")
.exclude("Reverse")
.exclude("elementAt")
- .exclude("Flatten")
.exclude("ArrayRepeat")
.exclude("Array remove")
.exclude("Array Distinct")
diff --git
a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala
b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala
index 2b0b40790..e64f760ab 100644
---
a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala
+++
b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala
@@ -49,4 +49,86 @@ class GlutenDataFrameFunctionsSuite extends
DataFrameFunctionsSuite with GlutenS
false
)
}
+
+ testGluten("flatten function") {
+ // Test cases with a primitive type
+ val intDF = Seq(
+ (Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6))),
+ (Seq(Seq(1, 2))),
+ (Seq(Seq(1), Seq.empty)),
+ (Seq(Seq.empty, Seq(1)))
+ ).toDF("i")
+
+ val intDFResult = Seq(Row(Seq(1, 2, 3, 4, 5, 6)), Row(Seq(1, 2)),
Row(Seq(1)), Row(Seq(1)))
+
+ def testInt(): Unit = {
+ checkAnswer(intDF.select(flatten($"i")), intDFResult)
+ checkAnswer(intDF.selectExpr("flatten(i)"), intDFResult)
+ }
+
+ // Test with local relation, the Project will be evaluated without codegen
+ testInt()
+ // Test with cached relation, the Project will be evaluated with codegen
+ intDF.cache()
+ testInt()
+
+ // Test cases with non-primitive types
+ val strDF = Seq(
+ (Seq(Seq("a", "b"), Seq("c"), Seq("d", "e", "f"))),
+ (Seq(Seq("a", "b"))),
+ (Seq(Seq("a", null), Seq(null, "b"), Seq(null, null))),
+ (Seq(Seq("a"), Seq.empty)),
+ (Seq(Seq.empty, Seq("a")))
+ ).toDF("s")
+
+ val strDFResult = Seq(
+ Row(Seq("a", "b", "c", "d", "e", "f")),
+ Row(Seq("a", "b")),
+ Row(Seq("a", null, null, "b", null, null)),
+ Row(Seq("a")),
+ Row(Seq("a")))
+
+ def testString(): Unit = {
+ checkAnswer(strDF.select(flatten($"s")), strDFResult)
+ checkAnswer(strDF.selectExpr("flatten(s)"), strDFResult)
+ }
+
+ // Test with local relation, the Project will be evaluated without codegen
+ testString()
+ // Test with cached relation, the Project will be evaluated with codegen
+ strDF.cache()
+ testString()
+
+ val arrDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr")
+
+ def testArray(): Unit = {
+ checkAnswer(
+ arrDF.selectExpr("flatten(array(arr, array(null, 5), array(6,
null)))"),
+ Seq(Row(Seq(1, 2, 3, null, 5, 6, null))))
+ checkAnswer(
+ arrDF.selectExpr("flatten(array(array(arr, arr), array(arr)))"),
+ Seq(Row(Seq(Seq(1, 2, 3), Seq(1, 2, 3), Seq(1, 2, 3)))))
+ }
+
+ // Test with local relation, the Project will be evaluated without codegen
+ testArray()
+ // Test with cached relation, the Project will be evaluated with codegen
+ arrDF.cache()
+ testArray()
+
+ // Error test cases
+ val oneRowDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr")
+ intercept[AnalysisException] {
+ oneRowDF.select(flatten($"arr"))
+ }
+ intercept[AnalysisException] {
+ oneRowDF.select(flatten($"i"))
+ }
+ intercept[AnalysisException] {
+ oneRowDF.select(flatten($"s"))
+ }
+ intercept[AnalysisException] {
+ oneRowDF.selectExpr("flatten(null)")
+ }
+ }
}
diff --git
a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index 50e7929e4..3147c7c3d 100644
---
a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++
b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -190,6 +190,7 @@ class ClickHouseTestSettings extends BackendTestSettings {
.exclude("shuffle function - array for primitive type not containing null")
.exclude("shuffle function - array for primitive type containing null")
.exclude("shuffle function - array for non-primitive type")
+ .exclude("flatten function")
enableSuite[GlutenDataFrameHintSuite]
enableSuite[GlutenDataFrameImplicitsSuite]
enableSuite[GlutenDataFrameJoinSuite].exclude(
@@ -714,7 +715,6 @@ class ClickHouseTestSettings extends BackendTestSettings {
.exclude("Sequence with default step")
.exclude("Reverse")
.exclude("elementAt")
- .exclude("Flatten")
.exclude("ArrayRepeat")
.exclude("Array remove")
.exclude("Array Distinct")
diff --git
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala
index 2b0b40790..e64f760ab 100644
---
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala
+++
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala
@@ -49,4 +49,86 @@ class GlutenDataFrameFunctionsSuite extends
DataFrameFunctionsSuite with GlutenS
false
)
}
+
+ testGluten("flatten function") {
+ // Test cases with a primitive type
+ val intDF = Seq(
+ (Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6))),
+ (Seq(Seq(1, 2))),
+ (Seq(Seq(1), Seq.empty)),
+ (Seq(Seq.empty, Seq(1)))
+ ).toDF("i")
+
+ val intDFResult = Seq(Row(Seq(1, 2, 3, 4, 5, 6)), Row(Seq(1, 2)),
Row(Seq(1)), Row(Seq(1)))
+
+ def testInt(): Unit = {
+ checkAnswer(intDF.select(flatten($"i")), intDFResult)
+ checkAnswer(intDF.selectExpr("flatten(i)"), intDFResult)
+ }
+
+ // Test with local relation, the Project will be evaluated without codegen
+ testInt()
+ // Test with cached relation, the Project will be evaluated with codegen
+ intDF.cache()
+ testInt()
+
+ // Test cases with non-primitive types
+ val strDF = Seq(
+ (Seq(Seq("a", "b"), Seq("c"), Seq("d", "e", "f"))),
+ (Seq(Seq("a", "b"))),
+ (Seq(Seq("a", null), Seq(null, "b"), Seq(null, null))),
+ (Seq(Seq("a"), Seq.empty)),
+ (Seq(Seq.empty, Seq("a")))
+ ).toDF("s")
+
+ val strDFResult = Seq(
+ Row(Seq("a", "b", "c", "d", "e", "f")),
+ Row(Seq("a", "b")),
+ Row(Seq("a", null, null, "b", null, null)),
+ Row(Seq("a")),
+ Row(Seq("a")))
+
+ def testString(): Unit = {
+ checkAnswer(strDF.select(flatten($"s")), strDFResult)
+ checkAnswer(strDF.selectExpr("flatten(s)"), strDFResult)
+ }
+
+ // Test with local relation, the Project will be evaluated without codegen
+ testString()
+ // Test with cached relation, the Project will be evaluated with codegen
+ strDF.cache()
+ testString()
+
+ val arrDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr")
+
+ def testArray(): Unit = {
+ checkAnswer(
+ arrDF.selectExpr("flatten(array(arr, array(null, 5), array(6,
null)))"),
+ Seq(Row(Seq(1, 2, 3, null, 5, 6, null))))
+ checkAnswer(
+ arrDF.selectExpr("flatten(array(array(arr, arr), array(arr)))"),
+ Seq(Row(Seq(Seq(1, 2, 3), Seq(1, 2, 3), Seq(1, 2, 3)))))
+ }
+
+ // Test with local relation, the Project will be evaluated without codegen
+ testArray()
+ // Test with cached relation, the Project will be evaluated with codegen
+ arrDF.cache()
+ testArray()
+
+ // Error test cases
+ val oneRowDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr")
+ intercept[AnalysisException] {
+ oneRowDF.select(flatten($"arr"))
+ }
+ intercept[AnalysisException] {
+ oneRowDF.select(flatten($"i"))
+ }
+ intercept[AnalysisException] {
+ oneRowDF.select(flatten($"s"))
+ }
+ intercept[AnalysisException] {
+ oneRowDF.selectExpr("flatten(null)")
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]