This is an automated email from the ASF dual-hosted git repository.

kazuyukitanimura pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 25957ddf fix: Remove castting on decimals with a small precision to 
decimal256  (#741)
25957ddf is described below

commit 25957ddf7431ec59bb18711aeeea4a1db2c62664
Author: KAZUYUKI TANIMURA <[email protected]>
AuthorDate: Thu Aug 1 11:03:26 2024 -0700

    fix: Remove castting on decimals with a small precision to decimal256  
(#741)
    
    ## Which issue does this PR close?
    
    Part of #670
    
    ## Rationale for this change
    
    This PR improves the native execution performance on decimals with a small 
precision
    
    ## What changes are included in this PR?
    
    This PR changes not to promote decimal128 to decimal256 if the precisions 
are small enough
    
    ## How are these changes tested?
    
    Existing tests
---
 native/core/src/execution/datafusion/planner.rs | 17 +++++---
 native/spark-expr/src/scalar_funcs.rs           | 52 ++++++++++++++++---------
 2 files changed, 45 insertions(+), 24 deletions(-)

diff --git a/native/core/src/execution/datafusion/planner.rs 
b/native/core/src/execution/datafusion/planner.rs
index 6d6102ae..5bfd7679 100644
--- a/native/core/src/execution/datafusion/planner.rs
+++ b/native/core/src/execution/datafusion/planner.rs
@@ -17,9 +17,7 @@
 
 //! Converts Spark physical plan to DataFusion physical plan
 
-use std::{collections::HashMap, sync::Arc};
-
-use arrow_schema::{DataType, Field, Schema, TimeUnit};
+use arrow_schema::{DataType, Field, Schema, TimeUnit, 
DECIMAL128_MAX_PRECISION};
 use datafusion::functions_aggregate::bit_and_or_xor::{bit_and_udaf, 
bit_or_udaf, bit_xor_udaf};
 use datafusion::functions_aggregate::count::count_udaf;
 use datafusion::functions_aggregate::sum::sum_udaf;
@@ -62,6 +60,8 @@ use 
datafusion_physical_expr_common::aggregate::create_aggregate_expr;
 use itertools::Itertools;
 use jni::objects::GlobalRef;
 use num::{BigInt, ToPrimitive};
+use std::cmp::max;
+use std::{collections::HashMap, sync::Arc};
 
 use crate::{
     errors::ExpressionError,
@@ -410,7 +410,7 @@ impl PhysicalPlanner {
                 // Spark Substring's start is 1-based when start > 0
                 let start = expr.start - i32::from(expr.start > 0);
                 // substring negative len is treated as 0 in Spark
-                let len = std::cmp::max(expr.len, 0);
+                let len = max(expr.len, 0);
 
                 Ok(Arc::new(SubstringExpr::new(
                     child,
@@ -664,7 +664,14 @@ impl PhysicalPlanner {
                 | DataFusionOperator::Modulo,
                 Ok(DataType::Decimal128(p1, s1)),
                 Ok(DataType::Decimal128(p2, s2)),
-            ) => {
+            ) if ((op == DataFusionOperator::Plus || op == 
DataFusionOperator::Minus)
+                && max(s1, s2) as u8 + max(p1 - s1 as u8, p2 - s2 as u8)
+                    >= DECIMAL128_MAX_PRECISION)
+                || (op == DataFusionOperator::Multiply && p1 + p2 >= 
DECIMAL128_MAX_PRECISION)
+                || (op == DataFusionOperator::Modulo
+                    && max(s1, s2) as u8 + max(p1 - s1 as u8, p2 - s2 as u8)
+                        > DECIMAL128_MAX_PRECISION) =>
+            {
                 let data_type = return_type.map(to_arrow_datatype).unwrap();
                 // For some Decimal128 operations, we need wider internal 
digits.
                 // Cast left and right to Decimal256 and cast the result back 
to Decimal128
diff --git a/native/spark-expr/src/scalar_funcs.rs 
b/native/spark-expr/src/scalar_funcs.rs
index c50b98ba..7cbaf12a 100644
--- a/native/spark-expr/src/scalar_funcs.rs
+++ b/native/spark-expr/src/scalar_funcs.rs
@@ -25,7 +25,7 @@ use arrow::{
     datatypes::{validate_decimal_precision, Decimal128Type, Int64Type},
 };
 use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Decimal128Array};
-use arrow_schema::DataType;
+use arrow_schema::{DataType, DECIMAL128_MAX_PRECISION};
 use datafusion::{functions::math::round::round, physical_plan::ColumnarValue};
 use datafusion_common::{
     cast::as_generic_string_array, exec_err, internal_err, DataFusionError,
@@ -460,27 +460,41 @@ pub fn spark_decimal_div(
     };
     let left = left.as_primitive::<Decimal128Type>();
     let right = right.as_primitive::<Decimal128Type>();
-    let (_, s1) = get_precision_scale(left.data_type());
-    let (_, s2) = get_precision_scale(right.data_type());
+    let (p1, s1) = get_precision_scale(left.data_type());
+    let (p2, s2) = get_precision_scale(right.data_type());
 
-    let ten = BigInt::from(10);
     let l_exp = ((s2 + s3 + 1) as u32).saturating_sub(s1 as u32);
     let r_exp = (s1 as u32).saturating_sub((s2 + s3 + 1) as u32);
-    let l_mul = ten.pow(l_exp);
-    let r_mul = ten.pow(r_exp);
-    let five = BigInt::from(5);
-    let zero = BigInt::from(0);
-    let result: Decimal128Array = arrow::compute::kernels::arity::binary(left, 
right, |l, r| {
-        let l = BigInt::from(l) * &l_mul;
-        let r = BigInt::from(r) * &r_mul;
-        let div = if r.eq(&zero) { zero.clone() } else { &l / &r };
-        let res = if div.is_negative() {
-            div - &five
-        } else {
-            div + &five
-        } / &ten;
-        res.to_i128().unwrap_or(i128::MAX)
-    })?;
+    let result: Decimal128Array = if p1 as u32 + l_exp > 
DECIMAL128_MAX_PRECISION as u32
+        || p2 as u32 + r_exp > DECIMAL128_MAX_PRECISION as u32
+    {
+        let ten = BigInt::from(10);
+        let l_mul = ten.pow(l_exp);
+        let r_mul = ten.pow(r_exp);
+        let five = BigInt::from(5);
+        let zero = BigInt::from(0);
+        arrow::compute::kernels::arity::binary(left, right, |l, r| {
+            let l = BigInt::from(l) * &l_mul;
+            let r = BigInt::from(r) * &r_mul;
+            let div = if r.eq(&zero) { zero.clone() } else { &l / &r };
+            let res = if div.is_negative() {
+                div - &five
+            } else {
+                div + &five
+            } / &ten;
+            res.to_i128().unwrap_or(i128::MAX)
+        })?
+    } else {
+        let l_mul = 10_i128.pow(l_exp);
+        let r_mul = 10_i128.pow(r_exp);
+        arrow::compute::kernels::arity::binary(left, right, |l, r| {
+            let l = l * l_mul;
+            let r = r * r_mul;
+            let div = if r == 0 { 0 } else { l / r };
+            let res = if div.is_negative() { div - 5 } else { div + 5 } / 10;
+            res.to_i128().unwrap_or(i128::MAX)
+        })?
+    };
     let result = result.with_data_type(DataType::Decimal128(p3, s3));
     Ok(ColumnarValue::Array(Arc::new(result)))
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to