lidavidm commented on a change in pull request #10557:
URL: https://github.com/apache/arrow/pull/10557#discussion_r663098052



##########
File path: cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
##########
@@ -316,5 +318,165 @@ TEST_F(TestIfElseKernel, IfElseDispatchBest) {
   CheckDispatchBest(name, {null(), uint8(), int8()}, {boolean(), int16(), 
int16()});
 }
 
+void CheckVarArgs(const std::string& name, const std::vector<Datum>& inputs,
+                  Datum expected) {
+  ASSERT_OK_AND_ASSIGN(Datum datum_out, CallFunction(name, inputs));
+  if (datum_out.is_array()) {
+    std::shared_ptr<Array> result = datum_out.make_array();
+    ASSERT_OK(result->ValidateFull());
+    std::shared_ptr<Array> expected_ = expected.make_array();
+    AssertArraysEqual(*expected_, *result, /*verbose=*/true);
+
+    for (int64_t i = 0; i < result->length(); i++) {
+      // Check scalar
+      ASSERT_OK_AND_ASSIGN(auto expected_scalar, expected_->GetScalar(i));
+      std::vector<Datum> inputs_scalar;
+      for (const auto& input : inputs) {
+        if (input.is_scalar()) {
+          inputs_scalar.push_back(input);
+        } else {
+          auto array = input.make_array();
+          ASSERT_OK_AND_ASSIGN(auto input_scalar, array->GetScalar(i));
+          inputs_scalar.push_back(input_scalar);
+        }
+      }
+      ASSERT_OK_AND_ASSIGN(auto scalar_out, CallFunction(name, inputs_scalar));
+      ASSERT_TRUE(scalar_out.is_scalar());
+      AssertScalarsEqual(*expected_scalar, *scalar_out.scalar(), 
/*verbose=*/true);
+
+      // Check slice
+      inputs_scalar.clear();
+      auto expected_array = expected_->Slice(i);
+      for (const auto& input : inputs) {
+        if (input.is_scalar()) {
+          inputs_scalar.push_back(input);
+        } else {
+          inputs_scalar.push_back(input.make_array()->Slice(i));
+        }
+      }
+      ASSERT_OK_AND_ASSIGN(auto array_out, CallFunction(name, inputs_scalar));
+      ASSERT_TRUE(array_out.is_array());
+      AssertArraysEqual(*expected_array, *array_out.make_array(), 
/*verbose=*/true);
+    }
+  } else {
+    const std::shared_ptr<Scalar>& result = datum_out.scalar();
+    const std::shared_ptr<Scalar>& expected_ = expected.scalar();
+    AssertScalarsEqual(*expected_, *result, /*verbose=*/true);
+  }
+}
+
+template <typename Type>
+class TestCaseWhenNumeric : public ::testing::Test {};
+
+TYPED_TEST_SUITE(TestCaseWhenNumeric, NumericBasedTypes);
+
+void CheckCaseWhenCases(const std::shared_ptr<DataType>& type, const 
std::string& value1,
+                        const std::string& value2) {
+  auto scalar_true = ScalarFromJSON(boolean(), "true");
+  auto scalar_false = ScalarFromJSON(boolean(), "false");
+  auto scalar_null = ScalarFromJSON(boolean(), "null");
+  auto cond1 = ArrayFromJSON(boolean(), "[true, false, false, null]");
+  auto cond2 = ArrayFromJSON(boolean(), "[true, false, null, true]");
+  auto value_null = ScalarFromJSON(type, "null");
+  auto scalar1 = ScalarFromJSON(type, value1);
+  auto scalar2 = ScalarFromJSON(type, value2);
+  auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+  std::stringstream builder;
+  builder << "[null, " << value1 << ',' << value1 << ',' << value1 << ']';
+  auto values1 = ArrayFromJSON(type, builder.str());
+  builder.str("");
+  builder << '[' << value2 << ',' << value2 << ',' << value2 << ',' << value2 
<< ']';
+  auto values2 = ArrayFromJSON(type, builder.str());
+  // N.B. all-scalar cases are checked in CheckCaseWhen
+  // Only an else array
+  CheckVarArgs("case_when", {values1}, values1);
+  // No else clause, scalar cond, array values
+  CheckVarArgs("case_when", {scalar_true, values1}, values1);
+  CheckVarArgs("case_when", {scalar_false, values1}, values_null);
+  CheckVarArgs("case_when", {scalar_null, values1}, values_null);
+  CheckVarArgs("case_when", {scalar_true, values1, scalar_null, values1}, 
values1);
+  CheckVarArgs("case_when", {scalar_null, values2, scalar_true, values1}, 
values1);
+  CheckVarArgs("case_when", {scalar_true, values1, scalar_true, values2}, 
values1);
+  // No else clause, array cond, scalar values
+  builder.str("");
+  builder << '[' << value1 << ", null, null, null]";
+  CheckVarArgs("case_when", {cond1, scalar1}, ArrayFromJSON(type, 
builder.str()));
+  CheckVarArgs("case_when", {cond1, value_null}, values_null);
+  builder.str("");
+  builder << '[' << value1 << ", null, null, " << value2 << ']';
+  CheckVarArgs("case_when", {cond1, scalar1, cond2, scalar2},
+               ArrayFromJSON(type, builder.str()));
+  // No else clause, array cond, array values
+  builder.str("");
+  builder << "[null, null, null, " << value2 << ']';
+  CheckVarArgs("case_when", {cond1, values1, cond2, values2},
+               ArrayFromJSON(type, builder.str()));
+  // Else clauses/mixed scalar and array
+  builder.str("");
+  builder << "[null, " << value1 << ',' << value1 << ',' << value2 << ']';
+  CheckVarArgs("case_when", {cond1, values1, cond2, values2, scalar1},
+               ArrayFromJSON(type, builder.str()));
+  CheckVarArgs("case_when", {cond1, values1, cond2, values2, values1},
+               ArrayFromJSON(type, builder.str()));
+}
+
+TYPED_TEST(TestCaseWhenNumeric, FixedSize) {
+  auto type = default_type_instance<TypeParam>();
+  CheckCaseWhenCases(type, "10", "42");

Review comment:
       I'll inline all tests as I've done with coalesce and choose.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to