This is an automated email from the ASF dual-hosted git repository.

taiyangli pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new c40735bcb3 [GLUTEN-7780][CH] Fix split diff (#7781)
c40735bcb3 is described below

commit c40735bcb34b7f7fbd9b8c5930eb947011cb611a
Author: 李扬 <[email protected]>
AuthorDate: Mon Nov 4 10:16:41 2024 +0800

    [GLUTEN-7780][CH] Fix split diff (#7781)
    
    * fix split diff
    
    * fix code style
    
    * fix code style
---
 .../execution/GlutenClickHouseTPCHSuite.scala      |   7 +
 .../Functions/SparkFunctionSplitByRegexp.cpp       | 239 +++++++++++++++++++++
 .../Parser/scalar_function_parser/split.cpp        |  12 +-
 .../utils/clickhouse/ClickHouseTestSettings.scala  |   4 -
 .../utils/clickhouse/ClickHouseTestSettings.scala  |   4 -
 .../utils/clickhouse/ClickHouseTestSettings.scala  |   4 -
 .../utils/clickhouse/ClickHouseTestSettings.scala  |   4 -
 7 files changed, 252 insertions(+), 22 deletions(-)

diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala
index a56f45d1ba..8dc178e46c 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala
@@ -563,5 +563,12 @@ class GlutenClickHouseTPCHSuite extends 
GlutenClickHouseTPCHAbstractSuite {
     compareResultsAgainstVanillaSpark(sql, true, { _ => })
     spark.sql("drop table t1")
   }
+
+  test("GLUTEN-7780 fix split diff") {
+    val sql = "select split(concat('a|b|c', cast(id as string)), '\\|')" +
+      ", split(concat('a|b|c', cast(id as string)), '\\\\|')" +
+      ", split(concat('a|b|c', cast(id as string)), '|') from range(10)"
+    compareResultsAgainstVanillaSpark(sql, true, { _ => })
+  }
 }
 // scalastyle:off line.size.limit
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionSplitByRegexp.cpp 
b/cpp-ch/local-engine/Functions/SparkFunctionSplitByRegexp.cpp
new file mode 100644
index 0000000000..66f37c6203
--- /dev/null
+++ b/cpp-ch/local-engine/Functions/SparkFunctionSplitByRegexp.cpp
@@ -0,0 +1,239 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <Columns/ColumnConst.h>
+#include <DataTypes/IDataType.h>
+#include <Functions/FunctionFactory.h>
+#include <Functions/FunctionHelpers.h>
+#include <Functions/FunctionTokens.h>
+#include <Functions/Regexps.h>
+#include <Common/StringUtils.h>
+#include <base/map.h>
+#include <Common/assert_cast.h>
+
+
+namespace DB
+{
+
+namespace ErrorCodes
+{
+    extern const int ILLEGAL_COLUMN;
+}
+
+
+/** Functions that split strings into an array of strings or vice versa.
+  *
+  * splitByRegexp(regexp, s[, max_substrings])
+  */
+namespace
+{
+
+using Pos = const char *;
+
+class SparkSplitByRegexpImpl
+{
+private:
+    Regexps::RegexpPtr re;
+    OptimizedRegularExpression::MatchVec matches;
+
+    Pos pos;
+    Pos end;
+
+    std::optional<size_t> max_splits;
+    size_t splits;
+    bool max_substrings_includes_remaining_string;
+
+public:
+    static constexpr auto name = "splitByRegexpSpark";
+
+    static bool isVariadic() { return true; }
+    static size_t getNumberOfArguments() { return 0; }
+
+    static ColumnNumbers getArgumentsThatAreAlwaysConstant() { return {0, 2}; }
+
+    static void checkArguments(const IFunction & func, const 
ColumnsWithTypeAndName & arguments)
+    {
+        checkArgumentsWithSeparatorAndOptionalMaxSubstrings(func, arguments);
+    }
+
+    static constexpr auto strings_argument_position = 1uz;
+
+    void init(const ColumnsWithTypeAndName & arguments, bool 
max_substrings_includes_remaining_string_)
+    {
+        const ColumnConst * col = 
checkAndGetColumnConstStringOrFixedString(arguments[0].column.get());
+
+        if (!col)
+            throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} of 
first argument of function {}. "
+                            "Must be constant string.", 
arguments[0].column->getName(), name);
+
+        if (!col->getValue<String>().empty())
+            re = 
std::make_shared<OptimizedRegularExpression>(Regexps::createRegexp<false, 
false, false>(col->getValue<String>()));
+
+        max_substrings_includes_remaining_string = 
max_substrings_includes_remaining_string_;
+        max_splits = extractMaxSplits(arguments, 2);
+    }
+
+    /// Called for each next string.
+    void set(Pos pos_, Pos end_)
+    {
+        pos = pos_;
+        end = end_;
+        splits = 0;
+    }
+
+    /// Get the next token, if any, or return false.
+    bool get(Pos & token_begin, Pos & token_end)
+    {
+        if (!re)
+        {
+            if (pos == end)
+                return false;
+
+            token_begin = pos;
+
+            if (max_splits)
+            {
+                if (max_substrings_includes_remaining_string)
+                {
+                    if (splits == *max_splits - 1)
+                    {
+                        token_end = end;
+                        pos = end;
+                        return true;
+                    }
+                }
+                else
+                    if (splits == *max_splits)
+                        return false;
+            }
+
+            ++pos;
+            token_end = pos;
+            ++splits;
+        }
+        else
+        {
+            if (!pos || pos > end)
+                return false;
+
+            token_begin = pos;
+
+            if (max_splits)
+            {
+                if (max_substrings_includes_remaining_string)
+                {
+                    if (splits == *max_splits - 1)
+                    {
+                        token_end = end;
+                        pos = nullptr;
+                        return true;
+                    }
+                }
+                else
+                    if (splits == *max_splits)
+                        return false;
+            }
+
+            auto res = re->match(pos, end - pos, matches);
+            if (!res)
+            {
+                token_end = end;
+                pos = end + 1;
+            }
+            else if (!matches[0].length)
+            {
+                /// If match part is empty, increment position to avoid 
infinite loop.
+                token_end = (pos == end ? end : pos + 1);
+                ++pos;
+                ++splits;
+            }
+            else
+            {
+                token_end = pos + matches[0].offset;
+                pos = token_end + matches[0].length;
+                ++splits;
+            }
+        }
+
+        return true;
+    }
+};
+
+using SparkFunctionSplitByRegexp = FunctionTokens<SparkSplitByRegexpImpl>;
+
+/// Fallback splitByRegexp to splitByChar when its 1st argument is a trivial 
char for better performance
+class SparkSplitByRegexpOverloadResolver : public IFunctionOverloadResolver
+{
+public:
+    static constexpr auto name = "splitByRegexpSpark";
+    static FunctionOverloadResolverPtr create(ContextPtr context) { return 
std::make_unique<SparkSplitByRegexpOverloadResolver>(context); }
+
+    explicit SparkSplitByRegexpOverloadResolver(ContextPtr context_)
+        : context(context_)
+        , split_by_regexp(SparkFunctionSplitByRegexp::create(context)) {}
+
+    String getName() const override { return name; }
+    size_t getNumberOfArguments() const override { return 
SparkSplitByRegexpImpl::getNumberOfArguments(); }
+    bool isVariadic() const override { return 
SparkSplitByRegexpImpl::isVariadic(); }
+
+    FunctionBasePtr buildImpl(const ColumnsWithTypeAndName & arguments, const 
DataTypePtr & return_type) const override
+    {
+        if (patternIsTrivialChar(arguments))
+            return FunctionFactory::instance().getImpl("splitByChar", 
context)->build(arguments);
+        return std::make_unique<FunctionToFunctionBaseAdaptor>(
+            split_by_regexp, collections::map<DataTypes>(arguments, [](const 
auto & elem) { return elem.type; }), return_type);
+    }
+
+    DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) 
const override
+    {
+        return split_by_regexp->getReturnTypeImpl(arguments);
+    }
+
+private:
+    bool patternIsTrivialChar(const ColumnsWithTypeAndName & arguments) const
+    {
+        if (!arguments[0].column.get())
+            return false;
+        const ColumnConst * col = 
checkAndGetColumnConstStringOrFixedString(arguments[0].column.get());
+        if (!col)
+            return false;
+
+        String pattern = col->getValue<String>();
+        if (pattern.size() == 1)
+        {
+            OptimizedRegularExpression re = Regexps::createRegexp<false, 
false, false>(pattern);
+
+            std::string required_substring;
+            bool is_trivial;
+            bool required_substring_is_prefix;
+            re.getAnalyzeResult(required_substring, is_trivial, 
required_substring_is_prefix);
+            return is_trivial && required_substring == pattern;
+        }
+        return false;
+    }
+
+    ContextPtr context;
+    FunctionPtr split_by_regexp;
+};
+}
+
+REGISTER_FUNCTION(SparkSplitByRegexp)
+{
+    factory.registerFunction<SparkSplitByRegexpOverloadResolver>();
+}
+
+}
diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/split.cpp 
b/cpp-ch/local-engine/Parser/scalar_function_parser/split.cpp
index ed17c27ead..3ffd64decb 100644
--- a/cpp-ch/local-engine/Parser/scalar_function_parser/split.cpp
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/split.cpp
@@ -19,14 +19,14 @@
 
 namespace local_engine
 {
-class SparkFunctionSplitParser : public FunctionParser
+class FunctionSplitParser : public FunctionParser
 {
 public:
-    SparkFunctionSplitParser(ParserContextPtr parser_context_) : 
FunctionParser(parser_context_) {}
-    ~SparkFunctionSplitParser() override = default;
+    FunctionSplitParser(ParserContextPtr parser_context_) : 
FunctionParser(parser_context_) {}
+    ~FunctionSplitParser() override = default;
     static constexpr auto name = "split";
     String getName() const override { return name; }
-    String getCHFunctionName(const substrait::Expression_ScalarFunction &) 
const override { return "splitByRegexp"; }
+    String getCHFunctionName(const substrait::Expression_ScalarFunction &) 
const override { return "splitByRegexpSpark"; }
 
     const DB::ActionsDAG::Node * parse(const 
substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAG & 
actions_dag) const override
     {
@@ -35,7 +35,7 @@ public:
         for (const auto & arg : args)
             parsed_args.emplace_back(parseExpression(actions_dag, 
arg.value()));
         /// In Spark: split(str, regex [, limit] )
-        /// In CH: splitByRegexp(regexp, str [, limit])
+        /// In CH: splitByRegexpSpark(regexp, str [, limit])
         if (parsed_args.size() >= 2)
             std::swap(parsed_args[0], parsed_args[1]);
         auto ch_function_name = getCHFunctionName(substrait_func);
@@ -43,6 +43,6 @@ public:
         return convertNodeTypeIfNeeded(substrait_func, func_node, actions_dag);
     }
 };
-static FunctionParserRegister<SparkFunctionSplitParser> register_split;
+static FunctionParserRegister<FunctionSplitParser> register_split;
 }
 
diff --git 
a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
 
b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index 27e26606f6..50110f15d4 100644
--- 
a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++ 
b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -844,8 +844,6 @@ class ClickHouseTestSettings extends BackendTestSettings {
     .exclude("SPARK-32110: compare special double/float values in struct")
   enableSuite[GlutenRandomSuite].exclude("random").exclude("SPARK-9127 codegen 
with long seed")
   enableSuite[GlutenRegexpExpressionsSuite]
-    .exclude("LIKE ALL")
-    .exclude("LIKE ANY")
     .exclude("LIKE Pattern")
     .exclude("LIKE Pattern ESCAPE '/'")
     .exclude("LIKE Pattern ESCAPE '#'")
@@ -854,8 +852,6 @@ class ClickHouseTestSettings extends BackendTestSettings {
     .exclude("RegexReplace")
     .exclude("RegexExtract")
     .exclude("RegexExtractAll")
-    .exclude("SPLIT")
-    .exclude("SPARK-34814: LikeSimplification should handle NULL")
   enableSuite[GlutenSortOrderExpressionsSuite].exclude("SortPrefix")
   enableSuite[GlutenStringExpressionsSuite]
     .exclude("StringComparison")
diff --git 
a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
 
b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index da950e2fc1..9b3b090e32 100644
--- 
a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++ 
b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -817,8 +817,6 @@ class ClickHouseTestSettings extends BackendTestSettings {
     .exclude("SPARK-32110: compare special double/float values in struct")
   enableSuite[GlutenRandomSuite].exclude("random").exclude("SPARK-9127 codegen 
with long seed")
   enableSuite[GlutenRegexpExpressionsSuite]
-    .exclude("LIKE ALL")
-    .exclude("LIKE ANY")
     .exclude("LIKE Pattern")
     .exclude("LIKE Pattern ESCAPE '/'")
     .exclude("LIKE Pattern ESCAPE '#'")
@@ -827,8 +825,6 @@ class ClickHouseTestSettings extends BackendTestSettings {
     .exclude("RegexReplace")
     .exclude("RegexExtract")
     .exclude("RegexExtractAll")
-    .exclude("SPLIT")
-    .exclude("SPARK - 34814: LikeSimplification should handleNULL")
   enableSuite[GlutenSortOrderExpressionsSuite].exclude("SortPrefix")
   enableSuite[GlutenStringExpressionsSuite]
     .exclude("StringComparison")
diff --git 
a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
 
b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index ac08fc5a80..e91f1495fb 100644
--- 
a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++ 
b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -740,8 +740,6 @@ class ClickHouseTestSettings extends BackendTestSettings {
     .exclude("SPARK-32110: compare special double/float values in struct")
   enableSuite[GlutenRandomSuite].exclude("random").exclude("SPARK-9127 codegen 
with long seed")
   enableSuite[GlutenRegexpExpressionsSuite]
-    .exclude("LIKE ALL")
-    .exclude("LIKE ANY")
     .exclude("LIKE Pattern")
     .exclude("LIKE Pattern ESCAPE '/'")
     .exclude("LIKE Pattern ESCAPE '#'")
@@ -750,8 +748,6 @@ class ClickHouseTestSettings extends BackendTestSettings {
     .exclude("RegexReplace")
     .exclude("RegexExtract")
     .exclude("RegexExtractAll")
-    .exclude("SPLIT")
-    .exclude("SPARK - 34814: LikeSimplification should handleNULL")
   enableSuite[GlutenSortOrderExpressionsSuite].exclude("SortPrefix")
   enableSuite[GlutenStringExpressionsSuite]
     .exclude("StringComparison")
diff --git 
a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
 
b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index 9e4c81081d..f0637839a7 100644
--- 
a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++ 
b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -740,8 +740,6 @@ class ClickHouseTestSettings extends BackendTestSettings {
     .exclude("SPARK-32110: compare special double/float values in struct")
   enableSuite[GlutenRandomSuite].exclude("random").exclude("SPARK-9127 codegen 
with long seed")
   enableSuite[GlutenRegexpExpressionsSuite]
-    .exclude("LIKE ALL")
-    .exclude("LIKE ANY")
     .exclude("LIKE Pattern")
     .exclude("LIKE Pattern ESCAPE '/'")
     .exclude("LIKE Pattern ESCAPE '#'")
@@ -750,8 +748,6 @@ class ClickHouseTestSettings extends BackendTestSettings {
     .exclude("RegexReplace")
     .exclude("RegexExtract")
     .exclude("RegexExtractAll")
-    .exclude("SPLIT")
-    .exclude("SPARK - 34814: LikeSimplification should handleNULL")
   enableSuite[GlutenSortOrderExpressionsSuite].exclude("SortPrefix")
   enableSuite[GlutenStringExpressionsSuite]
     .exclude("StringComparison")


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to