This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 7fab5ac53c Move inlist rule to expr_simplifier (#9692)
7fab5ac53c is described below

commit 7fab5ac53c1e715743aee7a51111c2976add8a99
Author: Jay Zhan <[email protected]>
AuthorDate: Wed Mar 20 00:58:10 2024 +0800

    Move inlist rule to expr_simplifier (#9692)
    
    * move inlist rule to expr_simplifier
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * clippy
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    ---------
    
    Signed-off-by: jayzhan211 <[email protected]>
---
 .../src/simplify_expressions/expr_simplifier.rs    | 220 +++++++++++++++++++--
 .../src/simplify_expressions/inlist_simplifier.rs  | 122 +-----------
 datafusion/sqllogictest/test_files/predicates.slt  |   2 +-
 3 files changed, 210 insertions(+), 134 deletions(-)

diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs 
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index 5b5bca75dd..61e002ece9 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -21,7 +21,7 @@ use std::borrow::Cow;
 use std::collections::HashSet;
 use std::ops::Not;
 
-use super::inlist_simplifier::{InListSimplifier, ShortenInListSimplifier};
+use super::inlist_simplifier::ShortenInListSimplifier;
 use super::utils::*;
 use crate::analyzer::type_coercion::TypeCoercionRewriter;
 use crate::simplify_expressions::guarantees::GuaranteeRewriter;
@@ -175,7 +175,6 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
         let mut simplifier = Simplifier::new(&self.info);
         let mut const_evaluator = 
ConstEvaluator::try_new(self.info.execution_props())?;
         let mut shorten_in_list_simplifier = ShortenInListSimplifier::new();
-        let mut inlist_simplifier = InListSimplifier::new();
         let mut guarantee_rewriter = GuaranteeRewriter::new(&self.guarantees);
 
         if self.canonicalize {
@@ -190,8 +189,6 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
             .data()?
             .rewrite(&mut simplifier)
             .data()?
-            .rewrite(&mut inlist_simplifier)
-            .data()?
             .rewrite(&mut guarantee_rewriter)
             .data()?
             // run both passes twice to try an minimize simplifications that 
we missed
@@ -1452,13 +1449,8 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for 
Simplifier<'a, S> {
                 op: Operator::Or,
                 right,
             }) if are_inlist_and_eq(left.as_ref(), right.as_ref()) => {
-                let left = as_inlist(left.as_ref());
-                let right = as_inlist(right.as_ref());
-
-                let lhs = left.unwrap();
-                let rhs = right.unwrap();
-                let lhs = lhs.into_owned();
-                let rhs = rhs.into_owned();
+                let lhs = to_inlist(*left).unwrap();
+                let rhs = to_inlist(*right).unwrap();
                 let mut seen: HashSet<Expr> = HashSet::new();
                 let list = lhs
                     .list
@@ -1473,7 +1465,123 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for 
Simplifier<'a, S> {
                     negated: false,
                 };
 
-                return Ok(Transformed::yes(Expr::InList(merged_inlist)));
+                Transformed::yes(Expr::InList(merged_inlist))
+            }
+
+            // Simplify expressions that is guaranteed to be true or false to 
a literal boolean expression
+            //
+            // Rules:
+            // If both expressions are `IN` or `NOT IN`, then we can apply 
intersection or union on both lists
+            //   Intersection:
+            //     1. `a in (1,2,3) AND a in (4,5) -> a in (), which is false`
+            //     2. `a in (1,2,3) AND a in (2,3,4) -> a in (2,3)`
+            //     3. `a not in (1,2,3) OR a not in (3,4,5,6) -> a not in (3)`
+            //   Union:
+            //     4. `a not int (1,2,3) AND a not in (4,5,6) -> a not in 
(1,2,3,4,5,6)`
+            //     # This rule is handled by `or_in_list_simplifier.rs`
+            //     5. `a in (1,2,3) OR a in (4,5,6) -> a in (1,2,3,4,5,6)`
+            // If one of the expressions is `IN` and another one is `NOT IN`, 
then we apply exception on `In` expression
+            //     6. `a in (1,2,3,4) AND a not in (1,2,3,4,5) -> a in (), 
which is false`
+            //     7. `a not in (1,2,3,4) AND a in (1,2,3,4,5) -> a = 5`
+            //     8. `a in (1,2,3,4) AND a not in (5,6,7,8) -> a in (1,2,3,4)`
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: Operator::And,
+                right,
+            }) if are_inlist_and_eq_and_match_neg(
+                left.as_ref(),
+                right.as_ref(),
+                false,
+                false,
+            ) =>
+            {
+                match (*left, *right) {
+                    (Expr::InList(l1), Expr::InList(l2)) => {
+                        return inlist_intersection(l1, l2, 
false).map(Transformed::yes);
+                    }
+                    // Matched previously once
+                    _ => unreachable!(),
+                }
+            }
+
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: Operator::And,
+                right,
+            }) if are_inlist_and_eq_and_match_neg(
+                left.as_ref(),
+                right.as_ref(),
+                true,
+                true,
+            ) =>
+            {
+                match (*left, *right) {
+                    (Expr::InList(l1), Expr::InList(l2)) => {
+                        return inlist_union(l1, l2, 
true).map(Transformed::yes);
+                    }
+                    // Matched previously once
+                    _ => unreachable!(),
+                }
+            }
+
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: Operator::And,
+                right,
+            }) if are_inlist_and_eq_and_match_neg(
+                left.as_ref(),
+                right.as_ref(),
+                false,
+                true,
+            ) =>
+            {
+                match (*left, *right) {
+                    (Expr::InList(l1), Expr::InList(l2)) => {
+                        return inlist_except(l1, l2).map(Transformed::yes);
+                    }
+                    // Matched previously once
+                    _ => unreachable!(),
+                }
+            }
+
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: Operator::And,
+                right,
+            }) if are_inlist_and_eq_and_match_neg(
+                left.as_ref(),
+                right.as_ref(),
+                true,
+                false,
+            ) =>
+            {
+                match (*left, *right) {
+                    (Expr::InList(l1), Expr::InList(l2)) => {
+                        return inlist_except(l2, l1).map(Transformed::yes);
+                    }
+                    // Matched previously once
+                    _ => unreachable!(),
+                }
+            }
+
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: Operator::Or,
+                right,
+            }) if are_inlist_and_eq_and_match_neg(
+                left.as_ref(),
+                right.as_ref(),
+                true,
+                true,
+            ) =>
+            {
+                match (*left, *right) {
+                    (Expr::InList(l1), Expr::InList(l2)) => {
+                        return inlist_intersection(l1, l2, 
true).map(Transformed::yes);
+                    }
+                    // Matched previously once
+                    _ => unreachable!(),
+                }
             }
 
             // no additional rewrites possible
@@ -1482,6 +1590,22 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for 
Simplifier<'a, S> {
     }
 }
 
+// TODO: We might not need this after defer pattern for Box is stabilized. 
https://github.com/rust-lang/rust/issues/87121
+fn are_inlist_and_eq_and_match_neg(
+    left: &Expr,
+    right: &Expr,
+    is_left_neg: bool,
+    is_right_neg: bool,
+) -> bool {
+    match (left, right) {
+        (Expr::InList(l), Expr::InList(r)) => {
+            l.expr == r.expr && l.negated == is_left_neg && r.negated == 
is_right_neg
+        }
+        _ => false,
+    }
+}
+
+// TODO: We might not need this after defer pattern for Box is stabilized. 
https://github.com/rust-lang/rust/issues/87121
 fn are_inlist_and_eq(left: &Expr, right: &Expr) -> bool {
     let left = as_inlist(left);
     let right = as_inlist(right);
@@ -1519,6 +1643,78 @@ fn as_inlist(expr: &Expr) -> Option<Cow<InList>> {
     }
 }
 
+fn to_inlist(expr: Expr) -> Option<InList> {
+    match expr {
+        Expr::InList(inlist) => Some(inlist),
+        Expr::BinaryExpr(BinaryExpr {
+            left,
+            op: Operator::Eq,
+            right,
+        }) => match (left.as_ref(), right.as_ref()) {
+            (Expr::Column(_), Expr::Literal(_)) => Some(InList {
+                expr: left,
+                list: vec![*right],
+                negated: false,
+            }),
+            (Expr::Literal(_), Expr::Column(_)) => Some(InList {
+                expr: right,
+                list: vec![*left],
+                negated: false,
+            }),
+            _ => None,
+        },
+        _ => None,
+    }
+}
+
+/// Return the union of two inlist expressions
+/// maintaining the order of the elements in the two lists
+fn inlist_union(mut l1: InList, l2: InList, negated: bool) -> Result<Expr> {
+    // extend the list in l1 with the elements in l2 that are not already in l1
+    let l1_items: HashSet<_> = l1.list.iter().collect();
+
+    // keep all l2 items that do not also appear in l1
+    let keep_l2: Vec<_> = l2
+        .list
+        .into_iter()
+        .filter_map(|e| if l1_items.contains(&e) { None } else { Some(e) })
+        .collect();
+
+    l1.list.extend(keep_l2);
+    l1.negated = negated;
+    Ok(Expr::InList(l1))
+}
+
+/// Return the intersection of two inlist expressions
+/// maintaining the order of the elements in the two lists
+fn inlist_intersection(mut l1: InList, l2: InList, negated: bool) -> 
Result<Expr> {
+    let l2_items = l2.list.iter().collect::<HashSet<_>>();
+
+    // remove all items from l1 that are not in l2
+    l1.list.retain(|e| l2_items.contains(e));
+
+    // e in () is always false
+    // e not in () is always true
+    if l1.list.is_empty() {
+        return Ok(lit(negated));
+    }
+    Ok(Expr::InList(l1))
+}
+
+/// Return the all items in l1 that are not in l2
+/// maintaining the order of the elements in the two lists
+fn inlist_except(mut l1: InList, l2: InList) -> Result<Expr> {
+    let l2_items = l2.list.iter().collect::<HashSet<_>>();
+
+    // keep only items from l1 that are not in l2
+    l1.list.retain(|e| !l2_items.contains(e));
+
+    if l1.list.is_empty() {
+        return Ok(lit(false));
+    }
+    Ok(Expr::InList(l1))
+}
+
 #[cfg(test)]
 mod tests {
     use std::{
diff --git a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs 
b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs
index 5d1cf27827..9dcb8ed155 100644
--- a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs
@@ -19,12 +19,10 @@
 
 use super::THRESHOLD_INLINE_INLIST;
 
-use std::collections::HashSet;
-
 use datafusion_common::tree_node::{Transformed, TreeNodeRewriter};
 use datafusion_common::Result;
 use datafusion_expr::expr::InList;
-use datafusion_expr::{lit, BinaryExpr, Expr, Operator};
+use datafusion_expr::Expr;
 
 pub(super) struct ShortenInListSimplifier {}
 
@@ -97,121 +95,3 @@ impl TreeNodeRewriter for ShortenInListSimplifier {
         Ok(Transformed::no(expr))
     }
 }
-
-pub(super) struct InListSimplifier {}
-
-impl InListSimplifier {
-    pub(super) fn new() -> Self {
-        Self {}
-    }
-}
-
-impl TreeNodeRewriter for InListSimplifier {
-    type Node = Expr;
-
-    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
-        // Simplify expressions that is guaranteed to be true or false to a 
literal boolean expression
-        //
-        // Rules:
-        // If both expressions are `IN` or `NOT IN`, then we can apply 
intersection or union on both lists
-        //   Intersection:
-        //     1. `a in (1,2,3) AND a in (4,5) -> a in (), which is false`
-        //     2. `a in (1,2,3) AND a in (2,3,4) -> a in (2,3)`
-        //     3. `a not in (1,2,3) OR a not in (3,4,5,6) -> a not in (3)`
-        //   Union:
-        //     4. `a not int (1,2,3) AND a not in (4,5,6) -> a not in 
(1,2,3,4,5,6)`
-        //     # This rule is handled by `or_in_list_simplifier.rs`
-        //     5. `a in (1,2,3) OR a in (4,5,6) -> a in (1,2,3,4,5,6)`
-        // If one of the expressions is `IN` and another one is `NOT IN`, then 
we apply exception on `In` expression
-        //     6. `a in (1,2,3,4) AND a not in (1,2,3,4,5) -> a in (), which 
is false`
-        //     7. `a not in (1,2,3,4) AND a in (1,2,3,4,5) -> a = 5`
-        //     8. `a in (1,2,3,4) AND a not in (5,6,7,8) -> a in (1,2,3,4)`
-        if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr.clone() 
{
-            match (*left, op, *right) {
-                (Expr::InList(l1), Operator::And, Expr::InList(l2))
-                    if l1.expr == l2.expr && !l1.negated && !l2.negated =>
-                {
-                    return inlist_intersection(l1, l2, 
false).map(Transformed::yes);
-                }
-                (Expr::InList(l1), Operator::And, Expr::InList(l2))
-                    if l1.expr == l2.expr && l1.negated && l2.negated =>
-                {
-                    return inlist_union(l1, l2, true).map(Transformed::yes);
-                }
-                (Expr::InList(l1), Operator::And, Expr::InList(l2))
-                    if l1.expr == l2.expr && !l1.negated && l2.negated =>
-                {
-                    return inlist_except(l1, l2).map(Transformed::yes);
-                }
-                (Expr::InList(l1), Operator::And, Expr::InList(l2))
-                    if l1.expr == l2.expr && l1.negated && !l2.negated =>
-                {
-                    return inlist_except(l2, l1).map(Transformed::yes);
-                }
-                (Expr::InList(l1), Operator::Or, Expr::InList(l2))
-                    if l1.expr == l2.expr && l1.negated && l2.negated =>
-                {
-                    return inlist_intersection(l1, l2, 
true).map(Transformed::yes);
-                }
-                (left, op, right) => {
-                    // put the expression back together
-                    return Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr {
-                        left: Box::new(left),
-                        op,
-                        right: Box::new(right),
-                    })));
-                }
-            }
-        }
-
-        Ok(Transformed::no(expr))
-    }
-}
-
-/// Return the union of two inlist expressions
-/// maintaining the order of the elements in the two lists
-fn inlist_union(mut l1: InList, l2: InList, negated: bool) -> Result<Expr> {
-    // extend the list in l1 with the elements in l2 that are not already in l1
-    let l1_items: HashSet<_> = l1.list.iter().collect();
-
-    // keep all l2 items that do not also appear in l1
-    let keep_l2: Vec<_> = l2
-        .list
-        .into_iter()
-        .filter_map(|e| if l1_items.contains(&e) { None } else { Some(e) })
-        .collect();
-
-    l1.list.extend(keep_l2);
-    l1.negated = negated;
-    Ok(Expr::InList(l1))
-}
-
-/// Return the intersection of two inlist expressions
-/// maintaining the order of the elements in the two lists
-fn inlist_intersection(mut l1: InList, l2: InList, negated: bool) -> 
Result<Expr> {
-    let l2_items = l2.list.iter().collect::<HashSet<_>>();
-
-    // remove all items from l1 that are not in l2
-    l1.list.retain(|e| l2_items.contains(e));
-
-    // e in () is always false
-    // e not in () is always true
-    if l1.list.is_empty() {
-        return Ok(lit(negated));
-    }
-    Ok(Expr::InList(l1))
-}
-
-/// Return the all items in l1 that are not in l2
-/// maintaining the order of the elements in the two lists
-fn inlist_except(mut l1: InList, l2: InList) -> Result<Expr> {
-    let l2_items = l2.list.iter().collect::<HashSet<_>>();
-
-    // keep only items from l1 that are not in l2
-    l1.list.retain(|e| !l2_items.contains(e));
-
-    if l1.list.is_empty() {
-        return Ok(lit(false));
-    }
-    Ok(Expr::InList(l1))
-}
diff --git a/datafusion/sqllogictest/test_files/predicates.slt 
b/datafusion/sqllogictest/test_files/predicates.slt
index 4c9254beef..33c9ff7c3e 100644
--- a/datafusion/sqllogictest/test_files/predicates.slt
+++ b/datafusion/sqllogictest/test_files/predicates.slt
@@ -781,4 +781,4 @@ logical_plan EmptyRelation
 physical_plan EmptyExec
 
 statement ok
-drop table t;
+drop table t;
\ No newline at end of file

Reply via email to