This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new c2972a29c support decimal for inlist expr (#2764)
c2972a29c is described below

commit c2972a29c3165dc8cb92f8d437de80d56d99740e
Author: Kun Liu <[email protected]>
AuthorDate: Sat Jun 25 02:03:33 2022 +0800

    support decimal for inlist expr (#2764)
---
 datafusion/expr/src/binary_rule.rs                 |   6 +-
 .../physical-expr/src/expressions/in_list.rs       | 353 +++++++++++++++++++--
 datafusion/physical-expr/src/planner.rs            |   7 +-
 3 files changed, 335 insertions(+), 31 deletions(-)

diff --git a/datafusion/expr/src/binary_rule.rs 
b/datafusion/expr/src/binary_rule.rs
index 6770fccd7..88b4d95ec 100644
--- a/datafusion/expr/src/binary_rule.rs
+++ b/datafusion/expr/src/binary_rule.rs
@@ -150,7 +150,11 @@ fn bitwise_coercion(left_type: &DataType, right_type: 
&DataType) -> Option<DataT
     }
 }
 
-fn comparison_eq_coercion(lhs_type: &DataType, rhs_type: &DataType) -> 
Option<DataType> {
+/// Get the coerced data type for `eq` or `not eq` operation
+pub fn comparison_eq_coercion(
+    lhs_type: &DataType,
+    rhs_type: &DataType,
+) -> Option<DataType> {
     // can't compare dictionaries directly due to
     // https://github.com/apache/arrow-rs/issues/1201
     if lhs_type == rhs_type && !is_dictionary(lhs_type) {
diff --git a/datafusion/physical-expr/src/expressions/in_list.rs 
b/datafusion/physical-expr/src/expressions/in_list.rs
index 7d15dd5e9..346eea472 100644
--- a/datafusion/physical-expr/src/expressions/in_list.rs
+++ b/datafusion/physical-expr/src/expressions/in_list.rs
@@ -33,11 +33,14 @@ use arrow::{
     record_batch::RecordBatch,
 };
 
+use crate::expressions::try_cast;
 use crate::{expressions, PhysicalExpr};
 use arrow::array::*;
 use arrow::buffer::{Buffer, MutableBuffer};
 use datafusion_common::ScalarValue;
+use datafusion_common::ScalarValue::Decimal128;
 use datafusion_common::{DataFusionError, Result};
+use datafusion_expr::binary_rule::comparison_eq_coercion;
 use datafusion_expr::ColumnarValue;
 
 /// Size at which to use a Set rather than Vec for `IN` / `NOT IN`
@@ -82,6 +85,8 @@ pub struct InListExpr {
 /// InSet
 #[derive(Debug)]
 pub struct InSet {
+    // TODO: optimization: In the `IN` or `NOT IN` we don't need to consider 
the NULL value
+    // The data type is same, we can use  set: HashSet<T>
     set: HashSet<ScalarValue>,
 }
 
@@ -160,6 +165,7 @@ macro_rules! make_contains_primitive {
                 ColumnarValue::Scalar(s) => match s {
                     ScalarValue::$SCALAR_VALUE(Some(v)) => Some(*v),
                     ScalarValue::$SCALAR_VALUE(None) => None,
+                    // TODO this is bug, for primitive the expr list should be 
cast to the same data type
                     ScalarValue::Utf8(None) => None,
                     datatype => unimplemented!("Unexpected type {} for 
InList", datatype),
                 },
@@ -300,6 +306,90 @@ fn cast_static_filter_to_set(list: &[Arc<dyn 
PhysicalExpr>]) -> HashSet<ScalarVa
     }))
 }
 
+fn make_list_contains_decimal(
+    array: &DecimalArray,
+    list: Vec<ColumnarValue>,
+    negated: bool,
+) -> BooleanArray {
+    let contains_null = list
+        .iter()
+        .any(|v| matches!(v, ColumnarValue::Scalar(s) if s.is_null()));
+    let values = list
+        .iter()
+        .flat_map(|v| match v {
+            ColumnarValue::Scalar(s) => match s {
+                Decimal128(v128op, _, _) => *v128op,
+                _ => {
+                    unreachable!(
+                        "InList can't reach other data type for decimal data 
type."
+                    )
+                }
+            },
+            ColumnarValue::Array(_) => {
+                unimplemented!("InList does not yet support nested columns.")
+            }
+        })
+        .collect::<Vec<_>>();
+
+    if !negated {
+        // In
+        array
+            .iter()
+            .map(|v| v.map(|v128| values.contains(&v128)))
+            .collect::<BooleanArray>()
+    } else {
+        // Not in
+        if contains_null {
+            // If the expr is NOT IN and the list contains NULL value
+            // All the result must be NONE
+            BooleanArray::from(vec![None; array.len()])
+        } else {
+            array
+                .iter()
+                .map(|v| v.map(|v128| !values.contains(&v128)))
+                .collect::<BooleanArray>()
+        }
+    }
+}
+
+fn make_set_contains_decimal(
+    array: &DecimalArray,
+    set: &HashSet<ScalarValue>,
+    negated: bool,
+) -> BooleanArray {
+    let contains_null = set.iter().any(|v| v.is_null());
+    let native_array = set
+        .iter()
+        .flat_map(|v| match v {
+            Decimal128(v128op, _, _) => *v128op,
+            _ => {
+                unreachable!("InList can't reach other data type for decimal 
data type.")
+            }
+        })
+        .collect::<Vec<_>>();
+    let native_set: HashSet<i128> = HashSet::from_iter(native_array);
+
+    if !negated {
+        // In
+        array
+            .iter()
+            .map(|v| v.map(|v128| native_set.contains(&v128)))
+            .collect::<BooleanArray>()
+    } else {
+        // Not in
+        if contains_null {
+            // If the expr is NOT IN and the list contains NULL value
+            // All the result must be NONE
+            BooleanArray::from(vec![None; array.len()])
+        } else {
+            array
+                .iter()
+                .map(|v| v.map(|v128| !native_set.contains(&v128)))
+                .collect::<BooleanArray>()
+        }
+    }
+}
+
 impl InListExpr {
     /// Create a new InList expression
     pub fn new(
@@ -504,6 +594,11 @@ impl PhysicalExpr for InListExpr {
                         .unwrap();
                     set_contains_with_negated!(array, set, self.negated)
                 }
+                DataType::Decimal(_, _) => {
+                    let array = 
array.as_any().downcast_ref::<DecimalArray>().unwrap();
+                    let result = make_set_contains_decimal(array, set, 
self.negated);
+                    Ok(ColumnarValue::Array(Arc::new(result)))
+                }
                 datatype => 
Result::Err(DataFusionError::NotImplemented(format!(
                     "InSet does not support datatype {:?}.",
                     datatype
@@ -631,6 +726,16 @@ impl PhysicalExpr for InListExpr {
                     let null_array = new_null_array(&DataType::Boolean, 
array.len());
                     Ok(ColumnarValue::Array(Arc::new(null_array)))
                 }
+                DataType::Decimal(_, _) => {
+                    let decimal_array =
+                        array.as_any().downcast_ref::<DecimalArray>().unwrap();
+                    let result = make_list_contains_decimal(
+                        decimal_array,
+                        list_values,
+                        self.negated,
+                    );
+                    Ok(ColumnarValue::Array(Arc::new(result)))
+                }
                 datatype => 
Result::Err(DataFusionError::NotImplemented(format!(
                     "InList does not support datatype {:?}.",
                     datatype
@@ -640,13 +745,63 @@ impl PhysicalExpr for InListExpr {
     }
 }
 
+type InListCastResult = (Arc<dyn PhysicalExpr>, Vec<Arc<dyn PhysicalExpr>>);
+
 /// Creates a unary expression InList
 pub fn in_list(
     expr: Arc<dyn PhysicalExpr>,
     list: Vec<Arc<dyn PhysicalExpr>>,
     negated: &bool,
+    input_schema: &Schema,
 ) -> Result<Arc<dyn PhysicalExpr>> {
-    Ok(Arc::new(InListExpr::new(expr, list, *negated)))
+    let (cast_expr, cast_list) = in_list_cast(expr, list, input_schema)?;
+    Ok(Arc::new(InListExpr::new(cast_expr, cast_list, *negated)))
+}
+
+fn in_list_cast(
+    expr: Arc<dyn PhysicalExpr>,
+    list: Vec<Arc<dyn PhysicalExpr>>,
+    input_schema: &Schema,
+) -> Result<InListCastResult> {
+    let expr_type = &expr.data_type(input_schema)?;
+    let list_types: Vec<DataType> = list
+        .iter()
+        .map(|list_expr| list_expr.data_type(input_schema).unwrap())
+        .collect();
+    // TODO in the arrow-rs, should support NULL type to Decimal Data type
+    // TODO support in the arrow-rs, NULL value cast to Decimal Value
+    // https://github.com/apache/arrow-datafusion/issues/2759
+    let result_type = get_coerce_type(expr_type, &list_types);
+    match result_type {
+        None => Err(DataFusionError::Internal(format!(
+            "In expr can find the coerced type for {:?} in {:?}",
+            expr_type, list_types
+        ))),
+        Some(data_type) => {
+            // find the coerced type
+            let cast_expr = try_cast(expr, input_schema, data_type.clone())?;
+            let cast_list_expr = list
+                .into_iter()
+                .map(|list_expr| {
+                    try_cast(list_expr, input_schema, 
data_type.clone()).unwrap()
+                })
+                .collect();
+            Ok((cast_expr, cast_list_expr))
+        }
+    }
+}
+
+fn get_coerce_type(expr_type: &DataType, list_type: &[DataType]) -> 
Option<DataType> {
+    // get the equal coerced data type
+    list_type
+        .iter()
+        .fold(Some(expr_type.clone()), |left, right_type| {
+            match left {
+                None => None,
+                // TODO refactor a framework to do the data type coercion
+                Some(left_type) => comparison_eq_coercion(&left_type, 
right_type),
+            }
+        })
 }
 
 #[cfg(test)]
@@ -659,8 +814,8 @@ mod tests {
 
     // applies the in_list expr to an input batch and list
     macro_rules! in_list {
-        ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr) => 
{{
-            let expr = in_list($COL, $LIST, $NEGATED).unwrap();
+        ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr, 
$SCHEMA:expr) => {{
+            let expr = in_list($COL, $LIST, $NEGATED, $SCHEMA).unwrap();
             let result = expr.evaluate(&$BATCH)?.into_array($BATCH.num_rows());
             let result = result
                 .as_any()
@@ -676,7 +831,7 @@ mod tests {
         let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
         let a = StringArray::from(vec![Some("a"), Some("d"), None]);
         let col_a = col("a", &schema)?;
-        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
+        let batch = RecordBatch::try_new(Arc::new(schema.clone()), 
vec![Arc::new(a)])?;
 
         // expression: "a in ("a", "b")"
         let list = vec![
@@ -688,7 +843,8 @@ mod tests {
             list,
             &false,
             vec![Some(true), Some(false), None],
-            col_a.clone()
+            col_a.clone(),
+            &schema
         );
 
         // expression: "a not in ("a", "b")"
@@ -701,7 +857,8 @@ mod tests {
             list,
             &true,
             vec![Some(false), Some(true), None],
-            col_a.clone()
+            col_a.clone(),
+            &schema
         );
 
         // expression: "a not in ("a", "b")"
@@ -715,7 +872,8 @@ mod tests {
             list,
             &false,
             vec![Some(true), None, None],
-            col_a.clone()
+            col_a.clone(),
+            &schema
         );
 
         // expression: "a not in ("a", "b")"
@@ -729,7 +887,8 @@ mod tests {
             list,
             &true,
             vec![Some(false), None, None],
-            col_a.clone()
+            col_a.clone(),
+            &schema
         );
 
         Ok(())
@@ -740,7 +899,7 @@ mod tests {
         let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]);
         let a = Int64Array::from(vec![Some(0), Some(2), None]);
         let col_a = col("a", &schema)?;
-        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
+        let batch = RecordBatch::try_new(Arc::new(schema.clone()), 
vec![Arc::new(a)])?;
 
         // expression: "a in (0, 1)"
         let list = vec![
@@ -752,7 +911,8 @@ mod tests {
             list,
             &false,
             vec![Some(true), Some(false), None],
-            col_a.clone()
+            col_a.clone(),
+            &schema
         );
 
         // expression: "a not in (0, 1)"
@@ -765,35 +925,38 @@ mod tests {
             list,
             &true,
             vec![Some(false), Some(true), None],
-            col_a.clone()
+            col_a.clone(),
+            &schema
         );
 
         // expression: "a in (0, 1, NULL)"
         let list = vec![
             lit(ScalarValue::Int64(Some(0))),
             lit(ScalarValue::Int64(Some(1))),
-            lit(ScalarValue::Utf8(None)),
+            lit(ScalarValue::Null),
         ];
         in_list!(
             batch,
             list,
             &false,
             vec![Some(true), None, None],
-            col_a.clone()
+            col_a.clone(),
+            &schema
         );
 
         // expression: "a not in (0, 1, NULL)"
         let list = vec![
             lit(ScalarValue::Int64(Some(0))),
             lit(ScalarValue::Int64(Some(1))),
-            lit(ScalarValue::Utf8(None)),
+            lit(ScalarValue::Null),
         ];
         in_list!(
             batch,
             list,
             &true,
             vec![Some(false), None, None],
-            col_a.clone()
+            col_a.clone(),
+            &schema
         );
 
         Ok(())
@@ -804,7 +967,7 @@ mod tests {
         let schema = Schema::new(vec![Field::new("a", DataType::Float64, 
true)]);
         let a = Float64Array::from(vec![Some(0.0), Some(0.2), None]);
         let col_a = col("a", &schema)?;
-        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
+        let batch = RecordBatch::try_new(Arc::new(schema.clone()), 
vec![Arc::new(a)])?;
 
         // expression: "a in (0.0, 0.2)"
         let list = vec![
@@ -816,7 +979,8 @@ mod tests {
             list,
             &false,
             vec![Some(true), Some(false), None],
-            col_a.clone()
+            col_a.clone(),
+            &schema
         );
 
         // expression: "a not in (0.0, 0.2)"
@@ -829,35 +993,38 @@ mod tests {
             list,
             &true,
             vec![Some(false), Some(true), None],
-            col_a.clone()
+            col_a.clone(),
+            &schema
         );
 
         // expression: "a in (0.0, 0.2, NULL)"
         let list = vec![
             lit(ScalarValue::Float64(Some(0.0))),
             lit(ScalarValue::Float64(Some(0.1))),
-            lit(ScalarValue::Utf8(None)),
+            lit(ScalarValue::Null),
         ];
         in_list!(
             batch,
             list,
             &false,
             vec![Some(true), None, None],
-            col_a.clone()
+            col_a.clone(),
+            &schema
         );
 
         // expression: "a not in (0.0, 0.2, NULL)"
         let list = vec![
             lit(ScalarValue::Float64(Some(0.0))),
             lit(ScalarValue::Float64(Some(0.1))),
-            lit(ScalarValue::Utf8(None)),
+            lit(ScalarValue::Null),
         ];
         in_list!(
             batch,
             list,
             &true,
             vec![Some(false), None, None],
-            col_a.clone()
+            col_a.clone(),
+            &schema
         );
 
         Ok(())
@@ -868,29 +1035,157 @@ mod tests {
         let schema = Schema::new(vec![Field::new("a", DataType::Boolean, 
true)]);
         let a = BooleanArray::from(vec![Some(true), None]);
         let col_a = col("a", &schema)?;
-        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
+        let batch = RecordBatch::try_new(Arc::new(schema.clone()), 
vec![Arc::new(a)])?;
 
         // expression: "a in (true)"
         let list = vec![lit(ScalarValue::Boolean(Some(true)))];
-        in_list!(batch, list, &false, vec![Some(true), None], col_a.clone());
+        in_list!(
+            batch,
+            list,
+            &false,
+            vec![Some(true), None],
+            col_a.clone(),
+            &schema
+        );
 
         // expression: "a not in (true)"
         let list = vec![lit(ScalarValue::Boolean(Some(true)))];
-        in_list!(batch, list, &true, vec![Some(false), None], col_a.clone());
+        in_list!(
+            batch,
+            list,
+            &true,
+            vec![Some(false), None],
+            col_a.clone(),
+            &schema
+        );
 
         // expression: "a in (true, NULL)"
         let list = vec![
             lit(ScalarValue::Boolean(Some(true))),
-            lit(ScalarValue::Utf8(None)),
+            lit(ScalarValue::Null),
         ];
-        in_list!(batch, list, &false, vec![Some(true), None], col_a.clone());
+        in_list!(
+            batch,
+            list,
+            &false,
+            vec![Some(true), None],
+            col_a.clone(),
+            &schema
+        );
 
         // expression: "a not in (true, NULL)"
         let list = vec![
             lit(ScalarValue::Boolean(Some(true))),
-            lit(ScalarValue::Utf8(None)),
+            lit(ScalarValue::Null),
+        ];
+        in_list!(
+            batch,
+            list,
+            &true,
+            vec![Some(false), None],
+            col_a.clone(),
+            &schema
+        );
+
+        Ok(())
+    }
+
+    #[test]
+    fn in_list_decimal() -> Result<()> {
+        // Now, we can check the NULL type
+        let schema = Schema::new(vec![Field::new("a", DataType::Decimal(13, 
4), true)]);
+        let array = vec![Some(100_0000_i128), None, Some(200_5000_i128)]
+            .into_iter()
+            .collect::<DecimalArray>();
+        let array = array.with_precision_and_scale(13, 4).unwrap();
+        let col_a = col("a", &schema)?;
+        let batch =
+            RecordBatch::try_new(Arc::new(schema.clone()), 
vec![Arc::new(array)])?;
+
+        // expression: "a in (100,200), the data type of list is INT32
+        let list = vec![
+            lit(ScalarValue::Int32(Some(100))),
+            lit(ScalarValue::Int32(Some(200))),
+        ];
+        in_list!(
+            batch,
+            list,
+            &false,
+            vec![Some(true), None, Some(false)],
+            col_a.clone(),
+            &schema
+        );
+        // expression: "a not in (100,200)
+        let list = vec![
+            lit(ScalarValue::Int32(Some(100))),
+            lit(ScalarValue::Int32(Some(200))),
+        ];
+        in_list!(
+            batch,
+            list,
+            &true,
+            vec![Some(false), None, Some(true)],
+            col_a.clone(),
+            &schema
+        );
+
+        // expression: "a in (200,NULL), the data type of list is INT32 AND 
NULL
+        // TODO support: NULL data type to decimal in arrow-rs
+        // let list = vec![lit(ScalarValue::Int32(Some(100))), 
lit(ScalarValue::Null)];
+        // in_list!(batch, list, &false, vec![Some(true), None, Some(false)], 
col_a.clone(), &schema);
+
+        // expression: "a in (200.5, 100), the data type of list is FLOAT32 
and INT32
+        let list = vec![
+            lit(ScalarValue::Float32(Some(200.50f32))),
+            lit(ScalarValue::Int32(Some(100))),
         ];
-        in_list!(batch, list, &true, vec![Some(false), None], col_a.clone());
+        in_list!(
+            batch,
+            list,
+            &false,
+            vec![Some(true), None, Some(true)],
+            col_a.clone(),
+            &schema
+        );
+
+        // expression: "a not in (200.5, 100), the data type of list is 
FLOAT32 and INT32
+        let list = vec![
+            lit(ScalarValue::Float32(Some(200.50f32))),
+            lit(ScalarValue::Int32(Some(101))),
+        ];
+        in_list!(
+            batch,
+            list,
+            &true,
+            vec![Some(true), None, Some(false)],
+            col_a.clone(),
+            &schema
+        );
+
+        // test the optimization: set
+        // expression: "a in (99..300), the data type of list is INT32
+        let list = (99..300)
+            .into_iter()
+            .map(|v| lit(ScalarValue::Int32(Some(v))))
+            .collect::<Vec<_>>();
+
+        in_list!(
+            batch,
+            list.clone(),
+            &false,
+            vec![Some(true), None, Some(false)],
+            col_a.clone(),
+            &schema
+        );
+
+        in_list!(
+            batch,
+            list,
+            &true,
+            vec![Some(false), None, Some(true)],
+            col_a.clone(),
+            &schema
+        );
 
         Ok(())
     }
diff --git a/datafusion/physical-expr/src/planner.rs 
b/datafusion/physical-expr/src/planner.rs
index 92580fce0..26583cd28 100644
--- a/datafusion/physical-expr/src/planner.rs
+++ b/datafusion/physical-expr/src/planner.rs
@@ -294,6 +294,9 @@ pub fn create_physical_expr(
                             input_schema,
                             execution_props,
                         ),
+                        // TODO refactor the logic of coercion the data type
+                        // data type in the `list expr` may be conflict with 
`value expr`,
+                        // we should not just compare data type between `value 
expr` with each `list expr`.
                         _ => {
                             let list_expr = create_physical_expr(
                                 expr,
@@ -310,6 +313,8 @@ pub fn create_physical_expr(
                                 &list_expr_data_type,
                                 &value_expr_data_type,
                             ) {
+                                // TODO: Can't cast from list type to value 
type directly
+                                // We should use the coercion rule to get the 
common data type
                                 expressions::cast(
                                     list_expr,
                                     input_schema,
@@ -325,7 +330,7 @@ pub fn create_physical_expr(
                     })
                     .collect::<Result<Vec<_>>>()?;
 
-                expressions::in_list(value_expr, list_exprs, negated)
+                expressions::in_list(value_expr, list_exprs, negated, 
input_schema)
             }
         },
         other => Err(DataFusionError::NotImplemented(format!(

Reply via email to