This is an automated email from the ASF dual-hosted git repository.

liukun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new c75986506 Support arithmetic scalar operation with DictionaryArray 
(#5151)
c75986506 is described below

commit c75986506ff27d60e8f79e30baa838cf5c717b86
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Fri Feb 3 00:00:58 2023 -0800

    Support arithmetic scalar operation with DictionaryArray (#5151)
    
    * Support arithmetic dyn scalar
    
    * For review: removing unnecessary macro parameter, adding one more type 
coercion pattern
    
    * For review: modify comment
---
 datafusion/expr/src/type_coercion/binary.rs        |  19 ++
 datafusion/physical-expr/src/expressions/binary.rs | 328 ++++++++++++++++++++-
 .../src/expressions/binary/kernels_arrow.rs        | 172 +++++++----
 3 files changed, 459 insertions(+), 60 deletions(-)

diff --git a/datafusion/expr/src/type_coercion/binary.rs 
b/datafusion/expr/src/type_coercion/binary.rs
index 000011f9d..5e010f220 100644
--- a/datafusion/expr/src/type_coercion/binary.rs
+++ b/datafusion/expr/src/type_coercion/binary.rs
@@ -341,6 +341,15 @@ fn mathematics_numerical_coercion(
         (Null, dec_type @ Decimal128(_, _)) | (dec_type @ Decimal128(_, _), 
Null) => {
             Some(dec_type.clone())
         }
+        (Dictionary(key_type, value_type), _) => {
+            let value_type =
+                mathematics_numerical_coercion(mathematics_op, value_type, 
rhs_type);
+            value_type
+                .map(|value_type| Dictionary(key_type.clone(), 
Box::new(value_type)))
+        }
+        (_, Dictionary(_, value_type)) => {
+            mathematics_numerical_coercion(mathematics_op, lhs_type, 
value_type)
+        }
         (Decimal128(_, _), Float32 | Float64) => Some(Float64),
         (Float32 | Float64, Decimal128(_, _)) => Some(Float64),
         (Decimal128(_, _), _) => {
@@ -439,6 +448,16 @@ fn both_numeric_or_null_and_numeric(lhs_type: &DataType, 
rhs_type: &DataType) ->
     match (lhs_type, rhs_type) {
         (_, DataType::Null) => is_numeric(lhs_type),
         (DataType::Null, _) => is_numeric(rhs_type),
+        (
+            DataType::Dictionary(_, lhs_value_type),
+            DataType::Dictionary(_, rhs_value_type),
+        ) => is_numeric(lhs_value_type) && is_numeric(rhs_value_type),
+        (DataType::Dictionary(_, value_type), _) => {
+            is_numeric(value_type) && is_numeric(rhs_type)
+        }
+        (_, DataType::Dictionary(_, value_type)) => {
+            is_numeric(lhs_type) && is_numeric(value_type)
+        }
         _ => is_numeric(lhs_type) && is_numeric(rhs_type),
     }
 }
diff --git a/datafusion/physical-expr/src/expressions/binary.rs 
b/datafusion/physical-expr/src/expressions/binary.rs
index d2346d278..1c2bb065b 100644
--- a/datafusion/physical-expr/src/expressions/binary.rs
+++ b/datafusion/physical-expr/src/expressions/binary.rs
@@ -24,8 +24,10 @@ use std::{any::Any, sync::Arc};
 
 use arrow::array::*;
 use arrow::compute::kernels::arithmetic::{
-    add, add_scalar, divide_opt, divide_scalar, modulus, modulus_scalar, 
multiply,
-    multiply_scalar, subtract, subtract_scalar,
+    add, add_scalar_dyn as add_dyn_scalar, divide_opt,
+    divide_scalar_dyn as divide_dyn_scalar, modulus, modulus_scalar, multiply,
+    multiply_scalar_dyn as multiply_dyn_scalar, subtract,
+    subtract_scalar_dyn as subtract_dyn_scalar,
 };
 use arrow::compute::kernels::boolean::{and_kleene, not, or_kleene};
 use arrow::compute::kernels::comparison::regexp_is_match_utf8;
@@ -49,6 +51,7 @@ use arrow::compute::kernels::comparison::{
 use arrow::compute::kernels::comparison::{
     eq_scalar, gt_eq_scalar, gt_scalar, lt_eq_scalar, lt_scalar, neq_scalar,
 };
+use arrow::datatypes::*;
 
 use adapter::{eq_dyn, gt_dyn, gt_eq_dyn, lt_dyn, lt_eq_dyn, neq_dyn};
 use arrow::compute::kernels::concat_elements::concat_elements_utf8;
@@ -58,12 +61,12 @@ use kernels::{
     bitwise_xor, bitwise_xor_scalar,
 };
 use kernels_arrow::{
-    add_decimal, add_decimal_scalar, divide_decimal_scalar, divide_opt_decimal,
+    add_decimal, add_decimal_dyn_scalar, divide_decimal_dyn_scalar, 
divide_opt_decimal,
     is_distinct_from, is_distinct_from_bool, is_distinct_from_decimal,
     is_distinct_from_null, is_distinct_from_utf8, is_not_distinct_from,
     is_not_distinct_from_bool, is_not_distinct_from_decimal, 
is_not_distinct_from_null,
     is_not_distinct_from_utf8, modulus_decimal, modulus_decimal_scalar, 
multiply_decimal,
-    multiply_decimal_scalar, subtract_decimal, subtract_decimal_scalar,
+    multiply_decimal_dyn_scalar, subtract_decimal, subtract_decimal_dyn_scalar,
 };
 
 use arrow::datatypes::{DataType, Schema, TimeUnit};
@@ -315,6 +318,45 @@ macro_rules! compute_op_dyn_scalar {
     }};
 }
 
+/// Invoke a dyn compute kernel on a data array and a scalar value
+/// LEFT is Primitive or Dictionary array of numeric values, RIGHT is scalar 
value
+/// OP_TYPE is the return type of scalar function
+/// SCALAR_TYPE is the type of the scalar value
+/// Different to `compute_op_dyn_scalar`, this calls the `_dyn_scalar` 
functions that
+/// take a `SCALAR_TYPE`.
+macro_rules! compute_primitive_op_dyn_scalar {
+    ($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr, $SCALAR_TYPE:ident) => 
{{
+        // generate the scalar function name, such as lt_dyn_scalar, from the 
$OP parameter
+        // (which could have a value of lt_dyn) and the suffix _scalar
+        if let Some(value) = $RIGHT {
+            Ok(Arc::new(paste::expr! {[<$OP _dyn_scalar>]::<$SCALAR_TYPE>}(
+                $LEFT,
+                value,
+            )?))
+        } else {
+            // when the $RIGHT is a NULL, generate a NULL array of $OP_TYPE
+            Ok(Arc::new(new_null_array($OP_TYPE, $LEFT.len())))
+        }
+    }};
+}
+
+/// Invoke a dyn decimal compute kernel on a data array and a scalar value
+/// LEFT is Decimal or Dictionary array of decimal values, RIGHT is scalar 
value
+/// OP_TYPE is the return type of scalar function
+macro_rules! compute_primitive_decimal_op_dyn_scalar {
+    ($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{
+        // generate the scalar function name, such as add_decimal_dyn_scalar,
+        // from the $OP parameter (which could have a value of add) and the
+        // suffix _decimal_dyn_scalar
+        if let Some(value) = $RIGHT {
+            Ok(paste::expr! {[<$OP _decimal_dyn_scalar>]}($LEFT, value)?)
+        } else {
+            // when the $RIGHT is a NULL, generate a NULL array of $OP_TYPE
+            Ok(Arc::new(new_null_array($OP_TYPE, $LEFT.len())))
+        }
+    }};
+}
+
 /// Invoke a compute kernel on array(s)
 macro_rules! compute_op {
     // invoke binary operator
@@ -376,6 +418,37 @@ macro_rules! binary_primitive_array_op {
     }};
 }
 
+/// Invoke a compute dyn kernel on an array and a scalar
+/// The binary_primitive_array_op_dyn_scalar macro only evaluates for primitive
+/// types like integers and floats.
+macro_rules! binary_primitive_array_op_dyn_scalar {
+    ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
+        // unwrap underlying (non dictionary) value
+        let right = unwrap_dict_value($RIGHT);
+        let op_type = $LEFT.data_type();
+
+        let result: Result<Arc<dyn Array>> = match right {
+            ScalarValue::Decimal128(v, _, _) => 
compute_primitive_decimal_op_dyn_scalar!($LEFT, v, $OP, op_type),
+            ScalarValue::Int8(v) => compute_primitive_op_dyn_scalar!($LEFT, v, 
$OP, op_type, Int8Type),
+            ScalarValue::Int16(v) => compute_primitive_op_dyn_scalar!($LEFT, 
v, $OP, op_type, Int16Type),
+            ScalarValue::Int32(v) => compute_primitive_op_dyn_scalar!($LEFT, 
v, $OP, op_type, Int32Type),
+            ScalarValue::Int64(v) => compute_primitive_op_dyn_scalar!($LEFT, 
v, $OP, op_type, Int64Type),
+            ScalarValue::UInt8(v) => compute_primitive_op_dyn_scalar!($LEFT, 
v, $OP, op_type, UInt8Type),
+            ScalarValue::UInt16(v) => compute_primitive_op_dyn_scalar!($LEFT, 
v, $OP, op_type, UInt16Type),
+            ScalarValue::UInt32(v) => compute_primitive_op_dyn_scalar!($LEFT, 
v, $OP, op_type, UInt32Type),
+            ScalarValue::UInt64(v) => compute_primitive_op_dyn_scalar!($LEFT, 
v, $OP, op_type, UInt64Type),
+            ScalarValue::Float32(v) => compute_primitive_op_dyn_scalar!($LEFT, 
v, $OP, op_type, Float32Type),
+            ScalarValue::Float64(v) => compute_primitive_op_dyn_scalar!($LEFT, 
v, $OP, op_type, Float64Type),
+            other => Err(DataFusionError::Internal(format!(
+                "Data type {:?} not supported for scalar operation '{}' on dyn 
array",
+                other, stringify!($OP)))
+            )
+        };
+
+        Some(result)
+    }}
+}
+
 /// Invoke a compute kernel on an array and a scalar
 /// The binary_primitive_array_op_scalar macro only evaluates for primitive
 /// types like integers and floats.
@@ -924,18 +997,19 @@ impl BinaryExpr {
                 binary_array_op_dyn_scalar!(array, scalar.clone(), neq, 
bool_type)
             }
             Operator::Plus => {
-                binary_primitive_array_op_scalar!(array, scalar.clone(), add)
+                binary_primitive_array_op_dyn_scalar!(array, scalar.clone(), 
add)
             }
             Operator::Minus => {
-                binary_primitive_array_op_scalar!(array, scalar.clone(), 
subtract)
+                binary_primitive_array_op_dyn_scalar!(array, scalar.clone(), 
subtract)
             }
             Operator::Multiply => {
-                binary_primitive_array_op_scalar!(array, scalar.clone(), 
multiply)
+                binary_primitive_array_op_dyn_scalar!(array, scalar.clone(), 
multiply)
             }
             Operator::Divide => {
-                binary_primitive_array_op_scalar!(array, scalar.clone(), 
divide)
+                binary_primitive_array_op_dyn_scalar!(array, scalar.clone(), 
divide)
             }
             Operator::Modulo => {
+                // todo: change to binary_primitive_array_op_dyn_scalar! once 
modulo is implemented
                 binary_primitive_array_op_scalar!(array, scalar.clone(), 
modulus)
             }
             Operator::RegexMatch => binary_string_array_flag_op_scalar!(
@@ -1115,8 +1189,8 @@ pub fn binary(
 #[cfg(test)]
 mod tests {
     use super::*;
-    use crate::expressions::try_cast;
     use crate::expressions::{col, lit};
+    use crate::expressions::{try_cast, Literal};
     use arrow::datatypes::{
         ArrowNumericType, Decimal128Type, Field, Int32Type, SchemaRef,
     };
@@ -1565,6 +1639,61 @@ mod tests {
         Ok(())
     }
 
+    #[test]
+    fn plus_op_scalar() -> Result<()> {
+        let schema = Schema::new(vec![Field::new("a", DataType::Int32, 
false)]);
+        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
+
+        apply_arithmetic_scalar(
+            Arc::new(schema),
+            vec![Arc::new(a)],
+            Operator::Plus,
+            ScalarValue::Int32(Some(1)),
+            Arc::new(Int32Array::from(vec![2, 3, 4, 5, 6])),
+        )?;
+
+        Ok(())
+    }
+
+    #[test]
+    fn plus_op_dict_scalar() -> Result<()> {
+        let schema = Schema::new(vec![Field::new(
+            "a",
+            DataType::Dictionary(Box::new(DataType::Int8), 
Box::new(DataType::Int32)),
+            true,
+        )]);
+
+        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, 
Int32Type>::new();
+
+        dict_builder.append(1)?;
+        dict_builder.append_null();
+        dict_builder.append(2)?;
+        dict_builder.append(5)?;
+
+        let a = dict_builder.finish();
+
+        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, 
Int32Type>::new();
+
+        dict_builder.append(2)?;
+        dict_builder.append_null();
+        dict_builder.append(3)?;
+        dict_builder.append(6)?;
+        let expected = dict_builder.finish();
+
+        apply_arithmetic_scalar(
+            Arc::new(schema),
+            vec![Arc::new(a)],
+            Operator::Plus,
+            ScalarValue::Dictionary(
+                Box::new(DataType::Int8),
+                Box::new(ScalarValue::Int32(Some(1))),
+            ),
+            Arc::new(expected),
+        )?;
+
+        Ok(())
+    }
+
     #[test]
     fn minus_op() -> Result<()> {
         let schema = Arc::new(Schema::new(vec![
@@ -1592,6 +1721,61 @@ mod tests {
         Ok(())
     }
 
+    #[test]
+    fn minus_op_scalar() -> Result<()> {
+        let schema = Schema::new(vec![Field::new("a", DataType::Int32, 
false)]);
+        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
+
+        apply_arithmetic_scalar(
+            Arc::new(schema),
+            vec![Arc::new(a)],
+            Operator::Minus,
+            ScalarValue::Int32(Some(1)),
+            Arc::new(Int32Array::from(vec![0, 1, 2, 3, 4])),
+        )?;
+
+        Ok(())
+    }
+
+    #[test]
+    fn minus_op_dict_scalar() -> Result<()> {
+        let schema = Schema::new(vec![Field::new(
+            "a",
+            DataType::Dictionary(Box::new(DataType::Int8), 
Box::new(DataType::Int32)),
+            true,
+        )]);
+
+        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, 
Int32Type>::new();
+
+        dict_builder.append(1)?;
+        dict_builder.append_null();
+        dict_builder.append(2)?;
+        dict_builder.append(5)?;
+
+        let a = dict_builder.finish();
+
+        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, 
Int32Type>::new();
+
+        dict_builder.append(0)?;
+        dict_builder.append_null();
+        dict_builder.append(1)?;
+        dict_builder.append(4)?;
+        let expected = dict_builder.finish();
+
+        apply_arithmetic_scalar(
+            Arc::new(schema),
+            vec![Arc::new(a)],
+            Operator::Minus,
+            ScalarValue::Dictionary(
+                Box::new(DataType::Int8),
+                Box::new(ScalarValue::Int32(Some(1))),
+            ),
+            Arc::new(expected),
+        )?;
+
+        Ok(())
+    }
+
     #[test]
     fn multiply_op() -> Result<()> {
         let schema = Arc::new(Schema::new(vec![
@@ -1611,6 +1795,61 @@ mod tests {
         Ok(())
     }
 
+    #[test]
+    fn multiply_op_scalar() -> Result<()> {
+        let schema = Schema::new(vec![Field::new("a", DataType::Int32, 
false)]);
+        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
+
+        apply_arithmetic_scalar(
+            Arc::new(schema),
+            vec![Arc::new(a)],
+            Operator::Multiply,
+            ScalarValue::Int32(Some(2)),
+            Arc::new(Int32Array::from(vec![2, 4, 6, 8, 10])),
+        )?;
+
+        Ok(())
+    }
+
+    #[test]
+    fn multiply_op_dict_scalar() -> Result<()> {
+        let schema = Schema::new(vec![Field::new(
+            "a",
+            DataType::Dictionary(Box::new(DataType::Int8), 
Box::new(DataType::Int32)),
+            true,
+        )]);
+
+        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, 
Int32Type>::new();
+
+        dict_builder.append(1)?;
+        dict_builder.append_null();
+        dict_builder.append(2)?;
+        dict_builder.append(5)?;
+
+        let a = dict_builder.finish();
+
+        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, 
Int32Type>::new();
+
+        dict_builder.append(2)?;
+        dict_builder.append_null();
+        dict_builder.append(4)?;
+        dict_builder.append(10)?;
+        let expected = dict_builder.finish();
+
+        apply_arithmetic_scalar(
+            Arc::new(schema),
+            vec![Arc::new(a)],
+            Operator::Multiply,
+            ScalarValue::Dictionary(
+                Box::new(DataType::Int8),
+                Box::new(ScalarValue::Int32(Some(2))),
+            ),
+            Arc::new(expected),
+        )?;
+
+        Ok(())
+    }
+
     #[test]
     fn divide_op() -> Result<()> {
         let schema = Arc::new(Schema::new(vec![
@@ -1630,6 +1869,61 @@ mod tests {
         Ok(())
     }
 
+    #[test]
+    fn divide_op_scalar() -> Result<()> {
+        let schema = Schema::new(vec![Field::new("a", DataType::Int32, 
false)]);
+        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
+
+        apply_arithmetic_scalar(
+            Arc::new(schema),
+            vec![Arc::new(a)],
+            Operator::Divide,
+            ScalarValue::Int32(Some(2)),
+            Arc::new(Int32Array::from(vec![0, 1, 1, 2, 2])),
+        )?;
+
+        Ok(())
+    }
+
+    #[test]
+    fn divide_op_dict_scalar() -> Result<()> {
+        let schema = Schema::new(vec![Field::new(
+            "a",
+            DataType::Dictionary(Box::new(DataType::Int8), 
Box::new(DataType::Int32)),
+            true,
+        )]);
+
+        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, 
Int32Type>::new();
+
+        dict_builder.append(1)?;
+        dict_builder.append_null();
+        dict_builder.append(2)?;
+        dict_builder.append(5)?;
+
+        let a = dict_builder.finish();
+
+        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, 
Int32Type>::new();
+
+        dict_builder.append(0)?;
+        dict_builder.append_null();
+        dict_builder.append(1)?;
+        dict_builder.append(2)?;
+        let expected = dict_builder.finish();
+
+        apply_arithmetic_scalar(
+            Arc::new(schema),
+            vec![Arc::new(a)],
+            Operator::Divide,
+            ScalarValue::Dictionary(
+                Box::new(DataType::Int8),
+                Box::new(ScalarValue::Int32(Some(2))),
+            ),
+            Arc::new(expected),
+        )?;
+
+        Ok(())
+    }
+
     #[test]
     fn modulus_op() -> Result<()> {
         let schema = Arc::new(Schema::new(vec![
@@ -1664,6 +1958,22 @@ mod tests {
         Ok(())
     }
 
+    fn apply_arithmetic_scalar(
+        schema: SchemaRef,
+        data: Vec<ArrayRef>,
+        op: Operator,
+        literal: ScalarValue,
+        expected: ArrayRef,
+    ) -> Result<()> {
+        let lit = Arc::new(Literal::new(literal));
+        let arithmetic_op = binary_simple(col("a", &schema)?, op, lit, 
&schema);
+        let batch = RecordBatch::try_new(schema, data)?;
+        let result = 
arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
+
+        assert_eq!(&result, &expected);
+        Ok(())
+    }
+
     fn apply_logic_op(
         schema: &SchemaRef,
         left: &ArrayRef,
diff --git a/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs 
b/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs
index 2135982b6..40e0d2b0e 100644
--- a/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs
+++ b/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs
@@ -19,11 +19,16 @@
 //! destined for arrow-rs but are in datafusion until they are ported.
 
 use arrow::compute::{
-    add, add_scalar, divide_opt, divide_scalar, modulus, modulus_scalar, 
multiply,
-    multiply_scalar, subtract, subtract_scalar,
+    add, add_scalar_dyn, divide_opt, divide_scalar, divide_scalar_dyn, modulus,
+    modulus_scalar, multiply, multiply_scalar, multiply_scalar_dyn, subtract,
+    subtract_scalar_dyn,
 };
-use arrow::{array::*, datatypes::ArrowNumericType};
-use datafusion_common::Result;
+use arrow::datatypes::Decimal128Type;
+use arrow::{array::*, datatypes::ArrowNumericType, downcast_dictionary_array};
+use arrow_schema::DataType;
+use datafusion_common::cast::as_decimal128_array;
+use datafusion_common::{DataFusionError, Result};
+use std::sync::Arc;
 
 // Simple (low performance) kernels until optimized kernels are added to arrow
 // See https://github.com/apache/arrow-rs/issues/960
@@ -183,50 +188,123 @@ pub(crate) fn add_decimal(
     Ok(array)
 }
 
-pub(crate) fn add_decimal_scalar(
-    left: &Decimal128Array,
+pub(crate) fn add_decimal_dyn_scalar(left: &dyn Array, right: i128) -> 
Result<ArrayRef> {
+    let left_decimal = 
left.as_any().downcast_ref::<Decimal128Array>().unwrap();
+
+    let array = add_scalar_dyn::<Decimal128Type>(left, right)?;
+    let decimal_array = as_decimal128_array(&array)?;
+    let decimal_array = decimal_array
+        .clone()
+        .with_precision_and_scale(left_decimal.precision(), 
left_decimal.scale())?;
+    Ok(Arc::new(decimal_array))
+}
+
+pub(crate) fn subtract_decimal_dyn_scalar(
+    left: &dyn Array,
     right: i128,
-) -> Result<Decimal128Array> {
-    let array = add_scalar(left, right)?
-        .with_precision_and_scale(left.precision(), left.scale())?;
-    Ok(array)
+) -> Result<ArrayRef> {
+    let left_decimal = 
left.as_any().downcast_ref::<Decimal128Array>().unwrap();
+
+    let array = subtract_scalar_dyn::<Decimal128Type>(left, right)?;
+    let decimal_array = as_decimal128_array(&array)?;
+    let decimal_array = decimal_array
+        .clone()
+        .with_precision_and_scale(left_decimal.precision(), 
left_decimal.scale())?;
+    Ok(Arc::new(decimal_array))
 }
 
-pub(crate) fn subtract_decimal(
-    left: &Decimal128Array,
-    right: &Decimal128Array,
-) -> Result<Decimal128Array> {
-    let array = subtract(left, right)?
-        .with_precision_and_scale(left.precision(), left.scale())?;
-    Ok(array)
+fn get_precision_scale(left: &dyn Array) -> Result<(u8, i8)> {
+    match left.data_type() {
+        DataType::Decimal128(precision, scale) => Ok((*precision, *scale)),
+        DataType::Dictionary(_, value_type) => match value_type.as_ref() {
+            DataType::Decimal128(precision, scale) => Ok((*precision, *scale)),
+            _ => Err(DataFusionError::Internal(
+                "Unexpected data type".to_string(),
+            )),
+        },
+        _ => Err(DataFusionError::Internal(
+            "Unexpected data type".to_string(),
+        )),
+    }
 }
 
-pub(crate) fn subtract_decimal_scalar(
-    left: &Decimal128Array,
+fn decimal_array_with_precision_scale(
+    array: ArrayRef,
+    precision: u8,
+    scale: i8,
+) -> Result<ArrayRef> {
+    let array = array.as_ref();
+    let decimal_array = match array.data_type() {
+        DataType::Decimal128(_, _) => {
+            let array = as_decimal128_array(array)?;
+            Arc::new(array.clone().with_precision_and_scale(precision, scale)?)
+                as ArrayRef
+        }
+        DataType::Dictionary(_, _) => {
+            downcast_dictionary_array!(
+                array => match array.values().data_type() {
+                    DataType::Decimal128(_, _) => {
+                        let decimal_dict_array = 
array.downcast_dict::<Decimal128Array>().unwrap();
+                        let decimal_array = 
decimal_dict_array.values().clone();
+                        let decimal_array = 
decimal_array.with_precision_and_scale(precision, scale)?;
+                        Arc::new(array.with_values(&decimal_array)) as ArrayRef
+                    }
+                    t => return 
Err(DataFusionError::Internal(format!("Unexpected dictionary value type {t}"))),
+                },
+                t => return Err(DataFusionError::Internal(format!("Unexpected 
datatype {t}"))),
+            )
+        }
+        _ => {
+            return Err(DataFusionError::Internal(
+                "Unexpected data type".to_string(),
+            ))
+        }
+    };
+    Ok(decimal_array)
+}
+
+pub(crate) fn multiply_decimal_dyn_scalar(
+    left: &dyn Array,
     right: i128,
-) -> Result<Decimal128Array> {
-    let array = subtract_scalar(left, right)?
-        .with_precision_and_scale(left.precision(), left.scale())?;
-    Ok(array)
+) -> Result<ArrayRef> {
+    let (precision, scale) = get_precision_scale(left)?;
+
+    let array = multiply_scalar_dyn::<Decimal128Type>(left, right)?;
+
+    let divide = 10_i128.pow(scale as u32);
+    let array = divide_scalar_dyn::<Decimal128Type>(&array, divide)?;
+
+    decimal_array_with_precision_scale(array, precision, scale)
 }
 
-pub(crate) fn multiply_decimal(
+pub(crate) fn divide_decimal_dyn_scalar(
+    left: &dyn Array,
+    right: i128,
+) -> Result<ArrayRef> {
+    let (precision, scale) = get_precision_scale(left)?;
+
+    let mul = 10_i128.pow(scale as u32);
+    let array = multiply_scalar_dyn::<Decimal128Type>(left, mul)?;
+
+    let array = divide_scalar_dyn::<Decimal128Type>(&array, right)?;
+    decimal_array_with_precision_scale(array, precision, scale)
+}
+
+pub(crate) fn subtract_decimal(
     left: &Decimal128Array,
     right: &Decimal128Array,
 ) -> Result<Decimal128Array> {
-    let divide = 10_i128.pow(left.scale() as u32);
-    let array = multiply(left, right)?;
-    let array = divide_scalar(&array, divide)?
+    let array = subtract(left, right)?
         .with_precision_and_scale(left.precision(), left.scale())?;
     Ok(array)
 }
 
-pub(crate) fn multiply_decimal_scalar(
+pub(crate) fn multiply_decimal(
     left: &Decimal128Array,
-    right: i128,
+    right: &Decimal128Array,
 ) -> Result<Decimal128Array> {
-    let array = multiply_scalar(left, right)?;
     let divide = 10_i128.pow(left.scale() as u32);
+    let array = multiply(left, right)?;
     let array = divide_scalar(&array, divide)?
         .with_precision_and_scale(left.precision(), left.scale())?;
     Ok(array)
@@ -243,18 +321,6 @@ pub(crate) fn divide_opt_decimal(
     Ok(array)
 }
 
-pub(crate) fn divide_decimal_scalar(
-    left: &Decimal128Array,
-    right: i128,
-) -> Result<Decimal128Array> {
-    let mul = 10_i128.pow(left.scale() as u32);
-    let array = multiply_scalar(left, mul)?;
-    // `0` of right will be checked in `divide_scalar`
-    let array = divide_scalar(&array, right)?
-        .with_precision_and_scale(left.precision(), left.scale())?;
-    Ok(array)
-}
-
 pub(crate) fn modulus_decimal(
     left: &Decimal128Array,
     right: &Decimal128Array,
@@ -371,25 +437,28 @@ mod tests {
         let expect =
             create_decimal_array(&[Some(246), None, Some(245), Some(247)], 25, 
3);
         assert_eq!(expect, result);
-        let result = add_decimal_scalar(&left_decimal_array, 10)?;
+        let result = add_decimal_dyn_scalar(&left_decimal_array, 10)?;
+        let result = as_decimal128_array(&result)?;
         let expect =
             create_decimal_array(&[Some(133), None, Some(132), Some(134)], 25, 
3);
-        assert_eq!(expect, result);
+        assert_eq!(&expect, result);
         // subtract
         let result = subtract_decimal(&left_decimal_array, 
&right_decimal_array)?;
         let expect = create_decimal_array(&[Some(0), None, Some(-1), Some(1)], 
25, 3);
         assert_eq!(expect, result);
-        let result = subtract_decimal_scalar(&left_decimal_array, 10)?;
+        let result = subtract_decimal_dyn_scalar(&left_decimal_array, 10)?;
+        let result = as_decimal128_array(&result)?;
         let expect =
             create_decimal_array(&[Some(113), None, Some(112), Some(114)], 25, 
3);
-        assert_eq!(expect, result);
+        assert_eq!(&expect, result);
         // multiply
         let result = multiply_decimal(&left_decimal_array, 
&right_decimal_array)?;
         let expect = create_decimal_array(&[Some(15), None, Some(15), 
Some(15)], 25, 3);
         assert_eq!(expect, result);
-        let result = multiply_decimal_scalar(&left_decimal_array, 10)?;
+        let result = multiply_decimal_dyn_scalar(&left_decimal_array, 10)?;
+        let result = as_decimal128_array(&result)?;
         let expect = create_decimal_array(&[Some(1), None, Some(1), Some(1)], 
25, 3);
-        assert_eq!(expect, result);
+        assert_eq!(&expect, result);
         // divide
         let left_decimal_array = create_decimal_array(
             &[
@@ -414,7 +483,8 @@ mod tests {
             3,
         );
         assert_eq!(expect, result);
-        let result = divide_decimal_scalar(&left_decimal_array, 10)?;
+        let result = divide_decimal_dyn_scalar(&left_decimal_array, 10)?;
+        let result = as_decimal128_array(&result)?;
         let expect = create_decimal_array(
             &[
                 Some(123456700),
@@ -426,7 +496,7 @@ mod tests {
             25,
             3,
         );
-        assert_eq!(expect, result);
+        assert_eq!(&expect, result);
         let result = modulus_decimal(&left_decimal_array, 
&right_decimal_array)?;
         let expect =
             create_decimal_array(&[Some(7), None, Some(37), Some(16), None], 
25, 3);
@@ -444,7 +514,7 @@ mod tests {
         let left_decimal_array = create_decimal_array(&[Some(101)], 10, 1);
         let right_decimal_array = create_decimal_array(&[Some(0)], 1, 1);
 
-        let err = divide_decimal_scalar(&left_decimal_array, 0).unwrap_err();
+        let err = divide_decimal_dyn_scalar(&left_decimal_array, 
0).unwrap_err();
         assert_eq!("Arrow error: Divide by zero error", err.to_string());
         let err = modulus_decimal(&left_decimal_array, 
&right_decimal_array).unwrap_err();
         assert_eq!("Arrow error: Divide by zero error", err.to_string());

Reply via email to