This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 1dd3674a9f Fix inconsistent array type for binary numerical operators
result between array and scalar (#6269)
1dd3674a9f is described below
commit 1dd3674a9f623b6dda7c2f45f4db57948daf3fb4
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Tue May 9 11:09:16 2023 -0700
Fix inconsistent array type for binary numerical operators result between
array and scalar (#6269)
* Cast binary numerical operators result between array and scalar to
primitive array
* Add order by to stablize query result
* Fix tests
* Fix clippy
---
.../tests/sqllogictests/test_files/aggregate.slt | 28 +++--
datafusion/physical-expr/src/expressions/binary.rs | 122 +++++++++++----------
2 files changed, 86 insertions(+), 64 deletions(-)
diff --git a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
index d2472e917f..0c0a2c49b4 100644
--- a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
@@ -1719,14 +1719,26 @@ select max(x_dict) from value_dict where x_dict > 3;
----
5
-query error DataFusion error: External error: Arrow error: Invalid argument
error: RowConverter column schema mismatch, expected Int64 got
Dictionary\(Int64, Int64\)
-select sum(x_dict) from value_dict group by x_dict % 2;
+query I
+select sum(x_dict) from value_dict group by x_dict % 2 order by sum(x_dict);
+----
+8
+13
-query error DataFusion error: External error: Arrow error: Invalid argument
error: RowConverter column schema mismatch, expected Int64 got
Dictionary\(Int64, Int64\)
-select avg(x_dict) from value_dict group by x_dict % 2;
+query R
+select avg(x_dict) from value_dict group by x_dict % 2 order by avg(x_dict);
+----
+2.6
+2.666666666667
-query error DataFusion error: External error: Arrow error: Invalid argument
error: RowConverter column schema mismatch, expected Int64 got
Dictionary\(Int64, Int64\)
-select min(x_dict) from value_dict group by x_dict % 2;
+query I
+select min(x_dict) from value_dict group by x_dict % 2 order by min(x_dict);
+----
+1
+2
-query error DataFusion error: External error: Arrow error: Invalid argument
error: RowConverter column schema mismatch, expected Int64 got
Dictionary\(Int64, Int64\)
-select max(x_dict) from value_dict group by x_dict % 2;
+query I
+select max(x_dict) from value_dict group by x_dict % 2 order by max(x_dict);
+----
+4
+5
diff --git a/datafusion/physical-expr/src/expressions/binary.rs
b/datafusion/physical-expr/src/expressions/binary.rs
index 9b46d79258..7bdbba88a8 100644
--- a/datafusion/physical-expr/src/expressions/binary.rs
+++ b/datafusion/physical-expr/src/expressions/binary.rs
@@ -50,7 +50,7 @@ use arrow::compute::kernels::comparison::{
eq_dyn_utf8_scalar, gt_dyn_utf8_scalar, gt_eq_dyn_utf8_scalar,
lt_dyn_utf8_scalar,
lt_eq_dyn_utf8_scalar, neq_dyn_utf8_scalar,
};
-use arrow::compute::{try_unary, unary, CastOptions};
+use arrow::compute::{cast, try_unary, unary, CastOptions};
use arrow::datatypes::*;
use adapter::{eq_dyn, gt_dyn, gt_eq_dyn, lt_dyn, lt_eq_dyn, neq_dyn};
@@ -694,6 +694,9 @@ impl PhysicalExpr for BinaryExpr {
(ColumnarValue::Array(array), ColumnarValue::Scalar(scalar)) => {
// if left is array and right is literal - use scalar
operations
self.evaluate_array_scalar(array, scalar.clone(),
&result_type)?
+ .map(|r| {
+ r.and_then(|a| to_result_type_array(&self.op, a,
&result_type))
+ })
}
(ColumnarValue::Scalar(scalar), ColumnarValue::Array(array)) => {
// if right is literal and left is array - reverse operator
and parameters
@@ -1027,6 +1030,35 @@ pub(crate) fn array_eq_scalar(lhs: &dyn Array, rhs:
&ScalarValue) -> Result<Arra
)?
}
+/// Casts dictionary array to result type for binary numerical operators. Such
operators
+/// between array and scalar produce a dictionary array other than primitive
array of the
+/// same operators between array and array. This leads to inconsistent result
types causing
+/// errors in the following query execution. For such operators between array
and scalar,
+/// we cast the dictionary array to primitive array.
+fn to_result_type_array(
+ op: &Operator,
+ array: ArrayRef,
+ result_type: &DataType,
+) -> Result<ArrayRef> {
+ if op.is_numerical_operators() {
+ match array.data_type() {
+ DataType::Dictionary(_, value_type) => {
+ if value_type.as_ref() == result_type {
+ Ok(cast(&array, result_type)?)
+ } else {
+ Err(DataFusionError::Internal(format!(
+ "Incompatible Dictionary value type {:?} with result
type {:?} of Binary operator {:?}",
+ value_type, result_type, op
+ )))
+ }
+ }
+ _ => Ok(array),
+ }
+ } else {
+ Ok(array)
+ }
+}
+
impl BinaryExpr {
/// Evaluate the expression of the left input is an array and
/// right is literal - use scalar operations
@@ -2699,13 +2731,8 @@ mod tests {
let a = dict_builder.finish();
- let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type,
Int32Type>::new();
-
- dict_builder.append(2)?;
- dict_builder.append_null();
- dict_builder.append(3)?;
- dict_builder.append(6)?;
- let expected = dict_builder.finish();
+ let expected: PrimitiveArray<Int32Type> =
+ PrimitiveArray::from(vec![Some(2), None, Some(3), Some(6)]);
apply_arithmetic_scalar(
Arc::new(schema),
@@ -2742,13 +2769,17 @@ mod tests {
let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
let a = DictionaryArray::try_new(keys, decimal_array)?;
- let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
let decimal_array = Arc::new(create_decimal_array(
- &[Some(value + 1), None, Some(value), Some(value + 2)],
+ &[
+ Some(value + 1),
+ Some(value),
+ None,
+ Some(value + 2),
+ Some(value + 1),
+ ],
11,
0,
));
- let expected = DictionaryArray::try_new(keys, decimal_array)?;
apply_arithmetic_scalar(
Arc::new(schema),
@@ -2758,7 +2789,7 @@ mod tests {
Box::new(DataType::Int8),
Box::new(ScalarValue::Decimal128(Some(1), 10, 0)),
),
- Arc::new(expected),
+ decimal_array,
)?;
Ok(())
@@ -2918,13 +2949,8 @@ mod tests {
let a = dict_builder.finish();
- let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type,
Int32Type>::new();
-
- dict_builder.append(0)?;
- dict_builder.append_null();
- dict_builder.append(1)?;
- dict_builder.append(4)?;
- let expected = dict_builder.finish();
+ let expected: PrimitiveArray<Int32Type> =
+ PrimitiveArray::from(vec![Some(0), None, Some(1), Some(4)]);
apply_arithmetic_scalar(
Arc::new(schema),
@@ -2961,13 +2987,17 @@ mod tests {
let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
let a = DictionaryArray::try_new(keys, decimal_array)?;
- let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
let decimal_array = Arc::new(create_decimal_array(
- &[Some(value - 1), None, Some(value - 2), Some(value)],
+ &[
+ Some(value - 1),
+ Some(value - 2),
+ None,
+ Some(value),
+ Some(value - 1),
+ ],
11,
0,
));
- let expected = DictionaryArray::try_new(keys, decimal_array)?;
apply_arithmetic_scalar(
Arc::new(schema),
@@ -2977,7 +3007,7 @@ mod tests {
Box::new(DataType::Int8),
Box::new(ScalarValue::Decimal128(Some(1), 10, 0)),
),
- Arc::new(expected),
+ decimal_array,
)?;
Ok(())
@@ -3133,13 +3163,8 @@ mod tests {
let a = dict_builder.finish();
- let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type,
Int32Type>::new();
-
- dict_builder.append(2)?;
- dict_builder.append_null();
- dict_builder.append(4)?;
- dict_builder.append(10)?;
- let expected = dict_builder.finish();
+ let expected: PrimitiveArray<Int32Type> =
+ PrimitiveArray::from(vec![Some(2), None, Some(4), Some(10)]);
apply_arithmetic_scalar(
Arc::new(schema),
@@ -3176,13 +3201,11 @@ mod tests {
let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
let a = DictionaryArray::try_new(keys, decimal_array)?;
- let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
let decimal_array = Arc::new(create_decimal_array(
- &[Some(246), None, Some(244), Some(248)],
+ &[Some(246), Some(244), None, Some(248), Some(246)],
21,
0,
));
- let expected = DictionaryArray::try_new(keys, decimal_array)?;
apply_arithmetic_scalar(
Arc::new(schema),
@@ -3192,7 +3215,7 @@ mod tests {
Box::new(DataType::Int8),
Box::new(ScalarValue::Decimal128(Some(2), 10, 0)),
),
- Arc::new(expected),
+ decimal_array,
)?;
Ok(())
@@ -3360,13 +3383,8 @@ mod tests {
let a = dict_builder.finish();
- let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type,
Int32Type>::new();
-
- dict_builder.append(0)?;
- dict_builder.append_null();
- dict_builder.append(1)?;
- dict_builder.append(2)?;
- let expected = dict_builder.finish();
+ let expected: PrimitiveArray<Int32Type> =
+ PrimitiveArray::from(vec![Some(0), None, Some(1), Some(2)]);
apply_arithmetic_scalar(
Arc::new(schema),
@@ -3403,18 +3421,17 @@ mod tests {
let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
let a = DictionaryArray::try_new(keys, decimal_array)?;
- let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
let decimal_array = Arc::new(create_decimal_array(
&[
Some(6150000000000),
- None,
Some(6100000000000),
+ None,
Some(6200000000000),
+ Some(6150000000000),
],
21,
11,
));
- let expected = DictionaryArray::try_new(keys, decimal_array)?;
apply_arithmetic_scalar(
Arc::new(schema),
@@ -3424,7 +3441,7 @@ mod tests {
Box::new(DataType::Int8),
Box::new(ScalarValue::Decimal128(Some(2), 10, 0)),
),
- Arc::new(expected),
+ decimal_array,
)?;
Ok(())
@@ -3582,13 +3599,8 @@ mod tests {
let a = dict_builder.finish();
- let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type,
Int32Type>::new();
-
- dict_builder.append(1)?;
- dict_builder.append_null();
- dict_builder.append(0)?;
- dict_builder.append(1)?;
- let expected = dict_builder.finish();
+ let expected: PrimitiveArray<Int32Type> =
+ PrimitiveArray::from(vec![Some(1), None, Some(0), Some(1)]);
apply_arithmetic_scalar(
Arc::new(schema),
@@ -3625,13 +3637,11 @@ mod tests {
let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
let a = DictionaryArray::try_new(keys, decimal_array)?;
- let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
let decimal_array = Arc::new(create_decimal_array(
- &[Some(1), None, Some(0), Some(0)],
+ &[Some(1), Some(0), None, Some(0), Some(1)],
10,
0,
));
- let expected = DictionaryArray::try_new(keys, decimal_array)?;
apply_arithmetic_scalar(
Arc::new(schema),
@@ -3641,7 +3651,7 @@ mod tests {
Box::new(DataType::Int8),
Box::new(ScalarValue::Decimal128(Some(2), 10, 0)),
),
- Arc::new(expected),
+ decimal_array,
)?;
Ok(())