alamb commented on code in PR #6414:
URL: https://github.com/apache/arrow-datafusion/pull/6414#discussion_r1210293125


##########
benchmarks/expected-plans/q19.txt:
##########
@@ -6,7 +6,7 @@
 |               |     Projection: lineitem.l_extendedprice, 
lineitem.l_discount                                                             
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                                                         
                                                |
 |               |       Inner Join: lineitem.l_partkey = part.p_partkey 
Filter: part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM 
CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND 
lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= 
Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = 
Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), 
Utf8("MED PKG"), Utf8("MED PACK")]) AND lineitem.l_quantity >= 
Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= 
Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = 
Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), 
Utf8("LG PACK"), Utf8("LG PKG")]) AND lineitem.l_quantity >= 
Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= 
Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15)                        
                                                                                
               
                                                |
 |               |         Projection: lineitem.l_partkey, lineitem.l_quantity, 
lineitem.l_extendedprice, lineitem.l_discount                                   
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                                      
                                                |
-|               |           Filter: (lineitem.l_quantity >= 
Decimal128(Some(100),15,2) AND lineitem.l_quantity <= 
Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= 
Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= 
Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= 
Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= 
Decimal128(Some(3000),15,2)) AND (lineitem.l_shipmode = Utf8("AIR REG") OR 
lineitem.l_shipmode = Utf8("AIR")) AND lineitem.l_shipinstruct = Utf8("DELIVER 
IN PERSON")                                                                     
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                               
                                                |
+|               |           Filter: (lineitem.l_quantity >= 
Decimal128(Some(100),15,2) AND lineitem.l_quantity <= 
Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= 
Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= 
Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= 
Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= 
Decimal128(Some(3000),15,2)) AND (lineitem.l_shipmode = Utf8("AIR") OR 
lineitem.l_shipmode = Utf8("AIR REG")) AND lineitem.l_shipinstruct = 
Utf8("DELIVER IN PERSON")                                                       
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                                             
                                                |

Review Comment:
   these differences also appear to be that the order of some of the OR 
predicates has changed,  which is fine as described above



##########
benchmarks/src/bin/tpch.rs:
##########
@@ -549,6 +555,55 @@ mod tests {
         Ok(())
     }
 
+    struct Line(Option<usize>);

Review Comment:
   I know this is a personal preference but if you want to introduce a new 
crate for comparing / diffing, I recommend adding https://insta.rs/ 
   1. It includes automatic updating (`cargo insta review`)
   2. It comes with its own diff generation tool
   3. We have used it with good luck in influxdb_iox



##########
benchmarks/src/bin/tpch.rs:
##########
@@ -549,6 +555,55 @@ mod tests {
         Ok(())
     }
 
+    struct Line(Option<usize>);
+
+    impl fmt::Display for Line {
+        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+            match self.0 {
+                None => write!(f, "    "),
+                Some(idx) => write!(f, "{:<4}", idx + 1),
+            }
+        }
+    }
+
+    fn assert_text_eq(expected: String, actual: String) {

Review Comment:
   In general I think it would help a lot to remove tests from the benchmarks 
-- @liurenjie1024 is making good progress in this area with  
https://github.com/apache/arrow-datafusion/pull/6435
   
   Maybe we could avoid adding this extra comparison to the `tpch` binary and 
instead work on getting the tests out of the benchmark runner in the first 
place 🤔 



##########
datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs:
##########
@@ -0,0 +1,92 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! This module implements a rule that simplifies OR expressions into IN list 
expressions
+
+use datafusion_common::tree_node::TreeNodeRewriter;
+use datafusion_common::Result;
+use datafusion_expr::expr::InList;
+use datafusion_expr::{BinaryExpr, Expr, Operator};
+
+/// Combine multiple OR expressions into a single IN list expression if 
possible
+///
+/// i.e. `a = 1 OR a = 2 OR a = 3` -> `a IN (1, 2, 3)`
+pub(super) struct OrInListSimplifier {}
+
+impl OrInListSimplifier {
+    pub(super) fn new() -> Self {
+        Self {}
+    }
+}
+
+impl TreeNodeRewriter for OrInListSimplifier {
+    type N = Expr;
+
+    fn mutate(&mut self, expr: Expr) -> Result<Expr> {
+        if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = &expr {
+            if *op == Operator::Or {
+                let left = as_inlist(left);
+                let right = as_inlist(right);
+                if let (Some(lhs), Some(rhs)) = (left, right) {
+                    if lhs.expr.try_into_col().is_ok()
+                        && rhs.expr.try_into_col().is_ok()
+                        && lhs.expr == rhs.expr
+                        && !lhs.negated
+                        && !rhs.negated
+                    {
+                        let mut list = vec![];
+                        list.extend(lhs.list);
+                        list.extend(rhs.list);
+                        let merged_inlist = InList {
+                            expr: lhs.expr,
+                            list,
+                            negated: false,
+                        };
+                        return Ok(Expr::InList(merged_inlist));
+                    }
+                }
+            }
+        }
+
+        Ok(expr)
+    }
+}
+
+/// Try to convert an expression to an in-list expression
+fn as_inlist(expr: &Expr) -> Option<InList> {
+    match expr {
+        Expr::InList(inlist) => Some(inlist.clone()),
+        Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == 
Operator::Eq => {
+            let unboxed_left = *left.clone();
+            let unboxed_right = *right.clone();
+            match (&unboxed_left, &unboxed_right) {

Review Comment:
   You can avoid these clones by doing something like:
   
   ```
           Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == 
Operator::Eq => {
               match (left.as_ref(), right.as_ref()) {
   ```
   
   I tried it out locally and I think the following diff results in fewer 
clones:
   
   ```diff
   diff --git 
a/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs 
b/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs
   index 10f3aa027..9192fbb77 100644
   --- a/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs
   +++ b/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs
   @@ -17,6 +17,8 @@
    
    //! This module implements a rule that simplifies OR expressions into IN 
list expressions
    
   +use std::borrow::Cow;
   +
    use datafusion_common::tree_node::TreeNodeRewriter;
    use datafusion_common::Result;
    use datafusion_expr::expr::InList;
   @@ -48,6 +50,8 @@ impl TreeNodeRewriter for OrInListSimplifier {
                            && !lhs.negated
                            && !rhs.negated
                        {
   +                        let lhs = lhs.into_owned();
   +                        let rhs = rhs.into_owned();
                            let mut list = vec![];
                            list.extend(lhs.list);
                            list.extend(rhs.list);
   @@ -67,23 +71,21 @@ impl TreeNodeRewriter for OrInListSimplifier {
    }
    
    /// Try to convert an expression to an in-list expression
   -fn as_inlist(expr: &Expr) -> Option<InList> {
   +fn as_inlist(expr: &Expr) -> Option<Cow<InList>> {
        match expr {
   -        Expr::InList(inlist) => Some(inlist.clone()),
   +        Expr::InList(inlist) => Some(Cow::Borrowed(&inlist)),
            Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == 
Operator::Eq => {
   -            let unboxed_left = *left.clone();
   -            let unboxed_right = *right.clone();
   -            match (&unboxed_left, &unboxed_right) {
   -                (Expr::Column(_), Expr::Literal(_)) => Some(InList {
   +            match (left.as_ref(), right.as_ref()) {
   +                (Expr::Column(_), Expr::Literal(_)) => 
Some(Cow::Owned(InList {
                        expr: left.clone(),
   -                    list: vec![unboxed_right],
   +                    list: vec![*right.clone()],
                        negated: false,
   -                }),
   -                (Expr::Literal(_), Expr::Column(_)) => Some(InList {
   +                })),
   +                (Expr::Literal(_), Expr::Column(_)) => 
Some(Cow::Owned(InList {
                        expr: right.clone(),
   -                    list: vec![unboxed_left],
   +                    list: vec![*left.clone()],
                        negated: false,
   -                }),
   +                })),
                    _ => None,
                }
            }
   ```



##########
datafusion/core/tests/sqllogictests/test_files/predicates.slt:
##########
@@ -249,6 +249,76 @@ SELECT * FROM test WHERE column1 IN ('foo', 'Bar', 'fazzz')
 foo
 fazzz
 
+
+###
+# Test logical plan simplifies large OR chains
+###
+
+statement ok
+set datafusion.explain.logical_plan_only = true
+
+# Number of OR statements is less than or equal to threshold
+query TT
+EXPLAIN SELECT * FROM test WHERE column1 = 'foo' OR column1 = 'bar' OR column1 
= 'fazzz'
+----
+logical_plan
+Filter: test.column1 = Utf8("foo") OR test.column1 = Utf8("bar") OR 
test.column1 = Utf8("fazzz")
+--TableScan: test projection=[column1]
+
+# Number of OR statements is greater than threshold
+query TT
+EXPLAIN SELECT * FROM test WHERE column1 = 'foo' OR column1 = 'bar' OR column1 
= 'fazzz' OR column1 = 'barfoo'
+----
+logical_plan
+Filter: test.column1 IN ([Utf8("foo"), Utf8("bar"), Utf8("fazzz"), 
Utf8("barfoo")])
+--TableScan: test projection=[column1]
+
+# Complex OR statements
+query TT
+EXPLAIN SELECT * FROM test WHERE column1 = 'foo' OR column1 = 'bar' OR column1 
= 'fazzz' OR column1 = 'barfoo' OR false OR column1 = 'foobar'
+----
+logical_plan
+Filter: test.column1 IN ([Utf8("foo"), Utf8("bar"), Utf8("fazzz"), 
Utf8("barfoo"), Utf8("foobar")])
+--TableScan: test projection=[column1]
+
+# Balanced OR structures
+query TT
+EXPLAIN SELECT * FROM test WHERE (column1 = 'foo' OR column1 = 'bar') OR 
(column1 = 'fazzz' OR column1 = 'barfoo')
+----
+logical_plan
+Filter: test.column1 IN ([Utf8("foo"), Utf8("bar"), Utf8("fazzz"), 
Utf8("barfoo")])
+--TableScan: test projection=[column1]
+
+# Right-deep OR structures

Review Comment:
   👍 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to