bkietz commented on code in PR #38252:
URL: https://github.com/apache/arrow/pull/38252#discussion_r1424347648


##########
cpp/src/arrow/array/util.cc:
##########
@@ -367,231 +369,254 @@ static Result<std::shared_ptr<Scalar>> 
MakeScalarForRunEndValue(
   return std::make_shared<Int64Scalar>(run_end);
 }
 
-// get the maximum buffer length required, then allocate a single zeroed buffer
-// to use anywhere a buffer is required
 class NullArrayFactory {
  public:
-  struct GetBufferLength {
-    GetBufferLength(const std::shared_ptr<DataType>& type, int64_t length)
-        : type_(*type), length_(length), 
buffer_length_(bit_util::BytesForBits(length)) {}
-
-    Result<int64_t> Finish() && {
-      RETURN_NOT_OK(VisitTypeInline(type_, this));
-      return buffer_length_;
-    }
-
-    template <typename T, typename = 
decltype(TypeTraits<T>::bytes_required(0))>
-    Status Visit(const T&) {
-      return MaxOf(TypeTraits<T>::bytes_required(length_));
-    }
-
-    template <typename T>
-    enable_if_var_size_list<T, Status> Visit(const T& type) {
-      // values array may be empty, but there must be at least one offset of 0
-      RETURN_NOT_OK(MaxOf(sizeof(typename T::offset_type) * (length_ + 1)));
-      RETURN_NOT_OK(MaxOf(GetBufferLength(type.value_type(), /*length=*/0)));
-      return Status::OK();
-    }
-
-    template <typename T>
-    enable_if_list_view<T, Status> Visit(const T& type) {
-      RETURN_NOT_OK(MaxOf(sizeof(typename T::offset_type) * length_));
-      RETURN_NOT_OK(MaxOf(GetBufferLength(type.value_type(), /*length=*/0)));
-      return Status::OK();
-    }
-
-    template <typename T>
-    enable_if_base_binary<T, Status> Visit(const T&) {
-      // values buffer may be empty, but there must be at least one offset of 0
-      return MaxOf(sizeof(typename T::offset_type) * (length_ + 1));
-    }
-
-    Status Visit(const BinaryViewType& type) {
-      return MaxOf(sizeof(BinaryViewType::c_type) * length_);
-    }
-
-    Status Visit(const FixedSizeListType& type) {
-      return MaxOf(GetBufferLength(type.value_type(), type.list_size() * 
length_));
-    }
-
-    Status Visit(const FixedSizeBinaryType& type) {
-      return MaxOf(type.byte_width() * length_);
-    }
-
-    Status Visit(const StructType& type) {
-      for (const auto& child : type.fields()) {
-        RETURN_NOT_OK(MaxOf(GetBufferLength(child->type(), length_)));
-      }
-      return Status::OK();
-    }
-
-    Status Visit(const SparseUnionType& type) {
-      // type codes
-      RETURN_NOT_OK(MaxOf(length_));
-      // will create children of the same length as the union
-      for (const auto& child : type.fields()) {
-        RETURN_NOT_OK(MaxOf(GetBufferLength(child->type(), length_)));
-      }
-      return Status::OK();
-    }
+  // For most types, every buffer in an entirely null array will contain 
nothing but
+  // zeroes. For arrays of such types, we can allocate a single buffer and use 
that in
+  // every position of the array data. The first stage of visitation handles 
assessment
+  // of this buffer's size, the second uses the resulting buffer to build the 
null array.
+  //
+  // The first stage may not allocate from the MemoryPool or raise a failing 
status.
+  //
+  // In the second stage, `zero_buffer_` has been allocated and `out_` has:
+  // - type = type_
+  // - length = length_
+  // - null_count = length_ unless current output may have direct nulls,
+  //                0 otherwise
+  // - offset = 0
+  // - buffers = []
+  // - child_data = [nullptr] * type.num_fields()
+  // - dictionary = nullptr
+  bool presizing_zero_buffer_;
 
-    Status Visit(const DenseUnionType& type) {
-      // type codes
-      RETURN_NOT_OK(MaxOf(length_));
-      // offsets
-      RETURN_NOT_OK(MaxOf(sizeof(int32_t) * length_));
-      // will create children of length 1
-      for (const auto& child : type.fields()) {
-        RETURN_NOT_OK(MaxOf(GetBufferLength(child->type(), 1)));
-      }
-      return Status::OK();
-    }
+  NullArrayFactory(const std::shared_ptr<DataType>& type, bool nullable, 
int64_t length)
+      : presizing_zero_buffer_{true},
+        type_{type},
+        nullable_{nullable},
+        length_{length},
+        zero_buffer_length_{MayHaveDirectNulls() ? 
bit_util::BytesForBits(length) : 0} {}
 
-    Status Visit(const DictionaryType& type) {
-      RETURN_NOT_OK(MaxOf(GetBufferLength(type.value_type(), length_)));
-      return MaxOf(GetBufferLength(type.index_type(), length_));
-    }
+  NullArrayFactory(const std::shared_ptr<DataType>& type, bool nullable, 
int64_t length,
+                   const std::shared_ptr<Buffer>& zero_buffer, MemoryPool* 
pool)
+      : presizing_zero_buffer_{false},
+        type_{type},
+        nullable_{nullable},
+        length_{length},
+        zero_buffer_length_{MayHaveDirectNulls() ? 
bit_util::BytesForBits(length) : 0},
+        zero_buffer_{&zero_buffer},
+        pool_{pool} {}
 
-    Status Visit(const RunEndEncodedType& type) {
-      // RunEndEncodedType has no buffers, only child arrays
-      buffer_length_ = 0;
-      return Status::OK();
-    }
-
-    Status Visit(const ExtensionType& type) {
-      // XXX is an extension array's length always == storage length
-      return MaxOf(GetBufferLength(type.storage_type(), length_));
-    }
+  template <typename... Args>
+  explicit NullArrayFactory(const std::shared_ptr<Field>& field, const 
Args&... args)
+      : NullArrayFactory{field->type(), field->nullable(), args...} {}
 
-    Status Visit(const DataType& type) {
-      return Status::NotImplemented("construction of all-null ", type);
-    }
+  bool MayHaveDirectNulls() const {
+    if (type_->storage_id() == Type::NA) return true;
+    return nullable_ && internal::HasValidityBitmap(type_->storage_id());
+  }
 
-   private:
-    Status MaxOf(GetBufferLength&& other) {
-      ARROW_ASSIGN_OR_RAISE(int64_t buffer_length, std::move(other).Finish());
-      return MaxOf(buffer_length);
-    }
+  void ZeroBufferMustBeAtLeast(int64_t length) {
+    DCHECK(presizing_zero_buffer_);
+    zero_buffer_length_ = std::max(zero_buffer_length_, length);
+  }
 
-    Status MaxOf(int64_t buffer_length) {
-      if (buffer_length > buffer_length_) {
-        buffer_length_ = buffer_length;
-      }
-      return Status::OK();
-    }
+  std::shared_ptr<Buffer> GetValidityBitmap() const {
+    DCHECK(!presizing_zero_buffer_);
+    return MayHaveDirectNulls() ? *zero_buffer_ : nullptr;
+  }
 
-    const DataType& type_;
-    int64_t length_, buffer_length_;
-  };
+  static int64_t GetZeroBufferLength(const std::shared_ptr<DataType>& type, 
bool nullable,
+                                     int64_t length) {
+    NullArrayFactory factory{type, nullable, length};
+    DCHECK_OK(VisitTypeInline(*type, &factory));
+    return factory.zero_buffer_length_;
+  }
 
-  NullArrayFactory(MemoryPool* pool, const std::shared_ptr<DataType>& type,
-                   int64_t length)
-      : pool_(pool), type_(type), length_(length) {}
+  static int64_t GetZeroBufferLength(const std::shared_ptr<Field>& field,
+                                     int64_t length) {
+    return GetZeroBufferLength(field->type(), field->nullable(), length);
+  }
 
-  Status CreateBuffer() {
-    if (type_->id() == Type::RUN_END_ENCODED) {
-      buffer_ = NULLPTR;
-      return Status::OK();
+  Status Visit(const NullType&) {
+    if (presizing_zero_buffer_) {
+      // null needs no buffers; don't touch the zero buffer size
+    } else {
+      out_->buffers = {nullptr};
     }
-    ARROW_ASSIGN_OR_RAISE(int64_t buffer_length,
-                          GetBufferLength(type_, length_).Finish());
-    ARROW_ASSIGN_OR_RAISE(buffer_, AllocateBuffer(buffer_length, pool_));
-    std::memset(buffer_->mutable_data(), 0, buffer_->size());
     return Status::OK();
   }
 
-  Result<std::shared_ptr<ArrayData>> Create() {
-    if (buffer_ == nullptr) {
-      RETURN_NOT_OK(CreateBuffer());
+  Status Visit(const BooleanType& type) {
+    if (presizing_zero_buffer_) {
+      ZeroBufferMustBeAtLeast(bit_util::BytesForBits(length_));
+      return Status::OK();
     }
-    std::vector<std::shared_ptr<ArrayData>> child_data(type_->num_fields());
-    auto buffer_slice =
-        buffer_ ? SliceBuffer(buffer_, 0, bit_util::BytesForBits(length_)) : 
NULLPTR;
-    out_ = ArrayData::Make(type_, length_, {std::move(buffer_slice)}, 
child_data, length_,
-                           0);
-    RETURN_NOT_OK(VisitTypeInline(*type_, this));
-    return out_;
-  }
-
-  Status Visit(const NullType&) {
-    out_->buffers.resize(1, nullptr);
+    out_->buffers = {GetValidityBitmap(), *zero_buffer_};
     return Status::OK();
   }
 
-  Status Visit(const FixedWidthType&) {
-    out_->buffers.resize(2, buffer_);
+  Status Visit(const FixedWidthType& type) {
+    if (presizing_zero_buffer_) {
+      ZeroBufferMustBeAtLeast(type.byte_width() * length_);
+      return Status::OK();
+    }
+    out_->buffers = {GetValidityBitmap(), *zero_buffer_};
     return Status::OK();
   }
 
   template <typename T>
   enable_if_base_binary<T, Status> Visit(const T&) {
-    out_->buffers.resize(3, buffer_);
+    if (presizing_zero_buffer_) {
+      // values buffer may be empty, but there must be at least one offset of 0
+      ZeroBufferMustBeAtLeast(sizeof(typename T::offset_type) * (length_ + 1));
+      return Status::OK();
+    }
+    out_->buffers = {GetValidityBitmap(), *zero_buffer_, *zero_buffer_};
     return Status::OK();
   }
 
   Status Visit(const BinaryViewType&) {
-    out_->buffers.resize(2, buffer_);
+    if (presizing_zero_buffer_) {
+      ZeroBufferMustBeAtLeast(sizeof(BinaryViewType::c_type) * length_);
+      return Status::OK();
+    }
+    out_->buffers = {GetValidityBitmap(), *zero_buffer_};
     return Status::OK();
   }
 
   template <typename T>
   enable_if_var_length_list_like<T, Status> Visit(const T& type) {
-    out_->buffers.resize(is_list_view(T::type_id) ? 3 : 2, buffer_);
-    ARROW_ASSIGN_OR_RAISE(out_->child_data[0], CreateChild(type, 0, 
/*length=*/0));
-    return Status::OK();
+    constexpr bool kIsView = is_list_view(T::type_id);
+    if (presizing_zero_buffer_) {
+      auto offsets_length = length_;
+      if constexpr (!kIsView) {
+        // there must be at least one offset of 0
+        offsets_length += 1;
+      }
+      ZeroBufferMustBeAtLeast(sizeof(typename T::offset_type) * 
offsets_length);
+      // include length required for zero length child
+      ZeroBufferMustBeAtLeast(GetZeroBufferLength(type.value_field(), 0));
+      return Status::OK();
+    }
+    if constexpr (!kIsView) {
+      out_->buffers = {GetValidityBitmap(), *zero_buffer_};
+    } else {
+      out_->buffers = {GetValidityBitmap(), *zero_buffer_, *zero_buffer_};
+    }
+    return CreateChild(0, /*length=*/0);
   }
 
   Status Visit(const FixedSizeListType& type) {
-    ARROW_ASSIGN_OR_RAISE(out_->child_data[0],
-                          CreateChild(type, 0, length_ * type.list_size()));
-    return Status::OK();
+    if (presizing_zero_buffer_) {
+      ZeroBufferMustBeAtLeast(
+          GetZeroBufferLength(type.value_field(), type.list_size() * length_));
+      return Status::OK();
+    }
+    out_->buffers = {GetValidityBitmap()};
+    return CreateChild(0, type.list_size() * length_);
   }
 
   Status Visit(const StructType& type) {
-    for (int i = 0; i < type_->num_fields(); ++i) {
-      ARROW_ASSIGN_OR_RAISE(out_->child_data[i], CreateChild(type, i, 
length_));
+    if (presizing_zero_buffer_) {
+      for (const auto& child : type.fields()) {
+        ZeroBufferMustBeAtLeast(GetZeroBufferLength(child, length_));
+      }
+      return Status::OK();
+    }
+    out_->buffers = {GetValidityBitmap()};
+    for (int i = 0; i < type.num_fields(); ++i) {
+      RETURN_NOT_OK(CreateChild(i, length_));
     }
     return Status::OK();
   }
 
+  static Result<int8_t> GetIdOfFirstNullableUnionMember(const UnionType& type) 
{
+    for (auto [field, id] : Zip(type.fields(), type.type_codes())) {
+      if (field->nullable()) return id;
+    }
+    return Status::Invalid("Cannot produce an array of null ", type,

Review Comment:
   That makes sense



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to