This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 1dd3674a9f Fix inconsistent array type for binary numerical operators 
result between array and scalar (#6269)
1dd3674a9f is described below

commit 1dd3674a9f623b6dda7c2f45f4db57948daf3fb4
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Tue May 9 11:09:16 2023 -0700

    Fix inconsistent array type for binary numerical operators result between 
array and scalar (#6269)
    
    * Cast binary numerical operators result between array and scalar to 
primitive array
    
    * Add order by to stablize query result
    
    * Fix tests
    
    * Fix clippy
---
 .../tests/sqllogictests/test_files/aggregate.slt   |  28 +++--
 datafusion/physical-expr/src/expressions/binary.rs | 122 +++++++++++----------
 2 files changed, 86 insertions(+), 64 deletions(-)

diff --git a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt 
b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
index d2472e917f..0c0a2c49b4 100644
--- a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
@@ -1719,14 +1719,26 @@ select max(x_dict) from value_dict where x_dict > 3;
 ----
 5
 
-query error DataFusion error: External error: Arrow error: Invalid argument 
error: RowConverter column schema mismatch, expected Int64 got 
Dictionary\(Int64, Int64\)
-select sum(x_dict) from value_dict group by x_dict % 2;
+query I
+select sum(x_dict) from value_dict group by x_dict % 2 order by sum(x_dict);
+----
+8
+13
 
-query error DataFusion error: External error: Arrow error: Invalid argument 
error: RowConverter column schema mismatch, expected Int64 got 
Dictionary\(Int64, Int64\)
-select avg(x_dict) from value_dict group by x_dict % 2;
+query R
+select avg(x_dict) from value_dict group by x_dict % 2 order by avg(x_dict);
+----
+2.6
+2.666666666667
 
-query error DataFusion error: External error: Arrow error: Invalid argument 
error: RowConverter column schema mismatch, expected Int64 got 
Dictionary\(Int64, Int64\)
-select min(x_dict) from value_dict group by x_dict % 2;
+query I
+select min(x_dict) from value_dict group by x_dict % 2 order by min(x_dict);
+----
+1
+2
 
-query error DataFusion error: External error: Arrow error: Invalid argument 
error: RowConverter column schema mismatch, expected Int64 got 
Dictionary\(Int64, Int64\)
-select max(x_dict) from value_dict group by x_dict % 2;
+query I
+select max(x_dict) from value_dict group by x_dict % 2 order by max(x_dict);
+----
+4
+5
diff --git a/datafusion/physical-expr/src/expressions/binary.rs 
b/datafusion/physical-expr/src/expressions/binary.rs
index 9b46d79258..7bdbba88a8 100644
--- a/datafusion/physical-expr/src/expressions/binary.rs
+++ b/datafusion/physical-expr/src/expressions/binary.rs
@@ -50,7 +50,7 @@ use arrow::compute::kernels::comparison::{
     eq_dyn_utf8_scalar, gt_dyn_utf8_scalar, gt_eq_dyn_utf8_scalar, 
lt_dyn_utf8_scalar,
     lt_eq_dyn_utf8_scalar, neq_dyn_utf8_scalar,
 };
-use arrow::compute::{try_unary, unary, CastOptions};
+use arrow::compute::{cast, try_unary, unary, CastOptions};
 use arrow::datatypes::*;
 
 use adapter::{eq_dyn, gt_dyn, gt_eq_dyn, lt_dyn, lt_eq_dyn, neq_dyn};
@@ -694,6 +694,9 @@ impl PhysicalExpr for BinaryExpr {
             (ColumnarValue::Array(array), ColumnarValue::Scalar(scalar)) => {
                 // if left is array and right is literal - use scalar 
operations
                 self.evaluate_array_scalar(array, scalar.clone(), 
&result_type)?
+                    .map(|r| {
+                        r.and_then(|a| to_result_type_array(&self.op, a, 
&result_type))
+                    })
             }
             (ColumnarValue::Scalar(scalar), ColumnarValue::Array(array)) => {
                 // if right is literal and left is array - reverse operator 
and parameters
@@ -1027,6 +1030,35 @@ pub(crate) fn array_eq_scalar(lhs: &dyn Array, rhs: 
&ScalarValue) -> Result<Arra
     )?
 }
 
+/// Casts dictionary array to result type for binary numerical operators. Such 
operators
+/// between array and scalar produce a dictionary array other than primitive 
array of the
+/// same operators between array and array. This leads to inconsistent result 
types causing
+/// errors in the following query execution. For such operators between array 
and scalar,
+/// we cast the dictionary array to primitive array.
+fn to_result_type_array(
+    op: &Operator,
+    array: ArrayRef,
+    result_type: &DataType,
+) -> Result<ArrayRef> {
+    if op.is_numerical_operators() {
+        match array.data_type() {
+            DataType::Dictionary(_, value_type) => {
+                if value_type.as_ref() == result_type {
+                    Ok(cast(&array, result_type)?)
+                } else {
+                    Err(DataFusionError::Internal(format!(
+                        "Incompatible Dictionary value type {:?} with result 
type {:?} of Binary operator {:?}",
+                        value_type, result_type, op
+                    )))
+                }
+            }
+            _ => Ok(array),
+        }
+    } else {
+        Ok(array)
+    }
+}
+
 impl BinaryExpr {
     /// Evaluate the expression of the left input is an array and
     /// right is literal - use scalar operations
@@ -2699,13 +2731,8 @@ mod tests {
 
         let a = dict_builder.finish();
 
-        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, 
Int32Type>::new();
-
-        dict_builder.append(2)?;
-        dict_builder.append_null();
-        dict_builder.append(3)?;
-        dict_builder.append(6)?;
-        let expected = dict_builder.finish();
+        let expected: PrimitiveArray<Int32Type> =
+            PrimitiveArray::from(vec![Some(2), None, Some(3), Some(6)]);
 
         apply_arithmetic_scalar(
             Arc::new(schema),
@@ -2742,13 +2769,17 @@ mod tests {
         let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
         let a = DictionaryArray::try_new(keys, decimal_array)?;
 
-        let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
         let decimal_array = Arc::new(create_decimal_array(
-            &[Some(value + 1), None, Some(value), Some(value + 2)],
+            &[
+                Some(value + 1),
+                Some(value),
+                None,
+                Some(value + 2),
+                Some(value + 1),
+            ],
             11,
             0,
         ));
-        let expected = DictionaryArray::try_new(keys, decimal_array)?;
 
         apply_arithmetic_scalar(
             Arc::new(schema),
@@ -2758,7 +2789,7 @@ mod tests {
                 Box::new(DataType::Int8),
                 Box::new(ScalarValue::Decimal128(Some(1), 10, 0)),
             ),
-            Arc::new(expected),
+            decimal_array,
         )?;
 
         Ok(())
@@ -2918,13 +2949,8 @@ mod tests {
 
         let a = dict_builder.finish();
 
-        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, 
Int32Type>::new();
-
-        dict_builder.append(0)?;
-        dict_builder.append_null();
-        dict_builder.append(1)?;
-        dict_builder.append(4)?;
-        let expected = dict_builder.finish();
+        let expected: PrimitiveArray<Int32Type> =
+            PrimitiveArray::from(vec![Some(0), None, Some(1), Some(4)]);
 
         apply_arithmetic_scalar(
             Arc::new(schema),
@@ -2961,13 +2987,17 @@ mod tests {
         let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
         let a = DictionaryArray::try_new(keys, decimal_array)?;
 
-        let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
         let decimal_array = Arc::new(create_decimal_array(
-            &[Some(value - 1), None, Some(value - 2), Some(value)],
+            &[
+                Some(value - 1),
+                Some(value - 2),
+                None,
+                Some(value),
+                Some(value - 1),
+            ],
             11,
             0,
         ));
-        let expected = DictionaryArray::try_new(keys, decimal_array)?;
 
         apply_arithmetic_scalar(
             Arc::new(schema),
@@ -2977,7 +3007,7 @@ mod tests {
                 Box::new(DataType::Int8),
                 Box::new(ScalarValue::Decimal128(Some(1), 10, 0)),
             ),
-            Arc::new(expected),
+            decimal_array,
         )?;
 
         Ok(())
@@ -3133,13 +3163,8 @@ mod tests {
 
         let a = dict_builder.finish();
 
-        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, 
Int32Type>::new();
-
-        dict_builder.append(2)?;
-        dict_builder.append_null();
-        dict_builder.append(4)?;
-        dict_builder.append(10)?;
-        let expected = dict_builder.finish();
+        let expected: PrimitiveArray<Int32Type> =
+            PrimitiveArray::from(vec![Some(2), None, Some(4), Some(10)]);
 
         apply_arithmetic_scalar(
             Arc::new(schema),
@@ -3176,13 +3201,11 @@ mod tests {
         let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
         let a = DictionaryArray::try_new(keys, decimal_array)?;
 
-        let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
         let decimal_array = Arc::new(create_decimal_array(
-            &[Some(246), None, Some(244), Some(248)],
+            &[Some(246), Some(244), None, Some(248), Some(246)],
             21,
             0,
         ));
-        let expected = DictionaryArray::try_new(keys, decimal_array)?;
 
         apply_arithmetic_scalar(
             Arc::new(schema),
@@ -3192,7 +3215,7 @@ mod tests {
                 Box::new(DataType::Int8),
                 Box::new(ScalarValue::Decimal128(Some(2), 10, 0)),
             ),
-            Arc::new(expected),
+            decimal_array,
         )?;
 
         Ok(())
@@ -3360,13 +3383,8 @@ mod tests {
 
         let a = dict_builder.finish();
 
-        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, 
Int32Type>::new();
-
-        dict_builder.append(0)?;
-        dict_builder.append_null();
-        dict_builder.append(1)?;
-        dict_builder.append(2)?;
-        let expected = dict_builder.finish();
+        let expected: PrimitiveArray<Int32Type> =
+            PrimitiveArray::from(vec![Some(0), None, Some(1), Some(2)]);
 
         apply_arithmetic_scalar(
             Arc::new(schema),
@@ -3403,18 +3421,17 @@ mod tests {
         let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
         let a = DictionaryArray::try_new(keys, decimal_array)?;
 
-        let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
         let decimal_array = Arc::new(create_decimal_array(
             &[
                 Some(6150000000000),
-                None,
                 Some(6100000000000),
+                None,
                 Some(6200000000000),
+                Some(6150000000000),
             ],
             21,
             11,
         ));
-        let expected = DictionaryArray::try_new(keys, decimal_array)?;
 
         apply_arithmetic_scalar(
             Arc::new(schema),
@@ -3424,7 +3441,7 @@ mod tests {
                 Box::new(DataType::Int8),
                 Box::new(ScalarValue::Decimal128(Some(2), 10, 0)),
             ),
-            Arc::new(expected),
+            decimal_array,
         )?;
 
         Ok(())
@@ -3582,13 +3599,8 @@ mod tests {
 
         let a = dict_builder.finish();
 
-        let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, 
Int32Type>::new();
-
-        dict_builder.append(1)?;
-        dict_builder.append_null();
-        dict_builder.append(0)?;
-        dict_builder.append(1)?;
-        let expected = dict_builder.finish();
+        let expected: PrimitiveArray<Int32Type> =
+            PrimitiveArray::from(vec![Some(1), None, Some(0), Some(1)]);
 
         apply_arithmetic_scalar(
             Arc::new(schema),
@@ -3625,13 +3637,11 @@ mod tests {
         let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
         let a = DictionaryArray::try_new(keys, decimal_array)?;
 
-        let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
         let decimal_array = Arc::new(create_decimal_array(
-            &[Some(1), None, Some(0), Some(0)],
+            &[Some(1), Some(0), None, Some(0), Some(1)],
             10,
             0,
         ));
-        let expected = DictionaryArray::try_new(keys, decimal_array)?;
 
         apply_arithmetic_scalar(
             Arc::new(schema),
@@ -3641,7 +3651,7 @@ mod tests {
                 Box::new(DataType::Int8),
                 Box::new(ScalarValue::Decimal128(Some(2), 10, 0)),
             ),
-            Arc::new(expected),
+            decimal_array,
         )?;
 
         Ok(())

Reply via email to