This is an automated email from the ASF dual-hosted git repository.
ravindra pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 319e225f0a ARROW-15568: [C++][Gandiva] Implement Translate Function
(#12333)
319e225f0a is described below
commit 319e225f0a82416e8cc1901a0f7ea115f652e74f
Author: Vinícius Roque <[email protected]>
AuthorDate: Mon Jun 27 02:34:55 2022 -0700
ARROW-15568: [C++][Gandiva] Implement Translate Function (#12333)
Lead-authored-by: Vinicius Roque <[email protected]>
Co-authored-by: ViniciusSouzaRoque <[email protected]>
Signed-off-by: Pindikura Ravindra <[email protected]>
---
cpp/src/gandiva/function_registry_string.cc | 6 +-
cpp/src/gandiva/gdv_function_stubs.h | 5 +
cpp/src/gandiva/gdv_function_stubs_test.cc | 43 +++++
cpp/src/gandiva/gdv_string_function_stubs.cc | 233 +++++++++++++++++++++++++++
cpp/src/gandiva/tests/projector_test.cc | 50 ++++++
5 files changed, 336 insertions(+), 1 deletion(-)
diff --git a/cpp/src/gandiva/function_registry_string.cc
b/cpp/src/gandiva/function_registry_string.cc
index ea3672c89b..21681775c6 100644
--- a/cpp/src/gandiva/function_registry_string.cc
+++ b/cpp/src/gandiva/function_registry_string.cc
@@ -511,7 +511,11 @@ std::vector<NativeFunction> GetStringFunctionRegistry() {
NativeFunction::kNeedsContext),
NativeFunction("instr", {}, DataTypeVector{utf8(), utf8()}, int32(),
- kResultNullIfNull, "instr_utf8")};
+ kResultNullIfNull, "instr_utf8"),
+
+ NativeFunction("translate", {}, DataTypeVector{utf8(), utf8(), utf8()},
utf8(),
+ kResultNullIfNull, "translate_utf8_utf8_utf8",
+ NativeFunction::kNeedsContext |
NativeFunction::kCanReturnErrors)};
return string_fn_registry_;
}
diff --git a/cpp/src/gandiva/gdv_function_stubs.h
b/cpp/src/gandiva/gdv_function_stubs.h
index a8ce58698e..265d48fb72 100644
--- a/cpp/src/gandiva/gdv_function_stubs.h
+++ b/cpp/src/gandiva/gdv_function_stubs.h
@@ -333,4 +333,9 @@ const char* gdv_fn_sha256_timestamp(int64_t context,
gdv_timestamp value, bool v
GANDIVA_EXPORT
const char* gdv_fn_sha256_utf8(int64_t context, gdv_utf8 value, int32_t
value_length,
bool value_validity, int32_t* out_length);
+
+GANDIVA_EXPORT
+const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t
in_len,
+ const char* from, int32_t from_len, const
char* to,
+ int32_t to_len, int32_t* out_len);
}
diff --git a/cpp/src/gandiva/gdv_function_stubs_test.cc
b/cpp/src/gandiva/gdv_function_stubs_test.cc
index 65f70a8edc..7f282c172b 100644
--- a/cpp/src/gandiva/gdv_function_stubs_test.cc
+++ b/cpp/src/gandiva/gdv_function_stubs_test.cc
@@ -950,4 +950,47 @@ TEST(TestGdvFnStubs, TestMaskLastN) {
EXPECT_EQ(expected, std::string(result, out_len));
}
+TEST(TestGdvFnStubs, TestTranslate) {
+ gandiva::ExecutionContext ctx;
+ int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);
+ int32_t out_len = 0;
+
+ std::string expected = "ACACACA";
+ const char* result =
+ translate_utf8_utf8_utf8(ctx_ptr, "ABABABA", 7, "B", 1, "C", 1,
&out_len);
+ EXPECT_EQ(expected, std::string(result, out_len));
+
+ expected = "acde";
+ result = translate_utf8_utf8_utf8(ctx_ptr, "a b c d e", 9, " b", 2, "", 0,
&out_len);
+ EXPECT_EQ(expected, std::string(result, out_len));
+
+ expected = "h3110, h0w ar3 y0u/";
+ result = translate_utf8_utf8_utf8(ctx_ptr, "hello, how are you?", 19,
"elo?", 4, "310/",
+ 4, &out_len);
+ EXPECT_EQ(expected, std::string(result, out_len));
+
+ expected = "1b9ef";
+ result = translate_utf8_utf8_utf8(ctx_ptr, "abcdef", 6, "adc", 3, "19", 2,
&out_len);
+ EXPECT_EQ(expected, std::string(result, out_len));
+
+ expected = "abcd";
+ result = translate_utf8_utf8_utf8(ctx_ptr, "a b c d", 7, " ", 1, "", 0,
&out_len);
+ EXPECT_EQ(expected, std::string(result, out_len));
+
+ expected = "1b9c9e1f";
+ result =
+ translate_utf8_utf8_utf8(ctx_ptr, "abdcdeaf", 8, "adad", 4, "192", 3,
&out_len);
+ EXPECT_EQ(expected, std::string(result, out_len));
+
+ expected = "012345678";
+ result = translate_utf8_utf8_utf8(ctx_ptr, "123456789", 9, "987654321", 9,
"0123456789",
+ 10, &out_len);
+ EXPECT_EQ(expected, std::string(result, out_len));
+
+ expected = "012345678";
+ result = translate_utf8_utf8_utf8(ctx_ptr, "987654321", 9, "123456789", 9,
"0123456789",
+ 10, &out_len);
+ EXPECT_EQ(expected, std::string(result, out_len));
+}
+
} // namespace gandiva
diff --git a/cpp/src/gandiva/gdv_string_function_stubs.cc
b/cpp/src/gandiva/gdv_string_function_stubs.cc
index 1948d3a3e1..eb3831ccb4 100644
--- a/cpp/src/gandiva/gdv_string_function_stubs.cc
+++ b/cpp/src/gandiva/gdv_string_function_stubs.cc
@@ -21,6 +21,7 @@
#include <utf8proc.h>
#include <string>
+#include <unordered_map>
#include <vector>
#include "arrow/util/double_conversion.h"
@@ -449,6 +450,222 @@ const char* gdv_fn_initcap_utf8(int64_t context, const
char* data, int32_t data_
*out_len = out_idx;
return out;
}
+GANDIVA_EXPORT
+const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t
in_len,
+ const char* from, int32_t from_len, const
char* to,
+ int32_t to_len, int32_t* out_len) {
+ if (in_len <= 0) {
+ *out_len = 0;
+ return "";
+ }
+
+ if (from_len <= 0) {
+ *out_len = in_len;
+ return in;
+ }
+
+ // This variable is to control if there are multi-byte utf8 entries
+ bool has_multi_byte = false;
+
+ // This variable is to store the final result
+ char* result;
+ int result_len;
+
+ // Searching multi-bytes in In
+ for (int i = 0; i < in_len; i++) {
+ unsigned char char_single_byte = in[i];
+ if (char_single_byte > 127) {
+ // found a multi-byte utf-8 char
+ has_multi_byte = true;
+ break;
+ }
+ }
+
+ // Searching multi-bytes in From
+ if (!has_multi_byte) {
+ for (int i = 0; i < from_len; i++) {
+ unsigned char char_single_byte = from[i];
+ if (char_single_byte > 127) {
+ // found a multi-byte utf-8 char
+ has_multi_byte = true;
+ break;
+ }
+ }
+ }
+
+ // Searching multi-bytes in To
+ if (!has_multi_byte) {
+ for (int i = 0; i < to_len; i++) {
+ unsigned char char_single_byte = to[i];
+ if (char_single_byte > 127) {
+ // found a multi-byte utf-8 char
+ has_multi_byte = true;
+ break;
+ }
+ }
+ }
+
+ // If there are no multibytes in the input, work only with char
+ if (!has_multi_byte) {
+ // This variable is for receive the substitutions
+ result = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context,
in_len));
+
+ if (result == nullptr) {
+ gdv_fn_context_set_error_msg(context,
+ "Could not allocate memory for output
string");
+ *out_len = 0;
+ return "";
+ }
+ result_len = 0;
+
+ // Creating a Map to mark substitutions to make
+ std::unordered_map<char, char> subs_list;
+
+ // This variable is for controlling the position in entry TO, for never
repeat the
+ // changes
+ int start_compare;
+
+ if (to_len > 0) {
+ start_compare = 0;
+ } else {
+ start_compare = -1;
+ }
+
+ // If the position in TO is out of range, this variable will be associated
to map
+ // list, to mark deletion positions
+ const char empty = '\0';
+
+ for (int in_for = 0; in_for < in_len; in_for++) {
+ if (subs_list.find(in[in_for]) != subs_list.end()) {
+ if (subs_list[in[in_for]] != empty) {
+ // If exist in map, only add the correspondent value in result
+ result[result_len] = subs_list[in[in_for]];
+ result_len++;
+ }
+ } else {
+ for (int from_for = 0; from_for <= from_len; from_for++) {
+ if (from_for == from_len) {
+ // If it's not in the FROM list, just add it to the map and the
result.
+ subs_list.insert(std::pair<char, char>(in[in_for], in[in_for]));
+ result[result_len] = in[in_for];
+ result_len++;
+ break;
+ }
+ if (in[in_for] != from[from_for]) {
+ // If this character does not exist in FROM list, don't need
treatment
+ continue;
+ } else if (start_compare == -1 || start_compare == to_len) {
+ // If exist but the start_compare is out of range, add to map as
empty, to
+ // deletion later
+ subs_list.insert(std::pair<char, char>(in[in_for], empty));
+ break;
+ } else {
+ // If exist and the start_compare is in range, add to map with the
+ // corresponding TO in position start_compare
+ subs_list.insert(std::pair<char, char>(in[in_for],
to[start_compare]));
+ result[result_len] = subs_list[in[in_for]];
+ result_len++;
+ start_compare++;
+ break; // for ignore duplicates entries in FROM, ex: ("adad")
+ }
+ }
+ }
+ }
+ } else { // If there are no multibytes in the input, work with std::strings
+ // This variable is for receive the substitutions, malloc is in_len * 4 to
receive
+ // possible inputs with 4 bytes
+ result = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context,
in_len * 4));
+
+ if (result == nullptr) {
+ gdv_fn_context_set_error_msg(context,
+ "Could not allocate memory for output
string");
+ *out_len = 0;
+ return "";
+ }
+ result_len = 0;
+
+ // This map is std::string to store multi-bytes entries
+ std::unordered_map<std::string, std::string> subs_list;
+
+ // This variable is for controlling the position in entry TO, for never
repeat the
+ // changes
+ int start_compare;
+
+ if (to_len > 0) {
+ start_compare = 0;
+ } else {
+ start_compare = -1;
+ }
+
+ // If the position in TO is out of range, this variable will be associated
to map
+ // list, to mark deletion positions
+ const std::string empty = "";
+
+ // This variables is to control len of multi-bytes entries
+ int len_char_in = 0;
+ int len_char_from = 0;
+ int len_char_to = 0;
+
+ for (int in_for = 0; in_for < in_len; in_for += len_char_in) {
+ // Updating len to char in this position
+ len_char_in = gdv_fn_utf8_char_length(in[in_for]);
+ // Making copy to std::string with length for this char position
+ std::string insert_copy_key(in + in_for, len_char_in);
+ if (subs_list.find(insert_copy_key) != subs_list.end()) {
+ if (subs_list[insert_copy_key] != empty) {
+ // If exist in map, only add the correspondent value in result
+ memcpy(result + result_len, subs_list[insert_copy_key].c_str(),
+ subs_list[insert_copy_key].length());
+ result_len += static_cast<int>(subs_list[insert_copy_key].length());
+ }
+ } else {
+ for (int from_for = 0; from_for <= from_len; from_for +=
len_char_from) {
+ // Updating len to char in this position
+ len_char_from = gdv_fn_utf8_char_length(from[from_for]);
+ // Making copy to std::string with length for this char position
+ std::string copy_from_compare(from + from_for, len_char_from);
+ if (from_for == from_len) {
+ // If it's not in the FROM list, just add it to the map and the
result.
+ std::string insert_copy_value(in + in_for, len_char_in);
+ // Insert in map to next loops
+ subs_list.insert(
+ std::pair<std::string, std::string>(insert_copy_key,
insert_copy_value));
+ memcpy(result + result_len, subs_list[insert_copy_key].c_str(),
+ subs_list[insert_copy_key].length());
+ result_len +=
static_cast<int>(subs_list[insert_copy_key].length());
+ break;
+ }
+
+ if (insert_copy_key != copy_from_compare) {
+ // If this character does not exist in FROM list, don't need
treatment
+ continue;
+ } else if (start_compare == -1 || start_compare >= to_len) {
+ // If exist but the start_compare is out of range, add to map as
empty, to
+ // deletion later
+ subs_list.insert(std::pair<std::string,
std::string>(insert_copy_key, empty));
+ break;
+ } else {
+ // If exist and the start_compare is in range, add to map with the
+ // corresponding TO in position start_compare
+ len_char_to = gdv_fn_utf8_char_length(to[start_compare]);
+ std::string insert_copy_value(to + start_compare, len_char_to);
+ // Insert in map to next loops
+ subs_list.insert(
+ std::pair<std::string, std::string>(insert_copy_key,
insert_copy_value));
+ memcpy(result + result_len, subs_list[insert_copy_key].c_str(),
+ subs_list[insert_copy_key].length());
+ result_len +=
static_cast<int>(subs_list[insert_copy_key].length());
+ start_compare += len_char_to;
+ break; // for ignore duplicates entries in FROM, ex: ("adad")
+ }
+ }
+ }
+ }
+ }
+
+ *out_len = result_len;
+ return result;
+}
}
namespace gandiva {
@@ -649,5 +866,21 @@ void ExportedStringFunctions::AddMappings(Engine* engine)
const {
engine->AddGlobalMappingForFunc("gdv_fn_initcap_utf8",
types->i8_ptr_type() /*return_type*/, args,
reinterpret_cast<void*>(gdv_fn_initcap_utf8));
+
+ // translate_utf8_utf8_utf8
+ args = {
+ types->i64_type(), // context
+ types->i8_ptr_type(), // const char*
+ types->i32_type(), // value_length
+ types->i8_ptr_type(), // const char*
+ types->i32_type(), // value_length
+ types->i8_ptr_type(), // const char*
+ types->i32_type(), // value_length
+ types->i32_ptr_type() // out_length
+ };
+
+ engine->AddGlobalMappingForFunc("translate_utf8_utf8_utf8",
+ types->i8_ptr_type() /*return_type*/, args,
+
reinterpret_cast<void*>(translate_utf8_utf8_utf8));
}
} // namespace gandiva
diff --git a/cpp/src/gandiva/tests/projector_test.cc
b/cpp/src/gandiva/tests/projector_test.cc
index ad36315ea1..60acc4e21b 100644
--- a/cpp/src/gandiva/tests/projector_test.cc
+++ b/cpp/src/gandiva/tests/projector_test.cc
@@ -2857,4 +2857,54 @@ TEST_F(TestProjector, TestLCase) {
EXPECT_ARROW_ARRAY_EQUALS(out_1, outputs.at(0));
}
+TEST_F(TestProjector, TestTranslate) {
+ // schema for input fields
+ auto field0 = field("f0", arrow::utf8());
+ auto field1 = field("f1", arrow::utf8());
+ auto field2 = field("f2", arrow::utf8());
+
+ auto schema_translate = arrow::schema({field0, field1, field2});
+
+ // output fields
+ auto field_translate = field("translate", arrow::utf8());
+
+ // Build expression
+ auto translate_expr = TreeExprBuilder::MakeExpression(
+ "translate", {field0, field1, field2}, field_translate);
+
+ std::shared_ptr<Projector> projector;
+ auto status = Projector::Make(schema_translate, {translate_expr},
TestConfiguration(),
+ &projector);
+
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Create a row-batch with some sample data
+ int num_records = 7;
+ auto array0 = MakeArrowArrayUtf8({"a b c d", "abcde", "My Name Is JHONNY",
"\n\n/n/n",
+ "x大路学路x", "学大路学路大", "abcd"},
+ {true, true, true, true, true, true, true});
+
+ auto array1 = MakeArrowArrayUtf8({" ", "bd", "JHONNY", "/\n", "x", "学大",
"abc"},
+ {true, true, true, true, true, true, true});
+
+ auto array2 = MakeArrowArrayUtf8({"", "xb", "XXXXX", "a~", "b", "12", "学大路"},
+ {true, true, true, true, true, true, true});
+
+ // expected output
+ auto exp_translate = MakeArrowArrayUtf8({"abcd", "axcbe", "My Xame Is
XXXXXX", "aa~n~n",
+ "b大路学路b", "12路1路2", "学大路d"},
+ {true, true, true, true, true, true,
true});
+
+ // prepare input record batch
+ auto in_batch =
+ arrow::RecordBatch::Make(schema_translate, num_records, {array0, array1,
array2});
+
+ // Evaluate expression
+ arrow::ArrayVector outputs;
+ status = projector->Evaluate(*in_batch, pool_, &outputs);
+ EXPECT_TRUE(status.ok()) << status.message();
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp_translate, outputs.at(0));
+}
} // namespace gandiva