This is an automated email from the ASF dual-hosted git repository.
felipecrv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 1f67c1a16a GH-43291: [C++] Expand the 'take' function tests to cover
more chunked-array cases (#43292)
1f67c1a16a is described below
commit 1f67c1a16a426d27a52d9aa31fc1b39602bad161
Author: Felipe Oliveira Carvalho <[email protected]>
AuthorDate: Thu Jul 25 13:22:51 2024 -0300
GH-43291: [C++] Expand the 'take' function tests to cover more
chunked-array cases (#43292)
### Rationale for this change
#41700 (as it is currently) passes all the C++ tests even though it
contains a few bugs (caught by manual repro steps and tests of of the Ruby
bindings). The C++ tests should be able to catch these kinds of bugs and
exercise code beyond the TakeAAA cases.
### What changes are included in this PR?
- Explicitly calling out which TakeXX variation is being checked in tests
and assert helpers
- Using `AssertChunkedEqual` instead of `AssertChunkedEquivalent` (via
`AssertDatumsEqual`)
-
### Are these changes tested?
Yes. The improved tests catch bugs now.
* GitHub Issue: #43291
Authored-by: Felipe Oliveira Carvalho <[email protected]>
Signed-off-by: Felipe Oliveira Carvalho <[email protected]>
---
.../arrow/compute/kernels/vector_selection_test.cc | 1037 ++++++++++++--------
1 file changed, 608 insertions(+), 429 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/vector_selection_test.cc
b/cpp/src/arrow/compute/kernels/vector_selection_test.cc
index aba016d6b7..b38f3fcbd8 100644
--- a/cpp/src/arrow/compute/kernels/vector_selection_test.cc
+++ b/cpp/src/arrow/compute/kernels/vector_selection_test.cc
@@ -28,6 +28,7 @@
#include "arrow/chunked_array.h"
#include "arrow/compute/api.h"
#include "arrow/compute/kernels/test_util.h"
+#include "arrow/scalar.h"
#include "arrow/table.h"
#include "arrow/testing/builder.h"
#include "arrow/testing/fixed_width_test_util.h"
@@ -1101,33 +1102,114 @@ TEST(TestFilterMetaFunction, ArityChecking) {
// ----------------------------------------------------------------------
// Take tests
+//
+// Shorthand notation (as defined in `TakeMetaFunction`):
+//
+// A = Array
+// C = ChunkedArray
+// R = RecordBatch
+// T = Table
+//
+// (e.g. TakeCAC = Take(ChunkedArray, Array) -> ChunkedArray)
+//
+// The interface implemented by `TakeMetaFunction` is:
+//
+// Take(A, A) -> A (TakeAAA)
+// Take(A, C) -> C (TakeACC)
+// Take(C, A) -> C (TakeCAC)
+// Take(C, C) -> C (TakeCCC)
+// Take(R, A) -> R (TakeRAR)
+// Take(T, A) -> T (TakeTAT)
+// Take(T, C) -> T (TakeTCT)
+//
+// The tests extend the notation with a few "union kinds":
+//
+// X = Array | ChunkedArray
+//
+// Examples:
+//
+// TakeXA = {TakeAAA, TakeCAC},
+// TakeXX = {TakeAAA, TakeACC, TakeCAC, TakeCCC}
+namespace {
-void AssertTakeArrays(const std::shared_ptr<Array>& values,
- const std::shared_ptr<Array>& indices,
- const std::shared_ptr<Array>& expected) {
- ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> actual, Take(*values, *indices));
- ValidateOutput(actual);
- AssertArraysEqual(*expected, *actual, /*verbose=*/true);
+Result<std::shared_ptr<Array>> TakeAAA(const Array& values, const Array&
indices) {
+ ARROW_ASSIGN_OR_RAISE(Datum out, Take(Datum(values), Datum(indices)));
+ return out.make_array();
}
-Status TakeJSON(const std::shared_ptr<DataType>& type, const std::string&
values,
- const std::shared_ptr<DataType>& index_type, const
std::string& indices,
- std::shared_ptr<Array>* out) {
- return Take(*ArrayFromJSON(type, values), *ArrayFromJSON(index_type,
indices))
- .Value(out);
+Result<std::shared_ptr<Array>> TakeAAA(
+ const std::shared_ptr<DataType>& type, const std::string& values,
+ const std::string& indices, const std::shared_ptr<DataType>& index_type =
int32()) {
+ return TakeAAA(*ArrayFromJSON(type, values), *ArrayFromJSON(index_type,
indices));
}
-void DoCheckTake(const std::shared_ptr<Array>& values,
- const std::shared_ptr<Array>& indices,
- const std::shared_ptr<Array>& expected) {
- AssertTakeArrays(values, indices, expected);
+// TakeACC is never tested directly, so it is not defined here
+
+Result<Datum> TakeCAC(std::shared_ptr<ChunkedArray> values,
+ std::shared_ptr<Array> indices) {
+ return Take(Datum{std::move(values)}, Datum{std::move(indices)});
+}
+
+Result<Datum> TakeCAC(const std::shared_ptr<DataType>& type,
+ const std::vector<std::string>& values, const
std::string& indices,
+ const std::shared_ptr<DataType>& index_type = int8()) {
+ return TakeCAC(ChunkedArrayFromJSON(type, values), ArrayFromJSON(index_type,
indices));
+}
+
+Result<Datum> TakeCCC(std::shared_ptr<ChunkedArray> values,
+ std::shared_ptr<ChunkedArray> indices) {
+ return Take(Datum{std::move(values)}, Datum{std::move(indices)});
+}
+
+Result<Datum> TakeCCC(const std::shared_ptr<DataType>& type,
+ const std::vector<std::string>& values,
+ const std::vector<std::string>& indices) {
+ return TakeCCC(ChunkedArrayFromJSON(type, values),
+ ChunkedArrayFromJSON(int8(), indices));
+}
+
+Result<Datum> TakeRAR(const std::shared_ptr<Schema>& schm, const std::string&
batch_json,
+ const std::string& indices,
+ const std::shared_ptr<DataType>& index_type = int8()) {
+ auto batch = RecordBatchFromJSON(schm, batch_json);
+ return Take(Datum{std::move(batch)}, Datum{ArrayFromJSON(index_type,
indices)});
+}
+
+Result<Datum> TakeTAT(const std::shared_ptr<Schema>& schm,
+ const std::vector<std::string>& values, const
std::string& indices,
+ const std::shared_ptr<DataType>& index_type = int8()) {
+ return Take(Datum{TableFromJSON(schm, values)},
+ Datum{ArrayFromJSON(index_type, indices)});
+}
+
+Result<Datum> TakeTCT(const std::shared_ptr<Schema>& schm,
+ const std::vector<std::string>& values,
+ const std::vector<std::string>& indices) {
+ return Take(Datum{TableFromJSON(schm, values)},
+ Datum{ChunkedArrayFromJSON(int8(), indices)});
+}
+
+// Assert helpers for Take tests
+
+void DoAssertTakeAAA(const std::shared_ptr<Array>& values,
+ const std::shared_ptr<Array>& indices,
+ const std::shared_ptr<Array>& expected) {
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> actual, TakeAAA(*values,
*indices));
+ ValidateOutput(actual);
+ AssertArraysEqual(*expected, *actual, /*verbose=*/true);
+}
+
+void DoCheckTakeAAA(const std::shared_ptr<Array>& values,
+ const std::shared_ptr<Array>& indices,
+ const std::shared_ptr<Array>& expected) {
+ DoAssertTakeAAA(values, indices, expected);
// Check sliced values
ASSERT_OK_AND_ASSIGN(auto values_filler, MakeArrayOfNull(values->type(), 2));
ASSERT_OK_AND_ASSIGN(auto values_sliced,
Concatenate({values_filler, values, values_filler}));
values_sliced = values_sliced->Slice(2, values->length());
- AssertTakeArrays(values_sliced, indices, expected);
+ DoAssertTakeAAA(values_sliced, indices, expected);
// Check sliced indices
ASSERT_OK_AND_ASSIGN(auto zero, MakeScalar(indices->type(), int8_t{0}));
@@ -1135,33 +1217,171 @@ void DoCheckTake(const std::shared_ptr<Array>& values,
ASSERT_OK_AND_ASSIGN(auto indices_sliced,
Concatenate({indices_filler, indices, indices_filler}));
indices_sliced = indices_sliced->Slice(3, indices->length());
- AssertTakeArrays(values, indices_sliced, expected);
-}
-
-void CheckTake(const std::shared_ptr<DataType>& type, const std::string&
values_json,
- const std::string& indices_json, const std::string&
expected_json) {
+ DoAssertTakeAAA(values, indices_sliced, expected);
+}
+
+void DoCheckTakeCACWithArrays(const std::shared_ptr<Array>& values,
+ const std::shared_ptr<Array>& indices,
+ const std::shared_ptr<Array>& expected) {
+ auto pool = default_memory_pool();
+ const bool indices_null_count_is_known = indices->null_count() !=
kUnknownNullCount;
+
+ // We check TakeCAC by checking this equality:
+ //
+ // TakeAAA(Concat(V, V, V), I') == Concat(TakeCAC([V, V, V], I'))
+ // where
+ // V = values
+ // I = indices
+ // I' = Concat(I + 2 * V.length, I, I + V.length)
+ auto values3 = ArrayVector{values, values, values};
+ ASSERT_OK_AND_ASSIGN(auto concat_values3, Concatenate(values3, pool));
+ auto chunked_values3 = std::make_shared<ChunkedArray>(values3);
+ std::shared_ptr<Array> concat_indices3;
+ {
+ auto double_length =
+ MakeScalar(indices->type(), static_cast<int>(2 * values->length()));
+ auto zero = MakeScalar(indices->type(), 0);
+ auto length = MakeScalar(indices->type(),
static_cast<int>(values->length()));
+ ASSERT_OK_AND_ASSIGN(auto indices_prefix, Add(indices, *double_length));
+ ASSERT_OK_AND_ASSIGN(auto indices_middle, Add(indices, *zero));
+ ASSERT_OK_AND_ASSIGN(auto indices_suffix, Add(indices, *length));
+ auto indices3 = ArrayVector{
+ indices_prefix.make_array(),
+ indices_middle.make_array(),
+ indices_suffix.make_array(),
+ };
+ ASSERT_OK_AND_ASSIGN(concat_indices3, Concatenate(indices3, pool));
+ // Preserve the fact that indices->null_count() is unknown if it is
unknown.
+ if (!indices_null_count_is_known) {
+ concat_indices3->data()->null_count = kUnknownNullCount;
+ }
+ }
+ ASSERT_OK_AND_ASSIGN(auto concat_expected3,
+ Concatenate({expected, expected, expected}));
+ ASSERT_OK_AND_ASSIGN(Datum chunked_actual, TakeCAC(chunked_values3,
concat_indices3));
+ ValidateOutput(chunked_actual);
+ ASSERT_OK_AND_ASSIGN(auto concat_actual,
+ Concatenate(chunked_actual.chunked_array()->chunks()));
+ AssertArraysEqual(*concat_expected3, *concat_actual, /*verbose=*/true);
+
+ // We check TakeCAC again by checking this equality:
+ //
+ // TakeAAA(V, I) == Concat(TakeCAC(C, I))
+ // where
+ // K = V.length // 4
+ // C = [V.slice(0, K), V.slice(K, 2*K), V.slice(3*K, N - 3*K)]
+ // V = values
+ // I = indices
+ const int64_t n = values->length();
+ const int64_t k = n / 4;
+ if (k > 0) {
+ auto value_slices = ArrayVector{values->Slice(0, k), values->Slice(k, 2 *
k),
+ values->Slice(3 * k, n - k)};
+ auto chunked_values = std::make_shared<ChunkedArray>(value_slices);
+ ASSERT_OK_AND_ASSIGN(chunked_actual, TakeCAC(chunked_values, indices));
+ ValidateOutput(chunked_actual);
+ ASSERT_OK_AND_ASSIGN(concat_actual,
+
Concatenate(chunked_actual.chunked_array()->chunks()));
+ AssertArraysEqual(*concat_actual, *expected, /*verbose=*/true);
+ }
+}
+
+// TakeXA = {TakeAAA, TakeCAC}
+void DoCheckTakeXA(const std::shared_ptr<Array>& values,
+ const std::shared_ptr<Array>& indices,
+ const std::shared_ptr<Array>& expected) {
+ DoCheckTakeAAA(values, indices, expected);
+ DoCheckTakeCACWithArrays(values, indices, expected);
+}
+
+// TakeXA = {TakeAAA, TakeCAC}
+void CheckTakeXA(const std::shared_ptr<DataType>& type, const std::string&
values_json,
+ const std::string& indices_json, const std::string&
expected_json) {
auto values = ArrayFromJSON(type, values_json);
auto expected = ArrayFromJSON(type, expected_json);
for (auto index_type : {int8(), uint32()}) {
auto indices = ArrayFromJSON(index_type, indices_json);
- DoCheckTake(values, indices, expected);
+ DoCheckTakeXA(values, indices, expected);
}
}
-void AssertTakeNull(const std::string& values, const std::string& indices,
- const std::string& expected) {
- CheckTake(null(), values, indices, expected);
+void CheckTakeXADictionary(std::shared_ptr<DataType> value_type,
+ const std::string& dictionary_values,
+ const std::string& dictionary_indices,
+ const std::string& indices,
+ const std::string& expected_indices) {
+ auto dict = ArrayFromJSON(value_type, dictionary_values);
+ auto type = dictionary(int8(), value_type);
+ ASSERT_OK_AND_ASSIGN(
+ auto values,
+ DictionaryArray::FromArrays(type, ArrayFromJSON(int8(),
dictionary_indices), dict));
+ ASSERT_OK_AND_ASSIGN(
+ auto expected,
+ DictionaryArray::FromArrays(type, ArrayFromJSON(int8(),
expected_indices), dict));
+ auto take_indices = ArrayFromJSON(int8(), indices);
+ DoCheckTakeXA(values, take_indices, expected);
}
-void AssertTakeBoolean(const std::string& values, const std::string& indices,
- const std::string& expected) {
- CheckTake(boolean(), values, indices, expected);
+void AssertTakeCAC(const std::shared_ptr<DataType>& type,
+ const std::vector<std::string>& values, const std::string&
indices,
+ const std::vector<std::string>& expected) {
+ ASSERT_OK_AND_ASSIGN(auto actual, TakeCAC(type, values, indices));
+ ValidateOutput(actual);
+ AssertChunkedEqual(*ChunkedArrayFromJSON(type, expected),
*actual.chunked_array());
}
+void AssertTakeCCC(const std::shared_ptr<DataType>& type,
+ const std::vector<std::string>& values,
+ const std::vector<std::string>& indices,
+ const std::vector<std::string>& expected) {
+ ASSERT_OK_AND_ASSIGN(auto actual, TakeCCC(type, values, indices));
+ ValidateOutput(actual);
+ AssertChunkedEqual(*ChunkedArrayFromJSON(type, expected),
*actual.chunked_array());
+}
+
+void CheckTakeXCC(const Datum& values, const std::vector<std::string>& indices,
+ const std::vector<std::string>& expected) {
+ EXPECT_TRUE(values.is_array() || values.is_chunked_array());
+ auto idx = ChunkedArrayFromJSON(int32(), indices);
+ ASSERT_OK_AND_ASSIGN(auto actual, Take(values, Datum{idx}));
+ ValidateOutput(actual);
+ AssertChunkedEqual(*ChunkedArrayFromJSON(values.type(), expected),
+ *actual.chunked_array());
+}
+
+void AssertTakeRAR(const std::shared_ptr<Schema>& schm, const std::string&
batch_json,
+ const std::string& indices, const std::string&
expected_batch) {
+ for (auto index_type : {int8(), uint32()}) {
+ ASSERT_OK_AND_ASSIGN(auto actual, TakeRAR(schm, batch_json, indices,
index_type));
+ ValidateOutput(actual);
+ ASSERT_BATCHES_EQUAL(*RecordBatchFromJSON(schm, expected_batch),
+ *actual.record_batch());
+ }
+}
+
+void AssertTakeTAT(const std::shared_ptr<Schema>& schm,
+ const std::vector<std::string>& table_json, const
std::string& filter,
+ const std::vector<std::string>& expected_table) {
+ ASSERT_OK_AND_ASSIGN(auto actual, TakeTAT(schm, table_json, filter));
+ ValidateOutput(actual);
+ ASSERT_TABLES_EQUAL(*TableFromJSON(schm, expected_table), *actual.table());
+}
+
+void AssertTakeTCT(const std::shared_ptr<Schema>& schm,
+ const std::vector<std::string>& table_json,
+ const std::vector<std::string>& filter,
+ const std::vector<std::string>& expected_table) {
+ ASSERT_OK_AND_ASSIGN(auto actual, TakeTCT(schm, table_json, filter));
+ ValidateOutput(actual);
+ ASSERT_TABLES_EQUAL(*TableFromJSON(schm, expected_table), *actual.table());
+}
+
+// Validators used by random data tests
+
template <typename ValuesType, typename IndexType>
-void ValidateTakeImpl(const std::shared_ptr<Array>& values,
- const std::shared_ptr<Array>& indices,
- const std::shared_ptr<Array>& result) {
+void ValidateTakeXAImpl(const std::shared_ptr<Array>& values,
+ const std::shared_ptr<Array>& indices,
+ const std::shared_ptr<Array>& result) {
using ValuesArrayType = typename TypeTraits<ValuesType>::ArrayType;
using IndexArrayType = typename TypeTraits<IndexType>::ArrayType;
auto typed_values = checked_pointer_cast<ValuesArrayType>(values);
@@ -1185,39 +1405,45 @@ void ValidateTakeImpl(const std::shared_ptr<Array>&
values,
<< i;
}
}
+ // DoCheckTakeCACWithArrays transforms the indices which has a risk of
+ // overflow, so we only call it if the index type is not too wide.
+ if (indices->type()->byte_width() <= 4) {
+ auto cast_options = CastOptions::Safe(TypeHolder{int64()});
+ ASSERT_OK_AND_ASSIGN(auto indices64, Cast(indices, cast_options));
+ DoCheckTakeCACWithArrays(values, indices64.make_array(),
/*expected=*/result);
+ }
}
template <typename ValuesType>
-void ValidateTake(const std::shared_ptr<Array>& values,
- const std::shared_ptr<Array>& indices) {
- ASSERT_OK_AND_ASSIGN(Datum out, Take(values, indices));
- auto taken = out.make_array();
+void ValidateTakeXA(const std::shared_ptr<Array>& values,
+ const std::shared_ptr<Array>& indices) {
+ ASSERT_OK_AND_ASSIGN(auto taken, TakeAAA(*values, *indices));
ValidateOutput(taken);
ASSERT_EQ(indices->length(), taken->length());
switch (indices->type_id()) {
case Type::INT8:
- ValidateTakeImpl<ValuesType, Int8Type>(values, indices, taken);
+ ValidateTakeXAImpl<ValuesType, Int8Type>(values, indices, taken);
break;
case Type::INT16:
- ValidateTakeImpl<ValuesType, Int16Type>(values, indices, taken);
+ ValidateTakeXAImpl<ValuesType, Int16Type>(values, indices, taken);
break;
case Type::INT32:
- ValidateTakeImpl<ValuesType, Int32Type>(values, indices, taken);
+ ValidateTakeXAImpl<ValuesType, Int32Type>(values, indices, taken);
break;
case Type::INT64:
- ValidateTakeImpl<ValuesType, Int64Type>(values, indices, taken);
+ ValidateTakeXAImpl<ValuesType, Int64Type>(values, indices, taken);
break;
case Type::UINT8:
- ValidateTakeImpl<ValuesType, UInt8Type>(values, indices, taken);
+ ValidateTakeXAImpl<ValuesType, UInt8Type>(values, indices, taken);
break;
case Type::UINT16:
- ValidateTakeImpl<ValuesType, UInt16Type>(values, indices, taken);
+ ValidateTakeXAImpl<ValuesType, UInt16Type>(values, indices, taken);
break;
case Type::UINT32:
- ValidateTakeImpl<ValuesType, UInt32Type>(values, indices, taken);
+ ValidateTakeXAImpl<ValuesType, UInt32Type>(values, indices, taken);
break;
case Type::UINT64:
- ValidateTakeImpl<ValuesType, UInt64Type>(values, indices, taken);
+ ValidateTakeXAImpl<ValuesType, UInt64Type>(values, indices, taken);
break;
default:
FAIL() << "Invalid index type";
@@ -1225,6 +1451,8 @@ void ValidateTake(const std::shared_ptr<Array>& values,
}
}
+// ----
+
template <typename T>
T GetMaxIndex(int64_t values_length) {
int64_t max_index = values_length - 1;
@@ -1239,13 +1467,15 @@ uint64_t GetMaxIndex(int64_t values_length) {
return static_cast<uint64_t>(values_length - 1);
}
+} // namespace
+
class TestTakeKernel : public ::testing::Test {
- public:
- void TestNoValidityBitmapButUnknownNullCount(const std::shared_ptr<Array>&
values,
- const std::shared_ptr<Array>&
indices) {
+ private:
+ void DoTestNoValidityBitmapButUnknownNullCount(const std::shared_ptr<Array>&
values,
+ const std::shared_ptr<Array>&
indices) {
ASSERT_EQ(values->null_count(), 0);
ASSERT_EQ(indices->null_count(), 0);
- auto expected = (*Take(values, indices)).make_array();
+ ASSERT_OK_AND_ASSIGN(auto expected, TakeAAA(*values, *indices));
auto new_values = MakeArray(values->data()->Copy());
new_values->data()->buffers[0].reset();
@@ -1253,67 +1483,95 @@ class TestTakeKernel : public ::testing::Test {
auto new_indices = MakeArray(indices->data()->Copy());
new_indices->data()->buffers[0].reset();
new_indices->data()->null_count = kUnknownNullCount;
- auto result = (*Take(new_values, new_indices)).make_array();
-
- AssertArraysEqual(*expected, *result);
+ DoCheckTakeXA(new_values, new_indices, expected);
}
- void TestNoValidityBitmapButUnknownNullCount(const
std::shared_ptr<DataType>& type,
- const std::string& values,
- const std::string& indices) {
- TestNoValidityBitmapButUnknownNullCount(ArrayFromJSON(type, values),
- ArrayFromJSON(int16(), indices));
+ public:
+ void DoTestNoValidityBitmapButUnknownNullCount(
+ const std::shared_ptr<DataType>& type, const std::string& values,
+ const std::string& indices, std::shared_ptr<DataType> index_type =
int8()) {
+ DoTestNoValidityBitmapButUnknownNullCount(ArrayFromJSON(type, values),
+ ArrayFromJSON(index_type,
indices));
}
void TestNumericBasics(const std::shared_ptr<DataType>& type) {
ARROW_SCOPED_TRACE("type = ", *type);
- CheckTake(type, "[7, 8, 9]", "[]", "[]");
- CheckTake(type, "[7, 8, 9]", "[0, 1, 0]", "[7, 8, 7]");
- CheckTake(type, "[null, 8, 9]", "[0, 1, 0]", "[null, 8, null]");
- CheckTake(type, "[7, 8, 9]", "[null, 1, 0]", "[null, 8, 7]");
- CheckTake(type, "[null, 8, 9]", "[]", "[]");
- CheckTake(type, "[7, 8, 9]", "[0, 0, 0, 0, 0, 0, 2]", "[7, 7, 7, 7, 7, 7,
9]");
-
+ CheckTakeXA(type, "[7, 8, 9]", "[]", "[]");
+ CheckTakeXA(type, "[7, 8, 9]", "[0, 1, 0]", "[7, 8, 7]");
+ CheckTakeXA(type, "[null, 8, 9]", "[0, 1, 0]", "[null, 8, null]");
+ CheckTakeXA(type, "[7, 8, 9]", "[null, 1, 0]", "[null, 8, 7]");
+ CheckTakeXA(type, "[null, 8, 9]", "[]", "[]");
+ CheckTakeXA(type, "[7, 8, 9]", "[0, 0, 0, 0, 0, 0, 2]", "[7, 7, 7, 7, 7,
7, 9]");
+
+ const std::string k789 = "[7, 8, 9]";
std::shared_ptr<Array> arr;
- ASSERT_RAISES(IndexError, TakeJSON(type, "[7, 8, 9]", int8(), "[0, 9, 0]",
&arr));
- ASSERT_RAISES(IndexError, TakeJSON(type, "[7, 8, 9]", int8(), "[0, -1,
0]", &arr));
+ ASSERT_RAISES(IndexError, TakeAAA(type, k789, "[0, 9, 0]").Value(&arr));
+ ASSERT_RAISES(IndexError, TakeAAA(type, k789, "[0, -1, 0]").Value(&arr));
+ Datum chunked_arr;
+ ASSERT_RAISES(IndexError,
+ TakeCAC(type, {k789, k789}, "[0, 9,
0]").Value(&chunked_arr));
+ ASSERT_RAISES(IndexError,
+ TakeCAC(type, {k789, k789}, "[0, -1,
0]").Value(&chunked_arr));
}
};
template <typename ArrowType>
-class TestTakeKernelTyped : public TestTakeKernel {};
+class TestTakeKernelTyped : public TestTakeKernel {
+ protected:
+ virtual std::shared_ptr<DataType> value_type() const {
+ if constexpr (is_parameter_free_type<ArrowType>::value) {
+ return TypeTraits<ArrowType>::type_singleton();
+ } else {
+ EXPECT_TRUE(false) << "value_type() must be overridden for parameterized
types";
+ return nullptr;
+ }
+ }
+
+ void TestNoValidityBitmapButUnknownNullCount(
+ const std::string& values, const std::string& indices,
+ const std::shared_ptr<DataType>& index_type = int8()) {
+ return DoTestNoValidityBitmapButUnknownNullCount(this->value_type(),
values, indices,
+ index_type);
+ }
+
+ void CheckTakeXA(const std::string& values, const std::string& indices,
+ const std::string& expected) {
+ compute::CheckTakeXA(this->value_type(), values, indices, expected);
+ }
+};
+
+static const char kNull3[] = "[null, null, null]";
TEST_F(TestTakeKernel, TakeNull) {
- AssertTakeNull("[null, null, null]", "[0, 1, 0]", "[null, null, null]");
- AssertTakeNull("[null, null, null]", "[0, 2]", "[null, null]");
+ CheckTakeXA(null(), kNull3, "[0, 1, 0]", "[null, null, null]");
+ CheckTakeXA(null(), kNull3, "[0, 2]", "[null, null]");
std::shared_ptr<Array> arr;
+ ASSERT_RAISES(IndexError, TakeAAA(null(), kNull3, "[0, 9, 0]").Value(&arr));
+ ASSERT_RAISES(IndexError, TakeAAA(boolean(), kNull3, "[0, -1,
0]").Value(&arr));
+ Datum chunked_arr;
ASSERT_RAISES(IndexError,
- TakeJSON(null(), "[null, null, null]", int8(), "[0, 9, 0]",
&arr));
+ TakeCAC(null(), {kNull3, kNull3}, "[0, 9,
0]").Value(&chunked_arr));
ASSERT_RAISES(IndexError,
- TakeJSON(boolean(), "[null, null, null]", int8(), "[0, -1,
0]", &arr));
+ TakeCAC(boolean(), {kNull3, kNull3}, "[0, -1,
0]").Value(&chunked_arr));
}
TEST_F(TestTakeKernel, InvalidIndexType) {
std::shared_ptr<Array> arr;
- ASSERT_RAISES(NotImplemented, TakeJSON(null(), "[null, null, null]",
float32(),
- "[0.0, 1.0, 0.1]", &arr));
+ ASSERT_RAISES(NotImplemented,
+ TakeAAA(null(), kNull3, "[0.0, 1.0, 0.1]",
float32()).Value(&arr));
+ Datum chunked_arr;
+ ASSERT_RAISES(NotImplemented,
+ TakeCAC(null(), {kNull3, kNull3}, "[0.0, 1.0, 0.1]", float32())
+ .Value(&chunked_arr));
}
-TEST_F(TestTakeKernel, TakeCCEmptyIndices) {
- Datum dat = ChunkedArrayFromJSON(int8(), {"[]"});
- Datum idx = ChunkedArrayFromJSON(int32(), {});
- ASSERT_OK_AND_ASSIGN(auto out, Take(dat, idx));
- ValidateOutput(out);
- AssertDatumsEqual(ChunkedArrayFromJSON(int8(), {"[]"}), out, true);
-}
-
-TEST_F(TestTakeKernel, TakeACEmptyIndices) {
- Datum dat = ArrayFromJSON(int8(), {"[]"});
- Datum idx = ChunkedArrayFromJSON(int32(), {});
- ASSERT_OK_AND_ASSIGN(auto out, Take(dat, idx));
- ValidateOutput(out);
- AssertDatumsEqual(ChunkedArrayFromJSON(int8(), {"[]"}), out, true);
+TEST_F(TestTakeKernel, TakeXCCEmptyIndices) {
+ auto expected = std::vector<std::string>{"[]"};
+ auto values = ArrayFromJSON(int8(), {"[1, 3, 3, 7]"});
+ CheckTakeXCC(values, {"[]"}, expected);
+ auto chunked_values = std::make_shared<ChunkedArray>(values);
+ CheckTakeXCC(chunked_values, {"[]"}, expected);
}
TEST_F(TestTakeKernel, DefaultOptions) {
@@ -1329,18 +1587,25 @@ TEST_F(TestTakeKernel, DefaultOptions) {
}
TEST_F(TestTakeKernel, TakeBoolean) {
- AssertTakeBoolean("[7, 8, 9]", "[]", "[]");
- AssertTakeBoolean("[true, false, true]", "[0, 1, 0]", "[true, false, true]");
- AssertTakeBoolean("[null, false, true]", "[0, 1, 0]", "[null, false, null]");
- AssertTakeBoolean("[true, false, true]", "[null, 1, 0]", "[null, false,
true]");
+ CheckTakeXA(boolean(), "[7, 8, 9]", "[]", "[]");
+ CheckTakeXA(boolean(), "[true, false, true]", "[0, 1, 0]", "[true, false,
true]");
+ CheckTakeXA(boolean(), "[null, false, true]", "[0, 1, 0]", "[null, false,
null]");
+ CheckTakeXA(boolean(), "[true, false, true]", "[null, 1, 0]", "[null, false,
true]");
- TestNoValidityBitmapButUnknownNullCount(boolean(), "[true, false, true]",
"[1, 0, 0]");
+ DoTestNoValidityBitmapButUnknownNullCount(boolean(), "[true, false, true]",
+ "[1, 0, 0]");
+ const std::string kTrueFalseTrue = "[true, false, true]";
std::shared_ptr<Array> arr;
+ ASSERT_RAISES(IndexError, TakeAAA(boolean(), kTrueFalseTrue, "[0, 9,
0]").Value(&arr));
+ ASSERT_RAISES(IndexError, TakeAAA(boolean(), kTrueFalseTrue, "[0, -1,
0]").Value(&arr));
+ Datum chunked_arr;
ASSERT_RAISES(IndexError,
- TakeJSON(boolean(), "[true, false, true]", int8(), "[0, 9,
0]", &arr));
+ TakeCAC(boolean(), {kTrueFalseTrue, kTrueFalseTrue}, "[0, 9,
0]")
+ .Value(&chunked_arr));
ASSERT_RAISES(IndexError,
- TakeJSON(boolean(), "[true, false, true]", int8(), "[0, -1,
0]", &arr));
+ TakeCAC(boolean(), {kTrueFalseTrue, kTrueFalseTrue}, "[0, -1,
0]")
+ .Value(&chunked_arr));
}
TEST_F(TestTakeKernel, Temporal) {
@@ -1349,8 +1614,8 @@ TEST_F(TestTakeKernel, Temporal) {
this->TestNumericBasics(timestamp(TimeUnit::NANO, "Europe/Paris"));
this->TestNumericBasics(duration(TimeUnit::SECOND));
this->TestNumericBasics(date32());
- CheckTake(date64(), "[0, 86400000, null]", "[null, 1, 1, 0]",
- "[null, 86400000, 86400000, 0]");
+ CheckTakeXA(date64(), "[0, 86400000, null]", "[null, 1, 1, 0]",
+ "[null, 86400000, 86400000, 0]");
}
TEST_F(TestTakeKernel, Duration) {
@@ -1363,177 +1628,184 @@ TEST_F(TestTakeKernel, Interval) {
this->TestNumericBasics(month_interval());
auto type = day_time_interval();
- CheckTake(type, "[[1, -600], [2, 3000], null]", "[0, null, 2, 1]",
- "[[1, -600], null, null, [2, 3000]]");
+ CheckTakeXA(type, "[[1, -600], [2, 3000], null]", "[0, null, 2, 1]",
+ "[[1, -600], null, null, [2, 3000]]");
type = month_day_nano_interval();
- CheckTake(type, "[[1, -2, 34567890123456789], [2, 3, -34567890123456789],
null]",
- "[0, null, 2, 1]",
- "[[1, -2, 34567890123456789], null, null, [2, 3,
-34567890123456789]]");
+ CheckTakeXA(type, "[[1, -2, 34567890123456789], [2, 3, -34567890123456789],
null]",
+ "[0, null, 2, 1]",
+ "[[1, -2, 34567890123456789], null, null, [2, 3,
-34567890123456789]]");
}
template <typename ArrowType>
-class TestTakeKernelWithNumeric : public TestTakeKernelTyped<ArrowType> {
- protected:
- void AssertTake(const std::string& values, const std::string& indices,
- const std::string& expected) {
- CheckTake(type_singleton(), values, indices, expected);
- }
-
- std::shared_ptr<DataType> type_singleton() {
- return TypeTraits<ArrowType>::type_singleton();
- }
-};
+class TestTakeKernelWithNumeric : public TestTakeKernelTyped<ArrowType> {};
TYPED_TEST_SUITE(TestTakeKernelWithNumeric, NumericArrowTypes);
TYPED_TEST(TestTakeKernelWithNumeric, TakeNumeric) {
- this->TestNumericBasics(this->type_singleton());
+ this->TestNumericBasics(this->value_type());
}
template <typename TypeClass>
class TestTakeKernelWithString : public TestTakeKernelTyped<TypeClass> {
public:
- std::shared_ptr<DataType> value_type() {
- return TypeTraits<TypeClass>::type_singleton();
- }
-
- void AssertTake(const std::string& values, const std::string& indices,
- const std::string& expected) {
- CheckTake(value_type(), values, indices, expected);
- }
-
- void AssertTakeDictionary(const std::string& dictionary_values,
- const std::string& dictionary_indices,
- const std::string& indices,
- const std::string& expected_indices) {
- auto dict = ArrayFromJSON(value_type(), dictionary_values);
- auto type = dictionary(int8(), value_type());
- ASSERT_OK_AND_ASSIGN(auto values,
- DictionaryArray::FromArrays(
- type, ArrayFromJSON(int8(), dictionary_indices),
dict));
- ASSERT_OK_AND_ASSIGN(
- auto expected,
- DictionaryArray::FromArrays(type, ArrayFromJSON(int8(),
expected_indices), dict));
- auto take_indices = ArrayFromJSON(int8(), indices);
- AssertTakeArrays(values, take_indices, expected);
+ void AssertTakeXADictionary(const std::string& dictionary_values,
+ const std::string& dictionary_indices,
+ const std::string& indices,
+ const std::string& expected_indices) {
+ return CheckTakeXADictionary(this->value_type(), dictionary_values,
+ dictionary_indices, indices,
expected_indices);
}
};
TYPED_TEST_SUITE(TestTakeKernelWithString, BaseBinaryArrowTypes);
TYPED_TEST(TestTakeKernelWithString, TakeString) {
- this->AssertTake(R"(["a", "b", "c"])", "[0, 1, 0]", R"(["a", "b", "a"])");
- this->AssertTake(R"([null, "b", "c"])", "[0, 1, 0]", "[null, \"b\", null]");
- this->AssertTake(R"(["a", "b", "c"])", "[null, 1, 0]", R"([null, "b",
"a"])");
+ this->CheckTakeXA(R"(["a", "b", "c"])", "[0, 1, 0]", R"(["a", "b", "a"])");
+ this->CheckTakeXA(R"([null, "b", "c"])", "[0, 1, 0]", "[null, \"b\", null]");
+ this->CheckTakeXA(R"(["a", "b", "c"])", "[null, 1, 0]", R"([null, "b",
"a"])");
- this->TestNoValidityBitmapButUnknownNullCount(this->value_type(), R"(["a",
"b", "c"])",
- "[0, 1, 0]");
+ this->TestNoValidityBitmapButUnknownNullCount(R"(["a", "b", "c"])", "[0, 1,
0]");
std::shared_ptr<DataType> type = this->value_type();
+ const std::string kABC = R"(["a", "b", "c"])";
std::shared_ptr<Array> arr;
- ASSERT_RAISES(IndexError,
- TakeJSON(type, R"(["a", "b", "c"])", int8(), "[0, 9, 0]",
&arr));
- ASSERT_RAISES(IndexError, TakeJSON(type, R"(["a", "b", null, "ddd", "ee"])",
int64(),
- "[2, 5]", &arr));
+ ASSERT_RAISES(IndexError, TakeAAA(type, kABC, "[0, 9, 0]").Value(&arr));
+ ASSERT_RAISES(IndexError, TakeAAA(type, kABC, "[2, 5]").Value(&arr));
+ Datum chunked_arr;
+ ASSERT_RAISES(IndexError, TakeCAC(type, {kABC, kABC}, "[0, 9,
0]").Value(&chunked_arr));
+ ASSERT_RAISES(IndexError, TakeCAC(type, {kABC, kABC}, "[4,
10]").Value(&chunked_arr));
}
TYPED_TEST(TestTakeKernelWithString, TakeDictionary) {
auto dict = R"(["a", "b", "c", "d", "e"])";
- this->AssertTakeDictionary(dict, "[3, 4, 2]", "[0, 1, 0]", "[3, 4, 3]");
- this->AssertTakeDictionary(dict, "[null, 4, 2]", "[0, 1, 0]", "[null, 4,
null]");
- this->AssertTakeDictionary(dict, "[3, 4, 2]", "[null, 1, 0]", "[null, 4,
3]");
+ this->AssertTakeXADictionary(dict, "[3, 4, 2]", "[0, 1, 0]", "[3, 4, 3]");
+ this->AssertTakeXADictionary(dict, "[null, 4, 2]", "[0, 1, 0]", "[null, 4,
null]");
+ this->AssertTakeXADictionary(dict, "[3, 4, 2]", "[null, 1, 0]", "[null, 4,
3]");
}
class TestTakeKernelFSB : public TestTakeKernelTyped<FixedSizeBinaryType> {
public:
- std::shared_ptr<DataType> value_type() { return fixed_size_binary(3); }
-
- void AssertTake(const std::string& values, const std::string& indices,
- const std::string& expected) {
- CheckTake(value_type(), values, indices, expected);
- }
+ std::shared_ptr<DataType> value_type() const override { return
fixed_size_binary(3); }
};
TEST_F(TestTakeKernelFSB, TakeFixedSizeBinary) {
- this->AssertTake(R"(["aaa", "bbb", "ccc"])", "[0, 1, 0]", R"(["aaa", "bbb",
"aaa"])");
- this->AssertTake(R"([null, "bbb", "ccc"])", "[0, 1, 0]", "[null, \"bbb\",
null]");
- this->AssertTake(R"(["aaa", "bbb", "ccc"])", "[null, 1, 0]", R"([null,
"bbb", "aaa"])");
+ const std::string kABC = R"(["aaa", "bbb", "ccc"])";
+ this->CheckTakeXA(kABC, "[0, 1, 0]", R"(["aaa", "bbb", "aaa"])");
+ this->CheckTakeXA(R"([null, "bbb", "ccc"])", "[0, 1, 0]", "[null, \"bbb\",
null]");
+ this->CheckTakeXA(kABC, "[null, 1, 0]", R"([null, "bbb", "aaa"])");
- this->TestNoValidityBitmapButUnknownNullCount(this->value_type(),
- R"(["aaa", "bbb", "ccc"])",
"[0, 1, 0]");
+ this->TestNoValidityBitmapButUnknownNullCount(kABC, "[0, 1, 0]");
std::shared_ptr<DataType> type = this->value_type();
+ const std::string kABNullDE = R"(["aaa", "bbb", null, "ddd", "eee"])";
std::shared_ptr<Array> arr;
+ ASSERT_RAISES(IndexError, TakeAAA(type, kABC, "[0, 9, 0]").Value(&arr));
+ ASSERT_RAISES(IndexError, TakeAAA(type, kABNullDE, "[2, 5]").Value(&arr));
+ Datum chunked_arr;
+ ASSERT_RAISES(IndexError, TakeCAC(type, {kABC, kABC}, "[0, 9,
0]").Value(&chunked_arr));
ASSERT_RAISES(IndexError,
- TakeJSON(type, R"(["aaa", "bbb", "ccc"])", int8(), "[0, 9,
0]", &arr));
- ASSERT_RAISES(IndexError, TakeJSON(type, R"(["aaa", "bbb", null, "ddd",
"eee"])",
- int64(), "[2, 5]", &arr));
+ TakeCAC(type, {kABNullDE, kABC}, "[4,
10]").Value(&chunked_arr));
}
-class TestTakeKernelWithList : public TestTakeKernelTyped<ListType> {};
+using ListAndListViewArrowTypes =
+ ::testing::Types<ListType, LargeListType, ListViewType, LargeListViewType>;
+
+template <typename ArrowListType>
+class TestTakeKernelWithList : public TestTakeKernelTyped<ListType> {
+ protected:
+ std::shared_ptr<DataType> inner_type_ = nullptr;
+
+ std::shared_ptr<DataType> value_type(std::shared_ptr<DataType> inner_type)
const {
+ return std::make_shared<ArrowListType>(std::move(inner_type));
+ }
+
+ std::shared_ptr<DataType> value_type() const override {
+ EXPECT_TRUE(inner_type_);
+ return value_type(inner_type_);
+ }
+
+ std::vector<std::shared_ptr<DataType>> InnerListTypes() const {
+ return std::vector<std::shared_ptr<DataType>>{
+ list(int32()),
+ large_list(int32()),
+ list_view(int32()),
+ large_list_view(int32()),
+ };
+ }
+};
+
+TYPED_TEST_SUITE(TestTakeKernelWithList, ListAndListViewArrowTypes);
-TEST_F(TestTakeKernelWithList, TakeListInt32) {
+TYPED_TEST(TestTakeKernelWithList, TakeListInt32) {
+ this->inner_type_ = int32();
std::string list_json = "[[], [1,2], null, [3]]";
- for (auto& type : kListAndListViewTypes) {
- CheckTake(type, list_json, "[]", "[]");
- CheckTake(type, list_json, "[3, 2, 1]", "[[3], null, [1,2]]");
- CheckTake(type, list_json, "[null, 3, 0]", "[null, [3], []]");
- CheckTake(type, list_json, "[null, null]", "[null, null]");
- CheckTake(type, list_json, "[3, 0, 0, 3]", "[[3], [], [], [3]]");
- CheckTake(type, list_json, "[0, 1, 2, 3]", list_json);
- CheckTake(type, list_json, "[0, 0, 0, 0, 0, 0, 1]",
- "[[], [], [], [], [], [], [1, 2]]");
+ {
+ this->CheckTakeXA(list_json, "[]", "[]");
+ this->CheckTakeXA(list_json, "[3, 2, 1]", "[[3], null, [1,2]]");
+ this->CheckTakeXA(list_json, "[null, 3, 0]", "[null, [3], []]");
+ this->CheckTakeXA(list_json, "[null, null]", "[null, null]");
+ this->CheckTakeXA(list_json, "[3, 0, 0, 3]", "[[3], [], [], [3]]");
+ this->CheckTakeXA(list_json, "[0, 1, 2, 3]", list_json);
+ this->CheckTakeXA(list_json, "[0, 0, 0, 0, 0, 0, 1]",
+ "[[], [], [], [], [], [], [1, 2]]");
- this->TestNoValidityBitmapButUnknownNullCount(type, "[[], [1,2], [3]]",
"[0, 1, 0]");
+ this->TestNoValidityBitmapButUnknownNullCount("[[], [1,2], [3]]", "[0, 1,
0]");
}
}
-TEST_F(TestTakeKernelWithList, TakeListListInt32) {
+TYPED_TEST(TestTakeKernelWithList, TakeListListInt32) {
std::string list_json = R"([
[],
[[1], [2, null, 2], []],
null,
[[3, null], null]
])";
- for (auto& type : kNestedListAndListViewTypes) {
- ARROW_SCOPED_TRACE("type = ", *type);
- CheckTake(type, list_json, "[]", "[]");
- CheckTake(type, list_json, "[3, 2, 1]", R"([
+ for (auto& inner_type : this->InnerListTypes()) {
+ this->inner_type_ = inner_type;
+ ARROW_SCOPED_TRACE("type = ", *this->value_type());
+ this->CheckTakeXA(list_json, "[]", "[]");
+ this->CheckTakeXA(list_json, "[3, 2, 1]", R"([
[[3, null], null],
null,
[[1], [2, null, 2], []]
])");
- CheckTake(type, list_json, "[null, 3, 0]", R"([
+ this->CheckTakeXA(list_json, "[null, 3, 0]", R"([
null,
[[3, null], null],
[]
])");
- CheckTake(type, list_json, "[null, null]", "[null, null]");
- CheckTake(type, list_json, "[3, 0, 0, 3]",
- "[[[3, null], null], [], [], [[3, null], null]]");
- CheckTake(type, list_json, "[0, 1, 2, 3]", list_json);
- CheckTake(type, list_json, "[0, 0, 0, 0, 0, 0, 1]",
- "[[], [], [], [], [], [], [[1], [2, null, 2], []]]");
+ this->CheckTakeXA(list_json, "[null, null]", "[null, null]");
+ this->CheckTakeXA(list_json, "[3, 0, 0, 3]",
+ "[[[3, null], null], [], [], [[3, null], null]]");
+ this->CheckTakeXA(list_json, "[0, 1, 2, 3]", list_json);
+ this->CheckTakeXA(list_json, "[0, 0, 0, 0, 0, 0, 1]",
+ "[[], [], [], [], [], [], [[1], [2, null, 2], []]]");
this->TestNoValidityBitmapButUnknownNullCount(
- type, "[[[1], [2, null, 2], []], [[3, null]]]", "[0, 1, 0]");
+ "[[[1], [2, null, 2], []], [[3, null]]]", "[0, 1, 0]");
}
}
-class TestTakeKernelWithLargeList : public TestTakeKernelTyped<LargeListType>
{};
-
-TEST_F(TestTakeKernelWithLargeList, TakeLargeListInt32) {
+TYPED_TEST(TestTakeKernelWithList, TakeLargeListInt32) {
+ this->inner_type_ = int32();
std::string list_json = "[[], [1,2], null, [3]]";
- for (auto& type : kLargeListAndListViewTypes) {
- ARROW_SCOPED_TRACE("type = ", *type);
- CheckTake(type, list_json, "[]", "[]");
- CheckTake(type, list_json, "[null, 1, 2, 0]", "[null, [1,2], null, []]");
+ {
+ ARROW_SCOPED_TRACE("type = ", *this->value_type());
+ this->CheckTakeXA(list_json, "[]", "[]");
+ this->CheckTakeXA(list_json, "[null, 1, 2, 0]", "[null, [1,2], null, []]");
}
}
class TestTakeKernelWithFixedSizeList : public
TestTakeKernelTyped<FixedSizeListType> {
protected:
- void CheckTakeOnNestedLists(const std::shared_ptr<DataType>& inner_type,
- const std::vector<int>& list_sizes, int64_t
length) {
+ std::shared_ptr<DataType> inner_type_ = nullptr;
+
+ std::shared_ptr<DataType> value_type() const override {
+ EXPECT_TRUE(inner_type_);
+ return fixed_size_list(inner_type_, 3);
+ }
+
+ void CheckTakeXAOnNestedLists(const std::shared_ptr<DataType>& inner_type,
+ const std::vector<int>& list_sizes, int64_t
length) {
using NLG = ::arrow::util::internal::NestedListGenerator;
// Create two equivalent lists: one as a FixedSizeList and another as a
List.
ASSERT_OK_AND_ASSIGN(auto fsl_list,
@@ -1544,51 +1816,50 @@ class TestTakeKernelWithFixedSizeList : public
TestTakeKernelTyped<FixedSizeList
auto indices = ArrayFromJSON(int64(), "[1, 2, 4]");
// Use the Take on ListType as the reference implementation.
- ASSERT_OK_AND_ASSIGN(auto expected_list, Take(*list, *indices));
+ ASSERT_OK_AND_ASSIGN(auto expected_list, TakeAAA(*list, *indices));
ASSERT_OK_AND_ASSIGN(auto expected_fsl, Cast(*expected_list,
fsl_list->type()));
- DoCheckTake(fsl_list, indices, expected_fsl);
+ DoCheckTakeXA(fsl_list, indices, expected_fsl);
}
};
TEST_F(TestTakeKernelWithFixedSizeList, TakeFixedSizeListInt32) {
+ inner_type_ = int32();
std::string list_json = "[null, [1, null, 3], [4, 5, 6], [7, 8, null]]";
- CheckTake(fixed_size_list(int32(), 3), list_json, "[]", "[]");
- CheckTake(fixed_size_list(int32(), 3), list_json, "[3, 2, 1]",
- "[[7, 8, null], [4, 5, 6], [1, null, 3]]");
- CheckTake(fixed_size_list(int32(), 3), list_json, "[null, 2, 0]",
- "[null, [4, 5, 6], null]");
- CheckTake(fixed_size_list(int32(), 3), list_json, "[null, null]", "[null,
null]");
- CheckTake(fixed_size_list(int32(), 3), list_json, "[3, 0, 0, 3]",
- "[[7, 8, null], null, null, [7, 8, null]]");
- CheckTake(fixed_size_list(int32(), 3), list_json, "[0, 1, 2, 3]", list_json);
+ CheckTakeXA(list_json, "[]", "[]");
+ CheckTakeXA(list_json, "[3, 2, 1]", "[[7, 8, null], [4, 5, 6], [1, null,
3]]");
+ CheckTakeXA(list_json, "[null, 2, 0]", "[null, [4, 5, 6], null]");
+ CheckTakeXA(list_json, "[null, null]", "[null, null]");
+ CheckTakeXA(list_json, "[3, 0, 0, 3]", "[[7, 8, null], null, null, [7, 8,
null]]");
+ CheckTakeXA(list_json, "[0, 1, 2, 3]", list_json);
// No nulls in inner list values trigger the use of FixedWidthTakeExec() in
// FSLTakeExec()
std::string no_nulls_list_json = "[[0, 0, 0], [1, 2, 3], [4, 5, 6], [7, 8,
9]]";
- CheckTake(
- fixed_size_list(int32(), 3), no_nulls_list_json, "[2, 2, 2, 2, 2, 2, 1]",
+ CheckTakeXA(
+ no_nulls_list_json, "[2, 2, 2, 2, 2, 2, 1]",
"[[4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [1,
2, 3]]");
- this->TestNoValidityBitmapButUnknownNullCount(fixed_size_list(int32(), 3),
- "[[1, null, 3], [4, 5, 6], [7,
8, null]]",
+ this->TestNoValidityBitmapButUnknownNullCount("[[1, null, 3], [4, 5, 6], [7,
8, null]]",
"[0, 1, 0]");
}
TEST_F(TestTakeKernelWithFixedSizeList, TakeFixedSizeListVarWidth) {
+ inner_type_ = utf8();
std::string list_json =
R"([["zero", "one", ""], ["two", "", "three"], ["four", "five", "six"],
["seven", "eight", ""]])";
- CheckTake(fixed_size_list(utf8(), 3), list_json, "[]", "[]");
- CheckTake(fixed_size_list(utf8(), 3), list_json, "[3, 2, 1]",
- R"([["seven", "eight", ""], ["four", "five", "six"], ["two", "",
"three"]])");
- CheckTake(fixed_size_list(utf8(), 3), list_json, "[null, 2, 0]",
- R"([null, ["four", "five", "six"], ["zero", "one", ""]])");
- CheckTake(fixed_size_list(utf8(), 3), list_json, R"([null, null])", "[null,
null]");
- CheckTake(
- fixed_size_list(utf8(), 3), list_json, "[3, 0, 0,3]",
+ CheckTakeXA(list_json, "[]", "[]");
+ CheckTakeXA(
+ list_json, "[3, 2, 1]",
+ R"([["seven", "eight", ""], ["four", "five", "six"], ["two", "",
"three"]])");
+ CheckTakeXA(list_json, "[null, 2, 0]",
+ R"([null, ["four", "five", "six"], ["zero", "one", ""]])");
+ CheckTakeXA(list_json, R"([null, null])", "[null, null]");
+ CheckTakeXA(
+ list_json, "[3, 0, 0,3]",
R"([["seven", "eight", ""], ["zero", "one", ""], ["zero", "one", ""],
["seven", "eight", ""]])");
- CheckTake(fixed_size_list(utf8(), 3), list_json, "[0, 1, 2, 3]", list_json);
- CheckTake(fixed_size_list(utf8(), 3), list_json, "[2, 2, 2, 2, 2, 2, 1]",
- R"([
+ CheckTakeXA(list_json, "[0, 1, 2, 3]", list_json);
+ CheckTakeXA(list_json, "[2, 2, 2, 2, 2, 2, 1]",
+ R"([
["four", "five", "six"], ["four", "five", "six"],
["four", "five", "six"], ["four", "five", "six"],
["four", "five", "six"], ["four", "five", "six"],
@@ -1606,11 +1877,14 @@ TEST_F(TestTakeKernelWithFixedSizeList,
TakeFixedSizeListModuloNesting) {
NLG::VisitAllNestedListConfigurations(
value_types, [this](const std::shared_ptr<DataType>& inner_type,
const std::vector<int>& list_sizes) {
- this->CheckTakeOnNestedLists(inner_type, list_sizes, /*length=*/5);
+ this->CheckTakeXAOnNestedLists(inner_type, list_sizes, /*length=*/5);
});
}
-class TestTakeKernelWithMap : public TestTakeKernelTyped<MapType> {};
+class TestTakeKernelWithMap : public TestTakeKernelTyped<MapType> {
+ protected:
+ std::shared_ptr<DataType> value_type() const override { return map(utf8(),
int32()); }
+};
TEST_F(TestTakeKernelWithMap, TakeMapStringToInt32) {
std::string map_json = R"([
@@ -1619,21 +1893,20 @@ TEST_F(TestTakeKernelWithMap, TakeMapStringToInt32) {
[["cap", 8]],
[]
])";
- CheckTake(map(utf8(), int32()), map_json, "[]", "[]");
- CheckTake(map(utf8(), int32()), map_json, "[3, 1, 3, 1, 3]",
- "[[], null, [], null, []]");
- CheckTake(map(utf8(), int32()), map_json, "[2, 1, null]", R"([
+ CheckTakeXA(map_json, "[]", "[]");
+ CheckTakeXA(map_json, "[3, 1, 3, 1, 3]", "[[], null, [], null, []]");
+ CheckTakeXA(map_json, "[2, 1, null]", R"([
[["cap", 8]],
null,
null
])");
- CheckTake(map(utf8(), int32()), map_json, "[2, 1, 0]", R"([
+ CheckTakeXA(map_json, "[2, 1, 0]", R"([
[["cap", 8]],
null,
[["joe", 0], ["mark", null]]
])");
- CheckTake(map(utf8(), int32()), map_json, "[0, 1, 2, 3]", map_json);
- CheckTake(map(utf8(), int32()), map_json, "[0, 0, 0, 0, 0, 0, 3]", R"([
+ CheckTakeXA(map_json, "[0, 1, 2, 3]", map_json);
+ CheckTakeXA(map_json, "[0, 0, 0, 0, 0, 0, 3]", R"([
[["joe", 0], ["mark", null]],
[["joe", 0], ["mark", null]],
[["joe", 0], ["mark", null]],
@@ -1644,31 +1917,34 @@ TEST_F(TestTakeKernelWithMap, TakeMapStringToInt32) {
])");
}
-class TestTakeKernelWithStruct : public TestTakeKernelTyped<StructType> {};
+class TestTakeKernelWithStruct : public TestTakeKernelTyped<StructType> {
+ std::shared_ptr<DataType> value_type() const override {
+ return struct_({field("a", int32()), field("b", utf8())});
+ }
+};
TEST_F(TestTakeKernelWithStruct, TakeStruct) {
- auto struct_type = struct_({field("a", int32()), field("b", utf8())});
auto struct_json = R"([
null,
{"a": 1, "b": ""},
{"a": 2, "b": "hello"},
{"a": 4, "b": "eh"}
])";
- CheckTake(struct_type, struct_json, "[]", "[]");
- CheckTake(struct_type, struct_json, "[3, 1, 3, 1, 3]", R"([
+ this->CheckTakeXA(struct_json, "[]", "[]");
+ this->CheckTakeXA(struct_json, "[3, 1, 3, 1, 3]", R"([
{"a": 4, "b": "eh"},
{"a": 1, "b": ""},
{"a": 4, "b": "eh"},
{"a": 1, "b": ""},
{"a": 4, "b": "eh"}
])");
- CheckTake(struct_type, struct_json, "[3, 1, 0]", R"([
+ this->CheckTakeXA(struct_json, "[3, 1, 0]", R"([
{"a": 4, "b": "eh"},
{"a": 1, "b": ""},
null
])");
- CheckTake(struct_type, struct_json, "[0, 1, 2, 3]", struct_json);
- CheckTake(struct_type, struct_json, "[0, 2, 2, 2, 2, 2, 2]", R"([
+ this->CheckTakeXA(struct_json, "[0, 1, 2, 3]", struct_json);
+ this->CheckTakeXA(struct_json, "[0, 2, 2, 2, 2, 2, 2]", R"([
null,
{"a": 2, "b": "hello"},
{"a": 2, "b": "hello"},
@@ -1678,16 +1954,30 @@ TEST_F(TestTakeKernelWithStruct, TakeStruct) {
{"a": 2, "b": "hello"}
])");
- this->TestNoValidityBitmapButUnknownNullCount(
- struct_type, R"([{"a": 1}, {"a": 2, "b": "hello"}])", "[0, 1, 0]");
+ this->TestNoValidityBitmapButUnknownNullCount(R"([{"a": 1}, {"a": 2, "b":
"hello"}])",
+ "[0, 1, 0]");
}
-class TestTakeKernelWithUnion : public TestTakeKernelTyped<UnionType> {};
+template <typename ArrowUnionType>
+class TestTakeKernelWithUnion : public TestTakeKernelTyped<ArrowUnionType> {
+ protected:
+ std::shared_ptr<DataType> value_type() const override {
+ return std::make_shared<ArrowUnionType>(
+ FieldVector{
+ field("a", int32()),
+ field("b", utf8()),
+ },
+ std::vector<int8_t>{
+ 2,
+ 5,
+ });
+ }
+};
+
+TYPED_TEST_SUITE(TestTakeKernelWithUnion, UnionArrowTypes);
-TEST_F(TestTakeKernelWithUnion, TakeUnion) {
- for (const auto& union_type :
- {dense_union({field("a", int32()), field("b", utf8())}, {2, 5}),
- sparse_union({field("a", int32()), field("b", utf8())}, {2, 5})}) {
+TYPED_TEST(TestTakeKernelWithUnion, TakeUnion) {
+ {
auto union_json = R"([
[2, 222],
[2, null],
@@ -1697,22 +1987,22 @@ TEST_F(TestTakeKernelWithUnion, TakeUnion) {
[2, 111],
[5, null]
])";
- CheckTake(union_type, union_json, "[]", "[]");
- CheckTake(union_type, union_json, "[3, 0, 3, 0, 3]", R"([
+ this->CheckTakeXA(union_json, "[]", "[]");
+ this->CheckTakeXA(union_json, "[3, 0, 3, 0, 3]", R"([
[5, "eh"],
[2, 222],
[5, "eh"],
[2, 222],
[5, "eh"]
])");
- CheckTake(union_type, union_json, "[4, 2, 0, 6]", R"([
+ this->CheckTakeXA(union_json, "[4, 2, 0, 6]", R"([
[2, null],
[5, "hello"],
[2, 222],
[5, null]
])");
- CheckTake(union_type, union_json, "[0, 1, 2, 3, 4, 5, 6]", union_json);
- CheckTake(union_type, union_json, "[1, 2, 2, 2, 2, 2, 2]", R"([
+ this->CheckTakeXA(union_json, "[0, 1, 2, 3, 4, 5, 6]", union_json);
+ this->CheckTakeXA(union_json, "[1, 2, 2, 2, 2, 2, 2]", R"([
[2, null],
[5, "hello"],
[5, "hello"],
@@ -1721,7 +2011,7 @@ TEST_F(TestTakeKernelWithUnion, TakeUnion) {
[5, "hello"],
[5, "hello"]
])");
- CheckTake(union_type, union_json, "[0, null, 1, null, 2, 2, 2]", R"([
+ this->CheckTakeXA(union_json, "[0, null, 1, null, 2, 2, 2]", R"([
[2, 222],
[2, null],
[2, null],
@@ -1735,72 +2025,58 @@ TEST_F(TestTakeKernelWithUnion, TakeUnion) {
class TestPermutationsWithTake : public ::testing::Test {
protected:
- void DoTake(const Int16Array& values, const Int16Array& indices,
- std::shared_ptr<Int16Array>* out) {
- ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> boxed_out, Take(values,
indices));
+ Result<std::shared_ptr<Int16Array>> DoTakeAAA(
+ const std::shared_ptr<Int16Array>& values,
+ const std::shared_ptr<Int16Array>& indices) {
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Array> boxed_out, TakeAAA(*values,
*indices));
ValidateOutput(boxed_out);
- *out = checked_pointer_cast<Int16Array>(std::move(boxed_out));
+ return checked_pointer_cast<Int16Array>(std::move(boxed_out));
}
- std::shared_ptr<Int16Array> DoTake(const Int16Array& values,
- const Int16Array& indices) {
- std::shared_ptr<Int16Array> out;
- DoTake(values, indices, &out);
- return out;
- }
-
- std::shared_ptr<Int16Array> DoTakeN(uint64_t n, std::shared_ptr<Int16Array>
array) {
+ Result<std::shared_ptr<Int16Array>> DoTakeN(uint64_t n,
+ std::shared_ptr<Int16Array>
array) {
auto power_of_2 = array;
- array = Identity(array->length());
+ ARROW_ASSIGN_OR_RAISE(array, Identity(array->length()));
while (n != 0) {
if (n & 1) {
- array = DoTake(*array, *power_of_2);
+ ARROW_ASSIGN_OR_RAISE(array, DoTakeAAA(array, power_of_2));
}
- power_of_2 = DoTake(*power_of_2, *power_of_2);
+ ARROW_ASSIGN_OR_RAISE(power_of_2, DoTakeAAA(power_of_2, power_of_2));
n >>= 1;
}
return array;
}
template <typename Rng>
- void Shuffle(const Int16Array& array, Rng& gen, std::shared_ptr<Int16Array>*
shuffled) {
+ Result<std::shared_ptr<Int16Array>> Shuffle(const Int16Array& array, Rng&
gen) {
auto byte_length = array.length() * sizeof(int16_t);
- ASSERT_OK_AND_ASSIGN(auto data, array.values()->CopySlice(0, byte_length));
+ ARROW_ASSIGN_OR_RAISE(auto data, array.values()->CopySlice(0,
byte_length));
auto mutable_data = reinterpret_cast<int16_t*>(data->mutable_data());
std::shuffle(mutable_data, mutable_data + array.length(), gen);
- shuffled->reset(new Int16Array(array.length(), data));
- }
-
- template <typename Rng>
- std::shared_ptr<Int16Array> Shuffle(const Int16Array& array, Rng& gen) {
- std::shared_ptr<Int16Array> out;
- Shuffle(array, gen, &out);
- return out;
+ return std::make_shared<Int16Array>(array.length(), data);
}
- void Identity(int64_t length, std::shared_ptr<Int16Array>* identity) {
+ Result<std::shared_ptr<Int16Array>> Identity(int64_t length) {
+ std::shared_ptr<Int16Array> identity;
Int16Builder identity_builder;
- ASSERT_OK(identity_builder.Resize(length));
+ RETURN_NOT_OK(identity_builder.Resize(length));
for (int16_t i = 0; i < length; ++i) {
identity_builder.UnsafeAppend(i);
}
- ASSERT_OK(identity_builder.Finish(identity));
- }
-
- std::shared_ptr<Int16Array> Identity(int64_t length) {
- std::shared_ptr<Int16Array> out;
- Identity(length, &out);
- return out;
+ RETURN_NOT_OK(identity_builder.Finish(&identity));
+ return identity;
}
- std::shared_ptr<Int16Array> Inverse(const std::shared_ptr<Int16Array>&
permutation) {
+ Result<std::shared_ptr<Int16Array>> Inverse(
+ const std::shared_ptr<Int16Array>& permutation) {
auto length = static_cast<int16_t>(permutation->length());
std::vector<bool> cycle_lengths(length + 1, false);
auto permutation_to_the_i = permutation;
for (int16_t cycle_length = 1; cycle_length <= length; ++cycle_length) {
cycle_lengths[cycle_length] = HasTrivialCycle(*permutation_to_the_i);
- permutation_to_the_i = DoTake(*permutation, *permutation_to_the_i);
+ ARROW_ASSIGN_OR_RAISE(permutation_to_the_i,
+ DoTakeAAA(permutation, permutation_to_the_i));
}
uint64_t cycle_to_identity_length = 1;
@@ -1836,42 +2112,18 @@ TEST_F(TestPermutationsWithTake, InvertPermutation) {
for (auto seed : std::vector<random::SeedType>({0, kRandomSeed, kRandomSeed
* 2 - 1})) {
std::default_random_engine gen(seed);
for (int16_t length = 0; length < 1 << 10; ++length) {
- auto identity = Identity(length);
- auto permutation = Shuffle(*identity, gen);
- auto inverse = Inverse(permutation);
+ ASSERT_OK_AND_ASSIGN(auto identity, Identity(length));
+ ASSERT_OK_AND_ASSIGN(auto permutation, Shuffle(*identity, gen));
+ ASSERT_OK_AND_ASSIGN(auto inverse, Inverse(permutation));
if (inverse == nullptr) {
break;
}
- ASSERT_TRUE(DoTake(*inverse, *permutation)->Equals(identity));
+ DoCheckTakeXA(inverse, permutation, identity);
}
}
}
-class TestTakeKernelWithRecordBatch : public TestTakeKernelTyped<RecordBatch> {
- public:
- void AssertTake(const std::shared_ptr<Schema>& schm, const std::string&
batch_json,
- const std::string& indices, const std::string&
expected_batch) {
- std::shared_ptr<RecordBatch> actual;
-
- for (auto index_type : {int8(), uint32()}) {
- ASSERT_OK(TakeJSON(schm, batch_json, index_type, indices, &actual));
- ValidateOutput(actual);
- ASSERT_BATCHES_EQUAL(*RecordBatchFromJSON(schm, expected_batch),
*actual);
- }
- }
-
- Status TakeJSON(const std::shared_ptr<Schema>& schm, const std::string&
batch_json,
- const std::shared_ptr<DataType>& index_type, const
std::string& indices,
- std::shared_ptr<RecordBatch>* out) {
- auto batch = RecordBatchFromJSON(schm, batch_json);
- ARROW_ASSIGN_OR_RAISE(Datum result,
- Take(Datum(batch), Datum(ArrayFromJSON(index_type,
indices))));
- *out = result.record_batch();
- return Status::OK();
- }
-};
-
-TEST_F(TestTakeKernelWithRecordBatch, TakeRecordBatch) {
+TEST(TestTakeKernelWithRecordBatch, TakeRecordBatch) {
std::vector<std::shared_ptr<Field>> fields = {field("a", int32()),
field("b", utf8())};
auto schm = schema(fields);
@@ -1881,21 +2133,21 @@ TEST_F(TestTakeKernelWithRecordBatch, TakeRecordBatch) {
{"a": 2, "b": "hello"},
{"a": 4, "b": "eh"}
])";
- this->AssertTake(schm, struct_json, "[]", "[]");
- this->AssertTake(schm, struct_json, "[3, 1, 3, 1, 3]", R"([
+ AssertTakeRAR(schm, struct_json, "[]", "[]");
+ AssertTakeRAR(schm, struct_json, "[3, 1, 3, 1, 3]", R"([
{"a": 4, "b": "eh"},
{"a": 1, "b": ""},
{"a": 4, "b": "eh"},
{"a": 1, "b": ""},
{"a": 4, "b": "eh"}
])");
- this->AssertTake(schm, struct_json, "[3, 1, 0]", R"([
+ AssertTakeRAR(schm, struct_json, "[3, 1, 0]", R"([
{"a": 4, "b": "eh"},
{"a": 1, "b": ""},
{"a": null, "b": "yo"}
])");
- this->AssertTake(schm, struct_json, "[0, 1, 2, 3]", struct_json);
- this->AssertTake(schm, struct_json, "[0, 2, 2, 2, 2, 2, 2]", R"([
+ AssertTakeRAR(schm, struct_json, "[0, 1, 2, 3]", struct_json);
+ AssertTakeRAR(schm, struct_json, "[0, 2, 2, 2, 2, 2, 2]", R"([
{"a": null, "b": "yo"},
{"a": 2, "b": "hello"},
{"a": 2, "b": "hello"},
@@ -1906,115 +2158,41 @@ TEST_F(TestTakeKernelWithRecordBatch, TakeRecordBatch)
{
])");
}
-class TestTakeKernelWithChunkedArray : public
TestTakeKernelTyped<ChunkedArray> {
- public:
- void AssertTake(const std::shared_ptr<DataType>& type,
- const std::vector<std::string>& values, const std::string&
indices,
- const std::vector<std::string>& expected) {
- std::shared_ptr<ChunkedArray> actual;
- ASSERT_OK(this->TakeWithArray(type, values, indices, &actual));
- ValidateOutput(actual);
- AssertChunkedEqual(*ChunkedArrayFromJSON(type, expected), *actual);
+TEST(TestTakeKernelWithChunkedIndices, TakeChunkedArray) {
+ for (auto& ty : {boolean(), int8(), uint64()}) {
+ AssertTakeCAC(ty, {"[]"}, "[]", {"[]"});
+ AssertTakeCCC(ty, {}, {}, {});
+ AssertTakeCCC(ty, {}, {"[]"}, {"[]"});
+ AssertTakeCCC(ty, {}, {"[null]"}, {"[null]"});
+ AssertTakeCCC(ty, {"[]"}, {}, {});
+ AssertTakeCCC(ty, {"[]"}, {"[]"}, {"[]"});
+ AssertTakeCCC(ty, {"[]"}, {"[null]"}, {"[null]"});
}
- void AssertChunkedTake(const std::shared_ptr<DataType>& type,
- const std::vector<std::string>& values,
- const std::vector<std::string>& indices,
- const std::vector<std::string>& expected) {
- std::shared_ptr<ChunkedArray> actual;
- ASSERT_OK(this->TakeWithChunkedArray(type, values, indices, &actual));
- ValidateOutput(actual);
- AssertChunkedEqual(*ChunkedArrayFromJSON(type, expected), *actual);
- }
+ AssertTakeCAC(boolean(), {"[true]", "[false, true]"}, "[0, 1, 0, 2]",
+ {"[true, false, true, true]"});
+ AssertTakeCCC(boolean(), {"[false]", "[true, false]"}, {"[0, 1, 0]", "[]",
"[2]"},
+ {"[false, true, false]", "[]", "[false]"});
+ AssertTakeCAC(boolean(), {"[true]", "[false, true]"}, "[2, 1]", {"[true,
false]"});
- Status TakeWithArray(const std::shared_ptr<DataType>& type,
- const std::vector<std::string>& values, const
std::string& indices,
- std::shared_ptr<ChunkedArray>* out) {
- ARROW_ASSIGN_OR_RAISE(Datum result, Take(ChunkedArrayFromJSON(type,
values),
- ArrayFromJSON(int8(), indices)));
- *out = result.chunked_array();
- return Status::OK();
- }
+ Datum chunked_arr;
+ for (auto& int_ty : SignedIntTypes()) {
+ AssertTakeCAC(int_ty, {"[7]", "[8, 9]"}, "[0, 1, 0, 2]", {"[7, 8, 7, 9]"});
+ AssertTakeCCC(int_ty, {"[7]", "[8, 9]"}, {"[0, 1, 0]", "[]", "[2]"},
+ {"[7, 8, 7]", "[]", "[9]"});
+ AssertTakeCAC(int_ty, {"[7]", "[8, 9]"}, "[2, 1]", {"[9, 8]"});
- Status TakeWithChunkedArray(const std::shared_ptr<DataType>& type,
- const std::vector<std::string>& values,
- const std::vector<std::string>& indices,
- std::shared_ptr<ChunkedArray>* out) {
- ARROW_ASSIGN_OR_RAISE(Datum result, Take(ChunkedArrayFromJSON(type,
values),
- ChunkedArrayFromJSON(int8(),
indices)));
- *out = result.chunked_array();
- return Status::OK();
+ ASSERT_RAISES(IndexError,
+ TakeCAC(int_ty, {"[7]", "[8, 9]"}, "[0,
5]").Value(&chunked_arr));
+ ASSERT_RAISES(
+ IndexError,
+ TakeCCC(int_ty, {"[7]", "[8, 9]"}, {"[0, 1, 0]", "[5,
1]"}).Value(&chunked_arr));
+ ASSERT_RAISES(IndexError, TakeCCC(int_ty, {},
{"[0]"}).Value(&chunked_arr));
+ ASSERT_RAISES(IndexError, TakeCCC(int_ty, {"[]"},
{"[0]"}).Value(&chunked_arr));
}
-};
-
-TEST_F(TestTakeKernelWithChunkedArray, TakeChunkedArray) {
- this->AssertTake(int8(), {"[]"}, "[]", {"[]"});
- this->AssertChunkedTake(int8(), {}, {}, {});
- this->AssertChunkedTake(int8(), {}, {"[]"}, {"[]"});
- this->AssertChunkedTake(int8(), {}, {"[null]"}, {"[null]"});
- this->AssertChunkedTake(int8(), {"[]"}, {}, {});
- this->AssertChunkedTake(int8(), {"[]"}, {"[]"}, {"[]"});
- this->AssertChunkedTake(int8(), {"[]"}, {"[null]"}, {"[null]"});
-
- this->AssertTake(int8(), {"[7]", "[8, 9]"}, "[0, 1, 0, 2]", {"[7, 8, 7,
9]"});
- this->AssertChunkedTake(int8(), {"[7]", "[8, 9]"}, {"[0, 1, 0]", "[]",
"[2]"},
- {"[7, 8, 7]", "[]", "[9]"});
- this->AssertTake(int8(), {"[7]", "[8, 9]"}, "[2, 1]", {"[9, 8]"});
-
- std::shared_ptr<ChunkedArray> arr;
- ASSERT_RAISES(IndexError,
- this->TakeWithArray(int8(), {"[7]", "[8, 9]"}, "[0, 5]",
&arr));
- ASSERT_RAISES(IndexError, this->TakeWithChunkedArray(int8(), {"[7]", "[8,
9]"},
- {"[0, 1, 0]", "[5,
1]"}, &arr));
- ASSERT_RAISES(IndexError, this->TakeWithChunkedArray(int8(), {}, {"[0]"},
&arr));
- ASSERT_RAISES(IndexError, this->TakeWithChunkedArray(int8(), {"[]"},
{"[0]"}, &arr));
}
-class TestTakeKernelWithTable : public TestTakeKernelTyped<Table> {
- public:
- void AssertTake(const std::shared_ptr<Schema>& schm,
- const std::vector<std::string>& table_json, const
std::string& filter,
- const std::vector<std::string>& expected_table) {
- std::shared_ptr<Table> actual;
-
- ASSERT_OK(this->TakeWithArray(schm, table_json, filter, &actual));
- ValidateOutput(actual);
- ASSERT_TABLES_EQUAL(*TableFromJSON(schm, expected_table), *actual);
- }
-
- void AssertChunkedTake(const std::shared_ptr<Schema>& schm,
- const std::vector<std::string>& table_json,
- const std::vector<std::string>& filter,
- const std::vector<std::string>& expected_table) {
- std::shared_ptr<Table> actual;
-
- ASSERT_OK(this->TakeWithChunkedArray(schm, table_json, filter, &actual));
- ValidateOutput(actual);
- ASSERT_TABLES_EQUAL(*TableFromJSON(schm, expected_table), *actual);
- }
-
- Status TakeWithArray(const std::shared_ptr<Schema>& schm,
- const std::vector<std::string>& values, const
std::string& indices,
- std::shared_ptr<Table>* out) {
- ARROW_ASSIGN_OR_RAISE(Datum result, Take(Datum(TableFromJSON(schm,
values)),
- Datum(ArrayFromJSON(int8(),
indices))));
- *out = result.table();
- return Status::OK();
- }
-
- Status TakeWithChunkedArray(const std::shared_ptr<Schema>& schm,
- const std::vector<std::string>& values,
- const std::vector<std::string>& indices,
- std::shared_ptr<Table>* out) {
- ARROW_ASSIGN_OR_RAISE(Datum result,
- Take(Datum(TableFromJSON(schm, values)),
- Datum(ChunkedArrayFromJSON(int8(), indices))));
- *out = result.table();
- return Status::OK();
- }
-};
-
-TEST_F(TestTakeKernelWithTable, TakeTable) {
+TEST(TestTakeKernelWithTable, TakeTable) {
std::vector<std::shared_ptr<Field>> fields = {field("a", int32()),
field("b", utf8())};
auto schm = schema(fields);
@@ -2022,11 +2200,12 @@ TEST_F(TestTakeKernelWithTable, TakeTable) {
"[{\"a\": null, \"b\": \"yo\"},{\"a\": 1, \"b\": \"\"}]",
"[{\"a\": 2, \"b\": \"hello\"},{\"a\": 4, \"b\": \"eh\"}]"};
- this->AssertTake(schm, table_json, "[]", {"[]"});
+ AssertTakeTAT(schm, table_json, "[]", {"[]"});
std::vector<std::string> expected_310 = {
- "[{\"a\": 4, \"b\": \"eh\"},{\"a\": 1, \"b\": \"\"},{\"a\": null, \"b\":
\"yo\"}]"};
- this->AssertTake(schm, table_json, "[3, 1, 0]", expected_310);
- this->AssertChunkedTake(schm, table_json, {"[0, 1]", "[2, 3]"}, table_json);
+ "[{\"a\": 4, \"b\": \"eh\"},{\"a\": 1, \"b\": \"\"},{\"a\": null, \"b\":
"
+ "\"yo\"}]"};
+ AssertTakeTAT(schm, table_json, "[3, 1, 0]", expected_310);
+ AssertTakeTCT(schm, table_json, {"[0, 1]", "[2, 3]"}, table_json);
}
TEST(TestTakeMetaFunction, ArityChecking) {
@@ -2066,14 +2245,14 @@ void CheckTakeRandom(const std::shared_ptr<Array>&
values, int64_t indices_lengt
max_index, null_probability);
auto indices_no_nulls = rand->Numeric<IndexType>(
indices_length, static_cast<IndexCType>(0), max_index,
/*null_probability=*/0.0);
- ValidateTake<ValuesType>(values, indices);
- ValidateTake<ValuesType>(values, indices_no_nulls);
+ ValidateTakeXA<ValuesType>(values, indices);
+ ValidateTakeXA<ValuesType>(values, indices_no_nulls);
// Sliced indices array
if (indices_length >= 2) {
indices = indices->Slice(1, indices_length - 2);
indices_no_nulls = indices_no_nulls->Slice(1, indices_length - 2);
- ValidateTake<ValuesType>(values, indices);
- ValidateTake<ValuesType>(values, indices_no_nulls);
+ ValidateTakeXA<ValuesType>(values, indices);
+ ValidateTakeXA<ValuesType>(values, indices_no_nulls);
}
}