This is an automated email from the ASF dual-hosted git repository.
lgbo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new bc1dba2d94 support delimiter in regular expression for str_to_map
(#8998)
bc1dba2d94 is described below
commit bc1dba2d9409d7f61f89e36fb298a9551f51f6d6
Author: lgbo <[email protected]>
AuthorDate: Mon Mar 17 19:31:49 2025 +0800
support delimiter in regular expression for str_to_map (#8998)
---
.../execution/GlutenFunctionValidateSuite.scala | 13 +
.../Functions/SparkFunctionStrToMap.cpp | 371 +++++++++++++++++----
2 files changed, 312 insertions(+), 72 deletions(-)
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
index 22693e7c7e..80edd15783 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
@@ -507,6 +507,19 @@ class GlutenFunctionValidateSuite extends
GlutenClickHouseWholeStageTransformerS
runQueryAndCompare(sql1)(checkGlutenOperatorMatch[ProjectExecTransformer])
}
+ test("test str2map, regular expression") {
+ val sql1 =
+ """
+ |select str_to_map('ab', '', ''),
+ | str_to_map('a:b,c:d'),
+ | str_to_map('ab', '', ':'),
+ | str_to_map('a:,c:d,e', ',', ''),
+ | str_to_map('a,b', ',', ''),
+ | str_to_map('a:c|b:c', '\\|', ':')
+ |""".stripMargin
+ runQueryAndCompare(sql1, true,
false)(checkGlutenOperatorMatch[ProjectExecTransformer])
+ }
+
test("test parse_url") {
val sql1 =
"""
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionStrToMap.cpp
b/cpp-ch/local-engine/Functions/SparkFunctionStrToMap.cpp
index 77a9bce7d5..2fa8a6ec58 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionStrToMap.cpp
+++ b/cpp-ch/local-engine/Functions/SparkFunctionStrToMap.cpp
@@ -14,20 +14,27 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+#include <memory>
#include <type_traits>
#include <Columns/ColumnConst.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnString.h>
+#include <Columns/ColumnVector.h>
#include <Core/Field.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeMap.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypesNumber.h>
+#include <DataTypes/IDataType.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h>
+#include <Functions/IFunctionAdaptors.h>
+#include <Functions/Regexps.h>
+#include <base/map.h>
#include <Common/Exception.h>
+#include <Common/OptimizedRegularExpression.h>
#include <Poco/Logger.h>
#include <Common/logger_useful.h>
@@ -36,47 +43,168 @@ namespace DB
{
namespace ErrorCodes
{
- extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
- extern const int ILLEGAL_COLUMN;
+extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
+extern const int ILLEGAL_COLUMN;
}
}
namespace local_engine
{
-class SparkFunctionStrToMap : public DB::IFunction
+
+class TrivialCharSplitter
{
public:
using Pos = const char *;
- static constexpr auto name = "spark_str_to_map";
- static DB::FunctionPtr create(const DB::ContextPtr) { return
std::make_shared<SparkFunctionStrToMap>(); }
+ TrivialCharSplitter(String & delimiter_) : delimiter(delimiter_) { }
- String getName() const override { return name; }
+ void reset(Pos str_begin_, Pos str_end_)
+ {
+ str_begin = str_begin_;
+ str_end = str_end_;
+ str_cursor = str_begin;
+ delimiter_begin = nullptr;
+ delimiter_end = nullptr;
+ }
- bool isVariadic() const override { return true; }
- size_t getNumberOfArguments() const override { return 3; }
+ Pos getDelimiterBegin() const { return delimiter_begin; }
+ Pos getDelimiterEnd() const { return delimiter_end; }
- bool useDefaultImplementationForConstants() const override { return true; }
- DB::ColumnNumbers getArgumentsThatAreAlwaysConstant() const override {
return {1, 2}; }
+ bool next(Pos & token_begin, Pos & token_end)
+ {
+ if (str_cursor >= str_end)
+ return false;
+ token_begin = str_cursor;
+ auto next_token_pos = static_cast<Pos>(memmem(str_cursor, str_end -
str_cursor, delimiter.c_str(), delimiter.size()));
+ // If delimiter is not found, return the remaining string.
+ LOG_ERROR(getLogger("TrivialCharSplitter"), "xxx next_token_pos: {}",
fmt::ptr(next_token_pos));
+ if (!next_token_pos)
+ {
+ token_end = str_end;
+ str_cursor = str_end;
+ delimiter_begin = nullptr;
+ delimiter_end = nullptr;
+ LOG_ERROR(
+ getLogger("TrivialCharSplitter"), "xxx delimiter_begin: {}
{}", delimiter_begin == nullptr, fmt::ptr(delimiter_begin));
+ }
+ else
+ {
+ delimiter_begin = next_token_pos;
+ token_end = next_token_pos;
+ str_cursor = next_token_pos + delimiter.size();
+ delimiter_end = str_cursor;
+ }
+ return true;
+ }
- bool isSuitableForShortCircuitArgumentsExecution(const
DB::DataTypesWithConstInfo & /*arguments*/) const override { return true; }
+private:
+ const String & delimiter;
+ Pos str_begin;
+ Pos str_end;
+ Pos str_cursor;
+ Pos delimiter_begin;
+ Pos delimiter_end;
+};
- DB::DataTypePtr getReturnTypeImpl(const DB::ColumnsWithTypeAndName &
arguments) const override
+struct RegularSplitter
+{
+public:
+ using Pos = const char *;
+ RegularSplitter(const String & delimiter_) : delimiter(delimiter_)
+ {
+ if (!delimiter.empty())
+ re =
std::make_shared<OptimizedRegularExpression>(DB::Regexps::createRegexp<false,
false, false>(delimiter));
+ }
+
+ void reset(Pos str_begin_, Pos str_end_)
+ {
+ str_begin = str_begin_;
+ str_end = str_end_;
+ str_cursor = str_begin;
+ delimiter_begin = nullptr;
+ delimiter_end = nullptr;
+ }
+
+ Pos getDelimiterBegin() const { return delimiter_begin; }
+ Pos getDelimiterEnd() const { return delimiter_end; }
+
+ bool next(Pos & token_begin, Pos & token_end)
{
- if (arguments.size() != 3)
+ if (str_cursor >= str_end)
+ return false;
+ // If delimiter is empty, return each character as a token.
+ if (!re)
{
- throw DB::Exception(
- DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
- "Function {} requires 3 arguments, passed {}",
- getName(),
- arguments.size());
+ token_begin = str_cursor;
+ ++str_cursor;
+ delimiter_begin = str_cursor;
+ delimiter_end = str_cursor;
+ token_end = str_cursor;
}
-
- if
(!DB::WhichDataType(DB::removeNullable(arguments[0].type)).isString()
- ||
!DB::WhichDataType(DB::removeNullable(arguments[1].type)).isString()
- ||
!DB::WhichDataType(DB::removeNullable(arguments[2].type)).isString())
+ else
{
- throw DB::Exception(DB::ErrorCodes::ILLEGAL_COLUMN, "All arguments
for function {} must be String", getName());
+ if (!re->match(str_cursor, str_end - str_cursor, matches))
+ {
+ token_begin = str_cursor;
+ token_end = str_end;
+ str_cursor = str_end;
+ delimiter_begin = nullptr;
+ delimiter_end = nullptr;
+ return true;
+ }
+ token_begin = str_cursor;
+ token_end = str_cursor + matches[0].offset;
+ delimiter_begin = token_end;
+ str_cursor = token_end + matches[0].length;
+ delimiter_end = str_cursor;
}
+ return true;
+ }
+
+private:
+ const String & delimiter;
+ DB::Regexps::RegexpPtr re;
+ OptimizedRegularExpression::MatchVec matches;
+ Pos str_begin;
+ Pos str_end;
+ Pos str_cursor;
+ Pos delimiter_begin;
+ Pos delimiter_end;
+};
+
+
+static bool isConstColumn(const DB::IColumn & col)
+{
+ return col.isConst();
+}
+
+template <typename PairGenerator, typename KVGenerator>
+class SparkFunctionStrToMap : public DB::IFunction
+{
+public:
+ using Pos = const char *;
+ static constexpr auto name = "spark_str_to_map";
+ static DB::FunctionPtr create(const DB::ContextPtr) { return
std::make_shared<SparkFunctionStrToMap<PairGenerator, KVGenerator>>(); }
+ String getName() const override { return name; }
+ bool isVariadic() const override { return false; }
+ size_t getNumberOfArguments() const override { return 3; }
+ bool useDefaultImplementationForConstants() const override { return true; }
+ DB::ColumnNumbers getArgumentsThatAreAlwaysConstant() const override {
return {1, 2}; }
+ bool isSuitableForShortCircuitArgumentsExecution(const
DB::DataTypesWithConstInfo & /*arguments*/) const override { return true; }
+ DB::DataTypePtr getReturnTypeImpl(const DB::ColumnsWithTypeAndName &
arguments) const override
+ {
+ DB::FunctionArgumentDescriptors mandatory_args{
+ {"string",
static_cast<DB::FunctionArgumentDescriptor::TypeValidator>(&DB::isStringOrFixedString),
nullptr, "String"},
+ {"pair_delimiter",
+
static_cast<DB::FunctionArgumentDescriptor::TypeValidator>(&DB::isStringOrFixedString),
+ &isConstColumn,
+ "String"},
+ {"key_value_delimiter",
+
static_cast<DB::FunctionArgumentDescriptor::TypeValidator>(&DB::isStringOrFixedString),
+ &isConstColumn,
+ "String"},
+ };
+ validateFunctionArguments(*this, arguments, mandatory_args);
+
auto map_typ = std::make_shared<DB::DataTypeMap>(
std::make_shared<DB::DataTypeString>(),
makeNullable(std::make_shared<DB::DataTypeString>()));
if (arguments[0].type->isNullable())
@@ -88,76 +216,175 @@ public:
DB::ColumnPtr executeImpl(
const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr &
result_type, size_t /*input_rows_count*/) const override
{
- auto map_col = result_type->createColumn();
- auto pair_delim = (*arguments[1].column)[0].safeGet<String>();
- auto pair_delim_len = pair_delim.size();
- auto kv_delim = (*arguments[2].column)[0].safeGet<String>();
- auto kv_delim_len = kv_delim.size();
- const DB::IColumn * arg0 = arguments[0].column.get();
- bool is_nullable = false;
- if (arg0->isNullable())
+ String pair_delim = (*arguments[1].column)[0].safeGet<String>();
+ String kv_delim = (*arguments[2].column)[0].safeGet<String>();
+ PairGenerator pair_generator(pair_delim);
+ KVGenerator kv_generator(kv_delim);
+
+ auto col_map = result_type->createColumn();
+
+ const DB::ColumnString * col_str = nullptr;
+ const DB::ColumnUInt8 * null_map = nullptr;
+ if (arguments[0].column->isNullable())
{
- arg0 = DB::checkAndGetColumn<DB::ColumnNullable>(arg0);
- is_nullable = true;
+ const auto * col_null =
DB::checkAndGetColumn<DB::ColumnNullable>(arguments[0].column.get());
+ col_str =
DB::checkAndGetColumn<DB::ColumnString>(col_null->getNestedColumnPtr().get());
+ null_map = &(col_null->getNullMapColumn());
}
- const auto * str_col =
DB::checkAndGetColumn<DB::ColumnString>(arguments[0].column.get());
- if (!str_col) [[unlikely]]
+ else
{
- throw DB::Exception(DB::ErrorCodes::ILLEGAL_COLUMN, "argument 0
for function {} must be String", getName());
+ col_str =
DB::checkAndGetColumn<DB::ColumnString>(arguments[0].column.get());
}
- const DB::ColumnString::Chars & str_vec = str_col->getChars();
- const DB::ColumnString::Offsets & str_offsets = str_col->getOffsets();
- map_col->reserve(str_offsets.size());
+
+ const auto & strs = col_str->getChars();
+ const auto & offsets = col_str->getOffsets();
DB::ColumnString::Offset prev_offset = 0;
- for (size_t i = 0, n = str_offsets.size(); i < n; ++i)
+
+ for (size_t i = 0, n = offsets.size(); i < n; ++i)
{
- if (is_nullable && str_col->isNullAt(i))
- {
- map_col->insertDefault();
- }
+ if (null_map && (*null_map)[n] != 0)
+ col_map->insertDefault();
else
{
DB::Map map;
- Pos pair_begin = reinterpret_cast<const char
*>(&str_vec[prev_offset]);
- Pos str_end = reinterpret_cast<const char
*>(&str_vec[str_offsets[i]]);
- while (pair_begin < str_end)
+ Pos str_begin = reinterpret_cast<Pos>(&strs[prev_offset]);
+ Pos str_end = reinterpret_cast<Pos>(&strs[offsets[i]]);
+ LOG_TRACE(
+ getLogger("SparkFunctionStrToMap"),
+ "str_begin: {}, str_end: {}. {}",
+ 0,
+ str_end - str_begin,
+ std::string_view(str_begin, str_end - str_begin));
+ pair_generator.reset(str_begin, str_end);
+ Pos pair_begin;
+ Pos pair_end;
+ while (pair_generator.next(pair_begin, pair_end))
{
- // Get next pair.
- auto next_pair_begin
- = static_cast<const char *>(memmem(pair_begin, str_end
- pair_begin, pair_delim.c_str(), pair_delim_len));
- if (!next_pair_begin) [[unlikely]]
- next_pair_begin = str_end - 1;
- Pos value_begin
- = static_cast<const char *>(memmem(pair_begin,
next_pair_begin - pair_begin, kv_delim.c_str(), kv_delim_len));
- DB::Field key;
- DB::Field value;
- if (!value_begin)
+ LOG_TRACE(
+ getLogger("SparkFunctionStrToMap"),
+ "pair_begin: {}, pair_end: {}, {}",
+ pair_begin - str_begin,
+ pair_end - str_begin,
+ std::string_view(pair_begin, pair_end - pair_begin));
+ kv_generator.reset(pair_begin, pair_end);
+ Pos key_begin;
+ Pos key_end;
+ if (kv_generator.next(key_begin, key_end))
{
- key = std::string_view(pair_begin, next_pair_begin -
pair_begin);
- value = DB::Null();
+ DB::Tuple tuple(2);
+ size_t key_len = key_end - key_begin;
+ tuple[0] = key_end == str_end ?
std::string_view(key_begin, key_len - 1) : std::string_view(key_begin, key_len);
+ auto delimiter_begin =
kv_generator.getDelimiterBegin();
+ auto delimiter_end = kv_generator.getDelimiterEnd();
+ LOG_TRACE(
+ getLogger("SparkFunctionStrToMap"),
+ "key_begin: {}, key_end: {}, delim_begin: {},
delim_end: {}, key:{}",
+ key_begin - str_begin,
+ key_end - str_begin,
+ delimiter_begin - str_begin,
+ delimiter_end - str_begin,
+ std::string_view(key_begin, key_end - key_begin));
+ if (delimiter_begin && delimiter_begin != str_end)
+ {
+ DB::Field value = pair_end == str_end ?
std::string_view(delimiter_end, pair_end - delimiter_end - 1)
+ :
std::string_view(delimiter_end, pair_end - delimiter_end);
+ tuple[1] = std::move(value);
+ }
+ else
+ {
+ // Not found delimiter, the value should be null
+ tuple[1] = DB::Null();
+ }
+ map.emplace_back(std::move(tuple));
+ }
+ else if (pair_begin == pair_end)
+ {
+ // Empty pair. key is empty string, but value is null.
+ DB::Tuple tuple(2);
+ tuple[0] = std::string();
+ tuple[1] = DB::Null();
+ map.emplace_back(std::move(tuple));
}
else
{
- key = std::string_view(pair_begin, value_begin -
pair_begin);
- value = std::string_view(value_begin + kv_delim_len,
next_pair_begin - value_begin - kv_delim_len);
+ LOG_WARNING(getLogger("SparkFunctionStrToMap"), "Split
key value failed.");
}
- DB::Tuple tuple(2);
- tuple[0] = std::move(key);
- tuple[1] = std::move(value);
- map.emplace_back(std::move(tuple));
-
- pair_begin = next_pair_begin + pair_delim_len;
}
- map_col->insert(map);
+ col_map->insert(std::move(map));
}
- prev_offset = str_offsets[i];
+ prev_offset = offsets[i];
}
- return map_col;
+
+ return col_map;
+ }
+};
+
+class SparkFunctionStrToMapOverloadResolver : public
DB::IFunctionOverloadResolver
+{
+public:
+ static constexpr auto name = "spark_str_to_map";
+ static DB::FunctionOverloadResolverPtr create(const DB::ContextPtr context)
+ {
+ return
std::make_shared<SparkFunctionStrToMapOverloadResolver>(context);
+ }
+
+ explicit SparkFunctionStrToMapOverloadResolver(DB::ContextPtr context_)
+ : context(context_),
trivial_function(SparkFunctionStrToMap<TrivialCharSplitter,
TrivialCharSplitter>::create(context))
+ {
+ }
+
+ String getName() const override { return name; }
+ size_t getNumberOfArguments() const override { return 3; }
+ bool isVariadic() const override { return false; }
+ DB::FunctionBasePtr buildImpl(const DB::ColumnsWithTypeAndName &
arguments, const DB::DataTypePtr & return_type) const override
+ {
+ // The delimiter could be a regular expression.
+ bool is_trivial_pair_delim = patternIsTrivialChar(arguments[1]);
+ bool is_trivial_kv_delim = patternIsTrivialChar(arguments[2]);
+ DB::FunctionPtr function_ptr = nullptr;
+ if (is_trivial_pair_delim && is_trivial_kv_delim)
+ function_ptr = trivial_function;
+ else if (is_trivial_pair_delim && !is_trivial_kv_delim)
+ function_ptr = SparkFunctionStrToMap<TrivialCharSplitter,
RegularSplitter>::create(context);
+ else if (!is_trivial_pair_delim && is_trivial_kv_delim)
+ function_ptr = SparkFunctionStrToMap<RegularSplitter,
TrivialCharSplitter>::create(context);
+ else
+ function_ptr = SparkFunctionStrToMap<RegularSplitter,
RegularSplitter>::create(context);
+ return std::make_unique<DB::FunctionToFunctionBaseAdaptor>(
+ function_ptr, collections::map<DB::DataTypes>(arguments, [](const
auto & elem) { return elem.type; }), return_type);
+ }
+
+ DB::DataTypePtr getReturnTypeImpl(const DB::ColumnsWithTypeAndName &
arguments) const override
+ {
+ return trivial_function->getReturnTypeImpl(arguments);
+ }
+
+private:
+ DB::ContextPtr context;
+ DB::FunctionPtr trivial_function;
+
+ bool patternIsTrivialChar(const DB::ColumnWithTypeAndName & argument) const
+ {
+ const DB::ColumnConst * col =
checkAndGetColumnConstStringOrFixedString(argument.column.get());
+ if (!col)
+ return false;
+
+ String pattern = col->getValue<String>();
+ if (pattern.empty())
+ return false;
+ OptimizedRegularExpression re = DB::Regexps::createRegexp<false,
false, false>(pattern);
+
+ std::string required_substring;
+ bool is_trivial;
+ bool required_substring_is_prefix;
+ re.getAnalyzeResult(required_substring, is_trivial,
required_substring_is_prefix);
+ return is_trivial && required_substring == pattern;
}
};
-REGISTER_FUNCTION(SparkFunctionStrToMap)
+
+REGISTER_FUNCTION(SparkStrToMap)
{
- factory.registerFunction<SparkFunctionStrToMap>();
+ factory.registerFunction<SparkFunctionStrToMapOverloadResolver>();
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]