This is an automated email from the ASF dual-hosted git repository. zanmato pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push: new 2987165820 GH-41011: [C++][Compute] Fix the issue that comparison function could not handle decimal arguments with different scales (#47459) 2987165820 is described below commit 2987165820e95f949f846b6d81ae3e6f671ab3d1 Author: Rossi Sun <zanmato1...@gmail.com> AuthorDate: Thu Sep 4 15:18:58 2025 +0800 GH-41011: [C++][Compute] Fix the issue that comparison function could not handle decimal arguments with different scales (#47459) ### Rationale for this change We used to be not able to suppress the exact matching for decimal arguments with different scales, when a decimal comparison kernel who actually requires the scales to be the same. This caused issue like #41011. The "match constraint" introduced in #47297 is exactly for fixing issues like this, by simply adding a proper constraint. ### What changes are included in this PR? Added match constraint that requires all decimal inputs have the same scale (like for decimal addition and subtract). ### Are these changes tested? Yes. ### Are there any user-facing changes? None. * GitHub Issue: #41011 Lead-authored-by: Rossi Sun <zanmato1...@gmail.com> Co-authored-by: Antoine Pitrou <pit...@free.fr> Signed-off-by: Rossi Sun <zanmato1...@gmail.com> --- cpp/src/arrow/compute/expression_test.cc | 81 +++++++++++++++++++++++++ cpp/src/arrow/compute/kernels/scalar_compare.cc | 4 +- 2 files changed, 83 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/expression_test.cc b/cpp/src/arrow/compute/expression_test.cc index 7e90f552ce..bbab57feeb 100644 --- a/cpp/src/arrow/compute/expression_test.cc +++ b/cpp/src/arrow/compute/expression_test.cc @@ -740,6 +740,11 @@ TEST(Expression, BindWithDecimalDivision) { } TEST(Expression, BindWithImplicitCasts) { + auto exciting_schema = schema( + {field("i64", int64()), field("dec128_3_2", decimal128(3, 2)), + field("dec128_4_2", decimal128(4, 2)), field("dec128_5_3", decimal128(5, 3)), + field("dec256_3_2", decimal256(3, 2)), field("dec256_4_2", decimal256(4, 2)), + field("dec256_5_3", decimal256(5, 3))}); for (auto cmp : {equal, not_equal, less, less_equal, greater, greater_equal}) { // cast arguments to common numeric type ExpectBindsTo(cmp(field_ref("i64"), field_ref("i32")), @@ -800,6 +805,82 @@ TEST(Expression, BindWithImplicitCasts) { ExpectBindsTo(cmp(field_ref("i32"), literal(std::make_shared<DoubleScalar>(10.0))), cmp(cast(field_ref("i32"), float32()), literal(std::make_shared<FloatScalar>(10.0f)))); + + // decimal int + ExpectBindsTo(cmp(field_ref("dec128_3_2"), field_ref("i64")), + cmp(field_ref("dec128_3_2"), cast(field_ref("i64"), decimal128(21, 2))), + /*bound_out=*/nullptr, *exciting_schema); + ExpectBindsTo(cmp(field_ref("i64"), field_ref("dec128_3_2")), + cmp(cast(field_ref("i64"), decimal128(21, 2)), field_ref("dec128_3_2")), + /*bound_out=*/nullptr, *exciting_schema); + + // decimal decimal with different widths different precisions but same scale + ExpectBindsTo( + cmp(field_ref("dec128_3_2"), field_ref("dec256_4_2")), + cmp(cast(field_ref("dec128_3_2"), decimal256(3, 2)), field_ref("dec256_4_2")), + /*bound_out=*/nullptr, *exciting_schema); + ExpectBindsTo( + cmp(field_ref("dec256_4_2"), field_ref("dec128_3_2")), + cmp(field_ref("dec256_4_2"), cast(field_ref("dec128_3_2"), decimal256(3, 2))), + /*bound_out=*/nullptr, *exciting_schema); + ExpectBindsTo( + cmp(field_ref("dec128_4_2"), field_ref("dec256_3_2")), + cmp(cast(field_ref("dec128_4_2"), decimal256(4, 2)), field_ref("dec256_3_2")), + /*bound_out=*/nullptr, *exciting_schema); + ExpectBindsTo( + cmp(field_ref("dec256_3_2"), field_ref("dec128_4_2")), + cmp(field_ref("dec256_3_2"), cast(field_ref("dec128_4_2"), decimal256(4, 2))), + /*bound_out=*/nullptr, *exciting_schema); + + // decimal decimal with different widths different scales + ExpectBindsTo( + cmp(field_ref("dec128_3_2"), field_ref("dec256_5_3")), + cmp(cast(field_ref("dec128_3_2"), decimal256(4, 3)), field_ref("dec256_5_3")), + /*bound_out=*/nullptr, *exciting_schema); + ExpectBindsTo( + cmp(field_ref("dec256_5_3"), field_ref("dec128_3_2")), + cmp(field_ref("dec256_5_3"), cast(field_ref("dec128_3_2"), decimal256(4, 3))), + /*bound_out=*/nullptr, *exciting_schema); + ExpectBindsTo(cmp(field_ref("dec128_5_3"), field_ref("dec256_3_2")), + cmp(cast(field_ref("dec128_5_3"), decimal256(5, 3)), + cast(field_ref("dec256_3_2"), decimal256(4, 3))), + /*bound_out=*/nullptr, *exciting_schema); + ExpectBindsTo(cmp(field_ref("dec256_3_2"), field_ref("dec128_5_3")), + cmp(cast(field_ref("dec256_3_2"), decimal256(4, 3)), + cast(field_ref("dec128_5_3"), decimal256(5, 3))), + /*bound_out=*/nullptr, *exciting_schema); + + // decimal decimal with same width same precision but different scales (no cast) + ExpectBindsTo(cmp(field_ref("dec128_3_2"), field_ref("dec128_4_2")), + cmp(field_ref("dec128_3_2"), field_ref("dec128_4_2")), + /*bound_out=*/nullptr, *exciting_schema); + ExpectBindsTo(cmp(field_ref("dec128_4_2"), field_ref("dec128_3_2")), + cmp(field_ref("dec128_4_2"), field_ref("dec128_3_2")), + /*bound_out=*/nullptr, *exciting_schema); + ExpectBindsTo(cmp(field_ref("dec256_3_2"), field_ref("dec256_4_2")), + cmp(field_ref("dec256_3_2"), field_ref("dec256_4_2")), + /*bound_out=*/nullptr, *exciting_schema); + ExpectBindsTo(cmp(field_ref("dec256_4_2"), field_ref("dec256_3_2")), + cmp(field_ref("dec256_4_2"), field_ref("dec256_3_2")), + /*bound_out=*/nullptr, *exciting_schema); + + // decimal decimal with same width but different scales + ExpectBindsTo( + cmp(field_ref("dec128_3_2"), field_ref("dec128_5_3")), + cmp(cast(field_ref("dec128_3_2"), decimal128(4, 3)), field_ref("dec128_5_3")), + /*bound_out=*/nullptr, *exciting_schema); + ExpectBindsTo( + cmp(field_ref("dec128_5_3"), field_ref("dec128_3_2")), + cmp(field_ref("dec128_5_3"), cast(field_ref("dec128_3_2"), decimal128(4, 3))), + /*bound_out=*/nullptr, *exciting_schema); + ExpectBindsTo( + cmp(field_ref("dec256_3_2"), field_ref("dec256_5_3")), + cmp(cast(field_ref("dec256_3_2"), decimal256(4, 3)), field_ref("dec256_5_3")), + /*bound_out=*/nullptr, *exciting_schema); + ExpectBindsTo( + cmp(field_ref("dec256_5_3"), field_ref("dec256_3_2")), + cmp(field_ref("dec256_5_3"), cast(field_ref("dec256_3_2"), decimal256(4, 3))), + /*bound_out=*/nullptr, *exciting_schema); } compute::SetLookupOptions in_a{ArrayFromJSON(utf8(), R"(["a"])")}; diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc b/cpp/src/arrow/compute/kernels/scalar_compare.cc index ef5b6fb4aa..773a3f684b 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc @@ -436,8 +436,8 @@ std::shared_ptr<ScalarFunction> MakeCompareFunction(std::string name, FunctionDo for (const auto id : {Type::DECIMAL128, Type::DECIMAL256}) { auto exec = GenerateDecimal<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(id); - DCHECK_OK( - func->AddKernel({InputType(id), InputType(id)}, boolean(), std::move(exec))); + DCHECK_OK(func->AddKernel({InputType(id), InputType(id)}, boolean(), std::move(exec), + /*init=*/nullptr, DecimalsHaveSameScale())); } {