This is an automated email from the ASF dual-hosted git repository.
zanmato pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new dadc21f31b GH-47287: [C++][Compute] Add constraint for kernel
signature matching and use it for binary decimal arithmetic kernels (#47297)
dadc21f31b is described below
commit dadc21f31b1d43cab4eb763f5e65691ce1f08fd5
Author: Rossi Sun <[email protected]>
AuthorDate: Wed Aug 20 23:20:54 2025 +0800
GH-47287: [C++][Compute] Add constraint for kernel signature matching and
use it for binary decimal arithmetic kernels (#47297)
### Rationale for this change
A rework of #40223 using a more systematic alternative.
### What changes are included in this PR?
Introduce a structure `MatchConstraint` for applying extra (and optional)
matching constraint for kernel signature matching, in additional to simply
input type checks.
Also implement two concrete `MatchConstraint`s for binary decimal
arithmetic kernels, to suppress exact match even if the input types are OK, for
example, by requiring all decimal must be of the same scale for `add` and
`subtract`, and s1 >= s2 for `divide`.
This should also be a fundamental enhancement to further resolve similar
issues like:
* #35843
* #39875
* #40911
* #41011
* #41336
(Haven't try each one of them. May do that if this PR gets merged.)
### Are these changes tested?
UT included.
### Are there any user-facing changes?
New public class `MatchConstraint`.
* GitHub Issue: #47287
Authored-by: Rossi Sun <[email protected]>
Signed-off-by: Rossi Sun <[email protected]>
---
cpp/src/arrow/compute/function.cc | 7 +-
cpp/src/arrow/compute/function.h | 3 +-
cpp/src/arrow/compute/kernel.cc | 83 ++++++++++++++++++++--
cpp/src/arrow/compute/kernel.h | 35 +++++++--
cpp/src/arrow/compute/kernel_test.cc | 80 +++++++++++++++++++++
cpp/src/arrow/compute/kernels/scalar_arithmetic.cc | 17 ++---
.../compute/kernels/scalar_arithmetic_test.cc | 55 ++++++++++++++
.../arrow/compute/kernels/test_util_internal.cc | 10 +++
cpp/src/arrow/compute/kernels/test_util_internal.h | 6 ++
9 files changed, 271 insertions(+), 25 deletions(-)
diff --git a/cpp/src/arrow/compute/function.cc
b/cpp/src/arrow/compute/function.cc
index d7842f14ef..b0b12a690f 100644
--- a/cpp/src/arrow/compute/function.cc
+++ b/cpp/src/arrow/compute/function.cc
@@ -410,14 +410,15 @@ Status Function::Validate() const {
}
Status ScalarFunction::AddKernel(std::vector<InputType> in_types, OutputType
out_type,
- ArrayKernelExec exec, KernelInit init) {
+ ArrayKernelExec exec, KernelInit init,
+ std::shared_ptr<MatchConstraint> constraint) {
RETURN_NOT_OK(CheckArity(in_types.size()));
if (arity_.is_varargs && in_types.size() != 1) {
return Status::Invalid("VarArgs signatures must have exactly one input
type");
}
- auto sig =
- KernelSignature::Make(std::move(in_types), std::move(out_type),
arity_.is_varargs);
+ auto sig = KernelSignature::Make(std::move(in_types), std::move(out_type),
+ arity_.is_varargs, std::move(constraint));
kernels_.emplace_back(std::move(sig), exec, init);
return Status::OK();
}
diff --git a/cpp/src/arrow/compute/function.h b/cpp/src/arrow/compute/function.h
index 2b86f64216..399081e2a7 100644
--- a/cpp/src/arrow/compute/function.h
+++ b/cpp/src/arrow/compute/function.h
@@ -308,7 +308,8 @@ class ARROW_EXPORT ScalarFunction : public
detail::FunctionImpl<ScalarKernel> {
/// initialization, preallocation for fixed-width types, and default null
/// handling (intersect validity bitmaps of inputs).
Status AddKernel(std::vector<InputType> in_types, OutputType out_type,
- ArrayKernelExec exec, KernelInit init = NULLPTR);
+ ArrayKernelExec exec, KernelInit init = NULLPTR,
+ std::shared_ptr<MatchConstraint> constraint = NULLPTR);
/// \brief Add a kernel (function implementation). Returns error if the
/// kernel's signature does not match the function's arity.
diff --git a/cpp/src/arrow/compute/kernel.cc b/cpp/src/arrow/compute/kernel.cc
index 4211593adf..f7fecc9247 100644
--- a/cpp/src/arrow/compute/kernel.cc
+++ b/cpp/src/arrow/compute/kernel.cc
@@ -475,23 +475,93 @@ std::string OutputType::ToString() const {
return "computed";
}
+// ----------------------------------------------------------------------
+// MatchConstraint
+
+std::shared_ptr<MatchConstraint> MakeConstraint(
+ std::function<bool(const std::vector<TypeHolder>&)> matches) {
+ class FunctionMatchConstraint : public MatchConstraint {
+ public:
+ explicit FunctionMatchConstraint(
+ std::function<bool(const std::vector<TypeHolder>&)> matches)
+ : matches_(std::move(matches)) {}
+
+ bool Matches(const std::vector<TypeHolder>& types) const override {
+ return matches_(types);
+ }
+
+ private:
+ std::function<bool(const std::vector<TypeHolder>&)> matches_;
+ };
+
+ return std::make_shared<FunctionMatchConstraint>(std::move(matches));
+}
+
+std::shared_ptr<MatchConstraint> DecimalsHaveSameScale() {
+ class DecimalsHaveSameScaleConstraint : public MatchConstraint {
+ public:
+ bool Matches(const std::vector<TypeHolder>& types) const override {
+ DCHECK_GE(types.size(), 2);
+ DCHECK(std::all_of(types.begin(), types.end(),
+ [](const TypeHolder& type) { return
is_decimal(type.id()); }));
+ const auto& ty0 = checked_cast<const DecimalType&>(*types[0].type);
+ auto s0 = ty0.scale();
+ for (size_t i = 1; i < types.size(); ++i) {
+ const auto& ty = checked_cast<const DecimalType&>(*types[i].type);
+ if (ty.scale() != s0) {
+ return false;
+ }
+ }
+ return true;
+ }
+ };
+ static auto instance = std::make_shared<DecimalsHaveSameScaleConstraint>();
+ return instance;
+}
+
+namespace {
+
+template <typename Op>
+class BinaryDecimalScaleComparisonConstraint : public MatchConstraint {
+ public:
+ bool Matches(const std::vector<TypeHolder>& types) const override {
+ DCHECK_EQ(types.size(), 2);
+ DCHECK(is_decimal(types[0].id()));
+ DCHECK(is_decimal(types[1].id()));
+ const auto& ty0 = checked_cast<const DecimalType&>(*types[0].type);
+ const auto& ty1 = checked_cast<const DecimalType&>(*types[1].type);
+ return Op{}(ty0.scale(), ty1.scale());
+ }
+};
+
+} // namespace
+
+std::shared_ptr<MatchConstraint> BinaryDecimalScale1GeScale2() {
+ using BinaryDecimalScale1GeScale2Constraint =
+ BinaryDecimalScaleComparisonConstraint<std::greater_equal<>>;
+ static auto instance =
std::make_shared<BinaryDecimalScale1GeScale2Constraint>();
+ return instance;
+}
+
// ----------------------------------------------------------------------
// KernelSignature
KernelSignature::KernelSignature(std::vector<InputType> in_types, OutputType
out_type,
- bool is_varargs)
+ bool is_varargs,
+ std::shared_ptr<MatchConstraint> constraint)
: in_types_(std::move(in_types)),
out_type_(std::move(out_type)),
is_varargs_(is_varargs),
+ constraint_(std::move(constraint)),
hash_code_(0) {
DCHECK(!is_varargs || (is_varargs && (in_types_.size() >= 1)));
}
-std::shared_ptr<KernelSignature> KernelSignature::Make(std::vector<InputType>
in_types,
- OutputType out_type,
- bool is_varargs) {
+std::shared_ptr<KernelSignature> KernelSignature::Make(
+ std::vector<InputType> in_types, OutputType out_type, bool is_varargs,
+ std::shared_ptr<MatchConstraint> constraint) {
return std::make_shared<KernelSignature>(std::move(in_types),
std::move(out_type),
- is_varargs);
+ is_varargs, std::move(constraint));
}
bool KernelSignature::Equals(const KernelSignature& other) const {
@@ -526,6 +596,9 @@ bool KernelSignature::MatchesInputs(const
std::vector<TypeHolder>& types) const
}
}
}
+ if (constraint_ && !constraint_->Matches(types)) {
+ return false;
+ }
return true;
}
diff --git a/cpp/src/arrow/compute/kernel.h b/cpp/src/arrow/compute/kernel.h
index b429c40ac3..fdcdb134de 100644
--- a/cpp/src/arrow/compute/kernel.h
+++ b/cpp/src/arrow/compute/kernel.h
@@ -348,7 +348,28 @@ class ARROW_EXPORT OutputType {
Resolver resolver_ = NULLPTR;
};
-/// \brief Holds the input types and output type of the kernel.
+/// \brief Additional constraints to apply to the input types of a kernel when
matching a
+/// specific kernel signature.
+class ARROW_EXPORT MatchConstraint {
+ public:
+ virtual ~MatchConstraint() = default;
+
+ /// \brief Return true if the input types satisfy the constraint.
+ virtual bool Matches(const std::vector<TypeHolder>& types) const = 0;
+};
+
+/// \brief Convenience function to create a MatchConstraint from a match
function.
+ARROW_EXPORT std::shared_ptr<MatchConstraint> MakeConstraint(
+ std::function<bool(const std::vector<TypeHolder>&)> matches);
+
+/// \brief Constraint that all input types are decimal types and have the same
scale.
+ARROW_EXPORT std::shared_ptr<MatchConstraint> DecimalsHaveSameScale();
+
+/// \brief Constraint that all binary input types are decimal types and the
first type's
+/// scale >= the second type's.
+ARROW_EXPORT std::shared_ptr<MatchConstraint> BinaryDecimalScale1GeScale2();
+
+/// \brief Holds the input types, optional match constraint and output type of
the kernel.
///
/// VarArgs functions with minimum N arguments should pass up to N input types
to be
/// used to validate the input types of a function invocation. The first N-1
types
@@ -357,15 +378,16 @@ class ARROW_EXPORT OutputType {
class ARROW_EXPORT KernelSignature {
public:
KernelSignature(std::vector<InputType> in_types, OutputType out_type,
- bool is_varargs = false);
+ bool is_varargs = false,
+ std::shared_ptr<MatchConstraint> constraint = NULLPTR);
/// \brief Convenience ctor since make_shared can be awkward
- static std::shared_ptr<KernelSignature> Make(std::vector<InputType> in_types,
- OutputType out_type,
- bool is_varargs = false);
+ static std::shared_ptr<KernelSignature> Make(
+ std::vector<InputType> in_types, OutputType out_type, bool is_varargs =
false,
+ std::shared_ptr<MatchConstraint> constraint = NULLPTR);
/// \brief Return true if the signature is compatible with the list of input
- /// value descriptors.
+ /// value descriptors and satisfies the match constraint, if any.
bool MatchesInputs(const std::vector<TypeHolder>& types) const;
/// \brief Returns true if the input types of each signature are
@@ -401,6 +423,7 @@ class ARROW_EXPORT KernelSignature {
std::vector<InputType> in_types_;
OutputType out_type_;
bool is_varargs_;
+ std::shared_ptr<MatchConstraint> constraint_;
// For caching the hash code after it's computed the first time
mutable uint64_t hash_code_;
diff --git a/cpp/src/arrow/compute/kernel_test.cc
b/cpp/src/arrow/compute/kernel_test.cc
index e9664b104d..deaddaddc6 100644
--- a/cpp/src/arrow/compute/kernel_test.cc
+++ b/cpp/src/arrow/compute/kernel_test.cc
@@ -307,6 +307,57 @@ TEST(OutputType, Resolve) {
ASSERT_EQ(result, int32());
}
+// ----------------------------------------------------------------------
+// MatchConstraint
+
+TEST(MatchConstraint, ConvenienceMaker) {
+ {
+ auto always_match =
+ MakeConstraint([](const std::vector<TypeHolder>& types) { return true;
});
+
+ ASSERT_TRUE(always_match->Matches({}));
+ ASSERT_TRUE(always_match->Matches({int8(), int16(), int32()}));
+ }
+
+ {
+ auto always_false =
+ MakeConstraint([](const std::vector<TypeHolder>& types) { return
false; });
+
+ ASSERT_FALSE(always_false->Matches({}));
+ ASSERT_FALSE(always_false->Matches({int8(), int16(), int32()}));
+ }
+}
+
+TEST(MatchConstraint, DecimalsHaveSameScale) {
+ auto c = DecimalsHaveSameScale();
+ constexpr int32_t precision = 12, scale = 2;
+ ASSERT_TRUE(c->Matches({decimal128(precision, scale), decimal128(precision,
scale)}));
+ ASSERT_TRUE(c->Matches({decimal128(precision, scale), decimal256(precision,
scale)}));
+ ASSERT_TRUE(c->Matches({decimal256(precision, scale), decimal128(precision,
scale)}));
+ ASSERT_TRUE(c->Matches({decimal256(precision, scale), decimal256(precision,
scale)}));
+ ASSERT_FALSE(
+ c->Matches({decimal128(precision, scale), decimal128(precision, scale +
1)}));
+ ASSERT_FALSE(c->Matches({decimal128(precision, scale), decimal128(precision,
scale),
+ decimal128(precision, scale + 1)}));
+}
+
+TEST(MatchConstraint, BinaryDecimalScaleComparisonGE) {
+ auto c = BinaryDecimalScale1GeScale2();
+ constexpr int32_t precision = 12, small_scale = 2, big_scale = 3;
+ ASSERT_TRUE(
+ c->Matches({decimal128(precision, big_scale), decimal128(precision,
small_scale)}));
+ ASSERT_TRUE(
+ c->Matches({decimal128(precision, big_scale), decimal256(precision,
small_scale)}));
+ ASSERT_TRUE(
+ c->Matches({decimal256(precision, big_scale), decimal128(precision,
small_scale)}));
+ ASSERT_TRUE(
+ c->Matches({decimal256(precision, big_scale), decimal256(precision,
small_scale)}));
+ ASSERT_TRUE(c->Matches(
+ {decimal128(precision, small_scale), decimal128(precision,
small_scale)}));
+ ASSERT_FALSE(
+ c->Matches({decimal128(precision, small_scale), decimal128(precision,
big_scale)}));
+}
+
// ----------------------------------------------------------------------
// KernelSignature
@@ -419,6 +470,35 @@ TEST(KernelSignature, VarArgsMatchesInputs) {
}
}
+TEST(KernelSignature, MatchesInputsWithConstraint) {
+ constexpr int32_t precision = 12, small_scale = 2, big_scale = 3;
+
+ auto small_scale_decimal = decimal128(precision, small_scale);
+ auto big_scale_decimal = decimal128(precision, big_scale);
+
+ // No constraint.
+ KernelSignature sig_no_constraint({Type::DECIMAL128, Type::DECIMAL128},
boolean());
+ ASSERT_TRUE(
+ sig_no_constraint.MatchesInputs({small_scale_decimal,
small_scale_decimal}));
+ ASSERT_TRUE(sig_no_constraint.MatchesInputs({small_scale_decimal,
big_scale_decimal}));
+ ASSERT_TRUE(
+ sig_no_constraint.MatchesInputs({small_scale_decimal,
small_scale_decimal}));
+ ASSERT_TRUE(sig_no_constraint.MatchesInputs({small_scale_decimal,
big_scale_decimal}));
+
+ for (auto constraint : {DecimalsHaveSameScale(),
BinaryDecimalScale1GeScale2()}) {
+ KernelSignature sig({Type::DECIMAL128, Type::DECIMAL128}, boolean(),
+ /*is_varargs=*/false, constraint);
+ ASSERT_EQ(constraint->Matches({small_scale_decimal, small_scale_decimal}),
+ sig.MatchesInputs({small_scale_decimal, small_scale_decimal}));
+ ASSERT_EQ(constraint->Matches({small_scale_decimal, big_scale_decimal}),
+ sig.MatchesInputs({small_scale_decimal, big_scale_decimal}));
+ ASSERT_EQ(constraint->Matches({big_scale_decimal, small_scale_decimal}),
+ sig.MatchesInputs({big_scale_decimal, small_scale_decimal}));
+ ASSERT_EQ(constraint->Matches({big_scale_decimal, big_scale_decimal}),
+ sig.MatchesInputs({big_scale_decimal, big_scale_decimal}));
+ }
+}
+
TEST(KernelSignature, ToString) {
std::vector<InputType> in_types = {InputType(int8()),
InputType(Type::DECIMAL),
InputType(utf8())};
diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
index e536b3d886..ccbd361362 100644
--- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
@@ -598,10 +598,6 @@ Result<TypeHolder>
ResolveDecimalAdditionOrSubtractionOutput(
types,
[](int32_t p1, int32_t s1, int32_t p2,
int32_t s2) -> Result<std::pair<int32_t, int32_t>> {
- if (s1 != s2) {
- return Status::Invalid("Addition or subtraction of two decimal ",
- "types scale1 != scale2. (", s1, s2, ").");
- }
DCHECK_EQ(s1, s2);
const int32_t scale = s1;
const int32_t precision = std::max(p1 - s1, p2 - s2) + scale + 1;
@@ -627,10 +623,6 @@ Result<TypeHolder>
ResolveDecimalDivisionOutput(KernelContext*,
types,
[](int32_t p1, int32_t s1, int32_t p2,
int32_t s2) -> Result<std::pair<int32_t, int32_t>> {
- if (s1 < s2) {
- return Status::Invalid("Division of two decimal types scale1 <
scale2. ", "(",
- s1, s2, ").");
- }
DCHECK_GE(s1, s2);
const int32_t scale = s1 - s2;
const int32_t precision = p1;
@@ -669,13 +661,16 @@ void AddDecimalUnaryKernels(ScalarFunction* func) {
template <typename Op>
void AddDecimalBinaryKernels(const std::string& name, ScalarFunction* func) {
OutputType out_type(null());
+ std::shared_ptr<MatchConstraint> constraint = nullptr;
const std::string op = name.substr(0, name.find("_"));
if (op == "add" || op == "subtract") {
out_type = OutputType(ResolveDecimalAdditionOrSubtractionOutput);
+ constraint = DecimalsHaveSameScale();
} else if (op == "multiply") {
out_type = OutputType(ResolveDecimalMultiplicationOutput);
} else if (op == "divide") {
out_type = OutputType(ResolveDecimalDivisionOutput);
+ constraint = BinaryDecimalScale1GeScale2();
} else {
DCHECK(false);
}
@@ -684,8 +679,10 @@ void AddDecimalBinaryKernels(const std::string& name,
ScalarFunction* func) {
auto in_type256 = InputType(Type::DECIMAL256);
auto exec128 = ScalarBinaryNotNullEqualTypes<Decimal128Type, Decimal128Type,
Op>::Exec;
auto exec256 = ScalarBinaryNotNullEqualTypes<Decimal256Type, Decimal256Type,
Op>::Exec;
- DCHECK_OK(func->AddKernel({in_type128, in_type128}, out_type, exec128));
- DCHECK_OK(func->AddKernel({in_type256, in_type256}, out_type, exec256));
+ DCHECK_OK(func->AddKernel({in_type128, in_type128}, out_type, exec128,
/*init=*/nullptr,
+ constraint));
+ DCHECK_OK(func->AddKernel({in_type256, in_type256}, out_type, exec256,
/*init=*/nullptr,
+ constraint));
}
template <typename Op>
diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
index c6b38df3db..0956168fc3 100644
--- a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
@@ -1900,6 +1900,61 @@ TEST_F(TestUnaryArithmeticDecimal, TrigTan) {
class TestBinaryArithmeticDecimal : public TestArithmeticDecimal {};
+TEST_F(TestBinaryArithmeticDecimal, DispatchExact) {
+ for (std::string name : {"add", "subtract"}) {
+ for (std::string suffix : {"", "_checked"}) {
+ name += suffix;
+ ARROW_SCOPED_TRACE(name);
+
+ CheckDispatchExact(name, {decimal128(2, 1), decimal128(2, 1)});
+ CheckDispatchExact(name, {decimal128(3, 1), decimal128(2, 1)});
+ CheckDispatchExactFails(name, {decimal128(2, 0), decimal128(2, 1)});
+ CheckDispatchExactFails(name, {decimal128(2, 1), decimal128(2, 0)});
+
+ CheckDispatchExact(name, {decimal256(2, 1), decimal256(2, 1)});
+ CheckDispatchExact(name, {decimal256(3, 1), decimal256(2, 1)});
+ CheckDispatchExactFails(name, {decimal256(2, 0), decimal256(2, 1)});
+ CheckDispatchExactFails(name, {decimal256(2, 1), decimal256(2, 0)});
+ }
+ }
+
+ {
+ std::string name = "multiply";
+ for (std::string suffix : {"", "_checked"}) {
+ name += suffix;
+ ARROW_SCOPED_TRACE(name);
+
+ CheckDispatchExact(name, {decimal128(2, 1), decimal128(2, 1)});
+ CheckDispatchExact(name, {decimal128(3, 1), decimal128(2, 1)});
+ CheckDispatchExact(name, {decimal128(2, 0), decimal128(2, 1)});
+ CheckDispatchExact(name, {decimal128(2, 1), decimal128(2, 0)});
+
+ CheckDispatchExact(name, {decimal256(2, 1), decimal256(2, 1)});
+ CheckDispatchExact(name, {decimal256(3, 1), decimal256(2, 1)});
+ CheckDispatchExact(name, {decimal256(2, 0), decimal256(2, 1)});
+ CheckDispatchExact(name, {decimal256(2, 1), decimal256(2, 0)});
+ }
+ }
+
+ {
+ std::string name = "divide";
+ for (std::string suffix : {"", "_checked"}) {
+ name += suffix;
+ ARROW_SCOPED_TRACE(name);
+
+ CheckDispatchExact(name, {decimal128(2, 1), decimal128(2, 1)});
+ CheckDispatchExact(name, {decimal128(3, 1), decimal128(2, 1)});
+ CheckDispatchExact(name, {decimal128(2, 1), decimal128(2, 0)});
+ CheckDispatchExactFails(name, {decimal128(2, 0), decimal128(2, 1)});
+
+ CheckDispatchExact(name, {decimal256(2, 1), decimal256(2, 1)});
+ CheckDispatchExact(name, {decimal256(3, 1), decimal256(2, 1)});
+ CheckDispatchExact(name, {decimal256(2, 1), decimal256(2, 0)});
+ CheckDispatchExactFails(name, {decimal256(2, 0), decimal256(2, 1)});
+ }
+ }
+}
+
TEST_F(TestBinaryArithmeticDecimal, DispatchBest) {
// decimal, floating point
for (std::string name : {"add", "subtract", "multiply", "divide"}) {
diff --git a/cpp/src/arrow/compute/kernels/test_util_internal.cc
b/cpp/src/arrow/compute/kernels/test_util_internal.cc
index 3b0c9bdd85..d1cad235cc 100644
--- a/cpp/src/arrow/compute/kernels/test_util_internal.cc
+++ b/cpp/src/arrow/compute/kernels/test_util_internal.cc
@@ -285,6 +285,11 @@ void CheckScalarBinaryCommutative(std::string func_name,
Datum left_input,
CheckScalar(func_name, {right_input, left_input}, expected, options);
}
+void CheckDispatchExact(std::string func_name, std::vector<TypeHolder> types) {
+ ASSERT_OK_AND_ASSIGN(auto function,
GetFunctionRegistry()->GetFunction(func_name));
+ ASSERT_OK(function->DispatchExact(types));
+}
+
void CheckDispatchBest(std::string func_name, std::vector<TypeHolder>
original_values,
std::vector<TypeHolder> expected_equivalent_values) {
ASSERT_OK_AND_ASSIGN(auto function,
GetFunctionRegistry()->GetFunction(func_name));
@@ -306,6 +311,11 @@ void CheckDispatchBest(std::string func_name,
std::vector<TypeHolder> original_v
}
}
+void CheckDispatchExactFails(std::string func_name, std::vector<TypeHolder>
types) {
+ ASSERT_OK_AND_ASSIGN(auto function,
GetFunctionRegistry()->GetFunction(func_name));
+ ASSERT_NOT_OK(function->DispatchExact(types));
+}
+
void CheckDispatchFails(std::string func_name, std::vector<TypeHolder> types) {
ASSERT_OK_AND_ASSIGN(auto function,
GetFunctionRegistry()->GetFunction(func_name));
ASSERT_NOT_OK(function->DispatchBest(&types));
diff --git a/cpp/src/arrow/compute/kernels/test_util_internal.h
b/cpp/src/arrow/compute/kernels/test_util_internal.h
index e3a27ab9ad..1077101377 100644
--- a/cpp/src/arrow/compute/kernels/test_util_internal.h
+++ b/cpp/src/arrow/compute/kernels/test_util_internal.h
@@ -155,11 +155,17 @@ void TestRandomPrimitiveCTypes() {
DoTestFunctor<DurationType>::Test(duration(TimeUnit::MILLI));
}
+// Check that DispatchExact on a given function yields a valid Kernel
+void CheckDispatchExact(std::string func_name, std::vector<TypeHolder> types);
+
// Check that DispatchBest on a given function yields the same Kernel as
// produced by DispatchExact on another set of types
void CheckDispatchBest(std::string func_name, std::vector<TypeHolder> types,
std::vector<TypeHolder> exact_types);
+// Check that function fails to produce a Kernel via DispatchExact for the set
of types
+void CheckDispatchExactFails(std::string func_name, std::vector<TypeHolder>
types);
+
// Check that function fails to produce a Kernel for the set of types
void CheckDispatchFails(std::string func_name, std::vector<TypeHolder> types);