This is an automated email from the ASF dual-hosted git repository.
kazuyukitanimura pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 25957ddf fix: Remove castting on decimals with a small precision to
decimal256 (#741)
25957ddf is described below
commit 25957ddf7431ec59bb18711aeeea4a1db2c62664
Author: KAZUYUKI TANIMURA <[email protected]>
AuthorDate: Thu Aug 1 11:03:26 2024 -0700
fix: Remove castting on decimals with a small precision to decimal256
(#741)
## Which issue does this PR close?
Part of #670
## Rationale for this change
This PR improves the native execution performance on decimals with a small
precision
## What changes are included in this PR?
This PR changes not to promote decimal128 to decimal256 if the precisions
are small enough
## How are these changes tested?
Existing tests
---
native/core/src/execution/datafusion/planner.rs | 17 +++++---
native/spark-expr/src/scalar_funcs.rs | 52 ++++++++++++++++---------
2 files changed, 45 insertions(+), 24 deletions(-)
diff --git a/native/core/src/execution/datafusion/planner.rs
b/native/core/src/execution/datafusion/planner.rs
index 6d6102ae..5bfd7679 100644
--- a/native/core/src/execution/datafusion/planner.rs
+++ b/native/core/src/execution/datafusion/planner.rs
@@ -17,9 +17,7 @@
//! Converts Spark physical plan to DataFusion physical plan
-use std::{collections::HashMap, sync::Arc};
-
-use arrow_schema::{DataType, Field, Schema, TimeUnit};
+use arrow_schema::{DataType, Field, Schema, TimeUnit,
DECIMAL128_MAX_PRECISION};
use datafusion::functions_aggregate::bit_and_or_xor::{bit_and_udaf,
bit_or_udaf, bit_xor_udaf};
use datafusion::functions_aggregate::count::count_udaf;
use datafusion::functions_aggregate::sum::sum_udaf;
@@ -62,6 +60,8 @@ use
datafusion_physical_expr_common::aggregate::create_aggregate_expr;
use itertools::Itertools;
use jni::objects::GlobalRef;
use num::{BigInt, ToPrimitive};
+use std::cmp::max;
+use std::{collections::HashMap, sync::Arc};
use crate::{
errors::ExpressionError,
@@ -410,7 +410,7 @@ impl PhysicalPlanner {
// Spark Substring's start is 1-based when start > 0
let start = expr.start - i32::from(expr.start > 0);
// substring negative len is treated as 0 in Spark
- let len = std::cmp::max(expr.len, 0);
+ let len = max(expr.len, 0);
Ok(Arc::new(SubstringExpr::new(
child,
@@ -664,7 +664,14 @@ impl PhysicalPlanner {
| DataFusionOperator::Modulo,
Ok(DataType::Decimal128(p1, s1)),
Ok(DataType::Decimal128(p2, s2)),
- ) => {
+ ) if ((op == DataFusionOperator::Plus || op ==
DataFusionOperator::Minus)
+ && max(s1, s2) as u8 + max(p1 - s1 as u8, p2 - s2 as u8)
+ >= DECIMAL128_MAX_PRECISION)
+ || (op == DataFusionOperator::Multiply && p1 + p2 >=
DECIMAL128_MAX_PRECISION)
+ || (op == DataFusionOperator::Modulo
+ && max(s1, s2) as u8 + max(p1 - s1 as u8, p2 - s2 as u8)
+ > DECIMAL128_MAX_PRECISION) =>
+ {
let data_type = return_type.map(to_arrow_datatype).unwrap();
// For some Decimal128 operations, we need wider internal
digits.
// Cast left and right to Decimal256 and cast the result back
to Decimal128
diff --git a/native/spark-expr/src/scalar_funcs.rs
b/native/spark-expr/src/scalar_funcs.rs
index c50b98ba..7cbaf12a 100644
--- a/native/spark-expr/src/scalar_funcs.rs
+++ b/native/spark-expr/src/scalar_funcs.rs
@@ -25,7 +25,7 @@ use arrow::{
datatypes::{validate_decimal_precision, Decimal128Type, Int64Type},
};
use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Decimal128Array};
-use arrow_schema::DataType;
+use arrow_schema::{DataType, DECIMAL128_MAX_PRECISION};
use datafusion::{functions::math::round::round, physical_plan::ColumnarValue};
use datafusion_common::{
cast::as_generic_string_array, exec_err, internal_err, DataFusionError,
@@ -460,27 +460,41 @@ pub fn spark_decimal_div(
};
let left = left.as_primitive::<Decimal128Type>();
let right = right.as_primitive::<Decimal128Type>();
- let (_, s1) = get_precision_scale(left.data_type());
- let (_, s2) = get_precision_scale(right.data_type());
+ let (p1, s1) = get_precision_scale(left.data_type());
+ let (p2, s2) = get_precision_scale(right.data_type());
- let ten = BigInt::from(10);
let l_exp = ((s2 + s3 + 1) as u32).saturating_sub(s1 as u32);
let r_exp = (s1 as u32).saturating_sub((s2 + s3 + 1) as u32);
- let l_mul = ten.pow(l_exp);
- let r_mul = ten.pow(r_exp);
- let five = BigInt::from(5);
- let zero = BigInt::from(0);
- let result: Decimal128Array = arrow::compute::kernels::arity::binary(left,
right, |l, r| {
- let l = BigInt::from(l) * &l_mul;
- let r = BigInt::from(r) * &r_mul;
- let div = if r.eq(&zero) { zero.clone() } else { &l / &r };
- let res = if div.is_negative() {
- div - &five
- } else {
- div + &five
- } / &ten;
- res.to_i128().unwrap_or(i128::MAX)
- })?;
+ let result: Decimal128Array = if p1 as u32 + l_exp >
DECIMAL128_MAX_PRECISION as u32
+ || p2 as u32 + r_exp > DECIMAL128_MAX_PRECISION as u32
+ {
+ let ten = BigInt::from(10);
+ let l_mul = ten.pow(l_exp);
+ let r_mul = ten.pow(r_exp);
+ let five = BigInt::from(5);
+ let zero = BigInt::from(0);
+ arrow::compute::kernels::arity::binary(left, right, |l, r| {
+ let l = BigInt::from(l) * &l_mul;
+ let r = BigInt::from(r) * &r_mul;
+ let div = if r.eq(&zero) { zero.clone() } else { &l / &r };
+ let res = if div.is_negative() {
+ div - &five
+ } else {
+ div + &five
+ } / &ten;
+ res.to_i128().unwrap_or(i128::MAX)
+ })?
+ } else {
+ let l_mul = 10_i128.pow(l_exp);
+ let r_mul = 10_i128.pow(r_exp);
+ arrow::compute::kernels::arity::binary(left, right, |l, r| {
+ let l = l * l_mul;
+ let r = r * r_mul;
+ let div = if r == 0 { 0 } else { l / r };
+ let res = if div.is_negative() { div - 5 } else { div + 5 } / 10;
+ res.to_i128().unwrap_or(i128::MAX)
+ })?
+ };
let result = result.with_data_type(DataType::Decimal128(p3, s3));
Ok(ColumnarValue::Array(Arc::new(result)))
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]