This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 0ce6f1b1f fix bug: the static/constant set value must be evaluated and 
get the result (#2834)
0ce6f1b1f is described below

commit 0ce6f1b1fd8fbf94238db913f8bc884e3c3c6aeb
Author: Kun Liu <[email protected]>
AuthorDate: Thu Jul 7 18:17:41 2022 +0800

    fix bug: the static/constant set value must be evaluated and get the result 
(#2834)
---
 datafusion/core/src/physical_plan/planner.rs       |   5 +-
 .../physical-expr/src/expressions/in_list.rs       | 129 +++++++++++++--------
 datafusion/physical-expr/src/planner.rs            |   2 +-
 3 files changed, 82 insertions(+), 54 deletions(-)

diff --git a/datafusion/core/src/physical_plan/planner.rs 
b/datafusion/core/src/physical_plan/planner.rs
index 0cb037620..40daef606 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -1890,14 +1890,13 @@ mod tests {
         for i in 1..31 {
             list.push(Expr::Literal(ScalarValue::Int64(Some(i))));
         }
-
         let logical_plan = test_csv_scan()
             .await?
             .filter(col("c12").lt(lit(0.05)))?
             .project(vec![col("c1").in_list(list, false)])?
             .build()?;
         let execution_plan = plan(&logical_plan).await?;
-        let expected = "expr: [(InListExpr { expr: Column { name: \"c1\", 
index: 0 }, list: [Literal { value: Utf8(\"a\") }, TryCastExpr { expr: Literal 
{ value: Int64(1) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: 
Int64(2) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(3) }, 
cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(4) }, cast_type: 
Utf8 }, TryCastExpr { expr: Literal { value: Int64(5) }, cast_type: Utf8 }, 
TryCastExpr { expr: Literal [...]
+        let expected = "expr: [(InListExpr { expr: Column { name: \"c1\", 
index: 0 }, list: [Literal { value: Utf8(\"a\") }, TryCastExpr { expr: Literal 
{ value: Int64(1) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: 
Int64(2) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(3) }, 
cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(4) }, cast_type: 
Utf8 }, TryCastExpr { expr: Literal { value: Int64(5) }, cast_type: Utf8 }, 
TryCastExpr { expr: Literal [...]
         assert!(format!("{:?}", execution_plan).contains(expected));
         Ok(())
     }
@@ -1916,7 +1915,7 @@ mod tests {
             .project(vec![col("c1").in_list(list, false)])?
             .build()?;
         let execution_plan = plan(&logical_plan).await?;
-        let expected = "expr: [(InListExpr { expr: Column { name: \"c1\", 
index: 0 }, list: [TryCastExpr { expr: Literal { value: Int64(NULL) }, 
cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(1) }, cast_type: 
Utf8 }, TryCastExpr { expr: Literal { value: Int64(2) }, cast_type: Utf8 }, 
TryCastExpr { expr: Literal { value: Int64(3) }, cast_type: Utf8 }, TryCastExpr 
{ expr: Literal { value: Int64(4) }, cast_type: Utf8 }, TryCastExpr { expr: 
Literal { value: Int64(5) }, cast_ty [...]
+        let expected = "expr: [(InListExpr { expr: Column { name: \"c1\", 
index: 0 }, list: [TryCastExpr { expr: Literal { value: Int64(NULL) }, 
cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(1) }, cast_type: 
Utf8 }, TryCastExpr { expr: Literal { value: Int64(2) }, cast_type: Utf8 }, 
TryCastExpr { expr: Literal { value: Int64(3) }, cast_type: Utf8 }, TryCastExpr 
{ expr: Literal { value: Int64(4) }, cast_type: Utf8 }, TryCastExpr { expr: 
Literal { value: Int64(5) }, cast_ty [...]
         assert!(format!("{:?}", execution_plan).contains(expected));
         Ok(())
     }
diff --git a/datafusion/physical-expr/src/expressions/in_list.rs 
b/datafusion/physical-expr/src/expressions/in_list.rs
index c9a3e419a..40a253c33 100644
--- a/datafusion/physical-expr/src/expressions/in_list.rs
+++ b/datafusion/physical-expr/src/expressions/in_list.rs
@@ -32,7 +32,7 @@ use arrow::{
     record_batch::RecordBatch,
 };
 
-use crate::{expressions, PhysicalExpr};
+use crate::PhysicalExpr;
 use arrow::array::*;
 use arrow::buffer::{Buffer, MutableBuffer};
 use datafusion_common::ScalarValue;
@@ -269,40 +269,26 @@ fn not_in_list_utf8<OffsetSize: OffsetSizeTrait>(
     compare_op_scalar!(array, values, |x, v: &[&str]| !v.contains(&x))
 }
 
-//check all filter values of In clause are static.
-//include `CastExpr + Literal` or `Literal`
-fn check_all_static_filter_expr(list: &[Arc<dyn PhysicalExpr>]) -> bool {
-    list.iter().all(|v| {
-        let cast = v.as_any().downcast_ref::<expressions::CastExpr>();
-        if let Some(c) = cast {
-            c.expr()
-                .as_any()
-                .downcast_ref::<expressions::Literal>()
-                .is_some()
-        } else {
-            let cast = v.as_any().downcast_ref::<expressions::Literal>();
-            cast.is_some()
-        }
-    })
-}
-
-fn cast_static_filter_to_set(list: &[Arc<dyn PhysicalExpr>]) -> 
HashSet<ScalarValue> {
-    HashSet::from_iter(list.iter().map(|expr| {
-        if let Some(cast) = 
expr.as_any().downcast_ref::<expressions::CastExpr>() {
-            cast.expr()
-                .as_any()
-                .downcast_ref::<expressions::Literal>()
-                .unwrap()
-                .value()
-                .clone()
-        } else {
-            expr.as_any()
-                .downcast_ref::<expressions::Literal>()
-                .unwrap()
-                .value()
-                .clone()
-        }
-    }))
+// try evaluate all list exprs and check if the exprs are constants or not
+fn try_cast_static_filter_to_set(
+    list: &[Arc<dyn PhysicalExpr>],
+    schema: &Schema,
+) -> Result<HashSet<ScalarValue>> {
+    let batch = RecordBatch::new_empty(Arc::new(schema.to_owned()));
+    match list
+        .iter()
+        .map(|expr| match expr.evaluate(&batch) {
+            Ok(ColumnarValue::Array(_)) => Err(DataFusionError::NotImplemented(
+                "InList doesn't support to evaluate the array 
result".to_string(),
+            )),
+            Ok(ColumnarValue::Scalar(s)) => Ok(s),
+            Err(e) => Err(e),
+        })
+        .collect::<Result<Vec<_>>>()
+    {
+        Ok(s) => Ok(HashSet::from_iter(s)),
+        Err(e) => Err(e),
+    }
 }
 
 fn make_list_contains_decimal(
@@ -379,22 +365,24 @@ impl InListExpr {
         expr: Arc<dyn PhysicalExpr>,
         list: Vec<Arc<dyn PhysicalExpr>>,
         negated: bool,
+        schema: &Schema,
     ) -> Self {
-        if list.len() > OPTIMIZER_INSET_THRESHOLD && 
check_all_static_filter_expr(&list) {
-            Self {
-                expr,
-                set: Some(InSet::new(cast_static_filter_to_set(&list))),
-                list,
-                negated,
-            }
-        } else {
-            Self {
-                expr,
-                list,
-                negated,
-                set: None,
+        if list.len() > OPTIMIZER_INSET_THRESHOLD {
+            if let Ok(set) = try_cast_static_filter_to_set(&list, schema) {
+                return Self {
+                    expr,
+                    set: Some(InSet::new(set)),
+                    list,
+                    negated,
+                };
             }
         }
+        Self {
+            expr,
+            list,
+            negated,
+            set: None,
+        }
     }
 
     /// Input expression
@@ -795,8 +783,9 @@ pub fn in_list(
     expr: Arc<dyn PhysicalExpr>,
     list: Vec<Arc<dyn PhysicalExpr>>,
     negated: &bool,
+    schema: &Schema,
 ) -> Result<Arc<dyn PhysicalExpr>> {
-    Ok(Arc::new(InListExpr::new(expr, list, *negated)))
+    Ok(Arc::new(InListExpr::new(expr, list, *negated, schema)))
 }
 
 #[cfg(test)]
@@ -804,6 +793,7 @@ mod tests {
     use arrow::{array::StringArray, datatypes::Field};
 
     use super::*;
+    use crate::expressions;
     use crate::expressions::{col, lit};
     use crate::planner::in_list_cast;
     use datafusion_common::Result;
@@ -812,7 +802,7 @@ mod tests {
     macro_rules! in_list {
         ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr, 
$SCHEMA:expr) => {{
             let (cast_expr, cast_list_exprs) = in_list_cast($COL, $LIST, 
$SCHEMA)?;
-            let expr = in_list(cast_expr, cast_list_exprs, $NEGATED).unwrap();
+            let expr = in_list(cast_expr, cast_list_exprs, $NEGATED, 
$SCHEMA).unwrap();
             let result = expr.evaluate(&$BATCH)?.into_array($BATCH.num_rows());
             let result = result
                 .as_any()
@@ -1326,4 +1316,43 @@ mod tests {
         );
         Ok(())
     }
+
+    #[test]
+    fn test_cast_static_filter_to_set() -> Result<()> {
+        // random schema
+        let schema = Schema::new(vec![Field::new("a", DataType::Decimal(13, 
4), true)]);
+        // list of phy expr
+        let mut phy_exprs = vec![
+            lit(1i64),
+            expressions::cast(lit(2i32), &schema, DataType::Int64)?,
+            expressions::try_cast(lit(3.13f32), &schema, DataType::Int64)?,
+        ];
+        let result = try_cast_static_filter_to_set(&phy_exprs, 
&schema).unwrap();
+
+        assert!(result.contains(&1i64.try_into().unwrap()));
+        assert!(result.contains(&2i64.try_into().unwrap()));
+        assert!(result.contains(&3i64.try_into().unwrap()));
+
+        assert!(try_cast_static_filter_to_set(&phy_exprs, &schema).is_ok());
+        // cast(cast(lit())), but the cast to the same data type, one case 
will be ignored
+        phy_exprs.push(expressions::cast(
+            expressions::cast(lit(2i32), &schema, DataType::Int64)?,
+            &schema,
+            DataType::Int64,
+        )?);
+        assert!(try_cast_static_filter_to_set(&phy_exprs, &schema).is_ok());
+        // case(cast(lit())), the cast to the diff data type
+        phy_exprs.push(expressions::cast(
+            expressions::cast(lit(2i32), &schema, DataType::Int64)?,
+            &schema,
+            DataType::Int32,
+        )?);
+        assert!(try_cast_static_filter_to_set(&phy_exprs, &schema).is_ok());
+
+        // column
+        phy_exprs.push(expressions::col("a", &schema)?);
+        assert!(try_cast_static_filter_to_set(&phy_exprs, &schema).is_err());
+
+        Ok(())
+    }
 }
diff --git a/datafusion/physical-expr/src/planner.rs 
b/datafusion/physical-expr/src/planner.rs
index 1eb975bed..570f438e2 100644
--- a/datafusion/physical-expr/src/planner.rs
+++ b/datafusion/physical-expr/src/planner.rs
@@ -302,7 +302,7 @@ pub fn create_physical_expr(
 
                 let (cast_expr, cast_list_exprs) =
                     in_list_cast(value_expr, list_exprs, input_schema)?;
-                expressions::in_list(cast_expr, cast_list_exprs, negated)
+                expressions::in_list(cast_expr, cast_list_exprs, negated, 
input_schema)
             }
         },
         other => Err(DataFusionError::NotImplemented(format!(

Reply via email to