felipecrv commented on code in PR #43292:
URL: https://github.com/apache/arrow/pull/43292#discussion_r1682891645


##########
cpp/src/arrow/compute/kernels/vector_selection_test.cc:
##########
@@ -1101,67 +1102,277 @@ TEST(TestFilterMetaFunction, ArityChecking) {
 
 // ----------------------------------------------------------------------
 // Take tests
+//
+// Shorthand notation (as defined in `TakeMetaFunction`):
+//
+//   A = Array
+//   C = ChunkedArray
+//   R = RecordBatch
+//   T = Table
+//
+// (e.g. TakeCAC = Take(ChunkedArray, Array) -> ChunkedArray)
+//
+// The interface implemented by `TakeMetaFunction` is:
+//
+//   Take(A, A) -> A  (TakeAAA)
+//   Take(A, C) -> C  (TakeACC)
+//   Take(C, A) -> C  (TakeCAC)
+//   Take(C, C) -> C  (TakeCCC)
+//   Take(R, A) -> R  (TakeRAR)
+//   Take(T, A) -> T  (TakeTAT)
+//   Take(T, C) -> T  (TakeTCT)
+//
+// The tests extend the notation with a few "union kinds":
+//
+//   X = Array | ChunkedArray
+//
+// Examples:
+//
+//   TakeXA = {TakeAAA, TakeCAC},
+//   TakeXX = {TakeAAA, TakeACC, TakeCAC, TakeCCC}
 
-void AssertTakeArrays(const std::shared_ptr<Array>& values,
-                      const std::shared_ptr<Array>& indices,
-                      const std::shared_ptr<Array>& expected) {
-  ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> actual, Take(*values, *indices));
-  ValidateOutput(actual);
-  AssertArraysEqual(*expected, *actual, /*verbose=*/true);
+Result<std::shared_ptr<Array>> TakeAAA(const Array& values, const Array& 
indices) {
+  return Take(values, indices);
 }
 
-Status TakeJSON(const std::shared_ptr<DataType>& type, const std::string& 
values,
-                const std::shared_ptr<DataType>& index_type, const 
std::string& indices,
-                std::shared_ptr<Array>* out) {
-  return Take(*ArrayFromJSON(type, values), *ArrayFromJSON(index_type, 
indices))
+Status TakeAAA(const std::shared_ptr<DataType>& type, const std::string& 
values,
+               const std::shared_ptr<DataType>& index_type, const std::string& 
indices,
+               std::shared_ptr<Array>* out) {
+  return TakeAAA(*ArrayFromJSON(type, values), *ArrayFromJSON(index_type, 
indices))
       .Value(out);
 }
 
-void DoCheckTake(const std::shared_ptr<Array>& values,
-                 const std::shared_ptr<Array>& indices,
-                 const std::shared_ptr<Array>& expected) {
-  AssertTakeArrays(values, indices, expected);
+Result<Datum> TakeACC(const std::shared_ptr<Array>& values,
+                      const std::shared_ptr<ChunkedArray>& indices) {
+  return Take(values, indices);
+}
+
+Result<Datum> TakeCAC(std::shared_ptr<ChunkedArray> values,
+                      std::shared_ptr<Array> indices) {
+  return Take(values, indices);
+}
+
+Status TakeCAC(const std::shared_ptr<DataType>& type,
+               const std::vector<std::string>& values, const std::string& 
indices,
+               std::shared_ptr<ChunkedArray>* out) {
+  ARROW_ASSIGN_OR_RAISE(Datum result, TakeCAC(ChunkedArrayFromJSON(type, 
values),
+                                              ArrayFromJSON(int8(), indices)));
+  *out = result.chunked_array();
+  return Status::OK();
+}
+
+Result<Datum> TakeCCC(const std::shared_ptr<ChunkedArray>& values,
+                      const std::shared_ptr<ChunkedArray>& indices) {
+  return Take(values, indices);
+}
+
+Status TakeCCC(const std::shared_ptr<DataType>& type,
+               const std::vector<std::string>& values,
+               const std::vector<std::string>& indices,
+               std::shared_ptr<ChunkedArray>* out) {
+  ARROW_ASSIGN_OR_RAISE(Datum result, Take(ChunkedArrayFromJSON(type, values),
+                                           ChunkedArrayFromJSON(int8(), 
indices)));
+  *out = result.chunked_array();
+  return Status::OK();
+}
+
+Status TakeRAR(const std::shared_ptr<Schema>& schm, const std::string& 
batch_json,
+               const std::shared_ptr<DataType>& index_type, const std::string& 
indices,
+               std::shared_ptr<RecordBatch>* out) {
+  auto batch = RecordBatchFromJSON(schm, batch_json);
+  ARROW_ASSIGN_OR_RAISE(Datum result,
+                        Take(Datum(batch), Datum(ArrayFromJSON(index_type, 
indices))));
+  *out = result.record_batch();
+  return Status::OK();
+}
+
+Status TakeTAT(const std::shared_ptr<Schema>& schm,
+               const std::vector<std::string>& values, const std::string& 
indices,
+               std::shared_ptr<Table>* out) {
+  ARROW_ASSIGN_OR_RAISE(Datum result, Take(Datum(TableFromJSON(schm, values)),
+                                           Datum(ArrayFromJSON(int8(), 
indices))));
+  *out = result.table();
+  return Status::OK();
+}
+
+Status TakeTCT(const std::shared_ptr<Schema>& schm,
+               const std::vector<std::string>& values,
+               const std::vector<std::string>& indices, 
std::shared_ptr<Table>* out) {
+  ARROW_ASSIGN_OR_RAISE(Datum result, Take(Datum(TableFromJSON(schm, values)),
+                                           Datum(ChunkedArrayFromJSON(int8(), 
indices))));
+  *out = result.table();
+  return Status::OK();
+}
+
+// Assert helpers for Take tests
+
+void DoAssertTakeAAA(const std::shared_ptr<Array>& values,
+                     const std::shared_ptr<Array>& indices,
+                     const std::shared_ptr<Array>& expected) {
+  ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> actual, TakeAAA(*values, 
*indices));
+  ValidateOutput(actual);
+  AssertArraysEqual(*expected, *actual, /*verbose=*/true);
+}
+
+void DoCheckTakeAAA(const std::shared_ptr<Array>& values,
+                    const std::shared_ptr<Array>& indices,
+                    const std::shared_ptr<Array>& expected) {
+  DoAssertTakeAAA(values, indices, expected);
 
   // Check sliced values
   ASSERT_OK_AND_ASSIGN(auto values_filler, MakeArrayOfNull(values->type(), 2));
   ASSERT_OK_AND_ASSIGN(auto values_sliced,
                        Concatenate({values_filler, values, values_filler}));
   values_sliced = values_sliced->Slice(2, values->length());
-  AssertTakeArrays(values_sliced, indices, expected);
+  DoAssertTakeAAA(values_sliced, indices, expected);
 
   // Check sliced indices
   ASSERT_OK_AND_ASSIGN(auto zero, MakeScalar(indices->type(), int8_t{0}));
   ASSERT_OK_AND_ASSIGN(auto indices_filler, MakeArrayFromScalar(*zero, 3));
   ASSERT_OK_AND_ASSIGN(auto indices_sliced,
                        Concatenate({indices_filler, indices, indices_filler}));
   indices_sliced = indices_sliced->Slice(3, indices->length());
-  AssertTakeArrays(values, indices_sliced, expected);
+  DoAssertTakeAAA(values, indices_sliced, expected);
 }
 
-void CheckTake(const std::shared_ptr<DataType>& type, const std::string& 
values_json,
-               const std::string& indices_json, const std::string& 
expected_json) {
+void CheckTakeAAA(const std::shared_ptr<DataType>& type, const std::string& 
values_json,
+                  const std::string& indices_json, const std::string& 
expected_json) {
   auto values = ArrayFromJSON(type, values_json);
   auto expected = ArrayFromJSON(type, expected_json);
   for (auto index_type : {int8(), uint32()}) {
     auto indices = ArrayFromJSON(index_type, indices_json);
-    DoCheckTake(values, indices, expected);
+    DoCheckTakeAAA(values, indices, expected);
   }
 }
 
-void AssertTakeNull(const std::string& values, const std::string& indices,
-                    const std::string& expected) {
-  CheckTake(null(), values, indices, expected);
+// TakeXA = {TakeAAA, TakeCAC}
+void CheckTakeXA(const std::shared_ptr<Array>& values,
+                 const std::shared_ptr<Array>& indices,
+                 const std::shared_ptr<Array>& expected) {
+  auto pool = default_memory_pool();
+
+  ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> actual, TakeAAA(*values, 
*indices));
+  ValidateOutput(actual);
+  AssertArraysEqual(*expected, *actual, /*verbose=*/true);
+
+  // We check TakeCAC by checking this equality:
+  //
+  // TakeAAA(Concat(V, V, V), I') == Concat(TakeCAC([V, V, V], I'))
+  // where
+  //   V = values
+  //   I = indices
+  //   I' = Concat(I + 2 * V.length, I,  I + V.length)

Review Comment:
   Right. Can I add it as a second "exercise" here? Because if there is a bug 
that can be caught by the simpler one I would rather debug the simpler one.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to