maartenbreddels commented on a change in pull request #8468:
URL: https://github.com/apache/arrow/pull/8468#discussion_r545869942



##########
File path: cpp/src/arrow/compute/kernels/scalar_string.cc
##########
@@ -1194,6 +1198,197 @@ void AddSplit(FunctionRegistry* registry) {
 #endif
 }
 
+// ----------------------------------------------------------------------
+// replace substring
+
+template <typename Type, typename Derived>
+struct ReplaceSubStringBase {
+  using ArrayType = typename TypeTraits<Type>::ArrayType;
+  using ScalarType = typename TypeTraits<Type>::ScalarType;
+  using BuilderType = typename TypeTraits<Type>::BuilderType;
+  using offset_type = typename Type::offset_type;
+  using ValueDataBuilder = TypedBufferBuilder<uint8_t>;
+  using OffsetBuilder = TypedBufferBuilder<offset_type>;
+  using State = OptionsWrapper<ReplaceSubstringOptions>;
+
+  static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+    Derived derived(ctx, State::Get(ctx));
+    if (ctx->status().ok()) {
+      derived.Replace(ctx, batch, out);
+    }
+  }
+  void Replace(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+    std::shared_ptr<ValueDataBuilder> value_data_builder =
+        std::make_shared<ValueDataBuilder>();
+    std::shared_ptr<OffsetBuilder> offset_builder = 
std::make_shared<OffsetBuilder>();
+
+    if (batch[0].kind() == Datum::ARRAY) {
+      // We already know how many strings we have, so we can use 
Reserve/UnsafeAppend
+      KERNEL_RETURN_IF_ERROR(ctx, 
offset_builder->Reserve(batch[0].array()->length));
+
+      const ArrayData& input = *batch[0].array();
+      KERNEL_RETURN_IF_ERROR(ctx, offset_builder->Append(0));  // offsets 
start at 0
+      KERNEL_RETURN_IF_ERROR(
+          ctx, VisitArrayDataInline<Type>(
+                   input,
+                   [&](util::string_view s) {
+                     RETURN_NOT_OK(static_cast<Derived&>(*this).ReplaceString(
+                         s, value_data_builder.get()));
+                     offset_builder->UnsafeAppend(
+                         
static_cast<offset_type>(value_data_builder->length()));
+                     return Status::OK();
+                   },
+                   [&]() {
+                     // offset for null value
+                     offset_builder->UnsafeAppend(
+                         
static_cast<offset_type>(value_data_builder->length()));
+                     return Status::OK();
+                   }));
+      ArrayData* output = out->mutable_array();
+      KERNEL_RETURN_IF_ERROR(ctx, 
value_data_builder->Finish(&output->buffers[2]));
+      KERNEL_RETURN_IF_ERROR(ctx, offset_builder->Finish(&output->buffers[1]));
+    } else {
+      const auto& input = checked_cast<const ScalarType&>(*batch[0].scalar());
+      auto result = std::make_shared<ScalarType>();
+      if (input.is_valid) {
+        util::string_view s = static_cast<util::string_view>(*input.value);
+        KERNEL_RETURN_IF_ERROR(
+            ctx, static_cast<Derived&>(*this).ReplaceString(s, 
value_data_builder.get()));
+        KERNEL_RETURN_IF_ERROR(ctx, 
value_data_builder->Finish(&result->value));
+        result->is_valid = true;
+      }
+      out->value = result;
+    }
+  }
+};
+
+template <typename Type>
+struct ReplaceSubString : ReplaceSubStringBase<Type, ReplaceSubString<Type>> {
+  using Base = ReplaceSubStringBase<Type, ReplaceSubString<Type>>;
+  using ValueDataBuilder = typename Base::ValueDataBuilder;
+  using offset_type = typename Base::offset_type;
+
+  ReplaceSubstringOptions options;
+  explicit ReplaceSubString(KernelContext* ctx, ReplaceSubstringOptions 
options)
+      : options(options) {}
+
+  Status ReplaceString(util::string_view s, ValueDataBuilder* builder) {
+    const char* i = s.begin();
+    const char* end = s.end();
+    int64_t max_replacements = options.max_replacements;
+    while ((i < end) && (max_replacements != 0)) {
+      const char* pos =
+          std::search(i, end, options.pattern.begin(), options.pattern.end());
+      if (pos == end) {
+        RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(i),
+                                      static_cast<offset_type>(end - i)));
+        i = end;
+      } else {
+        // the string before the pattern
+        RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(i),
+                                      static_cast<offset_type>(pos - i)));
+        // the replacement
+        RETURN_NOT_OK(
+            builder->Append(reinterpret_cast<const 
uint8_t*>(options.replacement.data()),
+                            options.replacement.length()));
+        // skip pattern
+        i = pos + options.pattern.length();
+        max_replacements--;
+      }
+    }
+    // if we exited early due to max_replacements, add the trailing part
+    RETURN_NOT_OK(builder->Append(reinterpret_cast<const uint8_t*>(i),
+                                  static_cast<offset_type>(end - i)));
+    return Status::OK();
+  }
+};
+
+const FunctionDoc replace_substring_doc(
+    "Replace non-overlapping substrings that match pattern by replacement",
+    ("For each string in `strings`, replace non-overlapping substrings that 
match\n"
+     "`pattern` by `replacement`. If `max_replacements != -1`, it determines 
the\n"
+     "maximum amount of replacements made, counting from the left. Null values 
emit\n"
+     "null."),
+    {"strings"}, "ReplaceSubstringOptions");
+
+#ifdef ARROW_WITH_RE2
+template <typename Type>
+struct ReplaceSubStringRE2 : ReplaceSubStringBase<Type, 
ReplaceSubStringRE2<Type>> {
+  using Base = ReplaceSubStringBase<Type, ReplaceSubStringRE2<Type>>;
+  using ValueDataBuilder = typename Base::ValueDataBuilder;
+  using offset_type = typename Base::offset_type;
+
+  ReplaceSubstringOptions options;
+  RE2 regex_find;
+  RE2 regex_replacement;
+  explicit ReplaceSubStringRE2(KernelContext* ctx, ReplaceSubstringOptions 
options)
+      : options(options),
+        regex_find("(" + options.pattern + ")"),
+        regex_replacement(options.pattern) {
+    // Using RE2::FindAndConsume we can only find the pattern if it is a 
group, therefore
+    // we have 2 regex, one with () around it, one without.
+    if (!(regex_find.ok() && regex_replacement.ok())) {
+      ctx->SetStatus(Status::Invalid("Regular expression error"));
+      return;
+    }
+  }
+  Status ReplaceString(util::string_view s, ValueDataBuilder* builder) {
+    re2::StringPiece replacement(options.replacement);
+    if (options.max_replacements == -1) {
+      std::string s_copy(s.to_string());
+      re2::RE2::GlobalReplace(&s_copy, regex_replacement, replacement);
+      RETURN_NOT_OK(builder->Append(reinterpret_cast<const 
uint8_t*>(s_copy.data()),
+                                    s_copy.length()));
+      return Status::OK();
+    }
+    // Since RE2 does not have the concept of max_replacements, we have to do 
some work
+    // ourselves.

Review comment:
       Good idea, I prefer to keep it as it is, I left a comment in the code so 
this doesn't get lost.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to