This is an automated email from the ASF dual-hosted git repository.

liuneng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new be57db84f [GLUTEN-6304][CH]Support array_join (#6305)
be57db84f is described below

commit be57db84f54ddf4f609733eab9bafeebd3ed8785
Author: kevinyhzou <[email protected]>
AuthorDate: Tue Jul 9 11:44:20 2024 +0800

    [GLUTEN-6304][CH]Support array_join (#6305)
    
    What changes were proposed in this pull request?
    (Please fill in changes proposed in this fix)
    
    (Fixes: #6304)
    
    How was this patch tested?
    test by spark ut
---
 .../org/apache/gluten/utils/CHExpressionUtil.scala |   9 +-
 .../Functions/SparkFunctionArrayJoin.cpp           | 204 +++++++++++++++++++++
 cpp-ch/local-engine/Parser/SerializedPlanParser.h  |   3 +-
 .../utils/clickhouse/ClickHouseTestSettings.scala  |   1 -
 .../utils/clickhouse/ClickHouseTestSettings.scala  |   1 -
 .../utils/clickhouse/ClickHouseTestSettings.scala  |   1 -
 .../utils/clickhouse/ClickHouseTestSettings.scala  |   1 -
 7 files changed, 214 insertions(+), 6 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
index 14f0ff489..ac03a7a5b 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
@@ -159,6 +159,13 @@ case class EncodeDecodeValidator() extends 
FunctionValidator {
   }
 }
 
+case class ArrayJoinValidator() extends FunctionValidator {
+  override def doValidate(expr: Expression): Boolean = expr match {
+    case t: ArrayJoin => !t.children.head.isInstanceOf[Literal]
+    case _ => true
+  }
+}
+
 object CHExpressionUtil {
 
   final val CH_AGGREGATE_FUNC_BLACKLIST: Map[String, FunctionValidator] = Map(
@@ -167,7 +174,7 @@ object CHExpressionUtil {
   )
 
   final val CH_BLACKLIST_SCALAR_FUNCTION: Map[String, FunctionValidator] = Map(
-    ARRAY_JOIN -> DefaultValidator(),
+    ARRAY_JOIN -> ArrayJoinValidator(),
     SPLIT_PART -> DefaultValidator(),
     TO_UNIX_TIMESTAMP -> UnixTimeStampValidator(),
     UNIX_TIMESTAMP -> UnixTimeStampValidator(),
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp 
b/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp
new file mode 100644
index 000000000..ed99c0904
--- /dev/null
+++ b/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp
@@ -0,0 +1,204 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <base/StringRef.h>
+#include <Interpreters/Context_fwd.h>
+#include <Columns/ColumnString.h>
+#include <Columns/ColumnNullable.h>
+#include <Functions/IFunction.h>
+#include <Functions/FunctionFactory.h>
+#include <DataTypes/DataTypeString.h>
+
+using namespace DB;
+
+namespace DB
+{
+namespace ErrorCodes
+{
+    extern const int ILLEGAL_TYPE_OF_ARGUMENT;
+    extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
+}
+}
+
+namespace local_engine
+{
+class SparkFunctionArrayJoin : public IFunction
+{
+public:
+    static constexpr auto name = "sparkArrayJoin";
+    static FunctionPtr create(ContextPtr) { return 
std::make_shared<SparkFunctionArrayJoin>(); }
+    SparkFunctionArrayJoin() = default;
+    ~SparkFunctionArrayJoin() override = default;
+    bool isSuitableForShortCircuitArgumentsExecution(const 
DataTypesWithConstInfo &) const override { return true; }
+    size_t getNumberOfArguments() const override { return 0; }
+    String getName() const override { return name; }
+    bool isVariadic() const override { return true; }
+    bool useDefaultImplementationForNulls() const override { return false; }
+
+    DB::DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName &) const 
override
+    {
+        auto data_type = std::make_shared<DataTypeString>();
+        return makeNullable(data_type);
+    }
+
+     ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const 
DataTypePtr &, size_t) const override
+     {
+        if (arguments.size() != 2 && arguments.size() != 3)
+            throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, 
"Function {} must have 2 or 3 arguments", getName());
+
+        const auto * arg_null_col = 
checkAndGetColumn<ColumnNullable>(arguments[0].column.get());
+        const ColumnArray * array_col;
+        if (!arg_null_col)
+            array_col = 
checkAndGetColumn<ColumnArray>(arguments[0].column.get());
+        else
+            array_col = 
checkAndGetColumn<ColumnArray>(arg_null_col->getNestedColumnPtr().get());
+        if (!array_col)
+            throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {} 
1st argument must be array type", getName());
+
+        auto res_col = ColumnString::create();
+        auto null_col = ColumnUInt8::create(array_col->size(), 0);
+        PaddedPODArray<UInt8> & null_result = null_col->getData();
+        std::pair<bool, StringRef> delim_p, null_replacement_p;
+        bool return_result = false;
+        auto checkAndGetConstString = [&](const ColumnPtr & col) -> 
std::pair<bool, StringRef>
+        {
+            StringRef res;
+            const auto * str_null_col = 
checkAndGetColumnConstData<ColumnNullable>(col.get());
+            if (str_null_col)
+            {
+                if (str_null_col->isNullAt(0))
+                {
+                    for (size_t i = 0; i < array_col->size(); ++i)
+                    {
+                        res_col->insertDefault();
+                        null_result[i] = 1;
+                    }
+                    return_result = true;
+                    return std::pair<bool, StringRef>(false, res);
+                }
+            }
+            else
+            {
+                const auto * string_col = 
checkAndGetColumnConstData<ColumnString>(col.get());
+                if (!string_col)
+                    return std::pair<bool, StringRef>(false, res);
+                else
+                    return std::pair<bool, StringRef>(true, 
string_col->getDataAt(0));
+            }
+        };
+        delim_p = checkAndGetConstString(arguments[1].column);
+        if (return_result)
+            return ColumnNullable::create(std::move(res_col), 
std::move(null_col));
+        
+        if (arguments.size() == 3)
+        {
+            null_replacement_p = checkAndGetConstString(arguments[2].column);
+            if (return_result)
+                return ColumnNullable::create(std::move(res_col), 
std::move(null_col));
+        }
+        const ColumnNullable * array_nested_col = 
checkAndGetColumn<ColumnNullable>(&array_col->getData());
+        const ColumnString * string_col;
+        if (array_nested_col)
+            string_col = 
checkAndGetColumn<ColumnString>(array_nested_col->getNestedColumnPtr().get());
+        else
+            string_col = 
checkAndGetColumn<ColumnString>(&array_col->getData());
+        const ColumnArray::Offsets & array_offsets = array_col->getOffsets();
+        const ColumnString::Offsets & string_offsets = 
string_col->getOffsets();
+        const ColumnString::Chars & string_data = string_col->getChars();
+        const ColumnNullable * delim_col = 
checkAndGetColumn<ColumnNullable>(arguments[1].column.get());
+        const ColumnNullable * null_replacement_col = arguments.size() == 3 ? 
checkAndGetColumn<ColumnNullable>(arguments[2].column.get()) : nullptr;
+        size_t current_offset = 0, array_pos = 0;
+        for (size_t i = 0; i < array_col->size(); ++i)
+        {
+            String res;
+            auto setResultNull = [&]() -> void
+            {
+                res_col->insertDefault();
+                null_result[i] = 1;
+                current_offset = array_offsets[i];
+            };
+            auto getDelimiterOrNullReplacement = [&](const std::pair<bool, 
StringRef> & s, const ColumnNullable * col) -> StringRef
+            {
+                if (s.first)
+                    return s.second;
+                else
+                {
+                    if (col->isNullAt(i))
+                        return StringRef(nullptr, 0);
+                    else
+                    {
+                        const ColumnString * col_string = 
checkAndGetColumn<ColumnString>(col->getNestedColumnPtr().get());
+                        return col_string->getDataAt(i);
+                    }
+                }
+            };
+            if (arg_null_col->isNullAt(i))
+            {
+                setResultNull();
+                continue;
+            }
+            const StringRef delim = getDelimiterOrNullReplacement(delim_p, 
delim_col);
+            if (!delim.data)
+            {
+                setResultNull();
+                continue;
+            }
+            StringRef null_replacement;
+            if (arguments.size() == 3)
+            {
+                null_replacement = 
getDelimiterOrNullReplacement(null_replacement_p, null_replacement_col);
+                if (!null_replacement.data)
+                {
+                    setResultNull();
+                    continue;
+                }
+            }
+            
+            size_t array_size = array_offsets[i] - current_offset;
+            size_t data_pos = array_pos == 0 ? 0 : string_offsets[array_pos - 
1];
+            for (size_t j = 0; j < array_size; ++j)
+            {
+                if (array_nested_col && array_nested_col->isNullAt(j + 
array_pos))
+                {
+                    if (null_replacement.data)
+                    {
+                        res += null_replacement.toString();
+                        if (j != array_size - 1)
+                            res += delim.toString();
+                    }
+                }
+                else
+                {
+                    const StringRef s(&string_data[data_pos], string_offsets[j 
+ array_pos] - data_pos - 1);
+                    res += s.toString();
+                    if (j != array_size - 1)
+                        res += delim.toString();
+                }
+                data_pos = string_offsets[j + array_pos];
+            }
+            array_pos += array_size;
+            res_col->insertData(res.data(), res.length());
+            current_offset = array_offsets[i];
+        }
+        return ColumnNullable::create(std::move(res_col), std::move(null_col));
+     }
+};
+
+REGISTER_FUNCTION(SparkArrayJoin)
+{
+    factory.registerFunction<SparkFunctionArrayJoin>();
+}
+}
\ No newline at end of file
diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h 
b/cpp-ch/local-engine/Parser/SerializedPlanParser.h
index 90086ea28..477fdb1f6 100644
--- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h
+++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h
@@ -177,7 +177,8 @@ static const std::map<std::string, std::string> 
SCALAR_FUNCTIONS
        {"array", "array"},
        {"shuffle", "arrayShuffle"},
        {"range", "range"}, /// dummy mapping
-        {"flatten", "sparkArrayFlatten"},
+       {"flatten", "sparkArrayFlatten"},
+       {"array_join", "sparkArrayJoin"},
 
        // map functions
        {"map", "map"},
diff --git 
a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
 
b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index a7ffbc9fa..8fd68d517 100644
--- 
a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++ 
b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -665,7 +665,6 @@ class ClickHouseTestSettings extends BackendTestSettings {
     .exclude("Map Concat")
     .exclude("MapFromEntries")
     .exclude("ArraysOverlap")
-    .exclude("ArrayJoin")
     .exclude("ArraysZip")
     .exclude("Sequence of numbers")
     .exclude("Sequence of timestamps")
diff --git 
a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
 
b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index ceb0d8a87..f69598adf 100644
--- 
a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++ 
b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -656,7 +656,6 @@ class ClickHouseTestSettings extends BackendTestSettings {
     .exclude("Map Concat")
     .exclude("MapFromEntries")
     .exclude("ArraysOverlap")
-    .exclude("ArrayJoin")
     .exclude("ArraysZip")
     .exclude("Sequence of numbers")
     .exclude("Sequence of timestamps")
diff --git 
a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
 
b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index 66007a367..ab288e835 100644
--- 
a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++ 
b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -544,7 +544,6 @@ class ClickHouseTestSettings extends BackendTestSettings {
     .exclude("Map Concat")
     .exclude("MapFromEntries")
     .exclude("ArraysOverlap")
-    .exclude("ArrayJoin")
     .exclude("ArraysZip")
     .exclude("Sequence of numbers")
     .exclude("Sequence of timestamps")
diff --git 
a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
 
b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index 66007a367..ab288e835 100644
--- 
a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++ 
b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -544,7 +544,6 @@ class ClickHouseTestSettings extends BackendTestSettings {
     .exclude("Map Concat")
     .exclude("MapFromEntries")
     .exclude("ArraysOverlap")
-    .exclude("ArrayJoin")
     .exclude("ArraysZip")
     .exclude("Sequence of numbers")
     .exclude("Sequence of timestamps")


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to