This is an automated email from the ASF dual-hosted git repository.
liukun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new c75986506 Support arithmetic scalar operation with DictionaryArray
(#5151)
c75986506 is described below
commit c75986506ff27d60e8f79e30baa838cf5c717b86
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Fri Feb 3 00:00:58 2023 -0800
Support arithmetic scalar operation with DictionaryArray (#5151)
* Support arithmetic dyn scalar
* For review: removing unnecessary macro parameter, adding one more type
coercion pattern
* For review: modify comment
---
datafusion/expr/src/type_coercion/binary.rs | 19 ++
datafusion/physical-expr/src/expressions/binary.rs | 328 ++++++++++++++++++++-
.../src/expressions/binary/kernels_arrow.rs | 172 +++++++----
3 files changed, 459 insertions(+), 60 deletions(-)
diff --git a/datafusion/expr/src/type_coercion/binary.rs
b/datafusion/expr/src/type_coercion/binary.rs
index 000011f9d..5e010f220 100644
--- a/datafusion/expr/src/type_coercion/binary.rs
+++ b/datafusion/expr/src/type_coercion/binary.rs
@@ -341,6 +341,15 @@ fn mathematics_numerical_coercion(
(Null, dec_type @ Decimal128(_, _)) | (dec_type @ Decimal128(_, _),
Null) => {
Some(dec_type.clone())
}
+ (Dictionary(key_type, value_type), _) => {
+ let value_type =
+ mathematics_numerical_coercion(mathematics_op, value_type,
rhs_type);
+ value_type
+ .map(|value_type| Dictionary(key_type.clone(),
Box::new(value_type)))
+ }
+ (_, Dictionary(_, value_type)) => {
+ mathematics_numerical_coercion(mathematics_op, lhs_type,
value_type)
+ }
(Decimal128(_, _), Float32 | Float64) => Some(Float64),
(Float32 | Float64, Decimal128(_, _)) => Some(Float64),
(Decimal128(_, _), _) => {
@@ -439,6 +448,16 @@ fn both_numeric_or_null_and_numeric(lhs_type: &DataType,
rhs_type: &DataType) ->
match (lhs_type, rhs_type) {
(_, DataType::Null) => is_numeric(lhs_type),
(DataType::Null, _) => is_numeric(rhs_type),
+ (
+ DataType::Dictionary(_, lhs_value_type),
+ DataType::Dictionary(_, rhs_value_type),
+ ) => is_numeric(lhs_value_type) && is_numeric(rhs_value_type),
+ (DataType::Dictionary(_, value_type), _) => {
+ is_numeric(value_type) && is_numeric(rhs_type)
+ }
+ (_, DataType::Dictionary(_, value_type)) => {
+ is_numeric(lhs_type) && is_numeric(value_type)
+ }
_ => is_numeric(lhs_type) && is_numeric(rhs_type),
}
}
diff --git a/datafusion/physical-expr/src/expressions/binary.rs
b/datafusion/physical-expr/src/expressions/binary.rs
index d2346d278..1c2bb065b 100644
--- a/datafusion/physical-expr/src/expressions/binary.rs
+++ b/datafusion/physical-expr/src/expressions/binary.rs
@@ -24,8 +24,10 @@ use std::{any::Any, sync::Arc};
use arrow::array::*;
use arrow::compute::kernels::arithmetic::{
- add, add_scalar, divide_opt, divide_scalar, modulus, modulus_scalar,
multiply,
- multiply_scalar, subtract, subtract_scalar,
+ add, add_scalar_dyn as add_dyn_scalar, divide_opt,
+ divide_scalar_dyn as divide_dyn_scalar, modulus, modulus_scalar, multiply,
+ multiply_scalar_dyn as multiply_dyn_scalar, subtract,
+ subtract_scalar_dyn as subtract_dyn_scalar,
};
use arrow::compute::kernels::boolean::{and_kleene, not, or_kleene};
use arrow::compute::kernels::comparison::regexp_is_match_utf8;
@@ -49,6 +51,7 @@ use arrow::compute::kernels::comparison::{
use arrow::compute::kernels::comparison::{
eq_scalar, gt_eq_scalar, gt_scalar, lt_eq_scalar, lt_scalar, neq_scalar,
};
+use arrow::datatypes::*;
use adapter::{eq_dyn, gt_dyn, gt_eq_dyn, lt_dyn, lt_eq_dyn, neq_dyn};
use arrow::compute::kernels::concat_elements::concat_elements_utf8;
@@ -58,12 +61,12 @@ use kernels::{
bitwise_xor, bitwise_xor_scalar,
};
use kernels_arrow::{
- add_decimal, add_decimal_scalar, divide_decimal_scalar, divide_opt_decimal,
+ add_decimal, add_decimal_dyn_scalar, divide_decimal_dyn_scalar,
divide_opt_decimal,
is_distinct_from, is_distinct_from_bool, is_distinct_from_decimal,
is_distinct_from_null, is_distinct_from_utf8, is_not_distinct_from,
is_not_distinct_from_bool, is_not_distinct_from_decimal,
is_not_distinct_from_null,
is_not_distinct_from_utf8, modulus_decimal, modulus_decimal_scalar,
multiply_decimal,
- multiply_decimal_scalar, subtract_decimal, subtract_decimal_scalar,
+ multiply_decimal_dyn_scalar, subtract_decimal, subtract_decimal_dyn_scalar,
};
use arrow::datatypes::{DataType, Schema, TimeUnit};
@@ -315,6 +318,45 @@ macro_rules! compute_op_dyn_scalar {
}};
}
+/// Invoke a dyn compute kernel on a data array and a scalar value
+/// LEFT is Primitive or Dictionary array of numeric values, RIGHT is scalar
value
+/// OP_TYPE is the return type of scalar function
+/// SCALAR_TYPE is the type of the scalar value
+/// Different to `compute_op_dyn_scalar`, this calls the `_dyn_scalar`
functions that
+/// take a `SCALAR_TYPE`.
+macro_rules! compute_primitive_op_dyn_scalar {
+ ($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr, $SCALAR_TYPE:ident) =>
{{
+ // generate the scalar function name, such as lt_dyn_scalar, from the
$OP parameter
+ // (which could have a value of lt_dyn) and the suffix _scalar
+ if let Some(value) = $RIGHT {
+ Ok(Arc::new(paste::expr! {[<$OP _dyn_scalar>]::<$SCALAR_TYPE>}(
+ $LEFT,
+ value,
+ )?))
+ } else {
+ // when the $RIGHT is a NULL, generate a NULL array of $OP_TYPE
+ Ok(Arc::new(new_null_array($OP_TYPE, $LEFT.len())))
+ }
+ }};
+}
+
+/// Invoke a dyn decimal compute kernel on a data array and a scalar value
+/// LEFT is Decimal or Dictionary array of decimal values, RIGHT is scalar
value
+/// OP_TYPE is the return type of scalar function
+macro_rules! compute_primitive_decimal_op_dyn_scalar {
+ ($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{
+ // generate the scalar function name, such as add_decimal_dyn_scalar,
+ // from the $OP parameter (which could have a value of add) and the
+ // suffix _decimal_dyn_scalar
+ if let Some(value) = $RIGHT {
+ Ok(paste::expr! {[<$OP _decimal_dyn_scalar>]}($LEFT, value)?)
+ } else {
+ // when the $RIGHT is a NULL, generate a NULL array of $OP_TYPE
+ Ok(Arc::new(new_null_array($OP_TYPE, $LEFT.len())))
+ }
+ }};
+}
+
/// Invoke a compute kernel on array(s)
macro_rules! compute_op {
// invoke binary operator
@@ -376,6 +418,37 @@ macro_rules! binary_primitive_array_op {
}};
}
+/// Invoke a compute dyn kernel on an array and a scalar
+/// The binary_primitive_array_op_dyn_scalar macro only evaluates for primitive
+/// types like integers and floats.
+macro_rules! binary_primitive_array_op_dyn_scalar {
+ ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
+ // unwrap underlying (non dictionary) value
+ let right = unwrap_dict_value($RIGHT);
+ let op_type = $LEFT.data_type();
+
+ let result: Result<Arc<dyn Array>> = match right {
+ ScalarValue::Decimal128(v, _, _) =>
compute_primitive_decimal_op_dyn_scalar!($LEFT, v, $OP, op_type),
+ ScalarValue::Int8(v) => compute_primitive_op_dyn_scalar!($LEFT, v,
$OP, op_type, Int8Type),
+ ScalarValue::Int16(v) => compute_primitive_op_dyn_scalar!($LEFT,
v, $OP, op_type, Int16Type),
+ ScalarValue::Int32(v) => compute_primitive_op_dyn_scalar!($LEFT,
v, $OP, op_type, Int32Type),
+ ScalarValue::Int64(v) => compute_primitive_op_dyn_scalar!($LEFT,
v, $OP, op_type, Int64Type),
+ ScalarValue::UInt8(v) => compute_primitive_op_dyn_scalar!($LEFT,
v, $OP, op_type, UInt8Type),
+ ScalarValue::UInt16(v) => compute_primitive_op_dyn_scalar!($LEFT,
v, $OP, op_type, UInt16Type),
+ ScalarValue::UInt32(v) => compute_primitive_op_dyn_scalar!($LEFT,
v, $OP, op_type, UInt32Type),
+ ScalarValue::UInt64(v) => compute_primitive_op_dyn_scalar!($LEFT,
v, $OP, op_type, UInt64Type),
+ ScalarValue::Float32(v) => compute_primitive_op_dyn_scalar!($LEFT,
v, $OP, op_type, Float32Type),
+ ScalarValue::Float64(v) => compute_primitive_op_dyn_scalar!($LEFT,
v, $OP, op_type, Float64Type),
+ other => Err(DataFusionError::Internal(format!(
+ "Data type {:?} not supported for scalar operation '{}' on dyn
array",
+ other, stringify!($OP)))
+ )
+ };
+
+ Some(result)
+ }}
+}
+
/// Invoke a compute kernel on an array and a scalar
/// The binary_primitive_array_op_scalar macro only evaluates for primitive
/// types like integers and floats.
@@ -924,18 +997,19 @@ impl BinaryExpr {
binary_array_op_dyn_scalar!(array, scalar.clone(), neq,
bool_type)
}
Operator::Plus => {
- binary_primitive_array_op_scalar!(array, scalar.clone(), add)
+ binary_primitive_array_op_dyn_scalar!(array, scalar.clone(),
add)
}
Operator::Minus => {
- binary_primitive_array_op_scalar!(array, scalar.clone(),
subtract)
+ binary_primitive_array_op_dyn_scalar!(array, scalar.clone(),
subtract)
}
Operator::Multiply => {
- binary_primitive_array_op_scalar!(array, scalar.clone(),
multiply)
+ binary_primitive_array_op_dyn_scalar!(array, scalar.clone(),
multiply)
}
Operator::Divide => {
- binary_primitive_array_op_scalar!(array, scalar.clone(),
divide)
+ binary_primitive_array_op_dyn_scalar!(array, scalar.clone(),
divide)
}
Operator::Modulo => {
+ // todo: change to binary_primitive_array_op_dyn_scalar! once
modulo is implemented
binary_primitive_array_op_scalar!(array, scalar.clone(),
modulus)
}
Operator::RegexMatch => binary_string_array_flag_op_scalar!(
@@ -1115,8 +1189,8 @@ pub fn binary(
#[cfg(test)]
mod tests {
use super::*;
- use crate::expressions::try_cast;
use crate::expressions::{col, lit};
+ use crate::expressions::{try_cast, Literal};
use arrow::datatypes::{
ArrowNumericType, Decimal128Type, Field, Int32Type, SchemaRef,
};
@@ -1565,6 +1639,61 @@ mod tests {
Ok(())
}
+ #[test]
+ fn plus_op_scalar() -> Result<()> {
+ let schema = Schema::new(vec![Field::new("a", DataType::Int32,
false)]);
+ let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
+
+ apply_arithmetic_scalar(
+ Arc::new(schema),
+ vec![Arc::new(a)],
+ Operator::Plus,
+ ScalarValue::Int32(Some(1)),
+ Arc::new(Int32Array::from(vec![2, 3, 4, 5, 6])),
+ )?;
+
+ Ok(())
+ }
+
+ #[test]
+ fn plus_op_dict_scalar() -> Result<()> {
+ let schema = Schema::new(vec![Field::new(
+ "a",
+ DataType::Dictionary(Box::new(DataType::Int8),
Box::new(DataType::Int32)),
+ true,
+ )]);
+
+ let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type,
Int32Type>::new();
+
+ dict_builder.append(1)?;
+ dict_builder.append_null();
+ dict_builder.append(2)?;
+ dict_builder.append(5)?;
+
+ let a = dict_builder.finish();
+
+ let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type,
Int32Type>::new();
+
+ dict_builder.append(2)?;
+ dict_builder.append_null();
+ dict_builder.append(3)?;
+ dict_builder.append(6)?;
+ let expected = dict_builder.finish();
+
+ apply_arithmetic_scalar(
+ Arc::new(schema),
+ vec![Arc::new(a)],
+ Operator::Plus,
+ ScalarValue::Dictionary(
+ Box::new(DataType::Int8),
+ Box::new(ScalarValue::Int32(Some(1))),
+ ),
+ Arc::new(expected),
+ )?;
+
+ Ok(())
+ }
+
#[test]
fn minus_op() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
@@ -1592,6 +1721,61 @@ mod tests {
Ok(())
}
+ #[test]
+ fn minus_op_scalar() -> Result<()> {
+ let schema = Schema::new(vec![Field::new("a", DataType::Int32,
false)]);
+ let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
+
+ apply_arithmetic_scalar(
+ Arc::new(schema),
+ vec![Arc::new(a)],
+ Operator::Minus,
+ ScalarValue::Int32(Some(1)),
+ Arc::new(Int32Array::from(vec![0, 1, 2, 3, 4])),
+ )?;
+
+ Ok(())
+ }
+
+ #[test]
+ fn minus_op_dict_scalar() -> Result<()> {
+ let schema = Schema::new(vec![Field::new(
+ "a",
+ DataType::Dictionary(Box::new(DataType::Int8),
Box::new(DataType::Int32)),
+ true,
+ )]);
+
+ let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type,
Int32Type>::new();
+
+ dict_builder.append(1)?;
+ dict_builder.append_null();
+ dict_builder.append(2)?;
+ dict_builder.append(5)?;
+
+ let a = dict_builder.finish();
+
+ let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type,
Int32Type>::new();
+
+ dict_builder.append(0)?;
+ dict_builder.append_null();
+ dict_builder.append(1)?;
+ dict_builder.append(4)?;
+ let expected = dict_builder.finish();
+
+ apply_arithmetic_scalar(
+ Arc::new(schema),
+ vec![Arc::new(a)],
+ Operator::Minus,
+ ScalarValue::Dictionary(
+ Box::new(DataType::Int8),
+ Box::new(ScalarValue::Int32(Some(1))),
+ ),
+ Arc::new(expected),
+ )?;
+
+ Ok(())
+ }
+
#[test]
fn multiply_op() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
@@ -1611,6 +1795,61 @@ mod tests {
Ok(())
}
+ #[test]
+ fn multiply_op_scalar() -> Result<()> {
+ let schema = Schema::new(vec![Field::new("a", DataType::Int32,
false)]);
+ let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
+
+ apply_arithmetic_scalar(
+ Arc::new(schema),
+ vec![Arc::new(a)],
+ Operator::Multiply,
+ ScalarValue::Int32(Some(2)),
+ Arc::new(Int32Array::from(vec![2, 4, 6, 8, 10])),
+ )?;
+
+ Ok(())
+ }
+
+ #[test]
+ fn multiply_op_dict_scalar() -> Result<()> {
+ let schema = Schema::new(vec![Field::new(
+ "a",
+ DataType::Dictionary(Box::new(DataType::Int8),
Box::new(DataType::Int32)),
+ true,
+ )]);
+
+ let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type,
Int32Type>::new();
+
+ dict_builder.append(1)?;
+ dict_builder.append_null();
+ dict_builder.append(2)?;
+ dict_builder.append(5)?;
+
+ let a = dict_builder.finish();
+
+ let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type,
Int32Type>::new();
+
+ dict_builder.append(2)?;
+ dict_builder.append_null();
+ dict_builder.append(4)?;
+ dict_builder.append(10)?;
+ let expected = dict_builder.finish();
+
+ apply_arithmetic_scalar(
+ Arc::new(schema),
+ vec![Arc::new(a)],
+ Operator::Multiply,
+ ScalarValue::Dictionary(
+ Box::new(DataType::Int8),
+ Box::new(ScalarValue::Int32(Some(2))),
+ ),
+ Arc::new(expected),
+ )?;
+
+ Ok(())
+ }
+
#[test]
fn divide_op() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
@@ -1630,6 +1869,61 @@ mod tests {
Ok(())
}
+ #[test]
+ fn divide_op_scalar() -> Result<()> {
+ let schema = Schema::new(vec![Field::new("a", DataType::Int32,
false)]);
+ let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
+
+ apply_arithmetic_scalar(
+ Arc::new(schema),
+ vec![Arc::new(a)],
+ Operator::Divide,
+ ScalarValue::Int32(Some(2)),
+ Arc::new(Int32Array::from(vec![0, 1, 1, 2, 2])),
+ )?;
+
+ Ok(())
+ }
+
+ #[test]
+ fn divide_op_dict_scalar() -> Result<()> {
+ let schema = Schema::new(vec![Field::new(
+ "a",
+ DataType::Dictionary(Box::new(DataType::Int8),
Box::new(DataType::Int32)),
+ true,
+ )]);
+
+ let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type,
Int32Type>::new();
+
+ dict_builder.append(1)?;
+ dict_builder.append_null();
+ dict_builder.append(2)?;
+ dict_builder.append(5)?;
+
+ let a = dict_builder.finish();
+
+ let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type,
Int32Type>::new();
+
+ dict_builder.append(0)?;
+ dict_builder.append_null();
+ dict_builder.append(1)?;
+ dict_builder.append(2)?;
+ let expected = dict_builder.finish();
+
+ apply_arithmetic_scalar(
+ Arc::new(schema),
+ vec![Arc::new(a)],
+ Operator::Divide,
+ ScalarValue::Dictionary(
+ Box::new(DataType::Int8),
+ Box::new(ScalarValue::Int32(Some(2))),
+ ),
+ Arc::new(expected),
+ )?;
+
+ Ok(())
+ }
+
#[test]
fn modulus_op() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
@@ -1664,6 +1958,22 @@ mod tests {
Ok(())
}
+ fn apply_arithmetic_scalar(
+ schema: SchemaRef,
+ data: Vec<ArrayRef>,
+ op: Operator,
+ literal: ScalarValue,
+ expected: ArrayRef,
+ ) -> Result<()> {
+ let lit = Arc::new(Literal::new(literal));
+ let arithmetic_op = binary_simple(col("a", &schema)?, op, lit,
&schema);
+ let batch = RecordBatch::try_new(schema, data)?;
+ let result =
arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
+
+ assert_eq!(&result, &expected);
+ Ok(())
+ }
+
fn apply_logic_op(
schema: &SchemaRef,
left: &ArrayRef,
diff --git a/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs
b/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs
index 2135982b6..40e0d2b0e 100644
--- a/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs
+++ b/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs
@@ -19,11 +19,16 @@
//! destined for arrow-rs but are in datafusion until they are ported.
use arrow::compute::{
- add, add_scalar, divide_opt, divide_scalar, modulus, modulus_scalar,
multiply,
- multiply_scalar, subtract, subtract_scalar,
+ add, add_scalar_dyn, divide_opt, divide_scalar, divide_scalar_dyn, modulus,
+ modulus_scalar, multiply, multiply_scalar, multiply_scalar_dyn, subtract,
+ subtract_scalar_dyn,
};
-use arrow::{array::*, datatypes::ArrowNumericType};
-use datafusion_common::Result;
+use arrow::datatypes::Decimal128Type;
+use arrow::{array::*, datatypes::ArrowNumericType, downcast_dictionary_array};
+use arrow_schema::DataType;
+use datafusion_common::cast::as_decimal128_array;
+use datafusion_common::{DataFusionError, Result};
+use std::sync::Arc;
// Simple (low performance) kernels until optimized kernels are added to arrow
// See https://github.com/apache/arrow-rs/issues/960
@@ -183,50 +188,123 @@ pub(crate) fn add_decimal(
Ok(array)
}
-pub(crate) fn add_decimal_scalar(
- left: &Decimal128Array,
+pub(crate) fn add_decimal_dyn_scalar(left: &dyn Array, right: i128) ->
Result<ArrayRef> {
+ let left_decimal =
left.as_any().downcast_ref::<Decimal128Array>().unwrap();
+
+ let array = add_scalar_dyn::<Decimal128Type>(left, right)?;
+ let decimal_array = as_decimal128_array(&array)?;
+ let decimal_array = decimal_array
+ .clone()
+ .with_precision_and_scale(left_decimal.precision(),
left_decimal.scale())?;
+ Ok(Arc::new(decimal_array))
+}
+
+pub(crate) fn subtract_decimal_dyn_scalar(
+ left: &dyn Array,
right: i128,
-) -> Result<Decimal128Array> {
- let array = add_scalar(left, right)?
- .with_precision_and_scale(left.precision(), left.scale())?;
- Ok(array)
+) -> Result<ArrayRef> {
+ let left_decimal =
left.as_any().downcast_ref::<Decimal128Array>().unwrap();
+
+ let array = subtract_scalar_dyn::<Decimal128Type>(left, right)?;
+ let decimal_array = as_decimal128_array(&array)?;
+ let decimal_array = decimal_array
+ .clone()
+ .with_precision_and_scale(left_decimal.precision(),
left_decimal.scale())?;
+ Ok(Arc::new(decimal_array))
}
-pub(crate) fn subtract_decimal(
- left: &Decimal128Array,
- right: &Decimal128Array,
-) -> Result<Decimal128Array> {
- let array = subtract(left, right)?
- .with_precision_and_scale(left.precision(), left.scale())?;
- Ok(array)
+fn get_precision_scale(left: &dyn Array) -> Result<(u8, i8)> {
+ match left.data_type() {
+ DataType::Decimal128(precision, scale) => Ok((*precision, *scale)),
+ DataType::Dictionary(_, value_type) => match value_type.as_ref() {
+ DataType::Decimal128(precision, scale) => Ok((*precision, *scale)),
+ _ => Err(DataFusionError::Internal(
+ "Unexpected data type".to_string(),
+ )),
+ },
+ _ => Err(DataFusionError::Internal(
+ "Unexpected data type".to_string(),
+ )),
+ }
}
-pub(crate) fn subtract_decimal_scalar(
- left: &Decimal128Array,
+fn decimal_array_with_precision_scale(
+ array: ArrayRef,
+ precision: u8,
+ scale: i8,
+) -> Result<ArrayRef> {
+ let array = array.as_ref();
+ let decimal_array = match array.data_type() {
+ DataType::Decimal128(_, _) => {
+ let array = as_decimal128_array(array)?;
+ Arc::new(array.clone().with_precision_and_scale(precision, scale)?)
+ as ArrayRef
+ }
+ DataType::Dictionary(_, _) => {
+ downcast_dictionary_array!(
+ array => match array.values().data_type() {
+ DataType::Decimal128(_, _) => {
+ let decimal_dict_array =
array.downcast_dict::<Decimal128Array>().unwrap();
+ let decimal_array =
decimal_dict_array.values().clone();
+ let decimal_array =
decimal_array.with_precision_and_scale(precision, scale)?;
+ Arc::new(array.with_values(&decimal_array)) as ArrayRef
+ }
+ t => return
Err(DataFusionError::Internal(format!("Unexpected dictionary value type {t}"))),
+ },
+ t => return Err(DataFusionError::Internal(format!("Unexpected
datatype {t}"))),
+ )
+ }
+ _ => {
+ return Err(DataFusionError::Internal(
+ "Unexpected data type".to_string(),
+ ))
+ }
+ };
+ Ok(decimal_array)
+}
+
+pub(crate) fn multiply_decimal_dyn_scalar(
+ left: &dyn Array,
right: i128,
-) -> Result<Decimal128Array> {
- let array = subtract_scalar(left, right)?
- .with_precision_and_scale(left.precision(), left.scale())?;
- Ok(array)
+) -> Result<ArrayRef> {
+ let (precision, scale) = get_precision_scale(left)?;
+
+ let array = multiply_scalar_dyn::<Decimal128Type>(left, right)?;
+
+ let divide = 10_i128.pow(scale as u32);
+ let array = divide_scalar_dyn::<Decimal128Type>(&array, divide)?;
+
+ decimal_array_with_precision_scale(array, precision, scale)
}
-pub(crate) fn multiply_decimal(
+pub(crate) fn divide_decimal_dyn_scalar(
+ left: &dyn Array,
+ right: i128,
+) -> Result<ArrayRef> {
+ let (precision, scale) = get_precision_scale(left)?;
+
+ let mul = 10_i128.pow(scale as u32);
+ let array = multiply_scalar_dyn::<Decimal128Type>(left, mul)?;
+
+ let array = divide_scalar_dyn::<Decimal128Type>(&array, right)?;
+ decimal_array_with_precision_scale(array, precision, scale)
+}
+
+pub(crate) fn subtract_decimal(
left: &Decimal128Array,
right: &Decimal128Array,
) -> Result<Decimal128Array> {
- let divide = 10_i128.pow(left.scale() as u32);
- let array = multiply(left, right)?;
- let array = divide_scalar(&array, divide)?
+ let array = subtract(left, right)?
.with_precision_and_scale(left.precision(), left.scale())?;
Ok(array)
}
-pub(crate) fn multiply_decimal_scalar(
+pub(crate) fn multiply_decimal(
left: &Decimal128Array,
- right: i128,
+ right: &Decimal128Array,
) -> Result<Decimal128Array> {
- let array = multiply_scalar(left, right)?;
let divide = 10_i128.pow(left.scale() as u32);
+ let array = multiply(left, right)?;
let array = divide_scalar(&array, divide)?
.with_precision_and_scale(left.precision(), left.scale())?;
Ok(array)
@@ -243,18 +321,6 @@ pub(crate) fn divide_opt_decimal(
Ok(array)
}
-pub(crate) fn divide_decimal_scalar(
- left: &Decimal128Array,
- right: i128,
-) -> Result<Decimal128Array> {
- let mul = 10_i128.pow(left.scale() as u32);
- let array = multiply_scalar(left, mul)?;
- // `0` of right will be checked in `divide_scalar`
- let array = divide_scalar(&array, right)?
- .with_precision_and_scale(left.precision(), left.scale())?;
- Ok(array)
-}
-
pub(crate) fn modulus_decimal(
left: &Decimal128Array,
right: &Decimal128Array,
@@ -371,25 +437,28 @@ mod tests {
let expect =
create_decimal_array(&[Some(246), None, Some(245), Some(247)], 25,
3);
assert_eq!(expect, result);
- let result = add_decimal_scalar(&left_decimal_array, 10)?;
+ let result = add_decimal_dyn_scalar(&left_decimal_array, 10)?;
+ let result = as_decimal128_array(&result)?;
let expect =
create_decimal_array(&[Some(133), None, Some(132), Some(134)], 25,
3);
- assert_eq!(expect, result);
+ assert_eq!(&expect, result);
// subtract
let result = subtract_decimal(&left_decimal_array,
&right_decimal_array)?;
let expect = create_decimal_array(&[Some(0), None, Some(-1), Some(1)],
25, 3);
assert_eq!(expect, result);
- let result = subtract_decimal_scalar(&left_decimal_array, 10)?;
+ let result = subtract_decimal_dyn_scalar(&left_decimal_array, 10)?;
+ let result = as_decimal128_array(&result)?;
let expect =
create_decimal_array(&[Some(113), None, Some(112), Some(114)], 25,
3);
- assert_eq!(expect, result);
+ assert_eq!(&expect, result);
// multiply
let result = multiply_decimal(&left_decimal_array,
&right_decimal_array)?;
let expect = create_decimal_array(&[Some(15), None, Some(15),
Some(15)], 25, 3);
assert_eq!(expect, result);
- let result = multiply_decimal_scalar(&left_decimal_array, 10)?;
+ let result = multiply_decimal_dyn_scalar(&left_decimal_array, 10)?;
+ let result = as_decimal128_array(&result)?;
let expect = create_decimal_array(&[Some(1), None, Some(1), Some(1)],
25, 3);
- assert_eq!(expect, result);
+ assert_eq!(&expect, result);
// divide
let left_decimal_array = create_decimal_array(
&[
@@ -414,7 +483,8 @@ mod tests {
3,
);
assert_eq!(expect, result);
- let result = divide_decimal_scalar(&left_decimal_array, 10)?;
+ let result = divide_decimal_dyn_scalar(&left_decimal_array, 10)?;
+ let result = as_decimal128_array(&result)?;
let expect = create_decimal_array(
&[
Some(123456700),
@@ -426,7 +496,7 @@ mod tests {
25,
3,
);
- assert_eq!(expect, result);
+ assert_eq!(&expect, result);
let result = modulus_decimal(&left_decimal_array,
&right_decimal_array)?;
let expect =
create_decimal_array(&[Some(7), None, Some(37), Some(16), None],
25, 3);
@@ -444,7 +514,7 @@ mod tests {
let left_decimal_array = create_decimal_array(&[Some(101)], 10, 1);
let right_decimal_array = create_decimal_array(&[Some(0)], 1, 1);
- let err = divide_decimal_scalar(&left_decimal_array, 0).unwrap_err();
+ let err = divide_decimal_dyn_scalar(&left_decimal_array,
0).unwrap_err();
assert_eq!("Arrow error: Divide by zero error", err.to_string());
let err = modulus_decimal(&left_decimal_array,
&right_decimal_array).unwrap_err();
assert_eq!("Arrow error: Divide by zero error", err.to_string());