This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new ff404cd29c Migrate Optimizer tests to insta, part2 (#15884)
ff404cd29c is described below
commit ff404cd29c1ef52d857c8003fe2c199769e0c6d3
Author: Tommy shu <[email protected]>
AuthorDate: Mon Apr 28 21:53:48 2025 -0400
Migrate Optimizer tests to insta, part2 (#15884)
* migrate tests in `replace_distinct_aggregate.rs`
* migrate tests in `replace_distinct_aggregate.rs`
* migrate tests in `push_down_limit.rs`
* migrate tests in `eliminate_duplicated_expr.rs`
* migrate tests in `eliminate_filter.rs`
* migrate tests in `eliminate_group_by_constant.rs` to insta
* migrate tests in `eliminate_join.rs` to use snapshot assertions
* migrate tests in `eliminate_nested_union.rs` to use snapshot assertions
* migrate tests in `eliminate_outer_join.rs` to use snapshot assertions
* migrate tests in `filter_null_join_keys.rs` to use snapshot assertions
* fix Type inferance
* fix macro to use crate path for OptimizerContext and Optimizer
* clean up
---
.../optimizer/src/eliminate_duplicated_expr.rs | 39 +-
datafusion/optimizer/src/eliminate_filter.rs | 63 +--
.../optimizer/src/eliminate_group_by_constant.rs | 121 ++----
datafusion/optimizer/src/eliminate_join.rs | 21 +-
datafusion/optimizer/src/eliminate_nested_union.rs | 174 ++++----
datafusion/optimizer/src/eliminate_outer_join.rs | 80 ++--
datafusion/optimizer/src/filter_null_join_keys.rs | 142 +++---
datafusion/optimizer/src/push_down_limit.rs | 482 +++++++++++++--------
.../optimizer/src/replace_distinct_aggregate.rs | 54 ++-
datafusion/optimizer/src/test/mod.rs | 18 +
10 files changed, 700 insertions(+), 494 deletions(-)
diff --git a/datafusion/optimizer/src/eliminate_duplicated_expr.rs
b/datafusion/optimizer/src/eliminate_duplicated_expr.rs
index 4669500920..6a5b29062e 100644
--- a/datafusion/optimizer/src/eliminate_duplicated_expr.rs
+++ b/datafusion/optimizer/src/eliminate_duplicated_expr.rs
@@ -118,16 +118,23 @@ impl OptimizerRule for EliminateDuplicatedExpr {
#[cfg(test)]
mod tests {
use super::*;
+ use crate::assert_optimized_plan_eq_snapshot;
use crate::test::*;
use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder};
use std::sync::Arc;
- fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) ->
Result<()> {
- crate::test::assert_optimized_plan_eq(
- Arc::new(EliminateDuplicatedExpr::new()),
- plan,
- expected,
- )
+ macro_rules! assert_optimized_plan_equal {
+ (
+ $plan:expr,
+ @ $expected:literal $(,)?
+ ) => {{
+ let rule: Arc<dyn crate::OptimizerRule + Send + Sync> =
Arc::new(EliminateDuplicatedExpr::new());
+ assert_optimized_plan_eq_snapshot!(
+ rule,
+ $plan,
+ @ $expected,
+ )
+ }};
}
#[test]
@@ -137,10 +144,12 @@ mod tests {
.sort_by(vec![col("a"), col("a"), col("b"), col("c")])?
.limit(5, Some(10))?
.build()?;
- let expected = "Limit: skip=5, fetch=10\
- \n Sort: test.a ASC NULLS LAST, test.b ASC NULLS LAST, test.c ASC
NULLS LAST\
- \n TableScan: test";
- assert_optimized_plan_eq(plan, expected)
+
+ assert_optimized_plan_equal!(plan, @r"
+ Limit: skip=5, fetch=10
+ Sort: test.a ASC NULLS LAST, test.b ASC NULLS LAST, test.c ASC NULLS
LAST
+ TableScan: test
+ ")
}
#[test]
@@ -156,9 +165,11 @@ mod tests {
.sort(sort_exprs)?
.limit(5, Some(10))?
.build()?;
- let expected = "Limit: skip=5, fetch=10\
- \n Sort: test.a ASC NULLS FIRST, test.b ASC NULLS LAST\
- \n TableScan: test";
- assert_optimized_plan_eq(plan, expected)
+
+ assert_optimized_plan_equal!(plan, @r"
+ Limit: skip=5, fetch=10
+ Sort: test.a ASC NULLS FIRST, test.b ASC NULLS LAST
+ TableScan: test
+ ")
}
}
diff --git a/datafusion/optimizer/src/eliminate_filter.rs
b/datafusion/optimizer/src/eliminate_filter.rs
index 4ed2ac8ba1..db2136e5e4 100644
--- a/datafusion/optimizer/src/eliminate_filter.rs
+++ b/datafusion/optimizer/src/eliminate_filter.rs
@@ -81,17 +81,26 @@ impl OptimizerRule for EliminateFilter {
mod tests {
use std::sync::Arc;
+ use crate::assert_optimized_plan_eq_snapshot;
use datafusion_common::{Result, ScalarValue};
- use datafusion_expr::{
- col, lit, logical_plan::builder::LogicalPlanBuilder, Expr, LogicalPlan,
- };
+ use datafusion_expr::{col, lit, logical_plan::builder::LogicalPlanBuilder,
Expr};
use crate::eliminate_filter::EliminateFilter;
use crate::test::*;
use datafusion_expr::test::function_stub::sum;
- fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) ->
Result<()> {
- assert_optimized_plan_eq(Arc::new(EliminateFilter::new()), plan,
expected)
+ macro_rules! assert_optimized_plan_equal {
+ (
+ $plan:expr,
+ @ $expected:literal $(,)?
+ ) => {{
+ let rule: Arc<dyn crate::OptimizerRule + Send + Sync> =
Arc::new(EliminateFilter::new());
+ assert_optimized_plan_eq_snapshot!(
+ rule,
+ $plan,
+ @ $expected,
+ )
+ }};
}
#[test]
@@ -105,8 +114,7 @@ mod tests {
.build()?;
// No aggregate / scan / limit
- let expected = "EmptyRelation";
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(plan, @"EmptyRelation")
}
#[test]
@@ -120,8 +128,7 @@ mod tests {
.build()?;
// No aggregate / scan / limit
- let expected = "EmptyRelation";
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(plan, @"EmptyRelation")
}
#[test]
@@ -139,11 +146,12 @@ mod tests {
.build()?;
// Left side is removed
- let expected = "Union\
- \n EmptyRelation\
- \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\
- \n TableScan: test";
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(plan, @r"
+ Union
+ EmptyRelation
+ Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]
+ TableScan: test
+ ")
}
#[test]
@@ -156,9 +164,10 @@ mod tests {
.filter(filter_expr)?
.build()?;
- let expected = "Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\
- \n TableScan: test";
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(plan, @r"
+ Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]
+ TableScan: test
+ ")
}
#[test]
@@ -176,12 +185,13 @@ mod tests {
.build()?;
// Filter is removed
- let expected = "Union\
- \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\
- \n TableScan: test\
- \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\
- \n TableScan: test";
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(plan, @r"
+ Union
+ Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]
+ TableScan: test
+ Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]
+ TableScan: test
+ ")
}
#[test]
@@ -202,8 +212,9 @@ mod tests {
.build()?;
// Filter is removed
- let expected = "Projection: test.a\
- \n EmptyRelation";
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(plan, @r"
+ Projection: test.a
+ EmptyRelation
+ ")
}
}
diff --git a/datafusion/optimizer/src/eliminate_group_by_constant.rs
b/datafusion/optimizer/src/eliminate_group_by_constant.rs
index 7e252d6dce..bd5e691020 100644
--- a/datafusion/optimizer/src/eliminate_group_by_constant.rs
+++ b/datafusion/optimizer/src/eliminate_group_by_constant.rs
@@ -115,6 +115,7 @@ fn is_constant_expression(expr: &Expr) -> bool {
#[cfg(test)]
mod tests {
use super::*;
+ use crate::assert_optimized_plan_eq_snapshot;
use crate::test::*;
use arrow::datatypes::DataType;
@@ -129,6 +130,20 @@ mod tests {
use std::sync::Arc;
+ macro_rules! assert_optimized_plan_equal {
+ (
+ $plan:expr,
+ @ $expected:literal $(,)?
+ ) => {{
+ let rule: Arc<dyn crate::OptimizerRule + Send + Sync> =
Arc::new(EliminateGroupByConstant::new());
+ assert_optimized_plan_eq_snapshot!(
+ rule,
+ $plan,
+ @ $expected,
+ )
+ }};
+ }
+
#[derive(Debug)]
struct ScalarUDFMock {
signature: Signature,
@@ -167,17 +182,11 @@ mod tests {
.aggregate(vec![col("a"), lit(1u32)], vec![count(col("c"))])?
.build()?;
- let expected = "\
- Projection: test.a, UInt32(1), count(test.c)\
- \n Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]\
- \n TableScan: test\
- ";
-
- assert_optimized_plan_eq(
- Arc::new(EliminateGroupByConstant::new()),
- plan,
- expected,
- )
+ assert_optimized_plan_equal!(plan, @r"
+ Projection: test.a, UInt32(1), count(test.c)
+ Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]
+ TableScan: test
+ ")
}
#[test]
@@ -187,17 +196,11 @@ mod tests {
.aggregate(vec![lit("test"), lit(123u32)], vec![count(col("c"))])?
.build()?;
- let expected = "\
- Projection: Utf8(\"test\"), UInt32(123), count(test.c)\
- \n Aggregate: groupBy=[[]], aggr=[[count(test.c)]]\
- \n TableScan: test\
- ";
-
- assert_optimized_plan_eq(
- Arc::new(EliminateGroupByConstant::new()),
- plan,
- expected,
- )
+ assert_optimized_plan_equal!(plan, @r#"
+ Projection: Utf8("test"), UInt32(123), count(test.c)
+ Aggregate: groupBy=[[]], aggr=[[count(test.c)]]
+ TableScan: test
+ "#)
}
#[test]
@@ -207,16 +210,10 @@ mod tests {
.aggregate(vec![col("a"), col("b")], vec![count(col("c"))])?
.build()?;
- let expected = "\
- Aggregate: groupBy=[[test.a, test.b]], aggr=[[count(test.c)]]\
- \n TableScan: test\
- ";
-
- assert_optimized_plan_eq(
- Arc::new(EliminateGroupByConstant::new()),
- plan,
- expected,
- )
+ assert_optimized_plan_equal!(plan, @r"
+ Aggregate: groupBy=[[test.a, test.b]], aggr=[[count(test.c)]]
+ TableScan: test
+ ")
}
#[test]
@@ -226,16 +223,10 @@ mod tests {
.aggregate(vec![lit(123u32)], Vec::<Expr>::new())?
.build()?;
- let expected = "\
- Aggregate: groupBy=[[UInt32(123)]], aggr=[[]]\
- \n TableScan: test\
- ";
-
- assert_optimized_plan_eq(
- Arc::new(EliminateGroupByConstant::new()),
- plan,
- expected,
- )
+ assert_optimized_plan_equal!(plan, @r"
+ Aggregate: groupBy=[[UInt32(123)]], aggr=[[]]
+ TableScan: test
+ ")
}
#[test]
@@ -248,17 +239,11 @@ mod tests {
)?
.build()?;
- let expected = "\
- Projection: UInt32(123) AS const, test.a, count(test.c)\
- \n Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]\
- \n TableScan: test\
- ";
-
- assert_optimized_plan_eq(
- Arc::new(EliminateGroupByConstant::new()),
- plan,
- expected,
- )
+ assert_optimized_plan_equal!(plan, @r"
+ Projection: UInt32(123) AS const, test.a, count(test.c)
+ Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]
+ TableScan: test
+ ")
}
#[test]
@@ -273,17 +258,11 @@ mod tests {
.aggregate(vec![udf_expr, col("a")], vec![count(col("c"))])?
.build()?;
- let expected = "\
- Projection: scalar_fn_mock(UInt32(123)), test.a, count(test.c)\
- \n Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]\
- \n TableScan: test\
- ";
-
- assert_optimized_plan_eq(
- Arc::new(EliminateGroupByConstant::new()),
- plan,
- expected,
- )
+ assert_optimized_plan_equal!(plan, @r"
+ Projection: scalar_fn_mock(UInt32(123)), test.a, count(test.c)
+ Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]
+ TableScan: test
+ ")
}
#[test]
@@ -298,15 +277,9 @@ mod tests {
.aggregate(vec![udf_expr, col("a")], vec![count(col("c"))])?
.build()?;
- let expected = "\
- Aggregate: groupBy=[[scalar_fn_mock(UInt32(123)), test.a]],
aggr=[[count(test.c)]]\
- \n TableScan: test\
- ";
-
- assert_optimized_plan_eq(
- Arc::new(EliminateGroupByConstant::new()),
- plan,
- expected,
- )
+ assert_optimized_plan_equal!(plan, @r"
+ Aggregate: groupBy=[[scalar_fn_mock(UInt32(123)), test.a]],
aggr=[[count(test.c)]]
+ TableScan: test
+ ")
}
}
diff --git a/datafusion/optimizer/src/eliminate_join.rs
b/datafusion/optimizer/src/eliminate_join.rs
index 789235595d..bac82a2ee1 100644
--- a/datafusion/optimizer/src/eliminate_join.rs
+++ b/datafusion/optimizer/src/eliminate_join.rs
@@ -74,15 +74,25 @@ impl OptimizerRule for EliminateJoin {
#[cfg(test)]
mod tests {
+ use crate::assert_optimized_plan_eq_snapshot;
use crate::eliminate_join::EliminateJoin;
- use crate::test::*;
use datafusion_common::Result;
use datafusion_expr::JoinType::Inner;
- use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder,
LogicalPlan};
+ use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder};
use std::sync::Arc;
- fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) ->
Result<()> {
- assert_optimized_plan_eq(Arc::new(EliminateJoin::new()), plan,
expected)
+ macro_rules! assert_optimized_plan_equal {
+ (
+ $plan:expr,
+ @$expected:literal $(,)?
+ ) => {{
+ let rule: Arc<dyn crate::OptimizerRule + Send + Sync> =
Arc::new(EliminateJoin::new());
+ assert_optimized_plan_eq_snapshot!(
+ rule,
+ $plan,
+ @ $expected,
+ )
+ }};
}
#[test]
@@ -95,7 +105,6 @@ mod tests {
)?
.build()?;
- let expected = "EmptyRelation";
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(plan, @"EmptyRelation")
}
}
diff --git a/datafusion/optimizer/src/eliminate_nested_union.rs
b/datafusion/optimizer/src/eliminate_nested_union.rs
index 94da08243d..fe835afbaa 100644
--- a/datafusion/optimizer/src/eliminate_nested_union.rs
+++ b/datafusion/optimizer/src/eliminate_nested_union.rs
@@ -116,7 +116,7 @@ mod tests {
use super::*;
use crate::analyzer::type_coercion::TypeCoercion;
use crate::analyzer::Analyzer;
- use crate::test::*;
+ use crate::assert_optimized_plan_eq_snapshot;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::config::ConfigOptions;
use datafusion_expr::{col, logical_plan::table_scan};
@@ -129,15 +129,21 @@ mod tests {
])
}
- fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) ->
Result<()> {
- let options = ConfigOptions::default();
- let analyzed_plan =
Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())])
- .execute_and_check(plan, &options, |_, _| {})?;
- assert_optimized_plan_eq(
- Arc::new(EliminateNestedUnion::new()),
- analyzed_plan,
- expected,
- )
+ macro_rules! assert_optimized_plan_equal {
+ (
+ $plan:expr,
+ @ $expected:literal $(,)?
+ ) => {{
+ let options = ConfigOptions::default();
+ let analyzed_plan =
Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())])
+ .execute_and_check($plan, &options, |_, _| {})?;
+ let rule: Arc<dyn crate::OptimizerRule + Send + Sync> =
Arc::new(EliminateNestedUnion::new());
+ assert_optimized_plan_eq_snapshot!(
+ rule,
+ analyzed_plan,
+ @ $expected,
+ )
+ }};
}
#[test]
@@ -146,11 +152,11 @@ mod tests {
let plan = plan_builder.clone().union(plan_builder.build()?)?.build()?;
- let expected = "\
- Union\
- \n TableScan: table\
- \n TableScan: table";
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(plan, @r"
+ Union
+ TableScan: table
+ TableScan: table
+ ")
}
#[test]
@@ -162,11 +168,12 @@ mod tests {
.union_distinct(plan_builder.build()?)?
.build()?;
- let expected = "Distinct:\
- \n Union\
- \n TableScan: table\
- \n TableScan: table";
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(plan, @r"
+ Distinct:
+ Union
+ TableScan: table
+ TableScan: table
+ ")
}
#[test]
@@ -180,13 +187,13 @@ mod tests {
.union(plan_builder.build()?)?
.build()?;
- let expected = "\
- Union\
- \n TableScan: table\
- \n TableScan: table\
- \n TableScan: table\
- \n TableScan: table";
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(plan, @r"
+ Union
+ TableScan: table
+ TableScan: table
+ TableScan: table
+ TableScan: table
+ ")
}
#[test]
@@ -200,14 +207,15 @@ mod tests {
.union(plan_builder.build()?)?
.build()?;
- let expected = "Union\
- \n Distinct:\
- \n Union\
- \n TableScan: table\
- \n TableScan: table\
- \n TableScan: table\
- \n TableScan: table";
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(plan, @r"
+ Union
+ Distinct:
+ Union
+ TableScan: table
+ TableScan: table
+ TableScan: table
+ TableScan: table
+ ")
}
#[test]
@@ -222,14 +230,15 @@ mod tests {
.union_distinct(plan_builder.build()?)?
.build()?;
- let expected = "Distinct:\
- \n Union\
- \n TableScan: table\
- \n TableScan: table\
- \n TableScan: table\
- \n TableScan: table\
- \n TableScan: table";
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(plan, @r"
+ Distinct:
+ Union
+ TableScan: table
+ TableScan: table
+ TableScan: table
+ TableScan: table
+ TableScan: table
+ ")
}
#[test]
@@ -243,13 +252,14 @@ mod tests {
.union_distinct(plan_builder.build()?)?
.build()?;
- let expected = "Distinct:\
- \n Union\
- \n TableScan: table\
- \n TableScan: table\
- \n TableScan: table\
- \n TableScan: table";
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(plan, @r"
+ Distinct:
+ Union
+ TableScan: table
+ TableScan: table
+ TableScan: table
+ TableScan: table
+ ")
}
// We don't need to use project_with_column_index in logical optimizer,
@@ -273,13 +283,14 @@ mod tests {
)?
.build()?;
- let expected = "Union\
- \n TableScan: table\
- \n Projection: table.id AS id, table.key, table.value\
- \n TableScan: table\
- \n Projection: table.id AS id, table.key, table.value\
- \n TableScan: table";
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(plan, @r"
+ Union
+ TableScan: table
+ Projection: table.id AS id, table.key, table.value
+ TableScan: table
+ Projection: table.id AS id, table.key, table.value
+ TableScan: table
+ ")
}
#[test]
@@ -301,14 +312,15 @@ mod tests {
)?
.build()?;
- let expected = "Distinct:\
- \n Union\
- \n TableScan: table\
- \n Projection: table.id AS id, table.key, table.value\
- \n TableScan: table\
- \n Projection: table.id AS id, table.key, table.value\
- \n TableScan: table";
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(plan, @r"
+ Distinct:
+ Union
+ TableScan: table
+ Projection: table.id AS id, table.key, table.value
+ TableScan: table
+ Projection: table.id AS id, table.key, table.value
+ TableScan: table
+ ")
}
#[test]
@@ -348,13 +360,14 @@ mod tests {
.union(table_3.build()?)?
.build()?;
- let expected = "Union\
- \n TableScan: table_1\
- \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key,
CAST(table_1.value AS Float64) AS value\
- \n TableScan: table_1\
- \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key,
CAST(table_1.value AS Float64) AS value\
- \n TableScan: table_1";
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(plan, @r"
+ Union
+ TableScan: table_1
+ Projection: CAST(table_1.id AS Int64) AS id, table_1.key,
CAST(table_1.value AS Float64) AS value
+ TableScan: table_1
+ Projection: CAST(table_1.id AS Int64) AS id, table_1.key,
CAST(table_1.value AS Float64) AS value
+ TableScan: table_1
+ ")
}
#[test]
@@ -394,13 +407,14 @@ mod tests {
.union_distinct(table_3.build()?)?
.build()?;
- let expected = "Distinct:\
- \n Union\
- \n TableScan: table_1\
- \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key,
CAST(table_1.value AS Float64) AS value\
- \n TableScan: table_1\
- \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key,
CAST(table_1.value AS Float64) AS value\
- \n TableScan: table_1";
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(plan, @r"
+ Distinct:
+ Union
+ TableScan: table_1
+ Projection: CAST(table_1.id AS Int64) AS id, table_1.key,
CAST(table_1.value AS Float64) AS value
+ TableScan: table_1
+ Projection: CAST(table_1.id AS Int64) AS id, table_1.key,
CAST(table_1.value AS Float64) AS value
+ TableScan: table_1
+ ")
}
}
diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs
b/datafusion/optimizer/src/eliminate_outer_join.rs
index 1ecb32ca2a..704a9e7e53 100644
--- a/datafusion/optimizer/src/eliminate_outer_join.rs
+++ b/datafusion/optimizer/src/eliminate_outer_join.rs
@@ -304,6 +304,7 @@ fn extract_non_nullable_columns(
#[cfg(test)]
mod tests {
use super::*;
+ use crate::assert_optimized_plan_eq_snapshot;
use crate::test::*;
use arrow::datatypes::DataType;
use datafusion_expr::{
@@ -313,8 +314,18 @@ mod tests {
Operator::{And, Or},
};
- fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) ->
Result<()> {
- assert_optimized_plan_eq(Arc::new(EliminateOuterJoin::new()), plan,
expected)
+ macro_rules! assert_optimized_plan_equal {
+ (
+ $plan:expr,
+ @$expected:literal $(,)?
+ ) => {{
+ let rule: Arc<dyn crate::OptimizerRule + Send + Sync> =
Arc::new(EliminateOuterJoin::new());
+ assert_optimized_plan_eq_snapshot!(
+ rule,
+ $plan,
+ @ $expected,
+ )
+ }};
}
#[test]
@@ -332,12 +343,13 @@ mod tests {
)?
.filter(col("t2.b").is_null())?
.build()?;
- let expected = "\
- Filter: t2.b IS NULL\
- \n Left Join: t1.a = t2.a\
- \n TableScan: t1\
- \n TableScan: t2";
- assert_optimized_plan_equal(plan, expected)
+
+ assert_optimized_plan_equal!(plan, @r"
+ Filter: t2.b IS NULL
+ Left Join: t1.a = t2.a
+ TableScan: t1
+ TableScan: t2
+ ")
}
#[test]
@@ -355,12 +367,13 @@ mod tests {
)?
.filter(col("t2.b").is_not_null())?
.build()?;
- let expected = "\
- Filter: t2.b IS NOT NULL\
- \n Inner Join: t1.a = t2.a\
- \n TableScan: t1\
- \n TableScan: t2";
- assert_optimized_plan_equal(plan, expected)
+
+ assert_optimized_plan_equal!(plan, @r"
+ Filter: t2.b IS NOT NULL
+ Inner Join: t1.a = t2.a
+ TableScan: t1
+ TableScan: t2
+ ")
}
#[test]
@@ -382,12 +395,13 @@ mod tests {
col("t1.c").lt(lit(20u32)),
))?
.build()?;
- let expected = "\
- Filter: t1.b > UInt32(10) OR t1.c < UInt32(20)\
- \n Inner Join: t1.a = t2.a\
- \n TableScan: t1\
- \n TableScan: t2";
- assert_optimized_plan_equal(plan, expected)
+
+ assert_optimized_plan_equal!(plan, @r"
+ Filter: t1.b > UInt32(10) OR t1.c < UInt32(20)
+ Inner Join: t1.a = t2.a
+ TableScan: t1
+ TableScan: t2
+ ")
}
#[test]
@@ -409,12 +423,13 @@ mod tests {
col("t2.c").lt(lit(20u32)),
))?
.build()?;
- let expected = "\
- Filter: t1.b > UInt32(10) AND t2.c < UInt32(20)\
- \n Inner Join: t1.a = t2.a\
- \n TableScan: t1\
- \n TableScan: t2";
- assert_optimized_plan_equal(plan, expected)
+
+ assert_optimized_plan_equal!(plan, @r"
+ Filter: t1.b > UInt32(10) AND t2.c < UInt32(20)
+ Inner Join: t1.a = t2.a
+ TableScan: t1
+ TableScan: t2
+ ")
}
#[test]
@@ -436,11 +451,12 @@ mod tests {
try_cast(col("t2.c"), DataType::Int64).lt(lit(20u32)),
))?
.build()?;
- let expected = "\
- Filter: CAST(t1.b AS Int64) > UInt32(10) AND TRY_CAST(t2.c AS Int64) <
UInt32(20)\
- \n Inner Join: t1.a = t2.a\
- \n TableScan: t1\
- \n TableScan: t2";
- assert_optimized_plan_equal(plan, expected)
+
+ assert_optimized_plan_equal!(plan, @r"
+ Filter: CAST(t1.b AS Int64) > UInt32(10) AND TRY_CAST(t2.c AS Int64) <
UInt32(20)
+ Inner Join: t1.a = t2.a
+ TableScan: t1
+ TableScan: t2
+ ")
}
}
diff --git a/datafusion/optimizer/src/filter_null_join_keys.rs
b/datafusion/optimizer/src/filter_null_join_keys.rs
index 2e7a751ca4..314b439cb5 100644
--- a/datafusion/optimizer/src/filter_null_join_keys.rs
+++ b/datafusion/optimizer/src/filter_null_join_keys.rs
@@ -107,35 +107,49 @@ fn create_not_null_predicate(filters: Vec<Expr>) -> Expr {
#[cfg(test)]
mod tests {
use super::*;
- use crate::test::assert_optimized_plan_eq;
+ use crate::assert_optimized_plan_eq_snapshot;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::Column;
use datafusion_expr::logical_plan::table_scan;
use datafusion_expr::{col, lit, JoinType, LogicalPlanBuilder};
- fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) ->
Result<()> {
- assert_optimized_plan_eq(Arc::new(FilterNullJoinKeys {}), plan,
expected)
+ macro_rules! assert_optimized_plan_equal {
+ (
+ $plan:expr,
+ @ $expected:literal $(,)?
+ ) => {{
+ let rule: Arc<dyn crate::OptimizerRule + Send + Sync> =
Arc::new(FilterNullJoinKeys {});
+ assert_optimized_plan_eq_snapshot!(
+ rule,
+ $plan,
+ @ $expected,
+ )
+ }};
}
#[test]
fn left_nullable() -> Result<()> {
let (t1, t2) = test_tables()?;
let plan = build_plan(t1, t2, "t1.optional_id", "t2.id",
JoinType::Inner)?;
- let expected = "Inner Join: t1.optional_id = t2.id\
- \n Filter: t1.optional_id IS NOT NULL\
- \n TableScan: t1\
- \n TableScan: t2";
- assert_optimized_plan_equal(plan, expected)
+
+ assert_optimized_plan_equal!(plan, @r"
+ Inner Join: t1.optional_id = t2.id
+ Filter: t1.optional_id IS NOT NULL
+ TableScan: t1
+ TableScan: t2
+ ")
}
#[test]
fn left_nullable_left_join() -> Result<()> {
let (t1, t2) = test_tables()?;
let plan = build_plan(t1, t2, "t1.optional_id", "t2.id",
JoinType::Left)?;
- let expected = "Left Join: t1.optional_id = t2.id\
- \n TableScan: t1\
- \n TableScan: t2";
- assert_optimized_plan_equal(plan, expected)
+
+ assert_optimized_plan_equal!(plan, @r"
+ Left Join: t1.optional_id = t2.id
+ TableScan: t1
+ TableScan: t2
+ ")
}
#[test]
@@ -144,22 +158,26 @@ mod tests {
// Note: order of tables is reversed
let plan =
build_plan(t_right, t_left, "t2.id", "t1.optional_id",
JoinType::Left)?;
- let expected = "Left Join: t2.id = t1.optional_id\
- \n TableScan: t2\
- \n Filter: t1.optional_id IS NOT NULL\
- \n TableScan: t1";
- assert_optimized_plan_equal(plan, expected)
+
+ assert_optimized_plan_equal!(plan, @r"
+ Left Join: t2.id = t1.optional_id
+ TableScan: t2
+ Filter: t1.optional_id IS NOT NULL
+ TableScan: t1
+ ")
}
#[test]
fn left_nullable_on_condition_reversed() -> Result<()> {
let (t1, t2) = test_tables()?;
let plan = build_plan(t1, t2, "t2.id", "t1.optional_id",
JoinType::Inner)?;
- let expected = "Inner Join: t1.optional_id = t2.id\
- \n Filter: t1.optional_id IS NOT NULL\
- \n TableScan: t1\
- \n TableScan: t2";
- assert_optimized_plan_equal(plan, expected)
+
+ assert_optimized_plan_equal!(plan, @r"
+ Inner Join: t1.optional_id = t2.id
+ Filter: t1.optional_id IS NOT NULL
+ TableScan: t1
+ TableScan: t2
+ ")
}
#[test]
@@ -189,14 +207,16 @@ mod tests {
None,
)?
.build()?;
- let expected = "Inner Join: t3.t1_id = t1.id, t3.t2_id = t2.id\
- \n Filter: t3.t1_id IS NOT NULL AND t3.t2_id IS NOT NULL\
- \n TableScan: t3\
- \n Inner Join: t1.optional_id = t2.id\
- \n Filter: t1.optional_id IS NOT NULL\
- \n TableScan: t1\
- \n TableScan: t2";
- assert_optimized_plan_equal(plan, expected)
+
+ assert_optimized_plan_equal!(plan, @r"
+ Inner Join: t3.t1_id = t1.id, t3.t2_id = t2.id
+ Filter: t3.t1_id IS NOT NULL AND t3.t2_id IS NOT NULL
+ TableScan: t3
+ Inner Join: t1.optional_id = t2.id
+ Filter: t1.optional_id IS NOT NULL
+ TableScan: t1
+ TableScan: t2
+ ")
}
#[test]
@@ -213,11 +233,13 @@ mod tests {
None,
)?
.build()?;
- let expected = "Inner Join: t1.optional_id + UInt32(1) = t2.id +
UInt32(1)\
- \n Filter: t1.optional_id + UInt32(1) IS NOT NULL\
- \n TableScan: t1\
- \n TableScan: t2";
- assert_optimized_plan_equal(plan, expected)
+
+ assert_optimized_plan_equal!(plan, @r"
+ Inner Join: t1.optional_id + UInt32(1) = t2.id + UInt32(1)
+ Filter: t1.optional_id + UInt32(1) IS NOT NULL
+ TableScan: t1
+ TableScan: t2
+ ")
}
#[test]
@@ -234,11 +256,13 @@ mod tests {
None,
)?
.build()?;
- let expected = "Inner Join: t1.id + UInt32(1) = t2.optional_id +
UInt32(1)\
- \n TableScan: t1\
- \n Filter: t2.optional_id + UInt32(1) IS NOT NULL\
- \n TableScan: t2";
- assert_optimized_plan_equal(plan, expected)
+
+ assert_optimized_plan_equal!(plan, @r"
+ Inner Join: t1.id + UInt32(1) = t2.optional_id + UInt32(1)
+ TableScan: t1
+ Filter: t2.optional_id + UInt32(1) IS NOT NULL
+ TableScan: t2
+ ")
}
#[test]
@@ -255,13 +279,14 @@ mod tests {
None,
)?
.build()?;
- let expected =
- "Inner Join: t1.optional_id + UInt32(1) = t2.optional_id +
UInt32(1)\
- \n Filter: t1.optional_id + UInt32(1) IS NOT NULL\
- \n TableScan: t1\
- \n Filter: t2.optional_id + UInt32(1) IS NOT NULL\
- \n TableScan: t2";
- assert_optimized_plan_equal(plan, expected)
+
+ assert_optimized_plan_equal!(plan, @r"
+ Inner Join: t1.optional_id + UInt32(1) = t2.optional_id + UInt32(1)
+ Filter: t1.optional_id + UInt32(1) IS NOT NULL
+ TableScan: t1
+ Filter: t2.optional_id + UInt32(1) IS NOT NULL
+ TableScan: t2
+ ")
}
#[test]
@@ -283,13 +308,22 @@ mod tests {
None,
)?
.build()?;
- let expected = "Inner Join: t1.optional_id = t2.optional_id\
- \n Filter: t1.optional_id IS NOT NULL\
- \n TableScan: t1\
- \n Filter: t2.optional_id IS NOT NULL\
- \n TableScan: t2";
- assert_optimized_plan_equal(plan_from_cols, expected)?;
- assert_optimized_plan_equal(plan_from_exprs, expected)
+
+ assert_optimized_plan_equal!(plan_from_cols, @r"
+ Inner Join: t1.optional_id = t2.optional_id
+ Filter: t1.optional_id IS NOT NULL
+ TableScan: t1
+ Filter: t2.optional_id IS NOT NULL
+ TableScan: t2
+ ")?;
+
+ assert_optimized_plan_equal!(plan_from_exprs, @r"
+ Inner Join: t1.optional_id = t2.optional_id
+ Filter: t1.optional_id IS NOT NULL
+ TableScan: t1
+ Filter: t2.optional_id IS NOT NULL
+ TableScan: t2
+ ")
}
fn build_plan(
diff --git a/datafusion/optimizer/src/push_down_limit.rs
b/datafusion/optimizer/src/push_down_limit.rs
index 1e9ef16bde..0ed4e05d85 100644
--- a/datafusion/optimizer/src/push_down_limit.rs
+++ b/datafusion/optimizer/src/push_down_limit.rs
@@ -276,6 +276,7 @@ mod test {
use std::vec;
use super::*;
+ use crate::assert_optimized_plan_eq_snapshot;
use crate::test::*;
use datafusion_common::DFSchemaRef;
@@ -285,8 +286,18 @@ mod test {
};
use datafusion_functions_aggregate::expr_fn::max;
- fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) ->
Result<()> {
- assert_optimized_plan_eq(Arc::new(PushDownLimit::new()), plan,
expected)
+ macro_rules! assert_optimized_plan_equal {
+ (
+ $plan:expr,
+ @ $expected:literal $(,)?
+ ) => {{
+ let rule: Arc<dyn crate::OptimizerRule + Send + Sync> =
Arc::new(PushDownLimit::new());
+ assert_optimized_plan_eq_snapshot!(
+ rule,
+ $plan,
+ @ $expected,
+ )
+ }};
}
#[derive(Debug, PartialEq, Eq, Hash)]
@@ -408,12 +419,15 @@ mod test {
.limit(0, Some(1000))?
.build()?;
- let expected = "Limit: skip=0, fetch=1000\
- \n NoopPlan\
- \n Limit: skip=0, fetch=1000\
- \n TableScan: test, fetch=1000";
-
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(
+ plan,
+ @r"
+ Limit: skip=0, fetch=1000
+ NoopPlan
+ Limit: skip=0, fetch=1000
+ TableScan: test, fetch=1000
+ "
+ )
}
#[test]
@@ -430,12 +444,15 @@ mod test {
.limit(10, Some(1000))?
.build()?;
- let expected = "Limit: skip=10, fetch=1000\
- \n NoopPlan\
- \n Limit: skip=0, fetch=1010\
- \n TableScan: test, fetch=1010";
-
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(
+ plan,
+ @r"
+ Limit: skip=10, fetch=1000
+ NoopPlan
+ Limit: skip=0, fetch=1010
+ TableScan: test, fetch=1010
+ "
+ )
}
#[test]
@@ -453,12 +470,15 @@ mod test {
.limit(20, Some(500))?
.build()?;
- let expected = "Limit: skip=30, fetch=500\
- \n NoopPlan\
- \n Limit: skip=0, fetch=530\
- \n TableScan: test, fetch=530";
-
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(
+ plan,
+ @r"
+ Limit: skip=30, fetch=500
+ NoopPlan
+ Limit: skip=0, fetch=530
+ TableScan: test, fetch=530
+ "
+ )
}
#[test]
@@ -475,14 +495,17 @@ mod test {
.limit(0, Some(1000))?
.build()?;
- let expected = "Limit: skip=0, fetch=1000\
- \n NoopPlan\
- \n Limit: skip=0, fetch=1000\
- \n TableScan: test, fetch=1000\
- \n Limit: skip=0, fetch=1000\
- \n TableScan: test, fetch=1000";
-
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(
+ plan,
+ @r"
+ Limit: skip=0, fetch=1000
+ NoopPlan
+ Limit: skip=0, fetch=1000
+ TableScan: test, fetch=1000
+ Limit: skip=0, fetch=1000
+ TableScan: test, fetch=1000
+ "
+ )
}
#[test]
@@ -499,11 +522,14 @@ mod test {
.limit(0, Some(1000))?
.build()?;
- let expected = "Limit: skip=0, fetch=1000\
- \n NoLimitNoopPlan\
- \n TableScan: test";
-
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(
+ plan,
+ @r"
+ Limit: skip=0, fetch=1000
+ NoLimitNoopPlan
+ TableScan: test
+ "
+ )
}
#[test]
@@ -517,11 +543,14 @@ mod test {
// Should push the limit down to table provider
// When it has a select
- let expected = "Projection: test.a\
- \n Limit: skip=0, fetch=1000\
- \n TableScan: test, fetch=1000";
-
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(
+ plan,
+ @r"
+ Projection: test.a
+ Limit: skip=0, fetch=1000
+ TableScan: test, fetch=1000
+ "
+ )
}
#[test]
@@ -536,10 +565,13 @@ mod test {
// Should push down the smallest limit
// Towards table scan
// This rule doesn't replace multiple limits
- let expected = "Limit: skip=0, fetch=10\
- \n TableScan: test, fetch=10";
-
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(
+ plan,
+ @r"
+ Limit: skip=0, fetch=10
+ TableScan: test, fetch=10
+ "
+ )
}
#[test]
@@ -552,11 +584,14 @@ mod test {
.build()?;
// Limit should *not* push down aggregate node
- let expected = "Limit: skip=0, fetch=1000\
- \n Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]]\
- \n TableScan: test";
-
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(
+ plan,
+ @r"
+ Limit: skip=0, fetch=1000
+ Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]]
+ TableScan: test
+ "
+ )
}
#[test]
@@ -569,14 +604,17 @@ mod test {
.build()?;
// Limit should push down through union
- let expected = "Limit: skip=0, fetch=1000\
- \n Union\
- \n Limit: skip=0, fetch=1000\
- \n TableScan: test, fetch=1000\
- \n Limit: skip=0, fetch=1000\
- \n TableScan: test, fetch=1000";
-
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(
+ plan,
+ @r"
+ Limit: skip=0, fetch=1000
+ Union
+ Limit: skip=0, fetch=1000
+ TableScan: test, fetch=1000
+ Limit: skip=0, fetch=1000
+ TableScan: test, fetch=1000
+ "
+ )
}
#[test]
@@ -589,11 +627,14 @@ mod test {
.build()?;
// Should push down limit to sort
- let expected = "Limit: skip=0, fetch=10\
- \n Sort: test.a ASC NULLS LAST, fetch=10\
- \n TableScan: test";
-
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(
+ plan,
+ @r"
+ Limit: skip=0, fetch=10
+ Sort: test.a ASC NULLS LAST, fetch=10
+ TableScan: test
+ "
+ )
}
#[test]
@@ -606,11 +647,14 @@ mod test {
.build()?;
// Should push down limit to sort
- let expected = "Limit: skip=5, fetch=10\
- \n Sort: test.a ASC NULLS LAST, fetch=15\
- \n TableScan: test";
-
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(
+ plan,
+ @r"
+ Limit: skip=5, fetch=10
+ Sort: test.a ASC NULLS LAST, fetch=15
+ TableScan: test
+ "
+ )
}
#[test]
@@ -624,12 +668,15 @@ mod test {
.build()?;
// Limit should use deeper LIMIT 1000, but Limit 10 shouldn't push
down aggregation
- let expected = "Limit: skip=0, fetch=10\
- \n Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]]\
- \n Limit: skip=0, fetch=1000\
- \n TableScan: test, fetch=1000";
-
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(
+ plan,
+ @r"
+ Limit: skip=0, fetch=10
+ Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]]
+ Limit: skip=0, fetch=1000
+ TableScan: test, fetch=1000
+ "
+ )
}
#[test]
@@ -641,10 +688,13 @@ mod test {
// Should not push any limit down to table provider
// When it has a select
- let expected = "Limit: skip=10, fetch=None\
- \n TableScan: test";
-
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(
+ plan,
+ @r"
+ Limit: skip=10, fetch=None
+ TableScan: test
+ "
+ )
}
#[test]
@@ -658,11 +708,14 @@ mod test {
// Should push the limit down to table provider
// When it has a select
- let expected = "Projection: test.a\
- \n Limit: skip=10, fetch=1000\
- \n TableScan: test, fetch=1010";
-
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(
+ plan,
+ @r"
+ Projection: test.a
+ Limit: skip=10, fetch=1000
+ TableScan: test, fetch=1010
+ "
+ )
}
#[test]
@@ -675,11 +728,14 @@ mod test {
.limit(10, None)?
.build()?;
- let expected = "Projection: test.a\
- \n Limit: skip=10, fetch=990\
- \n TableScan: test, fetch=1000";
-
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(
+ plan,
+ @r"
+ Projection: test.a
+ Limit: skip=10, fetch=990
+ TableScan: test, fetch=1000
+ "
+ )
}
#[test]
@@ -692,11 +748,14 @@ mod test {
.limit(0, Some(1000))?
.build()?;
- let expected = "Projection: test.a\
- \n Limit: skip=10, fetch=1000\
- \n TableScan: test, fetch=1010";
-
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(
+ plan,
+ @r"
+ Projection: test.a
+ Limit: skip=10, fetch=1000
+ TableScan: test, fetch=1010
+ "
+ )
}
#[test]
@@ -709,10 +768,13 @@ mod test {
.limit(0, Some(10))?
.build()?;
- let expected = "Limit: skip=10, fetch=10\
- \n TableScan: test, fetch=20";
-
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(
+ plan,
+ @r"
+ Limit: skip=10, fetch=10
+ TableScan: test, fetch=20
+ "
+ )
}
#[test]
@@ -725,11 +787,14 @@ mod test {
.build()?;
// Limit should *not* push down aggregate node
- let expected = "Limit: skip=10, fetch=1000\
- \n Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]]\
- \n TableScan: test";
-
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(
+ plan,
+ @r"
+ Limit: skip=10, fetch=1000
+ Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]]
+ TableScan: test
+ "
+ )
}
#[test]
@@ -742,14 +807,17 @@ mod test {
.build()?;
// Limit should push down through union
- let expected = "Limit: skip=10, fetch=1000\
- \n Union\
- \n Limit: skip=0, fetch=1010\
- \n TableScan: test, fetch=1010\
- \n Limit: skip=0, fetch=1010\
- \n TableScan: test, fetch=1010";
-
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(
+ plan,
+ @r"
+ Limit: skip=10, fetch=1000
+ Union
+ Limit: skip=0, fetch=1010
+ TableScan: test, fetch=1010
+ Limit: skip=0, fetch=1010
+ TableScan: test, fetch=1010
+ "
+ )
}
#[test]
@@ -768,12 +836,15 @@ mod test {
.build()?;
// Limit pushdown Not supported in Join
- let expected = "Limit: skip=10, fetch=1000\
- \n Inner Join: test.a = test2.a\
- \n TableScan: test\
- \n TableScan: test2";
-
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(
+ plan,
+ @r"
+ Limit: skip=10, fetch=1000
+ Inner Join: test.a = test2.a
+ TableScan: test
+ TableScan: test2
+ "
+ )
}
#[test]
@@ -792,12 +863,15 @@ mod test {
.build()?;
// Limit pushdown Not supported in Join
- let expected = "Limit: skip=10, fetch=1000\
- \n Inner Join: test.a = test2.a\
- \n TableScan: test\
- \n TableScan: test2";
-
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(
+ plan,
+ @r"
+ Limit: skip=10, fetch=1000
+ Inner Join: test.a = test2.a
+ TableScan: test
+ TableScan: test2
+ "
+ )
}
#[test]
@@ -817,16 +891,19 @@ mod test {
.build()?;
// Limit pushdown Not supported in sub_query
- let expected = "Limit: skip=10, fetch=100\
- \n Filter: EXISTS (<subquery>)\
- \n Subquery:\
- \n Filter: test1.a = test1.a\
- \n Projection: test1.a\
- \n TableScan: test1\
- \n Projection: test2.a\
- \n TableScan: test2";
-
- assert_optimized_plan_equal(outer_query, expected)
+ assert_optimized_plan_equal!(
+ outer_query,
+ @r"
+ Limit: skip=10, fetch=100
+ Filter: EXISTS (<subquery>)
+ Subquery:
+ Filter: test1.a = test1.a
+ Projection: test1.a
+ TableScan: test1
+ Projection: test2.a
+ TableScan: test2
+ "
+ )
}
#[test]
@@ -846,16 +923,19 @@ mod test {
.build()?;
// Limit pushdown Not supported in sub_query
- let expected = "Limit: skip=10, fetch=100\
- \n Filter: EXISTS (<subquery>)\
- \n Subquery:\
- \n Filter: test1.a = test1.a\
- \n Projection: test1.a\
- \n TableScan: test1\
- \n Projection: test2.a\
- \n TableScan: test2";
-
- assert_optimized_plan_equal(outer_query, expected)
+ assert_optimized_plan_equal!(
+ outer_query,
+ @r"
+ Limit: skip=10, fetch=100
+ Filter: EXISTS (<subquery>)
+ Subquery:
+ Filter: test1.a = test1.a
+ Projection: test1.a
+ TableScan: test1
+ Projection: test2.a
+ TableScan: test2
+ "
+ )
}
#[test]
@@ -874,13 +954,16 @@ mod test {
.build()?;
// Limit pushdown Not supported in Join
- let expected = "Limit: skip=10, fetch=1000\
- \n Left Join: test.a = test2.a\
- \n Limit: skip=0, fetch=1010\
- \n TableScan: test, fetch=1010\
- \n TableScan: test2";
-
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(
+ plan,
+ @r"
+ Limit: skip=10, fetch=1000
+ Left Join: test.a = test2.a
+ Limit: skip=0, fetch=1010
+ TableScan: test, fetch=1010
+ TableScan: test2
+ "
+ )
}
#[test]
@@ -899,13 +982,16 @@ mod test {
.build()?;
// Limit pushdown Not supported in Join
- let expected = "Limit: skip=0, fetch=1000\
- \n Right Join: test.a = test2.a\
- \n TableScan: test\
- \n Limit: skip=0, fetch=1000\
- \n TableScan: test2, fetch=1000";
-
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(
+ plan,
+ @r"
+ Limit: skip=0, fetch=1000
+ Right Join: test.a = test2.a
+ TableScan: test
+ Limit: skip=0, fetch=1000
+ TableScan: test2, fetch=1000
+ "
+ )
}
#[test]
@@ -924,13 +1010,16 @@ mod test {
.build()?;
// Limit pushdown with offset supported in right outer join
- let expected = "Limit: skip=10, fetch=1000\
- \n Right Join: test.a = test2.a\
- \n TableScan: test\
- \n Limit: skip=0, fetch=1010\
- \n TableScan: test2, fetch=1010";
-
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(
+ plan,
+ @r"
+ Limit: skip=10, fetch=1000
+ Right Join: test.a = test2.a
+ TableScan: test
+ Limit: skip=0, fetch=1010
+ TableScan: test2, fetch=1010
+ "
+ )
}
#[test]
@@ -943,14 +1032,17 @@ mod test {
.limit(0, Some(1000))?
.build()?;
- let expected = "Limit: skip=0, fetch=1000\
- \n Cross Join: \
- \n Limit: skip=0, fetch=1000\
- \n TableScan: test, fetch=1000\
- \n Limit: skip=0, fetch=1000\
- \n TableScan: test2, fetch=1000";
-
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(
+ plan,
+ @r"
+ Limit: skip=0, fetch=1000
+ Cross Join:
+ Limit: skip=0, fetch=1000
+ TableScan: test, fetch=1000
+ Limit: skip=0, fetch=1000
+ TableScan: test2, fetch=1000
+ "
+ )
}
#[test]
@@ -963,14 +1055,17 @@ mod test {
.limit(1000, Some(1000))?
.build()?;
- let expected = "Limit: skip=1000, fetch=1000\
- \n Cross Join: \
- \n Limit: skip=0, fetch=2000\
- \n TableScan: test, fetch=2000\
- \n Limit: skip=0, fetch=2000\
- \n TableScan: test2, fetch=2000";
-
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(
+ plan,
+ @r"
+ Limit: skip=1000, fetch=1000
+ Cross Join:
+ Limit: skip=0, fetch=2000
+ TableScan: test, fetch=2000
+ Limit: skip=0, fetch=2000
+ TableScan: test2, fetch=2000
+ "
+ )
}
#[test]
@@ -982,10 +1077,13 @@ mod test {
.limit(1000, None)?
.build()?;
- let expected = "Limit: skip=1000, fetch=0\
- \n TableScan: test, fetch=0";
-
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(
+ plan,
+ @r"
+ Limit: skip=1000, fetch=0
+ TableScan: test, fetch=0
+ "
+ )
}
#[test]
@@ -997,10 +1095,13 @@ mod test {
.limit(1000, None)?
.build()?;
- let expected = "Limit: skip=1000, fetch=0\
- \n TableScan: test, fetch=0";
-
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(
+ plan,
+ @r"
+ Limit: skip=1000, fetch=0
+ TableScan: test, fetch=0
+ "
+ )
}
#[test]
@@ -1013,10 +1114,13 @@ mod test {
.limit(1000, None)?
.build()?;
- let expected = "SubqueryAlias: a\
- \n Limit: skip=1000, fetch=0\
- \n TableScan: test, fetch=0";
-
- assert_optimized_plan_equal(plan, expected)
+ assert_optimized_plan_equal!(
+ plan,
+ @r"
+ SubqueryAlias: a
+ Limit: skip=1000, fetch=0
+ TableScan: test, fetch=0
+ "
+ )
}
}
diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs
b/datafusion/optimizer/src/replace_distinct_aggregate.rs
index 48b2828faf..c7c9d03a51 100644
--- a/datafusion/optimizer/src/replace_distinct_aggregate.rs
+++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs
@@ -186,21 +186,26 @@ impl OptimizerRule for ReplaceDistinctWithAggregate {
mod tests {
use std::sync::Arc;
+ use crate::assert_optimized_plan_eq_snapshot;
use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate;
use crate::test::*;
use datafusion_common::Result;
- use datafusion_expr::{
- col, logical_plan::builder::LogicalPlanBuilder, Expr, LogicalPlan,
- };
+ use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder,
Expr};
use datafusion_functions_aggregate::sum::sum;
- fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) ->
Result<()> {
- assert_optimized_plan_eq(
- Arc::new(ReplaceDistinctWithAggregate::new()),
- plan.clone(),
- expected,
- )
+ macro_rules! assert_optimized_plan_equal {
+ (
+ $plan:expr,
+ @ $expected:literal $(,)?
+ ) => {{
+ let rule: Arc<dyn crate::OptimizerRule + Send + Sync> =
Arc::new(ReplaceDistinctWithAggregate::new());
+ assert_optimized_plan_eq_snapshot!(
+ rule,
+ $plan,
+ @ $expected,
+ )
+ }};
}
#[test]
@@ -212,8 +217,11 @@ mod tests {
.distinct()?
.build()?;
- let expected = "Projection: test.c\n Aggregate: groupBy=[[test.c]],
aggr=[[]]\n TableScan: test";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal!(plan, @r"
+ Projection: test.c
+ Aggregate: groupBy=[[test.c]], aggr=[[]]
+ TableScan: test
+ ")
}
#[test]
@@ -225,9 +233,11 @@ mod tests {
.distinct()?
.build()?;
- let expected =
- "Projection: test.a, test.b\n Aggregate: groupBy=[[test.a,
test.b]], aggr=[[]]\n TableScan: test";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal!(plan, @r"
+ Projection: test.a, test.b
+ Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]
+ TableScan: test
+ ")
}
#[test]
@@ -238,8 +248,11 @@ mod tests {
.distinct()?
.build()?;
- let expected = "Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n
Projection: test.a, test.b\n TableScan: test";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal!(plan, @r"
+ Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]
+ Projection: test.a, test.b
+ TableScan: test
+ ")
}
#[test]
@@ -251,8 +264,11 @@ mod tests {
.distinct()?
.build()?;
- let expected =
- "Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n Projection:
test.a, test.b\n Aggregate: groupBy=[[test.a, test.b, test.c]],
aggr=[[sum(test.c)]]\n TableScan: test";
- assert_optimized_plan_equal(&plan, expected)
+ assert_optimized_plan_equal!(plan, @r"
+ Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]
+ Projection: test.a, test.b
+ Aggregate: groupBy=[[test.a, test.b, test.c]], aggr=[[sum(test.c)]]
+ TableScan: test
+ ")
}
}
diff --git a/datafusion/optimizer/src/test/mod.rs
b/datafusion/optimizer/src/test/mod.rs
index 94d07a0791..5927ffea5a 100644
--- a/datafusion/optimizer/src/test/mod.rs
+++ b/datafusion/optimizer/src/test/mod.rs
@@ -181,6 +181,24 @@ pub fn assert_optimized_plan_eq(
Ok(())
}
+#[macro_export]
+macro_rules! assert_optimized_plan_eq_snapshot {
+ (
+ $rule:expr,
+ $plan:expr,
+ @ $expected:literal $(,)?
+ ) => {{
+ // Apply the rule once
+ let opt_context = $crate::OptimizerContext::new().with_max_passes(1);
+
+ let optimizer = $crate::Optimizer::with_rules(vec![Arc::clone(&$rule)]);
+ let optimized_plan = optimizer.optimize($plan, &opt_context, |_, _| {})?;
+ insta::assert_snapshot!(optimized_plan, @ $expected);
+
+ Ok::<(), datafusion_common::DataFusionError>(())
+ }};
+}
+
fn generate_optimized_plan_with_rules(
rules: Vec<Arc<dyn OptimizerRule + Send + Sync>>,
plan: LogicalPlan,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]