pitrou commented on a change in pull request #8468:
URL: https://github.com/apache/arrow/pull/8468#discussion_r531163812
##########
File path: cpp/src/arrow/compute/kernels/scalar_string.cc
##########
@@ -1194,6 +1198,197 @@ void AddSplit(FunctionRegistry* registry) {
#endif
}
+// ----------------------------------------------------------------------
+// replace substring
+
+template <typename Type, typename Derived>
+struct ReplaceSubStringBase {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ using ScalarType = typename TypeTraits<Type>::ScalarType;
+ using BuilderType = typename TypeTraits<Type>::BuilderType;
+ using offset_type = typename Type::offset_type;
+ using ValueDataBuilder = TypedBufferBuilder<uint8_t>;
+ using OffsetBuilder = TypedBufferBuilder<offset_type>;
+ using State = OptionsWrapper<ReplaceSubstringOptions>;
+
+ static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ Derived derived(ctx, State::Get(ctx));
+ if (ctx->status().ok()) {
+ derived.Replace(ctx, batch, out);
+ }
+ }
+ void Replace(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ std::shared_ptr<ValueDataBuilder> value_data_builder =
+ std::make_shared<ValueDataBuilder>();
+ std::shared_ptr<OffsetBuilder> offset_builder =
std::make_shared<OffsetBuilder>();
+
+ if (batch[0].kind() == Datum::ARRAY) {
+ // We already know how many strings we have, so we can use
Reserve/UnsafeAppend
+ KERNEL_RETURN_IF_ERROR(ctx,
offset_builder->Reserve(batch[0].array()->length));
+
+ const ArrayData& input = *batch[0].array();
+ KERNEL_RETURN_IF_ERROR(ctx, offset_builder->Append(0)); // offsets
start at 0
+ KERNEL_RETURN_IF_ERROR(
+ ctx, VisitArrayDataInline<Type>(
+ input,
+ [&](util::string_view s) {
+ RETURN_NOT_OK(static_cast<Derived&>(*this).ReplaceString(
+ s, value_data_builder.get()));
+ offset_builder->UnsafeAppend(
+
static_cast<offset_type>(value_data_builder->length()));
+ return Status::OK();
+ },
+ [&]() {
+ // offset for null value
+ offset_builder->UnsafeAppend(
+
static_cast<offset_type>(value_data_builder->length()));
+ return Status::OK();
+ }));
+ ArrayData* output = out->mutable_array();
+ KERNEL_RETURN_IF_ERROR(ctx,
value_data_builder->Finish(&output->buffers[2]));
+ KERNEL_RETURN_IF_ERROR(ctx, offset_builder->Finish(&output->buffers[1]));
+ } else {
+ const auto& input = checked_cast<const ScalarType&>(*batch[0].scalar());
+ auto result = std::make_shared<ScalarType>();
+ if (input.is_valid) {
+ util::string_view s = static_cast<util::string_view>(*input.value);
+ KERNEL_RETURN_IF_ERROR(
+ ctx, static_cast<Derived&>(*this).ReplaceString(s,
value_data_builder.get()));
+ KERNEL_RETURN_IF_ERROR(ctx,
value_data_builder->Finish(&result->value));
+ result->is_valid = true;
+ }
+ out->value = result;
+ }
+ }
+};
+
+template <typename Type>
+struct ReplaceSubString : ReplaceSubStringBase<Type, ReplaceSubString<Type>> {
Review comment:
`ReplaceString` below is basically independent from `Type`, but using
this idiom may compile it twice. Can you find another way to parametrize the
kernel?
(hint: perhaps use composition rather than inheritance)
##########
File path: cpp/src/arrow/compute/api_scalar.h
##########
@@ -68,6 +68,18 @@ struct ARROW_EXPORT SplitPatternOptions : public
SplitOptions {
std::string pattern;
};
+struct ARROW_EXPORT ReplaceSubstringOptions : public FunctionOptions {
+ explicit ReplaceSubstringOptions(std::string pattern, std::string
replacement,
+ int64_t max_replacements = -1)
+ : pattern(pattern), replacement(replacement),
max_replacements(max_replacements) {}
+
+ /// Literal pattern, or regular expression depending on is_regex
Review comment:
Hmm... I don't see `is_regex` here?
##########
File path: cpp/src/arrow/compute/kernels/scalar_string.cc
##########
@@ -1194,6 +1198,197 @@ void AddSplit(FunctionRegistry* registry) {
#endif
}
+// ----------------------------------------------------------------------
+// replace substring
+
+template <typename Type, typename Derived>
+struct ReplaceSubStringBase {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ using ScalarType = typename TypeTraits<Type>::ScalarType;
+ using BuilderType = typename TypeTraits<Type>::BuilderType;
+ using offset_type = typename Type::offset_type;
+ using ValueDataBuilder = TypedBufferBuilder<uint8_t>;
+ using OffsetBuilder = TypedBufferBuilder<offset_type>;
+ using State = OptionsWrapper<ReplaceSubstringOptions>;
+
+ static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ Derived derived(ctx, State::Get(ctx));
+ if (ctx->status().ok()) {
+ derived.Replace(ctx, batch, out);
+ }
+ }
+ void Replace(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ std::shared_ptr<ValueDataBuilder> value_data_builder =
+ std::make_shared<ValueDataBuilder>();
+ std::shared_ptr<OffsetBuilder> offset_builder =
std::make_shared<OffsetBuilder>();
+
+ if (batch[0].kind() == Datum::ARRAY) {
+ // We already know how many strings we have, so we can use
Reserve/UnsafeAppend
+ KERNEL_RETURN_IF_ERROR(ctx,
offset_builder->Reserve(batch[0].array()->length));
+
+ const ArrayData& input = *batch[0].array();
+ KERNEL_RETURN_IF_ERROR(ctx, offset_builder->Append(0)); // offsets
start at 0
+ KERNEL_RETURN_IF_ERROR(
+ ctx, VisitArrayDataInline<Type>(
+ input,
+ [&](util::string_view s) {
+ RETURN_NOT_OK(static_cast<Derived&>(*this).ReplaceString(
+ s, value_data_builder.get()));
+ offset_builder->UnsafeAppend(
+
static_cast<offset_type>(value_data_builder->length()));
+ return Status::OK();
+ },
+ [&]() {
+ // offset for null value
+ offset_builder->UnsafeAppend(
+
static_cast<offset_type>(value_data_builder->length()));
+ return Status::OK();
+ }));
+ ArrayData* output = out->mutable_array();
+ KERNEL_RETURN_IF_ERROR(ctx,
value_data_builder->Finish(&output->buffers[2]));
+ KERNEL_RETURN_IF_ERROR(ctx, offset_builder->Finish(&output->buffers[1]));
+ } else {
+ const auto& input = checked_cast<const ScalarType&>(*batch[0].scalar());
+ auto result = std::make_shared<ScalarType>();
+ if (input.is_valid) {
+ util::string_view s = static_cast<util::string_view>(*input.value);
+ KERNEL_RETURN_IF_ERROR(
+ ctx, static_cast<Derived&>(*this).ReplaceString(s,
value_data_builder.get()));
+ KERNEL_RETURN_IF_ERROR(ctx,
value_data_builder->Finish(&result->value));
+ result->is_valid = true;
+ }
+ out->value = result;
+ }
+ }
+};
+
+template <typename Type>
+struct ReplaceSubString : ReplaceSubStringBase<Type, ReplaceSubString<Type>> {
+ using Base = ReplaceSubStringBase<Type, ReplaceSubString<Type>>;
+ using ValueDataBuilder = typename Base::ValueDataBuilder;
+ using offset_type = typename Base::offset_type;
+
+ ReplaceSubstringOptions options;
+ explicit ReplaceSubString(KernelContext* ctx, ReplaceSubstringOptions
options)
+ : options(options) {}
+
+ Status ReplaceString(util::string_view s, ValueDataBuilder* builder) {
+ const char* i = s.begin();
+ const char* end = s.end();
+ int64_t max_replacements = options.max_replacements;
+ while ((i < end) && (max_replacements != 0)) {
+ const char* pos =
+ std::search(i, end, options.pattern.begin(), options.pattern.end());
+ if (pos == end) {
+ RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(i),
+ static_cast<offset_type>(end - i)));
+ i = end;
+ } else {
+ // the string before the pattern
+ RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(i),
+ static_cast<offset_type>(pos - i)));
+ // the replacement
+ RETURN_NOT_OK(
+ builder->Append(reinterpret_cast<const
uint8_t*>(options.replacement.data()),
+ options.replacement.length()));
+ // skip pattern
+ i = pos + options.pattern.length();
+ max_replacements--;
+ }
+ }
+ // if we exited early due to max_replacements, add the trailing part
+ RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(i),
+ static_cast<offset_type>(end - i)));
+ return Status::OK();
+ }
+};
+
+const FunctionDoc replace_substring_doc(
+ "Replace non-overlapping substrings that match pattern by replacement",
+ ("For each string in `strings`, replace non-overlapping substrings that
match\n"
+ "`pattern` by `replacement`. If `max_replacements != -1`, it determines
the\n"
+ "maximum amount of replacements made, counting from the left. Null values
emit\n"
+ "null."),
+ {"strings"}, "ReplaceSubstringOptions");
+
+#ifdef ARROW_WITH_RE2
+template <typename Type>
+struct ReplaceSubStringRE2 : ReplaceSubStringBase<Type,
ReplaceSubStringRE2<Type>> {
+ using Base = ReplaceSubStringBase<Type, ReplaceSubStringRE2<Type>>;
+ using ValueDataBuilder = typename Base::ValueDataBuilder;
+ using offset_type = typename Base::offset_type;
+
+ ReplaceSubstringOptions options;
+ RE2 regex_find;
+ RE2 regex_replacement;
+ explicit ReplaceSubStringRE2(KernelContext* ctx, ReplaceSubstringOptions
options)
+ : options(options),
+ regex_find("(" + options.pattern + ")"),
+ regex_replacement(options.pattern) {
+ // Using RE2::FindAndConsume we can only find the pattern if it is a
group, therefore
+ // we have 2 regex, one with () around it, one without.
+ if (!(regex_find.ok() && regex_replacement.ok())) {
+ ctx->SetStatus(Status::Invalid("Regular expression error"));
+ return;
+ }
+ }
+ Status ReplaceString(util::string_view s, ValueDataBuilder* builder) {
+ re2::StringPiece replacement(options.replacement);
+ if (options.max_replacements == -1) {
+ std::string s_copy(s.to_string());
+ re2::RE2::GlobalReplace(&s_copy, regex_replacement, replacement);
+ RETURN_NOT_OK(builder->Append(reinterpret_cast<const
uint8_t*>(s_copy.data()),
+ s_copy.length()));
+ return Status::OK();
+ }
+ // Since RE2 does not have the concept of max_replacements, we have to do
some work
+ // ourselves.
Review comment:
Note that the `GlobalReplace` loop works a bit differently, it calls
`Match` then `Rewrite`, avoiding the duplicate matching calls. Not sure it's
worth optimizing this, though:
https://github.com/google/re2/blob/master/re2/re2.cc#L427
##########
File path: cpp/src/arrow/compute/kernels/scalar_string.cc
##########
@@ -1194,6 +1198,197 @@ void AddSplit(FunctionRegistry* registry) {
#endif
}
+// ----------------------------------------------------------------------
+// replace substring
+
+template <typename Type, typename Derived>
+struct ReplaceSubStringBase {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ using ScalarType = typename TypeTraits<Type>::ScalarType;
+ using BuilderType = typename TypeTraits<Type>::BuilderType;
+ using offset_type = typename Type::offset_type;
+ using ValueDataBuilder = TypedBufferBuilder<uint8_t>;
+ using OffsetBuilder = TypedBufferBuilder<offset_type>;
+ using State = OptionsWrapper<ReplaceSubstringOptions>;
+
+ static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ Derived derived(ctx, State::Get(ctx));
+ if (ctx->status().ok()) {
+ derived.Replace(ctx, batch, out);
+ }
+ }
+ void Replace(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ std::shared_ptr<ValueDataBuilder> value_data_builder =
+ std::make_shared<ValueDataBuilder>();
+ std::shared_ptr<OffsetBuilder> offset_builder =
std::make_shared<OffsetBuilder>();
+
+ if (batch[0].kind() == Datum::ARRAY) {
+ // We already know how many strings we have, so we can use
Reserve/UnsafeAppend
+ KERNEL_RETURN_IF_ERROR(ctx,
offset_builder->Reserve(batch[0].array()->length));
+
+ const ArrayData& input = *batch[0].array();
+ KERNEL_RETURN_IF_ERROR(ctx, offset_builder->Append(0)); // offsets
start at 0
+ KERNEL_RETURN_IF_ERROR(
+ ctx, VisitArrayDataInline<Type>(
+ input,
+ [&](util::string_view s) {
+ RETURN_NOT_OK(static_cast<Derived&>(*this).ReplaceString(
+ s, value_data_builder.get()));
+ offset_builder->UnsafeAppend(
+
static_cast<offset_type>(value_data_builder->length()));
+ return Status::OK();
+ },
+ [&]() {
+ // offset for null value
+ offset_builder->UnsafeAppend(
+
static_cast<offset_type>(value_data_builder->length()));
+ return Status::OK();
+ }));
+ ArrayData* output = out->mutable_array();
+ KERNEL_RETURN_IF_ERROR(ctx,
value_data_builder->Finish(&output->buffers[2]));
+ KERNEL_RETURN_IF_ERROR(ctx, offset_builder->Finish(&output->buffers[1]));
+ } else {
+ const auto& input = checked_cast<const ScalarType&>(*batch[0].scalar());
+ auto result = std::make_shared<ScalarType>();
+ if (input.is_valid) {
+ util::string_view s = static_cast<util::string_view>(*input.value);
+ KERNEL_RETURN_IF_ERROR(
+ ctx, static_cast<Derived&>(*this).ReplaceString(s,
value_data_builder.get()));
+ KERNEL_RETURN_IF_ERROR(ctx,
value_data_builder->Finish(&result->value));
+ result->is_valid = true;
+ }
+ out->value = result;
+ }
+ }
+};
+
+template <typename Type>
+struct ReplaceSubString : ReplaceSubStringBase<Type, ReplaceSubString<Type>> {
+ using Base = ReplaceSubStringBase<Type, ReplaceSubString<Type>>;
+ using ValueDataBuilder = typename Base::ValueDataBuilder;
+ using offset_type = typename Base::offset_type;
+
+ ReplaceSubstringOptions options;
+ explicit ReplaceSubString(KernelContext* ctx, ReplaceSubstringOptions
options)
+ : options(options) {}
+
+ Status ReplaceString(util::string_view s, ValueDataBuilder* builder) {
+ const char* i = s.begin();
+ const char* end = s.end();
+ int64_t max_replacements = options.max_replacements;
+ while ((i < end) && (max_replacements != 0)) {
+ const char* pos =
+ std::search(i, end, options.pattern.begin(), options.pattern.end());
+ if (pos == end) {
+ RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(i),
+ static_cast<offset_type>(end - i)));
+ i = end;
+ } else {
+ // the string before the pattern
+ RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(i),
+ static_cast<offset_type>(pos - i)));
+ // the replacement
+ RETURN_NOT_OK(
+ builder->Append(reinterpret_cast<const
uint8_t*>(options.replacement.data()),
+ options.replacement.length()));
+ // skip pattern
+ i = pos + options.pattern.length();
+ max_replacements--;
+ }
+ }
+ // if we exited early due to max_replacements, add the trailing part
+ RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(i),
+ static_cast<offset_type>(end - i)));
+ return Status::OK();
+ }
+};
+
+const FunctionDoc replace_substring_doc(
+ "Replace non-overlapping substrings that match pattern by replacement",
+ ("For each string in `strings`, replace non-overlapping substrings that
match\n"
+ "`pattern` by `replacement`. If `max_replacements != -1`, it determines
the\n"
+ "maximum amount of replacements made, counting from the left. Null values
emit\n"
+ "null."),
+ {"strings"}, "ReplaceSubstringOptions");
+
+#ifdef ARROW_WITH_RE2
+template <typename Type>
+struct ReplaceSubStringRE2 : ReplaceSubStringBase<Type,
ReplaceSubStringRE2<Type>> {
Review comment:
Similarly as above, this looks basically independent from `Type`.
##########
File path: cpp/src/arrow/compute/kernels/scalar_string_test.cc
##########
@@ -416,6 +424,28 @@ TYPED_TEST(TestStringKernels, SplitWhitespaceUTF8Reverse) {
&options_max);
}
+#ifdef ARROW_WITH_RE2
+TYPED_TEST(TestStringKernels, ReplaceSubstringNormal) {
+ ReplaceSubstringOptions options{"foo", "bazz"};
+ this->CheckUnary("replace_substring", R"(["foo", "this foo that foo",
null])",
+ this->type(), R"(["bazz", "this bazz that bazz", null])",
&options);
+ ReplaceSubstringOptions options_regex{"(fo+)\\s*", "\\1-bazz", -1};
+ this->CheckUnary("replace_substring_re2", R"(["foo ", "this foo that foo",
null])",
+ this->type(), R"(["foo-bazz", "this foo-bazzthat foo-bazz",
null])",
+ &options_regex);
Review comment:
Can you add a test with potential tricky cases? For example
`text="aaaaaa", match="(a.a)", replacement="ab\1"`.
##########
File path: docs/source/cpp/compute.rst
##########
@@ -355,19 +355,23 @@ The third set of functions examines string elements on a
byte-per-byte basis:
String transforms
~~~~~~~~~~~~~~~~~
-+--------------------------+------------+-------------------------+---------------------+---------+
-| Function name | Arity | Input types | Output
type | Notes |
-+==========================+============+=========================+=====================+=========+
-| ascii_lower | Unary | String-like |
String-like | \(1) |
-+--------------------------+------------+-------------------------+---------------------+---------+
-| ascii_upper | Unary | String-like |
String-like | \(1) |
-+--------------------------+------------+-------------------------+---------------------+---------+
-| binary_length | Unary | Binary- or String-like | Int32 or
Int64 | \(2) |
-+--------------------------+------------+-------------------------+---------------------+---------+
-| utf8_lower | Unary | String-like |
String-like | \(3) |
-+--------------------------+------------+-------------------------+---------------------+---------+
-| utf8_upper | Unary | String-like |
String-like | \(3) |
-+--------------------------+------------+-------------------------+---------------------+---------+
++--------------------------+------------+-------------------------+---------------------+-------------------------------------------------+
+| Function name | Arity | Input types | Output
type | Notes | Options class |
++==========================+============+=========================+=====================+=========+=======================================+
+| ascii_lower | Unary | String-like |
String-like | \(1) | |
++--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+
+| ascii_upper | Unary | String-like |
String-like | \(1) | |
++--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+
+| binary_length | Unary | Binary- or String-like | Int32 or
Int64 | \(2) | |
++--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+
+| replace_substring | Unary | String-like |
String-like | \(3) | :struct:`ReplaceSubstringOptions` |
++--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+
+| replace_substring_re2 | Unary | String-like |
String-like | \(4) | :struct:`ReplaceSubstringOptions` |
Review comment:
Please don't put "re2" in any of the public names or APIs. Using the re2
library is just an implementation detail.
"replace_regex" sounds just as good.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]