This is an automated email from the ASF dual-hosted git repository.

zanmato pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 8ccdbe7806 GH-41336: [C++][Compute] Fix case_when kernel dispatch for 
decimals with different precisions and scales (#47479)
8ccdbe7806 is described below

commit 8ccdbe78063ad4b43872b8826aba37a1a73dc951
Author: Rossi Sun <zanmato1...@gmail.com>
AuthorDate: Wed Sep 3 18:14:07 2025 +0800

    GH-41336: [C++][Compute] Fix case_when kernel dispatch for decimals with 
different precisions and scales (#47479)
    
    ### Rationale for this change
    
    Another case of decimal kernels not able to suppress exact matching when 
precisions and scales of the arguments differ, causing wrong result type. After 
#47297, we have a systematic way to do that and guide the matching to go to the 
"best match" (applying implicit casts).
    
    ### What changes are included in this PR?
    
    Simply added a constraint match that checks if the precisions and scales of 
the decimal arguments are the same. Also added corresponding tests in forms of 
both expression (exact match first, then best match) and function call (best 
match only).
    
    ### Are these changes tested?
    
    Yes.
    
    ### Are there any user-facing changes?
    
    None.
    * GitHub Issue: #41336
    
    Authored-by: Rossi Sun <zanmato1...@gmail.com>
    Signed-off-by: Rossi Sun <zanmato1...@gmail.com>
---
 cpp/src/arrow/compute/expression_test.cc           |  45 ++++++++
 cpp/src/arrow/compute/kernel.cc                    |   2 +-
 cpp/src/arrow/compute/kernel.h                     |   8 +-
 cpp/src/arrow/compute/kernel_test.cc               |   4 +-
 cpp/src/arrow/compute/kernels/scalar_if_else.cc    |  25 ++++-
 .../arrow/compute/kernels/scalar_if_else_test.cc   | 116 +++++++++++++++++++++
 6 files changed, 189 insertions(+), 11 deletions(-)

diff --git a/cpp/src/arrow/compute/expression_test.cc 
b/cpp/src/arrow/compute/expression_test.cc
index ce5cf3826b..7e90f552ce 100644
--- a/cpp/src/arrow/compute/expression_test.cc
+++ b/cpp/src/arrow/compute/expression_test.cc
@@ -809,6 +809,51 @@ TEST(Expression, BindWithImplicitCasts) {
                 call("is_in", {cast(field_ref("dict_str"), utf8())}, in_a));
 }
 
+TEST(Expression, BindWithImplicitCastsForCaseWhenOnDecimal) {
+  auto exciting_schema = schema(
+      {field("a", struct_({field("", boolean())})),
+       field("dec128_20_3", decimal128(20, 3)), field("dec128_21_3", 
decimal128(21, 3)),
+       field("dec128_20_1", decimal128(20, 1)), field("dec128_21_1", 
decimal128(21, 1)),
+       field("dec256_20_3", decimal256(20, 3)), field("dec256_21_3", 
decimal256(21, 3)),
+       field("dec256_20_1", decimal256(20, 1)), field("dec256_21_1", 
decimal256(21, 1))});
+  ExpectBindsTo(call("case_when", {field_ref("a"), field_ref("dec128_20_3"),
+                                   field_ref("dec128_21_3")}),
+                call("case_when",
+                     {field_ref("a"), cast(field_ref("dec128_20_3"), 
decimal128(21, 3)),
+                      field_ref("dec128_21_3")}),
+                /*bound_out=*/nullptr, *exciting_schema);
+  ExpectBindsTo(call("case_when", {field_ref("a"), field_ref("dec128_20_1"),
+                                   field_ref("dec128_21_3")}),
+                call("case_when",
+                     {field_ref("a"), cast(field_ref("dec128_20_1"), 
decimal128(22, 3)),
+                      cast(field_ref("dec128_21_3"), decimal128(22, 3))}),
+                /*bound_out=*/nullptr, *exciting_schema);
+  ExpectBindsTo(call("case_when", {field_ref("a"), field_ref("dec128_20_3"),
+                                   field_ref("dec128_21_1")}),
+                call("case_when",
+                     {field_ref("a"), cast(field_ref("dec128_20_3"), 
decimal128(23, 3)),
+                      cast(field_ref("dec128_21_1"), decimal128(23, 3))}),
+                /*bound_out=*/nullptr, *exciting_schema);
+  ExpectBindsTo(call("case_when", {field_ref("a"), field_ref("dec128_20_3"),
+                                   field_ref("dec256_21_3")}),
+                call("case_when",
+                     {field_ref("a"), cast(field_ref("dec128_20_3"), 
decimal256(21, 3)),
+                      field_ref("dec256_21_3")}),
+                /*bound_out=*/nullptr, *exciting_schema);
+  ExpectBindsTo(call("case_when", {field_ref("a"), field_ref("dec256_20_1"),
+                                   field_ref("dec128_21_3")}),
+                call("case_when",
+                     {field_ref("a"), cast(field_ref("dec256_20_1"), 
decimal256(22, 3)),
+                      cast(field_ref("dec128_21_3"), decimal256(22, 3))}),
+                /*bound_out=*/nullptr, *exciting_schema);
+  ExpectBindsTo(call("case_when", {field_ref("a"), field_ref("dec256_20_3"),
+                                   field_ref("dec256_21_1")}),
+                call("case_when",
+                     {field_ref("a"), cast(field_ref("dec256_20_3"), 
decimal256(23, 3)),
+                      cast(field_ref("dec256_21_1"), decimal256(23, 3))}),
+                /*bound_out=*/nullptr, *exciting_schema);
+}
+
 TEST(Expression, BindNestedCall) {
   auto expr = add(field_ref("a"),
                   call("subtract", {call("multiply", {field_ref("b"), 
field_ref("c")}),
diff --git a/cpp/src/arrow/compute/kernel.cc b/cpp/src/arrow/compute/kernel.cc
index 17f583c75f..addbb29edd 100644
--- a/cpp/src/arrow/compute/kernel.cc
+++ b/cpp/src/arrow/compute/kernel.cc
@@ -478,7 +478,7 @@ std::string OutputType::ToString() const {
 // ----------------------------------------------------------------------
 // MatchConstraint
 
-std::shared_ptr<MatchConstraint> MakeConstraint(
+std::shared_ptr<MatchConstraint> MatchConstraint::Make(
     std::function<bool(const std::vector<TypeHolder>&)> matches) {
   class FunctionMatchConstraint : public MatchConstraint {
    public:
diff --git a/cpp/src/arrow/compute/kernel.h b/cpp/src/arrow/compute/kernel.h
index fa2e983469..0d4f9d6ff4 100644
--- a/cpp/src/arrow/compute/kernel.h
+++ b/cpp/src/arrow/compute/kernel.h
@@ -356,11 +356,11 @@ class ARROW_EXPORT MatchConstraint {
 
   /// \brief Return true if the input types satisfy the constraint.
   virtual bool Matches(const std::vector<TypeHolder>& types) const = 0;
-};
 
-/// \brief Convenience function to create a MatchConstraint from a match 
function.
-ARROW_EXPORT std::shared_ptr<MatchConstraint> MakeConstraint(
-    std::function<bool(const std::vector<TypeHolder>&)> matches);
+  /// \brief Convenience function to create a MatchConstraint from a match 
function.
+  static std::shared_ptr<MatchConstraint> Make(
+      std::function<bool(const std::vector<TypeHolder>&)> matches);
+};
 
 /// \brief Constraint that all input types are decimal types and have the same 
scale.
 ARROW_EXPORT std::shared_ptr<MatchConstraint> DecimalsHaveSameScale();
diff --git a/cpp/src/arrow/compute/kernel_test.cc 
b/cpp/src/arrow/compute/kernel_test.cc
index 374d1458f4..9317ae7a42 100644
--- a/cpp/src/arrow/compute/kernel_test.cc
+++ b/cpp/src/arrow/compute/kernel_test.cc
@@ -313,7 +313,7 @@ TEST(OutputType, Resolve) {
 TEST(MatchConstraint, ConvenienceMaker) {
   {
     auto always_match =
-        MakeConstraint([](const std::vector<TypeHolder>& types) { return true; 
});
+        MatchConstraint::Make([](const std::vector<TypeHolder>& types) { 
return true; });
 
     ASSERT_TRUE(always_match->Matches({}));
     ASSERT_TRUE(always_match->Matches({int8(), int16(), int32()}));
@@ -321,7 +321,7 @@ TEST(MatchConstraint, ConvenienceMaker) {
 
   {
     auto always_false =
-        MakeConstraint([](const std::vector<TypeHolder>& types) { return 
false; });
+        MatchConstraint::Make([](const std::vector<TypeHolder>& types) { 
return false; });
 
     ASSERT_FALSE(always_false->Matches({}));
     ASSERT_FALSE(always_false->Matches({int8(), int16(), int32()}));
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc 
b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
index 753cc4de9f..d885db4cd9 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
@@ -1451,6 +1451,20 @@ struct CaseWhenFunction : ScalarFunction {
     if (auto kernel = DispatchExactImpl(this, *types)) return kernel;
     return arrow::compute::detail::NoMatchingKernel(this, *types);
   }
+
+  static std::shared_ptr<MatchConstraint> DecimalMatchConstraint() {
+    static auto constraint =
+        MatchConstraint::Make([](const std::vector<TypeHolder>& types) -> bool 
{
+          DCHECK_GE(types.size(), 2);
+          DCHECK(std::all_of(types.begin() + 1, types.end(), [](const 
TypeHolder& type) {
+            return is_decimal(type.id());
+          }));
+          return std::all_of(
+              types.begin() + 2, types.end(),
+              [&types](const TypeHolder& type) { return type == types[1]; });
+        });
+    return constraint;
+  }
 };
 
 // Implement a 'case when' (SQL)/'select' (NumPy) function for any scalar 
conditions
@@ -2712,10 +2726,11 @@ struct ChooseFunction : ScalarFunction {
 };
 
 void AddCaseWhenKernel(const std::shared_ptr<CaseWhenFunction>& 
scalar_function,
-                       detail::GetTypeId get_id, ArrayKernelExec exec) {
+                       detail::GetTypeId get_id, ArrayKernelExec exec,
+                       std::shared_ptr<MatchConstraint> constraint = nullptr) {
   ScalarKernel kernel(
       KernelSignature::Make({InputType(Type::STRUCT), InputType(get_id.id)}, 
LastType,
-                            /*is_varargs=*/true),
+                            /*is_varargs=*/true, std::move(constraint)),
       exec);
   if (is_fixed_width(get_id.id)) {
     kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE;
@@ -2890,8 +2905,10 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
     AddPrimitiveCaseWhenKernels(func, {boolean(), null(), float16()});
     AddCaseWhenKernel(func, Type::FIXED_SIZE_BINARY,
                       CaseWhenFunctor<FixedSizeBinaryType>::Exec);
-    AddCaseWhenKernel(func, Type::DECIMAL128, 
CaseWhenFunctor<FixedSizeBinaryType>::Exec);
-    AddCaseWhenKernel(func, Type::DECIMAL256, 
CaseWhenFunctor<FixedSizeBinaryType>::Exec);
+    AddCaseWhenKernel(func, Type::DECIMAL128, 
CaseWhenFunctor<FixedSizeBinaryType>::Exec,
+                      CaseWhenFunction::DecimalMatchConstraint());
+    AddCaseWhenKernel(func, Type::DECIMAL256, 
CaseWhenFunctor<FixedSizeBinaryType>::Exec,
+                      CaseWhenFunction::DecimalMatchConstraint());
     AddBinaryCaseWhenKernels(func, BaseBinaryTypes());
     AddNestedCaseWhenKernels(func);
     DCHECK_OK(registry->AddFunction(std::move(func)));
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc 
b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
index 2ff11dab43..e007a16d13 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
@@ -1807,6 +1807,74 @@ TEST(TestCaseWhen, Decimal) {
   }
 }
 
+TEST(TestCaseWhen, DecimalPromotion) {
+  auto check_case_when_decimal_promotion =
+      [](std::shared_ptr<Scalar> body_true, std::shared_ptr<Scalar> body_false,
+         std::shared_ptr<Scalar> promoted_true, std::shared_ptr<Scalar> 
promoted_false) {
+        auto cond_true = ScalarFromJSON(boolean(), "true");
+        auto cond_false = ScalarFromJSON(boolean(), "false");
+        CheckScalar("case_when", {MakeStruct({cond_true}), body_true, 
body_false},
+                    promoted_true);
+        CheckScalar("case_when", {MakeStruct({cond_false}), body_true, 
body_false},
+                    promoted_false);
+      };
+
+  const std::vector<std::pair<int, int>> precisions = {{10, 20}, {15, 15}, 
{20, 10}};
+  const std::vector<std::pair<int, int>> scales = {{3, 9}, {6, 6}, {9, 3}};
+  for (auto p : precisions) {
+    for (auto s : scales) {
+      auto p1 = p.first;
+      auto s1 = s.first;
+      auto p2 = p.second;
+      auto s2 = s.second;
+
+      auto max_scale = std::max({s1, s2});
+      auto scale_up_1 = max_scale - s1;
+      auto scale_up_2 = max_scale - s2;
+      auto max_precision = std::max({p1 + scale_up_1, p2 + scale_up_2});
+
+      // Operand string: 444.777...
+      std::string str_d1 =
+          R"(")" + std::string(p1 - s1, '4') + "." + std::string(s1, '7') + 
R"(")";
+      std::string str_d2 =
+          R"(")" + std::string(p2 - s2, '4') + "." + std::string(s2, '7') + 
R"(")";
+
+      // Promoted string: 444.777...000
+      std::string str_d1_promoted = R"(")" + std::string(p1 - s1, '4') + "." +
+                                    std::string(s1, '7') +
+                                    std::string(max_scale - s1, '0') + R"(")";
+      std::string str_d2_promoted = R"(")" + std::string(p2 - s2, '4') + "." +
+                                    std::string(s2, '7') +
+                                    std::string(max_scale - s2, '0') + R"(")";
+
+      auto d128_1 = decimal128(p1, s1);
+      auto d128_2 = decimal128(p2, s2);
+      auto d256_1 = decimal256(p1, s1);
+      auto d256_2 = decimal256(p2, s2);
+      auto d128_promoted = decimal128(max_precision, max_scale);
+      auto d256_promoted = decimal256(max_precision, max_scale);
+
+      auto scalar128_1 = ScalarFromJSON(d128_1, str_d1);
+      auto scalar128_2 = ScalarFromJSON(d128_2, str_d2);
+      auto scalar256_1 = ScalarFromJSON(d256_1, str_d1);
+      auto scalar256_2 = ScalarFromJSON(d256_2, str_d2);
+      auto scalar128_d1_promoted = ScalarFromJSON(d128_promoted, 
str_d1_promoted);
+      auto scalar128_d2_promoted = ScalarFromJSON(d128_promoted, 
str_d2_promoted);
+      auto scalar256_d1_promoted = ScalarFromJSON(d256_promoted, 
str_d1_promoted);
+      auto scalar256_d2_promoted = ScalarFromJSON(d256_promoted, 
str_d2_promoted);
+
+      check_case_when_decimal_promotion(scalar128_1, scalar128_2, 
scalar128_d1_promoted,
+                                        scalar128_d2_promoted);
+      check_case_when_decimal_promotion(scalar128_1, scalar256_2, 
scalar256_d1_promoted,
+                                        scalar256_d2_promoted);
+      check_case_when_decimal_promotion(scalar256_1, scalar128_2, 
scalar256_d1_promoted,
+                                        scalar256_d2_promoted);
+      check_case_when_decimal_promotion(scalar256_1, scalar256_2, 
scalar256_d1_promoted,
+                                        scalar256_d2_promoted);
+    }
+  }
+}
+
 TEST(TestCaseWhen, FixedSizeBinary) {
   auto type = fixed_size_binary(3);
   auto cond_true = ScalarFromJSON(boolean(), "true");
@@ -2509,6 +2577,28 @@ TEST(TestCaseWhen, UnionBoolStringRandom) {
   }
 }
 
+TEST(TestCaseWhen, DispatchExact) {
+  // Decimal types with same (p, s)
+  CheckDispatchExact("case_when", {struct_({field("", boolean())}), 
decimal128(20, 3),
+                                   decimal128(20, 3)});
+  CheckDispatchExact("case_when", {struct_({field("", boolean())}), 
decimal256(20, 3),
+                                   decimal256(20, 3)});
+
+  // Decimal types with different (p, s)
+  CheckDispatchExactFails("case_when", {struct_({field("", boolean())}),
+                                        decimal128(20, 3), decimal128(21, 3)});
+  CheckDispatchExactFails("case_when", {struct_({field("", boolean())}),
+                                        decimal128(20, 1), decimal128(20, 3)});
+  CheckDispatchExactFails("case_when", {struct_({field("", boolean())}),
+                                        decimal128(20, 3), decimal256(20, 3)});
+  CheckDispatchExactFails("case_when", {struct_({field("", boolean())}),
+                                        decimal256(20, 3), decimal128(21, 3)});
+  CheckDispatchExactFails("case_when", {struct_({field("", boolean())}),
+                                        decimal256(20, 3), decimal256(21, 3)});
+  CheckDispatchExactFails("case_when", {struct_({field("", boolean())}),
+                                        decimal256(20, 1), decimal256(20, 3)});
+}
+
 TEST(TestCaseWhen, DispatchBest) {
   CheckDispatchBest("case_when", {struct_({field("", boolean())}), int64(), 
int32()},
                     {struct_({field("", boolean())}), int64(), int64()});
@@ -2559,6 +2649,32 @@ TEST(TestCaseWhen, DispatchBest) {
   CheckDispatchBest(
       "case_when", {struct_({field("", boolean())}), dictionary(int64(), 
utf8()), utf8()},
       {struct_({field("", boolean())}), utf8(), utf8()});
+
+  // Decimal promotion
+  CheckDispatchBest(
+      "case_when",
+      {struct_({field("", boolean())}), decimal128(20, 3), decimal128(21, 3)},
+      {struct_({field("", boolean())}), decimal128(21, 3), decimal128(21, 3)});
+  CheckDispatchBest(
+      "case_when",
+      {struct_({field("", boolean())}), decimal128(20, 1), decimal128(21, 3)},
+      {struct_({field("", boolean())}), decimal128(22, 3), decimal128(22, 3)});
+  CheckDispatchBest(
+      "case_when",
+      {struct_({field("", boolean())}), decimal128(20, 3), decimal128(21, 1)},
+      {struct_({field("", boolean())}), decimal128(23, 3), decimal128(23, 3)});
+  CheckDispatchBest(
+      "case_when",
+      {struct_({field("", boolean())}), decimal128(20, 3), decimal256(21, 3)},
+      {struct_({field("", boolean())}), decimal256(21, 3), decimal256(21, 3)});
+  CheckDispatchBest(
+      "case_when",
+      {struct_({field("", boolean())}), decimal256(20, 1), decimal128(21, 3)},
+      {struct_({field("", boolean())}), decimal256(22, 3), decimal256(22, 3)});
+  CheckDispatchBest(
+      "case_when",
+      {struct_({field("", boolean())}), decimal256(20, 3), decimal256(21, 1)},
+      {struct_({field("", boolean())}), decimal256(23, 3), decimal256(23, 3)});
 }
 
 template <typename Type>

Reply via email to