pitrou commented on code in PR #45577: URL: https://github.com/apache/arrow/pull/45577#discussion_r1983592880
########## cpp/src/arrow/compute/kernels/scalar_string_test.cc: ########## @@ -1958,6 +1960,47 @@ TYPED_TEST(TestBaseBinaryKernels, ExtractRegex) { R"([{"letter": "a", "digit": "1"}, {"letter": "b", "digit": "3"}])", &options); } +TYPED_TEST(TestBaseBinaryKernels, ExtractRegexSpan) { + ExtractRegexSpanOptions options{"(?P<letter>[ab])(?P<digit>\\d)"}; + auto type_fixe_size_list = is_binary_like(this->type()->id()) ? int32() : int64(); + auto out_type = struct_({field("letter", fixed_size_list(type_fixe_size_list, 2)), + field("digit", fixed_size_list(type_fixe_size_list, 2))}); + this->CheckUnary("extract_regex_span", R"([])", out_type, R"([])", &options); + this->CheckUnary( + "extract_regex_span", R"(["a1", "b2", "c3", null])", out_type, Review Comment: Could you make the tests less trivial and use examples with variable-length captures? For example have the regex be: `(?P<letter>[ab]+)(?P<digit>\\d+)` and then test with "abb12", "abc13", etc. ########## cpp/src/arrow/compute/kernels/scalar_string_ascii.cc: ########## @@ -2347,6 +2356,137 @@ void AddAsciiStringExtractRegex(FunctionRegistry* registry) { } DCHECK_OK(registry->AddFunction(std::move(func))); } +struct ExtractRegexSpanData : public BaseExtractRegexData { + static Result<ExtractRegexSpanData> Make(const std::string& pattern) { + auto data = ExtractRegexSpanData(pattern, true); + ARROW_RETURN_NOT_OK(data.Init()); + return data; + } + + Result<TypeHolder> ResolveOutputType(const std::vector<TypeHolder>& types) const { + const DataType* input_type = types[0].type; + if (input_type == NULLPTR) { + return NULLPTR; + } + DCHECK(is_base_binary_like(input_type->id())); + const size_t field_count = group_names_.size(); + FieldVector fields; + fields.reserve(field_count); + const auto owned_type = input_type->GetSharedPtr(); + for (const auto& group_name : group_names_) { + auto type = is_binary_like(owned_type->id()) ? int32() : int64(); + // size list is 2 as every span contains position and length + fields.push_back(field(group_name, fixed_size_list(type, 2))); + } + return struct_(fields); + } + + private: + ExtractRegexSpanData(const std::string& pattern, const bool is_utf8) + : BaseExtractRegexData(pattern, is_utf8) {} +}; + +template <typename Type> +struct ExtractRegexSpan : ExtractRegexBase { + using ArrayType = typename TypeTraits<Type>::ArrayType; + using BuilderType = typename TypeTraits<Type>::BuilderType; + using offset_type = typename Type::offset_type; + using OffsetBuilderType = + typename TypeTraits<typename CTypeTraits<offset_type>::ArrowType>::BuilderType; + using OffsetCType = + typename TypeTraits<typename CTypeTraits<offset_type>::ArrowType>::CType; + + using ExtractRegexBase::ExtractRegexBase; + + static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + auto options = OptionsWrapper<ExtractRegexSpanOptions>::Get(ctx); + ARROW_ASSIGN_OR_RAISE(auto data, ExtractRegexSpanData::Make(options.pattern)); + return ExtractRegexSpan{data}.Extract(ctx, batch, out); + } + Status Extract(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + DCHECK_NE(out->array_data(), NULLPTR); + std::shared_ptr<DataType> out_type = out->array_data()->type; + DCHECK_NE(out_type, NULLPTR); + std::unique_ptr<ArrayBuilder> out_builder; + ARROW_RETURN_NOT_OK( + MakeBuilder(ctx->memory_pool(), out->type()->GetSharedPtr(), &out_builder)); + auto struct_builder = checked_pointer_cast<StructBuilder>(std::move(out_builder)); + ARROW_RETURN_NOT_OK(struct_builder->Reserve(batch[0].array.length)); + std::vector<FixedSizeListBuilder*> span_builders; + std::vector<OffsetBuilderType*> array_builders; + span_builders.reserve(group_count); + array_builders.reserve(group_count); + for (int i = 0; i < group_count; i++) { + span_builders.push_back( + checked_cast<FixedSizeListBuilder*>(struct_builder->field_builder(i))); + array_builders.push_back( + checked_cast<OffsetBuilderType*>(span_builders[i]->value_builder())); + RETURN_NOT_OK(span_builders.back()->Reserve(batch[0].array.length)); + RETURN_NOT_OK(array_builders.back()->Reserve(2 * batch[0].array.length)); + } + + auto visit_null = [&]() { return struct_builder->AppendNull(); }; + auto visit_value = [&](std::string_view element) -> Status { + if (Match(element)) { + for (int i = 0; i < group_count; i++) { + // https://github.com/google/re2/issues/24#issuecomment-97653183 + if (found_values[i].data() != NULLPTR) { Review Comment: You don't need to use `NULLPTR` in a `.cc` file. ```suggestion if (found_values[i].data() != nullptr) { ``` ########## cpp/src/arrow/compute/kernels/scalar_string_ascii.cc: ########## @@ -2280,15 +2290,14 @@ struct ExtractRegex : public ExtractRegexBase { static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { ExtractRegexOptions options = ExtractRegexState::Get(ctx); ARROW_ASSIGN_OR_RAISE(auto data, ExtractRegexData::Make(options, Type::is_utf8)); - return ExtractRegex{data}.Extract(ctx, batch, out); + return ExtractRegex(data).Extract(ctx, batch, out); } Status Extract(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - // TODO: why is this needed? Type resolution should already be - // done and the output type set in the output variable - ARROW_ASSIGN_OR_RAISE(TypeHolder out_type, data.ResolveOutputType(batch.GetTypes())); - DCHECK_NE(out_type.type, nullptr); - std::shared_ptr<DataType> type = out_type.GetSharedPtr(); + ExtractRegexOptions options = ExtractRegexState::Get(ctx); + DCHECK_NE(out->array_data(), NULLPTR); + std::shared_ptr<DataType> type = out->array_data()->type; + DCHECK_NE(type, NULLPTR); Review Comment: This DCHECK is not needed. ########## cpp/src/arrow/compute/kernels/scalar_string_ascii.cc: ########## @@ -2347,6 +2356,137 @@ void AddAsciiStringExtractRegex(FunctionRegistry* registry) { } DCHECK_OK(registry->AddFunction(std::move(func))); } +struct ExtractRegexSpanData : public BaseExtractRegexData { + static Result<ExtractRegexSpanData> Make(const std::string& pattern) { + auto data = ExtractRegexSpanData(pattern, true); + ARROW_RETURN_NOT_OK(data.Init()); + return data; + } + + Result<TypeHolder> ResolveOutputType(const std::vector<TypeHolder>& types) const { + const DataType* input_type = types[0].type; + if (input_type == NULLPTR) { + return NULLPTR; + } + DCHECK(is_base_binary_like(input_type->id())); + const size_t field_count = group_names_.size(); + FieldVector fields; + fields.reserve(field_count); + const auto owned_type = input_type->GetSharedPtr(); + for (const auto& group_name : group_names_) { + auto type = is_binary_like(owned_type->id()) ? int32() : int64(); + // size list is 2 as every span contains position and length + fields.push_back(field(group_name, fixed_size_list(type, 2))); + } + return struct_(fields); + } + + private: + ExtractRegexSpanData(const std::string& pattern, const bool is_utf8) + : BaseExtractRegexData(pattern, is_utf8) {} +}; + +template <typename Type> +struct ExtractRegexSpan : ExtractRegexBase { + using ArrayType = typename TypeTraits<Type>::ArrayType; + using BuilderType = typename TypeTraits<Type>::BuilderType; + using offset_type = typename Type::offset_type; + using OffsetBuilderType = + typename TypeTraits<typename CTypeTraits<offset_type>::ArrowType>::BuilderType; + using OffsetCType = + typename TypeTraits<typename CTypeTraits<offset_type>::ArrowType>::CType; + + using ExtractRegexBase::ExtractRegexBase; + + static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + auto options = OptionsWrapper<ExtractRegexSpanOptions>::Get(ctx); + ARROW_ASSIGN_OR_RAISE(auto data, ExtractRegexSpanData::Make(options.pattern)); + return ExtractRegexSpan{data}.Extract(ctx, batch, out); + } + Status Extract(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + DCHECK_NE(out->array_data(), NULLPTR); + std::shared_ptr<DataType> out_type = out->array_data()->type; + DCHECK_NE(out_type, NULLPTR); + std::unique_ptr<ArrayBuilder> out_builder; + ARROW_RETURN_NOT_OK( + MakeBuilder(ctx->memory_pool(), out->type()->GetSharedPtr(), &out_builder)); Review Comment: Let's simplify this: ```suggestion std::shared_ptr<DataType> out_type = out->array_data()->type; ARROW_ASSIGN_OR_RAISE(auto out_builder, MakeBuilder(ctx->memory_pool(), out_type)); ``` ########## cpp/src/arrow/compute/kernels/scalar_string_ascii.cc: ########## @@ -2347,6 +2356,137 @@ void AddAsciiStringExtractRegex(FunctionRegistry* registry) { } DCHECK_OK(registry->AddFunction(std::move(func))); } +struct ExtractRegexSpanData : public BaseExtractRegexData { + static Result<ExtractRegexSpanData> Make(const std::string& pattern) { + auto data = ExtractRegexSpanData(pattern, true); Review Comment: Don't we want to pass `Type::is_utf8` for `is_utf8`? ########## cpp/src/arrow/compute/kernels/scalar_string_ascii.cc: ########## @@ -2280,15 +2290,14 @@ struct ExtractRegex : public ExtractRegexBase { static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { ExtractRegexOptions options = ExtractRegexState::Get(ctx); ARROW_ASSIGN_OR_RAISE(auto data, ExtractRegexData::Make(options, Type::is_utf8)); - return ExtractRegex{data}.Extract(ctx, batch, out); + return ExtractRegex(data).Extract(ctx, batch, out); } Status Extract(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - // TODO: why is this needed? Type resolution should already be - // done and the output type set in the output variable - ARROW_ASSIGN_OR_RAISE(TypeHolder out_type, data.ResolveOutputType(batch.GetTypes())); - DCHECK_NE(out_type.type, nullptr); - std::shared_ptr<DataType> type = out_type.GetSharedPtr(); + ExtractRegexOptions options = ExtractRegexState::Get(ctx); Review Comment: `options` isn't used below, is it? ########## cpp/src/arrow/compute/kernels/scalar_string_ascii.cc: ########## @@ -2347,6 +2356,137 @@ void AddAsciiStringExtractRegex(FunctionRegistry* registry) { } DCHECK_OK(registry->AddFunction(std::move(func))); } +struct ExtractRegexSpanData : public BaseExtractRegexData { + static Result<ExtractRegexSpanData> Make(const std::string& pattern) { + auto data = ExtractRegexSpanData(pattern, true); + ARROW_RETURN_NOT_OK(data.Init()); + return data; + } + + Result<TypeHolder> ResolveOutputType(const std::vector<TypeHolder>& types) const { + const DataType* input_type = types[0].type; + if (input_type == NULLPTR) { + return NULLPTR; + } Review Comment: ```suggestion if (input_type == nullptr) { return nullptr; } ``` ########## cpp/src/arrow/compute/kernels/scalar_string_ascii.cc: ########## @@ -2347,6 +2356,137 @@ void AddAsciiStringExtractRegex(FunctionRegistry* registry) { } DCHECK_OK(registry->AddFunction(std::move(func))); } +struct ExtractRegexSpanData : public BaseExtractRegexData { + static Result<ExtractRegexSpanData> Make(const std::string& pattern) { + auto data = ExtractRegexSpanData(pattern, true); + ARROW_RETURN_NOT_OK(data.Init()); + return data; + } + + Result<TypeHolder> ResolveOutputType(const std::vector<TypeHolder>& types) const { + const DataType* input_type = types[0].type; + if (input_type == NULLPTR) { + return NULLPTR; + } + DCHECK(is_base_binary_like(input_type->id())); + const size_t field_count = group_names_.size(); + FieldVector fields; + fields.reserve(field_count); + const auto owned_type = input_type->GetSharedPtr(); + for (const auto& group_name : group_names_) { + auto type = is_binary_like(owned_type->id()) ? int32() : int64(); + // size list is 2 as every span contains position and length + fields.push_back(field(group_name, fixed_size_list(type, 2))); + } Review Comment: ```suggestion auto offset_type = is_large_binary_like(input_type->id()) : int64() : int32(); for (const auto& group_name : group_names_) { // size list is 2 as every span contains position and length fields.push_back(field(group_name, fixed_size_list(offset_type, 2))); } ``` ########## cpp/src/arrow/compute/kernels/scalar_string_ascii.cc: ########## @@ -2184,29 +2184,39 @@ void AddAsciiStringReplaceSubstring(FunctionRegistry* registry) { using ExtractRegexState = OptionsWrapper<ExtractRegexOptions>; -// TODO cache this once per ExtractRegexOptions -struct ExtractRegexData { - // Use unique_ptr<> because RE2 is non-movable (for ARROW_ASSIGN_OR_RAISE) - std::unique_ptr<RE2> regex; - std::vector<std::string> group_names; - - static Result<ExtractRegexData> Make(const ExtractRegexOptions& options, - bool is_utf8 = true) { - ExtractRegexData data(options.pattern, is_utf8); - RETURN_NOT_OK(RegexStatus(*data.regex)); +struct BaseExtractRegexData { + Status Init() { + RETURN_NOT_OK(RegexStatus(*regex_)); - const int group_count = data.regex->NumberOfCapturingGroups(); - const auto& name_map = data.regex->CapturingGroupNames(); - data.group_names.reserve(group_count); + const int group_count = regex_->NumberOfCapturingGroups(); + const auto& name_map = regex_->CapturingGroupNames(); + group_names_.reserve(group_count); for (int i = 0; i < group_count; i++) { auto item = name_map.find(i + 1); // re2 starts counting from 1 if (item == name_map.end()) { // XXX should we instead just create fields with an empty name? return Status::Invalid("Regular expression contains unnamed groups"); } - data.group_names.emplace_back(item->second); + group_names_.emplace_back(item->second); } + return Status::OK(); + } + int64_t num_group() const { return group_names_.size(); } Review Comment: Nit: plural ```suggestion int64_t num_groups() const { return group_names_.size(); } ``` ########## cpp/src/arrow/compute/kernels/scalar_string_ascii.cc: ########## @@ -2347,6 +2356,137 @@ void AddAsciiStringExtractRegex(FunctionRegistry* registry) { } DCHECK_OK(registry->AddFunction(std::move(func))); } +struct ExtractRegexSpanData : public BaseExtractRegexData { + static Result<ExtractRegexSpanData> Make(const std::string& pattern) { + auto data = ExtractRegexSpanData(pattern, true); + ARROW_RETURN_NOT_OK(data.Init()); + return data; + } + + Result<TypeHolder> ResolveOutputType(const std::vector<TypeHolder>& types) const { + const DataType* input_type = types[0].type; + if (input_type == NULLPTR) { + return NULLPTR; + } + DCHECK(is_base_binary_like(input_type->id())); + const size_t field_count = group_names_.size(); + FieldVector fields; + fields.reserve(field_count); + const auto owned_type = input_type->GetSharedPtr(); + for (const auto& group_name : group_names_) { + auto type = is_binary_like(owned_type->id()) ? int32() : int64(); + // size list is 2 as every span contains position and length + fields.push_back(field(group_name, fixed_size_list(type, 2))); + } + return struct_(fields); + } + + private: + ExtractRegexSpanData(const std::string& pattern, const bool is_utf8) + : BaseExtractRegexData(pattern, is_utf8) {} +}; + +template <typename Type> +struct ExtractRegexSpan : ExtractRegexBase { + using ArrayType = typename TypeTraits<Type>::ArrayType; + using BuilderType = typename TypeTraits<Type>::BuilderType; + using offset_type = typename Type::offset_type; + using OffsetBuilderType = + typename TypeTraits<typename CTypeTraits<offset_type>::ArrowType>::BuilderType; + using OffsetCType = + typename TypeTraits<typename CTypeTraits<offset_type>::ArrowType>::CType; + + using ExtractRegexBase::ExtractRegexBase; + + static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + auto options = OptionsWrapper<ExtractRegexSpanOptions>::Get(ctx); + ARROW_ASSIGN_OR_RAISE(auto data, ExtractRegexSpanData::Make(options.pattern)); + return ExtractRegexSpan{data}.Extract(ctx, batch, out); + } + Status Extract(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + DCHECK_NE(out->array_data(), NULLPTR); + std::shared_ptr<DataType> out_type = out->array_data()->type; + DCHECK_NE(out_type, NULLPTR); + std::unique_ptr<ArrayBuilder> out_builder; + ARROW_RETURN_NOT_OK( + MakeBuilder(ctx->memory_pool(), out->type()->GetSharedPtr(), &out_builder)); + auto struct_builder = checked_pointer_cast<StructBuilder>(std::move(out_builder)); + ARROW_RETURN_NOT_OK(struct_builder->Reserve(batch[0].array.length)); + std::vector<FixedSizeListBuilder*> span_builders; + std::vector<OffsetBuilderType*> array_builders; + span_builders.reserve(group_count); + array_builders.reserve(group_count); + for (int i = 0; i < group_count; i++) { + span_builders.push_back( + checked_cast<FixedSizeListBuilder*>(struct_builder->field_builder(i))); + array_builders.push_back( + checked_cast<OffsetBuilderType*>(span_builders[i]->value_builder())); Review Comment: Nit: let's be consistent and use `back()` as well ```suggestion array_builders.push_back( checked_cast<OffsetBuilderType*>(span_builders.back()->value_builder())); ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org