This is an automated email from the ASF dual-hosted git repository.

zanmato pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new dadc21f31b GH-47287: [C++][Compute] Add constraint for kernel 
signature matching and use it for binary decimal arithmetic kernels (#47297)
dadc21f31b is described below

commit dadc21f31b1d43cab4eb763f5e65691ce1f08fd5
Author: Rossi Sun <[email protected]>
AuthorDate: Wed Aug 20 23:20:54 2025 +0800

    GH-47287: [C++][Compute] Add constraint for kernel signature matching and 
use it for binary decimal arithmetic kernels (#47297)
    
    ### Rationale for this change
    
    A rework of #40223 using a more systematic alternative.
    
    ### What changes are included in this PR?
    
    Introduce a structure `MatchConstraint` for applying extra (and optional) 
matching constraint for kernel signature matching, in additional to simply 
input type checks.
    
    Also implement two concrete `MatchConstraint`s for binary decimal 
arithmetic kernels, to suppress exact match even if the input types are OK, for 
example, by requiring all decimal must be of the same scale for `add` and 
`subtract`, and s1 >= s2 for `divide`.
    
    This should also be a fundamental enhancement to further resolve similar 
issues like:
    * #35843
    * #39875
    * #40911
    * #41011
    * #41336
    (Haven't try each one of them. May do that if this PR gets merged.)
    
    ### Are these changes tested?
    
    UT included.
    
    ### Are there any user-facing changes?
    
    New public class `MatchConstraint`.
    * GitHub Issue: #47287
    
    Authored-by: Rossi Sun <[email protected]>
    Signed-off-by: Rossi Sun <[email protected]>
---
 cpp/src/arrow/compute/function.cc                  |  7 +-
 cpp/src/arrow/compute/function.h                   |  3 +-
 cpp/src/arrow/compute/kernel.cc                    | 83 ++++++++++++++++++++--
 cpp/src/arrow/compute/kernel.h                     | 35 +++++++--
 cpp/src/arrow/compute/kernel_test.cc               | 80 +++++++++++++++++++++
 cpp/src/arrow/compute/kernels/scalar_arithmetic.cc | 17 ++---
 .../compute/kernels/scalar_arithmetic_test.cc      | 55 ++++++++++++++
 .../arrow/compute/kernels/test_util_internal.cc    | 10 +++
 cpp/src/arrow/compute/kernels/test_util_internal.h |  6 ++
 9 files changed, 271 insertions(+), 25 deletions(-)

diff --git a/cpp/src/arrow/compute/function.cc 
b/cpp/src/arrow/compute/function.cc
index d7842f14ef..b0b12a690f 100644
--- a/cpp/src/arrow/compute/function.cc
+++ b/cpp/src/arrow/compute/function.cc
@@ -410,14 +410,15 @@ Status Function::Validate() const {
 }
 
 Status ScalarFunction::AddKernel(std::vector<InputType> in_types, OutputType 
out_type,
-                                 ArrayKernelExec exec, KernelInit init) {
+                                 ArrayKernelExec exec, KernelInit init,
+                                 std::shared_ptr<MatchConstraint> constraint) {
   RETURN_NOT_OK(CheckArity(in_types.size()));
 
   if (arity_.is_varargs && in_types.size() != 1) {
     return Status::Invalid("VarArgs signatures must have exactly one input 
type");
   }
-  auto sig =
-      KernelSignature::Make(std::move(in_types), std::move(out_type), 
arity_.is_varargs);
+  auto sig = KernelSignature::Make(std::move(in_types), std::move(out_type),
+                                   arity_.is_varargs, std::move(constraint));
   kernels_.emplace_back(std::move(sig), exec, init);
   return Status::OK();
 }
diff --git a/cpp/src/arrow/compute/function.h b/cpp/src/arrow/compute/function.h
index 2b86f64216..399081e2a7 100644
--- a/cpp/src/arrow/compute/function.h
+++ b/cpp/src/arrow/compute/function.h
@@ -308,7 +308,8 @@ class ARROW_EXPORT ScalarFunction : public 
detail::FunctionImpl<ScalarKernel> {
   /// initialization, preallocation for fixed-width types, and default null
   /// handling (intersect validity bitmaps of inputs).
   Status AddKernel(std::vector<InputType> in_types, OutputType out_type,
-                   ArrayKernelExec exec, KernelInit init = NULLPTR);
+                   ArrayKernelExec exec, KernelInit init = NULLPTR,
+                   std::shared_ptr<MatchConstraint> constraint = NULLPTR);
 
   /// \brief Add a kernel (function implementation). Returns error if the
   /// kernel's signature does not match the function's arity.
diff --git a/cpp/src/arrow/compute/kernel.cc b/cpp/src/arrow/compute/kernel.cc
index 4211593adf..f7fecc9247 100644
--- a/cpp/src/arrow/compute/kernel.cc
+++ b/cpp/src/arrow/compute/kernel.cc
@@ -475,23 +475,93 @@ std::string OutputType::ToString() const {
   return "computed";
 }
 
+// ----------------------------------------------------------------------
+// MatchConstraint
+
+std::shared_ptr<MatchConstraint> MakeConstraint(
+    std::function<bool(const std::vector<TypeHolder>&)> matches) {
+  class FunctionMatchConstraint : public MatchConstraint {
+   public:
+    explicit FunctionMatchConstraint(
+        std::function<bool(const std::vector<TypeHolder>&)> matches)
+        : matches_(std::move(matches)) {}
+
+    bool Matches(const std::vector<TypeHolder>& types) const override {
+      return matches_(types);
+    }
+
+   private:
+    std::function<bool(const std::vector<TypeHolder>&)> matches_;
+  };
+
+  return std::make_shared<FunctionMatchConstraint>(std::move(matches));
+}
+
+std::shared_ptr<MatchConstraint> DecimalsHaveSameScale() {
+  class DecimalsHaveSameScaleConstraint : public MatchConstraint {
+   public:
+    bool Matches(const std::vector<TypeHolder>& types) const override {
+      DCHECK_GE(types.size(), 2);
+      DCHECK(std::all_of(types.begin(), types.end(),
+                         [](const TypeHolder& type) { return 
is_decimal(type.id()); }));
+      const auto& ty0 = checked_cast<const DecimalType&>(*types[0].type);
+      auto s0 = ty0.scale();
+      for (size_t i = 1; i < types.size(); ++i) {
+        const auto& ty = checked_cast<const DecimalType&>(*types[i].type);
+        if (ty.scale() != s0) {
+          return false;
+        }
+      }
+      return true;
+    }
+  };
+  static auto instance = std::make_shared<DecimalsHaveSameScaleConstraint>();
+  return instance;
+}
+
+namespace {
+
+template <typename Op>
+class BinaryDecimalScaleComparisonConstraint : public MatchConstraint {
+ public:
+  bool Matches(const std::vector<TypeHolder>& types) const override {
+    DCHECK_EQ(types.size(), 2);
+    DCHECK(is_decimal(types[0].id()));
+    DCHECK(is_decimal(types[1].id()));
+    const auto& ty0 = checked_cast<const DecimalType&>(*types[0].type);
+    const auto& ty1 = checked_cast<const DecimalType&>(*types[1].type);
+    return Op{}(ty0.scale(), ty1.scale());
+  }
+};
+
+}  // namespace
+
+std::shared_ptr<MatchConstraint> BinaryDecimalScale1GeScale2() {
+  using BinaryDecimalScale1GeScale2Constraint =
+      BinaryDecimalScaleComparisonConstraint<std::greater_equal<>>;
+  static auto instance = 
std::make_shared<BinaryDecimalScale1GeScale2Constraint>();
+  return instance;
+}
+
 // ----------------------------------------------------------------------
 // KernelSignature
 
 KernelSignature::KernelSignature(std::vector<InputType> in_types, OutputType 
out_type,
-                                 bool is_varargs)
+                                 bool is_varargs,
+                                 std::shared_ptr<MatchConstraint> constraint)
     : in_types_(std::move(in_types)),
       out_type_(std::move(out_type)),
       is_varargs_(is_varargs),
+      constraint_(std::move(constraint)),
       hash_code_(0) {
   DCHECK(!is_varargs || (is_varargs && (in_types_.size() >= 1)));
 }
 
-std::shared_ptr<KernelSignature> KernelSignature::Make(std::vector<InputType> 
in_types,
-                                                       OutputType out_type,
-                                                       bool is_varargs) {
+std::shared_ptr<KernelSignature> KernelSignature::Make(
+    std::vector<InputType> in_types, OutputType out_type, bool is_varargs,
+    std::shared_ptr<MatchConstraint> constraint) {
   return std::make_shared<KernelSignature>(std::move(in_types), 
std::move(out_type),
-                                           is_varargs);
+                                           is_varargs, std::move(constraint));
 }
 
 bool KernelSignature::Equals(const KernelSignature& other) const {
@@ -526,6 +596,9 @@ bool KernelSignature::MatchesInputs(const 
std::vector<TypeHolder>& types) const
       }
     }
   }
+  if (constraint_ && !constraint_->Matches(types)) {
+    return false;
+  }
   return true;
 }
 
diff --git a/cpp/src/arrow/compute/kernel.h b/cpp/src/arrow/compute/kernel.h
index b429c40ac3..fdcdb134de 100644
--- a/cpp/src/arrow/compute/kernel.h
+++ b/cpp/src/arrow/compute/kernel.h
@@ -348,7 +348,28 @@ class ARROW_EXPORT OutputType {
   Resolver resolver_ = NULLPTR;
 };
 
-/// \brief Holds the input types and output type of the kernel.
+/// \brief Additional constraints to apply to the input types of a kernel when 
matching a
+/// specific kernel signature.
+class ARROW_EXPORT MatchConstraint {
+ public:
+  virtual ~MatchConstraint() = default;
+
+  /// \brief Return true if the input types satisfy the constraint.
+  virtual bool Matches(const std::vector<TypeHolder>& types) const = 0;
+};
+
+/// \brief Convenience function to create a MatchConstraint from a match 
function.
+ARROW_EXPORT std::shared_ptr<MatchConstraint> MakeConstraint(
+    std::function<bool(const std::vector<TypeHolder>&)> matches);
+
+/// \brief Constraint that all input types are decimal types and have the same 
scale.
+ARROW_EXPORT std::shared_ptr<MatchConstraint> DecimalsHaveSameScale();
+
+/// \brief Constraint that all binary input types are decimal types and the 
first type's
+/// scale >= the second type's.
+ARROW_EXPORT std::shared_ptr<MatchConstraint> BinaryDecimalScale1GeScale2();
+
+/// \brief Holds the input types, optional match constraint and output type of 
the kernel.
 ///
 /// VarArgs functions with minimum N arguments should pass up to N input types 
to be
 /// used to validate the input types of a function invocation. The first N-1 
types
@@ -357,15 +378,16 @@ class ARROW_EXPORT OutputType {
 class ARROW_EXPORT KernelSignature {
  public:
   KernelSignature(std::vector<InputType> in_types, OutputType out_type,
-                  bool is_varargs = false);
+                  bool is_varargs = false,
+                  std::shared_ptr<MatchConstraint> constraint = NULLPTR);
 
   /// \brief Convenience ctor since make_shared can be awkward
-  static std::shared_ptr<KernelSignature> Make(std::vector<InputType> in_types,
-                                               OutputType out_type,
-                                               bool is_varargs = false);
+  static std::shared_ptr<KernelSignature> Make(
+      std::vector<InputType> in_types, OutputType out_type, bool is_varargs = 
false,
+      std::shared_ptr<MatchConstraint> constraint = NULLPTR);
 
   /// \brief Return true if the signature is compatible with the list of input
-  /// value descriptors.
+  /// value descriptors and satisfies the match constraint, if any.
   bool MatchesInputs(const std::vector<TypeHolder>& types) const;
 
   /// \brief Returns true if the input types of each signature are
@@ -401,6 +423,7 @@ class ARROW_EXPORT KernelSignature {
   std::vector<InputType> in_types_;
   OutputType out_type_;
   bool is_varargs_;
+  std::shared_ptr<MatchConstraint> constraint_;
 
   // For caching the hash code after it's computed the first time
   mutable uint64_t hash_code_;
diff --git a/cpp/src/arrow/compute/kernel_test.cc 
b/cpp/src/arrow/compute/kernel_test.cc
index e9664b104d..deaddaddc6 100644
--- a/cpp/src/arrow/compute/kernel_test.cc
+++ b/cpp/src/arrow/compute/kernel_test.cc
@@ -307,6 +307,57 @@ TEST(OutputType, Resolve) {
   ASSERT_EQ(result, int32());
 }
 
+// ----------------------------------------------------------------------
+// MatchConstraint
+
+TEST(MatchConstraint, ConvenienceMaker) {
+  {
+    auto always_match =
+        MakeConstraint([](const std::vector<TypeHolder>& types) { return true; 
});
+
+    ASSERT_TRUE(always_match->Matches({}));
+    ASSERT_TRUE(always_match->Matches({int8(), int16(), int32()}));
+  }
+
+  {
+    auto always_false =
+        MakeConstraint([](const std::vector<TypeHolder>& types) { return 
false; });
+
+    ASSERT_FALSE(always_false->Matches({}));
+    ASSERT_FALSE(always_false->Matches({int8(), int16(), int32()}));
+  }
+}
+
+TEST(MatchConstraint, DecimalsHaveSameScale) {
+  auto c = DecimalsHaveSameScale();
+  constexpr int32_t precision = 12, scale = 2;
+  ASSERT_TRUE(c->Matches({decimal128(precision, scale), decimal128(precision, 
scale)}));
+  ASSERT_TRUE(c->Matches({decimal128(precision, scale), decimal256(precision, 
scale)}));
+  ASSERT_TRUE(c->Matches({decimal256(precision, scale), decimal128(precision, 
scale)}));
+  ASSERT_TRUE(c->Matches({decimal256(precision, scale), decimal256(precision, 
scale)}));
+  ASSERT_FALSE(
+      c->Matches({decimal128(precision, scale), decimal128(precision, scale + 
1)}));
+  ASSERT_FALSE(c->Matches({decimal128(precision, scale), decimal128(precision, 
scale),
+                           decimal128(precision, scale + 1)}));
+}
+
+TEST(MatchConstraint, BinaryDecimalScaleComparisonGE) {
+  auto c = BinaryDecimalScale1GeScale2();
+  constexpr int32_t precision = 12, small_scale = 2, big_scale = 3;
+  ASSERT_TRUE(
+      c->Matches({decimal128(precision, big_scale), decimal128(precision, 
small_scale)}));
+  ASSERT_TRUE(
+      c->Matches({decimal128(precision, big_scale), decimal256(precision, 
small_scale)}));
+  ASSERT_TRUE(
+      c->Matches({decimal256(precision, big_scale), decimal128(precision, 
small_scale)}));
+  ASSERT_TRUE(
+      c->Matches({decimal256(precision, big_scale), decimal256(precision, 
small_scale)}));
+  ASSERT_TRUE(c->Matches(
+      {decimal128(precision, small_scale), decimal128(precision, 
small_scale)}));
+  ASSERT_FALSE(
+      c->Matches({decimal128(precision, small_scale), decimal128(precision, 
big_scale)}));
+}
+
 // ----------------------------------------------------------------------
 // KernelSignature
 
@@ -419,6 +470,35 @@ TEST(KernelSignature, VarArgsMatchesInputs) {
   }
 }
 
+TEST(KernelSignature, MatchesInputsWithConstraint) {
+  constexpr int32_t precision = 12, small_scale = 2, big_scale = 3;
+
+  auto small_scale_decimal = decimal128(precision, small_scale);
+  auto big_scale_decimal = decimal128(precision, big_scale);
+
+  // No constraint.
+  KernelSignature sig_no_constraint({Type::DECIMAL128, Type::DECIMAL128}, 
boolean());
+  ASSERT_TRUE(
+      sig_no_constraint.MatchesInputs({small_scale_decimal, 
small_scale_decimal}));
+  ASSERT_TRUE(sig_no_constraint.MatchesInputs({small_scale_decimal, 
big_scale_decimal}));
+  ASSERT_TRUE(
+      sig_no_constraint.MatchesInputs({small_scale_decimal, 
small_scale_decimal}));
+  ASSERT_TRUE(sig_no_constraint.MatchesInputs({small_scale_decimal, 
big_scale_decimal}));
+
+  for (auto constraint : {DecimalsHaveSameScale(), 
BinaryDecimalScale1GeScale2()}) {
+    KernelSignature sig({Type::DECIMAL128, Type::DECIMAL128}, boolean(),
+                        /*is_varargs=*/false, constraint);
+    ASSERT_EQ(constraint->Matches({small_scale_decimal, small_scale_decimal}),
+              sig.MatchesInputs({small_scale_decimal, small_scale_decimal}));
+    ASSERT_EQ(constraint->Matches({small_scale_decimal, big_scale_decimal}),
+              sig.MatchesInputs({small_scale_decimal, big_scale_decimal}));
+    ASSERT_EQ(constraint->Matches({big_scale_decimal, small_scale_decimal}),
+              sig.MatchesInputs({big_scale_decimal, small_scale_decimal}));
+    ASSERT_EQ(constraint->Matches({big_scale_decimal, big_scale_decimal}),
+              sig.MatchesInputs({big_scale_decimal, big_scale_decimal}));
+  }
+}
+
 TEST(KernelSignature, ToString) {
   std::vector<InputType> in_types = {InputType(int8()), 
InputType(Type::DECIMAL),
                                      InputType(utf8())};
diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc 
b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
index e536b3d886..ccbd361362 100644
--- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
@@ -598,10 +598,6 @@ Result<TypeHolder> 
ResolveDecimalAdditionOrSubtractionOutput(
       types,
       [](int32_t p1, int32_t s1, int32_t p2,
          int32_t s2) -> Result<std::pair<int32_t, int32_t>> {
-        if (s1 != s2) {
-          return Status::Invalid("Addition or subtraction of two decimal ",
-                                 "types scale1 != scale2. (", s1, s2, ").");
-        }
         DCHECK_EQ(s1, s2);
         const int32_t scale = s1;
         const int32_t precision = std::max(p1 - s1, p2 - s2) + scale + 1;
@@ -627,10 +623,6 @@ Result<TypeHolder> 
ResolveDecimalDivisionOutput(KernelContext*,
       types,
       [](int32_t p1, int32_t s1, int32_t p2,
          int32_t s2) -> Result<std::pair<int32_t, int32_t>> {
-        if (s1 < s2) {
-          return Status::Invalid("Division of two decimal types scale1 < 
scale2. ", "(",
-                                 s1, s2, ").");
-        }
         DCHECK_GE(s1, s2);
         const int32_t scale = s1 - s2;
         const int32_t precision = p1;
@@ -669,13 +661,16 @@ void AddDecimalUnaryKernels(ScalarFunction* func) {
 template <typename Op>
 void AddDecimalBinaryKernels(const std::string& name, ScalarFunction* func) {
   OutputType out_type(null());
+  std::shared_ptr<MatchConstraint> constraint = nullptr;
   const std::string op = name.substr(0, name.find("_"));
   if (op == "add" || op == "subtract") {
     out_type = OutputType(ResolveDecimalAdditionOrSubtractionOutput);
+    constraint = DecimalsHaveSameScale();
   } else if (op == "multiply") {
     out_type = OutputType(ResolveDecimalMultiplicationOutput);
   } else if (op == "divide") {
     out_type = OutputType(ResolveDecimalDivisionOutput);
+    constraint = BinaryDecimalScale1GeScale2();
   } else {
     DCHECK(false);
   }
@@ -684,8 +679,10 @@ void AddDecimalBinaryKernels(const std::string& name, 
ScalarFunction* func) {
   auto in_type256 = InputType(Type::DECIMAL256);
   auto exec128 = ScalarBinaryNotNullEqualTypes<Decimal128Type, Decimal128Type, 
Op>::Exec;
   auto exec256 = ScalarBinaryNotNullEqualTypes<Decimal256Type, Decimal256Type, 
Op>::Exec;
-  DCHECK_OK(func->AddKernel({in_type128, in_type128}, out_type, exec128));
-  DCHECK_OK(func->AddKernel({in_type256, in_type256}, out_type, exec256));
+  DCHECK_OK(func->AddKernel({in_type128, in_type128}, out_type, exec128, 
/*init=*/nullptr,
+                            constraint));
+  DCHECK_OK(func->AddKernel({in_type256, in_type256}, out_type, exec256, 
/*init=*/nullptr,
+                            constraint));
 }
 
 template <typename Op>
diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc 
b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
index c6b38df3db..0956168fc3 100644
--- a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
@@ -1900,6 +1900,61 @@ TEST_F(TestUnaryArithmeticDecimal, TrigTan) {
 
 class TestBinaryArithmeticDecimal : public TestArithmeticDecimal {};
 
+TEST_F(TestBinaryArithmeticDecimal, DispatchExact) {
+  for (std::string name : {"add", "subtract"}) {
+    for (std::string suffix : {"", "_checked"}) {
+      name += suffix;
+      ARROW_SCOPED_TRACE(name);
+
+      CheckDispatchExact(name, {decimal128(2, 1), decimal128(2, 1)});
+      CheckDispatchExact(name, {decimal128(3, 1), decimal128(2, 1)});
+      CheckDispatchExactFails(name, {decimal128(2, 0), decimal128(2, 1)});
+      CheckDispatchExactFails(name, {decimal128(2, 1), decimal128(2, 0)});
+
+      CheckDispatchExact(name, {decimal256(2, 1), decimal256(2, 1)});
+      CheckDispatchExact(name, {decimal256(3, 1), decimal256(2, 1)});
+      CheckDispatchExactFails(name, {decimal256(2, 0), decimal256(2, 1)});
+      CheckDispatchExactFails(name, {decimal256(2, 1), decimal256(2, 0)});
+    }
+  }
+
+  {
+    std::string name = "multiply";
+    for (std::string suffix : {"", "_checked"}) {
+      name += suffix;
+      ARROW_SCOPED_TRACE(name);
+
+      CheckDispatchExact(name, {decimal128(2, 1), decimal128(2, 1)});
+      CheckDispatchExact(name, {decimal128(3, 1), decimal128(2, 1)});
+      CheckDispatchExact(name, {decimal128(2, 0), decimal128(2, 1)});
+      CheckDispatchExact(name, {decimal128(2, 1), decimal128(2, 0)});
+
+      CheckDispatchExact(name, {decimal256(2, 1), decimal256(2, 1)});
+      CheckDispatchExact(name, {decimal256(3, 1), decimal256(2, 1)});
+      CheckDispatchExact(name, {decimal256(2, 0), decimal256(2, 1)});
+      CheckDispatchExact(name, {decimal256(2, 1), decimal256(2, 0)});
+    }
+  }
+
+  {
+    std::string name = "divide";
+    for (std::string suffix : {"", "_checked"}) {
+      name += suffix;
+      ARROW_SCOPED_TRACE(name);
+
+      CheckDispatchExact(name, {decimal128(2, 1), decimal128(2, 1)});
+      CheckDispatchExact(name, {decimal128(3, 1), decimal128(2, 1)});
+      CheckDispatchExact(name, {decimal128(2, 1), decimal128(2, 0)});
+      CheckDispatchExactFails(name, {decimal128(2, 0), decimal128(2, 1)});
+
+      CheckDispatchExact(name, {decimal256(2, 1), decimal256(2, 1)});
+      CheckDispatchExact(name, {decimal256(3, 1), decimal256(2, 1)});
+      CheckDispatchExact(name, {decimal256(2, 1), decimal256(2, 0)});
+      CheckDispatchExactFails(name, {decimal256(2, 0), decimal256(2, 1)});
+    }
+  }
+}
+
 TEST_F(TestBinaryArithmeticDecimal, DispatchBest) {
   // decimal, floating point
   for (std::string name : {"add", "subtract", "multiply", "divide"}) {
diff --git a/cpp/src/arrow/compute/kernels/test_util_internal.cc 
b/cpp/src/arrow/compute/kernels/test_util_internal.cc
index 3b0c9bdd85..d1cad235cc 100644
--- a/cpp/src/arrow/compute/kernels/test_util_internal.cc
+++ b/cpp/src/arrow/compute/kernels/test_util_internal.cc
@@ -285,6 +285,11 @@ void CheckScalarBinaryCommutative(std::string func_name, 
Datum left_input,
   CheckScalar(func_name, {right_input, left_input}, expected, options);
 }
 
+void CheckDispatchExact(std::string func_name, std::vector<TypeHolder> types) {
+  ASSERT_OK_AND_ASSIGN(auto function, 
GetFunctionRegistry()->GetFunction(func_name));
+  ASSERT_OK(function->DispatchExact(types));
+}
+
 void CheckDispatchBest(std::string func_name, std::vector<TypeHolder> 
original_values,
                        std::vector<TypeHolder> expected_equivalent_values) {
   ASSERT_OK_AND_ASSIGN(auto function, 
GetFunctionRegistry()->GetFunction(func_name));
@@ -306,6 +311,11 @@ void CheckDispatchBest(std::string func_name, 
std::vector<TypeHolder> original_v
   }
 }
 
+void CheckDispatchExactFails(std::string func_name, std::vector<TypeHolder> 
types) {
+  ASSERT_OK_AND_ASSIGN(auto function, 
GetFunctionRegistry()->GetFunction(func_name));
+  ASSERT_NOT_OK(function->DispatchExact(types));
+}
+
 void CheckDispatchFails(std::string func_name, std::vector<TypeHolder> types) {
   ASSERT_OK_AND_ASSIGN(auto function, 
GetFunctionRegistry()->GetFunction(func_name));
   ASSERT_NOT_OK(function->DispatchBest(&types));
diff --git a/cpp/src/arrow/compute/kernels/test_util_internal.h 
b/cpp/src/arrow/compute/kernels/test_util_internal.h
index e3a27ab9ad..1077101377 100644
--- a/cpp/src/arrow/compute/kernels/test_util_internal.h
+++ b/cpp/src/arrow/compute/kernels/test_util_internal.h
@@ -155,11 +155,17 @@ void TestRandomPrimitiveCTypes() {
   DoTestFunctor<DurationType>::Test(duration(TimeUnit::MILLI));
 }
 
+// Check that DispatchExact on a given function yields a valid Kernel
+void CheckDispatchExact(std::string func_name, std::vector<TypeHolder> types);
+
 // Check that DispatchBest on a given function yields the same Kernel as
 // produced by DispatchExact on another set of types
 void CheckDispatchBest(std::string func_name, std::vector<TypeHolder> types,
                        std::vector<TypeHolder> exact_types);
 
+// Check that function fails to produce a Kernel via DispatchExact for the set 
of types
+void CheckDispatchExactFails(std::string func_name, std::vector<TypeHolder> 
types);
+
 // Check that function fails to produce a Kernel for the set of types
 void CheckDispatchFails(std::string func_name, std::vector<TypeHolder> types);
 

Reply via email to