This is an automated email from the ASF dual-hosted git repository. bkietz pushed a commit to branch feature/format-string-view in repository https://gitbox.apache.org/repos/asf/arrow.git
commit 7474342cf2214d88778dc33526013ec82537636a Author: Benjamin Kietzman <[email protected]> AuthorDate: Fri Nov 18 13:07:57 2022 -0500 Added validation for StringView arrays --- cpp/src/arrow/array/array_base.cc | 4 +- cpp/src/arrow/array/array_binary.h | 38 +++++- cpp/src/arrow/array/array_binary_test.cc | 67 +++++++--- cpp/src/arrow/array/array_test.cc | 4 +- cpp/src/arrow/array/builder_base.cc | 17 ++- cpp/src/arrow/array/builder_binary.h | 4 +- cpp/src/arrow/array/util.cc | 28 +++- cpp/src/arrow/array/validate.cc | 147 +++++++++++++++++++-- cpp/src/arrow/compare.cc | 8 +- .../arrow/compute/kernels/scalar_nested_test.cc | 3 + .../arrow/compute/kernels/scalar_string_test.cc | 10 +- cpp/src/arrow/compute/kernels/vector_hash.cc | 94 ++++--------- cpp/src/arrow/scalar.cc | 20 +-- cpp/src/arrow/scalar.h | 18 ++- cpp/src/arrow/testing/gtest_util.h | 6 +- cpp/src/arrow/type.h | 11 +- 16 files changed, 331 insertions(+), 148 deletions(-) diff --git a/cpp/src/arrow/array/array_base.cc b/cpp/src/arrow/array/array_base.cc index de9ab2e985..f4f860ca95 100644 --- a/cpp/src/arrow/array/array_base.cc +++ b/cpp/src/arrow/array/array_base.cc @@ -83,7 +83,9 @@ struct ScalarFromArraySlotImpl { } Status Visit(const BinaryViewArray& a) { - return Status::NotImplemented("ScalarFromArraySlot -> BinaryView"); + StringHeader header = a.Value(index_); + std::string_view view{header}; + return Finish(std::string{view}); } Status Visit(const FixedSizeBinaryArray& a) { return Finish(a.GetString(index_)); } diff --git a/cpp/src/arrow/array/array_binary.h b/cpp/src/arrow/array/array_binary.h index 03ee77fab8..1c8947dde3 100644 --- a/cpp/src/arrow/array/array_binary.h +++ b/cpp/src/arrow/array/array_binary.h @@ -230,16 +230,37 @@ class ARROW_EXPORT BinaryViewArray : public PrimitiveArray { explicit BinaryViewArray(const std::shared_ptr<ArrayData>& data); - BinaryViewArray(int64_t length, const std::shared_ptr<Buffer>& data, - const std::shared_ptr<Buffer>& null_bitmap = NULLPTR, + /// By default, ValidateFull() will check each view in a BinaryViewArray or + /// StringViewArray to ensure it references a memory range owned by one of the array's + /// buffers. + /// + /// If the last character buffer is null, ValidateFull will skip this step. Use this + /// for arrays which view memory elsewhere. + static BufferVector DoNotValidateViews(BufferVector char_buffers) { + char_buffers.push_back(NULLPTR); + return char_buffers; + } + + static bool OptedOutOfViewValidation(const ArrayData& data) { + return data.buffers.back() == NULLPTR; + } + bool OptedOutOfViewValidation() const { return OptedOutOfViewValidation(*data_); } + + BinaryViewArray(int64_t length, std::shared_ptr<Buffer> data, BufferVector char_buffers, + std::shared_ptr<Buffer> null_bitmap = NULLPTR, int64_t null_count = kUnknownNullCount, int64_t offset = 0) - : PrimitiveArray(binary_view(), length, data, null_bitmap, null_count, offset) {} + : PrimitiveArray(binary_view(), length, std::move(data), std::move(null_bitmap), + null_count, offset) { + for (auto& char_buffer : char_buffers) { + data_->buffers.push_back(std::move(char_buffer)); + } + } const StringHeader* raw_values() const { return reinterpret_cast<const StringHeader*>(raw_values_) + data_->offset; } - StringHeader Value(int64_t i) const { return raw_values()[i]; } + const StringHeader& Value(int64_t i) const { return raw_values()[i]; } // For API compatibility with BinaryArray etc. std::string_view GetView(int64_t i) const { return std::string_view(Value(i)); } @@ -264,10 +285,13 @@ class ARROW_EXPORT StringViewArray : public BinaryViewArray { explicit StringViewArray(const std::shared_ptr<ArrayData>& data); - StringViewArray(int64_t length, const std::shared_ptr<Buffer>& data, - const std::shared_ptr<Buffer>& null_bitmap = NULLPTR, + StringViewArray(int64_t length, std::shared_ptr<Buffer> data, BufferVector char_buffers, + std::shared_ptr<Buffer> null_bitmap = NULLPTR, int64_t null_count = kUnknownNullCount, int64_t offset = 0) - : BinaryViewArray(utf8_view(), length, data, null_bitmap, null_count, offset) {} + : BinaryViewArray(length, std::move(data), std::move(char_buffers), + std::move(null_bitmap), null_count, offset) { + data_->type = utf8_view(); + } /// \brief Validate that this array contains only valid UTF8 entries /// diff --git a/cpp/src/arrow/array/array_binary_test.cc b/cpp/src/arrow/array/array_binary_test.cc index c9f1b1cfab..92fc16f775 100644 --- a/cpp/src/arrow/array/array_binary_test.cc +++ b/cpp/src/arrow/array/array_binary_test.cc @@ -32,6 +32,7 @@ #include "arrow/status.h" #include "arrow/testing/builder.h" #include "arrow/testing/gtest_util.h" +#include "arrow/testing/matchers.h" #include "arrow/testing/util.h" #include "arrow/type.h" #include "arrow/type_traits.h" @@ -365,38 +366,73 @@ TYPED_TEST(TestStringArray, TestValidateOffsets) { this->TestValidateOffsets(); TYPED_TEST(TestStringArray, TestValidateData) { this->TestValidateData(); } +TEST(StringViewArray, Validate) { + auto MakeArray = [](std::vector<StringHeader> headers, BufferVector char_buffers) { + auto length = static_cast<int64_t>(headers.size()); + return StringViewArray(length, Buffer::Wrap(std::move(headers)), + std::move(char_buffers)); + }; + + // empty array is valid + EXPECT_THAT(MakeArray({}, {}).ValidateFull(), Ok()); + + // inline views need not have a corresponding buffer + EXPECT_THAT(MakeArray({"hello", "world", "inline me"}, {}).ValidateFull(), Ok()); + + auto buffer_s = Buffer::FromString("supercalifragilistic(sp?)"); + auto buffer_y = Buffer::FromString("yyyyyyyyyyyyyyyyyyyyyyyyy"); + + // non-inline views are expected to reside in a buffer managed by the array + EXPECT_THAT(MakeArray({StringHeader(std::string_view{*buffer_s}), + StringHeader(std::string_view{*buffer_y})}, + {buffer_s, buffer_y}) + .ValidateFull(), + Ok()); + + EXPECT_THAT(MakeArray({StringHeader(std::string_view{*buffer_s}), + // if a view points outside the buffers, that is invalid + StringHeader("from a galaxy far, far away"), + StringHeader(std::string_view{*buffer_y})}, + {buffer_s, buffer_y}) + .ValidateFull(), + Raises(StatusCode::Invalid)); + + // ... unless specifically overridden + EXPECT_THAT( + MakeArray({"from a galaxy far, far away"}, StringViewArray::DoNotValidateViews({})) + .ValidateFull(), + Ok()); +} + template <typename T> class TestUTF8Array : public ::testing::Test { public: using TypeClass = T; - using offset_type = typename TypeClass::offset_type; using ArrayType = typename TypeTraits<TypeClass>::ArrayType; - Status ValidateUTF8(int64_t length, std::vector<offset_type> offsets, - std::string_view data, int64_t offset = 0) { - ArrayType arr(length, Buffer::Wrap(offsets), std::make_shared<Buffer>(data), - /*null_bitmap=*/nullptr, /*null_count=*/0, offset); - return arr.ValidateUTF8(); + Status ValidateUTF8(const Array& arr) { + return checked_cast<const ArrayType&>(arr).ValidateUTF8(); } - Status ValidateUTF8(const std::string& json) { - auto ty = TypeTraits<T>::type_singleton(); - auto arr = ArrayFromJSON(ty, json); - return checked_cast<const ArrayType&>(*arr).ValidateUTF8(); + Status ValidateUTF8(std::vector<std::string> values) { + std::shared_ptr<Array> arr; + ArrayFromVector<T, std::string>(values, &arr); + return ValidateUTF8(*arr); } void TestValidateUTF8() { - ASSERT_OK(ValidateUTF8(R"(["Voix", "ambiguë", "d’un", "cœur"])")); - ASSERT_OK(ValidateUTF8(1, {0, 4}, "\xf4\x8f\xbf\xbf")); // \U0010ffff + ASSERT_OK(ValidateUTF8(*ArrayFromJSON(TypeTraits<T>::type_singleton(), + R"(["Voix", "ambiguë", "d’un", "cœur"])"))); + ASSERT_OK(ValidateUTF8({"\xf4\x8f\xbf\xbf"})); // \U0010ffff - ASSERT_RAISES(Invalid, ValidateUTF8(1, {0, 1}, "\xf4")); + ASSERT_RAISES(Invalid, ValidateUTF8({"\xf4"})); // More tests in TestValidateData() above // (ValidateFull() calls ValidateUTF8() internally) } }; -TYPED_TEST_SUITE(TestUTF8Array, StringArrowTypes); +TYPED_TEST_SUITE(TestUTF8Array, StringOrStringViewArrowTypes); TYPED_TEST(TestUTF8Array, TestValidateUTF8) { this->TestValidateUTF8(); } @@ -908,9 +944,6 @@ class TestBaseBinaryDataVisitor : public ::testing::Test { std::shared_ptr<DataType> type_; }; -using BinaryAndBin = ::testing::Types<BinaryType, LargeBinaryType, StringType, - LargeStringType, BinaryViewType, StringViewType>; - TYPED_TEST_SUITE(TestBaseBinaryDataVisitor, BaseBinaryOrBinaryViewLikeArrowTypes); TYPED_TEST(TestBaseBinaryDataVisitor, Basics) { this->TestBasics(); } diff --git a/cpp/src/arrow/array/array_test.cc b/cpp/src/arrow/array/array_test.cc index d4ad1578b7..c14d4f21ac 100644 --- a/cpp/src/arrow/array/array_test.cc +++ b/cpp/src/arrow/array/array_test.cc @@ -544,12 +544,14 @@ static ScalarVector GetScalars() { std::make_shared<DurationScalar>(60, duration(TimeUnit::SECOND)), std::make_shared<BinaryScalar>(hello), std::make_shared<LargeBinaryScalar>(hello), + std::make_shared<BinaryViewScalar>(hello), std::make_shared<FixedSizeBinaryScalar>( hello, fixed_size_binary(static_cast<int32_t>(hello->size()))), std::make_shared<Decimal128Scalar>(Decimal128(10), decimal(16, 4)), std::make_shared<Decimal256Scalar>(Decimal256(10), decimal(76, 38)), std::make_shared<StringScalar>(hello), std::make_shared<LargeStringScalar>(hello), + std::make_shared<StringViewScalar>(hello), std::make_shared<ListScalar>(ArrayFromJSON(int8(), "[1, 2, 3]")), ScalarFromJSON(map(int8(), utf8()), R"([[1, "foo"], [2, "bar"]])"), std::make_shared<LargeListScalar>(ArrayFromJSON(int8(), "[1, 1, 2, 2, 3, 3]")), @@ -594,7 +596,7 @@ TEST_F(TestArray, TestMakeArrayFromScalar) { ASSERT_EQ(array->null_count(), 0); // test case for ARROW-13321 - for (int64_t i : std::vector<int64_t>{0, length / 2, length - 1}) { + for (int64_t i : {int64_t{0}, length / 2, length - 1}) { ASSERT_OK_AND_ASSIGN(auto s, array->GetScalar(i)); AssertScalarsEqual(*s, *scalar, /*verbose=*/true); } diff --git a/cpp/src/arrow/array/builder_base.cc b/cpp/src/arrow/array/builder_base.cc index e9d5fb44ac..3b2ee570f9 100644 --- a/cpp/src/arrow/array/builder_base.cc +++ b/cpp/src/arrow/array/builder_base.cc @@ -103,10 +103,7 @@ namespace { struct AppendScalarImpl { template <typename T> - enable_if_t<has_c_type<T>::value || is_decimal_type<T>::value || - is_fixed_size_binary_type<T>::value, - Status> - Visit(const T&) { + Status HandleFixedWidth(const T&) { auto builder = checked_cast<typename TypeTraits<T>::BuilderType*>(builder_); RETURN_NOT_OK(builder->Reserve(n_repeats_ * (scalars_end_ - scalars_begin_))); @@ -125,7 +122,17 @@ struct AppendScalarImpl { } template <typename T> - enable_if_base_binary<T, Status> Visit(const T&) { + enable_if_t<has_c_type<T>::value, Status> Visit(const T& t) { + return HandleFixedWidth(t); + } + + Status Visit(const FixedSizeBinaryType& t) { return HandleFixedWidth(t); } + Status Visit(const Decimal128Type& t) { return HandleFixedWidth(t); } + Status Visit(const Decimal256Type& t) { return HandleFixedWidth(t); } + + template <typename T> + enable_if_t<is_binary_like_type<T>::value || is_string_like_type<T>::value, Status> + Visit(const T&) { int64_t data_size = 0; for (const std::shared_ptr<Scalar>* raw = scalars_begin_; raw != scalars_end_; raw++) { diff --git a/cpp/src/arrow/array/builder_binary.h b/cpp/src/arrow/array/builder_binary.h index b9d926cb16..30ab4b9d4a 100644 --- a/cpp/src/arrow/array/builder_binary.h +++ b/cpp/src/arrow/array/builder_binary.h @@ -576,7 +576,6 @@ class ARROW_EXPORT BinaryViewBuilder : public ArrayBuilder { Status Append(StringHeader value) { ARROW_RETURN_NOT_OK(Reserve(1)); UnsafeAppend(value); - UnsafeAppendToBitmap(true); return Status::OK(); } @@ -591,7 +590,6 @@ class ARROW_EXPORT BinaryViewBuilder : public ArrayBuilder { value = data_heap_builder_.UnsafeAppend(value, length); } UnsafeAppend(StringHeader(value, length)); - UnsafeAppendToBitmap(true); } void UnsafeAppend(const char* value, int64_t length) { @@ -653,7 +651,7 @@ class ARROW_EXPORT BinaryViewBuilder : public ArrayBuilder { } void UnsafeAppendEmptyValue() { - data_builder_.UnsafeAppend(StringHeader("")); + data_builder_.UnsafeAppend(StringHeader()); UnsafeAppendToBitmap(true); } diff --git a/cpp/src/arrow/array/util.cc b/cpp/src/arrow/array/util.cc index ac9d76d469..fe5a0dd575 100644 --- a/cpp/src/arrow/array/util.cc +++ b/cpp/src/arrow/array/util.cc @@ -355,6 +355,10 @@ class NullArrayFactory { return MaxOf(sizeof(typename T::offset_type) * (length_ + 1)); } + Status Visit(const BinaryViewType& type) { + return MaxOf(sizeof(StringHeader) * length_); + } + Status Visit(const FixedSizeListType& type) { return MaxOf(GetBufferLength(type.value_type(), type.list_size() * length_)); } @@ -463,6 +467,11 @@ class NullArrayFactory { return Status::OK(); } + Status Visit(const BinaryViewType&) { + out_->buffers.resize(2, buffer_); + return Status::OK(); + } + template <typename T> enable_if_var_size_list<T, Status> Visit(const T& type) { out_->buffers.resize(2, buffer_); @@ -599,14 +608,27 @@ class RepeatedArrayFactory { RETURN_NOT_OK(CreateBufferOf(value->data(), value->size(), &values_buffer)); auto size = static_cast<typename T::offset_type>(value->size()); RETURN_NOT_OK(CreateOffsetsBuffer(size, &offsets_buffer)); - out_ = std::make_shared<typename TypeTraits<T>::ArrayType>(length_, offsets_buffer, - values_buffer); + out_ = std::make_shared<typename TypeTraits<T>::ArrayType>( + length_, std::move(offsets_buffer), std::move(values_buffer)); return Status::OK(); } template <typename T> enable_if_binary_view_like<T, Status> Visit(const T&) { - return Status::NotImplemented("binary / string view"); + const std::shared_ptr<Buffer>& value = + checked_cast<const typename TypeTraits<T>::ScalarType&>(scalar_).value; + + StringHeader header{std::string_view{*value}}; + std::shared_ptr<Buffer> header_buffer; + RETURN_NOT_OK(CreateBufferOf(&header, sizeof(header), &header_buffer)); + + BufferVector char_buffers; + if (!header.IsInline()) { + char_buffers.push_back(value); + } + out_ = std::make_shared<typename TypeTraits<T>::ArrayType>( + length_, std::move(header_buffer), std::move(char_buffers)); + return Status::OK(); } template <typename T> diff --git a/cpp/src/arrow/array/validate.cc b/cpp/src/arrow/array/validate.cc index cddb086005..53d74ba148 100644 --- a/cpp/src/arrow/array/validate.cc +++ b/cpp/src/arrow/array/validate.cc @@ -30,6 +30,7 @@ #include "arrow/util/decimal.h" #include "arrow/util/int_util_overflow.h" #include "arrow/util/logging.h" +#include "arrow/util/unreachable.h" #include "arrow/util/utf8.h" #include "arrow/visit_data_inline.h" #include "arrow/visit_type_inline.h" @@ -42,10 +43,7 @@ namespace { struct UTF8DataValidator { const ArrayData& data; - Status Visit(const DataType&) { - // Default, should be unreachable - return Status::NotImplemented(""); - } + Status Visit(const DataType&) { Unreachable("utf-8 validation of non string type"); } Status Visit(const StringViewType&) { util::InitializeUTF8(); @@ -86,10 +84,7 @@ struct BoundsChecker { int64_t min_value; int64_t max_value; - Status Visit(const DataType&) { - // Default, should be unreachable - return Status::NotImplemented(""); - } + Status Visit(const DataType&) { Unreachable("bounds checking of non integer type"); } template <typename IntegerType> enable_if_integer<IntegerType, Status> Visit(const IntegerType&) { @@ -260,9 +255,7 @@ struct ValidateArrayImpl { Status Visit(const LargeBinaryType& type) { return ValidateBinaryLike(type); } - Status Visit(const BinaryViewType& type) { - return Status::NotImplemented("binary / string view"); - } + Status Visit(const BinaryViewType& type) { return ValidateBinaryView(type); } Status Visit(const ListType& type) { return ValidateListLike(type); } @@ -455,7 +448,14 @@ struct ValidateArrayImpl { return Status::Invalid("Array length is negative"); } - if (data.buffers.size() != layout.buffers.size()) { + if (layout.variadic_spec) { + if (data.buffers.size() < layout.buffers.size()) { + return Status::Invalid("Expected at least ", layout.buffers.size(), + " buffers in array " + "of type ", + type.ToString(), ", got ", data.buffers.size()); + } + } else if (data.buffers.size() != layout.buffers.size()) { return Status::Invalid("Expected ", layout.buffers.size(), " buffers in array " "of type ", @@ -471,7 +471,9 @@ struct ValidateArrayImpl { for (int i = 0; i < static_cast<int>(data.buffers.size()); ++i) { const auto& buffer = data.buffers[i]; - const auto& spec = layout.buffers[i]; + const auto& spec = i < static_cast<int>(layout.buffers.size()) + ? layout.buffers[i] + : *layout.variadic_spec; if (buffer == nullptr) { continue; @@ -594,6 +596,125 @@ struct ValidateArrayImpl { return Status::OK(); } + Status ValidateBinaryView(const BinaryViewType& type) { + int64_t headers_byte_size = data.buffers[1]->size(); + int64_t required_headers = data.length + data.offset; + if (static_cast<int64_t>(headers_byte_size / sizeof(StringHeader)) < + required_headers) { + return Status::Invalid("Header buffer size (bytes): ", headers_byte_size, + " isn't large enough for length: ", data.length, + " and offset: ", data.offset); + } + + if (!full_validation || BinaryViewArray::OptedOutOfViewValidation(data)) { + return Status::OK(); + } + + auto* headers = data.GetValues<StringHeader>(1); + std::string_view buffer_containing_previous_view; + + auto IsSubrangeOf = [](std::string_view super, std::string_view sub) { + return super.data() <= sub.data() && + super.data() + super.size() <= sub.data() + sub.size(); + }; + + std::vector<std::string_view> buffers; + for (auto it = data.buffers.begin() + 2; it != data.buffers.end(); ++it) { + buffers.emplace_back(**it); + } + + auto CheckViews = [&](auto in_a_buffer, auto check_previous_buffer) { + if constexpr (check_previous_buffer) { + buffer_containing_previous_view = buffers.front(); + } + + for (int64_t i = 0; i < data.length; ++i) { + if (headers[i].IsInline()) continue; + + std::string_view view{headers[i]}; + + if constexpr (check_previous_buffer) { + if (ARROW_PREDICT_TRUE(IsSubrangeOf(buffer_containing_previous_view, view))) { + // Fast path: for most string view arrays, we'll have runs + // of views into the same buffer. + continue; + } + } + + if (!in_a_buffer(view)) { + return Status::Invalid( + "String view at slot ", i, + " views memory not resident in any buffer managed by the array"); + } + } + return Status::OK(); + }; + + if (buffers.empty()) { + // there are no character buffers; the only way this array + // can be valid is if all views are inline + return CheckViews([](std::string_view) { return std::false_type{}; }, + /*check_previous_buffer=*/std::false_type{}); + } + + // Simplest check for view-in-buffer: loop through buffers and check each one. + auto Linear = [&](std::string_view view) { + for (std::string_view buffer : buffers) { + if (IsSubrangeOf(buffer, view)) { + buffer_containing_previous_view = buffer; + return true; + } + } + return false; + }; + + if (buffers.size() <= 32) { + // If there are few buffers to search through, sorting/binary search is not + // worthwhile. TODO(bkietz) benchmark this and get a less magic number here. + return CheckViews(Linear, + /*check_previous_buffer=*/std::true_type{}); + } + + auto DataPtrLess = [](std::string_view l, std::string_view r) { + return l.data() < r.data(); + }; + + std::sort(buffers.begin(), buffers.end(), DataPtrLess); + bool non_overlapping = + buffers.end() != + std::adjacent_find(buffers.begin(), buffers.end(), + [](std::string_view before, std::string_view after) { + return before.data() + before.size() <= after.data(); + }); + if (ARROW_PREDICT_FALSE(!non_overlapping)) { + // Using a binary search with overlapping buffers would not *uniquely* identify + // a potentially-containing buffer. Moreover this should be a fairly rare case + // so optimizing for it seems premature. + return CheckViews(Linear, + /*check_previous_buffer=*/std::true_type{}); + } + + // More sophisticated check for view-in-buffer: binary search through the buffers. + return CheckViews( + [&](std::string_view view) { + // Find the first buffer whose data starts after the data in view- + // only buffers *before* this could contain view. Since we've additionally + // checked that the buffers do not overlap, only the buffer *immediately before* + // this could contain view. + auto one_past_potential_super = + std::upper_bound(buffers.begin(), buffers.end(), view, DataPtrLess); + + if (one_past_potential_super == buffers.begin()) return false; + + auto potential_super = *(one_past_potential_super - 1); + if (!IsSubrangeOf(potential_super, view)) return false; + + buffer_containing_previous_view = potential_super; + return true; + }, + /*check_previous_buffer=*/std::true_type{}); + } + template <typename ListType> Status ValidateListLike(const ListType& type) { const ArrayData& values = *data.child_data[0]; diff --git a/cpp/src/arrow/compare.cc b/cpp/src/arrow/compare.cc index 8ccc645046..68250f0288 100644 --- a/cpp/src/arrow/compare.cc +++ b/cpp/src/arrow/compare.cc @@ -727,19 +727,13 @@ class ScalarEqualsVisitor { Status Visit(const DoubleScalar& left) { return CompareFloating(left); } template <typename T> - typename std::enable_if<std::is_base_of<BaseBinaryScalar, T>::value, Status>::type + enable_if_t<std::is_base_of<BaseBinaryScalar, T>::value, Status> Visit(const T& left) { const auto& right = checked_cast<const BaseBinaryScalar&>(right_); result_ = internal::SharedPtrEquals(left.value, right.value); return Status::OK(); } - Status Visit(const BinaryViewScalar& left) { - const auto& right = checked_cast<const BinaryViewScalar&>(right_); - result_ = left.value == right.value; - return Status::OK(); - } - Status Visit(const Decimal128Scalar& left) { const auto& right = checked_cast<const Decimal128Scalar&>(right_); result_ = left.value == right.value; diff --git a/cpp/src/arrow/compute/kernels/scalar_nested_test.cc b/cpp/src/arrow/compute/kernels/scalar_nested_test.cc index 744f188908..523e20c4a7 100644 --- a/cpp/src/arrow/compute/kernels/scalar_nested_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_nested_test.cc @@ -796,6 +796,9 @@ TEST(MakeStruct, Array) { EXPECT_THAT(MakeStructor({i32, str}, {"i", "s"}), ResultWith(Datum(*StructArray::Make({i32, str}, field_names)))); + EXPECT_THAT(*MakeScalar("aa"), testing::Eq(StringScalar("aa"))); + EXPECT_EQ(*MakeStructor({i32, MakeScalar("aa")}, {"i", "s"})->type(), + StructType({field("i", i32->type()), field("s", str->type())})); // Scalars are broadcast to the length of the arrays EXPECT_THAT(MakeStructor({i32, MakeScalar("aa")}, {"i", "s"}), ResultWith(Datum(*StructArray::Make({i32, str}, field_names)))); diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index 2498e7f562..b390a36b4c 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -47,7 +47,6 @@ namespace compute { template <typename TestType> class BaseTestStringKernels : public ::testing::Test { protected: - using OffsetType = typename TypeTraits<TestType>::OffsetType; using ScalarType = typename TypeTraits<TestType>::ScalarType; void CheckUnary(std::string func_name, std::string json_input, @@ -97,7 +96,14 @@ class BaseTestStringKernels : public ::testing::Test { } std::shared_ptr<DataType> offset_type() { - return TypeTraits<OffsetType>::type_singleton(); + if constexpr (is_binary_view_like_type<TestType>::value) { + // Views do not have offsets, but Functions like binary_length + // will return the length as uint32 + return uint32(); + } else { + using OffsetType = typename TypeTraits<TestType>::OffsetType; + return TypeTraits<OffsetType>::type_singleton(); + } } template <typename CType = const char*> diff --git a/cpp/src/arrow/compute/kernels/vector_hash.cc b/cpp/src/arrow/compute/kernels/vector_hash.cc index f2d4c29f0e..f9637a2f71 100644 --- a/cpp/src/arrow/compute/kernels/vector_hash.cc +++ b/cpp/src/arrow/compute/kernels/vector_hash.cc @@ -30,6 +30,7 @@ #include "arrow/compute/kernels/common.h" #include "arrow/result.h" #include "arrow/util/hashing.h" +#include "arrow/util/unreachable.h" namespace arrow { @@ -261,7 +262,7 @@ class HashKernel : public KernelState { // Base class for all "regular" hash kernel implementations // (NullType has a separate implementation) -template <typename Type, typename Scalar, typename Action, +template <typename Type, typename Action, typename Scalar = typename Type::c_type, bool with_error_status = Action::with_error_status> class RegularHashKernel : public HashKernel { public: @@ -501,39 +502,13 @@ class DictionaryHashKernel : public HashKernel { }; // ---------------------------------------------------------------------- - -template <typename Type, typename Action, typename Enable = void> -struct HashKernelTraits {}; - -template <typename Type, typename Action> -struct HashKernelTraits<Type, Action, enable_if_null<Type>> { - using HashKernel = NullHashKernel<Action>; -}; - -template <typename Type, typename Action> -struct HashKernelTraits<Type, Action, enable_if_has_c_type<Type>> { - using HashKernel = RegularHashKernel<Type, typename Type::c_type, Action>; -}; - -template <typename Type, typename Action> -struct HashKernelTraits<Type, Action, enable_if_has_string_view<Type>> { - using HashKernel = RegularHashKernel<Type, std::string_view, Action>; -}; - -template <typename Type, typename Action> -Result<std::unique_ptr<HashKernel>> HashInitImpl(KernelContext* ctx, - const KernelInitArgs& args) { - using HashKernelType = typename HashKernelTraits<Type, Action>::HashKernel; - auto result = std::make_unique<HashKernelType>(args.inputs[0].GetSharedPtr(), - args.options, ctx->memory_pool()); - RETURN_NOT_OK(result->Reset()); - return std::move(result); -} - -template <typename Type, typename Action> +template <typename HashKernel> Result<std::unique_ptr<KernelState>> HashInit(KernelContext* ctx, const KernelInitArgs& args) { - return HashInitImpl<Type, Action>(ctx, args); + auto result = std::make_unique<HashKernel>(args.inputs[0].GetSharedPtr(), args.options, + ctx->memory_pool()); + RETURN_NOT_OK(result->Reset()); + return std::move(result); } template <typename Action> @@ -542,22 +517,22 @@ KernelInit GetHashInit(Type::type type_id) { // representation switch (type_id) { case Type::NA: - return HashInit<NullType, Action>; + return HashInit<NullHashKernel<Action>>; case Type::BOOL: - return HashInit<BooleanType, Action>; + return HashInit<RegularHashKernel<BooleanType, Action>>; case Type::INT8: case Type::UINT8: - return HashInit<UInt8Type, Action>; + return HashInit<RegularHashKernel<UInt8Type, Action>>; case Type::INT16: case Type::UINT16: - return HashInit<UInt16Type, Action>; + return HashInit<RegularHashKernel<UInt16Type, Action>>; case Type::INT32: case Type::UINT32: case Type::FLOAT: case Type::DATE32: case Type::TIME32: case Type::INTERVAL_MONTHS: - return HashInit<UInt32Type, Action>; + return HashInit<RegularHashKernel<UInt32Type, Action>>; case Type::INT64: case Type::UINT64: case Type::DOUBLE: @@ -566,22 +541,23 @@ KernelInit GetHashInit(Type::type type_id) { case Type::TIMESTAMP: case Type::DURATION: case Type::INTERVAL_DAY_TIME: - return HashInit<UInt64Type, Action>; + return HashInit<RegularHashKernel<UInt64Type, Action>>; case Type::BINARY: case Type::STRING: - return HashInit<BinaryType, Action>; + case Type::BINARY_VIEW: + case Type::STRING_VIEW: + return HashInit<RegularHashKernel<BinaryType, Action, std::string_view>>; case Type::LARGE_BINARY: case Type::LARGE_STRING: - return HashInit<LargeBinaryType, Action>; + return HashInit<RegularHashKernel<LargeBinaryType, Action, std::string_view>>; case Type::FIXED_SIZE_BINARY: case Type::DECIMAL128: case Type::DECIMAL256: - return HashInit<FixedSizeBinaryType, Action>; + return HashInit<RegularHashKernel<FixedSizeBinaryType, Action, std::string_view>>; case Type::INTERVAL_MONTH_DAY_NANO: - return HashInit<MonthDayNanoIntervalType, Action>; + return HashInit<RegularHashKernel<MonthDayNanoIntervalType, Action>>; default: - DCHECK(false); - return nullptr; + Unreachable("non hashable type"); } } @@ -591,31 +567,11 @@ template <typename Action> Result<std::unique_ptr<KernelState>> DictionaryHashInit(KernelContext* ctx, const KernelInitArgs& args) { const auto& dict_type = checked_cast<const DictionaryType&>(*args.inputs[0].type); - Result<std::unique_ptr<HashKernel>> indices_hasher; - switch (dict_type.index_type()->id()) { - case Type::INT8: - case Type::UINT8: - indices_hasher = HashInitImpl<UInt8Type, Action>(ctx, args); - break; - case Type::INT16: - case Type::UINT16: - indices_hasher = HashInitImpl<UInt16Type, Action>(ctx, args); - break; - case Type::INT32: - case Type::UINT32: - indices_hasher = HashInitImpl<UInt32Type, Action>(ctx, args); - break; - case Type::INT64: - case Type::UINT64: - indices_hasher = HashInitImpl<UInt64Type, Action>(ctx, args); - break; - default: - DCHECK(false) << "Unsupported dictionary index type"; - break; - } - RETURN_NOT_OK(indices_hasher); - return std::make_unique<DictionaryHashKernel>(std::move(indices_hasher.ValueOrDie()), - dict_type.value_type()); + ARROW_ASSIGN_OR_RAISE(auto indices_hasher, + GetHashInit<Action>(dict_type.index_type()->id())(ctx, args)); + return std::make_unique<DictionaryHashKernel>( + checked_pointer_cast<HashKernel>(std::move(indices_hasher)), + dict_type.value_type()); } Status HashExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index bfe8a49a9e..aca767907c 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -226,13 +226,11 @@ struct ScalarValidateImpl { Status Visit(const StringScalar& s) { return ValidateStringScalar(s); } - Status Visit(const BinaryViewScalar& s) { - return Status::NotImplemented("Binary view"); - } + Status Visit(const BinaryViewScalar& s) { return ValidateBinaryScalar(s); } - Status Visit(const StringViewScalar& s) { - return Status::NotImplemented("String view"); - } + Status Visit(const StringViewScalar& s) { return ValidateStringScalar(s); } + + Status Visit(const LargeBinaryScalar& s) { return ValidateBinaryScalar(s); } Status Visit(const LargeStringScalar& s) { return ValidateStringScalar(s); } @@ -499,14 +497,8 @@ Status Scalar::ValidateFull() const { return ScalarValidateImpl(/*full_validation=*/true).Validate(*this); } -BinaryScalar::BinaryScalar(std::string s) - : BinaryScalar(Buffer::FromString(std::move(s))) {} - -LargeBinaryScalar::LargeBinaryScalar(std::string s) - : LargeBinaryScalar(Buffer::FromString(std::move(s))) {} - -LargeStringScalar::LargeStringScalar(std::string s) - : LargeStringScalar(Buffer::FromString(std::move(s))) {} +BaseBinaryScalar::BaseBinaryScalar(std::string s, std::shared_ptr<DataType> type) + : BaseBinaryScalar(Buffer::FromString(std::move(s)), std::move(type)) {} FixedSizeBinaryScalar::FixedSizeBinaryScalar(std::shared_ptr<Buffer> value, std::shared_ptr<DataType> type, diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index 9f41ad0975..6042f0b434 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -253,6 +253,8 @@ struct ARROW_EXPORT BaseBinaryScalar : public internal::PrimitiveScalarBase { BaseBinaryScalar(std::shared_ptr<Buffer> value, std::shared_ptr<DataType> type) : internal::PrimitiveScalarBase{std::move(type), true}, value(std::move(value)) {} + + BaseBinaryScalar(std::string s, std::shared_ptr<DataType> type); }; struct ARROW_EXPORT BinaryScalar : public BaseBinaryScalar { @@ -262,7 +264,7 @@ struct ARROW_EXPORT BinaryScalar : public BaseBinaryScalar { explicit BinaryScalar(std::shared_ptr<Buffer> value) : BinaryScalar(std::move(value), binary()) {} - explicit BinaryScalar(std::string s); + explicit BinaryScalar(std::string s) : BaseBinaryScalar(std::move(s), binary()) {} BinaryScalar() : BinaryScalar(binary()) {} }; @@ -274,6 +276,8 @@ struct ARROW_EXPORT StringScalar : public BinaryScalar { explicit StringScalar(std::shared_ptr<Buffer> value) : StringScalar(std::move(value), utf8()) {} + explicit StringScalar(std::string s) : BinaryScalar(std::move(s), utf8()) {} + StringScalar() : StringScalar(utf8()) {} }; @@ -284,6 +288,9 @@ struct ARROW_EXPORT BinaryViewScalar : public BaseBinaryScalar { explicit BinaryViewScalar(std::shared_ptr<Buffer> value) : BinaryViewScalar(std::move(value), binary_view()) {} + explicit BinaryViewScalar(std::string s) + : BaseBinaryScalar(std::move(s), binary_view()) {} + BinaryViewScalar() : BinaryViewScalar(binary_view()) {} std::string_view view() const override { return std::string_view(*this->value); } @@ -296,6 +303,9 @@ struct ARROW_EXPORT StringViewScalar : public BinaryViewScalar { explicit StringViewScalar(std::shared_ptr<Buffer> value) : StringViewScalar(std::move(value), utf8_view()) {} + explicit StringViewScalar(std::string s) + : BinaryViewScalar(std::move(s), utf8_view()) {} + StringViewScalar() : StringViewScalar(utf8_view()) {} }; @@ -309,7 +319,8 @@ struct ARROW_EXPORT LargeBinaryScalar : public BaseBinaryScalar { explicit LargeBinaryScalar(std::shared_ptr<Buffer> value) : LargeBinaryScalar(std::move(value), large_binary()) {} - explicit LargeBinaryScalar(std::string s); + explicit LargeBinaryScalar(std::string s) + : BaseBinaryScalar(std::move(s), large_binary()) {} LargeBinaryScalar() : LargeBinaryScalar(large_binary()) {} }; @@ -321,7 +332,8 @@ struct ARROW_EXPORT LargeStringScalar : public LargeBinaryScalar { explicit LargeStringScalar(std::shared_ptr<Buffer> value) : LargeStringScalar(std::move(value), large_utf8()) {} - explicit LargeStringScalar(std::string s); + explicit LargeStringScalar(std::string s) + : LargeBinaryScalar(std::move(s), large_utf8()) {} LargeStringScalar() : LargeStringScalar(large_utf8()) {} }; diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h index fc319a6d10..4d29706829 100644 --- a/cpp/src/arrow/testing/gtest_util.h +++ b/cpp/src/arrow/testing/gtest_util.h @@ -177,12 +177,16 @@ using BaseBinaryArrowTypes = ::testing::Types<BinaryType, LargeBinaryType, StringType, LargeStringType>; using BaseBinaryOrBinaryViewLikeArrowTypes = - ::testing::Types<BinaryType, LargeBinaryType, StringType, LargeStringType>; + ::testing::Types<BinaryType, LargeBinaryType, BinaryViewType, StringType, + LargeStringType, StringViewType>; using BinaryArrowTypes = ::testing::Types<BinaryType, LargeBinaryType>; using StringArrowTypes = ::testing::Types<StringType, LargeStringType>; +using StringOrStringViewArrowTypes = + ::testing::Types<StringType, LargeStringType, StringViewType>; + using ListArrowTypes = ::testing::Types<ListType, LargeListType>; using UnionArrowTypes = ::testing::Types<SparseUnionType, DenseUnionType>; diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index f4e082b3f6..faa2eb2af0 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -114,8 +114,14 @@ struct ARROW_EXPORT DataTypeLayout { std::vector<BufferSpec> buffers; /// Whether this type expects an associated dictionary array. bool has_dictionary = false; + /// If this is provided, the number of buffers expected is only lower-bounded by + /// buffers.size(). Buffers beyond this lower bound are expected to conform to + /// variadic_spec. + std::optional<BufferSpec> variadic_spec; - explicit DataTypeLayout(std::vector<BufferSpec> v) : buffers(std::move(v)) {} + explicit DataTypeLayout(std::vector<BufferSpec> buffers, + std::optional<BufferSpec> variadic_spec = {}) + : buffers(std::move(buffers)), variadic_spec(variadic_spec) {} }; /// \brief Base class for all data types @@ -701,7 +707,8 @@ class ARROW_EXPORT BinaryViewType : public DataType { DataTypeLayout layout() const override { return DataTypeLayout( - {DataTypeLayout::Bitmap(), DataTypeLayout::FixedWidth(sizeof(StringHeader))}); + {DataTypeLayout::Bitmap(), DataTypeLayout::FixedWidth(sizeof(StringHeader))}, + DataTypeLayout::VariableWidth()); } std::string ToString() const override;
