This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 191af8dcd7 Make decimal multiplication allow precision-loss in
DataFusion (#6103)
191af8dcd7 is described below
commit 191af8dcd7c4ee756b7d98e06bd7b40dc23202b7
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Fri Apr 28 00:07:43 2023 -0700
Make decimal multiplication allow precision-loss in DataFusion (#6103)
* Use multiply_fixed_point_dyn to allow precision-loss decimal
multiplication
* Fix clippy
* Fix format
* Add unit test for kernel
* For review
* Fix API call
* Update datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs
Co-authored-by: Andrew Lamb <[email protected]>
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
.../core/tests/sqllogictests/test_files/tpch.slt | 2 +-
.../src/expressions/binary/kernels_arrow.rs | 248 +++++++++++++++++++--
2 files changed, 229 insertions(+), 21 deletions(-)
diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch.slt
b/datafusion/core/tests/sqllogictests/test_files/tpch.slt
index 619fc8e0f6..f8c5763496 100644
--- a/datafusion/core/tests/sqllogictests/test_files/tpch.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/tpch.slt
@@ -129,7 +129,7 @@ select
sum(l_quantity) as sum_qty,
sum(l_extendedprice) as sum_base_price,
sum(l_extendedprice * (1 - l_discount)) as sum_disc_price,
- sum(cast(l_extendedprice as decimal(12,2)) * (1 - l_discount) * (1 +
l_tax)) as sum_charge,
+ sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge,
avg(l_quantity) as avg_qty,
avg(l_extendedprice) as avg_price,
avg(l_discount) as avg_disc,
diff --git a/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs
b/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs
index 3b93d6f792..b796bf888d 100644
--- a/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs
+++ b/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs
@@ -20,11 +20,17 @@
use arrow::compute::{
add_dyn, add_scalar_dyn, divide_dyn_opt, divide_scalar_dyn, modulus_dyn,
- modulus_scalar_dyn, multiply_dyn, multiply_scalar_dyn, subtract_dyn,
- subtract_scalar_dyn, try_unary,
+ modulus_scalar_dyn, multiply_dyn, multiply_fixed_point,
multiply_scalar_dyn,
+ subtract_dyn, subtract_scalar_dyn, try_unary,
+};
+use arrow::datatypes::{
+ i256, Date32Type, Date64Type, Decimal128Type, Decimal256Type,
+ DECIMAL128_MAX_PRECISION,
};
-use arrow::datatypes::{Date32Type, Date64Type, Decimal128Type};
use arrow::{array::*, datatypes::ArrowNumericType, downcast_dictionary_array};
+use arrow_array::types::{ArrowDictionaryKeyType, DecimalType};
+use arrow_array::ArrowNativeTypeOp;
+use arrow_buffer::ArrowNativeType;
use arrow_schema::DataType;
use datafusion_common::cast::{as_date32_array, as_date64_array,
as_decimal128_array};
use datafusion_common::scalar::{date32_add, date64_add};
@@ -32,6 +38,7 @@ use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_expr::type_coercion::binary::decimal_op_mathematics_type;
use datafusion_expr::ColumnarValue;
use datafusion_expr::Operator;
+use std::cmp::min;
use std::sync::Arc;
use super::{
@@ -506,31 +513,156 @@ pub(crate) fn subtract_dyn_decimal(
decimal_array_with_precision_scale(array, precision, scale)
}
-pub(crate) fn multiply_dyn_decimal(
+/// Remove this once arrow-rs provides `multiply_fixed_point_dyn`.
+fn math_op_dict<K, T, F>(
+ left: &DictionaryArray<K>,
+ right: &DictionaryArray<K>,
+ op: F,
+) -> Result<PrimitiveArray<T>>
+where
+ K: ArrowDictionaryKeyType + ArrowNumericType,
+ T: ArrowNumericType,
+ F: Fn(T::Native, T::Native) -> T::Native,
+{
+ if left.len() != right.len() {
+ return Err(DataFusionError::Internal(format!(
+ "Cannot perform operation on arrays of different length ({}, {})",
+ left.len(),
+ right.len()
+ )));
+ }
+
+ // Safety justification: Since the inputs are valid Arrow arrays, all
values are
+ // valid indexes into the dictionary (which is verified during
construction)
+
+ let left_iter = unsafe {
+ left.values()
+ .as_primitive::<T>()
+ .take_iter_unchecked(left.keys_iter())
+ };
+
+ let right_iter = unsafe {
+ right
+ .values()
+ .as_primitive::<T>()
+ .take_iter_unchecked(right.keys_iter())
+ };
+
+ let result = left_iter
+ .zip(right_iter)
+ .map(|(left_value, right_value)| {
+ if let (Some(left), Some(right)) = (left_value, right_value) {
+ Some(op(left, right))
+ } else {
+ None
+ }
+ })
+ .collect();
+
+ Ok(result)
+}
+
+/// Divide a decimal native value by given divisor and round the result.
+/// Remove this once arrow-rs provides `multiply_fixed_point_dyn`.
+fn divide_and_round<I>(input: I::Native, div: I::Native) -> I::Native
+where
+ I: DecimalType,
+ I::Native: ArrowNativeTypeOp,
+{
+ let d = input.div_wrapping(div);
+ let r = input.mod_wrapping(div);
+
+ let half = div.div_wrapping(I::Native::from_usize(2).unwrap());
+ let half_neg = half.neg_wrapping();
+ // Round result
+ match input >= I::Native::ZERO {
+ true if r >= half => d.add_wrapping(I::Native::ONE),
+ false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
+ _ => d,
+ }
+}
+
+/// Remove this once arrow-rs provides `multiply_fixed_point_dyn`.
+/// <https://github.com/apache/arrow-rs/issues/4135>
+fn multiply_fixed_point_dyn(
left: &dyn Array,
right: &dyn Array,
- result_type: &DataType,
+ required_scale: i8,
) -> Result<ArrayRef> {
- let (precision, scale) = get_precision_scale(result_type)?;
+ match (left.data_type(), right.data_type()) {
+ (
+ DataType::Dictionary(_, lhs_value_type),
+ DataType::Dictionary(_, rhs_value_type),
+ ) if matches!(lhs_value_type.as_ref(), &DataType::Decimal128(_, _))
+ && matches!(rhs_value_type.as_ref(), &DataType::Decimal128(_, _))
=>
+ {
+ downcast_dictionary_array!(
+ left => match left.values().data_type() {
+ DataType::Decimal128(_, _) => {
+ let lhs_precision_scale =
get_precision_scale(lhs_value_type.as_ref())?;
+ let rhs_precision_scale =
get_precision_scale(rhs_value_type.as_ref())?;
- let op_type = decimal_op_mathematics_type(
- &Operator::Multiply,
- left.data_type(),
- left.data_type(),
- )
- .unwrap();
- let (_, op_scale) = get_precision_scale(&op_type)?;
+ let product_scale = lhs_precision_scale.1 +
rhs_precision_scale.1;
+ let precision = min(lhs_precision_scale.0 +
rhs_precision_scale.0 + 1, DECIMAL128_MAX_PRECISION);
- let array = multiply_dyn(left, right)?;
- if op_scale > scale {
- let div = 10_i128.pow((op_scale - scale) as u32);
- let array = divide_scalar_dyn::<Decimal128Type>(&array, div)?;
- decimal_array_with_precision_scale(array, precision, scale)
- } else {
- decimal_array_with_precision_scale(array, precision, scale)
+ if required_scale == product_scale {
+ return Ok(multiply_dyn(left,
right)?.as_primitive::<Decimal128Type>().clone()
+ .with_precision_and_scale(precision,
required_scale).map(|a| Arc::new(a) as ArrayRef)?);
+ }
+
+ if required_scale > product_scale {
+ return Err(DataFusionError::Internal(format!(
+ "Required scale {} is greater than product
scale {}",
+ required_scale, product_scale
+ )));
+ }
+
+ let divisor =
+ i256::from_i128(10).pow_wrapping((product_scale -
required_scale) as u32);
+
+ let right = as_dictionary_array::<_>(right);
+
+ let array = math_op_dict::<_, Decimal128Type, _>(left,
right, |a, b| {
+ let a = i256::from_i128(a);
+ let b = i256::from_i128(b);
+
+ let mut mul = a.wrapping_mul(b);
+ mul = divide_and_round::<Decimal256Type>(mul,
divisor);
+ mul.as_i128()
+ }).map(|a| a.with_precision_and_scale(precision,
required_scale).unwrap())?;
+
+ Ok(Arc::new(array))
+ }
+ t => unreachable!("Unsupported dictionary value type {}",
t),
+ },
+ t => unreachable!("Unsupported data type {}", t),
+ )
+ }
+ (DataType::Decimal128(_, _), DataType::Decimal128(_, _)) => {
+ let left =
left.as_any().downcast_ref::<Decimal128Array>().unwrap();
+ let right =
right.as_any().downcast_ref::<Decimal128Array>().unwrap();
+
+ Ok(multiply_fixed_point(left, right, required_scale)
+ .map(|a| Arc::new(a) as ArrayRef)?)
+ }
+ (_, _) => Err(DataFusionError::Internal(format!(
+ "Unsupported data type {}, {}",
+ left.data_type(),
+ right.data_type()
+ ))),
}
}
+pub(crate) fn multiply_dyn_decimal(
+ left: &dyn Array,
+ right: &dyn Array,
+ result_type: &DataType,
+) -> Result<ArrayRef> {
+ let (precision, scale) = get_precision_scale(result_type)?;
+ let array = multiply_fixed_point_dyn(left, right, scale)?;
+ decimal_array_with_precision_scale(array, precision, scale)
+}
+
pub(crate) fn divide_dyn_opt_decimal(
left: &dyn Array,
right: &dyn Array,
@@ -888,4 +1020,80 @@ mod tests {
);
Ok(())
}
+
+ #[test]
+ fn test_decimal_multiply_fixed_point_dyn() {
+ // [123456789]
+ let a = Decimal128Array::from(vec![123456789000000000000000000])
+ .with_precision_and_scale(38, 18)
+ .unwrap();
+
+ // [10]
+ let b = Decimal128Array::from(vec![10000000000000000000])
+ .with_precision_and_scale(38, 18)
+ .unwrap();
+
+ // Avoid overflow by reducing the scale.
+ let result = multiply_fixed_point_dyn(&a, &b, 28).unwrap();
+ // [1234567890]
+ let expected = Arc::new(
+ Decimal128Array::from(vec![12345678900000000000000000000000000000])
+ .with_precision_and_scale(38, 28)
+ .unwrap(),
+ ) as ArrayRef;
+
+ assert_eq!(&expected, &result);
+ assert_eq!(
+ result.as_primitive::<Decimal128Type>().value_as_string(0),
+ "1234567890.0000000000000000000000000000"
+ );
+
+ // [123456789, 10]
+ let a = Decimal128Array::from(vec![
+ 123456789000000000000000000,
+ 10000000000000000000,
+ ])
+ .with_precision_and_scale(38, 18)
+ .unwrap();
+
+ // [10, 123456789, 12]
+ let b = Decimal128Array::from(vec![
+ 10000000000000000000,
+ 123456789000000000000000000,
+ 12000000000000000000,
+ ])
+ .with_precision_and_scale(38, 18)
+ .unwrap();
+
+ let keys = Int8Array::from(vec![Some(0_i8), Some(1), Some(1), None]);
+ let array1 = DictionaryArray::new(keys, Arc::new(a));
+ let keys = Int8Array::from(vec![Some(0_i8), Some(1), Some(2), None]);
+ let array2 = DictionaryArray::new(keys, Arc::new(b));
+
+ let result = multiply_fixed_point_dyn(&array1, &array2, 28).unwrap();
+ let expected = Arc::new(
+ Decimal128Array::from(vec![
+ Some(12345678900000000000000000000000000000),
+ Some(12345678900000000000000000000000000000),
+ Some(1200000000000000000000000000000),
+ None,
+ ])
+ .with_precision_and_scale(38, 28)
+ .unwrap(),
+ ) as ArrayRef;
+
+ assert_eq!(&expected, &result);
+ assert_eq!(
+ result.as_primitive::<Decimal128Type>().value_as_string(0),
+ "1234567890.0000000000000000000000000000"
+ );
+ assert_eq!(
+ result.as_primitive::<Decimal128Type>().value_as_string(1),
+ "1234567890.0000000000000000000000000000"
+ );
+ assert_eq!(
+ result.as_primitive::<Decimal128Type>().value_as_string(2),
+ "120.0000000000000000000000000000"
+ );
+ }
}