alamb commented on code in PR #6103:
URL: https://github.com/apache/arrow-datafusion/pull/6103#discussion_r1179304452


##########
datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs:
##########
@@ -506,31 +513,155 @@ pub(crate) fn subtract_dyn_decimal(
     decimal_array_with_precision_scale(array, precision, scale)
 }
 
-pub(crate) fn multiply_dyn_decimal(
+/// Remove this once arrow-rs provides `multiply_fixed_point_dyn`.
+fn math_op_dict<K, T, F>(
+    left: &DictionaryArray<K>,
+    right: &DictionaryArray<K>,
+    op: F,
+) -> Result<PrimitiveArray<T>>
+where
+    K: ArrowDictionaryKeyType + ArrowNumericType,
+    T: ArrowNumericType,
+    F: Fn(T::Native, T::Native) -> T::Native,
+{
+    if left.len() != right.len() {
+        return Err(DataFusionError::Internal(format!(
+            "Cannot perform operation on arrays of different length ({}, {})",
+            left.len(),
+            right.len()
+        )));
+    }
+
+    // Safety justification: Since the inputs are valid Arrow arrays, all 
values are
+    // valid indexes into the dictionary (which is verified during 
construction)
+
+    let left_iter = unsafe {
+        left.values()
+            .as_primitive::<T>()
+            .take_iter_unchecked(left.keys_iter())
+    };
+
+    let right_iter = unsafe {
+        right
+            .values()
+            .as_primitive::<T>()
+            .take_iter_unchecked(right.keys_iter())
+    };
+
+    let result = left_iter
+        .zip(right_iter)
+        .map(|(left_value, right_value)| {
+            if let (Some(left), Some(right)) = (left_value, right_value) {
+                Some(op(left, right))
+            } else {
+                None
+            }
+        })
+        .collect();
+
+    Ok(result)
+}
+
+/// Divide a decimal native value by given divisor and round the result.
+/// Remove this once arrow-rs provides `multiply_fixed_point_dyn`.
+fn divide_and_round<I>(input: I::Native, div: I::Native) -> I::Native
+where
+    I: DecimalType,
+    I::Native: ArrowNativeTypeOp,
+{
+    let d = input.div_wrapping(div);
+    let r = input.mod_wrapping(div);
+
+    let half = div.div_wrapping(I::Native::from_usize(2).unwrap());
+    let half_neg = half.neg_wrapping();
+    // Round result
+    match input >= I::Native::ZERO {
+        true if r >= half => d.add_wrapping(I::Native::ONE),
+        false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
+        _ => d,
+    }
+}
+
+/// Remove this once arrow-rs provides `multiply_fixed_point_dyn`.
+fn multiply_fixed_point_dyn(
     left: &dyn Array,
     right: &dyn Array,
-    result_type: &DataType,
+    required_scale: i8,
 ) -> Result<ArrayRef> {
-    let (precision, scale) = get_precision_scale(result_type)?;
+    match (left.data_type(), right.data_type()) {
+        (
+            DataType::Dictionary(_, lhs_value_type),
+            DataType::Dictionary(_, rhs_value_type),
+        ) if matches!(lhs_value_type.as_ref(), &DataType::Decimal128(_, _))
+            && matches!(rhs_value_type.as_ref(), &DataType::Decimal128(_, _)) 
=>
+        {
+            downcast_dictionary_array!(
+                left => match left.values().data_type() {
+                    DataType::Decimal128(_, _) => {
+                        let lhs_precision_scale = 
get_precision_scale(lhs_value_type.as_ref())?;
+                        let rhs_precision_scale = 
get_precision_scale(rhs_value_type.as_ref())?;
 
-    let op_type = decimal_op_mathematics_type(
-        &Operator::Multiply,
-        left.data_type(),
-        left.data_type(),
-    )
-    .unwrap();
-    let (_, op_scale) = get_precision_scale(&op_type)?;
+                        let product_scale = lhs_precision_scale.1 + 
rhs_precision_scale.1;
+                        let precision = min(lhs_precision_scale.0 + 
rhs_precision_scale.0 + 1, DECIMAL128_MAX_PRECISION);
 
-    let array = multiply_dyn(left, right)?;
-    if op_scale > scale {
-        let div = 10_i128.pow((op_scale - scale) as u32);
-        let array = divide_scalar_dyn::<Decimal128Type>(&array, div)?;
-        decimal_array_with_precision_scale(array, precision, scale)
-    } else {
-        decimal_array_with_precision_scale(array, precision, scale)
+                        if required_scale == product_scale {
+                            return Ok(multiply_dyn(left, 
right)?.as_primitive::<Decimal128Type>().clone()
+                                .with_precision_and_scale(precision, 
required_scale).map(|a| Arc::new(a) as ArrayRef)?);
+                        }
+
+                        if required_scale > product_scale {
+                            return Err(DataFusionError::Internal(format!(
+                                "Required scale {} is greater than product 
scale {}",
+                                required_scale, product_scale
+                            )));
+                        }
+
+                        let divisor =
+                            i256::from_i128(10).pow_wrapping((product_scale - 
required_scale) as u32);
+
+                        let right = as_dictionary_array::<_>(right);
+
+                        let array = math_op_dict::<_, Decimal128Type, _>(left, 
right, |a, b| {
+                            let a = i256::from_i128(a);
+                            let b = i256::from_i128(b);
+
+                            let mut mul = a.wrapping_mul(b);
+                            mul = divide_and_round::<Decimal256Type>(mul, 
divisor);
+                            mul.as_i128()
+                        }).map(|a| a.with_precision_and_scale(precision, 
required_scale).unwrap())?;
+
+                        Ok(Arc::new(array))
+                    }
+                    t => unreachable!("Unsupported dictionary value type {}", 
t),
+                },
+                t => unreachable!("Unsupported data type {}", t),
+            )
+        }
+        (DataType::Decimal128(_, _), DataType::Decimal128(_, _)) => {
+            let left = 
left.as_any().downcast_ref::<Decimal128Array>().unwrap();
+            let right = 
right.as_any().downcast_ref::<Decimal128Array>().unwrap();
+
+            Ok(multiply_fixed_point(left, right, required_scale)

Review Comment:
   I was confused for a while why there is so much code for the dictionary case 
but the normal decimal case is calling `multiply_fixed_point`
   
   ThenI see that part of the issue is the required precision / scale 
calculation is duplicated
   
https://github.com/apache/arrow-rs/blob/9fa8125fbe14a3a85b4995617945bda51ee3b055/arrow-arith/src/arithmetic.rs#L1508-L1528
   
   I think this is good for DataFusion, and I'll comment on  
https://github.com/apache/arrow-rs/pull/4136 about removing the duplication
   
   



##########
datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs:
##########
@@ -506,31 +513,155 @@ pub(crate) fn subtract_dyn_decimal(
     decimal_array_with_precision_scale(array, precision, scale)
 }
 
-pub(crate) fn multiply_dyn_decimal(
+/// Remove this once arrow-rs provides `multiply_fixed_point_dyn`.
+fn math_op_dict<K, T, F>(
+    left: &DictionaryArray<K>,
+    right: &DictionaryArray<K>,
+    op: F,
+) -> Result<PrimitiveArray<T>>
+where
+    K: ArrowDictionaryKeyType + ArrowNumericType,
+    T: ArrowNumericType,
+    F: Fn(T::Native, T::Native) -> T::Native,
+{
+    if left.len() != right.len() {
+        return Err(DataFusionError::Internal(format!(
+            "Cannot perform operation on arrays of different length ({}, {})",
+            left.len(),
+            right.len()
+        )));
+    }
+
+    // Safety justification: Since the inputs are valid Arrow arrays, all 
values are
+    // valid indexes into the dictionary (which is verified during 
construction)
+
+    let left_iter = unsafe {
+        left.values()
+            .as_primitive::<T>()
+            .take_iter_unchecked(left.keys_iter())
+    };
+
+    let right_iter = unsafe {
+        right
+            .values()
+            .as_primitive::<T>()
+            .take_iter_unchecked(right.keys_iter())
+    };
+
+    let result = left_iter
+        .zip(right_iter)
+        .map(|(left_value, right_value)| {
+            if let (Some(left), Some(right)) = (left_value, right_value) {
+                Some(op(left, right))
+            } else {
+                None
+            }
+        })
+        .collect();
+
+    Ok(result)
+}
+
+/// Divide a decimal native value by given divisor and round the result.
+/// Remove this once arrow-rs provides `multiply_fixed_point_dyn`.
+fn divide_and_round<I>(input: I::Native, div: I::Native) -> I::Native
+where
+    I: DecimalType,
+    I::Native: ArrowNativeTypeOp,
+{
+    let d = input.div_wrapping(div);
+    let r = input.mod_wrapping(div);
+
+    let half = div.div_wrapping(I::Native::from_usize(2).unwrap());
+    let half_neg = half.neg_wrapping();
+    // Round result
+    match input >= I::Native::ZERO {
+        true if r >= half => d.add_wrapping(I::Native::ONE),
+        false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
+        _ => d,
+    }
+}
+
+/// Remove this once arrow-rs provides `multiply_fixed_point_dyn`.
+fn multiply_fixed_point_dyn(

Review Comment:
   ```suggestion
   /// <https://github.com/apache/arrow-rs/issues/4135>
   fn multiply_fixed_point_dyn(
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to