This is an automated email from the ASF dual-hosted git repository.
bkietz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new d2f140dd75 GH-40069: [C++] Make scalar scratch space immutable after
initialization (#40237)
d2f140dd75 is described below
commit d2f140dd7540420aad7685e035f887b0cc9baf91
Author: Rossi Sun <[email protected]>
AuthorDate: Wed Apr 24 00:06:30 2024 +0800
GH-40069: [C++] Make scalar scratch space immutable after initialization
(#40237)
### Rationale for this change
As #40069 shows, TSAN reports data race that is caused by concurrent
filling the scratch space of a scalar instance. The concurrent use of the same
scalar could be parallel executing an acero plan containing a literal (a
"constant" that is simply represented by an underlying scalar), and this is
totally legit. The problem lies in the fact that the scratch space of the
scalar is filled "lazily" by the time when it is being involved in the
computation and transformed to an array span, f [...]
After piloting several approaches (relaxed atomic - an earlier version of
this PR, locking - #40260), @ pitrou and @ bkietz suggested an
immutable-after-initialization approach, for which the latest version of this
PR is.
### What changes are included in this PR?
There are generally two parts in this PR:
1. Mandate the initialization of the scratch space in constructor of the
concrete subclass of `Scalar`.
2. In order to keep the content of the scratch space consistent with the
underlying `value` of the scalar, make the `value` constant. This effectively
makes legacy code that directly assigning to the `value` invalid, which is
refactored accordingly:
2.1 `BoxScalar` in
https://github.com/apache/arrow/pull/40237/files#diff-08d11e02c001c82b1aa89565e16760a8bcca4a608c22619fb45da42fd0ebebac
2.2 `Scalar::CastTo` in
https://github.com/apache/arrow/pull/40237/files#diff-b4b83682450006616fa7e4f6f2ea3031cf1a22d734f4bee42a99af313e808f9e
2.3 `ScalarMinMax` in
https://github.com/apache/arrow/pull/40237/files#diff-368ab7e748bd4432c92d9fdc26b51e131742b968e3eb32a6fcea4b9f02fa36aa
Besides, when refactoring 2.2, I found the current `Scalar::CastTo` is not
fully covered by the existing tests. So I also added some lacking ones.
### Are these changes tested?
Yes.
### Are there any user-facing changes?
**This PR includes breaking changes to public APIs.**
The `value` member of `BaseBinaryScalar` and subclasses/`BaseListScalar`
and subclasses/`SparseUnionScalar`/`DenseUnionScalar`/`RunEndEncodedScalar` is
made constant, thus code directly assigning to this member will no more compile.
Also the `Scalar::mutable_data()` member function is removed because it's
against the immutable nature of `Scalar`.
However the impact of these changes seems limited. I don't think much user
code is depending on these two old pieces of code.
Also after an quick search, I didn't find any document that need to be
updated according to this change. There could be none. But if there is, may
someone please redirect me to it so I can update. Thanks.
* GitHub Issue: #40069
Lead-authored-by: Ruoxi Sun <[email protected]>
Co-authored-by: Rossi Sun <[email protected]>
Signed-off-by: Benjamin Kietzman <[email protected]>
---
c_glib/arrow-glib/scalar.cpp | 6 +-
cpp/src/arrow/array/array_test.cc | 36 ++
cpp/src/arrow/array/data.cc | 90 +++--
cpp/src/arrow/compute/kernels/codegen_internal.h | 37 ---
cpp/src/arrow/compute/kernels/scalar_compare.cc | 17 +-
cpp/src/arrow/scalar.cc | 405 +++++++++++++++--------
cpp/src/arrow/scalar.h | 213 +++++++++---
cpp/src/arrow/scalar_test.cc | 314 ++++++++++++++----
8 files changed, 764 insertions(+), 354 deletions(-)
diff --git a/c_glib/arrow-glib/scalar.cpp b/c_glib/arrow-glib/scalar.cpp
index def6b15148..f965b49703 100644
--- a/c_glib/arrow-glib/scalar.cpp
+++ b/c_glib/arrow-glib/scalar.cpp
@@ -1063,7 +1063,8 @@
garrow_base_binary_scalar_get_value(GArrowBaseBinaryScalar *scalar)
if (!priv->value) {
const auto arrow_scalar =
std::static_pointer_cast<arrow::BaseBinaryScalar>(
garrow_scalar_get_raw(GARROW_SCALAR(scalar)));
- priv->value = garrow_buffer_new_raw(&(arrow_scalar->value));
+ priv->value = garrow_buffer_new_raw(
+ const_cast<std::shared_ptr<arrow::Buffer> *>(&(arrow_scalar->value)));
}
return priv->value;
}
@@ -1983,7 +1984,8 @@ garrow_base_list_scalar_get_value(GArrowBaseListScalar
*scalar)
if (!priv->value) {
const auto arrow_scalar = std::static_pointer_cast<arrow::BaseListScalar>(
garrow_scalar_get_raw(GARROW_SCALAR(scalar)));
- priv->value = garrow_array_new_raw(&(arrow_scalar->value));
+ priv->value = garrow_array_new_raw(
+ const_cast<std::shared_ptr<arrow::Array> *>(&(arrow_scalar->value)));
}
return priv->value;
}
diff --git a/cpp/src/arrow/array/array_test.cc
b/cpp/src/arrow/array/array_test.cc
index 60efdb4768..b0d7fe740a 100644
--- a/cpp/src/arrow/array/array_test.cc
+++ b/cpp/src/arrow/array/array_test.cc
@@ -22,6 +22,7 @@
#include <cmath>
#include <cstdint>
#include <cstring>
+#include <future>
#include <limits>
#include <memory>
#include <numeric>
@@ -823,6 +824,41 @@ TEST_F(TestArray, TestFillFromScalar) {
}
}
+// GH-40069: Data-race when concurrent calling ArraySpan::FillFromScalar of
the same
+// scalar instance.
+TEST_F(TestArray, TestConcurrentFillFromScalar) {
+ for (auto type : TestArrayUtilitiesAgainstTheseTypes()) {
+ ARROW_SCOPED_TRACE("type = ", type->ToString());
+ for (auto seed : {0u, 0xdeadbeef, 42u}) {
+ ARROW_SCOPED_TRACE("seed = ", seed);
+
+ Field field("", type, /*nullable=*/true,
+ key_value_metadata({{"extension_allow_random_storage",
"true"}}));
+ auto array = random::GenerateArray(field, 1, seed);
+
+ ASSERT_OK_AND_ASSIGN(auto scalar, array->GetScalar(0));
+
+ // Lambda to create fill an ArraySpan with the scalar and use the
ArraySpan a bit.
+ auto array_span_from_scalar = [&]() {
+ ArraySpan span(*scalar);
+ auto roundtripped_array = span.ToArray();
+ ASSERT_OK(roundtripped_array->ValidateFull());
+
+ AssertArraysEqual(*array, *roundtripped_array);
+ ASSERT_OK_AND_ASSIGN(auto roundtripped_scalar,
roundtripped_array->GetScalar(0));
+ AssertScalarsEqual(*scalar, *roundtripped_scalar);
+ };
+
+ // Two concurrent calls to the lambda are just enough for TSAN to detect
a race
+ // condition.
+ auto fut1 = std::async(std::launch::async, array_span_from_scalar);
+ auto fut2 = std::async(std::launch::async, array_span_from_scalar);
+ fut1.get();
+ fut2.get();
+ }
+ }
+}
+
TEST_F(TestArray, ExtensionSpanRoundTrip) {
// Other types are checked in MakeEmptyArray but MakeEmptyArray doesn't
// work for extension types so we check that here
diff --git a/cpp/src/arrow/array/data.cc b/cpp/src/arrow/array/data.cc
index 80c411dfa6..ff3112ec1f 100644
--- a/cpp/src/arrow/array/data.cc
+++ b/cpp/src/arrow/array/data.cc
@@ -283,25 +283,15 @@ void ArraySpan::SetMembers(const ArrayData& data) {
namespace {
-template <typename offset_type>
-BufferSpan OffsetsForScalar(uint8_t* scratch_space, offset_type value_size) {
- auto* offsets = reinterpret_cast<offset_type*>(scratch_space);
- offsets[0] = 0;
- offsets[1] = static_cast<offset_type>(value_size);
- static_assert(2 * sizeof(offset_type) <= 16);
- return {scratch_space, sizeof(offset_type) * 2};
+BufferSpan OffsetsForScalar(uint8_t* scratch_space, int64_t offset_width) {
+ return {scratch_space, offset_width * 2};
}
-template <typename offset_type>
std::pair<BufferSpan, BufferSpan> OffsetsAndSizesForScalar(uint8_t*
scratch_space,
- offset_type
value_size) {
+ int64_t
offset_width) {
auto* offsets = scratch_space;
- auto* sizes = scratch_space + sizeof(offset_type);
- reinterpret_cast<offset_type*>(offsets)[0] = 0;
- reinterpret_cast<offset_type*>(sizes)[0] = value_size;
- static_assert(2 * sizeof(offset_type) <= 16);
- return {BufferSpan{offsets, sizeof(offset_type)},
- BufferSpan{sizes, sizeof(offset_type)}};
+ auto* sizes = scratch_space + offset_width;
+ return {BufferSpan{offsets, offset_width}, BufferSpan{sizes, offset_width}};
}
int GetNumBuffers(const DataType& type) {
@@ -415,26 +405,23 @@ void ArraySpan::FillFromScalar(const Scalar& value) {
data_size = scalar.value->size();
}
if (is_binary_like(type_id)) {
- this->buffers[1] =
- OffsetsForScalar(scalar.scratch_space_,
static_cast<int32_t>(data_size));
+ const auto& binary_scalar = checked_cast<const BinaryScalar&>(value);
+ this->buffers[1] = OffsetsForScalar(binary_scalar.scratch_space_,
sizeof(int32_t));
} else {
// is_large_binary_like
- this->buffers[1] = OffsetsForScalar(scalar.scratch_space_, data_size);
+ const auto& large_binary_scalar = checked_cast<const
LargeBinaryScalar&>(value);
+ this->buffers[1] =
+ OffsetsForScalar(large_binary_scalar.scratch_space_,
sizeof(int64_t));
}
this->buffers[2].data = const_cast<uint8_t*>(data_buffer);
this->buffers[2].size = data_size;
} else if (type_id == Type::BINARY_VIEW || type_id == Type::STRING_VIEW) {
- const auto& scalar = checked_cast<const BaseBinaryScalar&>(value);
+ const auto& scalar = checked_cast<const BinaryViewScalar&>(value);
this->buffers[1].size = BinaryViewType::kSize;
this->buffers[1].data = scalar.scratch_space_;
- static_assert(sizeof(BinaryViewType::c_type) <=
sizeof(scalar.scratch_space_));
- auto* view = new (&scalar.scratch_space_) BinaryViewType::c_type;
if (scalar.is_valid) {
- *view = util::ToBinaryView(std::string_view{*scalar.value}, 0, 0);
this->buffers[2] = internal::PackVariadicBuffers({&scalar.value, 1});
- } else {
- *view = {};
}
} else if (type_id == Type::FIXED_SIZE_BINARY) {
const auto& scalar = checked_cast<const BaseBinaryScalar&>(value);
@@ -443,12 +430,10 @@ void ArraySpan::FillFromScalar(const Scalar& value) {
} else if (is_var_length_list_like(type_id) || type_id ==
Type::FIXED_SIZE_LIST) {
const auto& scalar = checked_cast<const BaseListScalar&>(value);
- int64_t value_length = 0;
this->child_data.resize(1);
if (scalar.value != nullptr) {
// When the scalar is null, scalar.value can also be null
this->child_data[0].SetMembers(*scalar.value->data());
- value_length = scalar.value->length();
} else {
// Even when the value is null, we still must populate the
// child_data to yield a valid array. Tedious
@@ -456,17 +441,25 @@ void ArraySpan::FillFromScalar(const Scalar& value) {
&this->child_data[0]);
}
- if (type_id == Type::LIST || type_id == Type::MAP) {
- this->buffers[1] =
- OffsetsForScalar(scalar.scratch_space_,
static_cast<int32_t>(value_length));
+ if (type_id == Type::LIST) {
+ const auto& list_scalar = checked_cast<const ListScalar&>(value);
+ this->buffers[1] = OffsetsForScalar(list_scalar.scratch_space_,
sizeof(int32_t));
+ } else if (type_id == Type::MAP) {
+ const auto& map_scalar = checked_cast<const MapScalar&>(value);
+ this->buffers[1] = OffsetsForScalar(map_scalar.scratch_space_,
sizeof(int32_t));
} else if (type_id == Type::LARGE_LIST) {
- this->buffers[1] = OffsetsForScalar(scalar.scratch_space_, value_length);
+ const auto& large_list_scalar = checked_cast<const
LargeListScalar&>(value);
+ this->buffers[1] =
+ OffsetsForScalar(large_list_scalar.scratch_space_, sizeof(int64_t));
} else if (type_id == Type::LIST_VIEW) {
- std::tie(this->buffers[1], this->buffers[2]) = OffsetsAndSizesForScalar(
- scalar.scratch_space_, static_cast<int32_t>(value_length));
- } else if (type_id == Type::LARGE_LIST_VIEW) {
+ const auto& list_view_scalar = checked_cast<const
ListViewScalar&>(value);
std::tie(this->buffers[1], this->buffers[2]) =
- OffsetsAndSizesForScalar(scalar.scratch_space_, value_length);
+ OffsetsAndSizesForScalar(list_view_scalar.scratch_space_,
sizeof(int32_t));
+ } else if (type_id == Type::LARGE_LIST_VIEW) {
+ const auto& large_list_view_scalar =
+ checked_cast<const LargeListViewScalar&>(value);
+ std::tie(this->buffers[1], this->buffers[2]) = OffsetsAndSizesForScalar(
+ large_list_view_scalar.scratch_space_, sizeof(int64_t));
} else {
DCHECK_EQ(type_id, Type::FIXED_SIZE_LIST);
// FIXED_SIZE_LIST: does not have a second buffer
@@ -480,27 +473,19 @@ void ArraySpan::FillFromScalar(const Scalar& value) {
this->child_data[i].FillFromScalar(*scalar.value[i]);
}
} else if (is_union(type_id)) {
- // Dense union needs scratch space to store both offsets and a type code
- struct UnionScratchSpace {
- alignas(int64_t) int8_t type_code;
- alignas(int64_t) uint8_t offsets[sizeof(int32_t) * 2];
- };
- static_assert(sizeof(UnionScratchSpace) <=
sizeof(UnionScalar::scratch_space_));
- auto* union_scratch_space = reinterpret_cast<UnionScratchSpace*>(
- &checked_cast<const UnionScalar&>(value).scratch_space_);
-
// First buffer is kept null since unions have no validity vector
this->buffers[0] = {};
- union_scratch_space->type_code = checked_cast<const
UnionScalar&>(value).type_code;
- this->buffers[1].data =
reinterpret_cast<uint8_t*>(&union_scratch_space->type_code);
- this->buffers[1].size = 1;
-
this->child_data.resize(this->type->num_fields());
if (type_id == Type::DENSE_UNION) {
const auto& scalar = checked_cast<const DenseUnionScalar&>(value);
- this->buffers[2] =
- OffsetsForScalar(union_scratch_space->offsets,
static_cast<int32_t>(1));
+ auto* union_scratch_space =
+
reinterpret_cast<UnionScalar::UnionScratchSpace*>(&scalar.scratch_space_);
+
+ this->buffers[1].data =
reinterpret_cast<uint8_t*>(&union_scratch_space->type_code);
+ this->buffers[1].size = 1;
+
+ this->buffers[2] = OffsetsForScalar(union_scratch_space->offsets,
sizeof(int32_t));
// We can't "see" the other arrays in the union, but we put the "active"
// union array in the right place and fill zero-length arrays for the
// others
@@ -517,6 +502,12 @@ void ArraySpan::FillFromScalar(const Scalar& value) {
}
} else {
const auto& scalar = checked_cast<const SparseUnionScalar&>(value);
+ auto* union_scratch_space =
+
reinterpret_cast<UnionScalar::UnionScratchSpace*>(&scalar.scratch_space_);
+
+ this->buffers[1].data =
reinterpret_cast<uint8_t*>(&union_scratch_space->type_code);
+ this->buffers[1].size = 1;
+
// Sparse union scalars have a full complement of child values even
// though only one of them is relevant, so we just fill them in here
for (int i = 0; i < static_cast<int>(this->child_data.size()); ++i) {
@@ -541,7 +532,6 @@ void ArraySpan::FillFromScalar(const Scalar& value) {
e.null_count = 0;
e.buffers[1].data = scalar.scratch_space_;
e.buffers[1].size = sizeof(run_end);
- reinterpret_cast<decltype(run_end)*>(scalar.scratch_space_)[0] = run_end;
};
switch (scalar.run_end_type()->id()) {
diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h
b/cpp/src/arrow/compute/kernels/codegen_internal.h
index 72b29057b8..097ee1de45 100644
--- a/cpp/src/arrow/compute/kernels/codegen_internal.h
+++ b/cpp/src/arrow/compute/kernels/codegen_internal.h
@@ -369,43 +369,6 @@ struct UnboxScalar<Decimal256Type> {
}
};
-template <typename Type, typename Enable = void>
-struct BoxScalar;
-
-template <typename Type>
-struct BoxScalar<Type, enable_if_has_c_type<Type>> {
- using T = typename GetOutputType<Type>::T;
- static void Box(T val, Scalar* out) {
- // Enables BoxScalar<Int64Type> to work on a (for example) Time64Scalar
- T* mutable_data = reinterpret_cast<T*>(
-
checked_cast<::arrow::internal::PrimitiveScalarBase*>(out)->mutable_data());
- *mutable_data = val;
- }
-};
-
-template <typename Type>
-struct BoxScalar<Type, enable_if_base_binary<Type>> {
- using T = typename GetOutputType<Type>::T;
- using ScalarType = typename TypeTraits<Type>::ScalarType;
- static void Box(T val, Scalar* out) {
- checked_cast<ScalarType*>(out)->value = std::make_shared<Buffer>(val);
- }
-};
-
-template <>
-struct BoxScalar<Decimal128Type> {
- using T = Decimal128;
- using ScalarType = Decimal128Scalar;
- static void Box(T val, Scalar* out) { checked_cast<ScalarType*>(out)->value
= val; }
-};
-
-template <>
-struct BoxScalar<Decimal256Type> {
- using T = Decimal256;
- using ScalarType = Decimal256Scalar;
- static void Box(T val, Scalar* out) { checked_cast<ScalarType*>(out)->value
= val; }
-};
-
// A VisitArraySpanInline variant that calls its visitor function with logical
// values, such as Decimal128 rather than std::string_view.
diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc
b/cpp/src/arrow/compute/kernels/scalar_compare.cc
index daf8ed76d6..9b2fd987d8 100644
--- a/cpp/src/arrow/compute/kernels/scalar_compare.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc
@@ -491,8 +491,9 @@ template <typename OutType, typename Op>
struct ScalarMinMax {
using OutValue = typename GetOutputType<OutType>::T;
- static void ExecScalar(const ExecSpan& batch,
- const ElementWiseAggregateOptions& options, Scalar*
out) {
+ static Result<std::shared_ptr<Scalar>> ExecScalar(
+ const ExecSpan& batch, const ElementWiseAggregateOptions& options,
+ std::shared_ptr<DataType> type) {
// All arguments are scalar
OutValue value{};
bool valid = false;
@@ -502,8 +503,8 @@ struct ScalarMinMax {
const Scalar& scalar = *arg.scalar;
if (!scalar.is_valid) {
if (options.skip_nulls) continue;
- out->is_valid = false;
- return;
+ valid = false;
+ break;
}
if (!valid) {
value = UnboxScalar<OutType>::Unbox(scalar);
@@ -513,9 +514,10 @@ struct ScalarMinMax {
value, UnboxScalar<OutType>::Unbox(scalar));
}
}
- out->is_valid = valid;
if (valid) {
- BoxScalar<OutType>::Box(value, out);
+ return MakeScalar(std::move(type), std::move(value));
+ } else {
+ return MakeNullScalar(std::move(type));
}
}
@@ -537,8 +539,7 @@ struct ScalarMinMax {
bool initialize_output = true;
if (scalar_count > 0) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> temp_scalar,
- MakeScalar(out->type()->GetSharedPtr(), 0));
- ExecScalar(batch, options, temp_scalar.get());
+ ExecScalar(batch, options,
out->type()->GetSharedPtr()));
if (temp_scalar->is_valid) {
const auto value = UnboxScalar<OutType>::Unbox(*temp_scalar);
initialize_output = false;
diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc
index 6996b46c8b..8e8d390366 100644
--- a/cpp/src/arrow/scalar.cc
+++ b/cpp/src/arrow/scalar.cc
@@ -542,6 +542,12 @@ struct ScalarValidateImpl {
}
};
+template <typename T, size_t N>
+void FillScalarScratchSpace(void* scratch_space, T const (&arr)[N]) {
+ static_assert(sizeof(arr) <= internal::kScalarScratchSpaceSize);
+ std::memcpy(scratch_space, arr, sizeof(arr));
+}
+
} // namespace
size_t Scalar::hash() const { return ScalarHashImpl(*this).hash_; }
@@ -557,6 +563,28 @@ Status Scalar::ValidateFull() const {
BaseBinaryScalar::BaseBinaryScalar(std::string s, std::shared_ptr<DataType>
type)
: BaseBinaryScalar(Buffer::FromString(std::move(s)), std::move(type)) {}
+void BinaryScalar::FillScratchSpace() {
+ FillScalarScratchSpace(
+ scratch_space_,
+ {int32_t(0), value ? static_cast<int32_t>(value->size()) : int32_t(0)});
+}
+
+void BinaryViewScalar::FillScratchSpace() {
+ static_assert(sizeof(BinaryViewType::c_type) <=
internal::kScalarScratchSpaceSize);
+ auto* view = new (&scratch_space_) BinaryViewType::c_type;
+ if (value) {
+ *view = util::ToBinaryView(std::string_view{*value}, 0, 0);
+ } else {
+ *view = {};
+ }
+}
+
+void LargeBinaryScalar::FillScratchSpace() {
+ FillScalarScratchSpace(
+ scratch_space_,
+ {int64_t(0), value ? static_cast<int64_t>(value->size()) : int64_t(0)});
+}
+
FixedSizeBinaryScalar::FixedSizeBinaryScalar(std::shared_ptr<Buffer> value,
std::shared_ptr<DataType> type,
bool is_valid)
@@ -578,21 +606,45 @@ FixedSizeBinaryScalar::FixedSizeBinaryScalar(std::string
s, bool is_valid)
BaseListScalar::BaseListScalar(std::shared_ptr<Array> value,
std::shared_ptr<DataType> type, bool is_valid)
: Scalar{std::move(type), is_valid}, value(std::move(value)) {
- ARROW_CHECK(this->type->field(0)->type()->Equals(this->value->type()));
+ if (this->value) {
+ ARROW_CHECK(this->type->field(0)->type()->Equals(this->value->type()));
+ }
}
ListScalar::ListScalar(std::shared_ptr<Array> value, bool is_valid)
: BaseListScalar(value, list(value->type()), is_valid) {}
+void ListScalar::FillScratchSpace() {
+ FillScalarScratchSpace(
+ scratch_space_,
+ {int32_t(0), value ? static_cast<int32_t>(value->length()) :
int32_t(0)});
+}
+
LargeListScalar::LargeListScalar(std::shared_ptr<Array> value, bool is_valid)
: BaseListScalar(value, large_list(value->type()), is_valid) {}
+void LargeListScalar::FillScratchSpace() {
+ FillScalarScratchSpace(scratch_space_,
+ {int64_t(0), value ? value->length() : int64_t(0)});
+}
+
ListViewScalar::ListViewScalar(std::shared_ptr<Array> value, bool is_valid)
: BaseListScalar(value, list_view(value->type()), is_valid) {}
+void ListViewScalar::FillScratchSpace() {
+ FillScalarScratchSpace(
+ scratch_space_,
+ {int32_t(0), value ? static_cast<int32_t>(value->length()) :
int32_t(0)});
+}
+
LargeListViewScalar::LargeListViewScalar(std::shared_ptr<Array> value, bool
is_valid)
: BaseListScalar(value, large_list_view(value->type()), is_valid) {}
+void LargeListViewScalar::FillScratchSpace() {
+ FillScalarScratchSpace(scratch_space_,
+ {int64_t(0), value ? value->length() : int64_t(0)});
+}
+
inline std::shared_ptr<DataType> MakeMapType(const std::shared_ptr<DataType>&
pair_type) {
ARROW_CHECK_EQ(pair_type->id(), Type::STRUCT);
ARROW_CHECK_EQ(pair_type->num_fields(), 2);
@@ -602,11 +654,19 @@ inline std::shared_ptr<DataType> MakeMapType(const
std::shared_ptr<DataType>& pa
MapScalar::MapScalar(std::shared_ptr<Array> value, bool is_valid)
: BaseListScalar(value, MakeMapType(value->type()), is_valid) {}
+void MapScalar::FillScratchSpace() {
+ FillScalarScratchSpace(
+ scratch_space_,
+ {int32_t(0), value ? static_cast<int32_t>(value->length()) :
int32_t(0)});
+}
+
FixedSizeListScalar::FixedSizeListScalar(std::shared_ptr<Array> value,
std::shared_ptr<DataType> type, bool
is_valid)
- : BaseListScalar(value, std::move(type), is_valid) {
- ARROW_CHECK_EQ(this->value->length(),
- checked_cast<const
FixedSizeListType&>(*this->type).list_size());
+ : BaseListScalar(std::move(value), std::move(type), is_valid) {
+ if (this->value) {
+ ARROW_CHECK_EQ(this->value->length(),
+ checked_cast<const
FixedSizeListType&>(*this->type).list_size());
+ }
}
FixedSizeListScalar::FixedSizeListScalar(std::shared_ptr<Array> value, bool
is_valid)
@@ -656,6 +716,21 @@ RunEndEncodedScalar::RunEndEncodedScalar(const
std::shared_ptr<DataType>& type)
RunEndEncodedScalar::~RunEndEncodedScalar() = default;
+void RunEndEncodedScalar::FillScratchSpace() {
+ auto run_end = run_end_type()->id();
+ switch (run_end) {
+ case Type::INT16:
+ FillScalarScratchSpace(scratch_space_, {int16_t(1)});
+ break;
+ case Type::INT32:
+ FillScalarScratchSpace(scratch_space_, {int32_t(1)});
+ break;
+ default:
+ DCHECK_EQ(run_end, Type::INT64);
+ FillScalarScratchSpace(scratch_space_, {int64_t(1)});
+ }
+}
+
DictionaryScalar::DictionaryScalar(std::shared_ptr<DataType> type)
: internal::PrimitiveScalarBase(std::move(type)),
value{MakeNullScalar(checked_cast<const
DictionaryType&>(*this->type).index_type()),
@@ -732,11 +807,14 @@ SparseUnionScalar::SparseUnionScalar(ValueType value,
int8_t type_code,
std::shared_ptr<DataType> type)
: UnionScalar(std::move(type), type_code, /*is_valid=*/true),
value(std::move(value)) {
- this->child_id =
- checked_cast<const SparseUnionType&>(*this->type).child_ids()[type_code];
+ const auto child_ids = checked_cast<const
SparseUnionType&>(*this->type).child_ids();
+ if (type_code >= 0 && static_cast<size_t>(type_code) < child_ids.size() &&
+ child_ids[type_code] != UnionType::kInvalidChildId) {
+ this->child_id = child_ids[type_code];
- // Fix nullness based on whether the selected child is null
- this->is_valid = this->value[this->child_id]->is_valid;
+ // Fix nullness based on whether the selected child is null
+ this->is_valid = this->value[this->child_id]->is_valid;
+ }
}
std::shared_ptr<Scalar> SparseUnionScalar::FromValue(std::shared_ptr<Scalar>
value,
@@ -755,6 +833,17 @@ std::shared_ptr<Scalar>
SparseUnionScalar::FromValue(std::shared_ptr<Scalar> val
return std::make_shared<SparseUnionScalar>(field_values, type_code,
std::move(type));
}
+void SparseUnionScalar::FillScratchSpace() {
+ auto* union_scratch_space =
reinterpret_cast<UnionScratchSpace*>(&scratch_space_);
+ union_scratch_space->type_code = type_code;
+}
+
+void DenseUnionScalar::FillScratchSpace() {
+ auto* union_scratch_space =
reinterpret_cast<UnionScratchSpace*>(&scratch_space_);
+ union_scratch_space->type_code = type_code;
+ FillScalarScratchSpace(union_scratch_space->offsets, {int32_t(0),
int32_t(1)});
+}
+
namespace {
template <typename T>
@@ -969,58 +1058,72 @@ std::shared_ptr<Buffer> FormatToBuffer(Formatter&&
formatter, const ScalarType&
}
// error fallback
-Status CastImpl(const Scalar& from, Scalar* to) {
+template <typename To>
+Result<std::shared_ptr<Scalar>> CastImpl(const Scalar& from,
+ std::shared_ptr<DataType> to_type) {
return Status::NotImplemented("casting scalars of type ", *from.type, " to
type ",
- *to->type);
+ *to_type);
}
// numeric to numeric
-template <typename From, typename To>
-Status CastImpl(const NumericScalar<From>& from, NumericScalar<To>* to) {
- to->value = static_cast<typename To::c_type>(from.value);
- return Status::OK();
+template <typename To, typename From>
+enable_if_number<To, Result<std::shared_ptr<Scalar>>> CastImpl(
+ const NumericScalar<From>& from, std::shared_ptr<DataType> to_type) {
+ using ToScalar = typename TypeTraits<To>::ScalarType;
+ return std::make_shared<ToScalar>(static_cast<typename
To::c_type>(from.value),
+ std::move(to_type));
}
// numeric to boolean
-template <typename T>
-Status CastImpl(const NumericScalar<T>& from, BooleanScalar* to) {
- constexpr auto zero = static_cast<typename T::c_type>(0);
- to->value = from.value != zero;
- return Status::OK();
+template <typename To, typename From>
+enable_if_boolean<To, Result<std::shared_ptr<Scalar>>> CastImpl(
+ const NumericScalar<From>& from, std::shared_ptr<DataType> to_type) {
+ constexpr auto zero = static_cast<typename From::c_type>(0);
+ return std::make_shared<BooleanScalar>(from.value != zero,
std::move(to_type));
}
// boolean to numeric
-template <typename T>
-Status CastImpl(const BooleanScalar& from, NumericScalar<T>* to) {
- to->value = static_cast<typename T::c_type>(from.value);
- return Status::OK();
+template <typename To>
+enable_if_number<To, Result<std::shared_ptr<Scalar>>> CastImpl(
+ const BooleanScalar& from, std::shared_ptr<DataType> to_type) {
+ using ToScalar = typename TypeTraits<To>::ScalarType;
+ return std::make_shared<ToScalar>(static_cast<typename
To::c_type>(from.value),
+ std::move(to_type));
}
// numeric to temporal
-template <typename From, typename To>
+template <typename To, typename From>
typename std::enable_if<std::is_base_of<TemporalType, To>::value &&
!std::is_same<DayTimeIntervalType, To>::value &&
!std::is_same<MonthDayNanoIntervalType, To>::value,
- Status>::type
-CastImpl(const NumericScalar<From>& from, TemporalScalar<To>* to) {
- to->value = static_cast<typename To::c_type>(from.value);
- return Status::OK();
+ Result<std::shared_ptr<Scalar>>>::type
+CastImpl(const NumericScalar<From>& from, std::shared_ptr<DataType> to_type) {
+ using ToScalar = typename TypeTraits<To>::ScalarType;
+ return std::make_shared<ToScalar>(static_cast<typename
To::c_type>(from.value),
+ std::move(to_type));
}
// temporal to numeric
-template <typename From, typename To>
-typename std::enable_if<std::is_base_of<TemporalType, From>::value &&
+template <typename To, typename From>
+typename std::enable_if<is_number_type<To>::value &&
+ std::is_base_of<TemporalType, From>::value &&
!std::is_same<DayTimeIntervalType, From>::value &&
!std::is_same<MonthDayNanoIntervalType,
From>::value,
- Status>::type
-CastImpl(const TemporalScalar<From>& from, NumericScalar<To>* to) {
- to->value = static_cast<typename To::c_type>(from.value);
- return Status::OK();
+ Result<std::shared_ptr<Scalar>>>::type
+CastImpl(const TemporalScalar<From>& from, std::shared_ptr<DataType> to_type) {
+ using ToScalar = typename TypeTraits<To>::ScalarType;
+ return std::make_shared<ToScalar>(static_cast<typename
To::c_type>(from.value),
+ std::move(to_type));
}
// timestamp to timestamp
-Status CastImpl(const TimestampScalar& from, TimestampScalar* to) {
- return util::ConvertTimestampValue(from.type, to->type,
from.value).Value(&to->value);
+template <typename To>
+enable_if_timestamp<To, Result<std::shared_ptr<Scalar>>> CastImpl(
+ const TimestampScalar& from, std::shared_ptr<DataType> to_type) {
+ using ToScalar = typename TypeTraits<To>::ScalarType;
+ ARROW_ASSIGN_OR_RAISE(auto value,
+ util::ConvertTimestampValue(from.type, to_type,
from.value));
+ return std::make_shared<ToScalar>(value, std::move(to_type));
}
template <typename TypeWithTimeUnit>
@@ -1029,101 +1132,117 @@ std::shared_ptr<DataType> AsTimestampType(const
std::shared_ptr<DataType>& type)
}
// duration to duration
-Status CastImpl(const DurationScalar& from, DurationScalar* to) {
- return util::ConvertTimestampValue(AsTimestampType<DurationType>(from.type),
- AsTimestampType<DurationType>(to->type),
from.value)
- .Value(&to->value);
+template <typename To>
+enable_if_duration<To, Result<std::shared_ptr<Scalar>>> CastImpl(
+ const DurationScalar& from, std::shared_ptr<DataType> to_type) {
+ using ToScalar = typename TypeTraits<To>::ScalarType;
+ ARROW_ASSIGN_OR_RAISE(
+ auto value,
+ util::ConvertTimestampValue(AsTimestampType<DurationType>(from.type),
+ AsTimestampType<DurationType>(to_type),
from.value));
+ return std::make_shared<ToScalar>(value, std::move(to_type));
}
// time to time
-template <typename F, typename ToScalar, typename T = typename
ToScalar::TypeClass>
-enable_if_time<T, Status> CastImpl(const TimeScalar<F>& from, ToScalar* to) {
- return util::ConvertTimestampValue(AsTimestampType<F>(from.type),
- AsTimestampType<T>(to->type), from.value)
- .Value(&to->value);
+template <typename To, typename From, typename T = typename To::TypeClass>
+enable_if_time<To, Result<std::shared_ptr<Scalar>>> CastImpl(
+ const TimeScalar<From>& from, std::shared_ptr<DataType> to_type) {
+ using ToScalar = typename TypeTraits<To>::ScalarType;
+ ARROW_ASSIGN_OR_RAISE(
+ auto value, util::ConvertTimestampValue(AsTimestampType<From>(from.type),
+ AsTimestampType<To>(to_type),
from.value));
+ return std::make_shared<ToScalar>(value, std::move(to_type));
}
constexpr int64_t kMillisecondsInDay = 86400000;
// date to date
-Status CastImpl(const Date32Scalar& from, Date64Scalar* to) {
- to->value = from.value * kMillisecondsInDay;
- return Status::OK();
+template <typename To>
+enable_if_t<std::is_same<To, Date64Scalar>::value,
Result<std::shared_ptr<Scalar>>>
+CastImpl(const Date32Scalar& from, std::shared_ptr<DataType> to_type) {
+ return std::make_shared<Date64Scalar>(from.value * kMillisecondsInDay,
+ std::move(to_type));
}
-Status CastImpl(const Date64Scalar& from, Date32Scalar* to) {
- to->value = static_cast<int32_t>(from.value / kMillisecondsInDay);
- return Status::OK();
+template <typename To>
+enable_if_t<std::is_same<To, Date32Scalar>::value,
Result<std::shared_ptr<Scalar>>>
+CastImpl(const Date64Scalar& from, std::shared_ptr<DataType> to_type) {
+ return std::make_shared<Date32Scalar>(
+ static_cast<int32_t>(from.value / kMillisecondsInDay),
std::move(to_type));
}
// timestamp to date
-Status CastImpl(const TimestampScalar& from, Date64Scalar* to) {
+template <typename To>
+enable_if_t<std::is_same<To, Date64Scalar>::value,
Result<std::shared_ptr<Scalar>>>
+CastImpl(const TimestampScalar& from, std::shared_ptr<DataType> to_type) {
ARROW_ASSIGN_OR_RAISE(
auto millis,
util::ConvertTimestampValue(from.type, timestamp(TimeUnit::MILLI),
from.value));
- to->value = millis - millis % kMillisecondsInDay;
- return Status::OK();
+ return std::make_shared<Date64Scalar>(millis - millis % kMillisecondsInDay,
+ std::move(to_type));
}
-Status CastImpl(const TimestampScalar& from, Date32Scalar* to) {
+template <typename To>
+enable_if_t<std::is_same<To, Date32Scalar>::value,
Result<std::shared_ptr<Scalar>>>
+CastImpl(const TimestampScalar& from, std::shared_ptr<DataType> to_type) {
ARROW_ASSIGN_OR_RAISE(
auto millis,
util::ConvertTimestampValue(from.type, timestamp(TimeUnit::MILLI),
from.value));
- to->value = static_cast<int32_t>(millis / kMillisecondsInDay);
- return Status::OK();
+ return std::make_shared<Date32Scalar>(static_cast<int32_t>(millis /
kMillisecondsInDay),
+ std::move(to_type));
}
// date to timestamp
-template <typename D>
-Status CastImpl(const DateScalar<D>& from, TimestampScalar* to) {
+template <typename To, typename From>
+enable_if_timestamp<Result<std::shared_ptr<To>>> CastImpl(
+ const DateScalar<From>& from, std::shared_ptr<DataType> to_type) {
+ using ToScalar = typename TypeTraits<To>::ScalarType;
int64_t millis = from.value;
- if (std::is_same<D, Date32Type>::value) {
+ if (std::is_same<From, Date32Type>::value) {
millis *= kMillisecondsInDay;
}
- return util::ConvertTimestampValue(timestamp(TimeUnit::MILLI), to->type,
millis)
- .Value(&to->value);
+ ARROW_ASSIGN_OR_RAISE(auto value, util::ConvertTimestampValue(
+ timestamp(TimeUnit::MILLI), to_type,
millis));
+ return std::make_shared<ToScalar>(value, std::move(to_type));
}
// string to any
-template <typename ScalarType>
-Status CastImpl(const StringScalar& from, ScalarType* to) {
- ARROW_ASSIGN_OR_RAISE(auto out, Scalar::Parse(to->type,
std::string_view(*from.value)));
- to->value = std::move(checked_cast<ScalarType&>(*out).value);
- return Status::OK();
+template <typename To>
+Result<std::shared_ptr<Scalar>> CastImpl(const StringScalar& from,
+ std::shared_ptr<DataType> to_type) {
+ using ToScalar = typename TypeTraits<To>::ScalarType;
+ ARROW_ASSIGN_OR_RAISE(auto out,
+ Scalar::Parse(std::move(to_type),
std::string_view(*from.value)));
+ DCHECK(checked_pointer_cast<ToScalar>(out) != nullptr);
+ return std::move(out);
}
// binary/large binary/large string to string
-template <typename ScalarType>
-enable_if_t<std::is_base_of_v<BaseBinaryScalar, ScalarType> &&
- !std::is_same<ScalarType, StringScalar>::value,
- Status>
-CastImpl(const ScalarType& from, StringScalar* to) {
- to->value = from.value;
- return Status::OK();
+template <typename To, typename From>
+enable_if_t<std::is_same<To, StringType>::value &&
+ std::is_base_of_v<BaseBinaryScalar, From> &&
+ !std::is_same<From, StringScalar>::value,
+ Result<std::shared_ptr<Scalar>>>
+CastImpl(const From& from, std::shared_ptr<DataType> to_type) {
+ return std::make_shared<StringScalar>(from.value, std::move(to_type));
}
// formattable to string
-template <typename ScalarType, typename T = typename ScalarType::TypeClass,
+template <typename To, typename From, typename T = typename From::TypeClass,
typename Formatter = internal::StringFormatter<T>,
// note: Value unused but necessary to trigger SFINAE if Formatter is
// undefined
typename Value = typename Formatter::value_type>
-Status CastImpl(const ScalarType& from, StringScalar* to) {
- to->value = FormatToBuffer(Formatter{from.type.get()}, from);
- return Status::OK();
-}
-
-Status CastImpl(const Decimal128Scalar& from, StringScalar* to) {
- auto from_type = checked_cast<const Decimal128Type*>(from.type.get());
- to->value = Buffer::FromString(from.value.ToString(from_type->scale()));
- return Status::OK();
-}
-
-Status CastImpl(const Decimal256Scalar& from, StringScalar* to) {
- auto from_type = checked_cast<const Decimal256Type*>(from.type.get());
- to->value = Buffer::FromString(from.value.ToString(from_type->scale()));
- return Status::OK();
+typename std::enable_if_t<std::is_same<To, StringType>::value,
+ Result<std::shared_ptr<Scalar>>>
+CastImpl(const From& from, std::shared_ptr<DataType> to_type) {
+ return
std::make_shared<StringScalar>(FormatToBuffer(Formatter{from.type.get()}, from),
+ std::move(to_type));
}
-Status CastImpl(const StructScalar& from, StringScalar* to) {
+// struct to string
+template <typename To>
+typename std::enable_if_t<std::is_same<To, StringType>::value,
+ Result<std::shared_ptr<Scalar>>>
+CastImpl(const StructScalar& from, std::shared_ptr<DataType> to_type) {
std::stringstream ss;
ss << '{';
for (int i = 0; static_cast<size_t>(i) < from.value.size(); i++) {
@@ -1132,24 +1251,23 @@ Status CastImpl(const StructScalar& from, StringScalar*
to) {
<< " = " << from.value[i]->ToString();
}
ss << '}';
- to->value = Buffer::FromString(ss.str());
- return Status::OK();
+ return std::make_shared<StringScalar>(Buffer::FromString(ss.str()),
std::move(to_type));
}
// casts between variable-length and fixed-length list types
-template <typename ToScalar>
-enable_if_list_type<typename ToScalar::TypeClass, Status> CastImpl(
- const BaseListScalar& from, ToScalar* to) {
- if constexpr (sizeof(typename ToScalar::TypeClass::offset_type) <
sizeof(int64_t)) {
- if (from.value->length() >
- std::numeric_limits<typename ToScalar::TypeClass::offset_type>::max())
{
+template <typename To, typename From>
+std::enable_if_t<is_list_type<To>::value && is_list_type<From>::value,
+ Result<std::shared_ptr<Scalar>>>
+CastImpl(const From& from, std::shared_ptr<DataType> to_type) {
+ if constexpr (sizeof(typename To::offset_type) < sizeof(int64_t)) {
+ if (from.value->length() > std::numeric_limits<typename
To::offset_type>::max()) {
return Status::Invalid(from.type->ToString(), " too large to cast to ",
- to->type->ToString());
+ to_type->ToString());
}
}
- if constexpr (is_fixed_size_list_type<typename ToScalar::TypeClass>::value) {
- const auto& fixed_size_list_type = checked_cast<const
FixedSizeListType&>(*to->type);
+ if constexpr (is_fixed_size_list_type<To>::value) {
+ const auto& fixed_size_list_type = checked_cast<const
FixedSizeListType&>(*to_type);
if (from.value->length() != fixed_size_list_type.list_size()) {
return Status::Invalid("Cannot cast ", from.type->ToString(), " of
length ",
from.value->length(), " to fixed size list of
length ",
@@ -1157,13 +1275,15 @@ enable_if_list_type<typename ToScalar::TypeClass,
Status> CastImpl(
}
}
- DCHECK_EQ(from.is_valid, to->is_valid);
- to->value = from.value;
- return Status::OK();
+ using ToScalar = typename TypeTraits<To>::ScalarType;
+ return std::make_shared<ToScalar>(from.value, std::move(to_type),
from.is_valid);
}
// list based types (list, large list and map (fixed sized list too)) to string
-Status CastImpl(const BaseListScalar& from, StringScalar* to) {
+template <typename To>
+typename std::enable_if_t<std::is_same<To, StringType>::value,
+ Result<std::shared_ptr<Scalar>>>
+CastImpl(const BaseListScalar& from, std::shared_ptr<DataType> to_type) {
std::stringstream ss;
ss << from.type->ToString() << "[";
for (int64_t i = 0; i < from.value->length(); i++) {
@@ -1172,11 +1292,14 @@ Status CastImpl(const BaseListScalar& from,
StringScalar* to) {
ss << value->ToString();
}
ss << ']';
- to->value = Buffer::FromString(ss.str());
- return Status::OK();
+ return std::make_shared<StringScalar>(Buffer::FromString(ss.str()),
std::move(to_type));
}
-Status CastImpl(const UnionScalar& from, StringScalar* to) {
+// union types to string
+template <typename To>
+typename std::enable_if_t<std::is_same<To, StringType>::value,
+ Result<std::shared_ptr<Scalar>>>
+CastImpl(const UnionScalar& from, std::shared_ptr<DataType> to_type) {
const auto& union_ty = checked_cast<const UnionType&>(*from.type);
std::stringstream ss;
const Scalar* selected_value;
@@ -1188,8 +1311,7 @@ Status CastImpl(const UnionScalar& from, StringScalar*
to) {
}
ss << "union{" <<
union_ty.field(union_ty.child_ids()[from.type_code])->ToString()
<< " = " << selected_value->ToString() << '}';
- to->value = Buffer::FromString(ss.str());
- return Status::OK();
+ return std::make_shared<StringScalar>(Buffer::FromString(ss.str()),
std::move(to_type));
}
struct CastImplVisitor {
@@ -1199,59 +1321,49 @@ struct CastImplVisitor {
const Scalar& from_;
const std::shared_ptr<DataType>& to_type_;
- Scalar* out_;
+ std::shared_ptr<Scalar> out_ = nullptr;
};
template <typename ToType>
struct FromTypeVisitor : CastImplVisitor {
using ToScalar = typename TypeTraits<ToType>::ScalarType;
- FromTypeVisitor(const Scalar& from, const std::shared_ptr<DataType>& to_type,
- Scalar* out)
- : CastImplVisitor{from, to_type, out} {}
+ FromTypeVisitor(const Scalar& from, const std::shared_ptr<DataType>& to_type)
+ : CastImplVisitor{from, to_type} {}
template <typename FromType>
Status Visit(const FromType&) {
- return CastImpl(checked_cast<const typename
TypeTraits<FromType>::ScalarType&>(from_),
- checked_cast<ToScalar*>(out_));
+ ARROW_ASSIGN_OR_RAISE(
+ out_, CastImpl<ToType>(
+ checked_cast<const typename
TypeTraits<FromType>::ScalarType&>(from_),
+ std::move(to_type_)));
+ return Status::OK();
}
// identity cast only for parameter free types
template <typename T1 = ToType>
typename std::enable_if_t<TypeTraits<T1>::is_parameter_free, Status> Visit(
const ToType&) {
- checked_cast<ToScalar*>(out_)->value = checked_cast<const
ToScalar&>(from_).value;
+ ARROW_ASSIGN_OR_RAISE(out_, MakeScalar(std::move(to_type_),
+ checked_cast<const
ToScalar&>(from_).value));
return Status::OK();
}
- Status CastFromListLike(const BaseListType& base_list_type) {
- return CastImpl(checked_cast<const BaseListScalar&>(from_),
- checked_cast<ToScalar*>(out_));
- }
-
- Status Visit(const ListType& list_type) { return
CastFromListLike(list_type); }
-
- Status Visit(const LargeListType& large_list_type) {
- return CastFromListLike(large_list_type);
- }
-
- Status Visit(const FixedSizeListType& fixed_size_list_type) {
- return CastFromListLike(fixed_size_list_type);
- }
-
Status Visit(const NullType&) { return NotImplemented(); }
Status Visit(const DictionaryType&) { return NotImplemented(); }
Status Visit(const ExtensionType&) { return NotImplemented(); }
};
struct ToTypeVisitor : CastImplVisitor {
- ToTypeVisitor(const Scalar& from, const std::shared_ptr<DataType>& to_type,
Scalar* out)
- : CastImplVisitor{from, to_type, out} {}
+ ToTypeVisitor(const Scalar& from, const std::shared_ptr<DataType>& to_type)
+ : CastImplVisitor{from, to_type} {}
template <typename ToType>
Status Visit(const ToType&) {
- FromTypeVisitor<ToType> unpack_from_type{from_, to_type_, out_};
- return VisitTypeInline(*from_.type, &unpack_from_type);
+ FromTypeVisitor<ToType> unpack_from_type{from_, to_type_};
+ ARROW_RETURN_NOT_OK(VisitTypeInline(*from_.type, &unpack_from_type));
+ out_ = std::move(unpack_from_type.out_);
+ return Status::OK();
}
Status Visit(const NullType&) {
@@ -1262,25 +1374,28 @@ struct ToTypeVisitor : CastImplVisitor {
}
Status Visit(const DictionaryType& dict_type) {
- auto& out = checked_cast<DictionaryScalar*>(out_)->value;
ARROW_ASSIGN_OR_RAISE(auto cast_value,
from_.CastTo(dict_type.value_type()));
- ARROW_ASSIGN_OR_RAISE(out.dictionary, MakeArrayFromScalar(*cast_value, 1));
- return Int32Scalar(0).CastTo(dict_type.index_type()).Value(&out.index);
+ ARROW_ASSIGN_OR_RAISE(auto dictionary, MakeArrayFromScalar(*cast_value,
1));
+ ARROW_ASSIGN_OR_RAISE(auto index,
Int32Scalar(0).CastTo(dict_type.index_type()));
+ out_ = DictionaryScalar::Make(std::move(index), std::move(dictionary));
+ return Status::OK();
}
Status Visit(const ExtensionType&) { return NotImplemented(); }
+
+ Result<std::shared_ptr<Scalar>> Finish() && {
+ ARROW_RETURN_NOT_OK(VisitTypeInline(*to_type_, this));
+ return std::move(out_);
+ }
};
} // namespace
Result<std::shared_ptr<Scalar>> Scalar::CastTo(std::shared_ptr<DataType> to)
const {
- std::shared_ptr<Scalar> out = MakeNullScalar(to);
if (is_valid) {
- out->is_valid = true;
- ToTypeVisitor unpack_to_type{*this, to, out.get()};
- RETURN_NOT_OK(VisitTypeInline(*to, &unpack_to_type));
+ return ToTypeVisitor{*this, std::move(to)}.Finish();
}
- return out;
+ return MakeNullScalar(std::move(to));
}
void PrintTo(const Scalar& scalar, std::ostream* os) { *os <<
scalar.ToString(); }
diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h
index 65c5ee4df0..a7ee6a417d 100644
--- a/cpp/src/arrow/scalar.h
+++ b/cpp/src/arrow/scalar.h
@@ -131,11 +131,19 @@ struct ARROW_EXPORT NullScalar : public Scalar {
namespace internal {
+constexpr auto kScalarScratchSpaceSize = sizeof(int64_t) * 2;
+
+template <typename Impl>
struct ARROW_EXPORT ArraySpanFillFromScalarScratchSpace {
// 16 bytes of scratch space to enable ArraySpan to be a view onto any
// Scalar- including binary scalars where we need to create a buffer
// that looks like two 32-bit or 64-bit offsets.
- alignas(int64_t) mutable uint8_t scratch_space_[sizeof(int64_t) * 2];
+ alignas(int64_t) mutable uint8_t scratch_space_[kScalarScratchSpaceSize];
+
+ private:
+ ArraySpanFillFromScalarScratchSpace() {
static_cast<Impl*>(this)->FillScratchSpace(); }
+
+ friend Impl;
};
struct ARROW_EXPORT PrimitiveScalarBase : public Scalar {
@@ -145,8 +153,6 @@ struct ARROW_EXPORT PrimitiveScalarBase : public Scalar {
using Scalar::Scalar;
/// \brief Get a const pointer to the value of this scalar. May be null.
virtual const void* data() const = 0;
- /// \brief Get a mutable pointer to the value of this scalar. May be null.
- virtual void* mutable_data() = 0;
/// \brief Get an immutable view of the value of this scalar as bytes.
virtual std::string_view view() const = 0;
};
@@ -167,7 +173,6 @@ struct ARROW_EXPORT PrimitiveScalar : public
PrimitiveScalarBase {
ValueType value{};
const void* data() const override { return &value; }
- void* mutable_data() override { return &value; }
std::string_view view() const override {
return std::string_view(reinterpret_cast<const char*>(&value),
sizeof(ValueType));
};
@@ -245,34 +250,38 @@ struct ARROW_EXPORT DoubleScalar : public
NumericScalar<DoubleType> {
using NumericScalar<DoubleType>::NumericScalar;
};
-struct ARROW_EXPORT BaseBinaryScalar
- : public internal::PrimitiveScalarBase,
- private internal::ArraySpanFillFromScalarScratchSpace {
- using internal::PrimitiveScalarBase::PrimitiveScalarBase;
+struct ARROW_EXPORT BaseBinaryScalar : public internal::PrimitiveScalarBase {
using ValueType = std::shared_ptr<Buffer>;
- std::shared_ptr<Buffer> value;
+ // The value is not supposed to be modified after construction, because
subclasses have
+ // a scratch space whose content need to be kept consistent with the value.
It is also
+ // the user of this class's responsibility to ensure that the buffer is not
written to
+ // accidentally.
+ const std::shared_ptr<Buffer> value = NULLPTR;
const void* data() const override {
return value ? reinterpret_cast<const void*>(value->data()) : NULLPTR;
}
- void* mutable_data() override {
- return value ? reinterpret_cast<void*>(value->mutable_data()) : NULLPTR;
- }
std::string_view view() const override {
return value ? std::string_view(*value) : std::string_view();
}
+ explicit BaseBinaryScalar(std::shared_ptr<DataType> type)
+ : internal::PrimitiveScalarBase(std::move(type)) {}
+
BaseBinaryScalar(std::shared_ptr<Buffer> value, std::shared_ptr<DataType>
type)
: internal::PrimitiveScalarBase{std::move(type), true},
value(std::move(value)) {}
- friend ArraySpan;
BaseBinaryScalar(std::string s, std::shared_ptr<DataType> type);
};
-struct ARROW_EXPORT BinaryScalar : public BaseBinaryScalar {
+struct ARROW_EXPORT BinaryScalar
+ : public BaseBinaryScalar,
+ private internal::ArraySpanFillFromScalarScratchSpace<BinaryScalar> {
using BaseBinaryScalar::BaseBinaryScalar;
using TypeClass = BinaryType;
+ using ArraySpanFillFromScalarScratchSpace =
+ internal::ArraySpanFillFromScalarScratchSpace<BinaryScalar>;
explicit BinaryScalar(std::shared_ptr<Buffer> value)
: BinaryScalar(std::move(value), binary()) {}
@@ -280,6 +289,12 @@ struct ARROW_EXPORT BinaryScalar : public BaseBinaryScalar
{
explicit BinaryScalar(std::string s) : BaseBinaryScalar(std::move(s),
binary()) {}
BinaryScalar() : BinaryScalar(binary()) {}
+
+ private:
+ void FillScratchSpace();
+
+ friend ArraySpan;
+ friend ArraySpanFillFromScalarScratchSpace;
};
struct ARROW_EXPORT StringScalar : public BinaryScalar {
@@ -294,9 +309,13 @@ struct ARROW_EXPORT StringScalar : public BinaryScalar {
StringScalar() : StringScalar(utf8()) {}
};
-struct ARROW_EXPORT BinaryViewScalar : public BaseBinaryScalar {
+struct ARROW_EXPORT BinaryViewScalar
+ : public BaseBinaryScalar,
+ private internal::ArraySpanFillFromScalarScratchSpace<BinaryViewScalar> {
using BaseBinaryScalar::BaseBinaryScalar;
using TypeClass = BinaryViewType;
+ using ArraySpanFillFromScalarScratchSpace =
+ internal::ArraySpanFillFromScalarScratchSpace<BinaryViewScalar>;
explicit BinaryViewScalar(std::shared_ptr<Buffer> value)
: BinaryViewScalar(std::move(value), binary_view()) {}
@@ -307,6 +326,12 @@ struct ARROW_EXPORT BinaryViewScalar : public
BaseBinaryScalar {
BinaryViewScalar() : BinaryViewScalar(binary_view()) {}
std::string_view view() const override { return
std::string_view(*this->value); }
+
+ private:
+ void FillScratchSpace();
+
+ friend ArraySpan;
+ friend ArraySpanFillFromScalarScratchSpace;
};
struct ARROW_EXPORT StringViewScalar : public BinaryViewScalar {
@@ -322,9 +347,13 @@ struct ARROW_EXPORT StringViewScalar : public
BinaryViewScalar {
StringViewScalar() : StringViewScalar(utf8_view()) {}
};
-struct ARROW_EXPORT LargeBinaryScalar : public BaseBinaryScalar {
+struct ARROW_EXPORT LargeBinaryScalar
+ : public BaseBinaryScalar,
+ private internal::ArraySpanFillFromScalarScratchSpace<LargeBinaryScalar>
{
using BaseBinaryScalar::BaseBinaryScalar;
using TypeClass = LargeBinaryType;
+ using ArraySpanFillFromScalarScratchSpace =
+ internal::ArraySpanFillFromScalarScratchSpace<LargeBinaryScalar>;
LargeBinaryScalar(std::shared_ptr<Buffer> value, std::shared_ptr<DataType>
type)
: BaseBinaryScalar(std::move(value), std::move(type)) {}
@@ -336,6 +365,12 @@ struct ARROW_EXPORT LargeBinaryScalar : public
BaseBinaryScalar {
: BaseBinaryScalar(std::move(s), large_binary()) {}
LargeBinaryScalar() : LargeBinaryScalar(large_binary()) {}
+
+ private:
+ void FillScratchSpace();
+
+ friend ArraySpan;
+ friend ArraySpanFillFromScalarScratchSpace;
};
struct ARROW_EXPORT LargeStringScalar : public LargeBinaryScalar {
@@ -482,10 +517,6 @@ struct ARROW_EXPORT DecimalScalar : public
internal::PrimitiveScalarBase {
return reinterpret_cast<const void*>(value.native_endian_bytes());
}
- void* mutable_data() override {
- return reinterpret_cast<void*>(value.mutable_native_endian_bytes());
- }
-
std::string_view view() const override {
return std::string_view(reinterpret_cast<const
char*>(value.native_endian_bytes()),
ValueType::kByteWidth);
@@ -502,54 +533,102 @@ struct ARROW_EXPORT Decimal256Scalar : public
DecimalScalar<Decimal256Type, Deci
using DecimalScalar::DecimalScalar;
};
-struct ARROW_EXPORT BaseListScalar
- : public Scalar,
- private internal::ArraySpanFillFromScalarScratchSpace {
- using Scalar::Scalar;
+struct ARROW_EXPORT BaseListScalar : public Scalar {
using ValueType = std::shared_ptr<Array>;
BaseListScalar(std::shared_ptr<Array> value, std::shared_ptr<DataType> type,
bool is_valid = true);
- std::shared_ptr<Array> value;
-
- private:
- friend struct ArraySpan;
+ // The value is not supposed to be modified after construction, because
subclasses have
+ // a scratch space whose content need to be kept consistent with the value.
It is also
+ // the user of this class's responsibility to ensure that the array is not
modified
+ // accidentally.
+ const std::shared_ptr<Array> value;
};
-struct ARROW_EXPORT ListScalar : public BaseListScalar {
+struct ARROW_EXPORT ListScalar
+ : public BaseListScalar,
+ private internal::ArraySpanFillFromScalarScratchSpace<ListScalar> {
using TypeClass = ListType;
using BaseListScalar::BaseListScalar;
+ using ArraySpanFillFromScalarScratchSpace =
+ internal::ArraySpanFillFromScalarScratchSpace<ListScalar>;
explicit ListScalar(std::shared_ptr<Array> value, bool is_valid = true);
+
+ private:
+ void FillScratchSpace();
+
+ friend ArraySpan;
+ friend ArraySpanFillFromScalarScratchSpace;
};
-struct ARROW_EXPORT LargeListScalar : public BaseListScalar {
+struct ARROW_EXPORT LargeListScalar
+ : public BaseListScalar,
+ private internal::ArraySpanFillFromScalarScratchSpace<LargeListScalar> {
using TypeClass = LargeListType;
using BaseListScalar::BaseListScalar;
+ using ArraySpanFillFromScalarScratchSpace =
+ internal::ArraySpanFillFromScalarScratchSpace<LargeListScalar>;
explicit LargeListScalar(std::shared_ptr<Array> value, bool is_valid = true);
+
+ private:
+ void FillScratchSpace();
+
+ friend ArraySpan;
+ friend ArraySpanFillFromScalarScratchSpace;
};
-struct ARROW_EXPORT ListViewScalar : public BaseListScalar {
+struct ARROW_EXPORT ListViewScalar
+ : public BaseListScalar,
+ private internal::ArraySpanFillFromScalarScratchSpace<ListViewScalar> {
using TypeClass = ListViewType;
using BaseListScalar::BaseListScalar;
+ using ArraySpanFillFromScalarScratchSpace =
+ internal::ArraySpanFillFromScalarScratchSpace<ListViewScalar>;
explicit ListViewScalar(std::shared_ptr<Array> value, bool is_valid = true);
+
+ private:
+ void FillScratchSpace();
+
+ friend ArraySpan;
+ friend ArraySpanFillFromScalarScratchSpace;
};
-struct ARROW_EXPORT LargeListViewScalar : public BaseListScalar {
+struct ARROW_EXPORT LargeListViewScalar
+ : public BaseListScalar,
+ private
internal::ArraySpanFillFromScalarScratchSpace<LargeListViewScalar> {
using TypeClass = LargeListViewType;
using BaseListScalar::BaseListScalar;
+ using ArraySpanFillFromScalarScratchSpace =
+ internal::ArraySpanFillFromScalarScratchSpace<LargeListViewScalar>;
explicit LargeListViewScalar(std::shared_ptr<Array> value, bool is_valid =
true);
+
+ private:
+ void FillScratchSpace();
+
+ friend ArraySpan;
+ friend ArraySpanFillFromScalarScratchSpace;
};
-struct ARROW_EXPORT MapScalar : public BaseListScalar {
+struct ARROW_EXPORT MapScalar
+ : public BaseListScalar,
+ private internal::ArraySpanFillFromScalarScratchSpace<MapScalar> {
using TypeClass = MapType;
using BaseListScalar::BaseListScalar;
+ using ArraySpanFillFromScalarScratchSpace =
+ internal::ArraySpanFillFromScalarScratchSpace<MapScalar>;
explicit MapScalar(std::shared_ptr<Array> value, bool is_valid = true);
+
+ private:
+ void FillScratchSpace();
+
+ friend ArraySpan;
+ friend ArraySpanFillFromScalarScratchSpace;
};
struct ARROW_EXPORT FixedSizeListScalar : public BaseListScalar {
@@ -576,9 +655,10 @@ struct ARROW_EXPORT StructScalar : public Scalar {
std::vector<std::string>
field_names);
};
-struct ARROW_EXPORT UnionScalar : public Scalar,
- private
internal::ArraySpanFillFromScalarScratchSpace {
- int8_t type_code;
+struct ARROW_EXPORT UnionScalar : public Scalar {
+ // The type code is not supposed to be modified after construction, because
the scratch
+ // space's content need to be kept consistent with it.
+ const int8_t type_code;
virtual const std::shared_ptr<Scalar>& child_value() const = 0;
@@ -586,17 +666,31 @@ struct ARROW_EXPORT UnionScalar : public Scalar,
UnionScalar(std::shared_ptr<DataType> type, int8_t type_code, bool is_valid)
: Scalar(std::move(type), is_valid), type_code(type_code) {}
- friend struct ArraySpan;
+ struct UnionScratchSpace {
+ alignas(int64_t) int8_t type_code;
+ alignas(int64_t) uint8_t offsets[sizeof(int32_t) * 2];
+ };
+ static_assert(sizeof(UnionScratchSpace) <=
internal::kScalarScratchSpaceSize);
+
+ friend ArraySpan;
};
-struct ARROW_EXPORT SparseUnionScalar : public UnionScalar {
+struct ARROW_EXPORT SparseUnionScalar
+ : public UnionScalar,
+ private internal::ArraySpanFillFromScalarScratchSpace<SparseUnionScalar>
{
using TypeClass = SparseUnionType;
+ using ArraySpanFillFromScalarScratchSpace =
+ internal::ArraySpanFillFromScalarScratchSpace<SparseUnionScalar>;
// Even though only one of the union values is relevant for this scalar, we
// nonetheless construct a vector of scalars, one per union value, to have
// enough data to reconstruct a valid ArraySpan of length 1 from this scalar
using ValueType = std::vector<std::shared_ptr<Scalar>>;
- ValueType value;
+ // The value is not supposed to be modified after construction, because the
scratch
+ // space's content need to be kept consistent with the value. It is also the
user of
+ // this class's responsibility to ensure that the scalars of the vector is
not modified
+ // to accidentally.
+ const ValueType value;
// The value index corresponding to the active type code
int child_id;
@@ -611,30 +705,56 @@ struct ARROW_EXPORT SparseUnionScalar : public
UnionScalar {
/// to construct a vector of scalars
static std::shared_ptr<Scalar> FromValue(std::shared_ptr<Scalar> value, int
field_index,
std::shared_ptr<DataType> type);
+
+ private:
+ void FillScratchSpace();
+
+ friend ArraySpan;
+ friend ArraySpanFillFromScalarScratchSpace;
};
-struct ARROW_EXPORT DenseUnionScalar : public UnionScalar {
+struct ARROW_EXPORT DenseUnionScalar
+ : public UnionScalar,
+ private internal::ArraySpanFillFromScalarScratchSpace<DenseUnionScalar> {
using TypeClass = DenseUnionType;
+ using ArraySpanFillFromScalarScratchSpace =
+ internal::ArraySpanFillFromScalarScratchSpace<DenseUnionScalar>;
// For DenseUnionScalar, we can make a valid ArraySpan of length 1 from this
// scalar
using ValueType = std::shared_ptr<Scalar>;
- ValueType value;
+ // The value is not supposed to be modified after construction, because the
scratch
+ // space's content need to be kept consistent with the value. It is also the
user of
+ // this class's responsibility to ensure that the elements of the vector is
not modified
+ // accidentally.
+ const ValueType value;
const std::shared_ptr<Scalar>& child_value() const override { return
this->value; }
DenseUnionScalar(ValueType value, int8_t type_code,
std::shared_ptr<DataType> type)
: UnionScalar(std::move(type), type_code, value->is_valid),
value(std::move(value)) {}
+
+ private:
+ void FillScratchSpace();
+
+ friend ArraySpan;
+ friend ArraySpanFillFromScalarScratchSpace;
};
struct ARROW_EXPORT RunEndEncodedScalar
: public Scalar,
- private internal::ArraySpanFillFromScalarScratchSpace {
+ private
internal::ArraySpanFillFromScalarScratchSpace<RunEndEncodedScalar> {
using TypeClass = RunEndEncodedType;
using ValueType = std::shared_ptr<Scalar>;
+ using ArraySpanFillFromScalarScratchSpace =
+ internal::ArraySpanFillFromScalarScratchSpace<RunEndEncodedScalar>;
- ValueType value;
+ // The value is not supposed to be modified after construction, because the
scratch
+ // space's content need to be kept consistent with the value. It is also the
user of
+ // this class's responsibility to ensure that the wrapped scalar is not
modified
+ // accidentally.
+ const ValueType value;
RunEndEncodedScalar(std::shared_ptr<Scalar> value, std::shared_ptr<DataType>
type);
@@ -652,7 +772,10 @@ struct ARROW_EXPORT RunEndEncodedScalar
private:
const TypeClass& ree_type() const { return
internal::checked_cast<TypeClass&>(*type); }
+ void FillScratchSpace();
+
friend ArraySpan;
+ friend ArraySpanFillFromScalarScratchSpace;
};
/// \brief A Scalar value for DictionaryType
@@ -680,10 +803,6 @@ struct ARROW_EXPORT DictionaryScalar : public
internal::PrimitiveScalarBase {
const void* data() const override {
return
internal::checked_cast<internal::PrimitiveScalarBase&>(*value.index).data();
}
- void* mutable_data() override {
- return internal::checked_cast<internal::PrimitiveScalarBase&>(*value.index)
- .mutable_data();
- }
std::string_view view() const override {
return internal::checked_cast<const
internal::PrimitiveScalarBase&>(*value.index)
.view();
diff --git a/cpp/src/arrow/scalar_test.cc b/cpp/src/arrow/scalar_test.cc
index 09dfde3227..104a5697b5 100644
--- a/cpp/src/arrow/scalar_test.cc
+++ b/cpp/src/arrow/scalar_test.cc
@@ -95,6 +95,68 @@ TEST(TestNullScalar, ValidateErrors) {
AssertValidationFails(scalar);
}
+TEST(TestNullScalar, Cast) {
+ NullScalar scalar;
+ for (auto to_type : {
+ int8(),
+ float64(),
+ date32(),
+ time32(TimeUnit::SECOND),
+ timestamp(TimeUnit::SECOND),
+ duration(TimeUnit::SECOND),
+ utf8(),
+ large_binary(),
+ list(int32()),
+ struct_({field("f", int32())}),
+ map(utf8(), int32()),
+ decimal(12, 2),
+ list_view(int32()),
+ large_list(int32()),
+ dense_union({field("string", utf8()), field("number", uint64())}),
+ sparse_union({field("string", utf8()), field("number", uint64())}),
+ }) {
+ // Cast() function doesn't support casting null scalar, use
Scalar::CastTo() instead.
+ ASSERT_OK_AND_ASSIGN(auto casted, scalar.CastTo(to_type));
+ ASSERT_EQ(casted->type->id(), to_type->id());
+ ASSERT_FALSE(casted->is_valid);
+ }
+}
+
+TEST(TestBooleanScalar, Cast) {
+ for (auto b : {true, false}) {
+ BooleanScalar scalar(b);
+ ARROW_SCOPED_TRACE("boolean value: ", scalar.ToString());
+
+ // Boolean type (identity cast).
+ {
+ ASSERT_OK_AND_ASSIGN(auto casted, Cast(scalar, boolean()));
+ ASSERT_TRUE(casted.scalar()->Equals(scalar)) <<
casted.scalar()->ToString();
+ }
+
+ // Numeric types.
+ for (auto to_type : {
+ int8(),
+ uint16(),
+ int32(),
+ uint64(),
+ float32(),
+ float64(),
+ }) {
+ ARROW_SCOPED_TRACE("to type: ", to_type->ToString());
+ ASSERT_OK_AND_ASSIGN(auto casted, Cast(scalar, to_type));
+ ASSERT_EQ(casted.scalar()->type->id(), to_type->id());
+ ASSERT_EQ(casted.scalar()->ToString(), std::to_string(b));
+ }
+
+ // String type.
+ {
+ ASSERT_OK_AND_ASSIGN(auto casted, Cast(scalar, utf8()));
+ ASSERT_EQ(casted.scalar()->type->id(), utf8()->id());
+ ASSERT_EQ(casted.scalar()->ToString(), scalar.ToString());
+ }
+ }
+}
+
template <typename T>
class TestNumericScalar : public ::testing::Test {
public:
@@ -464,12 +526,23 @@ class TestDecimalScalar : public ::testing::Test {
::testing::HasSubstr("does not fit in
precision of"),
invalid.ValidateFull());
}
+
+ void TestCast() {
+ const auto ty = std::make_shared<T>(3, 2);
+ const auto pi = ScalarType(ValueType(314), ty);
+
+ ASSERT_OK_AND_ASSIGN(auto casted, Cast(pi, utf8()));
+ ASSERT_TRUE(casted.scalar()->Equals(StringScalar("3.14")))
+ << casted.scalar()->ToString();
+ }
};
TYPED_TEST_SUITE(TestDecimalScalar, DecimalArrowTypes);
TYPED_TEST(TestDecimalScalar, Basics) { this->TestBasics(); }
+TYPED_TEST(TestDecimalScalar, Cast) { this->TestCast(); }
+
TEST(TestBinaryScalar, Basics) {
std::string data = "test data";
auto buf = std::make_shared<Buffer>(data);
@@ -551,6 +624,14 @@ TEST(TestBinaryScalar, ValidateErrors) {
AssertValidationFails(*null_scalar);
}
+TEST(TestBinaryScalar, Cast) {
+ BinaryScalar scalar(Buffer::FromString("test data"));
+ ASSERT_OK_AND_ASSIGN(auto casted, Cast(scalar, utf8()));
+ ASSERT_EQ(casted.scalar()->type->id(), utf8()->id());
+ AssertBufferEqual(*checked_cast<const StringScalar&>(*casted.scalar()).value,
+ *scalar.value);
+}
+
template <typename T>
class TestStringScalar : public ::testing::Test {
public:
@@ -580,19 +661,25 @@ class TestStringScalar : public ::testing::Test {
}
void TestValidateErrors() {
- // Inconsistent is_valid / value
- ScalarType scalar(Buffer::FromString("xxx"));
- scalar.is_valid = false;
- AssertValidationFails(scalar);
+ {
+ // Inconsistent is_valid / value
+ ScalarType scalar(Buffer::FromString("xxx"));
+ scalar.is_valid = false;
+ AssertValidationFails(scalar);
+ }
- auto null_scalar = MakeNullScalar(type_);
- null_scalar->is_valid = true;
- AssertValidationFails(*null_scalar);
+ {
+ auto null_scalar = MakeNullScalar(type_);
+ null_scalar->is_valid = true;
+ AssertValidationFails(*null_scalar);
+ }
- // Invalid UTF8
- scalar = ScalarType(Buffer::FromString("\xff"));
- ASSERT_OK(scalar.Validate());
- ASSERT_RAISES(Invalid, scalar.ValidateFull());
+ {
+ // Invalid UTF8
+ ScalarType scalar(Buffer::FromString("\xff"));
+ ASSERT_OK(scalar.Validate());
+ ASSERT_RAISES(Invalid, scalar.ValidateFull());
+ }
}
protected:
@@ -676,8 +763,16 @@ TEST(TestFixedSizeBinaryScalar, ValidateErrors) {
FixedSizeBinaryScalar scalar(buf, type);
ASSERT_OK(scalar.ValidateFull());
- scalar.value = SliceBuffer(buf, 1);
- AssertValidationFails(scalar);
+ ASSERT_RAISES(Invalid, MakeScalar(type, SliceBuffer(buf, 1)));
+}
+
+TEST(TestFixedSizeBinaryScalar, Cast) {
+ std::string data = "test data";
+ FixedSizeBinaryScalar scalar(data);
+ ASSERT_OK_AND_ASSIGN(auto casted, Cast(scalar, utf8()));
+ ASSERT_EQ(casted.scalar()->type->id(), utf8()->id());
+ AssertBufferEqual(*checked_cast<const StringScalar&>(*casted.scalar()).value,
+ *scalar.value);
}
TEST(TestDateScalars, Basics) {
@@ -1136,24 +1231,25 @@ class TestListLikeScalar : public ::testing::Test {
}
void TestValidateErrors() {
- ScalarType scalar(value_);
- scalar.is_valid = false;
- ASSERT_OK(scalar.ValidateFull());
-
- // Value must be defined
- scalar = ScalarType(value_);
- scalar.value = nullptr;
- AssertValidationFails(scalar);
+ {
+ ScalarType scalar(value_);
+ scalar.is_valid = false;
+ ASSERT_OK(scalar.ValidateFull());
+ }
- // Inconsistent child type
- scalar = ScalarType(value_);
- scalar.value = ArrayFromJSON(int32(), "[1, 2, null]");
- AssertValidationFails(scalar);
+ {
+ // Value must be defined
+ ScalarType scalar(nullptr, type_);
+ scalar.is_valid = true;
+ AssertValidationFails(scalar);
+ }
- // Invalid UTF8 in child data
- scalar = ScalarType(ArrayFromJSON(utf8(), "[null, null, \"\xff\"]"));
- ASSERT_OK(scalar.Validate());
- ASSERT_RAISES(Invalid, scalar.ValidateFull());
+ {
+ // Invalid UTF8 in child data
+ ScalarType scalar(ArrayFromJSON(utf8(), "[null, null, \"\xff\"]"));
+ ASSERT_OK(scalar.Validate());
+ ASSERT_RAISES(Invalid, scalar.ValidateFull());
+ }
}
void TestHashing() {
@@ -1195,6 +1291,12 @@ class TestListLikeScalar : public ::testing::Test {
auto invalid_cast_type = fixed_size_list(value_->type(), 5);
CheckListCastError(scalar, invalid_cast_type);
+
+ // Cast() function doesn't support casting list-like to string, use
Scalar::CastTo()
+ // instead.
+ ASSERT_OK_AND_ASSIGN(auto casted_str, scalar.CastTo(utf8()));
+ ASSERT_EQ(casted_str->type->id(), utf8()->id());
+ ASSERT_EQ(casted_str->ToString(), scalar.ToString());
}
protected:
@@ -1224,6 +1326,24 @@ TEST(TestFixedSizeListScalar, ValidateErrors) {
AssertValidationFails(scalar);
}
+TEST(TestFixedSizeListScalar, Cast) {
+ const auto ty = fixed_size_list(int16(), 3);
+ FixedSizeListScalar scalar(ArrayFromJSON(int16(), "[1, 2, 5]"), ty);
+
+ CheckListCast(scalar, list(int16()));
+ CheckListCast(scalar, large_list(int16()));
+ CheckListCast(scalar, fixed_size_list(int16(), 3));
+
+ auto invalid_cast_type = fixed_size_list(int16(), 4);
+ CheckListCastError(scalar, invalid_cast_type);
+
+ // Cast() function doesn't support casting list-like to string, use
Scalar::CastTo()
+ // instead.
+ ASSERT_OK_AND_ASSIGN(auto casted_str, scalar.CastTo(utf8()));
+ ASSERT_EQ(casted_str->type->id(), utf8()->id());
+ ASSERT_EQ(casted_str->ToString(), scalar.ToString());
+}
+
TEST(TestMapScalar, Basics) {
auto value =
ArrayFromJSON(struct_({field("key", utf8(), false), field("value",
int8())}),
@@ -1253,6 +1373,12 @@ TEST(TestMapScalar, Cast) {
auto invalid_cast_type = fixed_size_list(key_value_type, 5);
CheckListCastError(scalar, invalid_cast_type);
+
+ // Cast() function doesn't support casting map to string, use
Scalar::CastTo() instead.
+ ASSERT_OK_AND_ASSIGN(auto casted_str, scalar.CastTo(utf8()));
+ ASSERT_TRUE(casted_str->Equals(StringScalar(
+ R"(map<string, int8>[{key:string = a, value:int8 = 1}, {key:string = b,
value:int8 = 2}])")))
+ << casted_str->ToString();
}
TEST(TestStructScalar, FieldAccess) {
@@ -1345,6 +1471,16 @@ TEST(TestStructScalar, ValidateErrors) {
ASSERT_RAISES(Invalid, scalar.ValidateFull());
}
+TEST(TestStructScalar, Cast) {
+ auto ty = struct_({field("i", int32()), field("s", utf8())});
+ StructScalar scalar({MakeScalar(42), MakeScalar("xxx")}, ty);
+
+ // Cast() function doesn't support casting map to string, use
Scalar::CastTo() instead.
+ ASSERT_OK_AND_ASSIGN(auto casted_str, scalar.CastTo(utf8()));
+ ASSERT_TRUE(casted_str->Equals(StringScalar(R"({i:int32 = 42, s:string =
xxx})")))
+ << casted_str->ToString();
+}
+
TEST(TestDictionaryScalar, Basics) {
for (auto index_ty : all_dictionary_index_types()) {
auto ty = dictionary(index_ty, utf8());
@@ -1534,17 +1670,41 @@ void CheckGetNullUnionScalar(const Array& arr, int64_t
index) {
ASSERT_FALSE(checked_cast<const
UnionScalar&>(*scalar).child_value()->is_valid);
}
+std::shared_ptr<Scalar> MakeUnionScalar(const SparseUnionType& type, int8_t
type_code,
+ std::shared_ptr<Scalar> field_value,
+ int field_index) {
+ ScalarVector field_values;
+ for (int i = 0; i < type.num_fields(); ++i) {
+ if (i == field_index) {
+ field_values.emplace_back(std::move(field_value));
+ } else {
+ field_values.emplace_back(MakeNullScalar(type.field(i)->type()));
+ }
+ }
+ return std::make_shared<SparseUnionScalar>(std::move(field_values),
type_code,
+ type.GetSharedPtr());
+}
+
std::shared_ptr<Scalar> MakeUnionScalar(const SparseUnionType& type,
std::shared_ptr<Scalar> field_value,
int field_index) {
- return SparseUnionScalar::FromValue(field_value, field_index,
type.GetSharedPtr());
+ return SparseUnionScalar::FromValue(std::move(field_value), field_index,
+ type.GetSharedPtr());
+}
+
+std::shared_ptr<Scalar> MakeUnionScalar(const DenseUnionType& type, int8_t
type_code,
+ std::shared_ptr<Scalar> field_value,
+ int field_index) {
+ return std::make_shared<DenseUnionScalar>(std::move(field_value), type_code,
+ type.GetSharedPtr());
}
std::shared_ptr<Scalar> MakeUnionScalar(const DenseUnionType& type,
std::shared_ptr<Scalar> field_value,
int field_index) {
int8_t type_code = type.type_codes()[field_index];
- return std::make_shared<DenseUnionScalar>(field_value, type_code,
type.GetSharedPtr());
+ return std::make_shared<DenseUnionScalar>(std::move(field_value), type_code,
+ type.GetSharedPtr());
}
std::shared_ptr<Scalar> MakeSpecificNullScalar(const DenseUnionType& type,
@@ -1592,7 +1752,13 @@ class TestUnionScalar : public ::testing::Test {
std::shared_ptr<Scalar> ScalarFromValue(int field_index,
std::shared_ptr<Scalar> field_value)
{
- return MakeUnionScalar(*union_type_, field_value, field_index);
+ return MakeUnionScalar(*union_type_, std::move(field_value), field_index);
+ }
+
+ std::shared_ptr<Scalar> ScalarFromTypeCodeAndValue(int8_t type_code,
+ std::shared_ptr<Scalar>
field_value,
+ int field_index) {
+ return MakeUnionScalar(*union_type_, type_code, std::move(field_value),
field_index);
}
std::shared_ptr<Scalar> SpecificNull(int field_index) {
@@ -1610,40 +1776,48 @@ class TestUnionScalar : public ::testing::Test {
}
void TestValidateErrors() {
- // Type code doesn't exist
- auto scalar = ScalarFromValue(0, alpha_);
- UnionScalar* union_scalar = static_cast<UnionScalar*>(scalar.get());
-
- // Invalid type code
- union_scalar->type_code = 0;
- AssertValidationFails(*union_scalar);
+ {
+ // Invalid type code
+ auto scalar = ScalarFromTypeCodeAndValue(0, alpha_, 0);
+ AssertValidationFails(*scalar);
+ }
- union_scalar->is_valid = false;
- AssertValidationFails(*union_scalar);
+ {
+ auto scalar = ScalarFromTypeCodeAndValue(0, alpha_, 0);
+ scalar->is_valid = false;
+ AssertValidationFails(*scalar);
+ }
- union_scalar->type_code = -42;
- union_scalar->is_valid = true;
- AssertValidationFails(*union_scalar);
+ {
+ auto scalar = ScalarFromTypeCodeAndValue(-42, alpha_, 0);
+ AssertValidationFails(*scalar);
+ }
- union_scalar->is_valid = false;
- AssertValidationFails(*union_scalar);
+ {
+ auto scalar = ScalarFromTypeCodeAndValue(-42, alpha_, 0);
+ scalar->is_valid = false;
+ AssertValidationFails(*scalar);
+ }
// Type code doesn't correspond to child type
if (type_->id() == ::arrow::Type::DENSE_UNION) {
- union_scalar->type_code = 42;
- union_scalar->is_valid = true;
- AssertValidationFails(*union_scalar);
-
- scalar = ScalarFromValue(2, two_);
- union_scalar = static_cast<UnionScalar*>(scalar.get());
- union_scalar->type_code = 3;
- AssertValidationFails(*union_scalar);
+ {
+ auto scalar = ScalarFromTypeCodeAndValue(42, alpha_, 0);
+ AssertValidationFails(*scalar);
+ }
+
+ {
+ auto scalar = ScalarFromTypeCodeAndValue(3, two_, 2);
+ AssertValidationFails(*scalar);
+ }
}
- // underlying value has invalid UTF8
- scalar = ScalarFromValue(0, std::make_shared<StringScalar>("\xff"));
- ASSERT_OK(scalar->Validate());
- ASSERT_RAISES(Invalid, scalar->ValidateFull());
+ {
+ // underlying value has invalid UTF8
+ auto scalar = ScalarFromValue(0, std::make_shared<StringScalar>("\xff"));
+ ASSERT_OK(scalar->Validate());
+ ASSERT_RAISES(Invalid, scalar->ValidateFull());
+ }
}
void TestEquals() {
@@ -1680,6 +1854,14 @@ class TestUnionScalar : public ::testing::Test {
}
}
+ void TestCast() {
+ // Cast() function doesn't support casting union to string, use
Scalar::CastTo()
+ // instead.
+ ASSERT_OK_AND_ASSIGN(auto casted, union_alpha_->CastTo(utf8()));
+ ASSERT_TRUE(casted->Equals(StringScalar(R"(union{string: string =
alpha})")))
+ << casted->ToString();
+ }
+
protected:
std::shared_ptr<DataType> type_;
const UnionType* union_type_;
@@ -1698,6 +1880,8 @@ TYPED_TEST(TestUnionScalar, Equals) { this->TestEquals();
}
TYPED_TEST(TestUnionScalar, MakeNullScalar) { this->TestMakeNullScalar(); }
+TYPED_TEST(TestUnionScalar, Cast) { this->TestCast(); }
+
class TestSparseUnionScalar : public TestUnionScalar<SparseUnionType> {};
TEST_F(TestSparseUnionScalar, GetScalar) {
@@ -1974,14 +2158,14 @@ TEST_F(TestExtensionScalar, ValidateErrors) {
scalar.is_valid = false;
ASSERT_OK(scalar.ValidateFull());
- // Invalid storage scalar (wrong length)
- std::shared_ptr<Scalar> invalid_storage = MakeNullScalar(storage_type_);
- invalid_storage->is_valid = true;
- static_cast<FixedSizeBinaryScalar*>(invalid_storage.get())->value =
- std::make_shared<Buffer>("123");
- AssertValidationFails(*invalid_storage);
+ // Invalid storage scalar (invalid UTF8)
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Scalar> invalid_storage,
+ MakeScalar(utf8(), std::make_shared<Buffer>("\xff")));
+ ASSERT_OK(invalid_storage->Validate());
+ ASSERT_RAISES(Invalid, invalid_storage->ValidateFull());
scalar = ExtensionScalar(invalid_storage, type_);
- AssertValidationFails(scalar);
+ ASSERT_OK(scalar.Validate());
+ ASSERT_RAISES(Invalid, scalar.ValidateFull());
}
} // namespace arrow