This is an automated email from the ASF dual-hosted git repository.

lgbo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new bc1dba2d94 support delimiter in regular expression for str_to_map 
(#8998)
bc1dba2d94 is described below

commit bc1dba2d9409d7f61f89e36fb298a9551f51f6d6
Author: lgbo <[email protected]>
AuthorDate: Mon Mar 17 19:31:49 2025 +0800

    support delimiter in regular expression for str_to_map (#8998)
---
 .../execution/GlutenFunctionValidateSuite.scala    |  13 +
 .../Functions/SparkFunctionStrToMap.cpp            | 371 +++++++++++++++++----
 2 files changed, 312 insertions(+), 72 deletions(-)

diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
index 22693e7c7e..80edd15783 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
@@ -507,6 +507,19 @@ class GlutenFunctionValidateSuite extends 
GlutenClickHouseWholeStageTransformerS
     runQueryAndCompare(sql1)(checkGlutenOperatorMatch[ProjectExecTransformer])
   }
 
+  test("test str2map, regular expression") {
+    val sql1 =
+      """
+        |select str_to_map('ab', '', ''),
+        | str_to_map('a:b,c:d'),
+        | str_to_map('ab', '', ':'),
+        | str_to_map('a:,c:d,e', ',', ''),
+        | str_to_map('a,b', ',', ''),
+        | str_to_map('a:c|b:c', '\\|', ':')
+        |""".stripMargin
+    runQueryAndCompare(sql1, true, 
false)(checkGlutenOperatorMatch[ProjectExecTransformer])
+  }
+
   test("test parse_url") {
     val sql1 =
       """
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionStrToMap.cpp 
b/cpp-ch/local-engine/Functions/SparkFunctionStrToMap.cpp
index 77a9bce7d5..2fa8a6ec58 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionStrToMap.cpp
+++ b/cpp-ch/local-engine/Functions/SparkFunctionStrToMap.cpp
@@ -14,20 +14,27 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
+#include <memory>
 #include <type_traits>
 #include <Columns/ColumnConst.h>
 #include <Columns/ColumnNullable.h>
 #include <Columns/ColumnString.h>
+#include <Columns/ColumnVector.h>
 #include <Core/Field.h>
 #include <DataTypes/DataTypeArray.h>
 #include <DataTypes/DataTypeMap.h>
 #include <DataTypes/DataTypeNullable.h>
 #include <DataTypes/DataTypeString.h>
 #include <DataTypes/DataTypesNumber.h>
+#include <DataTypes/IDataType.h>
 #include <Functions/FunctionFactory.h>
 #include <Functions/FunctionHelpers.h>
 #include <Functions/IFunction.h>
+#include <Functions/IFunctionAdaptors.h>
+#include <Functions/Regexps.h>
+#include <base/map.h>
 #include <Common/Exception.h>
+#include <Common/OptimizedRegularExpression.h>
 
 #include <Poco/Logger.h>
 #include <Common/logger_useful.h>
@@ -36,47 +43,168 @@ namespace DB
 {
 namespace ErrorCodes
 {
-    extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
-    extern const int ILLEGAL_COLUMN;
+extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
+extern const int ILLEGAL_COLUMN;
 }
 }
 
 namespace local_engine
 {
-class SparkFunctionStrToMap : public DB::IFunction
+
+class TrivialCharSplitter
 {
 public:
     using Pos = const char *;
-    static constexpr auto name = "spark_str_to_map";
-    static DB::FunctionPtr create(const DB::ContextPtr) { return 
std::make_shared<SparkFunctionStrToMap>(); }
+    TrivialCharSplitter(String & delimiter_) : delimiter(delimiter_) { }
 
-    String getName() const override { return name; }
+    void reset(Pos str_begin_, Pos str_end_)
+    {
+        str_begin = str_begin_;
+        str_end = str_end_;
+        str_cursor = str_begin;
+        delimiter_begin = nullptr;
+        delimiter_end = nullptr;
+    }
 
-    bool isVariadic() const override { return true; }
-    size_t getNumberOfArguments() const override { return 3; }
+    Pos getDelimiterBegin() const { return delimiter_begin; }
+    Pos getDelimiterEnd() const { return delimiter_end; }
 
-    bool useDefaultImplementationForConstants() const override { return true; }
-    DB::ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { 
return {1, 2}; }
+    bool next(Pos & token_begin, Pos & token_end)
+    {
+        if (str_cursor >= str_end)
+            return false;
+        token_begin = str_cursor;
+        auto next_token_pos = static_cast<Pos>(memmem(str_cursor, str_end - 
str_cursor, delimiter.c_str(), delimiter.size()));
+        // If delimiter is not found, return the remaining string.
+        LOG_ERROR(getLogger("TrivialCharSplitter"), "xxx next_token_pos: {}", 
fmt::ptr(next_token_pos));
+        if (!next_token_pos)
+        {
+            token_end = str_end;
+            str_cursor = str_end;
+            delimiter_begin = nullptr;
+            delimiter_end = nullptr;
+            LOG_ERROR(
+                getLogger("TrivialCharSplitter"), "xxx delimiter_begin: {} 
{}", delimiter_begin == nullptr, fmt::ptr(delimiter_begin));
+        }
+        else
+        {
+            delimiter_begin = next_token_pos;
+            token_end = next_token_pos;
+            str_cursor = next_token_pos + delimiter.size();
+            delimiter_end = str_cursor;
+        }
+        return true;
+    }
 
-    bool isSuitableForShortCircuitArgumentsExecution(const 
DB::DataTypesWithConstInfo & /*arguments*/) const override { return true; }
+private:
+    const String & delimiter;
+    Pos str_begin;
+    Pos str_end;
+    Pos str_cursor;
+    Pos delimiter_begin;
+    Pos delimiter_end;
+};
 
-    DB::DataTypePtr getReturnTypeImpl(const DB::ColumnsWithTypeAndName & 
arguments) const override
+struct RegularSplitter
+{
+public:
+    using Pos = const char *;
+    RegularSplitter(const String & delimiter_) : delimiter(delimiter_)
+    {
+        if (!delimiter.empty())
+            re = 
std::make_shared<OptimizedRegularExpression>(DB::Regexps::createRegexp<false, 
false, false>(delimiter));
+    }
+
+    void reset(Pos str_begin_, Pos str_end_)
+    {
+        str_begin = str_begin_;
+        str_end = str_end_;
+        str_cursor = str_begin;
+        delimiter_begin = nullptr;
+        delimiter_end = nullptr;
+    }
+
+    Pos getDelimiterBegin() const { return delimiter_begin; }
+    Pos getDelimiterEnd() const { return delimiter_end; }
+
+    bool next(Pos & token_begin, Pos & token_end)
     {
-        if (arguments.size() != 3)
+        if (str_cursor >= str_end)
+            return false;
+        // If delimiter is empty, return each character as a token.
+        if (!re)
         {
-            throw DB::Exception(
-                DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
-                "Function {} requires 3 arguments, passed {}",
-                getName(),
-                arguments.size());
+            token_begin = str_cursor;
+            ++str_cursor;
+            delimiter_begin = str_cursor;
+            delimiter_end = str_cursor;
+            token_end = str_cursor;
         }
-
-        if 
(!DB::WhichDataType(DB::removeNullable(arguments[0].type)).isString()
-            || 
!DB::WhichDataType(DB::removeNullable(arguments[1].type)).isString()
-            || 
!DB::WhichDataType(DB::removeNullable(arguments[2].type)).isString())
+        else
         {
-            throw DB::Exception(DB::ErrorCodes::ILLEGAL_COLUMN, "All arguments 
for function {} must be String", getName());
+            if (!re->match(str_cursor, str_end - str_cursor, matches))
+            {
+                token_begin = str_cursor;
+                token_end = str_end;
+                str_cursor = str_end;
+                delimiter_begin = nullptr;
+                delimiter_end = nullptr;
+                return true;
+            }
+            token_begin = str_cursor;
+            token_end = str_cursor + matches[0].offset;
+            delimiter_begin = token_end;
+            str_cursor = token_end + matches[0].length;
+            delimiter_end = str_cursor;
         }
+        return true;
+    }
+
+private:
+    const String & delimiter;
+    DB::Regexps::RegexpPtr re;
+    OptimizedRegularExpression::MatchVec matches;
+    Pos str_begin;
+    Pos str_end;
+    Pos str_cursor;
+    Pos delimiter_begin;
+    Pos delimiter_end;
+};
+
+
+static bool isConstColumn(const DB::IColumn & col)
+{
+    return col.isConst();
+}
+
+template <typename PairGenerator, typename KVGenerator>
+class SparkFunctionStrToMap : public DB::IFunction
+{
+public:
+    using Pos = const char *;
+    static constexpr auto name = "spark_str_to_map";
+    static DB::FunctionPtr create(const DB::ContextPtr) { return 
std::make_shared<SparkFunctionStrToMap<PairGenerator, KVGenerator>>(); }
+    String getName() const override { return name; }
+    bool isVariadic() const override { return false; }
+    size_t getNumberOfArguments() const override { return 3; }
+    bool useDefaultImplementationForConstants() const override { return true; }
+    DB::ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { 
return {1, 2}; }
+    bool isSuitableForShortCircuitArgumentsExecution(const 
DB::DataTypesWithConstInfo & /*arguments*/) const override { return true; }
+    DB::DataTypePtr getReturnTypeImpl(const DB::ColumnsWithTypeAndName & 
arguments) const override
+    {
+        DB::FunctionArgumentDescriptors mandatory_args{
+            {"string", 
static_cast<DB::FunctionArgumentDescriptor::TypeValidator>(&DB::isStringOrFixedString),
 nullptr, "String"},
+            {"pair_delimiter",
+             
static_cast<DB::FunctionArgumentDescriptor::TypeValidator>(&DB::isStringOrFixedString),
+             &isConstColumn,
+             "String"},
+            {"key_value_delimiter",
+             
static_cast<DB::FunctionArgumentDescriptor::TypeValidator>(&DB::isStringOrFixedString),
+             &isConstColumn,
+             "String"},
+        };
+        validateFunctionArguments(*this, arguments, mandatory_args);
+
         auto map_typ = std::make_shared<DB::DataTypeMap>(
             std::make_shared<DB::DataTypeString>(), 
makeNullable(std::make_shared<DB::DataTypeString>()));
         if (arguments[0].type->isNullable())
@@ -88,76 +216,175 @@ public:
     DB::ColumnPtr executeImpl(
         const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr & 
result_type, size_t /*input_rows_count*/) const override
     {
-        auto map_col = result_type->createColumn();
-        auto pair_delim = (*arguments[1].column)[0].safeGet<String>();
-        auto pair_delim_len = pair_delim.size();
-        auto kv_delim = (*arguments[2].column)[0].safeGet<String>();
-        auto kv_delim_len = kv_delim.size();
-        const DB::IColumn * arg0 = arguments[0].column.get();
-        bool is_nullable = false;
-        if (arg0->isNullable())
+        String pair_delim = (*arguments[1].column)[0].safeGet<String>();
+        String kv_delim = (*arguments[2].column)[0].safeGet<String>();
+        PairGenerator pair_generator(pair_delim);
+        KVGenerator kv_generator(kv_delim);
+
+        auto col_map = result_type->createColumn();
+
+        const DB::ColumnString * col_str = nullptr;
+        const DB::ColumnUInt8 * null_map = nullptr;
+        if (arguments[0].column->isNullable())
         {
-            arg0 = DB::checkAndGetColumn<DB::ColumnNullable>(arg0);
-            is_nullable = true;
+            const auto * col_null = 
DB::checkAndGetColumn<DB::ColumnNullable>(arguments[0].column.get());
+            col_str = 
DB::checkAndGetColumn<DB::ColumnString>(col_null->getNestedColumnPtr().get());
+            null_map = &(col_null->getNullMapColumn());
         }
-        const auto * str_col = 
DB::checkAndGetColumn<DB::ColumnString>(arguments[0].column.get());
-        if (!str_col) [[unlikely]]
+        else
         {
-            throw DB::Exception(DB::ErrorCodes::ILLEGAL_COLUMN, "argument 0 
for function {} must be String", getName());
+            col_str = 
DB::checkAndGetColumn<DB::ColumnString>(arguments[0].column.get());
         }
-        const DB::ColumnString::Chars & str_vec = str_col->getChars();
-        const DB::ColumnString::Offsets & str_offsets = str_col->getOffsets();
-        map_col->reserve(str_offsets.size());
+
+        const auto & strs = col_str->getChars();
+        const auto & offsets = col_str->getOffsets();
 
         DB::ColumnString::Offset prev_offset = 0;
-        for (size_t i = 0, n = str_offsets.size(); i < n; ++i)
+
+        for (size_t i = 0, n = offsets.size(); i < n; ++i)
         {
-            if (is_nullable && str_col->isNullAt(i))
-            {
-                map_col->insertDefault();
-            }
+            if (null_map && (*null_map)[n] != 0)
+                col_map->insertDefault();
             else
             {
                 DB::Map map;
-                Pos pair_begin = reinterpret_cast<const char 
*>(&str_vec[prev_offset]);
-                Pos str_end = reinterpret_cast<const char 
*>(&str_vec[str_offsets[i]]);
-                while (pair_begin < str_end)
+                Pos str_begin = reinterpret_cast<Pos>(&strs[prev_offset]);
+                Pos str_end = reinterpret_cast<Pos>(&strs[offsets[i]]);
+                LOG_TRACE(
+                    getLogger("SparkFunctionStrToMap"),
+                    "str_begin: {}, str_end: {}. {}",
+                    0,
+                    str_end - str_begin,
+                    std::string_view(str_begin, str_end - str_begin));
+                pair_generator.reset(str_begin, str_end);
+                Pos pair_begin;
+                Pos pair_end;
+                while (pair_generator.next(pair_begin, pair_end))
                 {
-                    // Get next pair.
-                    auto next_pair_begin
-                        = static_cast<const char *>(memmem(pair_begin, str_end 
- pair_begin, pair_delim.c_str(), pair_delim_len));
-                    if (!next_pair_begin) [[unlikely]]
-                        next_pair_begin = str_end - 1;
-                    Pos value_begin
-                        = static_cast<const char *>(memmem(pair_begin, 
next_pair_begin - pair_begin, kv_delim.c_str(), kv_delim_len));
-                    DB::Field key;
-                    DB::Field value;
-                    if (!value_begin)
+                    LOG_TRACE(
+                        getLogger("SparkFunctionStrToMap"),
+                        "pair_begin: {}, pair_end: {}, {}",
+                        pair_begin - str_begin,
+                        pair_end - str_begin,
+                        std::string_view(pair_begin, pair_end - pair_begin));
+                    kv_generator.reset(pair_begin, pair_end);
+                    Pos key_begin;
+                    Pos key_end;
+                    if (kv_generator.next(key_begin, key_end))
                     {
-                        key = std::string_view(pair_begin, next_pair_begin - 
pair_begin);
-                        value = DB::Null();
+                        DB::Tuple tuple(2);
+                        size_t key_len = key_end - key_begin;
+                        tuple[0] = key_end == str_end ? 
std::string_view(key_begin, key_len - 1) : std::string_view(key_begin, key_len);
+                        auto delimiter_begin = 
kv_generator.getDelimiterBegin();
+                        auto delimiter_end = kv_generator.getDelimiterEnd();
+                        LOG_TRACE(
+                            getLogger("SparkFunctionStrToMap"),
+                            "key_begin: {}, key_end: {}, delim_begin: {}, 
delim_end: {}, key:{}",
+                            key_begin - str_begin,
+                            key_end - str_begin,
+                            delimiter_begin - str_begin,
+                            delimiter_end - str_begin,
+                            std::string_view(key_begin, key_end - key_begin));
+                        if (delimiter_begin && delimiter_begin != str_end)
+                        {
+                            DB::Field value = pair_end == str_end ? 
std::string_view(delimiter_end, pair_end - delimiter_end - 1)
+                                                                 : 
std::string_view(delimiter_end, pair_end - delimiter_end);
+                            tuple[1] = std::move(value);
+                        }
+                        else
+                        {
+                            // Not found delimiter, the value should be null
+                            tuple[1] = DB::Null();
+                        }
+                        map.emplace_back(std::move(tuple));
+                    }
+                    else if (pair_begin == pair_end)
+                    {
+                        // Empty pair. key is empty string, but value is null.
+                        DB::Tuple tuple(2);
+                        tuple[0] = std::string();
+                        tuple[1] = DB::Null();
+                        map.emplace_back(std::move(tuple));
                     }
                     else
                     {
-                        key = std::string_view(pair_begin, value_begin - 
pair_begin);
-                        value = std::string_view(value_begin + kv_delim_len, 
next_pair_begin - value_begin - kv_delim_len);
+                        LOG_WARNING(getLogger("SparkFunctionStrToMap"), "Split 
key value failed.");
                     }
-                    DB::Tuple tuple(2);
-                    tuple[0] = std::move(key);
-                    tuple[1] = std::move(value);
-                    map.emplace_back(std::move(tuple));
-
-                    pair_begin = next_pair_begin + pair_delim_len;
                 }
-                map_col->insert(map);
+                col_map->insert(std::move(map));
             }
-            prev_offset = str_offsets[i];
+            prev_offset = offsets[i];
         }
-        return map_col;
+
+        return col_map;
+    }
+};
+
+class SparkFunctionStrToMapOverloadResolver : public 
DB::IFunctionOverloadResolver
+{
+public:
+    static constexpr auto name = "spark_str_to_map";
+    static DB::FunctionOverloadResolverPtr create(const DB::ContextPtr context)
+    {
+        return 
std::make_shared<SparkFunctionStrToMapOverloadResolver>(context);
+    }
+
+    explicit SparkFunctionStrToMapOverloadResolver(DB::ContextPtr context_)
+        : context(context_), 
trivial_function(SparkFunctionStrToMap<TrivialCharSplitter, 
TrivialCharSplitter>::create(context))
+    {
+    }
+
+    String getName() const override { return name; }
+    size_t getNumberOfArguments() const override { return 3; }
+    bool isVariadic() const override { return false; }
+    DB::FunctionBasePtr buildImpl(const DB::ColumnsWithTypeAndName & 
arguments, const DB::DataTypePtr & return_type) const override
+    {
+        // The delimiter could be a regular expression.
+        bool is_trivial_pair_delim = patternIsTrivialChar(arguments[1]);
+        bool is_trivial_kv_delim = patternIsTrivialChar(arguments[2]);
+        DB::FunctionPtr function_ptr = nullptr;
+        if (is_trivial_pair_delim && is_trivial_kv_delim)
+            function_ptr = trivial_function;
+        else if (is_trivial_pair_delim && !is_trivial_kv_delim)
+            function_ptr = SparkFunctionStrToMap<TrivialCharSplitter, 
RegularSplitter>::create(context);
+        else if (!is_trivial_pair_delim && is_trivial_kv_delim)
+            function_ptr = SparkFunctionStrToMap<RegularSplitter, 
TrivialCharSplitter>::create(context);
+        else
+            function_ptr = SparkFunctionStrToMap<RegularSplitter, 
RegularSplitter>::create(context);
+        return std::make_unique<DB::FunctionToFunctionBaseAdaptor>(
+            function_ptr, collections::map<DB::DataTypes>(arguments, [](const 
auto & elem) { return elem.type; }), return_type);
+    }
+
+    DB::DataTypePtr getReturnTypeImpl(const DB::ColumnsWithTypeAndName & 
arguments) const override
+    {
+        return trivial_function->getReturnTypeImpl(arguments);
+    }
+
+private:
+    DB::ContextPtr context;
+    DB::FunctionPtr trivial_function;
+
+    bool patternIsTrivialChar(const DB::ColumnWithTypeAndName & argument) const
+    {
+        const DB::ColumnConst * col = 
checkAndGetColumnConstStringOrFixedString(argument.column.get());
+        if (!col)
+            return false;
+
+        String pattern = col->getValue<String>();
+        if (pattern.empty())
+            return false;
+        OptimizedRegularExpression re = DB::Regexps::createRegexp<false, 
false, false>(pattern);
+
+        std::string required_substring;
+        bool is_trivial;
+        bool required_substring_is_prefix;
+        re.getAnalyzeResult(required_substring, is_trivial, 
required_substring_is_prefix);
+        return is_trivial && required_substring == pattern;
     }
 };
-REGISTER_FUNCTION(SparkFunctionStrToMap)
+
+REGISTER_FUNCTION(SparkStrToMap)
 {
-    factory.registerFunction<SparkFunctionStrToMap>();
+    factory.registerFunction<SparkFunctionStrToMapOverloadResolver>();
 }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to