This is an automated email from the ASF dual-hosted git repository.

taiyangli pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 92d3793273 [GLUTEN-6856][CH]Support arrays_overlap and fix array_join 
diff (#6857)
92d3793273 is described below

commit 92d3793273b6da051be6fb6bffa9329ecb439b30
Author: kevinyhzou <[email protected]>
AuthorDate: Wed Oct 9 10:25:46 2024 +0800

    [GLUTEN-6856][CH]Support arrays_overlap and fix array_join diff (#6857)
    
    * Fix orc timezone read
    
    * on test
    
    * fix ci
    
    * fix ci
    
    * ci build error
    
    * support arrays_overlap
    
    * remove useless code
    
    * fix ci
    
    * fix ci
    
    * fix array_join diff
    
    * remove useless code
    
    * remove useless code
    
    * solve conflict
    
    * ci fix
    
    * remove useless code
    
    * use default impl for array join
    
    * remove useless code
---
 .../org/apache/gluten/utils/CHExpressionUtil.scala |   9 --
 .../Functions/SparkFunctionArrayJoin.cpp           | 135 ++++++--------------
 .../Functions/SparkFunctionArraysOverlap.cpp       | 139 +++++++++++++++++++++
 .../CommonScalarFunctionParser.cpp                 |   1 +
 4 files changed, 181 insertions(+), 103 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
index e58db43b26..bb23612b13 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
@@ -159,13 +159,6 @@ case class EncodeDecodeValidator() extends 
FunctionValidator {
   }
 }
 
-case class ArrayJoinValidator() extends FunctionValidator {
-  override def doValidate(expr: Expression): Boolean = expr match {
-    case t: ArrayJoin => !t.children.head.isInstanceOf[Literal]
-    case _ => true
-  }
-}
-
 case class FormatStringValidator() extends FunctionValidator {
   override def doValidate(expr: Expression): Boolean = {
     val formatString = expr.asInstanceOf[FormatString]
@@ -181,13 +174,11 @@ object CHExpressionUtil {
   )
 
   final val CH_BLACKLIST_SCALAR_FUNCTION: Map[String, FunctionValidator] = Map(
-    ARRAY_JOIN -> ArrayJoinValidator(),
     SPLIT_PART -> DefaultValidator(),
     TO_UNIX_TIMESTAMP -> UnixTimeStampValidator(),
     UNIX_TIMESTAMP -> UnixTimeStampValidator(),
     SEQUENCE -> SequenceValidator(),
     GET_JSON_OBJECT -> GetJsonObjectValidator(),
-    ARRAYS_OVERLAP -> DefaultValidator(),
     SPLIT -> StringSplitValidator(),
     SUBSTRING_INDEX -> SubstringIndexValidator(),
     LPAD -> StringLPadValidator(),
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp 
b/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp
index ed99c09042..4c2847d9f9 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp
+++ b/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp
@@ -46,7 +46,7 @@ public:
     size_t getNumberOfArguments() const override { return 0; }
     String getName() const override { return name; }
     bool isVariadic() const override { return true; }
-    bool useDefaultImplementationForNulls() const override { return false; }
+    bool useDefaultImplementationForConstants() const override { return true; }
 
     DB::DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName &) const 
override
     {
@@ -54,61 +54,20 @@ public:
         return makeNullable(data_type);
     }
 
-     ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const 
DataTypePtr &, size_t) const override
-     {
+    ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const 
DataTypePtr &, size_t input_rows_count) const override
+    {
         if (arguments.size() != 2 && arguments.size() != 3)
             throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, 
"Function {} must have 2 or 3 arguments", getName());
-
-        const auto * arg_null_col = 
checkAndGetColumn<ColumnNullable>(arguments[0].column.get());
-        const ColumnArray * array_col;
-        if (!arg_null_col)
-            array_col = 
checkAndGetColumn<ColumnArray>(arguments[0].column.get());
-        else
-            array_col = 
checkAndGetColumn<ColumnArray>(arg_null_col->getNestedColumnPtr().get());
-        if (!array_col)
-            throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {} 
1st argument must be array type", getName());
-
         auto res_col = ColumnString::create();
-        auto null_col = ColumnUInt8::create(array_col->size(), 0);
+        auto null_col = ColumnUInt8::create(input_rows_count, 0);
         PaddedPODArray<UInt8> & null_result = null_col->getData();
-        std::pair<bool, StringRef> delim_p, null_replacement_p;
-        bool return_result = false;
-        auto checkAndGetConstString = [&](const ColumnPtr & col) -> 
std::pair<bool, StringRef>
-        {
-            StringRef res;
-            const auto * str_null_col = 
checkAndGetColumnConstData<ColumnNullable>(col.get());
-            if (str_null_col)
-            {
-                if (str_null_col->isNullAt(0))
-                {
-                    for (size_t i = 0; i < array_col->size(); ++i)
-                    {
-                        res_col->insertDefault();
-                        null_result[i] = 1;
-                    }
-                    return_result = true;
-                    return std::pair<bool, StringRef>(false, res);
-                }
-            }
-            else
-            {
-                const auto * string_col = 
checkAndGetColumnConstData<ColumnString>(col.get());
-                if (!string_col)
-                    return std::pair<bool, StringRef>(false, res);
-                else
-                    return std::pair<bool, StringRef>(true, 
string_col->getDataAt(0));
-            }
-        };
-        delim_p = checkAndGetConstString(arguments[1].column);
-        if (return_result)
+        if (input_rows_count == 0)
             return ColumnNullable::create(std::move(res_col), 
std::move(null_col));
         
-        if (arguments.size() == 3)
-        {
-            null_replacement_p = checkAndGetConstString(arguments[2].column);
-            if (return_result)
-                return ColumnNullable::create(std::move(res_col), 
std::move(null_col));
-        }
+        const ColumnArray * array_col = array_col = 
checkAndGetColumn<ColumnArray>(arguments[0].column.get());;
+        if (!array_col)
+            throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {} 
1st argument must be array type", getName());
+        
         const ColumnNullable * array_nested_col = 
checkAndGetColumn<ColumnNullable>(&array_col->getData());
         const ColumnString * string_col;
         if (array_nested_col)
@@ -118,57 +77,42 @@ public:
         const ColumnArray::Offsets & array_offsets = array_col->getOffsets();
         const ColumnString::Offsets & string_offsets = 
string_col->getOffsets();
         const ColumnString::Chars & string_data = string_col->getChars();
-        const ColumnNullable * delim_col = 
checkAndGetColumn<ColumnNullable>(arguments[1].column.get());
-        const ColumnNullable * null_replacement_col = arguments.size() == 3 ? 
checkAndGetColumn<ColumnNullable>(arguments[2].column.get()) : nullptr;
+
+        auto extractColumnString = [&](const ColumnPtr & col) -> const 
ColumnString *
+        {
+            const ColumnString * res = nullptr;
+            if (col->isConst())
+            {
+                const ColumnConst * const_col = 
checkAndGetColumn<ColumnConst>(col.get());
+                if (const_col)
+                    res = 
checkAndGetColumn<ColumnString>(const_col->getDataColumnPtr().get());
+            }
+            else
+                res = checkAndGetColumn<ColumnString>(col.get());
+            return res;
+        };
+        bool const_delim_col = arguments[1].column->isConst();
+        bool const_null_replacement_col = false;
+        const ColumnString * delim_col = 
extractColumnString(arguments[1].column);
+        const ColumnString * null_replacement_col = nullptr;
+        if (arguments.size() == 3)
+        {
+            const_null_replacement_col = arguments[2].column->isConst();
+            null_replacement_col = extractColumnString(arguments[2].column);
+        }
         size_t current_offset = 0, array_pos = 0;
         for (size_t i = 0; i < array_col->size(); ++i)
         {
             String res;
-            auto setResultNull = [&]() -> void
+            const StringRef delim = const_delim_col ? delim_col->getDataAt(0) 
: delim_col->getDataAt(i);
+            StringRef null_replacement = StringRef(nullptr, 0);
+            if (null_replacement_col)
             {
-                res_col->insertDefault();
-                null_result[i] = 1;
-                current_offset = array_offsets[i];
-            };
-            auto getDelimiterOrNullReplacement = [&](const std::pair<bool, 
StringRef> & s, const ColumnNullable * col) -> StringRef
-            {
-                if (s.first)
-                    return s.second;
-                else
-                {
-                    if (col->isNullAt(i))
-                        return StringRef(nullptr, 0);
-                    else
-                    {
-                        const ColumnString * col_string = 
checkAndGetColumn<ColumnString>(col->getNestedColumnPtr().get());
-                        return col_string->getDataAt(i);
-                    }
-                }
-            };
-            if (arg_null_col->isNullAt(i))
-            {
-                setResultNull();
-                continue;
-            }
-            const StringRef delim = getDelimiterOrNullReplacement(delim_p, 
delim_col);
-            if (!delim.data)
-            {
-                setResultNull();
-                continue;
+                null_replacement = const_null_replacement_col ? 
null_replacement_col->getDataAt(0) : null_replacement_col->getDataAt(i);
             }
-            StringRef null_replacement;
-            if (arguments.size() == 3)
-            {
-                null_replacement = 
getDelimiterOrNullReplacement(null_replacement_p, null_replacement_col);
-                if (!null_replacement.data)
-                {
-                    setResultNull();
-                    continue;
-                }
-            }
-            
             size_t array_size = array_offsets[i] - current_offset;
             size_t data_pos = array_pos == 0 ? 0 : string_offsets[array_pos - 
1];
+            size_t last_not_null_pos = 0;
             for (size_t j = 0; j < array_size; ++j)
             {
                 if (array_nested_col && array_nested_col->isNullAt(j + 
array_pos))
@@ -179,11 +123,14 @@ public:
                         if (j != array_size - 1)
                             res += delim.toString();
                     }
+                    else if (j == array_size - 1)
+                        res = res.substr(0, last_not_null_pos);
                 }
                 else
                 {
                     const StringRef s(&string_data[data_pos], string_offsets[j 
+ array_pos] - data_pos - 1);
                     res += s.toString();
+                    last_not_null_pos = res.size();
                     if (j != array_size - 1)
                         res += delim.toString();
                 }
@@ -194,7 +141,7 @@ public:
             current_offset = array_offsets[i];
         }
         return ColumnNullable::create(std::move(res_col), std::move(null_col));
-     }
+    }
 };
 
 REGISTER_FUNCTION(SparkArrayJoin)
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp 
b/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp
new file mode 100644
index 0000000000..e43b528231
--- /dev/null
+++ b/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <Columns/ColumnString.h>
+#include <Columns/ColumnNullable.h>
+#include <Functions/IFunction.h>
+#include <Functions/FunctionFactory.h>
+#include <DataTypes/DataTypeString.h>
+#include <DataTypes/DataTypesNumber.h>
+
+using namespace DB;
+
+namespace DB
+{
+namespace ErrorCodes
+{
+    extern const int ILLEGAL_TYPE_OF_ARGUMENT;
+    extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
+}
+}
+
+namespace local_engine
+{
+class SparkFunctionArraysOverlap : public IFunction
+{
+public:
+    static constexpr auto name = "sparkArraysOverlap";
+    static FunctionPtr create(ContextPtr) { return 
std::make_shared<SparkFunctionArraysOverlap>(); }
+    SparkFunctionArraysOverlap() = default;
+    ~SparkFunctionArraysOverlap() override = default;
+    bool isSuitableForShortCircuitArgumentsExecution(const 
DataTypesWithConstInfo &) const override { return true; }
+    size_t getNumberOfArguments() const override { return 2; }
+    String getName() const override { return name; }
+    bool useDefaultImplementationForConstants() const override { return true; }
+
+    DB::DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName &) const 
override
+    {
+        auto data_type = std::make_shared<DataTypeUInt8>();
+        return makeNullable(data_type);
+    }
+    
+    ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const 
DataTypePtr &, size_t input_rows_count) const override
+    {
+        if (arguments.size() != 2)
+            throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, 
"Function {} must have 2 arguments", getName());
+        
+        auto res = ColumnUInt8::create(input_rows_count, 0);
+        auto null_map = ColumnUInt8::create(input_rows_count, 0);
+        PaddedPODArray<UInt8> & res_data = res->getData();
+        PaddedPODArray<UInt8> & null_map_data = null_map->getData();
+        if (input_rows_count == 0)
+            return ColumnNullable::create(std::move(res), std::move(null_map));
+        
+        const ColumnArray * array_col_1 = 
checkAndGetColumn<ColumnArray>(arguments[0].column.get());  
+        const ColumnArray * array_col_2 = 
checkAndGetColumn<ColumnArray>(arguments[1].column.get());
+        if (!array_col_1 || !array_col_2)
+            throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {} 
1st/2nd argument must be array type", getName());
+
+        const ColumnArray::Offsets & array_offsets_1 = 
array_col_1->getOffsets();
+        const ColumnArray::Offsets & array_offsets_2 = 
array_col_2->getOffsets();
+
+        size_t current_offset_1 = 0, current_offset_2 = 0;
+        size_t array_pos_1 = 0, array_pos_2 = 0;
+        for (size_t i = 0; i < array_col_1->size(); ++i)
+        {
+            size_t array_size_1 = array_offsets_1[i] - current_offset_1;
+            size_t array_size_2 = array_offsets_2[i] - current_offset_2;
+            auto executeCompare = [&](const IColumn & col1, const IColumn & 
col2, const ColumnUInt8 * null_map1, const ColumnUInt8 * null_map2) -> void
+            {   
+                for (size_t j = 0; j < array_size_1 && !res_data[i]; ++j)
+                {
+                    for (size_t k = 0; k < array_size_2; ++k)
+                    {
+                        if ((null_map1 && null_map1->getElement(j + 
array_pos_1)) || (null_map2 && null_map2->getElement(k + array_pos_2)))
+                        {
+                            null_map_data[i] = 1;
+                        }
+                        else if (col1.compareAt(j + array_pos_1, k + 
array_pos_2, col2, -1) == 0)
+                        {
+                            res_data[i] = 1;
+                            null_map_data[i] = 0;
+                            break;  
+                        }
+                    }
+                }
+            };
+            if (array_col_1->getData().isNullable() || 
array_col_2->getData().isNullable())
+            {
+                if (array_col_1->getData().isNullable() && 
array_col_2->getData().isNullable())
+                {
+                    const ColumnNullable * array_null_col_1 = 
assert_cast<const ColumnNullable *>(&array_col_1->getData());
+                    const ColumnNullable * array_null_col_2 = 
assert_cast<const ColumnNullable *>(&array_col_2->getData());
+                    executeCompare(array_null_col_1->getNestedColumn(), 
array_null_col_2->getNestedColumn(),
+                        &array_null_col_1->getNullMapColumn(), 
&array_null_col_2->getNullMapColumn());
+                }
+                else if (array_col_1->getData().isNullable())
+                {
+                    const ColumnNullable * array_null_col_1 = 
assert_cast<const ColumnNullable *>(&array_col_1->getData());
+                    executeCompare(array_null_col_1->getNestedColumn(), 
array_col_2->getData(), &array_null_col_1->getNullMapColumn(), nullptr);
+                }
+                else if (array_col_2->getData().isNullable())
+                {
+                    const ColumnNullable * array_null_col_2 = 
assert_cast<const ColumnNullable *>(&array_col_2->getData());
+                    executeCompare(array_col_1->getData(), 
array_null_col_2->getNestedColumn(), nullptr, 
&array_null_col_2->getNullMapColumn());
+                }
+            }
+            else if (array_col_1->getData().getDataType() == 
array_col_2->getData().getDataType())
+            {
+                executeCompare(array_col_1->getData(), array_col_2->getData(), 
nullptr, nullptr);
+            }
+
+            current_offset_1 = array_offsets_1[i];
+            current_offset_2 = array_offsets_2[i];
+            array_pos_1 += array_size_1;
+            array_pos_2 += array_size_2;
+        }
+        return ColumnNullable::create(std::move(res), std::move(null_map));
+    }
+};
+
+REGISTER_FUNCTION(SparkArraysOverlap)
+{
+    factory.registerFunction<SparkFunctionArraysOverlap>();
+}
+
+}
diff --git 
a/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp
 
b/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp
index 88e5d7ea8a..695166a894 100644
--- 
a/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp
+++ 
b/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp
@@ -166,6 +166,7 @@ REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Shuffle, shuffle, 
arrayShuffle);
 REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Range, range, range);
 REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Flatten, flatten, sparkArrayFlatten);
 REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ArrayJoin, array_join, sparkArrayJoin);
+REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ArraysOverlap, arrays_overlap, 
sparkArraysOverlap);
 REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ArraysZip, arrays_zip, 
arrayZipUnaligned);
 
 // map functions


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to