This is an automated email from the ASF dual-hosted git repository.

apitrou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new ebcf7bc25f GH-36905: [C++] Add support for SparseUnion to selection 
functions (#36906)
ebcf7bc25f is described below

commit ebcf7bc25fcd47137523cb934a740cac0fc0fb76
Author: Jin Shang <[email protected]>
AuthorDate: Thu Aug 10 19:53:22 2023 +0800

    GH-36905: [C++] Add support for SparseUnion to selection functions (#36906)
    
    
    
    ### Rationale for this change
    
    Dense unions are already supported in Take, Filter and DropNull but sparse 
ones are not.
    
    ### What changes are included in this PR?
    
    Add kernels for sparse unions to those functions.
    
    ### Are these changes tested?
    
    Yes.
    
    ### Are there any user-facing changes?
    
    No.
    
    * Closes: #36905
    
    Lead-authored-by: Jin Shang <[email protected]>
    Co-authored-by: Antoine Pitrou <[email protected]>
    Signed-off-by: Antoine Pitrou <[email protected]>
---
 .../kernels/vector_selection_filter_internal.cc    | 34 +++++----
 .../compute/kernels/vector_selection_internal.cc   | 73 ++++++++++++++++++--
 .../compute/kernels/vector_selection_internal.h    | 13 ++--
 .../kernels/vector_selection_take_internal.cc      |  1 +
 .../arrow/compute/kernels/vector_selection_test.cc | 80 +++++++++++-----------
 docs/source/cpp/compute.rst                        | 18 +++--
 6 files changed, 144 insertions(+), 75 deletions(-)

diff --git a/cpp/src/arrow/compute/kernels/vector_selection_filter_internal.cc 
b/cpp/src/arrow/compute/kernels/vector_selection_filter_internal.cc
index 13e92ba27e..be6d1653b5 100644
--- a/cpp/src/arrow/compute/kernels/vector_selection_filter_internal.cc
+++ b/cpp/src/arrow/compute/kernels/vector_selection_filter_internal.cc
@@ -27,6 +27,7 @@
 #include "arrow/chunked_array.h"
 #include "arrow/compute/api_vector.h"
 #include "arrow/compute/exec.h"
+#include "arrow/compute/kernel.h"
 #include "arrow/compute/kernels/codegen_internal.h"
 #include "arrow/compute/kernels/vector_selection_filter_internal.h"
 #include "arrow/compute/kernels/vector_selection_internal.h"
@@ -49,8 +50,7 @@ using internal::CopyBitmap;
 using internal::CountSetBits;
 using internal::OptionalBitBlockCounter;
 
-namespace compute {
-namespace internal {
+namespace compute::internal {
 
 namespace {
 
@@ -863,20 +863,29 @@ Status ExtensionFilterExec(KernelContext* ctx, const 
ExecSpan& batch, ExecResult
   return Status::OK();
 }
 
-Status StructFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* 
out) {
-  // Transform filter to selection indices and then use Take.
+// Transform filter to selection indices and then use Take.
+Status FilterWithTakeExec(const ArrayKernelExec& take_exec, KernelContext* ctx,
+                          const ExecSpan& batch, ExecResult* out) {
   std::shared_ptr<ArrayData> indices;
   RETURN_NOT_OK(GetTakeIndices(batch[1].array,
                                FilterState::Get(ctx).null_selection_behavior,
                                ctx->memory_pool())
                     .Value(&indices));
+  KernelContext take_ctx(*ctx);
+  TakeState state{TakeOptions::NoBoundsCheck()};
+  take_ctx.SetState(&state);
+  ExecSpan take_batch({batch[0], ArraySpan(*indices)}, batch.length);
+  return take_exec(&take_ctx, take_batch, out);
+}
 
-  Datum result;
-  RETURN_NOT_OK(Take(batch[0].array.ToArrayData(), Datum(indices),
-                     TakeOptions::NoBoundsCheck(), ctx->exec_context())
-                    .Value(&result));
-  out->value = result.array();
-  return Status::OK();
+// Due to the special treatment with their Take kernels, we filter Struct and 
SparseUnion
+// arrays by transforming filter to selection indices and call Take.
+Status StructFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* 
out) {
+  return FilterWithTakeExec(StructTakeExec, ctx, batch, out);
+}
+
+Status SparseUnionFilterExec(KernelContext* ctx, const ExecSpan& batch, 
ExecResult* out) {
+  return FilterWithTakeExec(SparseUnionTakeExec, ctx, batch, out);
 }
 
 // ----------------------------------------------------------------------
@@ -1047,6 +1056,7 @@ void 
PopulateFilterKernels(std::vector<SelectionKernelData>* out) {
       {InputType(Type::LARGE_LIST), plain_filter, LargeListFilterExec},
       {InputType(Type::FIXED_SIZE_LIST), plain_filter, FSLFilterExec},
       {InputType(Type::DENSE_UNION), plain_filter, DenseUnionFilterExec},
+      {InputType(Type::SPARSE_UNION), plain_filter, SparseUnionFilterExec},
       {InputType(Type::STRUCT), plain_filter, StructFilterExec},
       {InputType(Type::MAP), plain_filter, MapFilterExec},
 
@@ -1064,12 +1074,12 @@ void 
PopulateFilterKernels(std::vector<SelectionKernelData>* out) {
       {InputType(Type::LARGE_LIST), ree_filter, LargeListFilterExec},
       {InputType(Type::FIXED_SIZE_LIST), ree_filter, FSLFilterExec},
       {InputType(Type::DENSE_UNION), ree_filter, DenseUnionFilterExec},
+      {InputType(Type::SPARSE_UNION), ree_filter, SparseUnionFilterExec},
       {InputType(Type::STRUCT), ree_filter, StructFilterExec},
       {InputType(Type::MAP), ree_filter, MapFilterExec},
   };
 }
 
-}  // namespace internal
-}  // namespace compute
+}  // namespace compute::internal
 
 }  // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/vector_selection_internal.cc 
b/cpp/src/arrow/compute/kernels/vector_selection_internal.cc
index 23b8b75bfa..98eb37e9c5 100644
--- a/cpp/src/arrow/compute/kernels/vector_selection_internal.cc
+++ b/cpp/src/arrow/compute/kernels/vector_selection_internal.cc
@@ -45,8 +45,7 @@ namespace arrow {
 
 using internal::CheckIndexBounds;
 
-namespace compute {
-namespace internal {
+namespace compute::internal {
 
 void RegisterSelectionFunction(const std::string& name, FunctionDoc doc,
                                VectorKernel base_kernel,
@@ -171,9 +170,6 @@ void VisitPlainxREEFilterOutputSegments(
 
 namespace {
 
-using FilterState = OptionsWrapper<FilterOptions>;
-using TakeState = OptionsWrapper<TakeOptions>;
-
 // ----------------------------------------------------------------------
 // Implement take for other data types where there is less performance
 // sensitivity by visiting the selected indices.
@@ -741,6 +737,66 @@ struct DenseUnionSelectionImpl
   }
 };
 
+// We need a slightly different approach for SparseUnion. For Take, we can
+// invoke Take on each child's data with boundschecking disabled. For
+// Filter on the other hand, if we naively call Filter on each child, then the
+// filter output length will have to be redundantly computed. Thus, for Filter
+// we instead convert the filter to selection indices and then invoke take.
+
+// SparseUnion selection implementation. ONLY used for Take
+struct SparseUnionSelectionImpl
+    : public Selection<SparseUnionSelectionImpl, SparseUnionType> {
+  using Base = Selection<SparseUnionSelectionImpl, SparseUnionType>;
+  LIFT_BASE_MEMBERS();
+
+  TypedBufferBuilder<int8_t> child_id_buffer_builder_;
+  const int8_t type_code_for_null_;
+
+  SparseUnionSelectionImpl(KernelContext* ctx, const ExecSpan& batch,
+                           int64_t output_length, ExecResult* out)
+      : Base(ctx, batch, output_length, out),
+        child_id_buffer_builder_(ctx->memory_pool()),
+        type_code_for_null_(
+            checked_cast<const 
UnionType&>(*this->values.type).type_codes()[0]) {}
+
+  template <typename Adapter>
+  Status GenerateOutput() {
+    SparseUnionArray typed_values(this->values.ToArrayData());
+    Adapter adapter(this);
+    RETURN_NOT_OK(adapter.Generate(
+        [&](int64_t index) {
+          child_id_buffer_builder_.UnsafeAppend(typed_values.type_code(index));
+          return Status::OK();
+        },
+        [&]() {
+          child_id_buffer_builder_.UnsafeAppend(type_code_for_null_);
+          return Status::OK();
+        }));
+    return Status::OK();
+  }
+
+  Status Init() override {
+    RETURN_NOT_OK(child_id_buffer_builder_.Reserve(output_length));
+    return Status::OK();
+  }
+
+  Status Finish() override {
+    ARROW_ASSIGN_OR_RAISE(auto child_ids_buffer, 
child_id_buffer_builder_.Finish());
+    SparseUnionArray typed_values(this->values.ToArrayData());
+    auto num_fields = typed_values.num_fields();
+    auto num_rows = child_ids_buffer->size();
+    BufferVector buffers{nullptr, std::move(child_ids_buffer)};
+    *out = ArrayData(typed_values.type(), num_rows, std::move(buffers), 0);
+    out->child_data.reserve(num_fields);
+    for (auto i = 0; i < num_fields; i++) {
+      ARROW_ASSIGN_OR_RAISE(auto child_datum,
+                            Take(*typed_values.field(i), 
*this->selection.ToArrayData()));
+      out->child_data.emplace_back(std::move(child_datum).array());
+    }
+    return Status::OK();
+  }
+};
+
 struct FSLSelectionImpl : public Selection<FSLSelectionImpl, 
FixedSizeListType> {
   Int64Builder child_index_builder;
 
@@ -909,6 +965,10 @@ Status DenseUnionTakeExec(KernelContext* ctx, const 
ExecSpan& batch, ExecResult*
   return TakeExec<DenseUnionSelectionImpl>(ctx, batch, out);
 }
 
+Status SparseUnionTakeExec(KernelContext* ctx, const ExecSpan& batch, 
ExecResult* out) {
+  return TakeExec<SparseUnionSelectionImpl>(ctx, batch, out);
+}
+
 Status StructTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* 
out) {
   return TakeExec<StructSelectionImpl>(ctx, batch, out);
 }
@@ -917,6 +977,5 @@ Status MapTakeExec(KernelContext* ctx, const ExecSpan& 
batch, ExecResult* out) {
   return TakeExec<ListSelectionImpl<MapType>>(ctx, batch, out);
 }
 
-}  // namespace internal
-}  // namespace compute
+}  // namespace compute::internal
 }  // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/vector_selection_internal.h 
b/cpp/src/arrow/compute/kernels/vector_selection_internal.h
index bcffdd820d..b9eba6ea66 100644
--- a/cpp/src/arrow/compute/kernels/vector_selection_internal.h
+++ b/cpp/src/arrow/compute/kernels/vector_selection_internal.h
@@ -26,10 +26,12 @@
 #include "arrow/compute/exec.h"
 #include "arrow/compute/function.h"
 #include "arrow/compute/kernel.h"
+#include "arrow/compute/kernels/codegen_internal.h"
 
-namespace arrow {
-namespace compute {
-namespace internal {
+namespace arrow::compute::internal {
+
+using FilterState = OptionsWrapper<FilterOptions>;
+using TakeState = OptionsWrapper<TakeOptions>;
 
 struct SelectionKernelData {
   InputType value_type;
@@ -82,9 +84,8 @@ Status ListTakeExec(KernelContext*, const ExecSpan&, 
ExecResult*);
 Status LargeListTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
 Status FSLTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
 Status DenseUnionTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
+Status SparseUnionTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
 Status StructTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
 Status MapTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
 
-}  // namespace internal
-}  // namespace compute
-}  // namespace arrow
+}  // namespace arrow::compute::internal
diff --git a/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc 
b/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc
index ab80127731..612de8505d 100644
--- a/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc
+++ b/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc
@@ -844,6 +844,7 @@ void PopulateTakeKernels(std::vector<SelectionKernelData>* 
out) {
       {InputType(Type::LARGE_LIST), take_indices, LargeListTakeExec},
       {InputType(Type::FIXED_SIZE_LIST), take_indices, FSLTakeExec},
       {InputType(Type::DENSE_UNION), take_indices, DenseUnionTakeExec},
+      {InputType(Type::SPARSE_UNION), take_indices, SparseUnionTakeExec},
       {InputType(Type::STRUCT), take_indices, StructTakeExec},
       {InputType(Type::MAP), take_indices, MapTakeExec},
   };
diff --git a/cpp/src/arrow/compute/kernels/vector_selection_test.cc 
b/cpp/src/arrow/compute/kernels/vector_selection_test.cc
index 5b624911ff..30e85c1f71 100644
--- a/cpp/src/arrow/compute/kernels/vector_selection_test.cc
+++ b/cpp/src/arrow/compute/kernels/vector_selection_test.cc
@@ -282,11 +282,6 @@ class TestFilterKernel : public ::testing::Test {
                     const std::shared_ptr<Array>& expected) {
     DoAssertFilter(values, filter, expected);
 
-    if (values->type_id() == Type::DENSE_UNION) {
-      // Concatenation of dense union not supported
-      return;
-    }
-
     // Check slicing: add M(=3) dummy values at the start and end of `values`,
     // add N(=2) dummy values at the start and end of `filter`.
     ARROW_SCOPED_TRACE("for sliced values and filter");
@@ -759,8 +754,10 @@ TEST_F(TestFilterKernelWithStruct, FilterStruct) {
 class TestFilterKernelWithUnion : public TestFilterKernel {};
 
 TEST_F(TestFilterKernelWithUnion, FilterUnion) {
-  auto union_type = dense_union({field("a", int32()), field("b", utf8())}, {2, 
5});
-  auto union_json = R"([
+  for (const auto& union_type :
+       {dense_union({field("a", int32()), field("b", utf8())}, {2, 5}),
+        sparse_union({field("a", int32()), field("b", utf8())}, {2, 5})}) {
+    auto union_json = R"([
       [2, null],
       [2, 222],
       [5, "hello"],
@@ -769,31 +766,21 @@ TEST_F(TestFilterKernelWithUnion, FilterUnion) {
       [2, 111],
       [5, null]
     ])";
-  this->AssertFilter(union_type, union_json, "[0, 0, 0, 0, 0, 0, 0]", "[]");
-  this->AssertFilter(union_type, union_json, "[0, 1, 1, null, 0, 1, 1]", R"([
+    this->AssertFilter(union_type, union_json, "[0, 0, 0, 0, 0, 0, 0]", "[]");
+    this->AssertFilter(union_type, union_json, "[0, 1, 1, null, 0, 1, 1]", R"([
       [2, 222],
       [5, "hello"],
       [2, null],
       [2, 111],
       [5, null]
     ])");
-  this->AssertFilter(union_type, union_json, "[1, 0, 1, 0, 1, 0, 0]", R"([
+    this->AssertFilter(union_type, union_json, "[1, 0, 1, 0, 1, 0, 0]", R"([
       [2, null],
       [5, "hello"],
       [2, null]
     ])");
-  this->AssertFilter(union_type, union_json, "[1, 1, 1, 1, 1, 1, 1]", 
union_json);
-
-  // Sliced
-  // (check this manually as concatenation of dense unions isn't supported: 
ARROW-4975)
-  auto values = ArrayFromJSON(union_type, union_json)->Slice(2, 4);
-  auto filter = ArrayFromJSON(boolean(), "[0, 1, 1, null, 0, 1, 1]")->Slice(2, 
4);
-  auto expected = ArrayFromJSON(union_type, R"([
-      [5, "hello"],
-      [2, null],
-      [2, 111]
-    ])");
-  this->AssertFilter(values, filter, expected);
+    this->AssertFilter(union_type, union_json, "[1, 1, 1, 1, 1, 1, 1]", 
union_json);
+  }
 }
 
 class TestFilterKernelWithRecordBatch : public TestFilterKernel {
@@ -1026,13 +1013,11 @@ void CheckTake(const std::shared_ptr<DataType>& type, 
const std::string& values_
     AssertTakeArrays(values, indices, expected);
 
     // Check sliced values
-    if (type->id() != Type::DENSE_UNION) {
-      ASSERT_OK_AND_ASSIGN(auto values_filler, MakeArrayOfNull(type, 2));
-      ASSERT_OK_AND_ASSIGN(auto values_sliced,
-                           Concatenate({values_filler, values, 
values_filler}));
-      values_sliced = values_sliced->Slice(2, values->length());
-      AssertTakeArrays(values_sliced, indices, expected);
-    }
+    ASSERT_OK_AND_ASSIGN(auto values_filler, MakeArrayOfNull(type, 2));
+    ASSERT_OK_AND_ASSIGN(auto values_sliced,
+                         Concatenate({values_filler, values, values_filler}));
+    values_sliced = values_sliced->Slice(2, values->length());
+    AssertTakeArrays(values_sliced, indices, expected);
 
     // Check sliced indices
     ASSERT_OK_AND_ASSIGN(auto zero, MakeScalar(index_type, int8_t{0}));
@@ -1477,32 +1462,34 @@ TEST_F(TestTakeKernelWithStruct, TakeStruct) {
 class TestTakeKernelWithUnion : public TestTakeKernelTyped<UnionType> {};
 
 TEST_F(TestTakeKernelWithUnion, TakeUnion) {
-  auto union_type = dense_union({field("a", int32()), field("b", utf8())}, {2, 
5});
-  auto union_json = R"([
-      [2, null],
+  for (const auto& union_type :
+       {dense_union({field("a", int32()), field("b", utf8())}, {2, 5}),
+        sparse_union({field("a", int32()), field("b", utf8())}, {2, 5})}) {
+    auto union_json = R"([
       [2, 222],
+      [2, null],
       [5, "hello"],
       [5, "eh"],
       [2, null],
       [2, 111],
       [5, null]
     ])";
-  CheckTake(union_type, union_json, "[]", "[]");
-  CheckTake(union_type, union_json, "[3, 1, 3, 1, 3]", R"([
+    CheckTake(union_type, union_json, "[]", "[]");
+    CheckTake(union_type, union_json, "[3, 0, 3, 0, 3]", R"([
       [5, "eh"],
       [2, 222],
       [5, "eh"],
       [2, 222],
       [5, "eh"]
     ])");
-  CheckTake(union_type, union_json, "[4, 2, 1, 6]", R"([
+    CheckTake(union_type, union_json, "[4, 2, 0, 6]", R"([
       [2, null],
       [5, "hello"],
       [2, 222],
       [5, null]
     ])");
-  CheckTake(union_type, union_json, "[0, 1, 2, 3, 4, 5, 6]", union_json);
-  CheckTake(union_type, union_json, "[0, 2, 2, 2, 2, 2, 2]", R"([
+    CheckTake(union_type, union_json, "[0, 1, 2, 3, 4, 5, 6]", union_json);
+    CheckTake(union_type, union_json, "[1, 2, 2, 2, 2, 2, 2]", R"([
       [2, null],
       [5, "hello"],
       [5, "hello"],
@@ -1511,6 +1498,16 @@ TEST_F(TestTakeKernelWithUnion, TakeUnion) {
       [5, "hello"],
       [5, "hello"]
     ])");
+    CheckTake(union_type, union_json, "[0, null, 1, null, 2, 2, 2]", R"([
+      [2, 222],
+      [2, null],
+      [2, null],
+      [2, null],
+      [5, "hello"],
+      [5, "hello"],
+      [5, "hello"]
+    ])");
+  }
 }
 
 class TestPermutationsWithTake : public ::testing::Test {
@@ -2162,8 +2159,10 @@ TEST_F(TestDropNullKernelWithStruct, DropNullStruct) {
 class TestDropNullKernelWithUnion : public TestDropNullKernelTyped<UnionType> 
{};
 
 TEST_F(TestDropNullKernelWithUnion, DropNullUnion) {
-  auto union_type = dense_union({field("a", int32()), field("b", utf8())}, {2, 
5});
-  auto union_json = R"([
+  for (const auto& union_type :
+       {dense_union({field("a", int32()), field("b", utf8())}, {2, 5}),
+        sparse_union({field("a", int32()), field("b", utf8())}, {2, 5})}) {
+    auto union_json = R"([
       [2, null],
       [2, 222],
       [5, "hello"],
@@ -2172,7 +2171,8 @@ TEST_F(TestDropNullKernelWithUnion, DropNullUnion) {
       [2, 111],
       [5, null]
     ])";
-  CheckDropNull(union_type, union_json, union_json);
+    CheckDropNull(union_type, union_json, union_json);
+  }
 }
 
 class TestDropNullKernelWithRecordBatch : public 
TestDropNullKernelTyped<RecordBatch> {
diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst
index 9cc1dd7993..44f43cbc87 100644
--- a/docs/source/cpp/compute.rst
+++ b/docs/source/cpp/compute.rst
@@ -1692,28 +1692,26 @@ These functions select and return a subset of their 
input.
 
+---------------+--------+--------------+--------------+--------------+-------------------------+-----------+
 | Function name | Arity  | Input type 1 | Input type 2 | Output type  | 
Options class           | Notes     |
 
+===============+========+==============+==============+==============+=========================+===========+
-| array_filter  | Binary | Any          | Boolean      | Input type 1 | 
:struct:`FilterOptions` | \(1) \(3) |
+| array_filter  | Binary | Any          | Boolean      | Input type 1 | 
:struct:`FilterOptions` | \(2)      |
 
+---------------+--------+--------------+--------------+--------------+-------------------------+-----------+
-| array_take    | Binary | Any          | Integer      | Input type 1 | 
:struct:`TakeOptions`   | \(1) \(4) |
+| array_take    | Binary | Any          | Integer      | Input type 1 | 
:struct:`TakeOptions`   | \(3)      |
 
+---------------+--------+--------------+--------------+--------------+-------------------------+-----------+
-| drop_null     | Unary  | Any          | -            | Input type 1 |        
                 | \(1) \(2) |
+| drop_null     | Unary  | Any          | -            | Input type 1 |        
                 | \(1)      |
 
+---------------+--------+--------------+--------------+--------------+-------------------------+-----------+
-| filter        | Binary | Any          | Boolean      | Input type 1 | 
:struct:`FilterOptions` | \(1) \(3) |
+| filter        | Binary | Any          | Boolean      | Input type 1 | 
:struct:`FilterOptions` | \(2)      |
 
+---------------+--------+--------------+--------------+--------------+-------------------------+-----------+
-| take          | Binary | Any          | Integer      | Input type 1 | 
:struct:`TakeOptions`   | \(1) \(4) |
+| take          | Binary | Any          | Integer      | Input type 1 | 
:struct:`TakeOptions`   | \(3)      |
 
+---------------+--------+--------------+--------------+--------------+-------------------------+-----------+
 
-* \(1) Sparse unions are unsupported.
-
-* \(2) Each element in the input is appended to the output iff it is non-null.
+* \(1) Each element in the input is appended to the output iff it is non-null.
   If the input is a record batch or table, any null value in a column drops
   the entire row.
 
-* \(3) Each element in input 1 (the values) is appended to the output iff
+* \(2) Each element in input 1 (the values) is appended to the output iff
   the corresponding element in input 2 (the filter) is true.  How
   nulls in the filter are handled can be configured using FilterOptions.
 
-* \(4) For each element *i* in input 2 (the indices), the *i*'th element
+* \(3) For each element *i* in input 2 (the indices), the *i*'th element
   in input 1 (the values) is appended to the output.
 
 Containment tests

Reply via email to