This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 102f879e60 Migrate-substrait-tests-to-insta, part2 (#15480)
102f879e60 is described below
commit 102f879e60d82c69157b6b21265f8a78ec3a4370
Author: Tommy shu <[email protected]>
AuthorDate: Sun Mar 30 06:42:12 2025 -0400
Migrate-substrait-tests-to-insta, part2 (#15480)
* add `cargo insta` to dev dependencies
* migrate `consumer_intergration.rs` tests to `insta`
* Revert "migrate `consumer_intergration.rs` tests to `insta`"
This reverts commit c3be2ebfeaeb5afff841810e38eff657f1b86a3f.
* migrate `consumer_integration.rs` to `insta` inline snapshot
* migrate logical plans tests to use `insta` snapshots
* migrate emit_kind_tests to use `insta` snapshots
* migrate function_test to use `insta` snapshots for assertions
* migrate substrait_validations tests to use insta snapshots, missing
`insta` mapping to `assert!`
* revert `handle_emit_as_project_without_volatile_exprs` back to
`assert_eq!` and remove `format!` for `assert_snapshot!`
* migrate function and validation tests to use plan directly in
assert_snapshot!
* migrate serialize tests to use insta snapshots for assertions
* migrate logical_plans test to use insta snapshots for assertions
* WIP
* migrate `assert_expected_plan_substrait`
* refactor tests to use assert_and_generate_plan and assert_snapshot! for
improved clarity and consistency
* remove println!
* migrate tests to use generate_plan_from_sql for improved clarity
---
.../tests/cases/roundtrip_logical_plan.rs | 335 +++++++++++++--------
1 file changed, 204 insertions(+), 131 deletions(-)
diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index f989d05c80..36ee78fe5d 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -37,6 +37,7 @@ use datafusion::logical_expr::{
};
use
datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST;
use datafusion::prelude::*;
+use insta::assert_snapshot;
use std::hash::Hash;
use std::sync::Arc;
use substrait::proto::extensions::simple_extension_declaration::MappingType;
@@ -188,13 +189,16 @@ async fn simple_select() -> Result<()> {
#[tokio::test]
async fn wildcard_select() -> Result<()> {
- assert_expected_plan_unoptimized(
- "SELECT * FROM data",
- "Projection: data.a, data.b, data.c, data.d, data.e, data.f\
- \n TableScan: data",
- true,
- )
- .await
+ let plan = generate_plan_from_sql("SELECT * FROM data", true,
false).await?;
+
+ assert_snapshot!(
+ plan,
+ @r#"
+ Projection: data.a, data.b, data.c, data.d, data.e, data.f
+ TableScan: data
+ "#
+ );
+ Ok(())
}
#[tokio::test]
@@ -299,24 +303,42 @@ async fn aggregate_grouping_sets() -> Result<()> {
#[tokio::test]
async fn aggregate_grouping_rollup() -> Result<()> {
- assert_expected_plan(
+ let plan = generate_plan_from_sql(
"SELECT a, c, e, avg(b) FROM data GROUP BY ROLLUP (a, c, e)",
- "Projection: data.a, data.c, data.e, avg(data.b)\
- \n Aggregate: groupBy=[[GROUPING SETS ((data.a, data.c, data.e),
(data.a, data.c), (data.a), ())]], aggr=[[avg(data.b)]]\
- \n TableScan: data projection=[a, b, c, e]",
- true
- ).await
+ true,
+ true,
+ )
+ .await?;
+
+ assert_snapshot!(
+ plan,
+ @r#"
+ Projection: data.a, data.c, data.e, avg(data.b)
+ Aggregate: groupBy=[[GROUPING SETS ((data.a, data.c, data.e),
(data.a, data.c), (data.a), ())]], aggr=[[avg(data.b)]]
+ TableScan: data projection=[a, b, c, e]
+ "#
+ );
+ Ok(())
}
#[tokio::test]
async fn multilayer_aggregate() -> Result<()> {
- assert_expected_plan(
+ let plan = generate_plan_from_sql(
"SELECT a, sum(partial_count_b) FROM (SELECT a, count(b) as
partial_count_b FROM data GROUP BY a) GROUP BY a",
- "Aggregate: groupBy=[[data.a]], aggr=[[sum(count(data.b)) AS
sum(partial_count_b)]]\
- \n Aggregate: groupBy=[[data.a]], aggr=[[count(data.b)]]\
- \n TableScan: data projection=[a, b]",
- true
- ).await
+ true,
+ true,
+ )
+ .await?;
+
+ assert_snapshot!(
+ plan,
+ @r#"
+ Aggregate: groupBy=[[data.a]], aggr=[[sum(count(data.b)) AS
sum(partial_count_b)]]
+ Aggregate: groupBy=[[data.a]], aggr=[[count(data.b)]]
+ TableScan: data projection=[a, b]
+ "#
+ );
+ Ok(())
}
#[tokio::test]
@@ -454,13 +476,21 @@ async fn try_cast_decimal_to_string() -> Result<()> {
#[tokio::test]
async fn aggregate_case() -> Result<()> {
- assert_expected_plan(
+ let plan = generate_plan_from_sql(
"SELECT sum(CASE WHEN a > 0 THEN 1 ELSE NULL END) FROM data",
- "Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN data.a > Int64(0) THEN
Int64(1) ELSE Int64(NULL) END) AS sum(CASE WHEN data.a > Int64(0) THEN Int64(1)
ELSE NULL END)]]\
- \n TableScan: data projection=[a]",
- true
+ true,
+ true,
)
- .await
+ .await?;
+
+ assert_snapshot!(
+ plan,
+ @r#"
+ Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN data.a > Int64(0) THEN
Int64(1) ELSE Int64(NULL) END) AS sum(CASE WHEN data.a > Int64(0) THEN Int64(1)
ELSE NULL END)]]
+ TableScan: data projection=[a]
+ "#
+ );
+ Ok(())
}
#[tokio::test]
@@ -493,18 +523,27 @@ async fn roundtrip_inlist_4() -> Result<()> {
#[tokio::test]
async fn roundtrip_inlist_5() -> Result<()> {
// on roundtrip there is an additional projection during TableScan which
includes all column of the table,
- // using assert_expected_plan here as a workaround
- assert_expected_plan(
+ // using assert_and_generate_plan and assert_snapshot! here as a workaround
+ let plan = generate_plan_from_sql(
"SELECT a, f FROM data WHERE (f IN ('a', 'b', 'c') OR a in (SELECT
data2.a FROM data2 WHERE f IN ('b', 'c', 'd')))",
+ true,
+ true,
+ )
+ .await?;
- "Projection: data.a, data.f\
- \n Filter: data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f =
Utf8(\"c\") OR data2.mark\
- \n LeftMark Join: data.a = data2.a\
- \n TableScan: data projection=[a, f]\
- \n Projection: data2.a\
- \n Filter: data2.f = Utf8(\"b\") OR data2.f = Utf8(\"c\") OR
data2.f = Utf8(\"d\")\
- \n TableScan: data2 projection=[a, f],
partial_filters=[data2.f = Utf8(\"b\") OR data2.f = Utf8(\"c\") OR data2.f =
Utf8(\"d\")]",
- true).await
+ assert_snapshot!(
+ plan,
+ @r#"
+ Projection: data.a, data.f
+ Filter: data.f = Utf8("a") OR data.f = Utf8("b") OR data.f = Utf8("c")
OR data2.mark
+ LeftMark Join: data.a = data2.a
+ TableScan: data projection=[a, f]
+ Projection: data2.a
+ Filter: data2.f = Utf8("b") OR data2.f = Utf8("c") OR data2.f =
Utf8("d")
+ TableScan: data2 projection=[a, f], partial_filters=[data2.f =
Utf8("b") OR data2.f = Utf8("c") OR data2.f = Utf8("d")]
+ "#
+ );
+ Ok(())
}
#[tokio::test]
@@ -535,27 +574,44 @@ async fn roundtrip_non_equi_join() -> Result<()> {
#[tokio::test]
async fn roundtrip_exists_filter() -> Result<()> {
- assert_expected_plan(
+ let plan = generate_plan_from_sql(
"SELECT b FROM data d1 WHERE EXISTS (SELECT * FROM data2 d2 WHERE d2.a
= d1.a AND d2.e != d1.e)",
- "Projection: data.b\
- \n LeftSemi Join: data.a = data2.a Filter: data2.e != CAST(data.e AS
Int64)\
- \n TableScan: data projection=[a, b, e]\
- \n TableScan: data2 projection=[a, e]",
- false // "d1" vs "data" field qualifier
- ).await
+ false,
+ true,
+ )
+ .await?;
+
+ assert_snapshot!(
+ plan,
+ @r#"
+ Projection: data.b
+ LeftSemi Join: data.a = data2.a Filter: data2.e != CAST(data.e AS Int64)
+ TableScan: data projection=[a, b, e]
+ TableScan: data2 projection=[a, e]
+ "#
+ );
+ Ok(())
}
#[tokio::test]
async fn inner_join() -> Result<()> {
- assert_expected_plan(
+ let plan = generate_plan_from_sql(
"SELECT data.a FROM data JOIN data2 ON data.a = data2.a",
- "Projection: data.a\
- \n Inner Join: data.a = data2.a\
- \n TableScan: data projection=[a]\
- \n TableScan: data2 projection=[a]",
+ true,
true,
)
- .await
+ .await?;
+
+ assert_snapshot!(
+ plan,
+ @r#"
+ Projection: data.a
+ Inner Join: data.a = data2.a
+ TableScan: data projection=[a]
+ TableScan: data2 projection=[a]
+ "#
+ );
+ Ok(())
}
#[tokio::test]
@@ -592,17 +648,25 @@ async fn roundtrip_self_implicit_cross_join() ->
Result<()> {
#[tokio::test]
async fn self_join_introduces_aliases() -> Result<()> {
- assert_expected_plan(
+ let plan = generate_plan_from_sql(
"SELECT d1.b, d2.c FROM data d1 JOIN data d2 ON d1.b = d2.b",
- "Projection: left.b, right.c\
- \n Inner Join: left.b = right.b\
- \n SubqueryAlias: left\
- \n TableScan: data projection=[b]\
- \n SubqueryAlias: right\
- \n TableScan: data projection=[b, c]",
false,
+ true,
)
- .await
+ .await?;
+
+ assert_snapshot!(
+ plan,
+ @r#"
+ Projection: left.b, right.c
+ Inner Join: left.b = right.b
+ SubqueryAlias: left
+ TableScan: data projection=[b]
+ SubqueryAlias: right
+ TableScan: data projection=[b, c]
+ "#
+ );
+ Ok(())
}
#[tokio::test]
@@ -747,12 +811,15 @@ async fn aggregate_wo_projection_consume() -> Result<()> {
let proto_plan =
read_json("tests/testdata/test_plans/aggregate_no_project.substrait.json");
- assert_expected_plan_substrait(
- proto_plan,
- "Aggregate: groupBy=[[data.a]], aggr=[[count(data.a) AS countA]]\
- \n TableScan: data projection=[a]",
- )
- .await
+ let plan = generate_plan_from_substrait(proto_plan).await?;
+ assert_snapshot!(
+ plan,
+ @r#"
+ Aggregate: groupBy=[[data.a]], aggr=[[count(data.a) AS countA]]
+ TableScan: data projection=[a]
+ "#
+ );
+ Ok(())
}
#[tokio::test]
@@ -760,12 +827,15 @@ async fn
aggregate_wo_projection_group_expression_ref_consume() -> Result<()> {
let proto_plan =
read_json("tests/testdata/test_plans/aggregate_no_project_group_expression_ref.substrait.json");
- assert_expected_plan_substrait(
- proto_plan,
- "Aggregate: groupBy=[[data.a]], aggr=[[count(data.a) AS countA]]\
- \n TableScan: data projection=[a]",
- )
- .await
+ let plan = generate_plan_from_substrait(proto_plan).await?;
+ assert_snapshot!(
+ plan,
+ @r#"
+ Aggregate: groupBy=[[data.a]], aggr=[[count(data.a) AS countA]]
+ TableScan: data projection=[a]
+ "#
+ );
+ Ok(())
}
#[tokio::test]
@@ -773,12 +843,15 @@ async fn aggregate_wo_projection_sorted_consume() ->
Result<()> {
let proto_plan =
read_json("tests/testdata/test_plans/aggregate_sorted_no_project.substrait.json");
- assert_expected_plan_substrait(
- proto_plan,
- "Aggregate: groupBy=[[data.a]], aggr=[[count(data.a) ORDER BY [data.a
DESC NULLS FIRST] AS countA]]\
- \n TableScan: data projection=[a]",
- )
- .await
+ let plan = generate_plan_from_substrait(proto_plan).await?;
+ assert_snapshot!(
+ plan,
+ @r#"
+ Aggregate: groupBy=[[data.a]], aggr=[[count(data.a) ORDER BY [data.a DESC
NULLS FIRST] AS countA]]
+ TableScan: data projection=[a]
+ "#
+ );
+ Ok(())
}
#[tokio::test]
@@ -986,19 +1059,27 @@ async fn roundtrip_literal_list() -> Result<()> {
#[tokio::test]
async fn roundtrip_literal_struct() -> Result<()> {
- assert_expected_plan(
+ let plan = generate_plan_from_sql(
"SELECT STRUCT(1, true, CAST(NULL AS STRING)) FROM data",
- "Projection: Struct({c0:1,c1:true,c2:}) AS
struct(Int64(1),Boolean(true),NULL)\
- \n TableScan: data projection=[]",
- false, // "Struct(..)" vs "struct(..)"
+ false,
+ true,
)
- .await
+ .await?;
+
+ assert_snapshot!(
+ plan,
+ @r#"
+ Projection: Struct({c0:1,c1:true,c2:}) AS
struct(Int64(1),Boolean(true),NULL)
+ TableScan: data projection=[]
+ "#
+ );
+ Ok(())
}
#[tokio::test]
async fn roundtrip_values() -> Result<()> {
// TODO: would be nice to have a struct inside the LargeList, but
arrow_cast doesn't support that currently
- assert_expected_plan(
+ let plan = generate_plan_from_sql(
"VALUES \
(\
1, \
@@ -1009,17 +1090,18 @@ async fn roundtrip_values() -> Result<()> {
[STRUCT(STRUCT('a' AS string_field) AS struct_field),
STRUCT(STRUCT('b' AS string_field) AS struct_field)]\
), \
(NULL, NULL, NULL, NULL, NULL, NULL)",
- "Values: \
- (\
- Int64(1), \
- Utf8(\"a\"), \
- List([[-213.1, , 5.5, 2.0, 1.0], []]), \
- LargeList([1, 2, 3]), \
- Struct({c0:true,int_field:1,c2:}), \
- List([{struct_field: {string_field: a}}, {struct_field:
{string_field: b}}])\
- ), \
- (Int64(NULL), Utf8(NULL), List(), LargeList(),
Struct({c0:,int_field:,c2:}), List())",
- true).await
+ true,
+ true,
+ )
+ .await?;
+
+ assert_snapshot!(
+ plan,
+ @r#"
+ Values: (Int64(1), Utf8("a"), List([[-213.1, , 5.5, 2.0, 1.0], []]),
LargeList([1, 2, 3]), Struct({c0:true,int_field:1,c2:}), List([{struct_field:
{string_field: a}}, {struct_field: {string_field: b}}])), (Int64(NULL),
Utf8(NULL), List(), LargeList(), Struct({c0:,int_field:,c2:}), List())
+ "#
+ );
+ Ok(())
}
#[tokio::test]
@@ -1061,14 +1143,22 @@ async fn duplicate_column() -> Result<()> {
// only. DataFusion however, is strict about not having duplicate column
names appear in the plan.
// This test confirms that we generate aliases for columns in the plan
which would otherwise have
// colliding names.
- assert_expected_plan(
+ let plan = generate_plan_from_sql(
"SELECT a + 1 as sum_a, a + 1 as sum_a_2 FROM data",
- "Projection: data.a + Int64(1) AS sum_a, data.a + Int64(1) AS data.a +
Int64(1)__temp__0 AS sum_a_2\
- \n Projection: data.a + Int64(1)\
- \n TableScan: data projection=[a]",
+ true,
true,
)
- .await
+ .await?;
+
+ assert_snapshot!(
+ plan,
+ @r#"
+ Projection: data.a + Int64(1) AS sum_a, data.a + Int64(1) AS data.a +
Int64(1)__temp__0 AS sum_a_2
+ Projection: data.a + Int64(1)
+ TableScan: data projection=[a]
+ "#
+ );
+ Ok(())
}
/// Construct a plan that cast columns. Only those SQL types are supported for
now.
@@ -1374,30 +1464,32 @@ async fn assert_read_filter_count(
Ok(())
}
-async fn assert_expected_plan_unoptimized(
+async fn generate_plan_from_sql(
sql: &str,
- expected_plan_str: &str,
assert_schema: bool,
-) -> Result<()> {
+ optimized: bool,
+) -> Result<LogicalPlan> {
let ctx = create_context().await?;
- let df = ctx.sql(sql).await?;
- let plan = df.into_unoptimized_plan();
- let proto = to_substrait_plan(&plan, &ctx.state())?;
- let plan2 = from_substrait_plan(&ctx.state(), &proto).await?;
-
- println!("{plan}");
- println!("{plan2}");
+ let df: DataFrame = ctx.sql(sql).await?;
- println!("{proto:?}");
+ let plan = if optimized {
+ df.into_optimized_plan()?
+ } else {
+ df.into_unoptimized_plan()
+ };
+ let proto = to_substrait_plan(&plan, &ctx.state())?;
+ let plan2 = if optimized {
+ let temp = from_substrait_plan(&ctx.state(), &proto).await?;
+ ctx.state().optimize(&temp)?
+ } else {
+ from_substrait_plan(&ctx.state(), &proto).await?
+ };
if assert_schema {
assert_eq!(plan.schema(), plan2.schema());
}
- let plan2str = format!("{plan2}");
- assert_eq!(expected_plan_str, &plan2str);
-
- Ok(())
+ Ok(plan2)
}
async fn assert_expected_plan(
@@ -1412,11 +1504,6 @@ async fn assert_expected_plan(
let plan2 = from_substrait_plan(&ctx.state(), &proto).await?;
let plan2 = ctx.state().optimize(&plan2)?;
- println!("{plan}");
- println!("{plan2}");
-
- println!("{proto:?}");
-
if assert_schema {
assert_eq!(plan.schema(), plan2.schema());
}
@@ -1427,20 +1514,14 @@ async fn assert_expected_plan(
Ok(())
}
-async fn assert_expected_plan_substrait(
- substrait_plan: Plan,
- expected_plan_str: &str,
-) -> Result<()> {
+async fn generate_plan_from_substrait(substrait_plan: Plan) ->
Result<LogicalPlan> {
let ctx = create_context().await?;
let plan = from_substrait_plan(&ctx.state(), &substrait_plan).await?;
let plan = ctx.state().optimize(&plan)?;
- let planstr = format!("{plan}");
- assert_eq!(planstr, expected_plan_str);
-
- Ok(())
+ Ok(plan)
}
async fn assert_substrait_sql(substrait_plan: Plan, sql: &str) -> Result<()> {
@@ -1491,9 +1572,6 @@ async fn test_alias(sql_with_alias: &str, sql_no_alias:
&str) -> Result<()> {
let proto = to_substrait_plan(&df.into_optimized_plan()?, &ctx.state())?;
let plan = from_substrait_plan(&ctx.state(), &proto).await?;
- println!("{plan_with_alias}");
- println!("{plan}");
-
let plan1str = format!("{plan_with_alias}");
let plan2str = format!("{plan}");
assert_eq!(plan1str, plan2str);
@@ -1510,11 +1588,6 @@ async fn roundtrip_logical_plan_with_ctx(
let plan2 = from_substrait_plan(&ctx.state(), &proto).await?;
let plan2 = ctx.state().optimize(&plan2)?;
- println!("{plan}");
- println!("{plan2}");
-
- println!("{proto:?}");
-
let plan1str = format!("{plan}");
let plan2str = format!("{plan2}");
assert_eq!(plan1str, plan2str);
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]