This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new bc0ba6a724 Enhance simplifier by adding Canonicalize (#8780)
bc0ba6a724 is described below
commit bc0ba6a724aaaf312b770451718cbce696de6640
Author: Junhao Liu <[email protected]>
AuthorDate: Wed Jan 24 06:39:53 2024 -0600
Enhance simplifier by adding Canonicalize (#8780)
* Enhance simplifier by adding Canonicalize
* fix swap operation
* Using Match to match cases
* feat: refactor code using better match and revise doc
* Update datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Co-authored-by: Andrew Lamb <[email protected]>
* add more robust test case and code format more Rustic
* Update datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Co-authored-by: Jeffrey Vo <[email protected]>
* Make Join unrelated to order
* fmt doc
* Check Join for both sides
* No canonicalize for Join
* cargo fmt
* Add comment for dup codes
* Fix test cases
* remove wrong git commit
* Update datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs
Co-authored-by: Jeffrey Vo <[email protected]>
* remove outdated change
---------
Co-authored-by: Andrew Lamb <[email protected]>
Co-authored-by: Jeffrey Vo <[email protected]>
---
.../src/simplify_expressions/expr_simplifier.rs | 104 ++++++++++++++++++++-
.../src/simplify_expressions/simplify_exprs.rs | 38 ++++++--
datafusion/sqllogictest/test_files/joins.slt | 6 +-
datafusion/sqllogictest/test_files/subquery.slt | 2 +-
.../sqllogictest/test_files/tpch/q12.slt.part | 6 +-
.../sqllogictest/test_files/tpch/q4.slt.part | 6 +-
6 files changed, 141 insertions(+), 21 deletions(-)
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index 95536e9fc5..561fe1d12d 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -151,6 +151,10 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
.rewrite(&mut simplifier)
}
+ pub fn canonicalize(&self, expr: Expr) -> Result<Expr> {
+ let mut canonicalizer = Canonicalizer::new();
+ expr.rewrite(&mut canonicalizer)
+ }
/// Apply type coercion to an [`Expr`] so that it can be
/// evaluated as a
[`PhysicalExpr`](datafusion_physical_expr::PhysicalExpr).
///
@@ -227,6 +231,51 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
}
}
+/// Canonicalize any BinaryExprs that are not in canonical form
+///
+/// `<literal> <op> <col>` is rewritten to `<col> <op> <literal>`
+///
+/// `<col1> <op> <col2>` is rewritten so that the name of `col1` sorts higher
+/// than `col2` (`b > a` would be canonicalized to `a < b`)
+struct Canonicalizer {}
+
+impl Canonicalizer {
+ fn new() -> Self {
+ Self {}
+ }
+}
+
+impl TreeNodeRewriter for Canonicalizer {
+ type N = Expr;
+
+ fn mutate(&mut self, expr: Expr) -> Result<Expr> {
+ let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr else {
+ return Ok(expr);
+ };
+ match (left.as_ref(), right.as_ref(), op.swap()) {
+ // <col1> <op> <col2>
+ (Expr::Column(left_col), Expr::Column(right_col), Some(swapped_op))
+ if right_col > left_col =>
+ {
+ Ok(Expr::BinaryExpr(BinaryExpr {
+ left: right,
+ op: swapped_op,
+ right: left,
+ }))
+ }
+ // <literal> <op> <col>
+ (Expr::Literal(_a), Expr::Column(_b), Some(swapped_op)) => {
+ Ok(Expr::BinaryExpr(BinaryExpr {
+ left: right,
+ op: swapped_op,
+ right: left,
+ }))
+ }
+ _ => Ok(Expr::BinaryExpr(BinaryExpr { left, op, right })),
+ }
+ }
+}
+
#[allow(rustdoc::private_intra_doc_links)]
/// Partially evaluate `Expr`s so constant subtrees are evaluated at plan time.
///
@@ -1612,6 +1661,58 @@ mod tests {
// --- Simplifier tests -----
// ------------------------------
+ #[test]
+ fn test_simplify_canonicalize() {
+ {
+ let expr = lit(1).lt(col("c2")).and(col("c2").gt(lit(1)));
+ let expected = col("c2").gt(lit(1));
+ assert_eq!(simplify(expr), expected);
+ }
+ {
+ let expr = col("c1").lt(col("c2")).and(col("c2").gt(col("c1")));
+ let expected = col("c2").gt(col("c1"));
+ assert_eq!(simplify(expr), expected);
+ }
+ {
+ let expr = col("c1")
+ .eq(lit(1))
+ .and(lit(1).eq(col("c1")))
+ .and(col("c1").eq(lit(3)));
+ let expected = col("c1").eq(lit(1)).and(col("c1").eq(lit(3)));
+ assert_eq!(simplify(expr), expected);
+ }
+ {
+ let expr = col("c1")
+ .eq(col("c2"))
+ .and(col("c1").gt(lit(5)))
+ .and(col("c2").eq(col("c1")));
+ let expected = col("c2").eq(col("c1")).and(col("c1").gt(lit(5)));
+ assert_eq!(simplify(expr), expected);
+ }
+ {
+ let expr = col("c1")
+ .eq(lit(1))
+ .and(col("c2").gt(lit(3)).or(lit(3).lt(col("c2"))));
+ let expected = col("c1").eq(lit(1)).and(col("c2").gt(lit(3)));
+ assert_eq!(simplify(expr), expected);
+ }
+ {
+ let expr = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
+ let expected = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
+ assert_eq!(simplify(expr), expected);
+ }
+ {
+ let expr = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
+ let expected = col("c1").lt(lit(5)).and(col("c1").gt_eq(lit(5)));
+ assert_eq!(simplify(expr), expected);
+ }
+ {
+ let expr = col("c1").gt(col("c2")).and(col("c1").gt(col("c2")));
+ let expected = col("c2").lt(col("c1"));
+ assert_eq!(simplify(expr), expected);
+ }
+ }
+
#[test]
fn test_simplify_or_true() {
let expr_a = col("c2").or(lit(true));
@@ -2807,7 +2908,8 @@ mod tests {
let simplifier = ExprSimplifier::new(
SimplifyContext::new(&execution_props).with_schema(schema),
);
- simplifier.simplify(expr)
+ let cano = simplifier.canonicalize(expr)?;
+ simplifier.simplify(cano)
}
fn simplify(expr: Expr) -> Expr {
diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs
b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs
index cfd02547b8..7265b17dd0 100644
--- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs
+++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs
@@ -93,16 +93,34 @@ impl SimplifyExpressions {
.map(|input| Self::optimize_internal(input, execution_props))
.collect::<Result<Vec<_>>>()?;
- let expr = plan
- .expressions()
- .into_iter()
- .map(|e| {
- // TODO: unify with `rewrite_preserving_name`
- let original_name = e.name_for_alias()?;
- let new_e = simplifier.simplify(e)?;
- new_e.alias_if_changed(original_name)
- })
- .collect::<Result<Vec<_>>>()?;
+ let expr = match plan {
+ // Canonicalize step won't reorder expressions in a Join on clause.
+ // The left and right expressions in a Join on clause are not
commutative,
+ // since the order of the columns must match the order of the
children.
+ LogicalPlan::Join(_) => {
+ plan.expressions()
+ .into_iter()
+ .map(|e| {
+ // TODO: unify with `rewrite_preserving_name`
+ let original_name = e.name_for_alias()?;
+ let new_e = simplifier.simplify(e)?;
+ new_e.alias_if_changed(original_name)
+ })
+ .collect::<Result<Vec<_>>>()?
+ }
+ _ => {
+ plan.expressions()
+ .into_iter()
+ .map(|e| {
+ // TODO: unify with `rewrite_preserving_name`
+ let original_name = e.name_for_alias()?;
+ let cano_e = simplifier.canonicalize(e)?;
+ let new_e = simplifier.simplify(cano_e)?;
+ new_e.alias_if_changed(original_name)
+ })
+ .collect::<Result<Vec<_>>>()?
+ }
+ };
plan.with_new_exprs(expr, &new_inputs)
}
diff --git a/datafusion/sqllogictest/test_files/joins.slt
b/datafusion/sqllogictest/test_files/joins.slt
index a7146a5a91..e605813b20 100644
--- a/datafusion/sqllogictest/test_files/joins.slt
+++ b/datafusion/sqllogictest/test_files/joins.slt
@@ -1109,7 +1109,7 @@ RIGHT JOIN join_t2 on join_t1.t1_id = join_t2.t2_id
WHERE NOT (join_t1.t1_int = join_t2.t2_int)
----
logical_plan
-Inner Join: join_t1.t1_id = join_t2.t2_id Filter: join_t1.t1_int !=
join_t2.t2_int
+Inner Join: join_t1.t1_id = join_t2.t2_id Filter: join_t2.t2_int !=
join_t1.t1_int
--TableScan: join_t1 projection=[t1_id, t1_name, t1_int]
--TableScan: join_t2 projection=[t2_id, t2_name, t2_int]
@@ -3472,13 +3472,13 @@ FROM annotated_data as l, annotated_data as r
WHERE l.a > r.a
----
logical_plan
-Inner Join: Filter: l.a > r.a
+Inner Join: Filter: r.a < l.a
--SubqueryAlias: l
----TableScan: annotated_data projection=[a0, a, b, c, d]
--SubqueryAlias: r
----TableScan: annotated_data projection=[a0, a, b, c, d]
physical_plan
-NestedLoopJoinExec: join_type=Inner, filter=a@0 > a@1
+NestedLoopJoinExec: join_type=Inner, filter=a@1 < a@0
--RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
----CsvExec: file_groups={1 group:
[[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a,
b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST],
has_header=true
--CsvExec: file_groups={1 group:
[[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a,
b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST],
has_header=true
diff --git a/datafusion/sqllogictest/test_files/subquery.slt
b/datafusion/sqllogictest/test_files/subquery.slt
index 3e0fcb7aa9..1ca9045f1b 100644
--- a/datafusion/sqllogictest/test_files/subquery.slt
+++ b/datafusion/sqllogictest/test_files/subquery.slt
@@ -396,7 +396,7 @@ LeftSemi Join: t1.t1_id = __correlated_sq_1.t1_int
--TableScan: t1 projection=[t1_id, t1_name, t1_int]
--SubqueryAlias: __correlated_sq_1
----Projection: t1.t1_int
-------Filter: t1.t1_id > t1.t1_int
+------Filter: t1.t1_int < t1.t1_id
--------TableScan: t1 projection=[t1_id, t1_int]
#in_subquery_nested_exist_subquery
diff --git a/datafusion/sqllogictest/test_files/tpch/q12.slt.part
b/datafusion/sqllogictest/test_files/tpch/q12.slt.part
index 09939359ce..68ef41b382 100644
--- a/datafusion/sqllogictest/test_files/tpch/q12.slt.part
+++ b/datafusion/sqllogictest/test_files/tpch/q12.slt.part
@@ -55,8 +55,8 @@ Sort: lineitem.l_shipmode ASC NULLS LAST
------Projection: lineitem.l_shipmode, orders.o_orderpriority
--------Inner Join: lineitem.l_orderkey = orders.o_orderkey
----------Projection: lineitem.l_orderkey, lineitem.l_shipmode
-------------Filter: (lineitem.l_shipmode = Utf8("MAIL") OR lineitem.l_shipmode
= Utf8("SHIP")) AND lineitem.l_commitdate < lineitem.l_receiptdate AND
lineitem.l_shipdate < lineitem.l_commitdate AND lineitem.l_receiptdate >=
Date32("8766") AND lineitem.l_receiptdate < Date32("9131")
---------------TableScan: lineitem projection=[l_orderkey, l_shipdate,
l_commitdate, l_receiptdate, l_shipmode], partial_filters=[lineitem.l_shipmode
= Utf8("MAIL") OR lineitem.l_shipmode = Utf8("SHIP"), lineitem.l_commitdate <
lineitem.l_receiptdate, lineitem.l_shipdate < lineitem.l_commitdate,
lineitem.l_receiptdate >= Date32("8766"), lineitem.l_receiptdate <
Date32("9131")]
+------------Filter: (lineitem.l_shipmode = Utf8("MAIL") OR lineitem.l_shipmode
= Utf8("SHIP")) AND lineitem.l_receiptdate > lineitem.l_commitdate AND
lineitem.l_shipdate < lineitem.l_commitdate AND lineitem.l_receiptdate >=
Date32("8766") AND lineitem.l_receiptdate < Date32("9131")
+--------------TableScan: lineitem projection=[l_orderkey, l_shipdate,
l_commitdate, l_receiptdate, l_shipmode], partial_filters=[lineitem.l_shipmode
= Utf8("MAIL") OR lineitem.l_shipmode = Utf8("SHIP"), lineitem.l_receiptdate >
lineitem.l_commitdate, lineitem.l_shipdate < lineitem.l_commitdate,
lineitem.l_receiptdate >= Date32("8766"), lineitem.l_receiptdate <
Date32("9131")]
----------TableScan: orders projection=[o_orderkey, o_orderpriority]
physical_plan
SortPreservingMergeExec: [l_shipmode@0 ASC NULLS LAST]
@@ -73,7 +73,7 @@ SortPreservingMergeExec: [l_shipmode@0 ASC NULLS LAST]
----------------------RepartitionExec: partitioning=Hash([l_orderkey@0], 4),
input_partitions=4
------------------------ProjectionExec: expr=[l_orderkey@0 as l_orderkey,
l_shipmode@4 as l_shipmode]
--------------------------CoalesceBatchesExec: target_batch_size=8192
-----------------------------FilterExec: (l_shipmode@4 = MAIL OR l_shipmode@4 =
SHIP) AND l_commitdate@2 < l_receiptdate@3 AND l_shipdate@1 < l_commitdate@2
AND l_receiptdate@3 >= 8766 AND l_receiptdate@3 < 9131
+----------------------------FilterExec: (l_shipmode@4 = MAIL OR l_shipmode@4 =
SHIP) AND l_receiptdate@3 > l_commitdate@2 AND l_shipdate@1 < l_commitdate@2
AND l_receiptdate@3 >= 8766 AND l_receiptdate@3 < 9131
------------------------------CsvExec: file_groups={4 groups:
[[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749],
[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498],
[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247],
[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]},
projection=[l_orderkey, l_shipdate, l_commitdate, l_re [...]
--------------------CoalesceBatchesExec: target_batch_size=8192
----------------------RepartitionExec: partitioning=Hash([o_orderkey@0], 4),
input_partitions=4
diff --git a/datafusion/sqllogictest/test_files/tpch/q4.slt.part
b/datafusion/sqllogictest/test_files/tpch/q4.slt.part
index 690ef64bc2..1709ae04aa 100644
--- a/datafusion/sqllogictest/test_files/tpch/q4.slt.part
+++ b/datafusion/sqllogictest/test_files/tpch/q4.slt.part
@@ -50,8 +50,8 @@ Sort: orders.o_orderpriority ASC NULLS LAST
--------------TableScan: orders projection=[o_orderkey, o_orderdate,
o_orderpriority], partial_filters=[orders.o_orderdate >= Date32("8582"),
orders.o_orderdate < Date32("8674")]
----------SubqueryAlias: __correlated_sq_1
------------Projection: lineitem.l_orderkey
---------------Filter: lineitem.l_commitdate < lineitem.l_receiptdate
-----------------TableScan: lineitem projection=[l_orderkey, l_commitdate,
l_receiptdate], partial_filters=[lineitem.l_commitdate < lineitem.l_receiptdate]
+--------------Filter: lineitem.l_receiptdate > lineitem.l_commitdate
+----------------TableScan: lineitem projection=[l_orderkey, l_commitdate,
l_receiptdate], partial_filters=[lineitem.l_receiptdate > lineitem.l_commitdate]
physical_plan
SortPreservingMergeExec: [o_orderpriority@0 ASC NULLS LAST]
--SortExec: expr=[o_orderpriority@0 ASC NULLS LAST]
@@ -73,7 +73,7 @@ SortPreservingMergeExec: [o_orderpriority@0 ASC NULLS LAST]
----------------------RepartitionExec: partitioning=Hash([l_orderkey@0], 4),
input_partitions=4
------------------------ProjectionExec: expr=[l_orderkey@0 as l_orderkey]
--------------------------CoalesceBatchesExec: target_batch_size=8192
-----------------------------FilterExec: l_commitdate@1 < l_receiptdate@2
+----------------------------FilterExec: l_receiptdate@2 > l_commitdate@1
------------------------------CsvExec: file_groups={4 groups:
[[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749],
[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498],
[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247],
[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]},
projection=[l_orderkey, l_commitdate, l_receiptdate], [...]