This is an automated email from the ASF dual-hosted git repository.
morningman pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-doris.git
The following commit(s) were added to refs/heads/master by this push:
new b26e7e3 [feature](function)(vec) support locate function (#7988)
b26e7e3 is described below
commit b26e7e3c284bdae98ac69d0e786b3e623e543e38
Author: Pxl <[email protected]>
AuthorDate: Sat Feb 12 16:00:37 2022 +0800
[feature](function)(vec) support locate function (#7988)
* support function locate in vectorized engine
* add ut and fix some bug
---
be/src/vec/common/string_ref.h | 7 +-
be/src/vec/functions/function_string.cpp | 32 ++++----
be/src/vec/functions/function_string.h | 110 +++++++++++++++++++++++++-
be/src/vec/functions/function_totype.h | 8 ++
be/test/vec/function/function_string_test.cpp | 46 +++++++++--
5 files changed, 174 insertions(+), 29 deletions(-)
diff --git a/be/src/vec/common/string_ref.h b/be/src/vec/common/string_ref.h
index bd81342..727996e 100644
--- a/be/src/vec/common/string_ref.h
+++ b/be/src/vec/common/string_ref.h
@@ -27,6 +27,7 @@
#include "gutil/hash/city.h"
#include "gutil/hash/hash128to64.h"
+#include "udf/udf.h"
#include "vec/common/unaligned.h"
#include "vec/core/types.h"
@@ -53,6 +54,10 @@ struct StringRef {
std::string to_string() const { return std::string(data, size); }
explicit operator std::string() const { return to_string(); }
+
+ StringVal to_string_val() const {
+ return StringVal(reinterpret_cast<uint8_t*>(const_cast<char*>(data)),
size);
+ }
};
using StringRefs = std::vector<StringRef>;
@@ -291,7 +296,7 @@ struct StringRefHash : CRC32Hash {};
struct CRC32Hash {
size_t operator()(StringRef /* x */) const {
- throw std::logic_error{"Not implemented CRC32Hash without SSE"};
+ throw std::logic_error {"Not implemented CRC32Hash without SSE"};
}
};
diff --git a/be/src/vec/functions/function_string.cpp
b/be/src/vec/functions/function_string.cpp
index a89edf5..94ed9c9 100644
--- a/be/src/vec/functions/function_string.cpp
+++ b/be/src/vec/functions/function_string.cpp
@@ -95,14 +95,7 @@ struct StringUtf8LengthImpl {
for (int i = 0; i < size; ++i) {
const char* raw_str = reinterpret_cast<const
char*>(&data[offsets[i - 1]]);
int str_size = offsets[i] - offsets[i - 1] - 1;
-
- size_t char_len = 0;
- for (size_t i = 0, char_size = 0; i < str_size; i += char_size) {
- char_size = get_utf8_byte_length((unsigned)(raw_str)[i]);
- ++char_len;
- }
-
- res[i] = char_len;
+ res[i] = get_char_len(StringValue(const_cast<char*>(raw_str),
str_size), str_size);
}
return Status::OK();
}
@@ -201,17 +194,19 @@ struct InStrOP {
// Hive returns positions starting from 1.
int loc = search.search(&str_sv);
if (loc > 0) {
- size_t char_len = 0;
- for (size_t i = 0, char_size = 0; i < loc; i += char_size) {
- char_size = get_utf8_byte_length((unsigned)(strl.data())[i]);
- ++char_len;
- }
- loc = char_len;
+ loc = get_char_len(str_sv, loc);
}
res = loc + 1;
}
};
+struct LocateOP {
+ using ResultDataType = DataTypeInt32;
+ using ResultPaddedPODArray = PaddedPODArray<Int32>;
+ static void execute(const std::string_view& strl, const std::string_view&
strr, int32_t& res) {
+ InStrOP::execute(strr, strl, res);
+ }
+};
// LeftDataType and RightDataType are DataTypeString
template <typename LeftDataType, typename RightDataType, typename OP>
@@ -706,6 +701,9 @@ template <typename LeftDataType, typename RightDataType>
using StringInstrImpl = StringFunctionImpl<LeftDataType, RightDataType,
InStrOP>;
template <typename LeftDataType, typename RightDataType>
+using StringLocateImpl = StringFunctionImpl<LeftDataType, RightDataType,
LocateOP>;
+
+template <typename LeftDataType, typename RightDataType>
using StringFindInSetImpl = StringFunctionImpl<LeftDataType, RightDataType,
FindInSetOp>;
// ready for regist function
@@ -720,7 +718,7 @@ using FunctionStringEndsWith =
using FunctionStringInstr =
FunctionBinaryToType<DataTypeString, DataTypeString, StringInstrImpl,
NameInstr>;
using FunctionStringLocate =
- FunctionBinaryToType<DataTypeString, DataTypeString, StringInstrImpl,
NameLocate>;
+ FunctionBinaryToType<DataTypeString, DataTypeString, StringLocateImpl,
NameLocate>;
using FunctionStringFindInSet =
FunctionBinaryToType<DataTypeString, DataTypeString,
StringFindInSetImpl, NameFindInSet>;
@@ -755,7 +753,6 @@ using FunctionStringLPad = FunctionStringPad<StringLPad>;
using FunctionStringRPad = FunctionStringPad<StringRPad>;
void register_function_string(SimpleFunctionFactory& factory) {
- // factory.register_function<>();
factory.register_function<FunctionStringASCII>();
factory.register_function<FunctionStringLength>();
factory.register_function<FunctionStringUTF8Length>();
@@ -764,7 +761,8 @@ void register_function_string(SimpleFunctionFactory&
factory) {
factory.register_function<FunctionStringEndsWith>();
factory.register_function<FunctionStringInstr>();
factory.register_function<FunctionStringFindInSet>();
- // factory.register_function<FunctionStringLocate>();
+ factory.register_function<FunctionStringLocate>();
+ factory.register_function<FunctionStringLocatePos>();
factory.register_function<FunctionReverse>();
factory.register_function<FunctionHexString>();
factory.register_function<FunctionUnHex>();
diff --git a/be/src/vec/functions/function_string.h
b/be/src/vec/functions/function_string.h
index af58062..934053e 100644
--- a/be/src/vec/functions/function_string.h
+++ b/be/src/vec/functions/function_string.h
@@ -21,18 +21,21 @@
#include <fmt/format.h>
#include <fmt/ranges.h>
+#include <cstdint>
#include <string_view>
#include "exprs/anyval_util.h"
#include "exprs/math_functions.h"
#include "exprs/string_functions.h"
#include "runtime/string_value.hpp"
+#include "udf/udf.h"
#include "util/md5.h"
#include "util/url_parser.h"
#include "vec/columns/column_decimal.h"
#include "vec/columns/column_nullable.h"
#include "vec/columns/column_string.h"
#include "vec/columns/columns_number.h"
+#include "vec/common/string_ref.h"
#include "vec/data_types/data_type_nullable.h"
#include "vec/data_types/data_type_number.h"
#include "vec/data_types/data_type_string.h"
@@ -70,6 +73,25 @@ inline size_t get_char_len(const std::string_view& str,
std::vector<size_t>* str
return char_len;
}
+inline size_t get_char_len(const StringVal& str, std::vector<size_t>*
str_index) {
+ size_t char_len = 0;
+ for (size_t i = 0, char_size = 0; i < str.len; i += char_size) {
+ char_size = get_utf8_byte_length((unsigned)(str.ptr)[i]);
+ str_index->push_back(i);
+ ++char_len;
+ }
+ return char_len;
+}
+
+inline size_t get_char_len(const StringValue& str, size_t end_pos) {
+ size_t char_len = 0;
+ for (size_t i = 0, char_size = 0; i < std::min(str.len, end_pos); i +=
char_size) {
+ char_size = get_utf8_byte_length((unsigned)(str.ptr)[i]);
+ ++char_len;
+ }
+ return char_len;
+}
+
struct StringOP {
static void push_empty_string(int index, ColumnString::Chars& chars,
ColumnString::Offsets& offsets) {
@@ -1079,7 +1101,7 @@ struct MoneyFormatDoubleImpl {
static DataTypes get_variadic_argument_types() { return
{std::make_shared<DataTypeFloat64>()}; }
static void execute(FunctionContext* context, ColumnString* result_column,
- const ColumnType* data_column, size_t
input_rows_count) {
+ const ColumnType* data_column, size_t
input_rows_count) {
for (size_t i = 0; i < input_rows_count; i++) {
double value =
MathFunctions::my_double_round(data_column->get_element(i), 2, false, false);
@@ -1095,7 +1117,7 @@ struct MoneyFormatInt64Impl {
static DataTypes get_variadic_argument_types() { return
{std::make_shared<DataTypeInt64>()}; }
static void execute(FunctionContext* context, ColumnString* result_column,
- const ColumnType* data_column, size_t
input_rows_count) {
+ const ColumnType* data_column, size_t
input_rows_count) {
for (size_t i = 0; i < input_rows_count; i++) {
Int64 value = data_column->get_element(i);
StringVal str = StringFunctions::do_money_format<Int64,
26>(context, value);
@@ -1110,7 +1132,7 @@ struct MoneyFormatInt128Impl {
static DataTypes get_variadic_argument_types() { return
{std::make_shared<DataTypeInt128>()}; }
static void execute(FunctionContext* context, ColumnString* result_column,
- const ColumnType* data_column, size_t
input_rows_count) {
+ const ColumnType* data_column, size_t
input_rows_count) {
for (size_t i = 0; i < input_rows_count; i++) {
Int128 value = data_column->get_element(i);
StringVal str = StringFunctions::do_money_format<Int128,
52>(context, value);
@@ -1127,7 +1149,7 @@ struct MoneyFormatDecimalImpl {
}
static void execute(FunctionContext* context, ColumnString* result_column,
- const ColumnType* data_column, size_t
input_rows_count) {
+ const ColumnType* data_column, size_t
input_rows_count) {
for (size_t i = 0; i < input_rows_count; i++) {
DecimalV2Val value = DecimalV2Val(data_column->get_element(i));
@@ -1142,4 +1164,84 @@ struct MoneyFormatDecimalImpl {
}
};
+class FunctionStringLocatePos : public IFunction {
+public:
+ static constexpr auto name = "locate";
+ static FunctionPtr create() { return
std::make_shared<FunctionStringLocatePos>(); }
+ String get_name() const override { return name; }
+ size_t get_number_of_arguments() const override { return 3; }
+
+ DataTypePtr get_return_type_impl(const DataTypes& arguments) const
override {
+ return std::make_shared<DataTypeInt32>();
+ }
+
+ DataTypes get_variadic_argument_types_impl() const override {
+ return {std::make_shared<DataTypeString>(),
std::make_shared<DataTypeString>(),
+ std::make_shared<DataTypeInt32>()};
+ }
+
+ bool is_variadic() const override { return true; }
+
+ bool use_default_implementation_for_constants() const override { return
true; }
+
+ Status execute_impl(FunctionContext* context, Block& block, const
ColumnNumbers& arguments,
+ size_t result, size_t input_rows_count) override {
+ auto col_substr =
+
block.get_by_position(arguments[0]).column->convert_to_full_column_if_const();
+ auto col_str =
+
block.get_by_position(arguments[1]).column->convert_to_full_column_if_const();
+ auto col_pos =
+
block.get_by_position(arguments[2]).column->convert_to_full_column_if_const();
+
+ ColumnInt32::MutablePtr col_res = ColumnInt32::create();
+
+ auto& vec_pos = reinterpret_cast<const
ColumnInt32*>(col_pos.get())->get_data();
+ auto& vec_res = col_res->get_data();
+ vec_res.resize(input_rows_count);
+
+ for (int i = 0; i < input_rows_count; ++i) {
+ vec_res[i] = locate_pos(col_substr->get_data_at(i).to_string_val(),
+ col_str->get_data_at(i).to_string_val(),
vec_pos[i]);
+ }
+
+ block.replace_by_position(result, std::move(col_res));
+ return Status::OK();
+ }
+
+private:
+ int locate_pos(StringVal substr, StringVal str, int start_pos) {
+ if (substr.len == 0) {
+ if (start_pos <= 0) {
+ return 0;
+ } else if (start_pos == 1) {
+ return 1;
+ } else if (start_pos > str.len) {
+ return 0;
+ } else {
+ return start_pos;
+ }
+ }
+ // Hive returns 0 for *start_pos <= 0,
+ // but throws an exception for *start_pos > str->len.
+ // Since returning 0 seems to be Hive's error condition, return 0.
+ std::vector<size_t> index;
+ size_t char_len = get_char_len(str, &index);
+ if (start_pos <= 0 || start_pos > str.len || start_pos > char_len) {
+ return 0;
+ }
+ StringValue substr_sv = StringValue::from_string_val(substr);
+ StringSearch search(&substr_sv);
+ // Input start_pos starts from 1.
+ StringValue adjusted_str(reinterpret_cast<char*>(str.ptr) +
index[start_pos - 1],
+ str.len - index[start_pos - 1]);
+ int32_t match_pos = search.search(&adjusted_str);
+ if (match_pos >= 0) {
+ // Hive returns the position in the original string starting from
1.
+ return start_pos + get_char_len(adjusted_str, match_pos);
+ } else {
+ return 0;
+ }
+ }
+};
+
} // namespace doris::vectorized
\ No newline at end of file
diff --git a/be/src/vec/functions/function_totype.h
b/be/src/vec/functions/function_totype.h
index 72e74b3..fef300e 100644
--- a/be/src/vec/functions/function_totype.h
+++ b/be/src/vec/functions/function_totype.h
@@ -195,9 +195,17 @@ public:
static FunctionPtr create() { return
std::make_shared<FunctionBinaryToType>(); }
String get_name() const override { return name; }
size_t get_number_of_arguments() const override { return 2; }
+
DataTypePtr get_return_type_impl(const DataTypes& arguments) const
override {
return std::make_shared<ResultDataType>();
}
+
+ DataTypes get_variadic_argument_types_impl() const override {
+ return {std::make_shared<DataTypeString>(),
std::make_shared<DataTypeString>()};
+ }
+
+ bool is_variadic() const override { return true; }
+
bool use_default_implementation_for_constants() const override { return
true; }
Status execute_impl(FunctionContext* context, Block& block, const
ColumnNumbers& arguments,
diff --git a/be/test/vec/function/function_string_test.cpp
b/be/test/vec/function/function_string_test.cpp
index 5617a36..60184fc 100644
--- a/be/test/vec/function/function_string_test.cpp
+++ b/be/test/vec/function/function_string_test.cpp
@@ -25,8 +25,10 @@
#include "util/encryption_util.h"
#include "util/url_coding.h"
#include "vec/core/field.h"
+#include "vec/core/types.h"
namespace doris::vectorized {
+using namespace ut_type;
TEST(function_string_test, function_string_substr_test) {
std::string func_name = "substr";
@@ -440,17 +442,47 @@ TEST(function_string_test, function_instr_test) {
InputTypeSet input_types = {TypeIndex::String, TypeIndex::String};
- DataSet data_set = {{{std::string("abcdefg"), std::string("efg")}, 5},
- {{std::string("aa"), std::string("a")}, 1},
- {{std::string("我是"), std::string("是")}, 2},
- {{std::string("abcd"), std::string("e")}, 0},
- {{std::string("abcdef"), std::string("")}, 1},
- {{std::string(""), std::string("")}, 1},
- {{std::string("aaaab"), std::string("bb")}, 0}};
+ DataSet data_set = {
+ {{STRING("abcdefg"), STRING("efg")}, INT(5)}, {{STRING("aa"),
STRING("a")}, INT(1)},
+ {{STRING("我是"), STRING("是")}, INT(2)}, {{STRING("abcd"),
STRING("e")}, INT(0)},
+ {{STRING("abcdef"), STRING("")}, INT(1)}, {{STRING(""),
STRING("")}, INT(1)},
+ {{STRING("aaaab"), STRING("bb")}, INT(0)}};
check_function<DataTypeInt32, true>(func_name, input_types, data_set);
}
+TEST(function_string_test, function_locate_test) {
+ std::string func_name = "locate";
+
+ {
+ InputTypeSet input_types = {TypeIndex::String, TypeIndex::String};
+
+ DataSet data_set = {{{STRING("efg"), STRING("abcdefg")}, INT(5)},
+ {{STRING("a"), STRING("aa")}, INT(1)},
+ {{STRING("是"), STRING("我是")}, INT(2)},
+ {{STRING("e"), STRING("abcd")}, INT(0)},
+ {{STRING(""), STRING("abcdef")}, INT(1)},
+ {{STRING(""), STRING("")}, INT(1)},
+ {{STRING("bb"), STRING("aaaab")}, INT(0)}};
+
+ check_function<DataTypeInt32, true>(func_name, input_types, data_set);
+ }
+
+ {
+ InputTypeSet input_types = {TypeIndex::String, TypeIndex::String,
TypeIndex::Int32};
+
+ DataSet data_set = {{{STRING("bar"), STRING("foobarbar"), INT(5)},
INT(7)},
+ {{STRING("xbar"), STRING("foobar"), INT(1)},
INT(0)},
+ {{STRING(""), STRING("foobar"), INT(2)}, INT(2)},
+ {{STRING("A"), STRING("大A写的A"), INT(0)}, INT(0)},
+ {{STRING("A"), STRING("大A写的A"), INT(1)}, INT(2)},
+ {{STRING("A"), STRING("大A写的A"), INT(2)}, INT(2)},
+ {{STRING("A"), STRING("大A写的A"), INT(3)}, INT(5)}};
+
+ check_function<DataTypeInt32, true>(func_name, input_types, data_set);
+ }
+}
+
TEST(function_string_test, function_find_in_set_test) {
std::string func_name = "find_in_set";
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]