This is an automated email from the ASF dual-hosted git repository.

kou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new f100eff39f GH-40308: [C++][Gandiva] Add support for compute module's 
decimal promotion rules (#40434)
f100eff39f is described below

commit f100eff39fd37538c5ab4572083029622fc0f5aa
Author: ZhangHuiGui <[email protected]>
AuthorDate: Mon Mar 25 13:02:24 2024 +0800

    GH-40308: [C++][Gandiva] Add support for compute module's decimal promotion 
rules (#40434)
    
    ### Rationale for this change
    
    Gandiva decimal divide rules are different with our compute module's rules. 
Some systems such as Redshift use the same rules as our compute module's rules. 
So it's useful that Gandiva support our compute module's rules too.
    
    ### What changes are included in this PR?
    Support an option argument in GetResultType for compatibilty with  
**compute module's decimal promotion rules**.
    
    ### Are these changes tested?
    Yes
    
    ### Are there any user-facing changes?
    No
    
    * GitHub Issue: #40308
    
    Authored-by: ZhangHuiGui <[email protected]>
    Signed-off-by: Sutou Kouhei <[email protected]>
---
 cpp/src/gandiva/decimal_type_util.cc         | 19 ++++++++++---
 cpp/src/gandiva/decimal_type_util.h          | 14 +++++++++-
 cpp/src/gandiva/tests/decimal_single_test.cc | 40 +++++++++++++++++++++++-----
 3 files changed, 63 insertions(+), 10 deletions(-)

diff --git a/cpp/src/gandiva/decimal_type_util.cc 
b/cpp/src/gandiva/decimal_type_util.cc
index 2abc5a21ea..cce4292f3b 100644
--- a/cpp/src/gandiva/decimal_type_util.cc
+++ b/cpp/src/gandiva/decimal_type_util.cc
@@ -30,7 +30,8 @@ constexpr int32_t DecimalTypeUtil::kMinAdjustedScale;
 
 // Implementation of decimal rules.
 Status DecimalTypeUtil::GetResultType(Op op, const Decimal128TypeVector& 
in_types,
-                                      Decimal128TypePtr* out_type) {
+                                      Decimal128TypePtr* out_type,
+                                      bool use_compute_rules) {
   DCHECK_EQ(in_types.size(), 2);
 
   *out_type = nullptr;
@@ -59,7 +60,9 @@ Status DecimalTypeUtil::GetResultType(Op op, const 
Decimal128TypeVector& in_type
       break;
 
     case kOpDivide:
-      result_scale = std::max(kMinAdjustedScale, s1 + p2 + 1);
+      result_scale = use_compute_rules
+                         ? std::max(kMinComputeAdjustedScale, s1 + p2 - s2 + 1)
+                         : std::max(kMinAdjustedScale, s1 + p2 + 1);
       result_precision = p1 - s1 + s2 + result_scale;
       break;
 
@@ -68,7 +71,17 @@ Status DecimalTypeUtil::GetResultType(Op op, const 
Decimal128TypeVector& in_type
       result_precision = std::min(p1 - s1, p2 - s2) + result_scale;
       break;
   }
-  *out_type = MakeAdjustedType(result_precision, result_scale);
+
+  if (use_compute_rules) {
+    if (result_precision < kMinPrecision || result_precision > kMaxPrecision) {
+      return Status::Invalid("Decimal precision out of range [", 
int32_t(kMinPrecision),
+                             ", ", int32_t(kMaxPrecision), "]: ", 
result_precision);
+    }
+    *out_type = MakeType(result_precision, result_scale);
+  } else {
+    *out_type = MakeAdjustedType(result_precision, result_scale);
+  }
+
   return Status::OK();
 }
 
diff --git a/cpp/src/gandiva/decimal_type_util.h 
b/cpp/src/gandiva/decimal_type_util.h
index 2b496f6cbf..16ce544717 100644
--- a/cpp/src/gandiva/decimal_type_util.h
+++ b/cpp/src/gandiva/decimal_type_util.h
@@ -45,6 +45,9 @@ class GANDIVA_EXPORT DecimalTypeUtil {
   /// The maximum precision representable by a 8-byte decimal
   static constexpr int32_t kMaxDecimal64Precision = 18;
 
+  /// The minimum precision representable by a 16-byte decimal
+  static constexpr int32_t kMinPrecision = 1;
+
   /// The maximum precision representable by a 16-byte decimal
   static constexpr int32_t kMaxPrecision = 38;
 
@@ -57,10 +60,19 @@ class GANDIVA_EXPORT DecimalTypeUtil {
   // * There is no strong reason for 6, but both SQLServer and Impala use 6 
too.
   static constexpr int32_t kMinAdjustedScale = 6;
 
+  // The same function with kMinAdjustedScale, just for compatibility with
+  // compute module's decimal promotion rules.
+  static constexpr int32_t kMinComputeAdjustedScale = 4;
+
   // For specified operation and input scale/precision, determine the output
   // scale/precision.
+  //
+  // The 'use_compute_rules' is for compatibility with compute module's
+  // decimal promotion rules:
+  // https://arrow.apache.org/docs/cpp/compute.html#arithmetic-functions
   static Status GetResultType(Op op, const Decimal128TypeVector& in_types,
-                              Decimal128TypePtr* out_type);
+                              Decimal128TypePtr* out_type,
+                              bool use_compute_rules = false);
 
   static Decimal128TypePtr MakeType(int32_t precision, int32_t scale) {
     return std::dynamic_pointer_cast<arrow::Decimal128Type>(
diff --git a/cpp/src/gandiva/tests/decimal_single_test.cc 
b/cpp/src/gandiva/tests/decimal_single_test.cc
index 666ee4a68d..57c281a455 100644
--- a/cpp/src/gandiva/tests/decimal_single_test.cc
+++ b/cpp/src/gandiva/tests/decimal_single_test.cc
@@ -49,7 +49,8 @@ class TestDecimalOps : public ::testing::Test {
   ArrayPtr MakeDecimalVector(const DecimalScalar128& in);
 
   void Verify(DecimalTypeUtil::Op, const std::string& function, const 
DecimalScalar128& x,
-              const DecimalScalar128& y, const DecimalScalar128& expected);
+              const DecimalScalar128& y, const DecimalScalar128& expected,
+              bool use_compute_rules = false, bool verify_failed = false);
 
   void AddAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
                     const DecimalScalar128& expected) {
@@ -67,8 +68,10 @@ class TestDecimalOps : public ::testing::Test {
   }
 
   void DivideAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
-                       const DecimalScalar128& expected) {
-    Verify(DecimalTypeUtil::kOpDivide, "divide", x, y, expected);
+                       const DecimalScalar128& expected, bool 
use_compute_rules = false,
+                       bool verify_failed = false) {
+    Verify(DecimalTypeUtil::kOpDivide, "divide", x, y, expected, 
use_compute_rules,
+           verify_failed);
   }
 
   void ModAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
@@ -91,7 +94,8 @@ ArrayPtr TestDecimalOps::MakeDecimalVector(const 
DecimalScalar128& in) {
 
 void TestDecimalOps::Verify(DecimalTypeUtil::Op op, const std::string& 
function,
                             const DecimalScalar128& x, const DecimalScalar128& 
y,
-                            const DecimalScalar128& expected) {
+                            const DecimalScalar128& expected, bool 
use_compute_rules,
+                            bool verify_failed) {
   auto x_type = std::make_shared<arrow::Decimal128Type>(x.precision(), 
x.scale());
   auto y_type = std::make_shared<arrow::Decimal128Type>(y.precision(), 
y.scale());
   auto field_x = field("x", x_type);
@@ -99,8 +103,14 @@ void TestDecimalOps::Verify(DecimalTypeUtil::Op op, const 
std::string& function,
   auto schema = arrow::schema({field_x, field_y});
 
   Decimal128TypePtr output_type;
-  auto status = DecimalTypeUtil::GetResultType(op, {x_type, y_type}, 
&output_type);
-  ARROW_EXPECT_OK(status);
+  auto status = DecimalTypeUtil::GetResultType(op, {x_type, y_type}, 
&output_type,
+                                               use_compute_rules);
+  if (verify_failed) {
+    ASSERT_NOT_OK(status);
+    return;
+  } else {
+    ARROW_EXPECT_OK(status);
+  }
 
   // output fields
   auto res = field("res", output_type);
@@ -283,13 +293,31 @@ TEST_F(TestDecimalOps, TestMultiply) {
 }
 
 TEST_F(TestDecimalOps, TestDivide) {
+  // fast-path
+  //
+  // origin Gandiva's rules
   DivideAndVerify(decimal_literal("201", 10, 3),              // x
                   decimal_literal("301", 10, 2),              // y
                   decimal_literal("6677740863787", 23, 14));  // expected
 
+  // compute module's rules
+  DivideAndVerify(decimal_literal("201", 10, 3),           // x
+                  decimal_literal("301", 10, 2),           // y
+                  decimal_literal("66777408638", 21, 12),  // expected
+                  /*use_compute_rules=*/true);
+
+  // max precision beyond 38
+  //
+  // normally under origin Gandiva rules
   DivideAndVerify(DecimalScalar128(std::string(38, '9'), 38, 20),  // x
                   DecimalScalar128(std::string(35, '9'), 38, 20),  // x
                   DecimalScalar128("1000000000", 38, 6));
+
+  // invalid under compute module's rules
+  DivideAndVerify(DecimalScalar128(std::string(38, '9'), 38, 20),  // x
+                  DecimalScalar128(std::string(35, '9'), 38, 20),  // x
+                  DecimalScalar128(std::string(35, '9'), 0, 0),    // useless 
expected
+                  /*use_compute_rules=*/true, /*verify_failed=*/true);
 }
 
 TEST_F(TestDecimalOps, TestMod) {

Reply via email to