This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new bc0ba6a724 Enhance simplifier by adding Canonicalize (#8780)
bc0ba6a724 is described below

commit bc0ba6a724aaaf312b770451718cbce696de6640
Author: Junhao Liu <[email protected]>
AuthorDate: Wed Jan 24 06:39:53 2024 -0600

    Enhance simplifier by adding Canonicalize (#8780)
    
    * Enhance simplifier by adding Canonicalize
    
    * fix swap operation
    
    * Using Match to match cases
    
    * feat: refactor code using better match and revise doc
    
    * Update datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
    
    Co-authored-by: Andrew Lamb <[email protected]>
    
    * add more robust test case and code format more Rustic
    
    * Update datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
    
    Co-authored-by: Jeffrey Vo <[email protected]>
    
    * Make Join unrelated to order
    
    * fmt doc
    
    * Check Join for both sides
    
    * No canonicalize for Join
    
    * cargo fmt
    
    * Add comment for dup codes
    
    * Fix test cases
    
    * remove wrong git commit
    
    * Update datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs
    
    Co-authored-by: Jeffrey Vo <[email protected]>
    
    * remove outdated change
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
    Co-authored-by: Jeffrey Vo <[email protected]>
---
 .../src/simplify_expressions/expr_simplifier.rs    | 104 ++++++++++++++++++++-
 .../src/simplify_expressions/simplify_exprs.rs     |  38 ++++++--
 datafusion/sqllogictest/test_files/joins.slt       |   6 +-
 datafusion/sqllogictest/test_files/subquery.slt    |   2 +-
 .../sqllogictest/test_files/tpch/q12.slt.part      |   6 +-
 .../sqllogictest/test_files/tpch/q4.slt.part       |   6 +-
 6 files changed, 141 insertions(+), 21 deletions(-)

diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs 
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index 95536e9fc5..561fe1d12d 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -151,6 +151,10 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
             .rewrite(&mut simplifier)
     }
 
+    pub fn canonicalize(&self, expr: Expr) -> Result<Expr> {
+        let mut canonicalizer = Canonicalizer::new();
+        expr.rewrite(&mut canonicalizer)
+    }
     /// Apply type coercion to an [`Expr`] so that it can be
     /// evaluated as a 
[`PhysicalExpr`](datafusion_physical_expr::PhysicalExpr).
     ///
@@ -227,6 +231,51 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
     }
 }
 
+/// Canonicalize any BinaryExprs that are not in canonical form
+///
+/// `<literal> <op> <col>` is rewritten to `<col> <op> <literal>`
+///
+/// `<col1> <op> <col2>` is rewritten so that the name of `col1` sorts higher
+/// than `col2` (`b > a` would be canonicalized to `a < b`)
+struct Canonicalizer {}
+
+impl Canonicalizer {
+    fn new() -> Self {
+        Self {}
+    }
+}
+
+impl TreeNodeRewriter for Canonicalizer {
+    type N = Expr;
+
+    fn mutate(&mut self, expr: Expr) -> Result<Expr> {
+        let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr else {
+            return Ok(expr);
+        };
+        match (left.as_ref(), right.as_ref(), op.swap()) {
+            // <col1> <op> <col2>
+            (Expr::Column(left_col), Expr::Column(right_col), Some(swapped_op))
+                if right_col > left_col =>
+            {
+                Ok(Expr::BinaryExpr(BinaryExpr {
+                    left: right,
+                    op: swapped_op,
+                    right: left,
+                }))
+            }
+            // <literal> <op> <col>
+            (Expr::Literal(_a), Expr::Column(_b), Some(swapped_op)) => {
+                Ok(Expr::BinaryExpr(BinaryExpr {
+                    left: right,
+                    op: swapped_op,
+                    right: left,
+                }))
+            }
+            _ => Ok(Expr::BinaryExpr(BinaryExpr { left, op, right })),
+        }
+    }
+}
+
 #[allow(rustdoc::private_intra_doc_links)]
 /// Partially evaluate `Expr`s so constant subtrees are evaluated at plan time.
 ///
@@ -1612,6 +1661,58 @@ mod tests {
     // --- Simplifier tests -----
     // ------------------------------
 
+    #[test]
+    fn test_simplify_canonicalize() {
+        {
+            let expr = lit(1).lt(col("c2")).and(col("c2").gt(lit(1)));
+            let expected = col("c2").gt(lit(1));
+            assert_eq!(simplify(expr), expected);
+        }
+        {
+            let expr = col("c1").lt(col("c2")).and(col("c2").gt(col("c1")));
+            let expected = col("c2").gt(col("c1"));
+            assert_eq!(simplify(expr), expected);
+        }
+        {
+            let expr = col("c1")
+                .eq(lit(1))
+                .and(lit(1).eq(col("c1")))
+                .and(col("c1").eq(lit(3)));
+            let expected = col("c1").eq(lit(1)).and(col("c1").eq(lit(3)));
+            assert_eq!(simplify(expr), expected);
+        }
+        {
+            let expr = col("c1")
+                .eq(col("c2"))
+                .and(col("c1").gt(lit(5)))
+                .and(col("c2").eq(col("c1")));
+            let expected = col("c2").eq(col("c1")).and(col("c1").gt(lit(5)));
+            assert_eq!(simplify(expr), expected);
+        }
+        {
+            let expr = col("c1")
+                .eq(lit(1))
+                .and(col("c2").gt(lit(3)).or(lit(3).lt(col("c2"))));
+            let expected = col("c1").eq(lit(1)).and(col("c2").gt(lit(3)));
+            assert_eq!(simplify(expr), expected);
+        }
+        {
+            let expr = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
+            let expected = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
+            assert_eq!(simplify(expr), expected);
+        }
+        {
+            let expr = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
+            let expected = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
+            assert_eq!(simplify(expr), expected);
+        }
+        {
+            let expr = col("c1").gt(col("c2")).and(col("c1").gt(col("c2")));
+            let expected = col("c2").lt(col("c1"));
+            assert_eq!(simplify(expr), expected);
+        }
+    }
+
     #[test]
     fn test_simplify_or_true() {
         let expr_a = col("c2").or(lit(true));
@@ -2807,7 +2908,8 @@ mod tests {
         let simplifier = ExprSimplifier::new(
             SimplifyContext::new(&execution_props).with_schema(schema),
         );
-        simplifier.simplify(expr)
+        let cano = simplifier.canonicalize(expr)?;
+        simplifier.simplify(cano)
     }
 
     fn simplify(expr: Expr) -> Expr {
diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs 
b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs
index cfd02547b8..7265b17dd0 100644
--- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs
+++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs
@@ -93,16 +93,34 @@ impl SimplifyExpressions {
             .map(|input| Self::optimize_internal(input, execution_props))
             .collect::<Result<Vec<_>>>()?;
 
-        let expr = plan
-            .expressions()
-            .into_iter()
-            .map(|e| {
-                // TODO: unify with `rewrite_preserving_name`
-                let original_name = e.name_for_alias()?;
-                let new_e = simplifier.simplify(e)?;
-                new_e.alias_if_changed(original_name)
-            })
-            .collect::<Result<Vec<_>>>()?;
+        let expr = match plan {
+            // Canonicalize step won't reorder expressions in a Join on clause.
+            // The left and right expressions in a Join on clause are not 
commutative,
+            // since the order of the columns must match the order of the 
children.
+            LogicalPlan::Join(_) => {
+                plan.expressions()
+                    .into_iter()
+                    .map(|e| {
+                        // TODO: unify with `rewrite_preserving_name`
+                        let original_name = e.name_for_alias()?;
+                        let new_e = simplifier.simplify(e)?;
+                        new_e.alias_if_changed(original_name)
+                    })
+                    .collect::<Result<Vec<_>>>()?
+            }
+            _ => {
+                plan.expressions()
+                    .into_iter()
+                    .map(|e| {
+                        // TODO: unify with `rewrite_preserving_name`
+                        let original_name = e.name_for_alias()?;
+                        let cano_e = simplifier.canonicalize(e)?;
+                        let new_e = simplifier.simplify(cano_e)?;
+                        new_e.alias_if_changed(original_name)
+                    })
+                    .collect::<Result<Vec<_>>>()?
+            }
+        };
 
         plan.with_new_exprs(expr, &new_inputs)
     }
diff --git a/datafusion/sqllogictest/test_files/joins.slt 
b/datafusion/sqllogictest/test_files/joins.slt
index a7146a5a91..e605813b20 100644
--- a/datafusion/sqllogictest/test_files/joins.slt
+++ b/datafusion/sqllogictest/test_files/joins.slt
@@ -1109,7 +1109,7 @@ RIGHT JOIN join_t2 on join_t1.t1_id = join_t2.t2_id
 WHERE NOT (join_t1.t1_int = join_t2.t2_int)
 ----
 logical_plan
-Inner Join: join_t1.t1_id = join_t2.t2_id Filter: join_t1.t1_int != 
join_t2.t2_int
+Inner Join: join_t1.t1_id = join_t2.t2_id Filter: join_t2.t2_int != 
join_t1.t1_int
 --TableScan: join_t1 projection=[t1_id, t1_name, t1_int]
 --TableScan: join_t2 projection=[t2_id, t2_name, t2_int]
 
@@ -3472,13 +3472,13 @@ FROM annotated_data as l, annotated_data as r
 WHERE l.a > r.a
 ----
 logical_plan
-Inner Join:  Filter: l.a > r.a
+Inner Join:  Filter: r.a < l.a
 --SubqueryAlias: l
 ----TableScan: annotated_data projection=[a0, a, b, c, d]
 --SubqueryAlias: r
 ----TableScan: annotated_data projection=[a0, a, b, c, d]
 physical_plan
-NestedLoopJoinExec: join_type=Inner, filter=a@0 > a@1
+NestedLoopJoinExec: join_type=Inner, filter=a@1 < a@0
 --RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
 ----CsvExec: file_groups={1 group: 
[[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, 
b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], 
has_header=true
 --CsvExec: file_groups={1 group: 
[[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, 
b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], 
has_header=true
diff --git a/datafusion/sqllogictest/test_files/subquery.slt 
b/datafusion/sqllogictest/test_files/subquery.slt
index 3e0fcb7aa9..1ca9045f1b 100644
--- a/datafusion/sqllogictest/test_files/subquery.slt
+++ b/datafusion/sqllogictest/test_files/subquery.slt
@@ -396,7 +396,7 @@ LeftSemi Join: t1.t1_id = __correlated_sq_1.t1_int
 --TableScan: t1 projection=[t1_id, t1_name, t1_int]
 --SubqueryAlias: __correlated_sq_1
 ----Projection: t1.t1_int
-------Filter: t1.t1_id > t1.t1_int
+------Filter: t1.t1_int < t1.t1_id
 --------TableScan: t1 projection=[t1_id, t1_int]
 
 #in_subquery_nested_exist_subquery
diff --git a/datafusion/sqllogictest/test_files/tpch/q12.slt.part 
b/datafusion/sqllogictest/test_files/tpch/q12.slt.part
index 09939359ce..68ef41b382 100644
--- a/datafusion/sqllogictest/test_files/tpch/q12.slt.part
+++ b/datafusion/sqllogictest/test_files/tpch/q12.slt.part
@@ -55,8 +55,8 @@ Sort: lineitem.l_shipmode ASC NULLS LAST
 ------Projection: lineitem.l_shipmode, orders.o_orderpriority
 --------Inner Join: lineitem.l_orderkey = orders.o_orderkey
 ----------Projection: lineitem.l_orderkey, lineitem.l_shipmode
-------------Filter: (lineitem.l_shipmode = Utf8("MAIL") OR lineitem.l_shipmode 
= Utf8("SHIP")) AND lineitem.l_commitdate < lineitem.l_receiptdate AND 
lineitem.l_shipdate < lineitem.l_commitdate AND lineitem.l_receiptdate >= 
Date32("8766") AND lineitem.l_receiptdate < Date32("9131")
---------------TableScan: lineitem projection=[l_orderkey, l_shipdate, 
l_commitdate, l_receiptdate, l_shipmode], partial_filters=[lineitem.l_shipmode 
= Utf8("MAIL") OR lineitem.l_shipmode = Utf8("SHIP"), lineitem.l_commitdate < 
lineitem.l_receiptdate, lineitem.l_shipdate < lineitem.l_commitdate, 
lineitem.l_receiptdate >= Date32("8766"), lineitem.l_receiptdate < 
Date32("9131")]
+------------Filter: (lineitem.l_shipmode = Utf8("MAIL") OR lineitem.l_shipmode 
= Utf8("SHIP")) AND lineitem.l_receiptdate > lineitem.l_commitdate AND 
lineitem.l_shipdate < lineitem.l_commitdate AND lineitem.l_receiptdate >= 
Date32("8766") AND lineitem.l_receiptdate < Date32("9131")
+--------------TableScan: lineitem projection=[l_orderkey, l_shipdate, 
l_commitdate, l_receiptdate, l_shipmode], partial_filters=[lineitem.l_shipmode 
= Utf8("MAIL") OR lineitem.l_shipmode = Utf8("SHIP"), lineitem.l_receiptdate > 
lineitem.l_commitdate, lineitem.l_shipdate < lineitem.l_commitdate, 
lineitem.l_receiptdate >= Date32("8766"), lineitem.l_receiptdate < 
Date32("9131")]
 ----------TableScan: orders projection=[o_orderkey, o_orderpriority]
 physical_plan
 SortPreservingMergeExec: [l_shipmode@0 ASC NULLS LAST]
@@ -73,7 +73,7 @@ SortPreservingMergeExec: [l_shipmode@0 ASC NULLS LAST]
 ----------------------RepartitionExec: partitioning=Hash([l_orderkey@0], 4), 
input_partitions=4
 ------------------------ProjectionExec: expr=[l_orderkey@0 as l_orderkey, 
l_shipmode@4 as l_shipmode]
 --------------------------CoalesceBatchesExec: target_batch_size=8192
-----------------------------FilterExec: (l_shipmode@4 = MAIL OR l_shipmode@4 = 
SHIP) AND l_commitdate@2 < l_receiptdate@3 AND l_shipdate@1 < l_commitdate@2 
AND l_receiptdate@3 >= 8766 AND l_receiptdate@3 < 9131
+----------------------------FilterExec: (l_shipmode@4 = MAIL OR l_shipmode@4 = 
SHIP) AND l_receiptdate@3 > l_commitdate@2 AND l_shipdate@1 < l_commitdate@2 
AND l_receiptdate@3 >= 8766 AND l_receiptdate@3 < 9131
 ------------------------------CsvExec: file_groups={4 groups: 
[[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749],
 
[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498],
 
[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247],
 
[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]},
 projection=[l_orderkey, l_shipdate, l_commitdate, l_re [...]
 --------------------CoalesceBatchesExec: target_batch_size=8192
 ----------------------RepartitionExec: partitioning=Hash([o_orderkey@0], 4), 
input_partitions=4
diff --git a/datafusion/sqllogictest/test_files/tpch/q4.slt.part 
b/datafusion/sqllogictest/test_files/tpch/q4.slt.part
index 690ef64bc2..1709ae04aa 100644
--- a/datafusion/sqllogictest/test_files/tpch/q4.slt.part
+++ b/datafusion/sqllogictest/test_files/tpch/q4.slt.part
@@ -50,8 +50,8 @@ Sort: orders.o_orderpriority ASC NULLS LAST
 --------------TableScan: orders projection=[o_orderkey, o_orderdate, 
o_orderpriority], partial_filters=[orders.o_orderdate >= Date32("8582"), 
orders.o_orderdate < Date32("8674")]
 ----------SubqueryAlias: __correlated_sq_1
 ------------Projection: lineitem.l_orderkey
---------------Filter: lineitem.l_commitdate < lineitem.l_receiptdate
-----------------TableScan: lineitem projection=[l_orderkey, l_commitdate, 
l_receiptdate], partial_filters=[lineitem.l_commitdate < lineitem.l_receiptdate]
+--------------Filter: lineitem.l_receiptdate > lineitem.l_commitdate
+----------------TableScan: lineitem projection=[l_orderkey, l_commitdate, 
l_receiptdate], partial_filters=[lineitem.l_receiptdate > lineitem.l_commitdate]
 physical_plan
 SortPreservingMergeExec: [o_orderpriority@0 ASC NULLS LAST]
 --SortExec: expr=[o_orderpriority@0 ASC NULLS LAST]
@@ -73,7 +73,7 @@ SortPreservingMergeExec: [o_orderpriority@0 ASC NULLS LAST]
 ----------------------RepartitionExec: partitioning=Hash([l_orderkey@0], 4), 
input_partitions=4
 ------------------------ProjectionExec: expr=[l_orderkey@0 as l_orderkey]
 --------------------------CoalesceBatchesExec: target_batch_size=8192
-----------------------------FilterExec: l_commitdate@1 < l_receiptdate@2
+----------------------------FilterExec: l_receiptdate@2 > l_commitdate@1
 ------------------------------CsvExec: file_groups={4 groups: 
[[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749],
 
[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498],
 
[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247],
 
[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]},
 projection=[l_orderkey, l_commitdate, l_receiptdate],  [...]
 
 

Reply via email to