This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 03d8ba1f0d Refactor `Optimizer` to use owned plans and `TreeNode` API
(10% faster planning) (#9948)
03d8ba1f0d is described below
commit 03d8ba1f0d94bac6bb8bb33e95f00f9f6fb5275a
Author: Andrew Lamb <[email protected]>
AuthorDate: Wed Apr 10 09:27:35 2024 -0400
Refactor `Optimizer` to use owned plans and `TreeNode` API (10% faster
planning) (#9948)
* Rewrite Optimizer to use TreeNode API
* fmt
---
datafusion-examples/examples/rewrite_expr.rs | 2 +-
datafusion/common/src/tree_node.rs | 12 +-
datafusion/core/src/execution/context/mod.rs | 4 +-
datafusion/core/tests/optimizer_integration.rs | 2 +-
.../src/decorrelate_predicate_subquery.rs | 96 +++---
.../optimizer/src/eliminate_duplicated_expr.rs | 6 +-
datafusion/optimizer/src/eliminate_filter.rs | 14 +-
datafusion/optimizer/src/eliminate_join.rs | 6 +-
datafusion/optimizer/src/eliminate_limit.rs | 32 +-
datafusion/optimizer/src/eliminate_nested_union.rs | 22 +-
datafusion/optimizer/src/eliminate_one_union.rs | 6 +-
datafusion/optimizer/src/eliminate_outer_join.rs | 12 +-
.../optimizer/src/extract_equijoin_predicate.rs | 18 +-
datafusion/optimizer/src/filter_null_join_keys.rs | 14 +-
datafusion/optimizer/src/optimize_projections.rs | 50 +--
datafusion/optimizer/src/optimizer.rs | 348 ++++++++++-----------
.../optimizer/src/propagate_empty_relation.rs | 22 +-
datafusion/optimizer/src/push_down_filter.rs | 151 +++++----
datafusion/optimizer/src/push_down_limit.rs | 70 ++---
datafusion/optimizer/src/push_down_projection.rs | 76 +++--
.../optimizer/src/replace_distinct_aggregate.rs | 4 +-
.../optimizer/src/scalar_subquery_to_join.rs | 28 +-
.../optimizer/src/single_distinct_to_groupby.rs | 40 +--
datafusion/optimizer/src/test/mod.rs | 68 ++--
.../optimizer/tests/optimizer_integration.rs | 6 +-
datafusion/sqllogictest/test_files/join.slt | 2 +-
26 files changed, 535 insertions(+), 576 deletions(-)
diff --git a/datafusion-examples/examples/rewrite_expr.rs
b/datafusion-examples/examples/rewrite_expr.rs
index 541448ebf1..dcebbb55fb 100644
--- a/datafusion-examples/examples/rewrite_expr.rs
+++ b/datafusion-examples/examples/rewrite_expr.rs
@@ -59,7 +59,7 @@ pub fn main() -> Result<()> {
// then run the optimizer with our custom rule
let optimizer = Optimizer::with_rules(vec![Arc::new(MyOptimizerRule {})]);
- let optimized_plan = optimizer.optimize(&analyzed_plan, &config, observe)?;
+ let optimized_plan = optimizer.optimize(analyzed_plan, &config, observe)?;
println!(
"Optimized Logical Plan:\n\n{}\n",
optimized_plan.display_indent()
diff --git a/datafusion/common/src/tree_node.rs
b/datafusion/common/src/tree_node.rs
index bb268e048d..dff22d4959 100644
--- a/datafusion/common/src/tree_node.rs
+++ b/datafusion/common/src/tree_node.rs
@@ -56,6 +56,9 @@ pub trait TreeNode: Sized {
/// Visit the tree node using the given [`TreeNodeVisitor`], performing a
/// depth-first walk of the node and its children.
///
+ /// See also:
+ /// * [`Self::rewrite`] to rewrite owned `TreeNode`s
+ ///
/// Consider the following tree structure:
/// ```text
/// ParentNode
@@ -93,6 +96,9 @@ pub trait TreeNode: Sized {
/// Implements the [visitor
pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for
/// recursively transforming [`TreeNode`]s.
///
+ /// See also:
+ /// * [`Self::visit`] for inspecting (without modification) `TreeNode`s
+ ///
/// Consider the following tree structure:
/// ```text
/// ParentNode
@@ -310,13 +316,15 @@ pub trait TreeNode: Sized {
}
/// Apply the closure `F` to the node's children.
+ ///
+ /// See `mutate_children` for rewriting in place
fn apply_children<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
&self,
f: F,
) -> Result<TreeNodeRecursion>;
- /// Apply transform `F` to the node's children. Note that the transform `F`
- /// might have a direction (pre-order or post-order).
+ /// Apply transform `F` to potentially rewrite the node's children. Note
+ /// that the transform `F` might have a direction (pre-order or
post-order).
fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
self,
f: F,
diff --git a/datafusion/core/src/execution/context/mod.rs
b/datafusion/core/src/execution/context/mod.rs
index 9e48c7b8a6..5cf8969aa4 100644
--- a/datafusion/core/src/execution/context/mod.rs
+++ b/datafusion/core/src/execution/context/mod.rs
@@ -1881,7 +1881,7 @@ impl SessionState {
// optimize the child plan, capturing the output of each optimizer
let optimized_plan = self.optimizer.optimize(
- &analyzed_plan,
+ analyzed_plan,
self,
|optimized_plan, optimizer| {
let optimizer_name = optimizer.name().to_string();
@@ -1911,7 +1911,7 @@ impl SessionState {
let analyzed_plan =
self.analyzer
.execute_and_check(plan, self.options(), |_, _| {})?;
- self.optimizer.optimize(&analyzed_plan, self, |_, _| {})
+ self.optimizer.optimize(analyzed_plan, self, |_, _| {})
}
}
diff --git a/datafusion/core/tests/optimizer_integration.rs
b/datafusion/core/tests/optimizer_integration.rs
index 60010bdddf..6e938361dd 100644
--- a/datafusion/core/tests/optimizer_integration.rs
+++ b/datafusion/core/tests/optimizer_integration.rs
@@ -110,7 +110,7 @@ fn test_sql(sql: &str) -> Result<LogicalPlan> {
let optimizer = Optimizer::new();
// analyze and optimize the logical plan
let plan = analyzer.execute_and_check(&plan, config.options(), |_, _| {})?;
- optimizer.optimize(&plan, &config, |_, _| {})
+ optimizer.optimize(plan, &config, |_, _| {})
}
#[derive(Default)]
diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs
b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs
index 019e7507b1..d9fc5a6ce2 100644
--- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs
+++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs
@@ -338,7 +338,7 @@ mod tests {
Operator,
};
- fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) ->
Result<()> {
+ fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) ->
Result<()> {
assert_optimized_plan_eq_display_indent(
Arc::new(DecorrelatePredicateSubquery::new()),
plan,
@@ -378,7 +378,7 @@ mod tests {
\n SubqueryAlias: __correlated_sq_2 [c:UInt32]\
\n Projection: sq_2.c [c:UInt32]\
\n TableScan: sq_2 [a:UInt32, b:UInt32, c:UInt32]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
/// Test for IN subquery with additional AND filter
@@ -404,7 +404,7 @@ mod tests {
\n Projection: sq.c [c:UInt32]\
\n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
/// Test for IN subquery with additional OR filter
@@ -430,7 +430,7 @@ mod tests {
\n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -458,7 +458,7 @@ mod tests {
\n Projection: sq2.c [c:UInt32]\
\n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
/// Test for nested IN subqueries
@@ -487,7 +487,7 @@ mod tests {
\n Projection: sq_nested.c [c:UInt32]\
\n TableScan: sq_nested [a:UInt32, b:UInt32, c:UInt32]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
/// Test for filter input modification in case filter not supported
@@ -519,7 +519,7 @@ mod tests {
\n Projection: sq_inner.c [c:UInt32]\
\n TableScan: sq_inner [a:UInt32, b:UInt32, c:UInt32]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
/// Test multiple correlated subqueries
@@ -557,7 +557,7 @@ mod tests {
assert_optimized_plan_eq_display_indent(
Arc::new(DecorrelatePredicateSubquery::new()),
- &plan,
+ plan,
expected,
);
Ok(())
@@ -607,7 +607,7 @@ mod tests {
assert_optimized_plan_eq_display_indent(
Arc::new(DecorrelatePredicateSubquery::new()),
- &plan,
+ plan,
expected,
);
Ok(())
@@ -642,7 +642,7 @@ mod tests {
assert_optimized_plan_eq_display_indent(
Arc::new(DecorrelatePredicateSubquery::new()),
- &plan,
+ plan,
expected,
);
Ok(())
@@ -675,7 +675,7 @@ mod tests {
assert_optimized_plan_eq_display_indent(
Arc::new(DecorrelatePredicateSubquery::new()),
- &plan,
+ plan,
expected,
);
Ok(())
@@ -706,7 +706,7 @@ mod tests {
assert_optimized_plan_eq_display_indent(
Arc::new(DecorrelatePredicateSubquery::new()),
- &plan,
+ plan,
expected,
);
Ok(())
@@ -739,7 +739,7 @@ mod tests {
assert_optimized_plan_eq_display_indent(
Arc::new(DecorrelatePredicateSubquery::new()),
- &plan,
+ plan,
expected,
);
Ok(())
@@ -772,7 +772,7 @@ mod tests {
assert_optimized_plan_eq_display_indent(
Arc::new(DecorrelatePredicateSubquery::new()),
- &plan,
+ plan,
expected,
);
Ok(())
@@ -806,7 +806,7 @@ mod tests {
assert_optimized_plan_eq_display_indent(
Arc::new(DecorrelatePredicateSubquery::new()),
- &plan,
+ plan,
expected,
);
@@ -863,7 +863,7 @@ mod tests {
assert_optimized_plan_eq_display_indent(
Arc::new(DecorrelatePredicateSubquery::new()),
- &plan,
+ plan,
expected,
);
Ok(())
@@ -896,7 +896,7 @@ mod tests {
assert_optimized_plan_eq_display_indent(
Arc::new(DecorrelatePredicateSubquery::new()),
- &plan,
+ plan,
expected,
);
Ok(())
@@ -962,7 +962,7 @@ mod tests {
assert_optimized_plan_eq_display_indent(
Arc::new(DecorrelatePredicateSubquery::new()),
- &plan,
+ plan,
expected,
);
Ok(())
@@ -1000,7 +1000,7 @@ mod tests {
assert_optimized_plan_eq_display_indent(
Arc::new(DecorrelatePredicateSubquery::new()),
- &plan,
+ plan,
expected,
);
Ok(())
@@ -1030,7 +1030,7 @@ mod tests {
assert_optimized_plan_eq_display_indent(
Arc::new(DecorrelatePredicateSubquery::new()),
- &plan,
+ plan,
expected,
);
Ok(())
@@ -1054,7 +1054,7 @@ mod tests {
assert_optimized_plan_eq_display_indent(
Arc::new(DecorrelatePredicateSubquery::new()),
- &plan,
+ plan,
expected,
);
Ok(())
@@ -1078,7 +1078,7 @@ mod tests {
assert_optimized_plan_eq_display_indent(
Arc::new(DecorrelatePredicateSubquery::new()),
- &plan,
+ plan,
expected,
);
Ok(())
@@ -1107,7 +1107,7 @@ mod tests {
assert_optimized_plan_eq_display_indent(
Arc::new(DecorrelatePredicateSubquery::new()),
- &plan,
+ plan,
expected,
);
Ok(())
@@ -1142,7 +1142,7 @@ mod tests {
assert_optimized_plan_eq_display_indent(
Arc::new(DecorrelatePredicateSubquery::new()),
- &plan,
+ plan,
expected,
);
Ok(())
@@ -1178,7 +1178,7 @@ mod tests {
assert_optimized_plan_eq_display_indent(
Arc::new(DecorrelatePredicateSubquery::new()),
- &plan,
+ plan,
expected,
);
Ok(())
@@ -1224,7 +1224,7 @@ mod tests {
assert_optimized_plan_eq_display_indent(
Arc::new(DecorrelatePredicateSubquery::new()),
- &plan,
+ plan,
expected,
);
Ok(())
@@ -1255,7 +1255,7 @@ mod tests {
assert_optimized_plan_eq_display_indent(
Arc::new(DecorrelatePredicateSubquery::new()),
- &plan,
+ plan,
expected,
);
Ok(())
@@ -1289,7 +1289,7 @@ mod tests {
\n SubqueryAlias: __correlated_sq_2
[o_custkey:Int64]\
\n Projection: orders.o_custkey [o_custkey:Int64]\
\n TableScan: orders [o_orderkey:Int64,
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
/// Test recursive correlated subqueries
@@ -1332,7 +1332,7 @@ mod tests {
\n SubqueryAlias: __correlated_sq_2
[l_orderkey:Int64]\
\n Projection: lineitem.l_orderkey
[l_orderkey:Int64]\
\n TableScan: lineitem [l_orderkey:Int64,
l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64,
l_extendedprice:Float64]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
/// Test for correlated exists subquery filter with additional subquery
filters
@@ -1362,7 +1362,7 @@ mod tests {
\n Filter: orders.o_orderkey = Int32(1)
[o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n TableScan: orders [o_orderkey:Int64,
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -1387,7 +1387,7 @@ mod tests {
\n Projection: orders.o_custkey [o_custkey:Int64]\
\n TableScan: orders [o_orderkey:Int64,
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
/// Test for exists subquery with both columns in schema
@@ -1405,7 +1405,7 @@ mod tests {
.project(vec![col("customer.c_custkey")])?
.build()?;
-
assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()),
&plan)
+
assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), plan)
}
/// Test for correlated exists subquery not equal
@@ -1433,7 +1433,7 @@ mod tests {
\n Projection: orders.o_custkey [o_custkey:Int64]\
\n TableScan: orders [o_orderkey:Int64,
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
/// Test for correlated exists subquery less than
@@ -1461,7 +1461,7 @@ mod tests {
\n Projection: orders.o_custkey [o_custkey:Int64]\
\n TableScan: orders [o_orderkey:Int64,
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
/// Test for correlated exists subquery filter with subquery disjunction
@@ -1490,7 +1490,7 @@ mod tests {
\n Projection: orders.o_custkey,
orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64]\
\n TableScan: orders [o_orderkey:Int64,
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
/// Test for correlated exists without projection
@@ -1516,7 +1516,7 @@ mod tests {
\n SubqueryAlias: __correlated_sq_1
[o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n TableScan: orders [o_orderkey:Int64,
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
/// Test for correlated exists expressions
@@ -1544,7 +1544,7 @@ mod tests {
\n Projection: orders.o_custkey + Int32(1),
orders.o_custkey [orders.o_custkey + Int32(1):Int64, o_custkey:Int64]\
\n TableScan: orders [o_orderkey:Int64,
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
/// Test for correlated exists subquery filter with additional filters
@@ -1572,7 +1572,7 @@ mod tests {
\n Projection: orders.o_custkey
[o_custkey:Int64]\
\n TableScan: orders [o_orderkey:Int64,
o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
/// Test for correlated exists subquery filter with disjustions
@@ -1599,7 +1599,7 @@ mod tests {
TableScan: orders [o_orderkey:Int64, o_custkey:Int64,
o_orderstatus:Utf8, o_totalprice:Float64;N]
TableScan: customer [c_custkey:Int64, c_name:Utf8]"#;
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
/// Test for correlated EXISTS subquery filter
@@ -1624,7 +1624,7 @@ mod tests {
\n Projection: sq.c, sq.a [c:UInt32, a:UInt32]\
\n TableScan: sq [a:UInt32, b:UInt32,
c:UInt32]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
/// Test for single exists subquery filter
@@ -1636,7 +1636,7 @@ mod tests {
.project(vec![col("test.b")])?
.build()?;
-
assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()),
&plan)
+
assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), plan)
}
/// Test for single NOT exists subquery filter
@@ -1648,7 +1648,7 @@ mod tests {
.project(vec![col("test.b")])?
.build()?;
-
assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()),
&plan)
+
assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), plan)
}
#[test]
@@ -1687,7 +1687,7 @@ mod tests {
\n Projection: sq2.c, sq2.a [c:UInt32,
a:UInt32]\
\n TableScan: sq2 [a:UInt32, b:UInt32,
c:UInt32]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -1713,7 +1713,7 @@ mod tests {
\n Projection: UInt32(1), sq.a [UInt32(1):UInt32,
a:UInt32]\
\n TableScan: sq [a:UInt32, b:UInt32,
c:UInt32]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -1739,7 +1739,7 @@ mod tests {
\n TableScan: test [a:UInt32, b:UInt32,
c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -1767,7 +1767,7 @@ mod tests {
\n Projection: sq.c, sq.a [c:UInt32, a:UInt32]\
\n TableScan: sq [a:UInt32, b:UInt32,
c:UInt32]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -1795,7 +1795,7 @@ mod tests {
\n Projection: sq.b + sq.c, sq.a [sq.b +
sq.c:UInt32, a:UInt32]\
\n TableScan: sq [a:UInt32, b:UInt32,
c:UInt32]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -1823,6 +1823,6 @@ mod tests {
\n Projection: UInt32(1), sq.c, sq.a
[UInt32(1):UInt32, c:UInt32, a:UInt32]\
\n TableScan: sq [a:UInt32, b:UInt32,
c:UInt32]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
}
diff --git a/datafusion/optimizer/src/eliminate_duplicated_expr.rs
b/datafusion/optimizer/src/eliminate_duplicated_expr.rs
index 349d4d8878..ee44a328f8 100644
--- a/datafusion/optimizer/src/eliminate_duplicated_expr.rs
+++ b/datafusion/optimizer/src/eliminate_duplicated_expr.rs
@@ -116,7 +116,7 @@ mod tests {
use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder};
use std::sync::Arc;
- fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) ->
Result<()> {
+ fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) ->
Result<()> {
crate::test::assert_optimized_plan_eq(
Arc::new(EliminateDuplicatedExpr::new()),
plan,
@@ -134,7 +134,7 @@ mod tests {
let expected = "Limit: skip=5, fetch=10\
\n Sort: test.a, test.b, test.c\
\n TableScan: test";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -153,6 +153,6 @@ mod tests {
let expected = "Limit: skip=5, fetch=10\
\n Sort: test.a ASC NULLS FIRST, test.b ASC NULLS LAST\
\n TableScan: test";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
}
diff --git a/datafusion/optimizer/src/eliminate_filter.rs
b/datafusion/optimizer/src/eliminate_filter.rs
index 9411dc192b..2bf5cfa303 100644
--- a/datafusion/optimizer/src/eliminate_filter.rs
+++ b/datafusion/optimizer/src/eliminate_filter.rs
@@ -91,7 +91,7 @@ mod tests {
use crate::test::*;
- fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) ->
Result<()> {
+ fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) ->
Result<()> {
assert_optimized_plan_eq(Arc::new(EliminateFilter::new()), plan,
expected)
}
@@ -107,7 +107,7 @@ mod tests {
// No aggregate / scan / limit
let expected = "EmptyRelation";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -122,7 +122,7 @@ mod tests {
// No aggregate / scan / limit
let expected = "EmptyRelation";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -144,7 +144,7 @@ mod tests {
\n EmptyRelation\
\n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\
\n TableScan: test";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -159,7 +159,7 @@ mod tests {
let expected = "Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\
\n TableScan: test";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -182,7 +182,7 @@ mod tests {
\n TableScan: test\
\n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\
\n TableScan: test";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -205,6 +205,6 @@ mod tests {
// Filter is removed
let expected = "Projection: test.a\
\n EmptyRelation";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
}
diff --git a/datafusion/optimizer/src/eliminate_join.rs
b/datafusion/optimizer/src/eliminate_join.rs
index e685229c61..caf45dda98 100644
--- a/datafusion/optimizer/src/eliminate_join.rs
+++ b/datafusion/optimizer/src/eliminate_join.rs
@@ -83,7 +83,7 @@ mod tests {
use datafusion_expr::{logical_plan::builder::LogicalPlanBuilder, Expr,
LogicalPlan};
use std::sync::Arc;
- fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) ->
Result<()> {
+ fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) ->
Result<()> {
assert_optimized_plan_eq(Arc::new(EliminateJoin::new()), plan,
expected)
}
@@ -98,7 +98,7 @@ mod tests {
.build()?;
let expected = "EmptyRelation";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -115,6 +115,6 @@ mod tests {
CrossJoin:\
\n EmptyRelation\
\n EmptyRelation";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
}
diff --git a/datafusion/optimizer/src/eliminate_limit.rs
b/datafusion/optimizer/src/eliminate_limit.rs
index fb5d0d17b8..39231d784e 100644
--- a/datafusion/optimizer/src/eliminate_limit.rs
+++ b/datafusion/optimizer/src/eliminate_limit.rs
@@ -94,24 +94,19 @@ mod tests {
use crate::push_down_limit::PushDownLimit;
- fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) ->
Result<()> {
+ fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
+ fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) ->
Result<()> {
let optimizer =
Optimizer::with_rules(vec![Arc::new(EliminateLimit::new())]);
- let optimized_plan = optimizer
- .optimize_recursively(
- optimizer.rules.first().unwrap(),
- plan,
- &OptimizerContext::new(),
- )?
- .unwrap_or_else(|| plan.clone());
+ let optimized_plan =
+ optimizer.optimize(plan, &OptimizerContext::new(), observe)?;
let formatted_plan = format!("{optimized_plan:?}");
assert_eq!(formatted_plan, expected);
- assert_eq!(plan.schema(), optimized_plan.schema());
Ok(())
}
fn assert_optimized_plan_eq_with_pushdown(
- plan: &LogicalPlan,
+ plan: LogicalPlan,
expected: &str,
) -> Result<()> {
fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
@@ -125,7 +120,6 @@ mod tests {
.expect("failed to optimize plan");
let formatted_plan = format!("{optimized_plan:?}");
assert_eq!(formatted_plan, expected);
- assert_eq!(plan.schema(), optimized_plan.schema());
Ok(())
}
@@ -138,7 +132,7 @@ mod tests {
.build()?;
// No aggregate / scan / limit
let expected = "EmptyRelation";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -158,7 +152,7 @@ mod tests {
\n EmptyRelation\
\n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\
\n TableScan: test";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -172,7 +166,7 @@ mod tests {
// No aggregate / scan / limit
let expected = "EmptyRelation";
- assert_optimized_plan_eq_with_pushdown(&plan, expected)
+ assert_optimized_plan_eq_with_pushdown(plan, expected)
}
#[test]
@@ -192,7 +186,7 @@ mod tests {
\n Limit: skip=0, fetch=2\
\n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\
\n TableScan: test";
- assert_optimized_plan_eq_with_pushdown(&plan, expected)
+ assert_optimized_plan_eq_with_pushdown(plan, expected)
}
#[test]
@@ -210,7 +204,7 @@ mod tests {
\n Limit: skip=0, fetch=2\
\n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\
\n TableScan: test";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -228,7 +222,7 @@ mod tests {
\n Limit: skip=2, fetch=1\
\n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\
\n TableScan: test";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -250,7 +244,7 @@ mod tests {
\n Limit: skip=2, fetch=1\
\n TableScan: test\
\n TableScan: test1";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -263,6 +257,6 @@ mod tests {
let expected = "Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\
\n TableScan: test";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
}
diff --git a/datafusion/optimizer/src/eliminate_nested_union.rs
b/datafusion/optimizer/src/eliminate_nested_union.rs
index 924a085341..da2a6a1721 100644
--- a/datafusion/optimizer/src/eliminate_nested_union.rs
+++ b/datafusion/optimizer/src/eliminate_nested_union.rs
@@ -114,7 +114,7 @@ mod tests {
])
}
- fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) ->
Result<()> {
+ fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) ->
Result<()> {
assert_optimized_plan_eq(Arc::new(EliminateNestedUnion::new()), plan,
expected)
}
@@ -131,7 +131,7 @@ mod tests {
Union\
\n TableScan: table\
\n TableScan: table";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -147,7 +147,7 @@ mod tests {
\n Union\
\n TableScan: table\
\n TableScan: table";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -167,7 +167,7 @@ mod tests {
\n TableScan: table\
\n TableScan: table\
\n TableScan: table";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -188,7 +188,7 @@ mod tests {
\n TableScan: table\
\n TableScan: table\
\n TableScan: table";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -210,7 +210,7 @@ mod tests {
\n TableScan: table\
\n TableScan: table\
\n TableScan: table";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -230,7 +230,7 @@ mod tests {
\n TableScan: table\
\n TableScan: table\
\n TableScan: table";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
// We don't need to use project_with_column_index in logical optimizer,
@@ -261,7 +261,7 @@ mod tests {
\n TableScan: table\
\n Projection: table.id AS id, table.key, table.value\
\n TableScan: table";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -291,7 +291,7 @@ mod tests {
\n TableScan: table\
\n Projection: table.id AS id, table.key, table.value\
\n TableScan: table";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -337,7 +337,7 @@ mod tests {
\n TableScan: table_1\
\n Projection: CAST(table_1.id AS Int64) AS id, table_1.key,
CAST(table_1.value AS Float64) AS value\
\n TableScan: table_1";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -384,6 +384,6 @@ mod tests {
\n TableScan: table_1\
\n Projection: CAST(table_1.id AS Int64) AS id, table_1.key,
CAST(table_1.value AS Float64) AS value\
\n TableScan: table_1";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
}
diff --git a/datafusion/optimizer/src/eliminate_one_union.rs
b/datafusion/optimizer/src/eliminate_one_union.rs
index 63c3e789da..95a3370ab1 100644
--- a/datafusion/optimizer/src/eliminate_one_union.rs
+++ b/datafusion/optimizer/src/eliminate_one_union.rs
@@ -76,7 +76,7 @@ mod tests {
])
}
- fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) ->
Result<()> {
+ fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) ->
Result<()> {
assert_optimized_plan_eq_with_rules(
vec![Arc::new(EliminateOneUnion::new())],
plan,
@@ -97,7 +97,7 @@ mod tests {
Union\
\n TableScan: table\
\n TableScan: table";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -113,6 +113,6 @@ mod tests {
});
let expected = "TableScan: table";
- assert_optimized_plan_equal(&single_union_plan, expected)
+ assert_optimized_plan_equal(single_union_plan, expected)
}
}
diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs
b/datafusion/optimizer/src/eliminate_outer_join.rs
index a004da2bff..63b8b887bb 100644
--- a/datafusion/optimizer/src/eliminate_outer_join.rs
+++ b/datafusion/optimizer/src/eliminate_outer_join.rs
@@ -306,7 +306,7 @@ mod tests {
Operator::{And, Or},
};
- fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) ->
Result<()> {
+ fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) ->
Result<()> {
assert_optimized_plan_eq(Arc::new(EliminateOuterJoin::new()), plan,
expected)
}
@@ -330,7 +330,7 @@ mod tests {
\n Left Join: t1.a = t2.a\
\n TableScan: t1\
\n TableScan: t2";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -353,7 +353,7 @@ mod tests {
\n Inner Join: t1.a = t2.a\
\n TableScan: t1\
\n TableScan: t2";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -380,7 +380,7 @@ mod tests {
\n Inner Join: t1.a = t2.a\
\n TableScan: t1\
\n TableScan: t2";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -407,7 +407,7 @@ mod tests {
\n Inner Join: t1.a = t2.a\
\n TableScan: t1\
\n TableScan: t2";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -434,6 +434,6 @@ mod tests {
\n Inner Join: t1.a = t2.a\
\n TableScan: t1\
\n TableScan: t2";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
}
diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs
b/datafusion/optimizer/src/extract_equijoin_predicate.rs
index 4cfcd07b47..60b9ba3031 100644
--- a/datafusion/optimizer/src/extract_equijoin_predicate.rs
+++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs
@@ -164,7 +164,7 @@ mod tests {
col, lit, logical_plan::builder::LogicalPlanBuilder, JoinType,
};
- fn assert_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> {
+ fn assert_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> {
assert_optimized_plan_eq_display_indent(
Arc::new(ExtractEquijoinPredicate {}),
plan,
@@ -186,7 +186,7 @@ mod tests {
\n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
- assert_plan_eq(&plan, expected)
+ assert_plan_eq(plan, expected)
}
#[test]
@@ -205,7 +205,7 @@ mod tests {
\n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
- assert_plan_eq(&plan, expected)
+ assert_plan_eq(plan, expected)
}
#[test]
@@ -228,7 +228,7 @@ mod tests {
\n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
- assert_plan_eq(&plan, expected)
+ assert_plan_eq(plan, expected)
}
#[test]
@@ -255,7 +255,7 @@ mod tests {
\n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
- assert_plan_eq(&plan, expected)
+ assert_plan_eq(plan, expected)
}
#[test]
@@ -281,7 +281,7 @@ mod tests {
\n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
- assert_plan_eq(&plan, expected)
+ assert_plan_eq(plan, expected)
}
#[test]
@@ -318,7 +318,7 @@ mod tests {
\n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]";
- assert_plan_eq(&plan, expected)
+ assert_plan_eq(plan, expected)
}
#[test]
@@ -351,7 +351,7 @@ mod tests {
\n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]";
- assert_plan_eq(&plan, expected)
+ assert_plan_eq(plan, expected)
}
#[test]
@@ -375,6 +375,6 @@ mod tests {
\n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
- assert_plan_eq(&plan, expected)
+ assert_plan_eq(plan, expected)
}
}
diff --git a/datafusion/optimizer/src/filter_null_join_keys.rs
b/datafusion/optimizer/src/filter_null_join_keys.rs
index 16039b182b..fcf85327fd 100644
--- a/datafusion/optimizer/src/filter_null_join_keys.rs
+++ b/datafusion/optimizer/src/filter_null_join_keys.rs
@@ -116,7 +116,7 @@ mod tests {
use datafusion_expr::logical_plan::table_scan;
use datafusion_expr::{col, lit, logical_plan::JoinType,
LogicalPlanBuilder};
- fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) ->
Result<()> {
+ fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) ->
Result<()> {
assert_optimized_plan_eq(Arc::new(FilterNullJoinKeys {}), plan,
expected)
}
@@ -128,7 +128,7 @@ mod tests {
\n Filter: t1.optional_id IS NOT NULL\
\n TableScan: t1\
\n TableScan: t2";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -139,7 +139,7 @@ mod tests {
\n Filter: t1.optional_id IS NOT NULL\
\n TableScan: t1\
\n TableScan: t2";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -176,7 +176,7 @@ mod tests {
\n Filter: t1.optional_id IS NOT NULL\
\n TableScan: t1\
\n TableScan: t2";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -197,7 +197,7 @@ mod tests {
\n Filter: t1.optional_id + UInt32(1) IS NOT NULL\
\n TableScan: t1\
\n TableScan: t2";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -218,7 +218,7 @@ mod tests {
\n TableScan: t1\
\n Filter: t2.optional_id + UInt32(1) IS NOT NULL\
\n TableScan: t2";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -241,7 +241,7 @@ mod tests {
\n TableScan: t1\
\n Filter: t2.optional_id + UInt32(1) IS NOT NULL\
\n TableScan: t2";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
fn build_plan(
diff --git a/datafusion/optimizer/src/optimize_projections.rs
b/datafusion/optimizer/src/optimize_projections.rs
index 69905c990a..6967b28f30 100644
--- a/datafusion/optimizer/src/optimize_projections.rs
+++ b/datafusion/optimizer/src/optimize_projections.rs
@@ -941,7 +941,7 @@ mod tests {
UserDefinedLogicalNodeCore,
};
- fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) ->
Result<()> {
+ fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) ->
Result<()> {
assert_optimized_plan_eq(Arc::new(OptimizeProjections::new()), plan,
expected)
}
@@ -1090,7 +1090,7 @@ mod tests {
let expected = "Projection: Int32(1) + test.a\
\n TableScan: test projection=[a]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -1104,7 +1104,7 @@ mod tests {
let expected = "Projection: Int32(1) + test.a\
\n TableScan: test projection=[a]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -1117,7 +1117,7 @@ mod tests {
let expected = "Projection: test.a AS alias\
\n TableScan: test projection=[a]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -1130,7 +1130,7 @@ mod tests {
let expected = "Projection: test.a AS alias\
\n TableScan: test projection=[a]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -1152,7 +1152,7 @@ mod tests {
\n Projection: \
\n Aggregate: groupBy=[[]], aggr=[[COUNT(Int32(1))]]\
\n TableScan: ?table? projection=[]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -1175,7 +1175,7 @@ mod tests {
.build()?;
let expected = "Projection: (?table?.s)[x]\
\n TableScan: ?table? projection=[s]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -1187,7 +1187,7 @@ mod tests {
let expected = "Projection: (- test.a)\
\n TableScan: test projection=[a]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -1199,7 +1199,7 @@ mod tests {
let expected = "Projection: test.a IS NULL\
\n TableScan: test projection=[a]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -1211,7 +1211,7 @@ mod tests {
let expected = "Projection: test.a IS NOT NULL\
\n TableScan: test projection=[a]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -1223,7 +1223,7 @@ mod tests {
let expected = "Projection: test.a IS TRUE\
\n TableScan: test projection=[a]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -1235,7 +1235,7 @@ mod tests {
let expected = "Projection: test.a IS NOT TRUE\
\n TableScan: test projection=[a]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -1247,7 +1247,7 @@ mod tests {
let expected = "Projection: test.a IS FALSE\
\n TableScan: test projection=[a]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -1259,7 +1259,7 @@ mod tests {
let expected = "Projection: test.a IS NOT FALSE\
\n TableScan: test projection=[a]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -1271,7 +1271,7 @@ mod tests {
let expected = "Projection: test.a IS UNKNOWN\
\n TableScan: test projection=[a]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -1283,7 +1283,7 @@ mod tests {
let expected = "Projection: test.a IS NOT UNKNOWN\
\n TableScan: test projection=[a]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -1295,7 +1295,7 @@ mod tests {
let expected = "Projection: NOT test.a\
\n TableScan: test projection=[a]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -1307,7 +1307,7 @@ mod tests {
let expected = "Projection: TRY_CAST(test.a AS Float64)\
\n TableScan: test projection=[a]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -1323,7 +1323,7 @@ mod tests {
let expected = "Projection: test.a SIMILAR TO Utf8(\"[0-9]\")\
\n TableScan: test projection=[a]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -1335,7 +1335,7 @@ mod tests {
let expected = "Projection: test.a BETWEEN Int32(1) AND Int32(3)\
\n TableScan: test projection=[a]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
// Test outer projection isn't discarded despite the same schema as inner
@@ -1356,7 +1356,7 @@ mod tests {
let expected = "Projection: test.a, CASE WHEN test.a = Int32(1) THEN
Int32(10) ELSE d END AS d\
\n Projection: test.a, Int32(0) AS d\
\n TableScan: test projection=[a]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
// Since only column `a` is referred at the output. Scan should only
contain projection=[a].
@@ -1377,7 +1377,7 @@ mod tests {
let expected = "Projection: test.a, Int32(0) AS d\
\n NoOpUserDefined\
\n TableScan: test projection=[a]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
// Only column `a` is referred at the output. However, User defined node
itself uses column `b`
@@ -1404,7 +1404,7 @@ mod tests {
let expected = "Projection: test.a, Int32(0) AS d\
\n NoOpUserDefined\
\n TableScan: test projection=[a, b]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
// Only column `a` is referred at the output. However, User defined node
itself uses expression `b+c`
@@ -1439,7 +1439,7 @@ mod tests {
let expected = "Projection: test.a, Int32(0) AS d\
\n NoOpUserDefined\
\n TableScan: test projection=[a, b, c]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
// Columns `l.a`, `l.c`, `r.a` is referred at the output.
@@ -1464,6 +1464,6 @@ mod tests {
\n UserDefinedCrossJoin\
\n TableScan: l projection=[a, c]\
\n TableScan: r projection=[a]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
}
diff --git a/datafusion/optimizer/src/optimizer.rs
b/datafusion/optimizer/src/optimizer.rs
index 03ff402c3e..032f9c5732 100644
--- a/datafusion/optimizer/src/optimizer.rs
+++ b/datafusion/optimizer/src/optimizer.rs
@@ -20,6 +20,16 @@
use std::collections::HashSet;
use std::sync::Arc;
+use chrono::{DateTime, Utc};
+use log::{debug, warn};
+
+use datafusion_common::alias::AliasGenerator;
+use datafusion_common::config::ConfigOptions;
+use datafusion_common::instant::Instant;
+use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
+use datafusion_common::{DFSchema, DataFusionError, Result};
+use datafusion_expr::logical_plan::LogicalPlan;
+
use crate::common_subexpr_eliminate::CommonSubexprEliminate;
use crate::decorrelate_predicate_subquery::DecorrelatePredicateSubquery;
use crate::eliminate_cross_join::EliminateCrossJoin;
@@ -45,15 +55,6 @@ use
crate::single_distinct_to_groupby::SingleDistinctToGroupBy;
use crate::unwrap_cast_in_comparison::UnwrapCastInComparison;
use crate::utils::log_plan;
-use datafusion_common::alias::AliasGenerator;
-use datafusion_common::config::ConfigOptions;
-use datafusion_common::instant::Instant;
-use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::logical_plan::LogicalPlan;
-
-use chrono::{DateTime, Utc};
-use log::{debug, warn};
-
/// `OptimizerRule`s transforms one [`LogicalPlan`] into another which
/// computes the same results, but in a potentially more efficient
/// way. If there are no suitable transformations for the input plan,
@@ -184,41 +185,15 @@ pub struct Optimizer {
pub rules: Vec<Arc<dyn OptimizerRule + Send + Sync>>,
}
-/// If a rule is with `ApplyOrder`, it means the optimizer will derive to
handle children instead of
-/// recursively handling in rule.
-/// We just need handle a subtree pattern itself.
-///
-/// Notice: **sometime** result after optimize still can be optimized, we need
apply again.
+/// Specifies how recursion for an `OptimizerRule` should be handled.
///
-/// Usage Example: Merge Limit (subtree pattern is: Limit-Limit)
-/// ```rust
-/// use datafusion_expr::{Limit, LogicalPlan, LogicalPlanBuilder};
-/// use datafusion_common::Result;
-/// fn merge_limit(parent: &Limit, child: &Limit) -> LogicalPlan {
-/// // just for run
-/// return parent.input.as_ref().clone();
-/// }
-/// fn try_optimize(plan: &LogicalPlan) -> Result<Option<LogicalPlan>> {
-/// match plan {
-/// LogicalPlan::Limit(limit) => match limit.input.as_ref() {
-/// LogicalPlan::Limit(child_limit) => {
-/// // merge limit ...
-/// let optimized_plan = merge_limit(limit, child_limit);
-/// // due to optimized_plan may be optimized again,
-/// // for example: plan is Limit-Limit-Limit
-/// Ok(Some(
-/// try_optimize(&optimized_plan)?
-/// .unwrap_or_else(|| optimized_plan.clone()),
-/// ))
-/// }
-/// _ => Ok(None),
-/// },
-/// _ => Ok(None),
-/// }
-/// }
-/// ```
+/// * `Some(apply_order)`: The Optimizer will recursively apply the rule to
the plan.
+/// * `None`: the rule must handle any required recursion itself.
+#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ApplyOrder {
+ /// Apply the rule to the node before its inputs
TopDown,
+ /// Apply the rule to the node after its inputs
BottomUp,
}
@@ -274,22 +249,78 @@ impl Optimizer {
pub fn with_rules(rules: Vec<Arc<dyn OptimizerRule + Send + Sync>>) ->
Self {
Self { rules }
}
+}
+
+/// Recursively rewrites LogicalPlans
+struct Rewriter<'a> {
+ apply_order: ApplyOrder,
+ rule: &'a dyn OptimizerRule,
+ config: &'a dyn OptimizerConfig,
+}
+impl<'a> Rewriter<'a> {
+ fn new(
+ apply_order: ApplyOrder,
+ rule: &'a dyn OptimizerRule,
+ config: &'a dyn OptimizerConfig,
+ ) -> Self {
+ Self {
+ apply_order,
+ rule,
+ config,
+ }
+ }
+}
+
+impl<'a> TreeNodeRewriter for Rewriter<'a> {
+ type Node = LogicalPlan;
+
+ fn f_down(&mut self, node: LogicalPlan) ->
Result<Transformed<LogicalPlan>> {
+ if self.apply_order == ApplyOrder::TopDown {
+ optimize_plan_node(node, self.rule, self.config)
+ } else {
+ Ok(Transformed::no(node))
+ }
+ }
+
+ fn f_up(&mut self, node: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
+ if self.apply_order == ApplyOrder::BottomUp {
+ optimize_plan_node(node, self.rule, self.config)
+ } else {
+ Ok(Transformed::no(node))
+ }
+ }
+}
+
+/// Invokes the Optimizer rule to rewrite the LogicalPlan in place.
+fn optimize_plan_node(
+ plan: LogicalPlan,
+ rule: &dyn OptimizerRule,
+ config: &dyn OptimizerConfig,
+) -> Result<Transformed<LogicalPlan>> {
+ // TODO: add API to OptimizerRule to allow rewriting by ownership
+ rule.try_optimize(&plan, config)
+ .map(|maybe_plan| match maybe_plan {
+ Some(new_plan) => Transformed::yes(new_plan),
+ None => Transformed::no(plan),
+ })
+}
+
+impl Optimizer {
/// Optimizes the logical plan by applying optimizer rules, and
/// invoking observer function after each call
pub fn optimize<F>(
&self,
- plan: &LogicalPlan,
+ plan: LogicalPlan,
config: &dyn OptimizerConfig,
mut observer: F,
) -> Result<LogicalPlan>
where
F: FnMut(&LogicalPlan, &dyn OptimizerRule),
{
- let options = config.options();
- let mut new_plan = plan.clone();
-
let start_time = Instant::now();
+ let options = config.options();
+ let mut new_plan = plan;
let mut previous_plans = HashSet::with_capacity(16);
previous_plans.insert(LogicalPlanSignature::new(&new_plan));
@@ -299,44 +330,71 @@ impl Optimizer {
log_plan(&format!("Optimizer input (pass {i})"), &new_plan);
for rule in &self.rules {
- let result =
- self.optimize_recursively(rule, &new_plan, config)
- .and_then(|plan| {
- if let Some(plan) = &plan {
- assert_schema_is_the_same(rule.name(), plan,
&new_plan)?;
- }
- Ok(plan)
- });
- match result {
- Ok(Some(plan)) => {
- new_plan = plan;
- observer(&new_plan, rule.as_ref());
- log_plan(rule.name(), &new_plan);
- }
- Ok(None) => {
+ // If skipping failed rules, copy plan before attempting to
rewrite
+ // as rewriting is destructive
+ let prev_plan = options
+ .optimizer
+ .skip_failed_rules
+ .then(|| new_plan.clone());
+
+ let starting_schema = new_plan.schema().clone();
+
+ let result = match rule.apply_order() {
+ // optimizer handles recursion
+ Some(apply_order) => new_plan.rewrite(&mut Rewriter::new(
+ apply_order,
+ rule.as_ref(),
+ config,
+ )),
+ // rule handles recursion itself
+ None => optimize_plan_node(new_plan, rule.as_ref(),
config),
+ }
+ // verify the rule didn't change the schema
+ .and_then(|tnr| {
+ assert_schema_is_the_same(rule.name(), &starting_schema,
&tnr.data)?;
+ Ok(tnr)
+ });
+
+ // Handle results
+ match (result, prev_plan) {
+ // OptimizerRule was successful
+ (
+ Ok(Transformed {
+ data, transformed, ..
+ }),
+ _,
+ ) => {
+ new_plan = data;
observer(&new_plan, rule.as_ref());
- debug!(
- "Plan unchanged by optimizer rule '{}' (pass {})",
- rule.name(),
- i
- );
+ if transformed {
+ log_plan(rule.name(), &new_plan);
+ } else {
+ debug!(
+ "Plan unchanged by optimizer rule '{}' (pass
{})",
+ rule.name(),
+ i
+ );
+ }
}
- Err(e) => {
- if options.optimizer.skip_failed_rules {
- // Note to future readers: if you see this warning
it signals a
- // bug in the DataFusion optimizer. Please
consider filing a ticket
- // https://github.com/apache/arrow-datafusion
- warn!(
+ // OptimizerRule was unsuccessful, but skipped failed
rules is on
+ // so use the previous plan
+ (Err(e), Some(orig_plan)) => {
+ // Note to future readers: if you see this warning it
signals a
+ // bug in the DataFusion optimizer. Please consider
filing a ticket
+ // https://github.com/apache/arrow-datafusion
+ warn!(
"Skipping optimizer rule '{}' due to unexpected
error: {}",
rule.name(),
e
);
- } else {
- return Err(DataFusionError::Context(
- format!("Optimizer rule '{}' failed",
rule.name(),),
- Box::new(e),
- ));
- }
+ new_plan = orig_plan;
+ }
+ // OptimizerRule was unsuccessful, but skipped failed
rules is off, return error
+ (Err(e), None) => {
+ return Err(e.context(format!(
+ "Optimizer rule '{}' failed",
+ rule.name()
+ )));
}
}
}
@@ -356,97 +414,22 @@ impl Optimizer {
debug!("Optimizer took {} ms", start_time.elapsed().as_millis());
Ok(new_plan)
}
-
- fn optimize_node(
- &self,
- rule: &Arc<dyn OptimizerRule + Send + Sync>,
- plan: &LogicalPlan,
- config: &dyn OptimizerConfig,
- ) -> Result<Option<LogicalPlan>> {
- // TODO: future feature: We can do Batch optimize
- rule.try_optimize(plan, config)
- }
-
- fn optimize_inputs(
- &self,
- rule: &Arc<dyn OptimizerRule + Send + Sync>,
- plan: &LogicalPlan,
- config: &dyn OptimizerConfig,
- ) -> Result<Option<LogicalPlan>> {
- let inputs = plan.inputs();
- let result = inputs
- .iter()
- .map(|sub_plan| self.optimize_recursively(rule, sub_plan, config))
- .collect::<Result<Vec<_>>>()?;
- if result.is_empty() || result.iter().all(|o| o.is_none()) {
- return Ok(None);
- }
-
- let new_inputs = result
- .into_iter()
- .zip(inputs)
- .map(|(new_plan, old_plan)| match new_plan {
- Some(plan) => plan,
- None => old_plan.clone(),
- })
- .collect();
-
- let exprs = plan.expressions();
- plan.with_new_exprs(exprs, new_inputs).map(Some)
- }
-
- /// Use a rule to optimize the whole plan.
- /// If the rule with `ApplyOrder`, we don't need to recursively handle
children in rule.
- pub fn optimize_recursively(
- &self,
- rule: &Arc<dyn OptimizerRule + Send + Sync>,
- plan: &LogicalPlan,
- config: &dyn OptimizerConfig,
- ) -> Result<Option<LogicalPlan>> {
- match rule.apply_order() {
- Some(order) => match order {
- ApplyOrder::TopDown => {
- let optimize_self_opt = self.optimize_node(rule, plan,
config)?;
- let optimize_inputs_opt = match &optimize_self_opt {
- Some(optimized_plan) => {
- self.optimize_inputs(rule, optimized_plan, config)?
- }
- _ => self.optimize_inputs(rule, plan, config)?,
- };
- Ok(optimize_inputs_opt.or(optimize_self_opt))
- }
- ApplyOrder::BottomUp => {
- let optimize_inputs_opt = self.optimize_inputs(rule, plan,
config)?;
- let optimize_self_opt = match &optimize_inputs_opt {
- Some(optimized_plan) => {
- self.optimize_node(rule, optimized_plan, config)?
- }
- _ => self.optimize_node(rule, plan, config)?,
- };
- Ok(optimize_self_opt.or(optimize_inputs_opt))
- }
- },
- _ => rule.try_optimize(plan, config),
- }
- }
}
-/// Returns an error if plans have different schemas.
+/// Returns an error if `new_plan`'s schema is different than `prev_schema`
///
/// It ignores metadata and nullability.
pub(crate) fn assert_schema_is_the_same(
rule_name: &str,
- prev_plan: &LogicalPlan,
+ prev_schema: &DFSchema,
new_plan: &LogicalPlan,
) -> Result<()> {
- let equivalent = new_plan
- .schema()
- .equivalent_names_and_types(prev_plan.schema());
+ let equivalent = new_plan.schema().equivalent_names_and_types(prev_schema);
if !equivalent {
let e = DataFusionError::Internal(format!(
"Failed due to a difference in schemas, original schema: {:?}, new
schema: {:?}",
- prev_plan.schema(),
+ prev_schema,
new_plan.schema()
));
Err(DataFusionError::Context(
@@ -462,14 +445,15 @@ pub(crate) fn assert_schema_is_the_same(
mod tests {
use std::sync::{Arc, Mutex};
- use super::ApplyOrder;
+ use datafusion_common::{plan_err, DFSchema, DFSchemaRef, Result};
+ use datafusion_expr::logical_plan::EmptyRelation;
+ use datafusion_expr::{col, lit, LogicalPlan, LogicalPlanBuilder,
Projection};
+
use crate::optimizer::Optimizer;
use crate::test::test_table_scan;
use crate::{OptimizerConfig, OptimizerContext, OptimizerRule};
- use datafusion_common::{plan_err, DFSchema, DFSchemaRef, Result};
- use datafusion_expr::logical_plan::EmptyRelation;
- use datafusion_expr::{col, lit, LogicalPlan, LogicalPlanBuilder,
Projection};
+ use super::ApplyOrder;
#[test]
fn skip_failing_rule() {
@@ -479,7 +463,7 @@ mod tests {
produce_one_row: false,
schema: Arc::new(DFSchema::empty()),
});
- opt.optimize(&plan, &config, &observe).unwrap();
+ opt.optimize(plan, &config, &observe).unwrap();
}
#[test]
@@ -490,7 +474,7 @@ mod tests {
produce_one_row: false,
schema: Arc::new(DFSchema::empty()),
});
- let err = opt.optimize(&plan, &config, &observe).unwrap_err();
+ let err = opt.optimize(plan, &config, &observe).unwrap_err();
assert_eq!(
"Optimizer rule 'bad rule' failed\ncaused by\n\
Error during planning: rule failed",
@@ -506,21 +490,27 @@ mod tests {
produce_one_row: false,
schema: Arc::new(DFSchema::empty()),
});
- let err = opt.optimize(&plan, &config, &observe).unwrap_err();
+ let err = opt.optimize(plan, &config, &observe).unwrap_err();
assert_eq!(
- "Optimizer rule 'get table_scan rule' failed\ncaused by\nget
table_scan rule\ncaused by\n\
- Internal error: Failed due to a difference in schemas, original
schema: \
- DFSchema { inner: Schema { fields: \
- [Field { name: \"a\", data_type: UInt32, nullable: false,
dict_id: 0, dict_is_ordered: false, metadata: {} }, \
- Field { name: \"b\", data_type: UInt32, nullable: false,
dict_id: 0, dict_is_ordered: false, metadata: {} }, \
- Field { name: \"c\", data_type: UInt32, nullable: false,
dict_id: 0, dict_is_ordered: false, metadata: {} }], metadata: {} }, \
- field_qualifiers: [Some(Bare { table: \"test\" }), Some(Bare {
table: \"test\" }), Some(Bare { table: \"test\" })], \
- functional_dependencies: FunctionalDependencies { deps: [] }
}, \
+ "Optimizer rule 'get table_scan rule' failed\n\
+ caused by\nget table_scan rule\ncaused by\n\
+ Internal error: Failed due to a difference in schemas, \
+ original schema: DFSchema { inner: Schema { \
+ fields: [], \
+ metadata: {} }, \
+ field_qualifiers: [], \
+ functional_dependencies: FunctionalDependencies { deps: [] } \
+ }, \
new schema: DFSchema { inner: Schema { \
- fields: [], metadata: {} }, \
- field_qualifiers: [], \
- functional_dependencies: FunctionalDependencies { deps: [] }
}.\n\
- This was likely caused by a bug in DataFusion's code and we
would welcome that you file an bug report in our issue tracker",
+ fields: [\
+ Field { name: \"a\", data_type: UInt32, nullable: false,
dict_id: 0, dict_is_ordered: false, metadata: {} }, \
+ Field { name: \"b\", data_type: UInt32, nullable: false,
dict_id: 0, dict_is_ordered: false, metadata: {} }, \
+ Field { name: \"c\", data_type: UInt32, nullable: false,
dict_id: 0, dict_is_ordered: false, metadata: {} }\
+ ], \
+ metadata: {} }, \
+ field_qualifiers: [Some(Bare { table: \"test\" }), Some(Bare {
table: \"test\" }), Some(Bare { table: \"test\" })], \
+ functional_dependencies: FunctionalDependencies { deps: [] } }.\n\
+ This was likely caused by a bug in DataFusion's code and we would
welcome that you file an bug report in our issue tracker",
err.strip_backtrace()
);
}
@@ -533,7 +523,7 @@ mod tests {
produce_one_row: false,
schema: Arc::new(DFSchema::empty()),
});
- opt.optimize(&plan, &config, &observe).unwrap();
+ opt.optimize(plan, &config, &observe).unwrap();
}
#[test]
@@ -554,7 +544,7 @@ mod tests {
// optimizing should be ok, but the schema will have changed (no
metadata)
assert_ne!(plan.schema().as_ref(), input_schema.as_ref());
- let optimized_plan = opt.optimize(&plan, &config, &observe)?;
+ let optimized_plan = opt.optimize(plan, &config, &observe)?;
// metadata was removed
assert_eq!(optimized_plan.schema().as_ref(), input_schema.as_ref());
Ok(())
@@ -575,7 +565,7 @@ mod tests {
let mut plans: Vec<LogicalPlan> = Vec::new();
let final_plan =
- opt.optimize(&initial_plan, &config, |p, _|
plans.push(p.clone()))?;
+ opt.optimize(initial_plan.clone(), &config, |p, _|
plans.push(p.clone()))?;
// initial_plan is not observed, so we have 3 plans
assert_eq!(3, plans.len());
@@ -601,7 +591,7 @@ mod tests {
let mut plans: Vec<LogicalPlan> = Vec::new();
let final_plan =
- opt.optimize(&initial_plan, &config, |p, _|
plans.push(p.clone()))?;
+ opt.optimize(initial_plan, &config, |p, _| plans.push(p.clone()))?;
// initial_plan is not observed, so we have 4 plans
assert_eq!(4, plans.len());
diff --git a/datafusion/optimizer/src/propagate_empty_relation.rs
b/datafusion/optimizer/src/propagate_empty_relation.rs
index 2aca6f9325..445109bbdf 100644
--- a/datafusion/optimizer/src/propagate_empty_relation.rs
+++ b/datafusion/optimizer/src/propagate_empty_relation.rs
@@ -198,12 +198,12 @@ mod tests {
use super::*;
- fn assert_eq(plan: &LogicalPlan, expected: &str) -> Result<()> {
+ fn assert_eq(plan: LogicalPlan, expected: &str) -> Result<()> {
assert_optimized_plan_eq(Arc::new(PropagateEmptyRelation::new()),
plan, expected)
}
fn assert_together_optimized_plan_eq(
- plan: &LogicalPlan,
+ plan: LogicalPlan,
expected: &str,
) -> Result<()> {
assert_optimized_plan_eq_with_rules(
@@ -226,7 +226,7 @@ mod tests {
.build()?;
let expected = "EmptyRelation";
- assert_eq(&plan, expected)
+ assert_eq(plan, expected)
}
#[test]
@@ -249,7 +249,7 @@ mod tests {
.build()?;
let expected = "EmptyRelation";
- assert_together_optimized_plan_eq(&plan, expected)
+ assert_together_optimized_plan_eq(plan, expected)
}
#[test]
@@ -262,7 +262,7 @@ mod tests {
let plan = LogicalPlanBuilder::from(left).union(right)?.build()?;
let expected = "TableScan: test";
- assert_together_optimized_plan_eq(&plan, expected)
+ assert_together_optimized_plan_eq(plan, expected)
}
#[test]
@@ -287,7 +287,7 @@ mod tests {
let expected = "Union\
\n TableScan: test1\
\n TableScan: test4";
- assert_together_optimized_plan_eq(&plan, expected)
+ assert_together_optimized_plan_eq(plan, expected)
}
#[test]
@@ -312,7 +312,7 @@ mod tests {
.build()?;
let expected = "EmptyRelation";
- assert_together_optimized_plan_eq(&plan, expected)
+ assert_together_optimized_plan_eq(plan, expected)
}
#[test]
@@ -339,7 +339,7 @@ mod tests {
let expected = "Union\
\n TableScan: test2\
\n TableScan: test3";
- assert_together_optimized_plan_eq(&plan, expected)
+ assert_together_optimized_plan_eq(plan, expected)
}
#[test]
@@ -352,7 +352,7 @@ mod tests {
let plan = LogicalPlanBuilder::from(left).union(right)?.build()?;
let expected = "TableScan: test";
- assert_together_optimized_plan_eq(&plan, expected)
+ assert_together_optimized_plan_eq(plan, expected)
}
#[test]
@@ -367,7 +367,7 @@ mod tests {
.build()?;
let expected = "EmptyRelation";
- assert_together_optimized_plan_eq(&plan, expected)
+ assert_together_optimized_plan_eq(plan, expected)
}
#[test]
@@ -400,6 +400,6 @@ mod tests {
let expected = "Projection: a, b, c\
\n TableScan: test";
- assert_together_optimized_plan_eq(&plan, expected)
+ assert_together_optimized_plan_eq(plan, expected)
}
}
diff --git a/datafusion/optimizer/src/push_down_filter.rs
b/datafusion/optimizer/src/push_down_filter.rs
index f3ce8bbcde..2b123e3559 100644
--- a/datafusion/optimizer/src/push_down_filter.rs
+++ b/datafusion/optimizer/src/push_down_filter.rs
@@ -1028,11 +1028,11 @@ fn contain(e: &Expr, check_map: &HashMap<String, Expr>)
-> bool {
#[cfg(test)]
mod tests {
+ use super::*;
use std::any::Any;
use std::fmt::{Debug, Formatter};
use std::sync::Arc;
- use super::*;
use crate::optimizer::Optimizer;
use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
use crate::test::*;
@@ -1040,6 +1040,7 @@ mod tests {
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion_common::{DFSchema, DFSchemaRef, ScalarValue};
+ use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::logical_plan::table_scan;
use datafusion_expr::{
and, col, in_list, in_subquery, lit, logical_plan::JoinType, or, sum,
BinaryExpr,
@@ -1049,9 +1050,9 @@ mod tests {
};
use async_trait::async_trait;
- use datafusion_expr::expr::ScalarFunction;
+ fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
- fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) ->
Result<()> {
+ fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) ->
Result<()> {
crate::test::assert_optimized_plan_eq(
Arc::new(PushDownFilter::new()),
plan,
@@ -1060,29 +1061,17 @@ mod tests {
}
fn assert_optimized_plan_eq_with_rewrite_predicate(
- plan: &LogicalPlan,
+ plan: LogicalPlan,
expected: &str,
) -> Result<()> {
let optimizer = Optimizer::with_rules(vec![
Arc::new(RewriteDisjunctivePredicate::new()),
Arc::new(PushDownFilter::new()),
]);
- let mut optimized_plan = optimizer
- .optimize_recursively(
- optimizer.rules.first().unwrap(),
- plan,
- &OptimizerContext::new(),
- )?
- .unwrap_or_else(|| plan.clone());
- optimized_plan = optimizer
- .optimize_recursively(
- optimizer.rules.get(1).unwrap(),
- &optimized_plan,
- &OptimizerContext::new(),
- )?
- .unwrap_or_else(|| plan.clone());
+ let optimized_plan =
+ optimizer.optimize(plan, &OptimizerContext::new(), observe)?;
+
let formatted_plan = format!("{optimized_plan:?}");
- assert_eq!(plan.schema(), optimized_plan.schema());
assert_eq!(expected, formatted_plan);
Ok(())
}
@@ -1098,7 +1087,7 @@ mod tests {
let expected = "\
Projection: test.a, test.b\
\n TableScan: test, full_filters=[test.a = Int64(1)]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -1115,7 +1104,7 @@ mod tests {
\n Limit: skip=0, fetch=10\
\n Projection: test.a, test.b\
\n TableScan: test";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -1125,7 +1114,7 @@ mod tests {
.filter(lit(0i64).eq(lit(1i64)))?
.build()?;
let expected = "TableScan: test, full_filters=[Int64(0) = Int64(1)]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -1141,7 +1130,7 @@ mod tests {
Projection: test.c, test.b\
\n Projection: test.a, test.b, test.c\
\n TableScan: test, full_filters=[test.a = Int64(1)]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -1155,7 +1144,7 @@ mod tests {
let expected = "\
Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS
total_salary]]\
\n TableScan: test, full_filters=[test.a > Int64(10)]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -1168,7 +1157,7 @@ mod tests {
let expected = "Filter: test.b > Int64(10)\
\n Aggregate: groupBy=[[test.b + test.a]], aggr=[[SUM(test.a),
test.b]]\
\n TableScan: test";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -1180,7 +1169,7 @@ mod tests {
let expected =
"Aggregate: groupBy=[[test.b + test.a]], aggr=[[SUM(test.a),
test.b]]\
\n TableScan: test, full_filters=[test.b + test.a > Int64(10)]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -1195,7 +1184,7 @@ mod tests {
Filter: b > Int64(10)\
\n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\
\n TableScan: test";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
/// verifies that a filter is pushed to before a projection, the filter
expression is correctly re-written
@@ -1210,7 +1199,7 @@ mod tests {
let expected = "\
Projection: test.a AS b, test.c\
\n TableScan: test, full_filters=[test.a = Int64(1)]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
fn add(left: Expr, right: Expr) -> Expr {
@@ -1254,7 +1243,7 @@ mod tests {
let expected = "\
Projection: test.a * Int32(2) + test.c AS b, test.c\
\n TableScan: test, full_filters=[test.a * Int32(2) + test.c =
Int64(1)]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
/// verifies that when a filter is pushed to after 2 projections, the
filter expression is correctly re-written
@@ -1286,7 +1275,7 @@ mod tests {
Projection: b * Int32(3) AS a, test.c\
\n Projection: test.a * Int32(2) + test.c AS b, test.c\
\n TableScan: test, full_filters=[(test.a * Int32(2) + test.c) *
Int32(3) = Int64(1)]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[derive(Debug, PartialEq, Eq, Hash)]
@@ -1349,7 +1338,7 @@ mod tests {
let expected = "\
NoopPlan\
\n TableScan: test, full_filters=[test.a = Int64(1)]";
- assert_optimized_plan_eq(&plan, expected)?;
+ assert_optimized_plan_eq(plan, expected)?;
let custom_plan = LogicalPlan::Extension(Extension {
node: Arc::new(NoopPlan {
@@ -1366,7 +1355,7 @@ mod tests {
Filter: test.c = Int64(2)\
\n NoopPlan\
\n TableScan: test, full_filters=[test.a = Int64(1)]";
- assert_optimized_plan_eq(&plan, expected)?;
+ assert_optimized_plan_eq(plan, expected)?;
let custom_plan = LogicalPlan::Extension(Extension {
node: Arc::new(NoopPlan {
@@ -1383,7 +1372,7 @@ mod tests {
NoopPlan\
\n TableScan: test, full_filters=[test.a = Int64(1)]\
\n TableScan: test, full_filters=[test.a = Int64(1)]";
- assert_optimized_plan_eq(&plan, expected)?;
+ assert_optimized_plan_eq(plan, expected)?;
let custom_plan = LogicalPlan::Extension(Extension {
node: Arc::new(NoopPlan {
@@ -1401,7 +1390,7 @@ mod tests {
\n NoopPlan\
\n TableScan: test, full_filters=[test.a = Int64(1)]\
\n TableScan: test, full_filters=[test.a = Int64(1)]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
/// verifies that when two filters apply after an aggregation that only
allows one to be pushed, one is pushed
@@ -1434,7 +1423,7 @@ mod tests {
\n Aggregate: groupBy=[[b]], aggr=[[SUM(test.c)]]\
\n Projection: test.a AS b, test.c\
\n TableScan: test, full_filters=[test.a > Int64(10)]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
/// verifies that when a filter with two predicates is applied after an
aggregation that only allows one to be pushed, one is pushed
@@ -1468,7 +1457,7 @@ mod tests {
\n Aggregate: groupBy=[[b]], aggr=[[SUM(test.c)]]\
\n Projection: test.a AS b, test.c\
\n TableScan: test, full_filters=[test.a > Int64(10)]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
/// verifies that when two limits are in place, we jump neither
@@ -1490,7 +1479,7 @@ mod tests {
\n Limit: skip=0, fetch=20\
\n Projection: test.a, test.b\
\n TableScan: test";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -1505,7 +1494,7 @@ mod tests {
let expected = "Union\
\n TableScan: test, full_filters=[test.a = Int64(1)]\
\n TableScan: test2, full_filters=[test2.a = Int64(1)]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -1528,7 +1517,7 @@ mod tests {
\n SubqueryAlias: test2\
\n Projection: test.a AS b\
\n TableScan: test, full_filters=[test.a = Int64(1)]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -1559,7 +1548,7 @@ mod tests {
\n Projection: test1.d, test1.e, test1.f\
\n TableScan: test1, full_filters=[test1.d > Int32(2)]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -1585,7 +1574,7 @@ mod tests {
\n TableScan: test, full_filters=[test.a = Int32(1)]\
\n Projection: test1.a, test1.b, test1.c\
\n TableScan: test1, full_filters=[test1.a > Int32(2)]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
/// verifies that filters with the same columns are correctly placed
@@ -1619,7 +1608,7 @@ mod tests {
\n Projection: test.a\
\n TableScan: test, full_filters=[test.a <= Int64(1)]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
/// verifies that filters to be placed on the same depth are ANDed
@@ -1649,7 +1638,7 @@ mod tests {
\n Limit: skip=0, fetch=1\
\n TableScan: test";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
/// verifies that filters on a plan with user nodes are not lost
@@ -1675,7 +1664,7 @@ mod tests {
TestUserDefined\
\n TableScan: test, full_filters=[test.a <= Int64(1)]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
/// post-on-join predicates on a column common to both sides is pushed to
both sides
@@ -1713,7 +1702,7 @@ mod tests {
\n TableScan: test, full_filters=[test.a <= Int64(1)]\
\n Projection: test2.a\
\n TableScan: test2, full_filters=[test2.a <= Int64(1)]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
/// post-using-join predicates on a column common to both sides is pushed
to both sides
@@ -1750,7 +1739,7 @@ mod tests {
\n TableScan: test, full_filters=[test.a <= Int64(1)]\
\n Projection: test2.a\
\n TableScan: test2, full_filters=[test2.a <= Int64(1)]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
/// post-join predicates with columns from both sides are converted to
join filterss
@@ -1792,7 +1781,7 @@ mod tests {
\n TableScan: test\
\n Projection: test2.a, test2.b\
\n TableScan: test2";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
/// post-join predicates with columns from one side of a join are pushed
only to that side
@@ -1834,7 +1823,7 @@ mod tests {
\n TableScan: test, full_filters=[test.b <= Int64(1)]\
\n Projection: test2.a, test2.c\
\n TableScan: test2";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
/// post-join predicates on the right side of a left join are not
duplicated
@@ -1873,7 +1862,7 @@ mod tests {
\n TableScan: test\
\n Projection: test2.a\
\n TableScan: test2";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
/// post-join predicates on the left side of a right join are not
duplicated
@@ -1911,7 +1900,7 @@ mod tests {
\n TableScan: test\
\n Projection: test2.a\
\n TableScan: test2";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
/// post-left-join predicate on a column common to both sides is only
pushed to the left side
@@ -1949,7 +1938,7 @@ mod tests {
\n TableScan: test, full_filters=[test.a <= Int64(1)]\
\n Projection: test2.a\
\n TableScan: test2";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
/// post-right-join predicate on a column common to both sides is only
pushed to the right side
@@ -1987,7 +1976,7 @@ mod tests {
\n TableScan: test\
\n Projection: test2.a\
\n TableScan: test2, full_filters=[test2.a <= Int64(1)]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
/// single table predicate parts of ON condition should be pushed to both
inputs
@@ -2030,7 +2019,7 @@ mod tests {
\n TableScan: test, full_filters=[test.c > UInt32(1)]\
\n Projection: test2.a, test2.b, test2.c\
\n TableScan: test2, full_filters=[test2.c > UInt32(4)]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
/// join filter should be completely removed after pushdown
@@ -2072,7 +2061,7 @@ mod tests {
\n TableScan: test, full_filters=[test.b > UInt32(1)]\
\n Projection: test2.a, test2.b, test2.c\
\n TableScan: test2, full_filters=[test2.c > UInt32(4)]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
/// predicate on join key in filter expression should be pushed down to
both inputs
@@ -2112,7 +2101,7 @@ mod tests {
\n TableScan: test, full_filters=[test.a > UInt32(1)]\
\n Projection: test2.b\
\n TableScan: test2, full_filters=[test2.b > UInt32(1)]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
/// single table predicate parts of ON condition should be pushed to right
input
@@ -2155,7 +2144,7 @@ mod tests {
\n TableScan: test\
\n Projection: test2.a, test2.b, test2.c\
\n TableScan: test2, full_filters=[test2.c > UInt32(4)]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
/// single table predicate parts of ON condition should be pushed to left
input
@@ -2198,7 +2187,7 @@ mod tests {
\n TableScan: test, full_filters=[test.a > UInt32(1)]\
\n Projection: test2.a, test2.b, test2.c\
\n TableScan: test2";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
/// single table predicate parts of ON condition should not be pushed
@@ -2236,7 +2225,7 @@ mod tests {
);
let expected = &format!("{plan:?}");
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
struct PushDownProvider {
@@ -2295,7 +2284,7 @@ mod tests {
let expected = "\
TableScan: test, full_filters=[a = Int64(1)]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -2306,7 +2295,7 @@ mod tests {
let expected = "\
Filter: a = Int64(1)\
\n TableScan: test, partial_filters=[a = Int64(1)]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -2314,7 +2303,7 @@ mod tests {
let plan =
table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
- let optimised_plan = PushDownFilter::new()
+ let optimized_plan = PushDownFilter::new()
.try_optimize(&plan, &OptimizerContext::new())
.expect("failed to optimize plan")
.unwrap();
@@ -2325,7 +2314,7 @@ mod tests {
// Optimizing the same plan multiple times should produce the same plan
// each time.
- assert_optimized_plan_eq(&optimised_plan, expected)
+ assert_optimized_plan_eq(optimized_plan, expected)
}
#[test]
@@ -2336,7 +2325,7 @@ mod tests {
let expected = "\
Filter: a = Int64(1)\
\n TableScan: test";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -2365,7 +2354,7 @@ mod tests {
\n Filter: a = Int64(10) AND b > Int64(11)\
\n TableScan: test projection=[a], partial_filters=[a =
Int64(10), b > Int64(11)]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -2396,7 +2385,7 @@ Projection: a, b
"#
.trim();
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -2424,7 +2413,7 @@ Projection: a, b
\n TableScan: test, full_filters=[test.a > Int64(10), test.c >
Int64(10)]\
";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -2456,7 +2445,7 @@ Projection: a, b
\n TableScan: test, full_filters=[test.a > Int64(10), test.c >
Int64(10)]\
";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -2481,7 +2470,7 @@ Projection: a, b
Projection: test.a AS b, test.c AS d\
\n TableScan: test, full_filters=[test.a > Int64(10), test.c >
Int64(10)]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
/// predicate on join key in filter expression should be pushed down to
both inputs
@@ -2521,7 +2510,7 @@ Projection: a, b
\n TableScan: test, full_filters=[test.a > UInt32(1)]\
\n Projection: test2.b AS d\
\n TableScan: test2, full_filters=[test2.b > UInt32(1)]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -2550,7 +2539,7 @@ Projection: a, b
Projection: test.a AS b, test.c\
\n TableScan: test, full_filters=[test.a IN ([UInt32(1),
UInt32(2), UInt32(3), UInt32(4)])]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -2582,7 +2571,7 @@ Projection: a, b
\n Projection: test.a AS b, test.c\
\n TableScan: test, full_filters=[test.a IN ([UInt32(1),
UInt32(2), UInt32(3), UInt32(4)])]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -2618,7 +2607,7 @@ Projection: a, b
\n Subquery:\
\n Projection: sq.c\
\n TableScan: sq";
- assert_optimized_plan_eq(&plan, expected_after)
+ assert_optimized_plan_eq(plan, expected_after)
}
#[test]
@@ -2651,7 +2640,7 @@ Projection: a, b
\n Projection: Int64(0) AS a\
\n Filter: Int64(0) = Int64(1)\
\n EmptyRelation";
- assert_optimized_plan_eq(&plan, expected_after)
+ assert_optimized_plan_eq(plan, expected_after)
}
#[test]
@@ -2679,14 +2668,14 @@ Projection: a, b
\n TableScan: test, full_filters=[test.b > UInt32(1) OR test.c <
UInt32(10)]\
\n Projection: test1.a AS d, test1.a AS e\
\n TableScan: test1";
- assert_optimized_plan_eq_with_rewrite_predicate(&plan, expected)?;
+ assert_optimized_plan_eq_with_rewrite_predicate(plan.clone(),
expected)?;
// Originally global state which can help to avoid duplicate Filters
been generated and pushed down.
// Now the global state is removed. Need to double confirm that avoid
duplicate Filters.
let optimized_plan = PushDownFilter::new()
.try_optimize(&plan, &OptimizerContext::new())?
.expect("failed to optimize plan");
- assert_optimized_plan_eq(&optimized_plan, expected)
+ assert_optimized_plan_eq(optimized_plan, expected)
}
#[test]
@@ -2727,7 +2716,7 @@ Projection: a, b
\n TableScan: test1, full_filters=[test1.b > UInt32(1)]\
\n Projection: test2.a, test2.b\
\n TableScan: test2, full_filters=[test2.b > UInt32(2)]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -2768,7 +2757,7 @@ Projection: a, b
\n TableScan: test1, full_filters=[test1.b > UInt32(1)]\
\n Projection: test2.a, test2.b\
\n TableScan: test2, full_filters=[test2.b > UInt32(2)]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -2814,7 +2803,7 @@ Projection: a, b
\n TableScan: test1\
\n Projection: test2.a, test2.b\
\n TableScan: test2, full_filters=[test2.b > UInt32(2)]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -2859,7 +2848,7 @@ Projection: a, b
\n TableScan: test1, full_filters=[test1.b > UInt32(1)]\
\n Projection: test2.a, test2.b\
\n TableScan: test2";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[derive(Debug)]
@@ -2919,7 +2908,7 @@ Projection: a, b
\n Projection: test1.a, SUM(test1.b), TestScalarUDF() + Int32(1)
AS r\
\n Aggregate: groupBy=[[test1.a]], aggr=[[SUM(test1.b)]]\
\n TableScan: test1, full_filters=[test1.a > Int32(5)]";
- assert_optimized_plan_eq(&plan, expected_after)
+ assert_optimized_plan_eq(plan, expected_after)
}
#[test]
@@ -2965,6 +2954,6 @@ Projection: a, b
\n Inner Join: test1.a = test2.a\
\n TableScan: test1\
\n TableScan: test2";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
}
diff --git a/datafusion/optimizer/src/push_down_limit.rs
b/datafusion/optimizer/src/push_down_limit.rs
index cca6c3fd9b..6f1d7bf97c 100644
--- a/datafusion/optimizer/src/push_down_limit.rs
+++ b/datafusion/optimizer/src/push_down_limit.rs
@@ -285,7 +285,7 @@ mod test {
max,
};
- fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) ->
Result<()> {
+ fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) ->
Result<()> {
assert_optimized_plan_eq(Arc::new(PushDownLimit::new()), plan,
expected)
}
@@ -304,7 +304,7 @@ mod test {
\n Limit: skip=0, fetch=1000\
\n TableScan: test, fetch=1000";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -322,7 +322,7 @@ mod test {
let expected = "Limit: skip=0, fetch=10\
\n TableScan: test, fetch=10";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -339,7 +339,7 @@ mod test {
\n Aggregate: groupBy=[[test.a]], aggr=[[MAX(test.b)]]\
\n TableScan: test";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -359,7 +359,7 @@ mod test {
\n Limit: skip=0, fetch=1000\
\n TableScan: test, fetch=1000";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -376,7 +376,7 @@ mod test {
\n Sort: test.a, fetch=10\
\n TableScan: test";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -393,7 +393,7 @@ mod test {
\n Sort: test.a, fetch=15\
\n TableScan: test";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -412,7 +412,7 @@ mod test {
\n Limit: skip=0, fetch=1000\
\n TableScan: test, fetch=1000";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -427,7 +427,7 @@ mod test {
let expected = "Limit: skip=10, fetch=None\
\n TableScan: test";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -445,7 +445,7 @@ mod test {
\n Limit: skip=10, fetch=1000\
\n TableScan: test, fetch=1010";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -462,7 +462,7 @@ mod test {
\n Limit: skip=10, fetch=990\
\n TableScan: test, fetch=1000";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -479,7 +479,7 @@ mod test {
\n Limit: skip=10, fetch=1000\
\n TableScan: test, fetch=1010";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -495,7 +495,7 @@ mod test {
let expected = "Limit: skip=10, fetch=10\
\n TableScan: test, fetch=20";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -512,7 +512,7 @@ mod test {
\n Aggregate: groupBy=[[test.a]], aggr=[[MAX(test.b)]]\
\n TableScan: test";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -532,7 +532,7 @@ mod test {
\n Limit: skip=0, fetch=1010\
\n TableScan: test, fetch=1010";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -556,7 +556,7 @@ mod test {
\n TableScan: test\
\n TableScan: test2";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -580,7 +580,7 @@ mod test {
\n TableScan: test\
\n TableScan: test2";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -609,7 +609,7 @@ mod test {
\n Projection: test2.a\
\n TableScan: test2";
- assert_optimized_plan_equal(&outer_query, expected)
+ assert_optimized_plan_equal(outer_query, expected)
}
#[test]
@@ -638,7 +638,7 @@ mod test {
\n Projection: test2.a\
\n TableScan: test2";
- assert_optimized_plan_equal(&outer_query, expected)
+ assert_optimized_plan_equal(outer_query, expected)
}
#[test]
@@ -664,7 +664,7 @@ mod test {
\n Limit: skip=0, fetch=1000\
\n TableScan: test2, fetch=1000";
- assert_optimized_plan_equal(&plan, expected)?;
+ assert_optimized_plan_equal(plan, expected)?;
let plan = LogicalPlanBuilder::from(table_scan_1.clone())
.join(
@@ -683,7 +683,7 @@ mod test {
\n Limit: skip=0, fetch=1000\
\n TableScan: test2, fetch=1000";
- assert_optimized_plan_equal(&plan, expected)?;
+ assert_optimized_plan_equal(plan, expected)?;
let plan = LogicalPlanBuilder::from(table_scan_1.clone())
.join(
@@ -702,7 +702,7 @@ mod test {
\n Limit: skip=0, fetch=1000\
\n TableScan: test2, fetch=1000";
- assert_optimized_plan_equal(&plan, expected)?;
+ assert_optimized_plan_equal(plan, expected)?;
let plan = LogicalPlanBuilder::from(table_scan_1.clone())
.join(
@@ -720,7 +720,7 @@ mod test {
\n TableScan: test, fetch=1000\
\n TableScan: test2";
- assert_optimized_plan_equal(&plan, expected)?;
+ assert_optimized_plan_equal(plan, expected)?;
let plan = LogicalPlanBuilder::from(table_scan_1.clone())
.join(
@@ -738,7 +738,7 @@ mod test {
\n TableScan: test, fetch=1000\
\n TableScan: test2";
- assert_optimized_plan_equal(&plan, expected)?;
+ assert_optimized_plan_equal(plan, expected)?;
let plan = LogicalPlanBuilder::from(table_scan_1.clone())
.join(
@@ -756,7 +756,7 @@ mod test {
\n Limit: skip=0, fetch=1000\
\n TableScan: test2, fetch=1000";
- assert_optimized_plan_equal(&plan, expected)?;
+ assert_optimized_plan_equal(plan, expected)?;
let plan = LogicalPlanBuilder::from(table_scan_1)
.join(
@@ -774,7 +774,7 @@ mod test {
\n Limit: skip=0, fetch=1000\
\n TableScan: test2, fetch=1000";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -799,7 +799,7 @@ mod test {
\n TableScan: test, fetch=1000\
\n TableScan: test2";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -824,7 +824,7 @@ mod test {
\n TableScan: test, fetch=1010\
\n TableScan: test2";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -849,7 +849,7 @@ mod test {
\n Limit: skip=0, fetch=1000\
\n TableScan: test2, fetch=1000";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -874,7 +874,7 @@ mod test {
\n Limit: skip=0, fetch=1010\
\n TableScan: test2, fetch=1010";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -894,7 +894,7 @@ mod test {
\n Limit: skip=0, fetch=1000\
\n TableScan: test2, fetch=1000";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -914,7 +914,7 @@ mod test {
\n Limit: skip=0, fetch=2000\
\n TableScan: test2, fetch=2000";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -929,7 +929,7 @@ mod test {
let expected = "Limit: skip=1000, fetch=0\
\n TableScan: test, fetch=0";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -944,7 +944,7 @@ mod test {
let expected = "Limit: skip=1000, fetch=0\
\n TableScan: test, fetch=0";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -961,6 +961,6 @@ mod test {
\n Limit: skip=1000, fetch=0\
\n TableScan: test, fetch=0";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
}
diff --git a/datafusion/optimizer/src/push_down_projection.rs
b/datafusion/optimizer/src/push_down_projection.rs
index ae57ed9e5a..2f578094b3 100644
--- a/datafusion/optimizer/src/push_down_projection.rs
+++ b/datafusion/optimizer/src/push_down_projection.rs
@@ -24,7 +24,7 @@ mod tests {
use crate::optimize_projections::OptimizeProjections;
use crate::optimizer::Optimizer;
use crate::test::*;
- use crate::OptimizerContext;
+ use crate::{OptimizerContext, OptimizerRule};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::{Column, DFSchema, Result};
use datafusion_expr::builder::table_scan_with_filters;
@@ -48,7 +48,7 @@ mod tests {
let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(test.b)]]\
\n TableScan: test projection=[b]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -62,7 +62,7 @@ mod tests {
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[MAX(test.b)]]\
\n TableScan: test projection=[b, c]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -78,7 +78,7 @@ mod tests {
\n SubqueryAlias: a\
\n TableScan: test projection=[b, c]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -95,7 +95,7 @@ mod tests {
\n Filter: test.c > Int32(1)\
\n TableScan: test projection=[b, c]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -120,7 +120,7 @@ mod tests {
Aggregate: groupBy=[[]], aggr=[[MAX(m4.tag.one) AS tag.one]]\
\n TableScan: m4 projection=[tag.one]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -134,7 +134,7 @@ mod tests {
let expected = "Projection: test.a, test.c, test.b\
\n TableScan: test projection=[a, b, c]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -144,7 +144,7 @@ mod tests {
let plan = table_scan(Some("test"), &schema, Some(vec![1, 0,
2]))?.build()?;
let expected = "TableScan: test projection=[b, a, c]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -157,7 +157,7 @@ mod tests {
let expected = "Projection: test.a, test.b\
\n TableScan: test projection=[b, a]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -170,7 +170,7 @@ mod tests {
let expected = "Projection: test.c, test.b, test.a\
\n TableScan: test projection=[a, b, c]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -192,7 +192,7 @@ mod tests {
\n Filter: test.c > Int32(1)\
\n Projection: test.c, test.b, test.a\
\n TableScan: test projection=[a, b, c]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -212,7 +212,7 @@ mod tests {
\n TableScan: test projection=[a, b]\
\n TableScan: test2 projection=[c1]";
- let optimized_plan = optimize(&plan)?;
+ let optimized_plan = optimize(plan)?;
let formatted_plan = format!("{optimized_plan:?}");
assert_eq!(formatted_plan, expected);
@@ -264,7 +264,7 @@ mod tests {
\n TableScan: test projection=[a, b]\
\n TableScan: test2 projection=[c1]";
- let optimized_plan = optimize(&plan)?;
+ let optimized_plan = optimize(plan)?;
let formatted_plan = format!("{optimized_plan:?}");
assert_eq!(formatted_plan, expected);
@@ -314,7 +314,7 @@ mod tests {
\n TableScan: test projection=[a, b]\
\n TableScan: test2 projection=[a]";
- let optimized_plan = optimize(&plan)?;
+ let optimized_plan = optimize(plan)?;
let formatted_plan = format!("{optimized_plan:?}");
assert_eq!(formatted_plan, expected);
@@ -358,7 +358,7 @@ mod tests {
let expected = "Projection: CAST(test.c AS Float64)\
\n TableScan: test projection=[c]";
- assert_optimized_plan_eq(&projection, expected)
+ assert_optimized_plan_eq(projection, expected)
}
#[test]
@@ -374,7 +374,7 @@ mod tests {
let expected = "TableScan: test projection=[a, b]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -395,7 +395,7 @@ mod tests {
let expected = "TableScan: test projection=[a, b]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -415,7 +415,7 @@ mod tests {
\n Projection: test.c, test.a\
\n TableScan: test projection=[a, c]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -424,7 +424,7 @@ mod tests {
let plan = LogicalPlanBuilder::from(table_scan).build()?;
// should expand projection to all columns without projection
let expected = "TableScan: test projection=[a, b, c]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -435,7 +435,7 @@ mod tests {
.build()?;
let expected = "Projection: Int64(1), Int64(2)\
\n TableScan: test projection=[]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
/// tests that it removes unused columns in projections
@@ -454,14 +454,14 @@ mod tests {
assert_fields_eq(&plan, vec!["c", "MAX(test.a)"]);
- let plan = optimize(&plan).expect("failed to optimize plan");
+ let plan = optimize(plan).expect("failed to optimize plan");
let expected = "\
Aggregate: groupBy=[[test.c]], aggr=[[MAX(test.a)]]\
\n Filter: test.c > Int32(1)\
\n Projection: test.c, test.a\
\n TableScan: test projection=[a, c]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
/// tests that it removes un-needed projections
@@ -483,7 +483,7 @@ mod tests {
Projection: Int32(1) AS a\
\n TableScan: test projection=[]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -512,7 +512,7 @@ mod tests {
Projection: Int32(1) AS a\
\n TableScan: test projection=[], full_filters=[b = Int32(1)]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
/// tests that optimizing twice yields same plan
@@ -525,9 +525,9 @@ mod tests {
.project(vec![lit(1).alias("a")])?
.build()?;
- let optimized_plan1 = optimize(&plan).expect("failed to optimize
plan");
+ let optimized_plan1 = optimize(plan).expect("failed to optimize plan");
let optimized_plan2 =
- optimize(&optimized_plan1).expect("failed to optimize plan");
+ optimize(optimized_plan1.clone()).expect("failed to optimize
plan");
let formatted_plan1 = format!("{optimized_plan1:?}");
let formatted_plan2 = format!("{optimized_plan2:?}");
@@ -556,7 +556,7 @@ mod tests {
\n Aggregate: groupBy=[[test.a, test.c]], aggr=[[MAX(test.b)]]\
\n TableScan: test projection=[a, b, c]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -582,7 +582,7 @@ mod tests {
let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(test.b),
COUNT(test.b) FILTER (WHERE test.c > Int32(42)) AS count2]]\
\n TableScan: test projection=[a, b, c]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -599,7 +599,7 @@ mod tests {
\n Distinct:\
\n TableScan: test projection=[a, b]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
#[test]
@@ -638,25 +638,23 @@ mod tests {
\n WindowAggr: windowExpr=[[MAX(test.a) PARTITION BY [test.b]
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\
\n TableScan: test projection=[a, b]";
- assert_optimized_plan_eq(&plan, expected)
+ assert_optimized_plan_eq(plan, expected)
}
- fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) ->
Result<()> {
+ fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) ->
Result<()> {
let optimized_plan = optimize(plan).expect("failed to optimize plan");
let formatted_plan = format!("{optimized_plan:?}");
assert_eq!(formatted_plan, expected);
Ok(())
}
- fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
+ fn optimize(plan: LogicalPlan) -> Result<LogicalPlan> {
let optimizer =
Optimizer::with_rules(vec![Arc::new(OptimizeProjections::new())]);
- let optimized_plan = optimizer
- .optimize_recursively(
- optimizer.rules.first().unwrap(),
- plan,
- &OptimizerContext::new(),
- )?
- .unwrap_or_else(|| plan.clone());
+ let optimized_plan =
+ optimizer.optimize(plan, &OptimizerContext::new(), observe)?;
+
Ok(optimized_plan)
}
+
+ fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
}
diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs
b/datafusion/optimizer/src/replace_distinct_aggregate.rs
index 752915be69..f464506057 100644
--- a/datafusion/optimizer/src/replace_distinct_aggregate.rs
+++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs
@@ -172,7 +172,7 @@ mod tests {
assert_optimized_plan_eq(
Arc::new(ReplaceDistinctWithAggregate::new()),
- &plan,
+ plan,
expected,
)
}
@@ -195,7 +195,7 @@ mod tests {
assert_optimized_plan_eq(
Arc::new(ReplaceDistinctWithAggregate::new()),
- &plan,
+ plan,
expected,
)
}
diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs
b/datafusion/optimizer/src/scalar_subquery_to_join.rs
index a2c4eabcaa..a8999f9c1d 100644
--- a/datafusion/optimizer/src/scalar_subquery_to_join.rs
+++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs
@@ -429,7 +429,7 @@ mod tests {
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64,
o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
- &plan,
+ plan,
expected,
);
Ok(())
@@ -485,7 +485,7 @@ mod tests {
\n TableScan: lineitem [l_orderkey:Int64,
l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64,
l_extendedprice:Float64]";
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
- &plan,
+ plan,
expected,
);
Ok(())
@@ -523,7 +523,7 @@ mod tests {
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
- &plan,
+ plan,
expected,
);
Ok(())
@@ -559,7 +559,7 @@ mod tests {
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64,
o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
- &plan,
+ plan,
expected,
);
Ok(())
@@ -593,7 +593,7 @@ mod tests {
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
- &plan,
+ plan,
expected,
);
Ok(())
@@ -732,7 +732,7 @@ mod tests {
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
- &plan,
+ plan,
expected,
);
Ok(())
@@ -798,7 +798,7 @@ mod tests {
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
- &plan,
+ plan,
expected,
);
Ok(())
@@ -837,7 +837,7 @@ mod tests {
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
- &plan,
+ plan,
expected,
);
Ok(())
@@ -877,7 +877,7 @@ mod tests {
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
- &plan,
+ plan,
expected,
);
Ok(())
@@ -910,7 +910,7 @@ mod tests {
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
- &plan,
+ plan,
expected,
);
Ok(())
@@ -942,7 +942,7 @@ mod tests {
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
- &plan,
+ plan,
expected,
);
Ok(())
@@ -973,7 +973,7 @@ mod tests {
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
- &plan,
+ plan,
expected,
);
Ok(())
@@ -1030,7 +1030,7 @@ mod tests {
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
- &plan,
+ plan,
expected,
);
Ok(())
@@ -1079,7 +1079,7 @@ mod tests {
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
- &plan,
+ plan,
expected,
);
Ok(())
diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs
b/datafusion/optimizer/src/single_distinct_to_groupby.rs
index 076bf4e242..602994a9e3 100644
--- a/datafusion/optimizer/src/single_distinct_to_groupby.rs
+++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs
@@ -313,7 +313,7 @@ mod tests {
min, sum, AggregateFunction,
};
- fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) ->
Result<()> {
+ fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) ->
Result<()> {
assert_optimized_plan_eq_display_indent(
Arc::new(SingleDistinctToGroupBy::new()),
plan,
@@ -335,7 +335,7 @@ mod tests {
"Aggregate: groupBy=[[]], aggr=[[MAX(test.b)]]
[MAX(test.b):UInt32;N]\
\n TableScan: test [a:UInt32, b:UInt32,
c:UInt32]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -352,7 +352,7 @@ mod tests {
\n Aggregate: groupBy=[[test.b AS alias1]],
aggr=[[]] [alias1:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32,
c:UInt32]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
// Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET
@@ -373,7 +373,7 @@ mod tests {
let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a),
(test.b))]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N,
COUNT(DISTINCT test.c):Int64;N]\
\n TableScan: test [a:UInt32, b:UInt32,
c:UInt32]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
// Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET
@@ -391,7 +391,7 @@ mod tests {
let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]],
aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT
test.c):Int64;N]\
\n TableScan: test [a:UInt32, b:UInt32,
c:UInt32]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
// Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET
@@ -410,7 +410,7 @@ mod tests {
let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]],
aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT
test.c):Int64;N]\
\n TableScan: test [a:UInt32, b:UInt32,
c:UInt32]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -426,7 +426,7 @@ mod tests {
\n Aggregate: groupBy=[[Int32(2) * test.b AS
alias1]], aggr=[[]] [alias1:Int32]\
\n TableScan: test [a:UInt32, b:UInt32,
c:UInt32]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -443,7 +443,7 @@ mod tests {
\n Aggregate: groupBy=[[test.a, test.b AS
alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32,
c:UInt32]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -461,7 +461,7 @@ mod tests {
let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(DISTINCT
test.b), COUNT(DISTINCT test.c)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N,
COUNT(DISTINCT test.c):Int64;N]\
\n TableScan: test [a:UInt32, b:UInt32,
c:UInt32]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -490,7 +490,7 @@ mod tests {
\n Aggregate: groupBy=[[test.a, test.b AS
alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32,
c:UInt32]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -508,7 +508,7 @@ mod tests {
let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(DISTINCT
test.b), COUNT(test.c)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N,
COUNT(test.c):Int64;N]\
\n TableScan: test [a:UInt32, b:UInt32,
c:UInt32]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -525,7 +525,7 @@ mod tests {
\n Aggregate: groupBy=[[test.a + Int32(1) AS
group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32,
alias1:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32,
c:UInt32]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -555,7 +555,7 @@ mod tests {
\n Aggregate: groupBy=[[test.a, test.b AS
alias1]], aggr=[[SUM(test.c) AS alias2]] [a:UInt32, alias1:UInt32,
alias2:UInt64;N]\
\n TableScan: test [a:UInt32, b:UInt32,
c:UInt32]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -574,7 +574,7 @@ mod tests {
\n Aggregate: groupBy=[[test.a, test.b AS
alias1]], aggr=[[SUM(test.c) AS alias2, MAX(test.c) AS alias3]] [a:UInt32,
alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N]\
\n TableScan: test [a:UInt32, b:UInt32,
c:UInt32]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -593,7 +593,7 @@ mod tests {
\n Aggregate: groupBy=[[test.c, test.b AS
alias1]], aggr=[[MIN(test.a) AS alias2]] [c:UInt32, alias1:UInt32,
alias2:UInt32;N]\
\n TableScan: test [a:UInt32, b:UInt32,
c:UInt32]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -616,7 +616,7 @@ mod tests {
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a)
FILTER (WHERE test.a > Int32(5)), COUNT(DISTINCT test.b)]] [c:UInt32,
SUM(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, COUNT(DISTINCT
test.b):Int64;N]\
\n TableScan: test [a:UInt32, b:UInt32,
c:UInt32]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -639,7 +639,7 @@ mod tests {
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a),
COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32,
SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a >
Int32(5)):Int64;N]\
\n TableScan: test [a:UInt32, b:UInt32,
c:UInt32]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -662,7 +662,7 @@ mod tests {
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a)
ORDER BY [test.a], COUNT(DISTINCT test.b)]] [c:UInt32, SUM(test.a) ORDER BY
[test.a]:UInt64;N, COUNT(DISTINCT test.b):Int64;N]\
\n TableScan: test [a:UInt32, b:UInt32,
c:UInt32]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -685,7 +685,7 @@ mod tests {
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a),
COUNT(DISTINCT test.a) ORDER BY [test.a]]] [c:UInt32, SUM(test.a):UInt64;N,
COUNT(DISTINCT test.a) ORDER BY [test.a]:Int64;N]\
\n TableScan: test [a:UInt32, b:UInt32,
c:UInt32]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
#[test]
@@ -708,6 +708,6 @@ mod tests {
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a),
COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]]]
[c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a >
Int32(5)) ORDER BY [test.a]:Int64;N]\
\n TableScan: test [a:UInt32, b:UInt32,
c:UInt32]";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal(plan, expected)
}
}
diff --git a/datafusion/optimizer/src/test/mod.rs
b/datafusion/optimizer/src/test/mod.rs
index e691fe9a53..cafda8359a 100644
--- a/datafusion/optimizer/src/test/mod.rs
+++ b/datafusion/optimizer/src/test/mod.rs
@@ -16,7 +16,7 @@
// under the License.
use crate::analyzer::{Analyzer, AnalyzerRule};
-use crate::optimizer::{assert_schema_is_the_same, Optimizer};
+use crate::optimizer::Optimizer;
use crate::{OptimizerContext, OptimizerRule};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::config::ConfigOptions;
@@ -150,22 +150,19 @@ pub fn assert_analyzer_check_err(
}
}
}
+
+fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
+
pub fn assert_optimized_plan_eq(
rule: Arc<dyn OptimizerRule + Send + Sync>,
- plan: &LogicalPlan,
+ plan: LogicalPlan,
expected: &str,
) -> Result<()> {
- let optimizer = Optimizer::with_rules(vec![rule.clone()]);
- let optimized_plan = optimizer
- .optimize_recursively(
- optimizer.rules.first().unwrap(),
- plan,
- &OptimizerContext::new(),
- )?
- .unwrap_or_else(|| plan.clone());
+ // Apply the rule once
+ let opt_context = OptimizerContext::new().with_max_passes(1);
- // Ensure schemas always match after an optimization
- assert_schema_is_the_same(rule.name(), plan, &optimized_plan)?;
+ let optimizer = Optimizer::with_rules(vec![rule.clone()]);
+ let optimized_plan = optimizer.optimize(plan, &opt_context, observe)?;
let formatted_plan = format!("{optimized_plan:?}");
assert_eq!(formatted_plan, expected);
@@ -174,7 +171,7 @@ pub fn assert_optimized_plan_eq(
pub fn assert_optimized_plan_eq_with_rules(
rules: Vec<Arc<dyn OptimizerRule + Send + Sync>>,
- plan: &LogicalPlan,
+ plan: LogicalPlan,
expected: &str,
) -> Result<()> {
fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
@@ -187,58 +184,44 @@ pub fn assert_optimized_plan_eq_with_rules(
.expect("failed to optimize plan");
let formatted_plan = format!("{optimized_plan:?}");
assert_eq!(formatted_plan, expected);
- assert_eq!(plan.schema(), optimized_plan.schema());
Ok(())
}
pub fn assert_optimized_plan_eq_display_indent(
rule: Arc<dyn OptimizerRule + Send + Sync>,
- plan: &LogicalPlan,
+ plan: LogicalPlan,
expected: &str,
) {
let optimizer = Optimizer::with_rules(vec![rule]);
let optimized_plan = optimizer
- .optimize_recursively(
- optimizer.rules.first().unwrap(),
- plan,
- &OptimizerContext::new(),
- )
- .expect("failed to optimize plan")
- .unwrap_or_else(|| plan.clone());
+ .optimize(plan, &OptimizerContext::new(), observe)
+ .expect("failed to optimize plan");
let formatted_plan = optimized_plan.display_indent_schema().to_string();
assert_eq!(formatted_plan, expected);
}
pub fn assert_multi_rules_optimized_plan_eq_display_indent(
rules: Vec<Arc<dyn OptimizerRule + Send + Sync>>,
- plan: &LogicalPlan,
+ plan: LogicalPlan,
expected: &str,
) {
let optimizer = Optimizer::with_rules(rules);
- let mut optimized_plan = plan.clone();
- for rule in &optimizer.rules {
- optimized_plan = optimizer
- .optimize_recursively(rule, &optimized_plan,
&OptimizerContext::new())
- .expect("failed to optimize plan")
- .unwrap_or_else(|| optimized_plan.clone());
- }
+ let optimized_plan = optimizer
+ .optimize(plan, &OptimizerContext::new(), observe)
+ .expect("failed to optimize plan");
let formatted_plan = optimized_plan.display_indent_schema().to_string();
assert_eq!(formatted_plan, expected);
}
pub fn assert_optimizer_err(
rule: Arc<dyn OptimizerRule + Send + Sync>,
- plan: &LogicalPlan,
+ plan: LogicalPlan,
expected: &str,
) {
let optimizer = Optimizer::with_rules(vec![rule]);
- let res = optimizer.optimize_recursively(
- optimizer.rules.first().unwrap(),
- plan,
- &OptimizerContext::new(),
- );
+ let res = optimizer.optimize(plan, &OptimizerContext::new(), observe);
match res {
- Ok(plan) => assert_eq!(format!("{}", plan.unwrap().display_indent()),
"An error"),
+ Ok(plan) => assert_eq!(format!("{}", plan.display_indent()), "An
error"),
Err(ref e) => {
let actual = format!("{e}");
if expected.is_empty() || !actual.contains(expected) {
@@ -250,16 +233,11 @@ pub fn assert_optimizer_err(
pub fn assert_optimization_skipped(
rule: Arc<dyn OptimizerRule + Send + Sync>,
- plan: &LogicalPlan,
+ plan: LogicalPlan,
) -> Result<()> {
let optimizer = Optimizer::with_rules(vec![rule]);
- let new_plan = optimizer
- .optimize_recursively(
- optimizer.rules.first().unwrap(),
- plan,
- &OptimizerContext::new(),
- )?
- .unwrap_or_else(|| plan.clone());
+ let new_plan = optimizer.optimize(plan.clone(), &OptimizerContext::new(),
observe)?;
+
assert_eq!(
format!("{}", plan.display_indent()),
format!("{}", new_plan.display_indent())
diff --git a/datafusion/optimizer/tests/optimizer_integration.rs
b/datafusion/optimizer/tests/optimizer_integration.rs
index 61d2535930..01db5e817c 100644
--- a/datafusion/optimizer/tests/optimizer_integration.rs
+++ b/datafusion/optimizer/tests/optimizer_integration.rs
@@ -25,7 +25,7 @@ use datafusion_common::{plan_err, Result};
use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource,
WindowUDF};
use datafusion_optimizer::analyzer::Analyzer;
use datafusion_optimizer::optimizer::Optimizer;
-use datafusion_optimizer::{OptimizerConfig, OptimizerContext};
+use datafusion_optimizer::{OptimizerConfig, OptimizerContext, OptimizerRule};
use datafusion_sql::planner::{ContextProvider, SqlToRel};
use datafusion_sql::sqlparser::ast::Statement;
use datafusion_sql::sqlparser::dialect::GenericDialect;
@@ -315,9 +315,11 @@ fn test_sql(sql: &str) -> Result<LogicalPlan> {
let optimizer = Optimizer::new();
// analyze and optimize the logical plan
let plan = analyzer.execute_and_check(&plan, config.options(), |_, _| {})?;
- optimizer.optimize(&plan, &config, |_, _| {})
+ optimizer.optimize(plan, &config, observe)
}
+fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
+
#[derive(Default)]
struct MyContextProvider {
options: ConfigOptions,
diff --git a/datafusion/sqllogictest/test_files/join.slt
b/datafusion/sqllogictest/test_files/join.slt
index da9b4168e7..135ab80754 100644
--- a/datafusion/sqllogictest/test_files/join.slt
+++ b/datafusion/sqllogictest/test_files/join.slt
@@ -587,7 +587,7 @@ FROM t1
----
11 11 11
-# subsequent inner join
+# subsequent inner join
query III rowsort
SELECT t1.t1_id, t2.t2_id, t3.t3_id
FROM t1