This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 191af8dcd7 Make decimal multiplication allow precision-loss in 
DataFusion (#6103)
191af8dcd7 is described below

commit 191af8dcd7c4ee756b7d98e06bd7b40dc23202b7
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Fri Apr 28 00:07:43 2023 -0700

    Make decimal multiplication allow precision-loss in DataFusion (#6103)
    
    * Use multiply_fixed_point_dyn to allow precision-loss decimal 
multiplication
    
    * Fix clippy
    
    * Fix format
    
    * Add unit test for kernel
    
    * For review
    
    * Fix API call
    
    * Update datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs
    
    Co-authored-by: Andrew Lamb <[email protected]>
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 .../core/tests/sqllogictests/test_files/tpch.slt   |   2 +-
 .../src/expressions/binary/kernels_arrow.rs        | 248 +++++++++++++++++++--
 2 files changed, 229 insertions(+), 21 deletions(-)

diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch.slt 
b/datafusion/core/tests/sqllogictests/test_files/tpch.slt
index 619fc8e0f6..f8c5763496 100644
--- a/datafusion/core/tests/sqllogictests/test_files/tpch.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/tpch.slt
@@ -129,7 +129,7 @@ select
     sum(l_quantity) as sum_qty,
     sum(l_extendedprice) as sum_base_price,
     sum(l_extendedprice * (1 - l_discount)) as sum_disc_price,
-    sum(cast(l_extendedprice as decimal(12,2)) * (1 - l_discount) * (1 + 
l_tax)) as sum_charge,
+    sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge,
     avg(l_quantity) as avg_qty,
     avg(l_extendedprice) as avg_price,
     avg(l_discount) as avg_disc,
diff --git a/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs 
b/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs
index 3b93d6f792..b796bf888d 100644
--- a/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs
+++ b/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs
@@ -20,11 +20,17 @@
 
 use arrow::compute::{
     add_dyn, add_scalar_dyn, divide_dyn_opt, divide_scalar_dyn, modulus_dyn,
-    modulus_scalar_dyn, multiply_dyn, multiply_scalar_dyn, subtract_dyn,
-    subtract_scalar_dyn, try_unary,
+    modulus_scalar_dyn, multiply_dyn, multiply_fixed_point, 
multiply_scalar_dyn,
+    subtract_dyn, subtract_scalar_dyn, try_unary,
+};
+use arrow::datatypes::{
+    i256, Date32Type, Date64Type, Decimal128Type, Decimal256Type,
+    DECIMAL128_MAX_PRECISION,
 };
-use arrow::datatypes::{Date32Type, Date64Type, Decimal128Type};
 use arrow::{array::*, datatypes::ArrowNumericType, downcast_dictionary_array};
+use arrow_array::types::{ArrowDictionaryKeyType, DecimalType};
+use arrow_array::ArrowNativeTypeOp;
+use arrow_buffer::ArrowNativeType;
 use arrow_schema::DataType;
 use datafusion_common::cast::{as_date32_array, as_date64_array, 
as_decimal128_array};
 use datafusion_common::scalar::{date32_add, date64_add};
@@ -32,6 +38,7 @@ use datafusion_common::{DataFusionError, Result, ScalarValue};
 use datafusion_expr::type_coercion::binary::decimal_op_mathematics_type;
 use datafusion_expr::ColumnarValue;
 use datafusion_expr::Operator;
+use std::cmp::min;
 use std::sync::Arc;
 
 use super::{
@@ -506,31 +513,156 @@ pub(crate) fn subtract_dyn_decimal(
     decimal_array_with_precision_scale(array, precision, scale)
 }
 
-pub(crate) fn multiply_dyn_decimal(
+/// Remove this once arrow-rs provides `multiply_fixed_point_dyn`.
+fn math_op_dict<K, T, F>(
+    left: &DictionaryArray<K>,
+    right: &DictionaryArray<K>,
+    op: F,
+) -> Result<PrimitiveArray<T>>
+where
+    K: ArrowDictionaryKeyType + ArrowNumericType,
+    T: ArrowNumericType,
+    F: Fn(T::Native, T::Native) -> T::Native,
+{
+    if left.len() != right.len() {
+        return Err(DataFusionError::Internal(format!(
+            "Cannot perform operation on arrays of different length ({}, {})",
+            left.len(),
+            right.len()
+        )));
+    }
+
+    // Safety justification: Since the inputs are valid Arrow arrays, all 
values are
+    // valid indexes into the dictionary (which is verified during 
construction)
+
+    let left_iter = unsafe {
+        left.values()
+            .as_primitive::<T>()
+            .take_iter_unchecked(left.keys_iter())
+    };
+
+    let right_iter = unsafe {
+        right
+            .values()
+            .as_primitive::<T>()
+            .take_iter_unchecked(right.keys_iter())
+    };
+
+    let result = left_iter
+        .zip(right_iter)
+        .map(|(left_value, right_value)| {
+            if let (Some(left), Some(right)) = (left_value, right_value) {
+                Some(op(left, right))
+            } else {
+                None
+            }
+        })
+        .collect();
+
+    Ok(result)
+}
+
+/// Divide a decimal native value by given divisor and round the result.
+/// Remove this once arrow-rs provides `multiply_fixed_point_dyn`.
+fn divide_and_round<I>(input: I::Native, div: I::Native) -> I::Native
+where
+    I: DecimalType,
+    I::Native: ArrowNativeTypeOp,
+{
+    let d = input.div_wrapping(div);
+    let r = input.mod_wrapping(div);
+
+    let half = div.div_wrapping(I::Native::from_usize(2).unwrap());
+    let half_neg = half.neg_wrapping();
+    // Round result
+    match input >= I::Native::ZERO {
+        true if r >= half => d.add_wrapping(I::Native::ONE),
+        false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
+        _ => d,
+    }
+}
+
+/// Remove this once arrow-rs provides `multiply_fixed_point_dyn`.
+/// <https://github.com/apache/arrow-rs/issues/4135>
+fn multiply_fixed_point_dyn(
     left: &dyn Array,
     right: &dyn Array,
-    result_type: &DataType,
+    required_scale: i8,
 ) -> Result<ArrayRef> {
-    let (precision, scale) = get_precision_scale(result_type)?;
+    match (left.data_type(), right.data_type()) {
+        (
+            DataType::Dictionary(_, lhs_value_type),
+            DataType::Dictionary(_, rhs_value_type),
+        ) if matches!(lhs_value_type.as_ref(), &DataType::Decimal128(_, _))
+            && matches!(rhs_value_type.as_ref(), &DataType::Decimal128(_, _)) 
=>
+        {
+            downcast_dictionary_array!(
+                left => match left.values().data_type() {
+                    DataType::Decimal128(_, _) => {
+                        let lhs_precision_scale = 
get_precision_scale(lhs_value_type.as_ref())?;
+                        let rhs_precision_scale = 
get_precision_scale(rhs_value_type.as_ref())?;
 
-    let op_type = decimal_op_mathematics_type(
-        &Operator::Multiply,
-        left.data_type(),
-        left.data_type(),
-    )
-    .unwrap();
-    let (_, op_scale) = get_precision_scale(&op_type)?;
+                        let product_scale = lhs_precision_scale.1 + 
rhs_precision_scale.1;
+                        let precision = min(lhs_precision_scale.0 + 
rhs_precision_scale.0 + 1, DECIMAL128_MAX_PRECISION);
 
-    let array = multiply_dyn(left, right)?;
-    if op_scale > scale {
-        let div = 10_i128.pow((op_scale - scale) as u32);
-        let array = divide_scalar_dyn::<Decimal128Type>(&array, div)?;
-        decimal_array_with_precision_scale(array, precision, scale)
-    } else {
-        decimal_array_with_precision_scale(array, precision, scale)
+                        if required_scale == product_scale {
+                            return Ok(multiply_dyn(left, 
right)?.as_primitive::<Decimal128Type>().clone()
+                                .with_precision_and_scale(precision, 
required_scale).map(|a| Arc::new(a) as ArrayRef)?);
+                        }
+
+                        if required_scale > product_scale {
+                            return Err(DataFusionError::Internal(format!(
+                                "Required scale {} is greater than product 
scale {}",
+                                required_scale, product_scale
+                            )));
+                        }
+
+                        let divisor =
+                            i256::from_i128(10).pow_wrapping((product_scale - 
required_scale) as u32);
+
+                        let right = as_dictionary_array::<_>(right);
+
+                        let array = math_op_dict::<_, Decimal128Type, _>(left, 
right, |a, b| {
+                            let a = i256::from_i128(a);
+                            let b = i256::from_i128(b);
+
+                            let mut mul = a.wrapping_mul(b);
+                            mul = divide_and_round::<Decimal256Type>(mul, 
divisor);
+                            mul.as_i128()
+                        }).map(|a| a.with_precision_and_scale(precision, 
required_scale).unwrap())?;
+
+                        Ok(Arc::new(array))
+                    }
+                    t => unreachable!("Unsupported dictionary value type {}", 
t),
+                },
+                t => unreachable!("Unsupported data type {}", t),
+            )
+        }
+        (DataType::Decimal128(_, _), DataType::Decimal128(_, _)) => {
+            let left = 
left.as_any().downcast_ref::<Decimal128Array>().unwrap();
+            let right = 
right.as_any().downcast_ref::<Decimal128Array>().unwrap();
+
+            Ok(multiply_fixed_point(left, right, required_scale)
+                .map(|a| Arc::new(a) as ArrayRef)?)
+        }
+        (_, _) => Err(DataFusionError::Internal(format!(
+            "Unsupported data type {}, {}",
+            left.data_type(),
+            right.data_type()
+        ))),
     }
 }
 
+pub(crate) fn multiply_dyn_decimal(
+    left: &dyn Array,
+    right: &dyn Array,
+    result_type: &DataType,
+) -> Result<ArrayRef> {
+    let (precision, scale) = get_precision_scale(result_type)?;
+    let array = multiply_fixed_point_dyn(left, right, scale)?;
+    decimal_array_with_precision_scale(array, precision, scale)
+}
+
 pub(crate) fn divide_dyn_opt_decimal(
     left: &dyn Array,
     right: &dyn Array,
@@ -888,4 +1020,80 @@ mod tests {
         );
         Ok(())
     }
+
+    #[test]
+    fn test_decimal_multiply_fixed_point_dyn() {
+        // [123456789]
+        let a = Decimal128Array::from(vec![123456789000000000000000000])
+            .with_precision_and_scale(38, 18)
+            .unwrap();
+
+        // [10]
+        let b = Decimal128Array::from(vec![10000000000000000000])
+            .with_precision_and_scale(38, 18)
+            .unwrap();
+
+        // Avoid overflow by reducing the scale.
+        let result = multiply_fixed_point_dyn(&a, &b, 28).unwrap();
+        // [1234567890]
+        let expected = Arc::new(
+            Decimal128Array::from(vec![12345678900000000000000000000000000000])
+                .with_precision_and_scale(38, 28)
+                .unwrap(),
+        ) as ArrayRef;
+
+        assert_eq!(&expected, &result);
+        assert_eq!(
+            result.as_primitive::<Decimal128Type>().value_as_string(0),
+            "1234567890.0000000000000000000000000000"
+        );
+
+        // [123456789, 10]
+        let a = Decimal128Array::from(vec![
+            123456789000000000000000000,
+            10000000000000000000,
+        ])
+        .with_precision_and_scale(38, 18)
+        .unwrap();
+
+        // [10, 123456789, 12]
+        let b = Decimal128Array::from(vec![
+            10000000000000000000,
+            123456789000000000000000000,
+            12000000000000000000,
+        ])
+        .with_precision_and_scale(38, 18)
+        .unwrap();
+
+        let keys = Int8Array::from(vec![Some(0_i8), Some(1), Some(1), None]);
+        let array1 = DictionaryArray::new(keys, Arc::new(a));
+        let keys = Int8Array::from(vec![Some(0_i8), Some(1), Some(2), None]);
+        let array2 = DictionaryArray::new(keys, Arc::new(b));
+
+        let result = multiply_fixed_point_dyn(&array1, &array2, 28).unwrap();
+        let expected = Arc::new(
+            Decimal128Array::from(vec![
+                Some(12345678900000000000000000000000000000),
+                Some(12345678900000000000000000000000000000),
+                Some(1200000000000000000000000000000),
+                None,
+            ])
+            .with_precision_and_scale(38, 28)
+            .unwrap(),
+        ) as ArrayRef;
+
+        assert_eq!(&expected, &result);
+        assert_eq!(
+            result.as_primitive::<Decimal128Type>().value_as_string(0),
+            "1234567890.0000000000000000000000000000"
+        );
+        assert_eq!(
+            result.as_primitive::<Decimal128Type>().value_as_string(1),
+            "1234567890.0000000000000000000000000000"
+        );
+        assert_eq!(
+            result.as_primitive::<Decimal128Type>().value_as_string(2),
+            "120.0000000000000000000000000000"
+        );
+    }
 }

Reply via email to