pitrou commented on a change in pull request #11641:
URL: https://github.com/apache/arrow/pull/11641#discussion_r745872869



##########
File path: cpp/src/arrow/array/array_nested.cc
##########
@@ -650,6 +651,44 @@ 
SparseUnionArray::SparseUnionArray(std::shared_ptr<DataType> type, int64_t lengt
   SetData(std::move(internal_data));
 }
 
+Result<std::shared_ptr<Array>> SparseUnionArray::GetFlattenedField(
+    int index, MemoryPool* pool) const {
+  if (index < 0 || index >= num_fields()) {
+    return Status::Invalid("Index out of range: ", index);
+  }
+  auto child_data = data_->child_data[index]->Copy();
+  // Adjust the result offset/length to be absolute.
+  if (data_->offset != 0 || data_->length != child_data->length) {
+    child_data = child_data->Slice(data_->offset, data_->length);
+  }
+  std::shared_ptr<Buffer> child_null_bitmap = child_data->buffers[0];
+  const int64_t child_offset = child_data->offset;
+
+  // Synthesize a null bitmap based on the union discriminant.
+  // Make sure the bitmap has extra bits corresponding to the child offset.
+  ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> flattened_null_bitmap,
+                        AllocateEmptyBitmap(child_data->length + child_offset, 
pool));
+  const int8_t type_code = union_type()->type_codes()[index];
+  const int8_t* type_codes = raw_type_codes();
+  int64_t offset = 0;
+  internal::GenerateBitsUnrolled(flattened_null_bitmap->mutable_data(), 
child_offset,
+                                 data_->length,
+                                 [&] { return type_codes[offset++] == 
type_code; });
+
+  // The validity of a flattened datum is the logical AND of the synthesized
+  // null bitmap buffer and the individual field element's validity.
+  if (child_null_bitmap) {
+    BitmapAnd(flattened_null_bitmap->data(), child_offset, 
child_null_bitmap->data(),
+              child_offset, child_data->length, child_offset,
+              flattened_null_bitmap->mutable_data());
+  }
+
+  auto flattened_data = child_data->Copy();

Review comment:
       `child_data` was already a copy, so is this necessary?

##########
File path: cpp/src/arrow/compute/kernels/scalar_nested_test.cc
##########
@@ -107,6 +107,109 @@ TEST(TestScalarNested, ListElementInvalid) {
               Raises(StatusCode::Invalid));
 }
 
+TEST(TestScalarNested, StructField) {
+  StructFieldOptions trivial;
+  StructFieldOptions extract0({0});
+  StructFieldOptions extract20({2, 0});
+  StructFieldOptions invalid1({-1});
+  StructFieldOptions invalid2({2, 4});
+  StructFieldOptions invalid3({0, 1});
+  FieldVector fields = {field("a", int32()), field("b", utf8()),
+                        field("c", struct_({
+                                       field("d", int64()),
+                                       field("e", float64()),
+                                   }))};
+  {
+    auto arr = ArrayFromJSON(struct_(fields), R"([
+      [1, "a", [10, 10.0]],
+      [null, "b", [11, 11.0]],
+      [3, null, [12, 12.0]],
+      null
+    ])");
+    CheckScalar("struct_field", {arr}, arr, &trivial);
+    CheckScalar("struct_field", {arr}, ArrayFromJSON(int32(), "[1, null, 3, 
null]"),
+                &extract0);
+    CheckScalar("struct_field", {arr}, ArrayFromJSON(int64(), "[10, 11, 12, 
null]"),
+                &extract20);
+    EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+                                    ::testing::HasSubstr("out-of-bounds field 
reference"),
+                                    CallFunction("struct_field", {arr}, 
&invalid1));
+    EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+                                    ::testing::HasSubstr("out-of-bounds field 
reference"),
+                                    CallFunction("struct_field", {arr}, 
&invalid2));
+    EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("cannot 
subscript"),

Review comment:
       Should it be `TypeError` in this case?

##########
File path: cpp/src/arrow/compute/kernels/scalar_nested.cc
##########
@@ -187,6 +188,150 @@ const FunctionDoc list_element_doc(
      "is emitted. Null values emit a null in the output."),
     {"lists", "index"});
 
+struct StructFieldFunctor {
+  static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* 
out) {
+    const auto& options = OptionsWrapper<StructFieldOptions>::Get(ctx);
+    std::shared_ptr<Array> current = batch[0].make_array();
+    for (const auto& index : options.indices) {
+      RETURN_NOT_OK(CheckIndex(index, *current->type()));
+      switch (current->type()->id()) {
+        case Type::STRUCT: {
+          const auto& struct_array = checked_cast<const 
StructArray&>(*current);
+          ARROW_ASSIGN_OR_RAISE(
+              current, struct_array.GetFlattenedField(index, 
ctx->memory_pool()));
+          break;
+        }
+        case Type::DENSE_UNION: {
+          // We implement this here instead of in DenseUnionArray since it's
+          // easiest to do via Take(), but DenseUnionArray can't rely on
+          // arrow::compute. See ARROW-8891.
+          const auto& union_array = checked_cast<const 
DenseUnionArray&>(*current);
+
+          // Generate a bitmap for the offsets buffer based on the type codes 
buffer.
+          ARROW_ASSIGN_OR_RAISE(
+              std::shared_ptr<Buffer> take_bitmap,
+              ctx->AllocateBitmap(union_array.length() + 
union_array.offset()));
+          const int8_t* type_codes = union_array.raw_type_codes();
+          const int8_t type_code = 
union_array.union_type()->type_codes()[index];
+          int64_t offset = 0;
+          arrow::internal::GenerateBitsUnrolled(
+              take_bitmap->mutable_data(), union_array.offset(), 
union_array.length(),
+              [&] { return type_codes[offset++] == type_code; });
+
+          // Pass the combined buffer to Take().
+          Datum take_indices(
+              ArrayData(int32(), union_array.length(),
+                        {std::move(take_bitmap), union_array.value_offsets()},
+                        kUnknownNullCount, union_array.offset()));
+          // Do not slice the child since the indices are relative to the 
unsliced array.
+          ARROW_ASSIGN_OR_RAISE(
+              Datum result,
+              CallFunction("take", {union_array.field(index), 
std::move(take_indices)}));
+          current = result.make_array();
+          break;
+        }
+        case Type::SPARSE_UNION: {
+          const auto& union_array = checked_cast<const 
SparseUnionArray&>(*current);
+          ARROW_ASSIGN_OR_RAISE(current,
+                                union_array.GetFlattenedField(index, 
ctx->memory_pool()));
+          break;
+        }
+        default:
+          // Should have been checked in ResolveStructFieldType
+          return Status::Invalid("struct_field: cannot reference child field 
of type ",
+                                 *current->type());
+      }
+    }
+    *out = current;
+    return Status::OK();
+  }
+
+  static Status ExecScalar(KernelContext* ctx, const ExecBatch& batch, Datum* 
out) {
+    const auto& options = OptionsWrapper<StructFieldOptions>::Get(ctx);
+    const std::shared_ptr<Scalar>* current = &batch[0].scalar();
+    for (const auto& index : options.indices) {
+      RETURN_NOT_OK(CheckIndex(index, *(*current)->type));
+      if (!(*current)->is_valid) {
+        // out should already be a null scalar of the appropriate type
+        return Status::OK();
+      }
+
+      switch ((*current)->type->id()) {
+        case Type::STRUCT: {
+          current = &checked_cast<const StructScalar&>(**current).value[index];
+          break;
+        }
+        case Type::DENSE_UNION:
+        case Type::SPARSE_UNION: {
+          const auto& union_scalar = checked_cast<const 
UnionScalar&>(**current);
+          const auto& union_ty = checked_cast<const 
UnionType&>(*(*current)->type);
+          if (union_scalar.type_code != union_ty.type_codes()[index]) {
+            // out should already be a null scalar of the appropriate type
+            return Status::OK();
+          }
+          current = &union_scalar.value;
+          break;
+        }
+        default:
+          // Should have been checked in ResolveStructFieldType
+          return Status::Invalid("struct_field: cannot reference child field 
of type ",
+                                 *(*current)->type);
+      }
+    }
+    *out = *current;
+    return Status::OK();
+  }
+
+  static Status CheckIndex(int index, const DataType& type) {
+    if (!ValidParentType(type)) {
+      return Status::Invalid("struct_field: cannot subscript field of type ", 
type);
+    } else if (index < 0 || index > type.num_fields()) {
+      return Status::Invalid("struct_field: out-of-bounds field reference to 
field ",
+                             index, " in type ", type, " with ", 
type.num_fields(),
+                             " fields");
+    }
+    return Status::OK();
+  }
+
+  static bool ValidParentType(const DataType& type) {
+    return type.id() == Type::STRUCT || type.id() == Type::DENSE_UNION ||
+           type.id() == Type::SPARSE_UNION;
+  }
+};
+
+Result<ValueDescr> ResolveStructFieldType(KernelContext* ctx,
+                                          const std::vector<ValueDescr>& 
descrs) {
+  const auto& options = OptionsWrapper<StructFieldOptions>::Get(ctx);
+  const std::shared_ptr<DataType>* type = &descrs.front().type;
+  for (const auto& index : options.indices) {
+    RETURN_NOT_OK(StructFieldFunctor::CheckIndex(index, **type));
+    type = &(*type)->field(index)->type();
+  }
+  return ValueDescr(*type, descrs.front().shape);
+}
+
+void AddStructFieldKernels(ScalarFunction* func) {
+  for (const auto shape : {ValueDescr::ARRAY, ValueDescr::SCALAR}) {
+    for (const auto in_type : {Type::STRUCT, Type::DENSE_UNION, 
Type::SPARSE_UNION}) {
+      ScalarKernel kernel({InputType(in_type, shape)}, 
OutputType(ResolveStructFieldType),
+                          shape == ValueDescr::ARRAY ? 
StructFieldFunctor::ExecArray
+                                                     : 
StructFieldFunctor::ExecScalar,
+                          OptionsWrapper<StructFieldOptions>::Init);
+      kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
+      kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;
+      DCHECK_OK(func->AddKernel(std::move(kernel)));
+    }
+  }
+}
+
+const FunctionDoc struct_field_doc(
+    "Extract children of a struct or union value by index.",
+    ("Given a series of indices, extract the child array or scalar referenced "
+     "by the index. For union values, mask the child based on the type codes "
+     "of the union array. The indices are always the child index and not the "
+     "type code (for unions) - so the first child is always index 0."),

Review comment:
       Mention that the indices are given in StructFieldOptions?

##########
File path: cpp/src/arrow/array/array_union_test.cc
##########
@@ -68,6 +69,58 @@ TEST(TestUnionArray, TestSliceEquals) {
   CheckUnion(batch->column(1));
 }
 
+TEST(TestSparseUnionArray, GetFlattenedField) {
+  auto ty = sparse_union({field("ints", int64()), field("strs", utf8())}, {2, 
7});
+  auto ints = ArrayFromJSON(int64(), "[0, 1, 2, 3]");
+  auto strs = ArrayFromJSON(utf8(), R"(["a", null, "c", "d"])");
+  auto ids = ArrayFromJSON(int8(), "[2, 7, 2, 7]")->data()->buffers[1];
+  const int length = 4;
+
+  {
+    SparseUnionArray arr(ty, length, {ints, strs}, ids);
+    ASSERT_OK(arr.ValidateFull());
+
+    ASSERT_OK_AND_ASSIGN(auto flattened, arr.GetFlattenedField(0));
+    AssertArraysEqual(*ArrayFromJSON(int64(), "[0, null, 2, null]"), 
*flattened,
+                      /*verbose=*/true);
+
+    ASSERT_OK_AND_ASSIGN(flattened, arr.GetFlattenedField(1));
+    AssertArraysEqual(*ArrayFromJSON(utf8(), R"([null, null, null, "d"])"), 
*flattened,
+                      /*verbose=*/true);
+
+    const auto sliced = checked_pointer_cast<SparseUnionArray>(arr.Slice(1, 
2));
+
+    ASSERT_OK_AND_ASSIGN(flattened, sliced->GetFlattenedField(0));
+    AssertArraysEqual(*ArrayFromJSON(int64(), "[null, 2]"), *flattened, 
/*verbose=*/true);
+
+    ASSERT_OK_AND_ASSIGN(flattened, sliced->GetFlattenedField(1));
+    AssertArraysEqual(*ArrayFromJSON(utf8(), R"([null, null])"), *flattened,
+                      /*verbose=*/true);
+
+    ASSERT_RAISES(Invalid, arr.GetFlattenedField(-1));
+    ASSERT_RAISES(Invalid, arr.GetFlattenedField(2));
+  }
+  {
+    SparseUnionArray arr(ty, length - 2, {ints->Slice(1, 2), strs->Slice(1, 
2)}, ids);
+    ASSERT_OK(arr.ValidateFull());
+
+    ASSERT_OK_AND_ASSIGN(auto flattened, arr.GetFlattenedField(0));
+    AssertArraysEqual(*ArrayFromJSON(int64(), "[1, null]"), *flattened, 
/*verbose=*/true);
+
+    ASSERT_OK_AND_ASSIGN(flattened, arr.GetFlattenedField(1));
+    AssertArraysEqual(*ArrayFromJSON(utf8(), R"([null, "c"])"), *flattened,
+                      /*verbose=*/true);
+
+    const auto sliced = checked_pointer_cast<SparseUnionArray>(arr.Slice(1, 
1));
+
+    ASSERT_OK_AND_ASSIGN(flattened, sliced->GetFlattenedField(0));
+    AssertArraysEqual(*ArrayFromJSON(int64(), "[null]"), *flattened, 
/*verbose=*/true);
+
+    ASSERT_OK_AND_ASSIGN(flattened, sliced->GetFlattenedField(1));
+    AssertArraysEqual(*ArrayFromJSON(utf8(), R"(["c"])"), *flattened, 
/*verbose=*/true);
+  }
+}

Review comment:
       Also test with an empty union array?

##########
File path: cpp/src/arrow/compute/api_scalar.h
##########
@@ -223,6 +223,18 @@ class ARROW_EXPORT SetLookupOptions : public 
FunctionOptions {
   bool skip_nulls;
 };
 
+/// Options for struct_field function
+class ARROW_EXPORT StructFieldOptions : public FunctionOptions {
+ public:
+  explicit StructFieldOptions(std::vector<int> indices);

Review comment:
       I wonder whether this should also accept a `FieldRef` or field 
resolution should be left to the caller.

##########
File path: cpp/src/arrow/compute/kernels/scalar_nested_test.cc
##########
@@ -107,6 +107,109 @@ TEST(TestScalarNested, ListElementInvalid) {
               Raises(StatusCode::Invalid));
 }
 
+TEST(TestScalarNested, StructField) {
+  StructFieldOptions trivial;
+  StructFieldOptions extract0({0});
+  StructFieldOptions extract20({2, 0});
+  StructFieldOptions invalid1({-1});
+  StructFieldOptions invalid2({2, 4});
+  StructFieldOptions invalid3({0, 1});
+  FieldVector fields = {field("a", int32()), field("b", utf8()),
+                        field("c", struct_({
+                                       field("d", int64()),
+                                       field("e", float64()),
+                                   }))};
+  {
+    auto arr = ArrayFromJSON(struct_(fields), R"([
+      [1, "a", [10, 10.0]],
+      [null, "b", [11, 11.0]],
+      [3, null, [12, 12.0]],
+      null
+    ])");
+    CheckScalar("struct_field", {arr}, arr, &trivial);
+    CheckScalar("struct_field", {arr}, ArrayFromJSON(int32(), "[1, null, 3, 
null]"),
+                &extract0);
+    CheckScalar("struct_field", {arr}, ArrayFromJSON(int64(), "[10, 11, 12, 
null]"),
+                &extract20);
+    EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+                                    ::testing::HasSubstr("out-of-bounds field 
reference"),
+                                    CallFunction("struct_field", {arr}, 
&invalid1));
+    EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+                                    ::testing::HasSubstr("out-of-bounds field 
reference"),
+                                    CallFunction("struct_field", {arr}, 
&invalid2));
+    EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("cannot 
subscript"),
+                                    CallFunction("struct_field", {arr}, 
&invalid3));
+  }
+  {
+    auto ty = dense_union(fields, {2, 5, 8});
+    auto arr = ArrayFromJSON(ty, R"([
+      [2, 1],
+      [5, "foo"],
+      [8, null],
+      [8, [10, 10.0]]
+    ])");
+    CheckScalar("struct_field", {arr}, arr, &trivial);
+    CheckScalar("struct_field", {arr}, ArrayFromJSON(int32(), "[1, null, null, 
null]"),
+                &extract0);
+    CheckScalar("struct_field", {arr}, ArrayFromJSON(int64(), "[null, null, 
null, 10]"),
+                &extract20);
+    EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+                                    ::testing::HasSubstr("out-of-bounds field 
reference"),
+                                    CallFunction("struct_field", {arr}, 
&invalid1));
+    EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+                                    ::testing::HasSubstr("out-of-bounds field 
reference"),
+                                    CallFunction("struct_field", {arr}, 
&invalid2));
+    EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("cannot 
subscript"),
+                                    CallFunction("struct_field", {arr}, 
&invalid3));
+
+    // Test edge cases for union representation
+    auto ints = ArrayFromJSON(fields[0]->type(), "[null, 2, 3]");
+    auto strs = ArrayFromJSON(fields[1]->type(), R"([null, "bar"])");
+    auto nested = ArrayFromJSON(fields[2]->type(), R"([null, [10, 10.0]])");
+    auto type_ids = ArrayFromJSON(int8(), "[2, 5, 8, 2, 5, 
8]")->data()->buffers[1];
+    auto offsets = ArrayFromJSON(int32(), "[0, 0, 0, 1, 1, 
1]")->data()->buffers[1];
+
+    arr = std::make_shared<DenseUnionArray>(ty, /*length=*/6,
+                                            ArrayVector{ints, strs, nested}, 
type_ids,
+                                            offsets, /*offset=*/0);
+    // Sliced parent
+    CheckScalar("struct_field", {arr->Slice(3, 3)},
+                ArrayFromJSON(int32(), "[2, null, null]"), &extract0);
+    // Sliced child
+    arr = std::make_shared<DenseUnionArray>(ty, /*length=*/6,
+                                            ArrayVector{ints->Slice(1, 2), 
strs, nested},
+                                            type_ids, offsets, /*offset=*/0);
+    CheckScalar("struct_field", {arr},
+                ArrayFromJSON(int32(), "[2, null, null, 3, null, null]"), 
&extract0);
+    // Sliced parent + sliced child
+    CheckScalar("struct_field", {arr->Slice(3, 3)},
+                ArrayFromJSON(int32(), "[3, null, null]"), &extract0);
+  }
+  {
+    // The underlying implementation is tested directly/more thoroughly in
+    // array_union_test.cc.
+    auto arr = ArrayFromJSON(sparse_union(fields, {2, 5, 8}), R"([
+      [2, 1],
+      [5, "foo"],
+      [8, null],
+      [8, [10, 10.0]]
+    ])");
+    CheckScalar("struct_field", {arr}, arr, &trivial);
+    CheckScalar("struct_field", {arr}, ArrayFromJSON(int32(), "[1, null, null, 
null]"),
+                &extract0);
+    CheckScalar("struct_field", {arr}, ArrayFromJSON(int64(), "[null, null, 
null, 10]"),
+                &extract20);
+    EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+                                    ::testing::HasSubstr("out-of-bounds field 
reference"),
+                                    CallFunction("struct_field", {arr}, 
&invalid1));
+    EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+                                    ::testing::HasSubstr("out-of-bounds field 
reference"),
+                                    CallFunction("struct_field", {arr}, 
&invalid2));
+    EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("cannot 
subscript"),
+                                    CallFunction("struct_field", {arr}, 
&invalid3));
+  }

Review comment:
       What happens with non-nested `arr` and `trivial` options? Should it be 
tested here?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to