pitrou commented on a change in pull request #8468:
URL: https://github.com/apache/arrow/pull/8468#discussion_r531163812



##########
File path: cpp/src/arrow/compute/kernels/scalar_string.cc
##########
@@ -1194,6 +1198,197 @@ void AddSplit(FunctionRegistry* registry) {
 #endif
 }
 
+// ----------------------------------------------------------------------
+// replace substring
+
+template <typename Type, typename Derived>
+struct ReplaceSubStringBase {
+  using ArrayType = typename TypeTraits<Type>::ArrayType;
+  using ScalarType = typename TypeTraits<Type>::ScalarType;
+  using BuilderType = typename TypeTraits<Type>::BuilderType;
+  using offset_type = typename Type::offset_type;
+  using ValueDataBuilder = TypedBufferBuilder<uint8_t>;
+  using OffsetBuilder = TypedBufferBuilder<offset_type>;
+  using State = OptionsWrapper<ReplaceSubstringOptions>;
+
+  static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+    Derived derived(ctx, State::Get(ctx));
+    if (ctx->status().ok()) {
+      derived.Replace(ctx, batch, out);
+    }
+  }
+  void Replace(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+    std::shared_ptr<ValueDataBuilder> value_data_builder =
+        std::make_shared<ValueDataBuilder>();
+    std::shared_ptr<OffsetBuilder> offset_builder = 
std::make_shared<OffsetBuilder>();
+
+    if (batch[0].kind() == Datum::ARRAY) {
+      // We already know how many strings we have, so we can use 
Reserve/UnsafeAppend
+      KERNEL_RETURN_IF_ERROR(ctx, 
offset_builder->Reserve(batch[0].array()->length));
+
+      const ArrayData& input = *batch[0].array();
+      KERNEL_RETURN_IF_ERROR(ctx, offset_builder->Append(0));  // offsets 
start at 0
+      KERNEL_RETURN_IF_ERROR(
+          ctx, VisitArrayDataInline<Type>(
+                   input,
+                   [&](util::string_view s) {
+                     RETURN_NOT_OK(static_cast<Derived&>(*this).ReplaceString(
+                         s, value_data_builder.get()));
+                     offset_builder->UnsafeAppend(
+                         
static_cast<offset_type>(value_data_builder->length()));
+                     return Status::OK();
+                   },
+                   [&]() {
+                     // offset for null value
+                     offset_builder->UnsafeAppend(
+                         
static_cast<offset_type>(value_data_builder->length()));
+                     return Status::OK();
+                   }));
+      ArrayData* output = out->mutable_array();
+      KERNEL_RETURN_IF_ERROR(ctx, 
value_data_builder->Finish(&output->buffers[2]));
+      KERNEL_RETURN_IF_ERROR(ctx, offset_builder->Finish(&output->buffers[1]));
+    } else {
+      const auto& input = checked_cast<const ScalarType&>(*batch[0].scalar());
+      auto result = std::make_shared<ScalarType>();
+      if (input.is_valid) {
+        util::string_view s = static_cast<util::string_view>(*input.value);
+        KERNEL_RETURN_IF_ERROR(
+            ctx, static_cast<Derived&>(*this).ReplaceString(s, 
value_data_builder.get()));
+        KERNEL_RETURN_IF_ERROR(ctx, 
value_data_builder->Finish(&result->value));
+        result->is_valid = true;
+      }
+      out->value = result;
+    }
+  }
+};
+
+template <typename Type>
+struct ReplaceSubString : ReplaceSubStringBase<Type, ReplaceSubString<Type>> {

Review comment:
       `ReplaceString` below is basically independent from `Type`, but using 
this idiom may compile it twice. Can you find another way to parametrize the 
kernel?
   (hint: perhaps use composition rather than inheritance)

##########
File path: cpp/src/arrow/compute/api_scalar.h
##########
@@ -68,6 +68,18 @@ struct ARROW_EXPORT SplitPatternOptions : public 
SplitOptions {
   std::string pattern;
 };
 
+struct ARROW_EXPORT ReplaceSubstringOptions : public FunctionOptions {
+  explicit ReplaceSubstringOptions(std::string pattern, std::string 
replacement,
+                                   int64_t max_replacements = -1)
+      : pattern(pattern), replacement(replacement), 
max_replacements(max_replacements) {}
+
+  /// Literal pattern, or regular expression depending on is_regex

Review comment:
       Hmm... I don't see `is_regex` here?

##########
File path: cpp/src/arrow/compute/kernels/scalar_string.cc
##########
@@ -1194,6 +1198,197 @@ void AddSplit(FunctionRegistry* registry) {
 #endif
 }
 
+// ----------------------------------------------------------------------
+// replace substring
+
+template <typename Type, typename Derived>
+struct ReplaceSubStringBase {
+  using ArrayType = typename TypeTraits<Type>::ArrayType;
+  using ScalarType = typename TypeTraits<Type>::ScalarType;
+  using BuilderType = typename TypeTraits<Type>::BuilderType;
+  using offset_type = typename Type::offset_type;
+  using ValueDataBuilder = TypedBufferBuilder<uint8_t>;
+  using OffsetBuilder = TypedBufferBuilder<offset_type>;
+  using State = OptionsWrapper<ReplaceSubstringOptions>;
+
+  static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+    Derived derived(ctx, State::Get(ctx));
+    if (ctx->status().ok()) {
+      derived.Replace(ctx, batch, out);
+    }
+  }
+  void Replace(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+    std::shared_ptr<ValueDataBuilder> value_data_builder =
+        std::make_shared<ValueDataBuilder>();
+    std::shared_ptr<OffsetBuilder> offset_builder = 
std::make_shared<OffsetBuilder>();
+
+    if (batch[0].kind() == Datum::ARRAY) {
+      // We already know how many strings we have, so we can use 
Reserve/UnsafeAppend
+      KERNEL_RETURN_IF_ERROR(ctx, 
offset_builder->Reserve(batch[0].array()->length));
+
+      const ArrayData& input = *batch[0].array();
+      KERNEL_RETURN_IF_ERROR(ctx, offset_builder->Append(0));  // offsets 
start at 0
+      KERNEL_RETURN_IF_ERROR(
+          ctx, VisitArrayDataInline<Type>(
+                   input,
+                   [&](util::string_view s) {
+                     RETURN_NOT_OK(static_cast<Derived&>(*this).ReplaceString(
+                         s, value_data_builder.get()));
+                     offset_builder->UnsafeAppend(
+                         
static_cast<offset_type>(value_data_builder->length()));
+                     return Status::OK();
+                   },
+                   [&]() {
+                     // offset for null value
+                     offset_builder->UnsafeAppend(
+                         
static_cast<offset_type>(value_data_builder->length()));
+                     return Status::OK();
+                   }));
+      ArrayData* output = out->mutable_array();
+      KERNEL_RETURN_IF_ERROR(ctx, 
value_data_builder->Finish(&output->buffers[2]));
+      KERNEL_RETURN_IF_ERROR(ctx, offset_builder->Finish(&output->buffers[1]));
+    } else {
+      const auto& input = checked_cast<const ScalarType&>(*batch[0].scalar());
+      auto result = std::make_shared<ScalarType>();
+      if (input.is_valid) {
+        util::string_view s = static_cast<util::string_view>(*input.value);
+        KERNEL_RETURN_IF_ERROR(
+            ctx, static_cast<Derived&>(*this).ReplaceString(s, 
value_data_builder.get()));
+        KERNEL_RETURN_IF_ERROR(ctx, 
value_data_builder->Finish(&result->value));
+        result->is_valid = true;
+      }
+      out->value = result;
+    }
+  }
+};
+
+template <typename Type>
+struct ReplaceSubString : ReplaceSubStringBase<Type, ReplaceSubString<Type>> {
+  using Base = ReplaceSubStringBase<Type, ReplaceSubString<Type>>;
+  using ValueDataBuilder = typename Base::ValueDataBuilder;
+  using offset_type = typename Base::offset_type;
+
+  ReplaceSubstringOptions options;
+  explicit ReplaceSubString(KernelContext* ctx, ReplaceSubstringOptions 
options)
+      : options(options) {}
+
+  Status ReplaceString(util::string_view s, ValueDataBuilder* builder) {
+    const char* i = s.begin();
+    const char* end = s.end();
+    int64_t max_replacements = options.max_replacements;
+    while ((i < end) && (max_replacements != 0)) {
+      const char* pos =
+          std::search(i, end, options.pattern.begin(), options.pattern.end());
+      if (pos == end) {
+        RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(i),
+                                      static_cast<offset_type>(end - i)));
+        i = end;
+      } else {
+        // the string before the pattern
+        RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(i),
+                                      static_cast<offset_type>(pos - i)));
+        // the replacement
+        RETURN_NOT_OK(
+            builder->Append(reinterpret_cast<const 
uint8_t*>(options.replacement.data()),
+                            options.replacement.length()));
+        // skip pattern
+        i = pos + options.pattern.length();
+        max_replacements--;
+      }
+    }
+    // if we exited early due to max_replacements, add the trailing part
+    RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(i),
+                                  static_cast<offset_type>(end - i)));
+    return Status::OK();
+  }
+};
+
+const FunctionDoc replace_substring_doc(
+    "Replace non-overlapping substrings that match pattern by replacement",
+    ("For each string in `strings`, replace non-overlapping substrings that 
match\n"
+     "`pattern` by `replacement`. If `max_replacements != -1`, it determines 
the\n"
+     "maximum amount of replacements made, counting from the left. Null values 
emit\n"
+     "null."),
+    {"strings"}, "ReplaceSubstringOptions");
+
+#ifdef ARROW_WITH_RE2
+template <typename Type>
+struct ReplaceSubStringRE2 : ReplaceSubStringBase<Type, 
ReplaceSubStringRE2<Type>> {
+  using Base = ReplaceSubStringBase<Type, ReplaceSubStringRE2<Type>>;
+  using ValueDataBuilder = typename Base::ValueDataBuilder;
+  using offset_type = typename Base::offset_type;
+
+  ReplaceSubstringOptions options;
+  RE2 regex_find;
+  RE2 regex_replacement;
+  explicit ReplaceSubStringRE2(KernelContext* ctx, ReplaceSubstringOptions 
options)
+      : options(options),
+        regex_find("(" + options.pattern + ")"),
+        regex_replacement(options.pattern) {
+    // Using RE2::FindAndConsume we can only find the pattern if it is a 
group, therefore
+    // we have 2 regex, one with () around it, one without.
+    if (!(regex_find.ok() && regex_replacement.ok())) {
+      ctx->SetStatus(Status::Invalid("Regular expression error"));
+      return;
+    }
+  }
+  Status ReplaceString(util::string_view s, ValueDataBuilder* builder) {
+    re2::StringPiece replacement(options.replacement);
+    if (options.max_replacements == -1) {
+      std::string s_copy(s.to_string());
+      re2::RE2::GlobalReplace(&s_copy, regex_replacement, replacement);
+      RETURN_NOT_OK(builder->Append(reinterpret_cast<const 
uint8_t*>(s_copy.data()),
+                                    s_copy.length()));
+      return Status::OK();
+    }
+    // Since RE2 does not have the concept of max_replacements, we have to do 
some work
+    // ourselves.

Review comment:
       Note that the `GlobalReplace` loop works a bit differently, it calls 
`Match` then `Rewrite`, avoiding the duplicate matching calls. Not sure it's 
worth optimizing this, though:
   https://github.com/google/re2/blob/master/re2/re2.cc#L427

##########
File path: cpp/src/arrow/compute/kernels/scalar_string.cc
##########
@@ -1194,6 +1198,197 @@ void AddSplit(FunctionRegistry* registry) {
 #endif
 }
 
+// ----------------------------------------------------------------------
+// replace substring
+
+template <typename Type, typename Derived>
+struct ReplaceSubStringBase {
+  using ArrayType = typename TypeTraits<Type>::ArrayType;
+  using ScalarType = typename TypeTraits<Type>::ScalarType;
+  using BuilderType = typename TypeTraits<Type>::BuilderType;
+  using offset_type = typename Type::offset_type;
+  using ValueDataBuilder = TypedBufferBuilder<uint8_t>;
+  using OffsetBuilder = TypedBufferBuilder<offset_type>;
+  using State = OptionsWrapper<ReplaceSubstringOptions>;
+
+  static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+    Derived derived(ctx, State::Get(ctx));
+    if (ctx->status().ok()) {
+      derived.Replace(ctx, batch, out);
+    }
+  }
+  void Replace(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+    std::shared_ptr<ValueDataBuilder> value_data_builder =
+        std::make_shared<ValueDataBuilder>();
+    std::shared_ptr<OffsetBuilder> offset_builder = 
std::make_shared<OffsetBuilder>();
+
+    if (batch[0].kind() == Datum::ARRAY) {
+      // We already know how many strings we have, so we can use 
Reserve/UnsafeAppend
+      KERNEL_RETURN_IF_ERROR(ctx, 
offset_builder->Reserve(batch[0].array()->length));
+
+      const ArrayData& input = *batch[0].array();
+      KERNEL_RETURN_IF_ERROR(ctx, offset_builder->Append(0));  // offsets 
start at 0
+      KERNEL_RETURN_IF_ERROR(
+          ctx, VisitArrayDataInline<Type>(
+                   input,
+                   [&](util::string_view s) {
+                     RETURN_NOT_OK(static_cast<Derived&>(*this).ReplaceString(
+                         s, value_data_builder.get()));
+                     offset_builder->UnsafeAppend(
+                         
static_cast<offset_type>(value_data_builder->length()));
+                     return Status::OK();
+                   },
+                   [&]() {
+                     // offset for null value
+                     offset_builder->UnsafeAppend(
+                         
static_cast<offset_type>(value_data_builder->length()));
+                     return Status::OK();
+                   }));
+      ArrayData* output = out->mutable_array();
+      KERNEL_RETURN_IF_ERROR(ctx, 
value_data_builder->Finish(&output->buffers[2]));
+      KERNEL_RETURN_IF_ERROR(ctx, offset_builder->Finish(&output->buffers[1]));
+    } else {
+      const auto& input = checked_cast<const ScalarType&>(*batch[0].scalar());
+      auto result = std::make_shared<ScalarType>();
+      if (input.is_valid) {
+        util::string_view s = static_cast<util::string_view>(*input.value);
+        KERNEL_RETURN_IF_ERROR(
+            ctx, static_cast<Derived&>(*this).ReplaceString(s, 
value_data_builder.get()));
+        KERNEL_RETURN_IF_ERROR(ctx, 
value_data_builder->Finish(&result->value));
+        result->is_valid = true;
+      }
+      out->value = result;
+    }
+  }
+};
+
+template <typename Type>
+struct ReplaceSubString : ReplaceSubStringBase<Type, ReplaceSubString<Type>> {
+  using Base = ReplaceSubStringBase<Type, ReplaceSubString<Type>>;
+  using ValueDataBuilder = typename Base::ValueDataBuilder;
+  using offset_type = typename Base::offset_type;
+
+  ReplaceSubstringOptions options;
+  explicit ReplaceSubString(KernelContext* ctx, ReplaceSubstringOptions 
options)
+      : options(options) {}
+
+  Status ReplaceString(util::string_view s, ValueDataBuilder* builder) {
+    const char* i = s.begin();
+    const char* end = s.end();
+    int64_t max_replacements = options.max_replacements;
+    while ((i < end) && (max_replacements != 0)) {
+      const char* pos =
+          std::search(i, end, options.pattern.begin(), options.pattern.end());
+      if (pos == end) {
+        RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(i),
+                                      static_cast<offset_type>(end - i)));
+        i = end;
+      } else {
+        // the string before the pattern
+        RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(i),
+                                      static_cast<offset_type>(pos - i)));
+        // the replacement
+        RETURN_NOT_OK(
+            builder->Append(reinterpret_cast<const 
uint8_t*>(options.replacement.data()),
+                            options.replacement.length()));
+        // skip pattern
+        i = pos + options.pattern.length();
+        max_replacements--;
+      }
+    }
+    // if we exited early due to max_replacements, add the trailing part
+    RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(i),
+                                  static_cast<offset_type>(end - i)));
+    return Status::OK();
+  }
+};
+
+const FunctionDoc replace_substring_doc(
+    "Replace non-overlapping substrings that match pattern by replacement",
+    ("For each string in `strings`, replace non-overlapping substrings that 
match\n"
+     "`pattern` by `replacement`. If `max_replacements != -1`, it determines 
the\n"
+     "maximum amount of replacements made, counting from the left. Null values 
emit\n"
+     "null."),
+    {"strings"}, "ReplaceSubstringOptions");
+
+#ifdef ARROW_WITH_RE2
+template <typename Type>
+struct ReplaceSubStringRE2 : ReplaceSubStringBase<Type, 
ReplaceSubStringRE2<Type>> {

Review comment:
       Similarly as above, this looks basically independent from `Type`.

##########
File path: cpp/src/arrow/compute/kernels/scalar_string_test.cc
##########
@@ -416,6 +424,28 @@ TYPED_TEST(TestStringKernels, SplitWhitespaceUTF8Reverse) {
                    &options_max);
 }
 
+#ifdef ARROW_WITH_RE2
+TYPED_TEST(TestStringKernels, ReplaceSubstringNormal) {
+  ReplaceSubstringOptions options{"foo", "bazz"};
+  this->CheckUnary("replace_substring", R"(["foo", "this foo that foo", 
null])",
+                   this->type(), R"(["bazz", "this bazz that bazz", null])", 
&options);
+  ReplaceSubstringOptions options_regex{"(fo+)\\s*", "\\1-bazz", -1};
+  this->CheckUnary("replace_substring_re2", R"(["foo ", "this foo   that foo", 
null])",
+                   this->type(), R"(["foo-bazz", "this foo-bazzthat foo-bazz", 
null])",
+                   &options_regex);

Review comment:
       Can you add a test with potential tricky cases? For example 
`text="aaaaaa", match="(a.a)", replacement="ab\1"`.

##########
File path: docs/source/cpp/compute.rst
##########
@@ -355,19 +355,23 @@ The third set of functions examines string elements on a 
byte-per-byte basis:
 String transforms
 ~~~~~~~~~~~~~~~~~
 
-+--------------------------+------------+-------------------------+---------------------+---------+
-| Function name            | Arity      | Input types             | Output 
type         | Notes   |
-+==========================+============+=========================+=====================+=========+
-| ascii_lower              | Unary      | String-like             | 
String-like         | \(1)    |
-+--------------------------+------------+-------------------------+---------------------+---------+
-| ascii_upper              | Unary      | String-like             | 
String-like         | \(1)    |
-+--------------------------+------------+-------------------------+---------------------+---------+
-| binary_length            | Unary      | Binary- or String-like  | Int32 or 
Int64      | \(2)    |
-+--------------------------+------------+-------------------------+---------------------+---------+
-| utf8_lower               | Unary      | String-like             | 
String-like         | \(3)    |
-+--------------------------+------------+-------------------------+---------------------+---------+
-| utf8_upper               | Unary      | String-like             | 
String-like         | \(3)    |
-+--------------------------+------------+-------------------------+---------------------+---------+
++--------------------------+------------+-------------------------+---------------------+-------------------------------------------------+
+| Function name            | Arity      | Input types             | Output 
type         | Notes   | Options class                         |
++==========================+============+=========================+=====================+=========+=======================================+
+| ascii_lower              | Unary      | String-like             | 
String-like         | \(1)    |                                       |
++--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+
+| ascii_upper              | Unary      | String-like             | 
String-like         | \(1)    |                                       |
++--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+
+| binary_length            | Unary      | Binary- or String-like  | Int32 or 
Int64      | \(2)    |                                       |
++--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+
+| replace_substring        | Unary      | String-like             | 
String-like         | \(3)    | :struct:`ReplaceSubstringOptions`     |
++--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+
+| replace_substring_re2    | Unary      | String-like             | 
String-like         | \(4)    | :struct:`ReplaceSubstringOptions`     |

Review comment:
       Please don't put "re2" in any of the public names or APIs. Using the re2 
library is just an implementation detail.
   "replace_regex" sounds just as good.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to