pitrou commented on code in PR #43292:
URL: https://github.com/apache/arrow/pull/43292#discussion_r1682902369


##########
cpp/src/arrow/compute/kernels/vector_selection_test.cc:
##########
@@ -1101,67 +1102,277 @@ TEST(TestFilterMetaFunction, ArityChecking) {
 
 // ----------------------------------------------------------------------
 // Take tests
+//
+// Shorthand notation (as defined in `TakeMetaFunction`):
+//
+//   A = Array
+//   C = ChunkedArray
+//   R = RecordBatch
+//   T = Table
+//
+// (e.g. TakeCAC = Take(ChunkedArray, Array) -> ChunkedArray)
+//
+// The interface implemented by `TakeMetaFunction` is:
+//
+//   Take(A, A) -> A  (TakeAAA)
+//   Take(A, C) -> C  (TakeACC)
+//   Take(C, A) -> C  (TakeCAC)
+//   Take(C, C) -> C  (TakeCCC)
+//   Take(R, A) -> R  (TakeRAR)
+//   Take(T, A) -> T  (TakeTAT)
+//   Take(T, C) -> T  (TakeTCT)
+//
+// The tests extend the notation with a few "union kinds":
+//
+//   X = Array | ChunkedArray
+//
+// Examples:
+//
+//   TakeXA = {TakeAAA, TakeCAC},
+//   TakeXX = {TakeAAA, TakeACC, TakeCAC, TakeCCC}
 
-void AssertTakeArrays(const std::shared_ptr<Array>& values,
-                      const std::shared_ptr<Array>& indices,
-                      const std::shared_ptr<Array>& expected) {
-  ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> actual, Take(*values, *indices));
-  ValidateOutput(actual);
-  AssertArraysEqual(*expected, *actual, /*verbose=*/true);
+Result<std::shared_ptr<Array>> TakeAAA(const Array& values, const Array& 
indices) {
+  return Take(values, indices);
 }
 
-Status TakeJSON(const std::shared_ptr<DataType>& type, const std::string& 
values,
-                const std::shared_ptr<DataType>& index_type, const 
std::string& indices,
-                std::shared_ptr<Array>* out) {
-  return Take(*ArrayFromJSON(type, values), *ArrayFromJSON(index_type, 
indices))
+Status TakeAAA(const std::shared_ptr<DataType>& type, const std::string& 
values,
+               const std::shared_ptr<DataType>& index_type, const std::string& 
indices,
+               std::shared_ptr<Array>* out) {
+  return TakeAAA(*ArrayFromJSON(type, values), *ArrayFromJSON(index_type, 
indices))
       .Value(out);
 }
 
-void DoCheckTake(const std::shared_ptr<Array>& values,
-                 const std::shared_ptr<Array>& indices,
-                 const std::shared_ptr<Array>& expected) {
-  AssertTakeArrays(values, indices, expected);
+Result<Datum> TakeACC(const std::shared_ptr<Array>& values,
+                      const std::shared_ptr<ChunkedArray>& indices) {
+  return Take(values, indices);
+}
+
+Result<Datum> TakeCAC(std::shared_ptr<ChunkedArray> values,
+                      std::shared_ptr<Array> indices) {
+  return Take(values, indices);
+}
+
+Status TakeCAC(const std::shared_ptr<DataType>& type,
+               const std::vector<std::string>& values, const std::string& 
indices,
+               std::shared_ptr<ChunkedArray>* out) {
+  ARROW_ASSIGN_OR_RAISE(Datum result, TakeCAC(ChunkedArrayFromJSON(type, 
values),
+                                              ArrayFromJSON(int8(), indices)));
+  *out = result.chunked_array();
+  return Status::OK();
+}
+
+Result<Datum> TakeCCC(const std::shared_ptr<ChunkedArray>& values,
+                      const std::shared_ptr<ChunkedArray>& indices) {
+  return Take(values, indices);
+}
+
+Status TakeCCC(const std::shared_ptr<DataType>& type,
+               const std::vector<std::string>& values,
+               const std::vector<std::string>& indices,
+               std::shared_ptr<ChunkedArray>* out) {
+  ARROW_ASSIGN_OR_RAISE(Datum result, Take(ChunkedArrayFromJSON(type, values),
+                                           ChunkedArrayFromJSON(int8(), 
indices)));
+  *out = result.chunked_array();
+  return Status::OK();
+}
+
+Status TakeRAR(const std::shared_ptr<Schema>& schm, const std::string& 
batch_json,
+               const std::shared_ptr<DataType>& index_type, const std::string& 
indices,
+               std::shared_ptr<RecordBatch>* out) {
+  auto batch = RecordBatchFromJSON(schm, batch_json);
+  ARROW_ASSIGN_OR_RAISE(Datum result,
+                        Take(Datum(batch), Datum(ArrayFromJSON(index_type, 
indices))));
+  *out = result.record_batch();
+  return Status::OK();
+}
+
+Status TakeTAT(const std::shared_ptr<Schema>& schm,
+               const std::vector<std::string>& values, const std::string& 
indices,
+               std::shared_ptr<Table>* out) {
+  ARROW_ASSIGN_OR_RAISE(Datum result, Take(Datum(TableFromJSON(schm, values)),
+                                           Datum(ArrayFromJSON(int8(), 
indices))));
+  *out = result.table();
+  return Status::OK();
+}
+
+Status TakeTCT(const std::shared_ptr<Schema>& schm,
+               const std::vector<std::string>& values,
+               const std::vector<std::string>& indices, 
std::shared_ptr<Table>* out) {
+  ARROW_ASSIGN_OR_RAISE(Datum result, Take(Datum(TableFromJSON(schm, values)),
+                                           Datum(ChunkedArrayFromJSON(int8(), 
indices))));
+  *out = result.table();
+  return Status::OK();
+}
+
+// Assert helpers for Take tests
+
+void DoAssertTakeAAA(const std::shared_ptr<Array>& values,
+                     const std::shared_ptr<Array>& indices,
+                     const std::shared_ptr<Array>& expected) {
+  ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> actual, TakeAAA(*values, 
*indices));
+  ValidateOutput(actual);
+  AssertArraysEqual(*expected, *actual, /*verbose=*/true);
+}
+
+void DoCheckTakeAAA(const std::shared_ptr<Array>& values,
+                    const std::shared_ptr<Array>& indices,
+                    const std::shared_ptr<Array>& expected) {
+  DoAssertTakeAAA(values, indices, expected);
 
   // Check sliced values
   ASSERT_OK_AND_ASSIGN(auto values_filler, MakeArrayOfNull(values->type(), 2));
   ASSERT_OK_AND_ASSIGN(auto values_sliced,
                        Concatenate({values_filler, values, values_filler}));
   values_sliced = values_sliced->Slice(2, values->length());
-  AssertTakeArrays(values_sliced, indices, expected);
+  DoAssertTakeAAA(values_sliced, indices, expected);
 
   // Check sliced indices
   ASSERT_OK_AND_ASSIGN(auto zero, MakeScalar(indices->type(), int8_t{0}));
   ASSERT_OK_AND_ASSIGN(auto indices_filler, MakeArrayFromScalar(*zero, 3));
   ASSERT_OK_AND_ASSIGN(auto indices_sliced,
                        Concatenate({indices_filler, indices, indices_filler}));
   indices_sliced = indices_sliced->Slice(3, indices->length());
-  AssertTakeArrays(values, indices_sliced, expected);
+  DoAssertTakeAAA(values, indices_sliced, expected);
 }
 
-void CheckTake(const std::shared_ptr<DataType>& type, const std::string& 
values_json,
-               const std::string& indices_json, const std::string& 
expected_json) {
+void CheckTakeAAA(const std::shared_ptr<DataType>& type, const std::string& 
values_json,
+                  const std::string& indices_json, const std::string& 
expected_json) {
   auto values = ArrayFromJSON(type, values_json);
   auto expected = ArrayFromJSON(type, expected_json);
   for (auto index_type : {int8(), uint32()}) {
     auto indices = ArrayFromJSON(index_type, indices_json);
-    DoCheckTake(values, indices, expected);
+    DoCheckTakeAAA(values, indices, expected);
   }
 }
 
-void AssertTakeNull(const std::string& values, const std::string& indices,
-                    const std::string& expected) {
-  CheckTake(null(), values, indices, expected);
+// TakeXA = {TakeAAA, TakeCAC}
+void CheckTakeXA(const std::shared_ptr<Array>& values,
+                 const std::shared_ptr<Array>& indices,
+                 const std::shared_ptr<Array>& expected) {
+  auto pool = default_memory_pool();
+
+  ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> actual, TakeAAA(*values, 
*indices));
+  ValidateOutput(actual);
+  AssertArraysEqual(*expected, *actual, /*verbose=*/true);
+
+  // We check TakeCAC by checking this equality:
+  //
+  // TakeAAA(Concat(V, V, V), I') == Concat(TakeCAC([V, V, V], I'))
+  // where
+  //   V = values
+  //   I = indices
+  //   I' = Concat(I + 2 * V.length, I,  I + V.length)
+  auto values3 = ArrayVector{values, values, values};
+  ASSERT_OK_AND_ASSIGN(auto concat_values3, Concatenate(values3, pool));
+  auto chunked_values3 = std::make_shared<ChunkedArray>(values3);
+  std::shared_ptr<Array> concat_indices3;
+  {
+    Int32Scalar double_length(static_cast<int32_t>(2 * values->length()));
+    Int32Scalar zero(static_cast<int32_t>(values->length()));
+    Int32Scalar length(static_cast<int32_t>(values->length()));
+    ASSERT_OK_AND_ASSIGN(auto indices_prefix, Add(indices, double_length));
+    ASSERT_OK_AND_ASSIGN(auto indices_middle, Add(indices, zero));
+    ASSERT_OK_AND_ASSIGN(auto indices_suffix, Add(indices, length));
+    auto indices3 = ArrayVector{
+        indices_prefix.make_array(),
+        indices_middle.make_array(),
+        indices_suffix.make_array(),
+    };
+    ASSERT_OK_AND_ASSIGN(concat_indices3, Concatenate(indices3, pool));
+  }
+  ASSERT_OK_AND_ASSIGN(auto concat_expected3,
+                       Concatenate({expected, expected, expected}));
+  ASSERT_OK_AND_ASSIGN(Datum chunked_actual, TakeCAC(chunked_values3, 
concat_indices3));
+  ValidateOutput(chunked_actual);
+  ASSERT_OK_AND_ASSIGN(auto concat_actual,
+                       Concatenate(chunked_actual.chunked_array()->chunks()));
+  AssertArraysEqual(*concat_expected3, *concat_actual, /*verbose=*/true);
+}
+
+void CheckTakeXADictionary(std::shared_ptr<DataType> value_type,
+                           const std::string& dictionary_values,
+                           const std::string& dictionary_indices,
+                           const std::string& indices,
+                           const std::string& expected_indices) {
+  auto dict = ArrayFromJSON(value_type, dictionary_values);
+  auto type = dictionary(int8(), value_type);
+  ASSERT_OK_AND_ASSIGN(
+      auto values,
+      DictionaryArray::FromArrays(type, ArrayFromJSON(int8(), 
dictionary_indices), dict));
+  ASSERT_OK_AND_ASSIGN(
+      auto expected,
+      DictionaryArray::FromArrays(type, ArrayFromJSON(int8(), 
expected_indices), dict));
+  auto take_indices = ArrayFromJSON(int8(), indices);
+  CheckTakeXA(values, take_indices, expected);
+}
+
+void AssertTakeCAC(const std::shared_ptr<DataType>& type,
+                   const std::vector<std::string>& values, const std::string& 
indices,
+                   const std::vector<std::string>& expected) {
+  std::shared_ptr<ChunkedArray> actual;
+  ASSERT_OK(TakeCAC(type, values, indices, &actual));
+  ValidateOutput(actual);
+  AssertChunkedEqual(*ChunkedArrayFromJSON(type, expected), *actual);
+}
+
+void AssertTakeCCC(const std::shared_ptr<DataType>& type,
+                   const std::vector<std::string>& values,
+                   const std::vector<std::string>& indices,
+                   const std::vector<std::string>& expected) {
+  std::shared_ptr<ChunkedArray> actual;
+  ASSERT_OK(TakeCCC(type, values, indices, &actual));
+  ValidateOutput(actual);
+  AssertChunkedEqual(*ChunkedArrayFromJSON(type, expected), *actual);
+}
+
+void CheckTakeXCC(const Datum& values, const std::vector<std::string>& indices,
+                  const std::vector<std::string>& expected) {
+  EXPECT_TRUE(values.is_array() || values.is_chunked_array());
+  auto idx = ChunkedArrayFromJSON(int32(), indices);
+  ASSERT_OK_AND_ASSIGN(auto actual, Take(values, Datum{idx}));
+  ValidateOutput(actual);
+  AssertChunkedEqual(*ChunkedArrayFromJSON(values.type(), expected),
+                     *actual.chunked_array());
+}
+
+void AssertTakeRAR(const std::shared_ptr<Schema>& schm, const std::string& 
batch_json,
+                   const std::string& indices, const std::string& 
expected_batch) {
+  std::shared_ptr<RecordBatch> actual;
+
+  for (auto index_type : {int8(), uint32()}) {
+    ASSERT_OK(TakeRAR(schm, batch_json, index_type, indices, &actual));
+    ValidateOutput(actual);
+    ASSERT_BATCHES_EQUAL(*RecordBatchFromJSON(schm, expected_batch), *actual);
+  }
+}

Review Comment:
   Oops, sorry.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to