This is an automated email from the ASF dual-hosted git repository.

xudong963 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 73447b560 simplify the `between` expr during logical plan optimization 
(#3404)
73447b560 is described below

commit 73447b560edea20773114f4ed8b49a561b91799d
Author: Kirk Mitchener <[email protected]>
AuthorDate: Fri Sep 9 08:27:55 2022 -0400

    simplify the `between` expr during logical plan optimization (#3404)
    
    * rewrite between expression so that it can be further optimized and pushed 
down
    
    * update tests
    
    * update for comment and test
    
    * fix common_subexpr_eliminate to retain predictable ordering between runs
---
 datafusion/core/tests/sql/predicates.rs            |  7 +-
 datafusion/core/tests/sql/select.rs                |  4 +-
 .../optimizer/src/common_subexpr_eliminate.rs      | 20 +++---
 datafusion/optimizer/src/simplify_expressions.rs   | 78 +++++++++++++++++++---
 4 files changed, 83 insertions(+), 26 deletions(-)

diff --git a/datafusion/core/tests/sql/predicates.rs 
b/datafusion/core/tests/sql/predicates.rs
index 3c11b690d..32365090a 100644
--- a/datafusion/core/tests/sql/predicates.rs
+++ b/datafusion/core/tests/sql/predicates.rs
@@ -427,11 +427,12 @@ async fn multiple_or_predicates() -> Result<()> {
     let expected =vec![
         "Explain [plan_type:Utf8, plan:Utf8]",
         "  Projection: #lineitem.l_partkey [l_partkey:Int64]",
-        "    Projection: #part.p_partkey = #lineitem.l_partkey AS 
#part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey, 
#lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size 
[#part.p_partkey = 
#lineitem.l_partkey#lineitem.l_partkey#part.p_partkey:Boolean;N, 
l_partkey:Int64, l_quantity:Float64, p_brand:Utf8, p_size:Int32]",
-        "      Filter: #part.p_partkey = #lineitem.l_partkey AND #part.p_brand 
= Utf8(\"Brand#12\") AND #lineitem.l_quantity >= CAST(Int64(1) AS Float64) AND 
#lineitem.l_quantity <= CAST(Int64(11) AS Float64) AND #part.p_size BETWEEN 
Int64(1) AND Int64(5) OR #part.p_brand = Utf8(\"Brand#23\") AND 
#lineitem.l_quantity >= CAST(Int64(10) AS Float64) AND #lineitem.l_quantity <= 
CAST(Int64(20) AS Float64) AND #part.p_size BETWEEN Int64(1) AND Int64(10) OR 
#part.p_brand = Utf8(\"Brand#34\") AN [...]
+        "    Projection: #part.p_partkey = #lineitem.l_partkey AS 
#part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey, 
#part.p_size >= Int32(1) AS #part.p_size >= Int32(1)Int32(1)#part.p_size, 
#lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size 
[#part.p_partkey = 
#lineitem.l_partkey#lineitem.l_partkey#part.p_partkey:Boolean;N, #part.p_size 
>= Int32(1)Int32(1)#part.p_size:Boolean;N, l_partkey:Int64, l_quantity:Float64, 
p_brand:Utf8, p_size:Int32]",
+        "      Filter: #part.p_partkey = #lineitem.l_partkey AND #part.p_brand 
= Utf8(\"Brand#12\") AND #lineitem.l_quantity >= CAST(Int64(1) AS Float64) AND 
#lineitem.l_quantity <= CAST(Int64(11) AS Float64) AND #part.p_size <= Int32(5) 
OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >= 
CAST(Int64(10) AS Float64) AND #lineitem.l_quantity <= CAST(Int64(20) AS 
Float64) AND #part.p_size <= Int32(10) OR #part.p_brand = Utf8(\"Brand#34\") 
AND #lineitem.l_quantity >= CAST(Int64 [...]
         "        CrossJoin: [l_partkey:Int64, l_quantity:Float64, 
p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
         "          TableScan: lineitem projection=[l_partkey, l_quantity] 
[l_partkey:Int64, l_quantity:Float64]",
-        "          TableScan: part projection=[p_partkey, p_brand, p_size] 
[p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
+        "          Filter: #part.p_size >= Int32(1) [p_partkey:Int64, 
p_brand:Utf8, p_size:Int32]",
+        "            TableScan: part projection=[p_partkey, p_brand, p_size], 
partial_filters=[#part.p_size >= Int32(1)] [p_partkey:Int64, p_brand:Utf8, 
p_size:Int32]",
     ];
     let formatted = plan.display_indent_schema().to_string();
     let actual: Vec<&str> = formatted.trim().lines().collect();
diff --git a/datafusion/core/tests/sql/select.rs 
b/datafusion/core/tests/sql/select.rs
index 06353167c..54d1b24e8 100644
--- a/datafusion/core/tests/sql/select.rs
+++ b/datafusion/core/tests/sql/select.rs
@@ -495,10 +495,10 @@ async fn use_between_expression_in_select_query() -> 
Result<()> {
         .unwrap()
         .to_string();
 
-    // Only test that the projection exprs arecorrect, rather than entire 
output
+    // Only test that the projection exprs are correct, rather than entire 
output
     let needle = "ProjectionExec: expr=[c1@0 >= 2 AND c1@0 <= 3 as test.c1 
BETWEEN Int64(2) AND Int64(3)]";
     assert_contains!(&formatted, needle);
-    let needle = "Projection: #test.c1 BETWEEN Int64(2) AND Int64(3)";
+    let needle = "Projection: #test.c1 >= Int64(2) AND #test.c1 <= Int64(3)";
     assert_contains!(&formatted, needle);
 
     Ok(())
diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs 
b/datafusion/optimizer/src/common_subexpr_eliminate.rs
index 239939f81..978b79d37 100644
--- a/datafusion/optimizer/src/common_subexpr_eliminate.rs
+++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs
@@ -28,7 +28,7 @@ use datafusion_expr::{
     utils::from_plan,
     Expr, ExprSchemable,
 };
-use std::collections::{HashMap, HashSet};
+use std::collections::{BTreeSet, HashMap};
 use std::sync::Arc;
 
 /// A map from expression's identifier to tuple including
@@ -271,12 +271,12 @@ fn to_arrays(
 /// Build the "intermediate" projection plan that evaluates the extracted 
common expressions.
 fn build_project_plan(
     input: LogicalPlan,
-    affected_id: HashSet<Identifier>,
+    affected_id: BTreeSet<Identifier>,
     expr_set: &ExprSet,
 ) -> Result<LogicalPlan> {
     let mut project_exprs = vec![];
     let mut fields = vec![];
-    let mut fields_set = HashSet::new();
+    let mut fields_set = BTreeSet::new();
 
     for id in affected_id {
         match expr_set.get(&id) {
@@ -320,7 +320,7 @@ fn rewrite_expr(
     expr_set: &mut ExprSet,
     optimizer_config: &OptimizerConfig,
 ) -> Result<(Vec<Vec<Expr>>, LogicalPlan)> {
-    let mut affected_id = HashSet::<Identifier>::new();
+    let mut affected_id = BTreeSet::<Identifier>::new();
 
     let rewrote_exprs = exprs_list
         .iter()
@@ -482,7 +482,7 @@ struct CommonSubexprRewriter<'a> {
     expr_set: &'a mut ExprSet,
     id_array: &'a [(usize, Identifier)],
     /// Which identifier is replaced.
-    affected_id: &'a mut HashSet<Identifier>,
+    affected_id: &'a mut BTreeSet<Identifier>,
 
     /// the max series number we have rewritten. Other expression nodes
     /// with smaller series number is already replaced and shouldn't
@@ -561,7 +561,7 @@ fn replace_common_expr(
     expr: Expr,
     id_array: &[(usize, Identifier)],
     expr_set: &mut ExprSet,
-    affected_id: &mut HashSet<Identifier>,
+    affected_id: &mut BTreeSet<Identifier>,
 ) -> Result<Expr> {
     expr.rewrite(&mut CommonSubexprRewriter {
         expr_set,
@@ -752,7 +752,7 @@ mod test {
     #[test]
     fn redundant_project_fields() {
         let table_scan = test_table_scan().unwrap();
-        let affected_id: HashSet<Identifier> =
+        let affected_id: BTreeSet<Identifier> =
             ["c+a".to_string(), "d+a".to_string()].into_iter().collect();
         let expr_set = [
             ("c+a".to_string(), (col("c+a"), 1, DataType::UInt32)),
@@ -764,7 +764,7 @@ mod test {
             build_project_plan(table_scan, affected_id.clone(), 
&expr_set).unwrap();
         let project_2 = build_project_plan(project, affected_id, 
&expr_set).unwrap();
 
-        let mut field_set = HashSet::new();
+        let mut field_set = BTreeSet::new();
         for field in project_2.schema().fields() {
             assert!(field_set.insert(field.qualified_name()));
         }
@@ -779,7 +779,7 @@ mod test {
             .unwrap()
             .build()
             .unwrap();
-        let affected_id: HashSet<Identifier> =
+        let affected_id: BTreeSet<Identifier> =
             ["c+a".to_string(), "d+a".to_string()].into_iter().collect();
         let expr_set = [
             ("c+a".to_string(), (col("c+a"), 1, DataType::UInt32)),
@@ -790,7 +790,7 @@ mod test {
         let project = build_project_plan(join, affected_id.clone(), 
&expr_set).unwrap();
         let project_2 = build_project_plan(project, affected_id, 
&expr_set).unwrap();
 
-        let mut field_set = HashSet::new();
+        let mut field_set = BTreeSet::new();
         for field in project_2.schema().fields() {
             assert!(field_set.insert(field.qualified_name()));
         }
diff --git a/datafusion/optimizer/src/simplify_expressions.rs 
b/datafusion/optimizer/src/simplify_expressions.rs
index d1afa3543..aa87c5318 100644
--- a/datafusion/optimizer/src/simplify_expressions.rs
+++ b/datafusion/optimizer/src/simplify_expressions.rs
@@ -164,8 +164,6 @@ fn is_op_with(target_op: Operator, haystack: &Expr, needle: 
&Expr) -> bool {
 
 /// returns the contained boolean value in `expr` as
 /// `Expr::Literal(ScalarValue::Boolean(v))`.
-///
-/// panics if expr is not a literal boolean
 fn as_bool_lit(expr: Expr) -> Result<Option<bool>> {
     match expr {
         Expr::Literal(ScalarValue::Boolean(v)) => Ok(v),
@@ -502,7 +500,7 @@ impl<'a> ConstEvaluator<'a> {
             ColumnarValue::Array(a) => {
                 if a.len() != 1 {
                     Err(DataFusionError::Execution(format!(
-                        "Could not evaluate the expressison, found a result of 
length {}",
+                        "Could not evaluate the expression, found a result of 
length {}",
                         a.len()
                     )))
                 } else {
@@ -803,6 +801,27 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, 
S> {
                 out_expr.rewrite(self)?
             }
 
+            //
+            // Rules for Between
+            //
+
+            // a between 3 and 5  -->  a >= 3 AND a <=5
+            // a not between 3 and 5  -->  a < 3 OR a > 5
+            Between {
+                expr,
+                low,
+                high,
+                negated,
+            } => {
+                if negated {
+                    let l = *expr.clone();
+                    let r = *expr;
+                    or(l.lt(*low), r.gt(*high))
+                } else {
+                    and(expr.clone().gt_eq(*low), expr.lt_eq(*high))
+                }
+            }
+
             expr => {
                 // no additional rewrites possible
                 expr
@@ -1555,8 +1574,13 @@ mod tests {
             high: Box::new(lit(10)),
         };
         let expr = expr.or(lit_bool_null());
-        let result = simplify(expr.clone());
-        assert_eq!(expr, result);
+        let result = simplify(expr);
+
+        let expected_expr = or(
+            and(col("c1").gt_eq(lit(0)), col("c1").lt_eq(lit(10))),
+            lit_bool_null(),
+        );
+        assert_eq!(expected_expr, result);
     }
 
     #[test]
@@ -1579,8 +1603,8 @@ mod tests {
         assert_eq!(simplify(lit_bool_null().and(lit(false))), lit(false),);
 
         // c1 BETWEEN Int32(0) AND Int32(10) AND Boolean(NULL)
-        // it can be either NULL or FALSE depending on the value of `c1 
BETWEEN Int32(0) AND Int32(10`
-        // and should not be rewritten
+        // it can be either NULL or FALSE depending on the value of `c1 
BETWEEN Int32(0) AND Int32(10)`
+        // and the Boolean(NULL) should remain
         let expr = Expr::Between {
             expr: Box::new(col("c1")),
             negated: false,
@@ -1588,8 +1612,40 @@ mod tests {
             high: Box::new(lit(10)),
         };
         let expr = expr.and(lit_bool_null());
-        let result = simplify(expr.clone());
-        assert_eq!(expr, result);
+        let result = simplify(expr);
+
+        let expected_expr = and(
+            and(col("c1").gt_eq(lit(0)), col("c1").lt_eq(lit(10))),
+            lit_bool_null(),
+        );
+        assert_eq!(expected_expr, result);
+    }
+
+    #[test]
+    fn simplify_expr_between() {
+        // c2 between 3 and 4 is c2 >= 3 and c2 <= 4
+        let expr = Expr::Between {
+            expr: Box::new(col("c2")),
+            negated: false,
+            low: Box::new(lit(3)),
+            high: Box::new(lit(4)),
+        };
+        assert_eq!(
+            simplify(expr),
+            and(col("c2").gt_eq(lit(3)), col("c2").lt_eq(lit(4)))
+        );
+
+        // c2 not between 3 and 4 is c2 < 3 or c2 > 4
+        let expr = Expr::Between {
+            expr: Box::new(col("c2")),
+            negated: true,
+            low: Box::new(lit(3)),
+            high: Box::new(lit(4)),
+        };
+        assert_eq!(
+            simplify(expr),
+            or(col("c2").lt(lit(3)), col("c2").gt(lit(4)))
+        );
     }
 
     // ------------------------------
@@ -2167,7 +2223,7 @@ mod tests {
             .unwrap()
             .build()
             .unwrap();
-        let expected = "Filter: #test.d NOT BETWEEN Int32(1) AND Int32(10) AS 
NOT test.d BETWEEN Int32(1) AND Int32(10)\
+        let expected = "Filter: #test.d < Int32(1) OR #test.d > Int32(10) AS 
NOT test.d BETWEEN Int32(1) AND Int32(10)\
         \n  TableScan: test";
 
         assert_optimized_plan_eq(&plan, expected);
@@ -2188,7 +2244,7 @@ mod tests {
             .unwrap()
             .build()
             .unwrap();
-        let expected = "Filter: #test.d BETWEEN Int32(1) AND Int32(10) AS NOT 
test.d NOT BETWEEN Int32(1) AND Int32(10)\
+        let expected = "Filter: #test.d >= Int32(1) AND #test.d <= Int32(10) 
AS NOT test.d NOT BETWEEN Int32(1) AND Int32(10)\
         \n  TableScan: test";
 
         assert_optimized_plan_eq(&plan, expected);

Reply via email to