This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new dec9adcbe Optimize the evaluation of `IN` for large lists using InSet 
(#2156)
dec9adcbe is described below

commit dec9adcbe1c7b878e2544c43c9351ab9ded50e4f
Author: Yang Jiang <[email protected]>
AuthorDate: Sat Apr 9 02:36:46 2022 +0800

    Optimize the evaluation of `IN` for large lists using InSet (#2156)
    
    * commit 1
    
    * Add an InSet  as an optimized version for IN_LIST
    
    * fix clippy
    
    * fix ut
    
    * fix fmt
    
    * fix clippy
    
    * make clear in explain
    
    Co-authored-by: Andrew Lamb <[email protected]>
    
    * add UT and change threshold
    
    * fix clippy
    
    * change OPTIMIZER_INSET_THRESHOLD
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/core/src/physical_plan/planner.rs       |  60 ++-
 datafusion/core/tests/sql/predicates.rs            |  17 +
 .../physical-expr/src/expressions/in_list.rs       | 413 +++++++++++++++------
 3 files changed, 372 insertions(+), 118 deletions(-)

diff --git a/datafusion/core/src/physical_plan/planner.rs 
b/datafusion/core/src/physical_plan/planner.rs
index 1e6c13f9c..273ec0a85 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -1706,7 +1706,7 @@ mod tests {
         .build()?;
         let execution_plan = plan(&logical_plan).await?;
         // verify that the plan correctly adds cast from Int64(1) to Utf8
-        let expected = "InListExpr { expr: Column { name: \"c1\", index: 0 }, 
list: [Literal { value: Utf8(\"a\") }, CastExpr { expr: Literal { value: 
Int64(1) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }], 
negated: false }";
+        let expected = "InListExpr { expr: Column { name: \"c1\", index: 0 }, 
list: [Literal { value: Utf8(\"a\") }, CastExpr { expr: Literal { value: 
Int64(1) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }], 
negated: false, set: None }";
         assert!(format!("{:?}", execution_plan).contains(expected));
 
         // expression: "a in (true, 'a')"
@@ -1742,6 +1742,64 @@ mod tests {
         Ok(())
     }
 
+    #[tokio::test]
+    async fn in_set_test() -> Result<()> {
+        let testdata = crate::test_util::arrow_test_data();
+        let path = format!("{}/csv/aggregate_test_100.csv", testdata);
+        let options = CsvReadOptions::new().schema_infer_max_records(100);
+
+        // OPTIMIZER_INSET_THRESHOLD = 10
+        // expression: "a in ('a', 1, 2, ..30)"
+        let mut list = 
vec![Expr::Literal(ScalarValue::Utf8(Some("a".to_string())))];
+        for i in 1..31 {
+            list.push(Expr::Literal(ScalarValue::Int64(Some(i))));
+        }
+
+        let logical_plan = LogicalPlanBuilder::scan_csv(
+            Arc::new(LocalFileSystem {}),
+            &path,
+            options,
+            None,
+            1,
+        )
+        .await?
+        .filter(col("c12").lt(lit(0.05)))?
+        .project(vec![col("c1").in_list(list, false)])?
+        .build()?;
+        let execution_plan = plan(&logical_plan).await?;
+        let expected = "expr: [(InListExpr { expr: Column { name: \"c1\", 
index: 0 }, list: [Literal { value: Utf8(\"a\") }, CastExpr { expr: Literal { 
value: Int64(1) }, cast_type: Utf8, cast_options: CastOptions { safe: false } 
}, CastExpr { expr: Literal { value: Int64(2) }, cast_type: Utf8, cast_options: 
CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(3) }, 
cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: 
Literal { value: Int64(4)  [...]
+        assert!(format!("{:?}", execution_plan).contains(expected));
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn in_set_null_test() -> Result<()> {
+        let testdata = crate::test_util::arrow_test_data();
+        let path = format!("{}/csv/aggregate_test_100.csv", testdata);
+        let options = CsvReadOptions::new().schema_infer_max_records(100);
+        // test NULL
+        let mut list = vec![Expr::Literal(ScalarValue::Int64(None))];
+        for i in 1..31 {
+            list.push(Expr::Literal(ScalarValue::Int64(Some(i))));
+        }
+
+        let logical_plan = LogicalPlanBuilder::scan_csv(
+            Arc::new(LocalFileSystem {}),
+            &path,
+            options,
+            None,
+            1,
+        )
+        .await?
+        .filter(col("c12").lt(lit(0.05)))?
+        .project(vec![col("c1").in_list(list, false)])?
+        .build()?;
+        let execution_plan = plan(&logical_plan).await?;
+        let expected = "expr: [(InListExpr { expr: Column { name: \"c1\", 
index: 0 }, list: [CastExpr { expr: Literal { value: Int64(NULL) }, cast_type: 
Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { 
value: Int64(1) }, cast_type: Utf8, cast_options: CastOptions { safe: false } 
}, CastExpr { expr: Literal { value: Int64(2) }, cast_type: Utf8, cast_options: 
CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(3) }, 
cast_type: Utf8, cast_opti [...]
+        assert!(format!("{:?}", execution_plan).contains(expected));
+        Ok(())
+    }
+
     #[tokio::test]
     async fn hash_agg_input_schema() -> Result<()> {
         let testdata = crate::test_util::arrow_test_data();
diff --git a/datafusion/core/tests/sql/predicates.rs 
b/datafusion/core/tests/sql/predicates.rs
index 1369baa75..ea79e2b14 100644
--- a/datafusion/core/tests/sql/predicates.rs
+++ b/datafusion/core/tests/sql/predicates.rs
@@ -369,3 +369,20 @@ async fn test_expect_distinct() -> Result<()> {
     assert_batches_eq!(expected, &actual);
     Ok(())
 }
+
+#[tokio::test]
+async fn csv_in_set_test() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_aggregate_csv(&ctx).await?;
+    let sql = "SELECT count(*) FROM aggregate_test_100 WHERE c7 in 
('25','155','204','77','208','67','139','191','26','7','202','113','129','197','249','146','129','220','154','163','220','19','71','243','150','231','196','170','99','255');";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+-----------------+",
+        "| COUNT(UInt8(1)) |",
+        "+-----------------+",
+        "| 36              |",
+        "+-----------------+",
+    ];
+    assert_batches_sorted_eq!(expected, &actual);
+    Ok(())
+}
diff --git a/datafusion/physical-expr/src/expressions/in_list.rs 
b/datafusion/physical-expr/src/expressions/in_list.rs
index 2aee0d87d..a6894b938 100644
--- a/datafusion/physical-expr/src/expressions/in_list.rs
+++ b/datafusion/physical-expr/src/expressions/in_list.rs
@@ -18,6 +18,7 @@
 //! InList expression
 
 use std::any::Any;
+use std::collections::HashSet;
 use std::sync::Arc;
 
 use arrow::array::GenericStringArray;
@@ -32,13 +33,19 @@ use arrow::{
     record_batch::RecordBatch,
 };
 
-use crate::PhysicalExpr;
+use crate::{expressions, PhysicalExpr};
 use arrow::array::*;
 use arrow::buffer::{Buffer, MutableBuffer};
 use datafusion_common::ScalarValue;
 use datafusion_common::{DataFusionError, Result};
 use datafusion_expr::ColumnarValue;
 
+/// Size at which to use a Set rather than Vec for `IN` / `NOT IN`
+/// Value chosen by the benchmark at
+/// https://github.com/apache/arrow-datafusion/pull/2156#discussion_r845198369
+/// TODO: add switch codeGen in In_List
+static OPTIMIZER_INSET_THRESHOLD: usize = 30;
+
 macro_rules! compare_op_scalar {
     ($left: expr, $right:expr, $op:expr) => {{
         let null_bit_buffer = $left.data().null_buffer().cloned();
@@ -69,6 +76,23 @@ pub struct InListExpr {
     expr: Arc<dyn PhysicalExpr>,
     list: Vec<Arc<dyn PhysicalExpr>>,
     negated: bool,
+    set: Option<InSet>,
+}
+
+/// InSet
+#[derive(Debug)]
+pub struct InSet {
+    set: HashSet<ScalarValue>,
+}
+
+impl InSet {
+    pub fn new(set: HashSet<ScalarValue>) -> Self {
+        Self { set }
+    }
+
+    pub fn get_set(&self) -> &HashSet<ScalarValue> {
+        &self.set
+    }
 }
 
 macro_rules! make_contains {
@@ -181,6 +205,26 @@ macro_rules! make_contains_primitive {
     }};
 }
 
+macro_rules! set_contains_with_negated {
+    ($ARRAY:expr, $LIST_VALUES:expr, $NEGATED:expr) => {{
+        if $NEGATED {
+            return Ok(ColumnarValue::Array(Arc::new(
+                $ARRAY
+                    .iter()
+                    .map(|x| x.map(|v| 
!$LIST_VALUES.contains(&v.try_into().unwrap())))
+                    .collect::<BooleanArray>(),
+            )));
+        } else {
+            return Ok(ColumnarValue::Array(Arc::new(
+                $ARRAY
+                    .iter()
+                    .map(|x| x.map(|v| 
$LIST_VALUES.contains(&v.try_into().unwrap())))
+                    .collect::<BooleanArray>(),
+            )));
+        }
+    }};
+}
+
 // whether each value on the left (can be null) is contained in the non-null 
list
 fn in_list_primitive<T: ArrowPrimitiveType>(
     array: &PrimitiveArray<T>,
@@ -220,6 +264,42 @@ fn not_in_list_utf8<OffsetSize: StringOffsetSizeTrait>(
     compare_op_scalar!(array, values, |x, v: &[&str]| !v.contains(&x))
 }
 
+//check all filter values of In clause are static.
+//include `CastExpr + Literal` or `Literal`
+fn check_all_static_filter_expr(list: &[Arc<dyn PhysicalExpr>]) -> bool {
+    list.iter().all(|v| {
+        let cast = v.as_any().downcast_ref::<expressions::CastExpr>();
+        if let Some(c) = cast {
+            c.expr()
+                .as_any()
+                .downcast_ref::<expressions::Literal>()
+                .is_some()
+        } else {
+            let cast = v.as_any().downcast_ref::<expressions::Literal>();
+            cast.is_some()
+        }
+    })
+}
+
+fn cast_static_filter_to_set(list: &[Arc<dyn PhysicalExpr>]) -> 
HashSet<ScalarValue> {
+    HashSet::from_iter(list.iter().map(|expr| {
+        if let Some(cast) = 
expr.as_any().downcast_ref::<expressions::CastExpr>() {
+            cast.expr()
+                .as_any()
+                .downcast_ref::<expressions::Literal>()
+                .unwrap()
+                .value()
+                .clone()
+        } else {
+            expr.as_any()
+                .downcast_ref::<expressions::Literal>()
+                .unwrap()
+                .value()
+                .clone()
+        }
+    }))
+}
+
 impl InListExpr {
     /// Create a new InList expression
     pub fn new(
@@ -227,10 +307,20 @@ impl InListExpr {
         list: Vec<Arc<dyn PhysicalExpr>>,
         negated: bool,
     ) -> Self {
-        Self {
-            expr,
-            list,
-            negated,
+        if list.len() > OPTIMIZER_INSET_THRESHOLD && 
check_all_static_filter_expr(&list) {
+            Self {
+                expr,
+                set: Some(InSet::new(cast_static_filter_to_set(&list))),
+                list,
+                negated,
+            }
+        } else {
+            Self {
+                expr,
+                list,
+                negated,
+                set: None,
+            }
         }
     }
 
@@ -318,7 +408,13 @@ impl InListExpr {
 impl std::fmt::Display for InListExpr {
     fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
         if self.negated {
-            write!(f, "{} NOT IN ({:?})", self.expr, self.list)
+            if self.set.is_some() {
+                write!(f, "{} NOT IN (SET) ({:?})", self.expr, self.list)
+            } else {
+                write!(f, "{} NOT IN ({:?})", self.expr, self.list)
+            }
+        } else if self.set.is_some() {
+            write!(f, "Use {} IN (SET) ({:?})", self.expr, self.list)
         } else {
             write!(f, "{} IN ({:?})", self.expr, self.list)
         }
@@ -342,119 +438,202 @@ impl PhysicalExpr for InListExpr {
     fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
         let value = self.expr.evaluate(batch)?;
         let value_data_type = value.data_type();
-        let list_values = self
-            .list
-            .iter()
-            .map(|expr| expr.evaluate(batch))
-            .collect::<Result<Vec<_>>>()?;
-
-        let array = match value {
-            ColumnarValue::Array(array) => array,
-            ColumnarValue::Scalar(scalar) => scalar.to_array(),
-        };
 
-        match value_data_type {
-            DataType::Float32 => {
-                make_contains_primitive!(
-                    array,
-                    list_values,
-                    self.negated,
-                    Float32,
-                    Float32Array
-                )
-            }
-            DataType::Float64 => {
-                make_contains_primitive!(
-                    array,
-                    list_values,
-                    self.negated,
-                    Float64,
-                    Float64Array
-                )
-            }
-            DataType::Int16 => {
-                make_contains_primitive!(
-                    array,
-                    list_values,
-                    self.negated,
-                    Int16,
-                    Int16Array
-                )
-            }
-            DataType::Int32 => {
-                make_contains_primitive!(
-                    array,
-                    list_values,
-                    self.negated,
-                    Int32,
-                    Int32Array
-                )
-            }
-            DataType::Int64 => {
-                make_contains_primitive!(
-                    array,
-                    list_values,
-                    self.negated,
-                    Int64,
-                    Int64Array
-                )
-            }
-            DataType::Int8 => {
-                make_contains_primitive!(
-                    array,
-                    list_values,
-                    self.negated,
-                    Int8,
-                    Int8Array
-                )
-            }
-            DataType::UInt16 => {
-                make_contains_primitive!(
-                    array,
-                    list_values,
-                    self.negated,
-                    UInt16,
-                    UInt16Array
-                )
-            }
-            DataType::UInt32 => {
-                make_contains_primitive!(
-                    array,
-                    list_values,
-                    self.negated,
-                    UInt32,
-                    UInt32Array
-                )
-            }
-            DataType::UInt64 => {
-                make_contains_primitive!(
-                    array,
-                    list_values,
-                    self.negated,
-                    UInt64,
-                    UInt64Array
-                )
-            }
-            DataType::UInt8 => {
-                make_contains_primitive!(
-                    array,
-                    list_values,
-                    self.negated,
-                    UInt8,
-                    UInt8Array
-                )
-            }
-            DataType::Boolean => {
-                make_contains!(array, list_values, self.negated, Boolean, 
BooleanArray)
-            }
-            DataType::Utf8 => self.compare_utf8::<i32>(array, list_values, 
self.negated),
-            DataType::LargeUtf8 => {
-                self.compare_utf8::<i64>(array, list_values, self.negated)
+        if let Some(in_set) = &self.set {
+            let array = match value {
+                ColumnarValue::Array(array) => array,
+                ColumnarValue::Scalar(scalar) => scalar.to_array(),
+            };
+            let set = in_set.get_set();
+            match value_data_type {
+                DataType::Boolean => {
+                    let array = 
array.as_any().downcast_ref::<BooleanArray>().unwrap();
+                    set_contains_with_negated!(array, set, self.negated)
+                }
+                DataType::Int8 => {
+                    let array = 
array.as_any().downcast_ref::<Int8Array>().unwrap();
+                    set_contains_with_negated!(array, set, self.negated)
+                }
+                DataType::Int16 => {
+                    let array = 
array.as_any().downcast_ref::<Int16Array>().unwrap();
+                    set_contains_with_negated!(array, set, self.negated)
+                }
+                DataType::Int32 => {
+                    let array = 
array.as_any().downcast_ref::<Int32Array>().unwrap();
+                    set_contains_with_negated!(array, set, self.negated)
+                }
+                DataType::Int64 => {
+                    let array = 
array.as_any().downcast_ref::<Int64Array>().unwrap();
+                    set_contains_with_negated!(array, set, self.negated)
+                }
+                DataType::UInt8 => {
+                    let array = 
array.as_any().downcast_ref::<UInt8Array>().unwrap();
+                    set_contains_with_negated!(array, set, self.negated)
+                }
+                DataType::UInt16 => {
+                    let array = 
array.as_any().downcast_ref::<UInt16Array>().unwrap();
+                    set_contains_with_negated!(array, set, self.negated)
+                }
+                DataType::UInt32 => {
+                    let array = 
array.as_any().downcast_ref::<UInt32Array>().unwrap();
+                    set_contains_with_negated!(array, set, self.negated)
+                }
+                DataType::UInt64 => {
+                    let array = 
array.as_any().downcast_ref::<UInt64Array>().unwrap();
+                    set_contains_with_negated!(array, set, self.negated)
+                }
+                DataType::Float32 => {
+                    let array = 
array.as_any().downcast_ref::<Float32Array>().unwrap();
+                    set_contains_with_negated!(array, set, self.negated)
+                }
+                DataType::Float64 => {
+                    let array = 
array.as_any().downcast_ref::<Float64Array>().unwrap();
+                    set_contains_with_negated!(array, set, self.negated)
+                }
+                DataType::Utf8 => {
+                    let array = array
+                        .as_any()
+                        .downcast_ref::<GenericStringArray<i32>>()
+                        .unwrap();
+                    set_contains_with_negated!(array, set, self.negated)
+                }
+                DataType::LargeUtf8 => {
+                    let array = array
+                        .as_any()
+                        .downcast_ref::<GenericStringArray<i64>>()
+                        .unwrap();
+                    set_contains_with_negated!(array, set, self.negated)
+                }
+                datatype => {
+                    return Result::Err(DataFusionError::NotImplemented(format!(
+                        "InSet does not support datatype {:?}.",
+                        datatype
+                    )))
+                }
+            };
+        } else {
+            let list_values = self
+                .list
+                .iter()
+                .map(|expr| expr.evaluate(batch))
+                .collect::<Result<Vec<_>>>()?;
+
+            let array = match value {
+                ColumnarValue::Array(array) => array,
+                ColumnarValue::Scalar(scalar) => scalar.to_array(),
+            };
+
+            match value_data_type {
+                DataType::Float32 => {
+                    make_contains_primitive!(
+                        array,
+                        list_values,
+                        self.negated,
+                        Float32,
+                        Float32Array
+                    )
+                }
+                DataType::Float64 => {
+                    make_contains_primitive!(
+                        array,
+                        list_values,
+                        self.negated,
+                        Float64,
+                        Float64Array
+                    )
+                }
+                DataType::Int16 => {
+                    make_contains_primitive!(
+                        array,
+                        list_values,
+                        self.negated,
+                        Int16,
+                        Int16Array
+                    )
+                }
+                DataType::Int32 => {
+                    make_contains_primitive!(
+                        array,
+                        list_values,
+                        self.negated,
+                        Int32,
+                        Int32Array
+                    )
+                }
+                DataType::Int64 => {
+                    make_contains_primitive!(
+                        array,
+                        list_values,
+                        self.negated,
+                        Int64,
+                        Int64Array
+                    )
+                }
+                DataType::Int8 => {
+                    make_contains_primitive!(
+                        array,
+                        list_values,
+                        self.negated,
+                        Int8,
+                        Int8Array
+                    )
+                }
+                DataType::UInt16 => {
+                    make_contains_primitive!(
+                        array,
+                        list_values,
+                        self.negated,
+                        UInt16,
+                        UInt16Array
+                    )
+                }
+                DataType::UInt32 => {
+                    make_contains_primitive!(
+                        array,
+                        list_values,
+                        self.negated,
+                        UInt32,
+                        UInt32Array
+                    )
+                }
+                DataType::UInt64 => {
+                    make_contains_primitive!(
+                        array,
+                        list_values,
+                        self.negated,
+                        UInt64,
+                        UInt64Array
+                    )
+                }
+                DataType::UInt8 => {
+                    make_contains_primitive!(
+                        array,
+                        list_values,
+                        self.negated,
+                        UInt8,
+                        UInt8Array
+                    )
+                }
+                DataType::Boolean => {
+                    make_contains!(
+                        array,
+                        list_values,
+                        self.negated,
+                        Boolean,
+                        BooleanArray
+                    )
+                }
+                DataType::Utf8 => {
+                    self.compare_utf8::<i32>(array, list_values, self.negated)
+                }
+                DataType::LargeUtf8 => {
+                    self.compare_utf8::<i64>(array, list_values, self.negated)
+                }
+                datatype => 
Result::Err(DataFusionError::NotImplemented(format!(
+                    "InList does not support datatype {:?}.",
+                    datatype
+                ))),
             }
-            datatype => Result::Err(DataFusionError::NotImplemented(format!(
-                "InList does not support datatype {:?}.",
-                datatype
-            ))),
         }
     }
 }

Reply via email to