This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 558b3d6cbe Combine multiple `IN` lists in `ExprSimplifier` (#8949)
558b3d6cbe is described below

commit 558b3d6cbe0b2515540008871e09125f6b397fa0
Author: Jay Zhan <[email protected]>
AuthorDate: Wed Jan 24 04:21:42 2024 +0800

    Combine multiple `IN` lists in `ExprSimplifier` (#8949)
    
    * simplified always true or false expression
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fmt
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * clippy
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * a in and b not in cases
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fix except
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * update comments
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * update comments
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fix bugs
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fix union
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * clippy
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * add more test
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * rename
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    ---------
    
    Signed-off-by: jayzhan211 <[email protected]>
---
 .../src/simplify_expressions/expr_simplifier.rs    | 117 +++++++++++++++++-
 .../src/simplify_expressions/inlist_simplifier.rs  | 136 +++++++++++++++++++++
 .../optimizer/src/simplify_expressions/mod.rs      |   1 +
 .../simplify_expressions/or_in_list_simplifier.rs  |  12 +-
 datafusion/sqllogictest/test_files/predicates.slt  |  37 ++++++
 5 files changed, 297 insertions(+), 6 deletions(-)

diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs 
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index 674e85a55c..35450b1f32 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -19,8 +19,10 @@
 
 use std::ops::Not;
 
-use super::or_in_list_simplifier::OrInListSimplifier;
 use super::utils::*;
+use super::{
+    inlist_simplifier::InListSimplifier, 
or_in_list_simplifier::OrInListSimplifier,
+};
 use crate::analyzer::type_coercion::TypeCoercionRewriter;
 use crate::simplify_expressions::guarantees::GuaranteeRewriter;
 use crate::simplify_expressions::regex::simplify_regex_expr;
@@ -133,6 +135,7 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
         let mut simplifier = Simplifier::new(&self.info);
         let mut const_evaluator = 
ConstEvaluator::try_new(self.info.execution_props())?;
         let mut or_in_list_simplifier = OrInListSimplifier::new();
+        let mut inlist_simplifier = InListSimplifier::new();
         let mut guarantee_rewriter = GuaranteeRewriter::new(&self.guarantees);
 
         // TODO iterate until no changes are made during rewrite
@@ -142,6 +145,7 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
         expr.rewrite(&mut const_evaluator)?
             .rewrite(&mut simplifier)?
             .rewrite(&mut or_in_list_simplifier)?
+            .rewrite(&mut inlist_simplifier)?
             .rewrite(&mut guarantee_rewriter)?
             // run both passes twice to try an minimize simplifications that 
we missed
             .rewrite(&mut const_evaluator)?
@@ -3201,11 +3205,118 @@ mod tests {
             col("c1").eq(subquery1).or(col("c1").eq(subquery2))
         );
 
-        // c1 NOT IN (1, 2, 3, 4) OR c1 NOT IN (5, 6, 7, 8) ->
-        // c1 NOT IN (1, 2, 3, 4) OR c1 NOT IN (5, 6, 7, 8)
+        // 1. c1 IN (1,2,3,4) AND c1 IN (5,6,7,8) -> false
+        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], 
false).and(
+            in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], false),
+        );
+        assert_eq!(simplify(expr.clone()), lit(false));
+
+        // 2. c1 IN (1,2,3,4) AND c1 IN (4,5,6,7) -> c1 = 4
+        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], 
false).and(
+            in_list(col("c1"), vec![lit(4), lit(5), lit(6), lit(7)], false),
+        );
+        assert_eq!(simplify(expr.clone()), col("c1").eq(lit(4)));
+
+        // 3. c1 NOT IN (1, 2, 3, 4) OR c1 NOT IN (5, 6, 7, 8) -> true
         let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], 
true).or(
             in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], true),
         );
+        assert_eq!(simplify(expr.clone()), lit(true));
+
+        // 4. c1 NOT IN (1,2,3,4) AND c1 NOT IN (4,5,6,7) -> c1 NOT IN 
(1,2,3,4,5,6,7)
+        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], 
true).and(
+            in_list(col("c1"), vec![lit(4), lit(5), lit(6), lit(7)], true),
+        );
+        assert_eq!(
+            simplify(expr.clone()),
+            in_list(
+                col("c1"),
+                vec![lit(1), lit(2), lit(3), lit(4), lit(5), lit(6), lit(7)],
+                true
+            )
+        );
+
+        // 5. c1 IN (1,2,3,4) OR c1 IN (2,3,4,5) -> c1 IN (1,2,3,4,5)
+        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], 
false).or(
+            in_list(col("c1"), vec![lit(2), lit(3), lit(4), lit(5)], false),
+        );
+        assert_eq!(
+            simplify(expr.clone()),
+            in_list(
+                col("c1"),
+                vec![lit(1), lit(2), lit(3), lit(4), lit(5)],
+                false
+            )
+        );
+
+        // 6. c1 IN (1,2,3) AND c1 NOT INT (1,2,3,4,5) -> false
+        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3)], 
false).and(in_list(
+            col("c1"),
+            vec![lit(1), lit(2), lit(3), lit(4), lit(5)],
+            true,
+        ));
+        assert_eq!(simplify(expr.clone()), lit(false));
+
+        // 7. c1 NOT IN (1,2,3,4) AND c1 IN (1,2,3,4,5) -> c1 = 5
+        let expr =
+            in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], 
true).and(in_list(
+                col("c1"),
+                vec![lit(1), lit(2), lit(3), lit(4), lit(5)],
+                false,
+            ));
+        assert_eq!(simplify(expr.clone()), col("c1").eq(lit(5)));
+
+        // 8. c1 IN (1,2,3,4) AND c1 NOT IN (5,6,7,8) -> c1 IN (1,2,3,4)
+        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], 
false).and(
+            in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], true),
+        );
+        assert_eq!(
+            simplify(expr.clone()),
+            in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false)
+        );
+
+        // inlist with more than two expressions
+        // c1 IN (1,2,3,4,5,6) AND c1 IN (1,3,5,6) AND c1 IN (3,6) -> c1 = 3 
OR c1 = 6
+        let expr = in_list(
+            col("c1"),
+            vec![lit(1), lit(2), lit(3), lit(4), lit(5), lit(6)],
+            false,
+        )
+        .and(in_list(
+            col("c1"),
+            vec![lit(1), lit(3), lit(5), lit(6)],
+            false,
+        ))
+        .and(in_list(col("c1"), vec![lit(3), lit(6)], false));
+        assert_eq!(
+            simplify(expr.clone()),
+            col("c1").eq(lit(3)).or(col("c1").eq(lit(6)))
+        );
+
+        // c1 NOT IN (1,2,3,4) AND c1 IN (5,6,7,8) AND c1 NOT IN (3,4,5,6) AND 
c1 IN (8,9,10) -> c1 = 8
+        let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], 
true).and(
+            in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], false)
+                .and(in_list(
+                    col("c1"),
+                    vec![lit(3), lit(4), lit(5), lit(6)],
+                    true,
+                ))
+                .and(in_list(col("c1"), vec![lit(8), lit(9), lit(10)], false)),
+        );
+        assert_eq!(simplify(expr.clone()), col("c1").eq(lit(8)));
+
+        // Contains non-InList expression
+        // c1 NOT IN (1,2,3,4) OR c1 != 5 OR c1 NOT IN (6,7,8,9) -> c1 NOT IN 
(1,2,3,4) OR c1 != 5 OR c1 NOT IN (6,7,8,9)
+        let expr =
+            in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], 
true).or(col("c1")
+                .not_eq(lit(5))
+                .or(in_list(
+                    col("c1"),
+                    vec![lit(6), lit(7), lit(8), lit(9)],
+                    true,
+                )));
+        // TODO: Further simplify this expression
+        // assert_eq!(simplify(expr.clone()), lit(true));
         assert_eq!(simplify(expr.clone()), expr);
     }
 
diff --git a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs 
b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs
new file mode 100644
index 0000000000..fa95f1688e
--- /dev/null
+++ b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs
@@ -0,0 +1,136 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! This module implements a rule that simplifies the values for `InList`s
+
+use std::collections::HashSet;
+
+use datafusion_common::tree_node::TreeNodeRewriter;
+use datafusion_common::Result;
+use datafusion_expr::expr::InList;
+use datafusion_expr::{lit, BinaryExpr, Expr, Operator};
+
+/// Simplify expressions that is guaranteed to be true or false to a literal 
boolean expression
+///
+/// Rules:
+/// If both expressions are `IN` or `NOT IN`, then we can apply intersection 
or union on both lists
+///   Intersection:
+///     1. `a in (1,2,3) AND a in (4,5) -> a in (), which is false`
+///     2. `a in (1,2,3) AND a in (2,3,4) -> a in (2,3)`
+///     3. `a not in (1,2,3) OR a not in (3,4,5,6) -> a not in (3)`
+///   Union:
+///     4. `a not int (1,2,3) AND a not in (4,5,6) -> a not in (1,2,3,4,5,6)`
+///     # This rule is handled by `or_in_list_simplifier.rs`
+///     5. `a in (1,2,3) OR a in (4,5,6) -> a in (1,2,3,4,5,6)`
+/// If one of the expressions is `IN` and another one is `NOT IN`, then we 
apply exception on `In` expression
+///     6. `a in (1,2,3,4) AND a not in (1,2,3,4,5) -> a in (), which is false`
+///     7. `a not in (1,2,3,4) AND a in (1,2,3,4,5) -> a = 5`
+///     8. `a in (1,2,3,4) AND a not in (5,6,7,8) -> a in (1,2,3,4)`
+pub(super) struct InListSimplifier {}
+
+impl InListSimplifier {
+    pub(super) fn new() -> Self {
+        Self {}
+    }
+}
+
+impl TreeNodeRewriter for InListSimplifier {
+    type N = Expr;
+
+    fn mutate(&mut self, expr: Expr) -> Result<Expr> {
+        if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = &expr {
+            if let (Expr::InList(l1), Operator::And, Expr::InList(l2)) =
+                (left.as_ref(), op, right.as_ref())
+            {
+                if l1.expr == l2.expr && !l1.negated && !l2.negated {
+                    return inlist_intersection(l1, l2, false);
+                } else if l1.expr == l2.expr && l1.negated && l2.negated {
+                    return inlist_union(l1, l2, true);
+                } else if l1.expr == l2.expr && !l1.negated && l2.negated {
+                    return inlist_except(l1, l2);
+                } else if l1.expr == l2.expr && l1.negated && !l2.negated {
+                    return inlist_except(l2, l1);
+                }
+            } else if let (Expr::InList(l1), Operator::Or, Expr::InList(l2)) =
+                (left.as_ref(), op, right.as_ref())
+            {
+                if l1.expr == l2.expr && l1.negated && l2.negated {
+                    return inlist_intersection(l1, l2, true);
+                }
+            }
+        }
+
+        Ok(expr)
+    }
+}
+
+fn inlist_union(l1: &InList, l2: &InList, negated: bool) -> Result<Expr> {
+    let mut seen: HashSet<Expr> = HashSet::new();
+    let list = l1
+        .list
+        .iter()
+        .chain(l2.list.iter())
+        .filter(|&e| seen.insert(e.to_owned()))
+        .cloned()
+        .collect::<Vec<_>>();
+    let merged_inlist = InList {
+        expr: l1.expr.clone(),
+        list,
+        negated,
+    };
+    Ok(Expr::InList(merged_inlist))
+}
+
+fn inlist_intersection(l1: &InList, l2: &InList, negated: bool) -> 
Result<Expr> {
+    let l1_set: HashSet<Expr> = l1.list.iter().cloned().collect();
+    let intersect_list: Vec<Expr> = l2
+        .list
+        .iter()
+        .filter(|x| l1_set.contains(x))
+        .cloned()
+        .collect();
+    // e in () is always false
+    // e not in () is always true
+    if intersect_list.is_empty() {
+        return Ok(lit(negated));
+    }
+    let merged_inlist = InList {
+        expr: l1.expr.clone(),
+        list: intersect_list,
+        negated,
+    };
+    Ok(Expr::InList(merged_inlist))
+}
+
+fn inlist_except(l1: &InList, l2: &InList) -> Result<Expr> {
+    let l2_set: HashSet<Expr> = l2.list.iter().cloned().collect();
+    let except_list: Vec<Expr> = l1
+        .list
+        .iter()
+        .filter(|x| !l2_set.contains(x))
+        .cloned()
+        .collect();
+    if except_list.is_empty() {
+        return Ok(lit(false));
+    }
+    let merged_inlist = InList {
+        expr: l1.expr.clone(),
+        list: except_list,
+        negated: false,
+    };
+    Ok(Expr::InList(merged_inlist))
+}
diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs 
b/datafusion/optimizer/src/simplify_expressions/mod.rs
index 2cf6ed166c..44ba5b3e3b 100644
--- a/datafusion/optimizer/src/simplify_expressions/mod.rs
+++ b/datafusion/optimizer/src/simplify_expressions/mod.rs
@@ -18,6 +18,7 @@
 pub mod context;
 pub mod expr_simplifier;
 mod guarantees;
+mod inlist_simplifier;
 mod or_in_list_simplifier;
 mod regex;
 pub mod simplify_exprs;
diff --git 
a/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs 
b/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs
index cebaaccc41..fd5c9ecaf8 100644
--- a/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs
@@ -18,6 +18,7 @@
 //! This module implements a rule that simplifies OR expressions into IN list 
expressions
 
 use std::borrow::Cow;
+use std::collections::HashSet;
 
 use datafusion_common::tree_node::TreeNodeRewriter;
 use datafusion_common::Result;
@@ -52,9 +53,14 @@ impl TreeNodeRewriter for OrInListSimplifier {
                     {
                         let lhs = lhs.into_owned();
                         let rhs = rhs.into_owned();
-                        let mut list = vec![];
-                        list.extend(lhs.list);
-                        list.extend(rhs.list);
+                        let mut seen: HashSet<Expr> = HashSet::new();
+                        let list = lhs
+                            .list
+                            .into_iter()
+                            .chain(rhs.list)
+                            .filter(|e| seen.insert(e.to_owned()))
+                            .collect::<Vec<_>>();
+
                         let merged_inlist = InList {
                             expr: lhs.expr,
                             list,
diff --git a/datafusion/sqllogictest/test_files/predicates.slt 
b/datafusion/sqllogictest/test_files/predicates.slt
index e32e415338..b5347f997a 100644
--- a/datafusion/sqllogictest/test_files/predicates.slt
+++ b/datafusion/sqllogictest/test_files/predicates.slt
@@ -725,3 +725,40 @@ AggregateExec: mode=SinglePartitioned, gby=[p_partkey@2 as 
p_partkey], aggr=[SUM
 --------CoalesceBatchesExec: target_batch_size=8192
 ----------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), 
input_partitions=1
 ------------MemoryExec: partitions=1, partition_sizes=[1]
+
+# Inlist simplification
+
+statement ok
+create table t(x int) as values (1), (2), (3);
+
+query TT
+explain select x from t where x IN (1,2,3) AND x IN (4,5);
+----
+logical_plan EmptyRelation
+physical_plan EmptyExec
+
+query TT
+explain select x from t where x NOT IN (1,2,3,4) OR x NOT IN (5,6,7,8);
+----
+logical_plan TableScan: t projection=[x]
+physical_plan MemoryExec: partitions=1, partition_sizes=[1]
+
+query TT
+explain select x from t where x IN (1,2,3,4,5) AND x NOT IN (1,2,3,4);
+----
+logical_plan
+Filter: t.x = Int32(5)
+--TableScan: t projection=[x]
+physical_plan
+CoalesceBatchesExec: target_batch_size=8192
+--FilterExec: x@0 = 5
+----MemoryExec: partitions=1, partition_sizes=[1]
+
+query TT
+explain select x from t where x NOT IN (1,2,3,4,5) AND x IN (1,2,3);
+----
+logical_plan EmptyRelation
+physical_plan EmptyExec
+
+statement ok
+drop table t;

Reply via email to