pitrou commented on code in PR #36825:
URL: https://github.com/apache/arrow/pull/36825#discussion_r1272117415


##########
cpp/src/arrow/compute/kernels/scalar_if_else_test.cc:
##########
@@ -126,75 +126,105 @@ TYPED_TEST(TestIfElsePrimitive, IfElseFixedSizeRand) {
   CheckIfElseOutput(cond, left, right, expected_data);
 }
 
-void CheckWithDifferentShapes(const std::shared_ptr<Array>& cond,
-                              const std::shared_ptr<Array>& left,
-                              const std::shared_ptr<Array>& right,
-                              const std::shared_ptr<Array>& expected) {
-  // this will check for whole arrays, every scalar at i'th index and slicing 
(offset)
-  CheckScalar("if_else", {cond, left, right}, expected);
+Datum ArrayOrBroadcastScalar(const Datum& input, int64_t length) {
+  if (input.is_scalar()) {
+    EXPECT_OK_AND_ASSIGN(auto array, MakeArrayFromScalar(*input.scalar(), 
length));
+    return array;
+  }
+  EXPECT_TRUE(input.is_array());
+  return input;
+}
 
-  auto len = left->length();
-  std::vector<int64_t> array_indices = {-1};  // sentinel for make_input
-  std::vector<int64_t> scalar_indices(len);
-  std::iota(scalar_indices.begin(), scalar_indices.end(), 0);
-  auto make_input = [&](const std::shared_ptr<Array>& array, int64_t index, 
Datum* input,
-                        Datum* input_broadcast, std::string* trace) {
-    if (index >= 0) {
-      // Use scalar from array[index] as input; broadcast scalar for computing 
expected
-      // result
-      ASSERT_OK_AND_ASSIGN(auto scalar, array->GetScalar(index));
-      *trace += "@" + std::to_string(index) + "=" + scalar->ToString();
-      *input = std::move(scalar);
-      ASSERT_OK_AND_ASSIGN(*input_broadcast, 
MakeArrayFromScalar(*input->scalar(), len));
+Result<Datum> ExpectedFromIfElse(const Datum& cond, const Datum& left, const 
Datum& right,
+                                 std::shared_ptr<DataType> type) {
+  if (cond.is_scalar() && left.is_scalar() && right.is_scalar()) {
+    const auto& scalar = cond.scalar_as<BooleanScalar>();
+    Datum expected;
+    if (scalar.is_valid) {
+      expected = scalar.value ? left : right;
     } else {
-      // Use array as input
-      *trace += "=Array";
-      *input = *input_broadcast = array;
+      expected = MakeNullScalar(left.type());
     }
-  };
-
-  enum { COND_SCALAR = 1, LEFT_SCALAR = 2, RIGHT_SCALAR = 4 };
-  for (int mask = 1; mask <= (COND_SCALAR | LEFT_SCALAR | RIGHT_SCALAR); 
++mask) {
-    for (int64_t cond_idx : (mask & COND_SCALAR) ? scalar_indices : 
array_indices) {
-      Datum cond_in, cond_bcast;
-      std::string trace_cond = "Cond";
-      make_input(cond, cond_idx, &cond_in, &cond_bcast, &trace_cond);
-
-      for (int64_t left_idx : (mask & LEFT_SCALAR) ? scalar_indices : 
array_indices) {
-        Datum left_in, left_bcast;
-        std::string trace_left = "Left";
-        make_input(left, left_idx, &left_in, &left_bcast, &trace_left);
-
-        for (int64_t right_idx : (mask & RIGHT_SCALAR) ? scalar_indices : 
array_indices) {
-          Datum right_in, right_bcast;
-          std::string trace_right = "Right";
-          make_input(right, right_idx, &right_in, &right_bcast, &trace_right);
-
-          SCOPED_TRACE(trace_right);
-          SCOPED_TRACE(trace_left);
-          SCOPED_TRACE(trace_cond);
-
-          Datum expected;
-          ASSERT_OK_AND_ASSIGN(auto actual, IfElse(cond_in, left_in, 
right_in));
-          if (mask == (COND_SCALAR | LEFT_SCALAR | RIGHT_SCALAR)) {
-            const auto& scalar = cond_in.scalar_as<BooleanScalar>();
-            if (scalar.is_valid) {
-              expected = scalar.value ? left_in : right_in;
-            } else {
-              expected = MakeNullScalar(left_in.type());
-            }
-            if (!left_in.type()->Equals(*right_in.type())) {
-              ASSERT_OK_AND_ASSIGN(expected,
-                                   Cast(expected, 
CastOptions::Safe(actual.type())));
-            }
-          } else {
-            ASSERT_OK_AND_ASSIGN(expected, IfElse(cond_bcast, left_bcast, 
right_bcast));
-          }
-          AssertDatumsEqual(expected, actual, /*verbose=*/true);
+    if (!left.type()->Equals(*right.type())) {
+      return Cast(expected, CastOptions::Safe(std::move(type)));
+    }
+    return expected;
+  }
+  // When at least one of the inputs is an array, we expect the output
+  // to be the same as if all the scalars were broadcast to arrays.
+  const auto expected_length =
+      std::max(cond.length(), std::max(left.length(), right.length()));
+  SCOPED_TRACE("IfElseAAACall");
+  return IfElse(ArrayOrBroadcastScalar(cond, expected_length),
+                ArrayOrBroadcastScalar(left, expected_length),
+                ArrayOrBroadcastScalar(right, expected_length));
+}
+
+bool NextScalarOrWholeArray(const std::shared_ptr<Array>& array, int* index, 
Datum* out) {
+  if (*index < array->length()) {
+    EXPECT_OK_AND_ASSIGN(auto scalar, array->GetScalar(*index));
+    *out = std::move(scalar);
+    *index += 1;
+    return true;
+  }
+  *out = array;
+  return false;
+}
+
+std::string CodedCallName(const Datum& cond, const Datum& left, const Datum& 
right) {
+  std::string coded = "IfElse";
+  coded += cond.is_scalar() ? "S" : "A";
+  coded += left.is_scalar() ? "S" : "A";
+  coded += right.is_scalar() ? "S" : "A";
+  coded += "Call";
+  return coded;
+}
+
+void DoCheckWithDifferentShapes(const std::shared_ptr<Array>& cond,
+                                const std::shared_ptr<Array>& left,
+                                const std::shared_ptr<Array>& right) {
+  auto make_trace([&](const char* name, const Datum& datum, int index) {
+    std::string trace = name;
+    trace += " : ";
+    if (datum.is_scalar()) {
+      trace += "Scalar@" + std::to_string(index) + " = " + 
datum.scalar()->ToString();
+    } else {
+      EXPECT_TRUE(datum.is_array());
+      trace += "Array = [...]";
+    }
+    return trace;
+  });
+  Datum cond_in;
+  Datum left_in;
+  Datum right_in;
+  int cond_index = 0;
+  int left_index = 0;
+  int right_index = 0;
+  while (NextScalarOrWholeArray(cond, &cond_index, &cond_in)) {

Review Comment:
   `NextScalarOrWholeArray` returns false when it fills an array in its last 
argument... so the loop doesn't run in that case. I suppose that's not expected?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to