This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 102f879e60 Migrate-substrait-tests-to-insta, part2 (#15480)
102f879e60 is described below

commit 102f879e60d82c69157b6b21265f8a78ec3a4370
Author: Tommy shu <[email protected]>
AuthorDate: Sun Mar 30 06:42:12 2025 -0400

    Migrate-substrait-tests-to-insta, part2 (#15480)
    
    * add `cargo insta` to dev dependencies
    
    * migrate `consumer_intergration.rs` tests to `insta`
    
    * Revert "migrate `consumer_intergration.rs` tests to `insta`"
    
    This reverts commit c3be2ebfeaeb5afff841810e38eff657f1b86a3f.
    
    * migrate `consumer_integration.rs` to `insta` inline snapshot
    
    * migrate logical plans tests to use `insta` snapshots
    
    * migrate emit_kind_tests to use `insta` snapshots
    
    * migrate function_test to use `insta` snapshots for assertions
    
    * migrate substrait_validations tests to use insta snapshots, missing 
`insta` mapping to `assert!`
    
    * revert `handle_emit_as_project_without_volatile_exprs` back to 
`assert_eq!` and remove `format!` for `assert_snapshot!`
    
    * migrate function and validation tests to use plan directly in 
assert_snapshot!
    
    * migrate serialize tests to use insta snapshots for assertions
    
    * migrate logical_plans test to use insta snapshots for assertions
    
    * WIP
    
    * migrate `assert_expected_plan_substrait`
    
    * refactor tests to use assert_and_generate_plan and assert_snapshot! for 
improved clarity and consistency
    
    * remove println!
    
    * migrate tests to use generate_plan_from_sql for improved clarity
---
 .../tests/cases/roundtrip_logical_plan.rs          | 335 +++++++++++++--------
 1 file changed, 204 insertions(+), 131 deletions(-)

diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs 
b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index f989d05c80..36ee78fe5d 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -37,6 +37,7 @@ use datafusion::logical_expr::{
 };
 use 
datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST;
 use datafusion::prelude::*;
+use insta::assert_snapshot;
 use std::hash::Hash;
 use std::sync::Arc;
 use substrait::proto::extensions::simple_extension_declaration::MappingType;
@@ -188,13 +189,16 @@ async fn simple_select() -> Result<()> {
 
 #[tokio::test]
 async fn wildcard_select() -> Result<()> {
-    assert_expected_plan_unoptimized(
-        "SELECT * FROM data",
-        "Projection: data.a, data.b, data.c, data.d, data.e, data.f\
-        \n  TableScan: data",
-        true,
-    )
-    .await
+    let plan = generate_plan_from_sql("SELECT * FROM data", true, 
false).await?;
+
+    assert_snapshot!(
+    plan,
+    @r#"
+    Projection: data.a, data.b, data.c, data.d, data.e, data.f
+      TableScan: data
+    "#
+    );
+    Ok(())
 }
 
 #[tokio::test]
@@ -299,24 +303,42 @@ async fn aggregate_grouping_sets() -> Result<()> {
 
 #[tokio::test]
 async fn aggregate_grouping_rollup() -> Result<()> {
-    assert_expected_plan(
+    let plan = generate_plan_from_sql(
         "SELECT a, c, e, avg(b) FROM data GROUP BY ROLLUP (a, c, e)",
-        "Projection: data.a, data.c, data.e, avg(data.b)\
-        \n  Aggregate: groupBy=[[GROUPING SETS ((data.a, data.c, data.e), 
(data.a, data.c), (data.a), ())]], aggr=[[avg(data.b)]]\
-        \n    TableScan: data projection=[a, b, c, e]",
-        true
-    ).await
+        true,
+        true,
+    )
+    .await?;
+
+    assert_snapshot!(
+    plan,
+    @r#"
+        Projection: data.a, data.c, data.e, avg(data.b)
+          Aggregate: groupBy=[[GROUPING SETS ((data.a, data.c, data.e), 
(data.a, data.c), (data.a), ())]], aggr=[[avg(data.b)]]
+            TableScan: data projection=[a, b, c, e]
+        "#
+    );
+    Ok(())
 }
 
 #[tokio::test]
 async fn multilayer_aggregate() -> Result<()> {
-    assert_expected_plan(
+    let plan = generate_plan_from_sql(
         "SELECT a, sum(partial_count_b) FROM (SELECT a, count(b) as 
partial_count_b FROM data GROUP BY a) GROUP BY a",
-        "Aggregate: groupBy=[[data.a]], aggr=[[sum(count(data.b)) AS 
sum(partial_count_b)]]\
-        \n  Aggregate: groupBy=[[data.a]], aggr=[[count(data.b)]]\
-        \n    TableScan: data projection=[a, b]",
-        true
-    ).await
+        true,
+        true,
+    )
+    .await?;
+
+    assert_snapshot!(
+    plan,
+    @r#"
+    Aggregate: groupBy=[[data.a]], aggr=[[sum(count(data.b)) AS 
sum(partial_count_b)]]
+      Aggregate: groupBy=[[data.a]], aggr=[[count(data.b)]]
+        TableScan: data projection=[a, b]
+    "#
+    );
+    Ok(())
 }
 
 #[tokio::test]
@@ -454,13 +476,21 @@ async fn try_cast_decimal_to_string() -> Result<()> {
 
 #[tokio::test]
 async fn aggregate_case() -> Result<()> {
-    assert_expected_plan(
+    let plan = generate_plan_from_sql(
         "SELECT sum(CASE WHEN a > 0 THEN 1 ELSE NULL END) FROM data",
-        "Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN data.a > Int64(0) THEN 
Int64(1) ELSE Int64(NULL) END) AS sum(CASE WHEN data.a > Int64(0) THEN Int64(1) 
ELSE NULL END)]]\
-         \n  TableScan: data projection=[a]",
-        true
+        true,
+        true,
     )
-        .await
+    .await?;
+
+    assert_snapshot!(
+    plan,
+    @r#"
+    Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN data.a > Int64(0) THEN 
Int64(1) ELSE Int64(NULL) END) AS sum(CASE WHEN data.a > Int64(0) THEN Int64(1) 
ELSE NULL END)]]
+      TableScan: data projection=[a]
+    "#
+    );
+    Ok(())
 }
 
 #[tokio::test]
@@ -493,18 +523,27 @@ async fn roundtrip_inlist_4() -> Result<()> {
 #[tokio::test]
 async fn roundtrip_inlist_5() -> Result<()> {
     // on roundtrip there is an additional projection during TableScan which 
includes all column of the table,
-    // using assert_expected_plan here as a workaround
-    assert_expected_plan(
+    // using assert_and_generate_plan and assert_snapshot! here as a workaround
+    let plan = generate_plan_from_sql(
         "SELECT a, f FROM data WHERE (f IN ('a', 'b', 'c') OR a in (SELECT 
data2.a FROM data2 WHERE f IN ('b', 'c', 'd')))",
+        true,
+        true,
+    )
+    .await?;
 
-        "Projection: data.a, data.f\
-        \n  Filter: data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = 
Utf8(\"c\") OR data2.mark\
-        \n    LeftMark Join: data.a = data2.a\
-        \n      TableScan: data projection=[a, f]\
-        \n      Projection: data2.a\
-        \n        Filter: data2.f = Utf8(\"b\") OR data2.f = Utf8(\"c\") OR 
data2.f = Utf8(\"d\")\
-        \n          TableScan: data2 projection=[a, f], 
partial_filters=[data2.f = Utf8(\"b\") OR data2.f = Utf8(\"c\") OR data2.f = 
Utf8(\"d\")]",
-    true).await
+    assert_snapshot!(
+    plan,
+    @r#"
+    Projection: data.a, data.f
+      Filter: data.f = Utf8("a") OR data.f = Utf8("b") OR data.f = Utf8("c") 
OR data2.mark
+        LeftMark Join: data.a = data2.a
+          TableScan: data projection=[a, f]
+          Projection: data2.a
+            Filter: data2.f = Utf8("b") OR data2.f = Utf8("c") OR data2.f = 
Utf8("d")
+              TableScan: data2 projection=[a, f], partial_filters=[data2.f = 
Utf8("b") OR data2.f = Utf8("c") OR data2.f = Utf8("d")]
+    "#
+            );
+    Ok(())
 }
 
 #[tokio::test]
@@ -535,27 +574,44 @@ async fn roundtrip_non_equi_join() -> Result<()> {
 
 #[tokio::test]
 async fn roundtrip_exists_filter() -> Result<()> {
-    assert_expected_plan(
+    let plan = generate_plan_from_sql(
         "SELECT b FROM data d1 WHERE EXISTS (SELECT * FROM data2 d2 WHERE d2.a 
= d1.a AND d2.e != d1.e)",
-        "Projection: data.b\
-        \n  LeftSemi Join: data.a = data2.a Filter: data2.e != CAST(data.e AS 
Int64)\
-        \n    TableScan: data projection=[a, b, e]\
-        \n    TableScan: data2 projection=[a, e]",
-        false // "d1" vs "data" field qualifier
-    ).await
+        false,
+        true,
+    )
+    .await?;
+
+    assert_snapshot!(
+    plan,
+    @r#"
+    Projection: data.b
+      LeftSemi Join: data.a = data2.a Filter: data2.e != CAST(data.e AS Int64)
+        TableScan: data projection=[a, b, e]
+        TableScan: data2 projection=[a, e]
+    "#
+            );
+    Ok(())
 }
 
 #[tokio::test]
 async fn inner_join() -> Result<()> {
-    assert_expected_plan(
+    let plan = generate_plan_from_sql(
         "SELECT data.a FROM data JOIN data2 ON data.a = data2.a",
-        "Projection: data.a\
-         \n  Inner Join: data.a = data2.a\
-         \n    TableScan: data projection=[a]\
-         \n    TableScan: data2 projection=[a]",
+        true,
         true,
     )
-    .await
+    .await?;
+
+    assert_snapshot!(
+    plan,
+    @r#"
+    Projection: data.a
+      Inner Join: data.a = data2.a
+        TableScan: data projection=[a]
+        TableScan: data2 projection=[a]
+    "#
+            );
+    Ok(())
 }
 
 #[tokio::test]
@@ -592,17 +648,25 @@ async fn roundtrip_self_implicit_cross_join() -> 
Result<()> {
 
 #[tokio::test]
 async fn self_join_introduces_aliases() -> Result<()> {
-    assert_expected_plan(
+    let plan = generate_plan_from_sql(
         "SELECT d1.b, d2.c FROM data d1 JOIN data d2 ON d1.b = d2.b",
-        "Projection: left.b, right.c\
-        \n  Inner Join: left.b = right.b\
-        \n    SubqueryAlias: left\
-        \n      TableScan: data projection=[b]\
-        \n    SubqueryAlias: right\
-        \n      TableScan: data projection=[b, c]",
         false,
+        true,
     )
-    .await
+    .await?;
+
+    assert_snapshot!(
+    plan,
+    @r#"
+    Projection: left.b, right.c
+      Inner Join: left.b = right.b
+        SubqueryAlias: left
+          TableScan: data projection=[b]
+        SubqueryAlias: right
+          TableScan: data projection=[b, c]
+    "#
+            );
+    Ok(())
 }
 
 #[tokio::test]
@@ -747,12 +811,15 @@ async fn aggregate_wo_projection_consume() -> Result<()> {
     let proto_plan =
         
read_json("tests/testdata/test_plans/aggregate_no_project.substrait.json");
 
-    assert_expected_plan_substrait(
-        proto_plan,
-        "Aggregate: groupBy=[[data.a]], aggr=[[count(data.a) AS countA]]\
-        \n  TableScan: data projection=[a]",
-    )
-    .await
+    let plan = generate_plan_from_substrait(proto_plan).await?;
+    assert_snapshot!(
+    plan,
+    @r#"
+            Aggregate: groupBy=[[data.a]], aggr=[[count(data.a) AS countA]]
+              TableScan: data projection=[a]
+            "#
+        );
+    Ok(())
 }
 
 #[tokio::test]
@@ -760,12 +827,15 @@ async fn 
aggregate_wo_projection_group_expression_ref_consume() -> Result<()> {
     let proto_plan =
         
read_json("tests/testdata/test_plans/aggregate_no_project_group_expression_ref.substrait.json");
 
-    assert_expected_plan_substrait(
-        proto_plan,
-        "Aggregate: groupBy=[[data.a]], aggr=[[count(data.a) AS countA]]\
-        \n  TableScan: data projection=[a]",
-    )
-    .await
+    let plan = generate_plan_from_substrait(proto_plan).await?;
+    assert_snapshot!(
+    plan,
+    @r#"
+            Aggregate: groupBy=[[data.a]], aggr=[[count(data.a) AS countA]]
+              TableScan: data projection=[a]
+            "#
+        );
+    Ok(())
 }
 
 #[tokio::test]
@@ -773,12 +843,15 @@ async fn aggregate_wo_projection_sorted_consume() -> 
Result<()> {
     let proto_plan =
         
read_json("tests/testdata/test_plans/aggregate_sorted_no_project.substrait.json");
 
-    assert_expected_plan_substrait(
-        proto_plan,
-        "Aggregate: groupBy=[[data.a]], aggr=[[count(data.a) ORDER BY [data.a 
DESC NULLS FIRST] AS countA]]\
-        \n  TableScan: data projection=[a]",
-    )
-    .await
+    let plan = generate_plan_from_substrait(proto_plan).await?;
+    assert_snapshot!(
+    plan,
+    @r#"
+    Aggregate: groupBy=[[data.a]], aggr=[[count(data.a) ORDER BY [data.a DESC 
NULLS FIRST] AS countA]]
+      TableScan: data projection=[a]
+    "#
+            );
+    Ok(())
 }
 
 #[tokio::test]
@@ -986,19 +1059,27 @@ async fn roundtrip_literal_list() -> Result<()> {
 
 #[tokio::test]
 async fn roundtrip_literal_struct() -> Result<()> {
-    assert_expected_plan(
+    let plan = generate_plan_from_sql(
         "SELECT STRUCT(1, true, CAST(NULL AS STRING)) FROM data",
-        "Projection: Struct({c0:1,c1:true,c2:}) AS 
struct(Int64(1),Boolean(true),NULL)\
-        \n  TableScan: data projection=[]",
-        false, // "Struct(..)" vs "struct(..)"
+        false,
+        true,
     )
-    .await
+    .await?;
+
+    assert_snapshot!(
+    plan,
+    @r#"
+    Projection: Struct({c0:1,c1:true,c2:}) AS 
struct(Int64(1),Boolean(true),NULL)
+      TableScan: data projection=[]
+    "#
+            );
+    Ok(())
 }
 
 #[tokio::test]
 async fn roundtrip_values() -> Result<()> {
     // TODO: would be nice to have a struct inside the LargeList, but 
arrow_cast doesn't support that currently
-    assert_expected_plan(
+    let plan = generate_plan_from_sql(
         "VALUES \
             (\
                 1, \
@@ -1009,17 +1090,18 @@ async fn roundtrip_values() -> Result<()> {
                 [STRUCT(STRUCT('a' AS string_field) AS struct_field), 
STRUCT(STRUCT('b' AS string_field) AS struct_field)]\
             ), \
             (NULL, NULL, NULL, NULL, NULL, NULL)",
-        "Values: \
-            (\
-                Int64(1), \
-                Utf8(\"a\"), \
-                List([[-213.1, , 5.5, 2.0, 1.0], []]), \
-                LargeList([1, 2, 3]), \
-                Struct({c0:true,int_field:1,c2:}), \
-                List([{struct_field: {string_field: a}}, {struct_field: 
{string_field: b}}])\
-            ), \
-            (Int64(NULL), Utf8(NULL), List(), LargeList(), 
Struct({c0:,int_field:,c2:}), List())",
-    true).await
+        true,
+        true,
+    )
+    .await?;
+
+    assert_snapshot!(
+    plan,
+    @r#"
+    Values: (Int64(1), Utf8("a"), List([[-213.1, , 5.5, 2.0, 1.0], []]), 
LargeList([1, 2, 3]), Struct({c0:true,int_field:1,c2:}), List([{struct_field: 
{string_field: a}}, {struct_field: {string_field: b}}])), (Int64(NULL), 
Utf8(NULL), List(), LargeList(), Struct({c0:,int_field:,c2:}), List())
+    "#
+            );
+    Ok(())
 }
 
 #[tokio::test]
@@ -1061,14 +1143,22 @@ async fn duplicate_column() -> Result<()> {
     // only. DataFusion however, is strict about not having duplicate column 
names appear in the plan.
     // This test confirms that we generate aliases for columns in the plan 
which would otherwise have
     // colliding names.
-    assert_expected_plan(
+    let plan = generate_plan_from_sql(
         "SELECT a + 1 as sum_a, a + 1 as sum_a_2 FROM data",
-        "Projection: data.a + Int64(1) AS sum_a, data.a + Int64(1) AS data.a + 
Int64(1)__temp__0 AS sum_a_2\
-            \n  Projection: data.a + Int64(1)\
-            \n    TableScan: data projection=[a]",
+        true,
         true,
     )
-    .await
+    .await?;
+
+    assert_snapshot!(
+    plan,
+    @r#"
+    Projection: data.a + Int64(1) AS sum_a, data.a + Int64(1) AS data.a + 
Int64(1)__temp__0 AS sum_a_2
+      Projection: data.a + Int64(1)
+        TableScan: data projection=[a]
+    "#
+        );
+    Ok(())
 }
 
 /// Construct a plan that cast columns. Only those SQL types are supported for 
now.
@@ -1374,30 +1464,32 @@ async fn assert_read_filter_count(
     Ok(())
 }
 
-async fn assert_expected_plan_unoptimized(
+async fn generate_plan_from_sql(
     sql: &str,
-    expected_plan_str: &str,
     assert_schema: bool,
-) -> Result<()> {
+    optimized: bool,
+) -> Result<LogicalPlan> {
     let ctx = create_context().await?;
-    let df = ctx.sql(sql).await?;
-    let plan = df.into_unoptimized_plan();
-    let proto = to_substrait_plan(&plan, &ctx.state())?;
-    let plan2 = from_substrait_plan(&ctx.state(), &proto).await?;
-
-    println!("{plan}");
-    println!("{plan2}");
+    let df: DataFrame = ctx.sql(sql).await?;
 
-    println!("{proto:?}");
+    let plan = if optimized {
+        df.into_optimized_plan()?
+    } else {
+        df.into_unoptimized_plan()
+    };
+    let proto = to_substrait_plan(&plan, &ctx.state())?;
+    let plan2 = if optimized {
+        let temp = from_substrait_plan(&ctx.state(), &proto).await?;
+        ctx.state().optimize(&temp)?
+    } else {
+        from_substrait_plan(&ctx.state(), &proto).await?
+    };
 
     if assert_schema {
         assert_eq!(plan.schema(), plan2.schema());
     }
 
-    let plan2str = format!("{plan2}");
-    assert_eq!(expected_plan_str, &plan2str);
-
-    Ok(())
+    Ok(plan2)
 }
 
 async fn assert_expected_plan(
@@ -1412,11 +1504,6 @@ async fn assert_expected_plan(
     let plan2 = from_substrait_plan(&ctx.state(), &proto).await?;
     let plan2 = ctx.state().optimize(&plan2)?;
 
-    println!("{plan}");
-    println!("{plan2}");
-
-    println!("{proto:?}");
-
     if assert_schema {
         assert_eq!(plan.schema(), plan2.schema());
     }
@@ -1427,20 +1514,14 @@ async fn assert_expected_plan(
     Ok(())
 }
 
-async fn assert_expected_plan_substrait(
-    substrait_plan: Plan,
-    expected_plan_str: &str,
-) -> Result<()> {
+async fn generate_plan_from_substrait(substrait_plan: Plan) -> 
Result<LogicalPlan> {
     let ctx = create_context().await?;
 
     let plan = from_substrait_plan(&ctx.state(), &substrait_plan).await?;
 
     let plan = ctx.state().optimize(&plan)?;
 
-    let planstr = format!("{plan}");
-    assert_eq!(planstr, expected_plan_str);
-
-    Ok(())
+    Ok(plan)
 }
 
 async fn assert_substrait_sql(substrait_plan: Plan, sql: &str) -> Result<()> {
@@ -1491,9 +1572,6 @@ async fn test_alias(sql_with_alias: &str, sql_no_alias: 
&str) -> Result<()> {
     let proto = to_substrait_plan(&df.into_optimized_plan()?, &ctx.state())?;
     let plan = from_substrait_plan(&ctx.state(), &proto).await?;
 
-    println!("{plan_with_alias}");
-    println!("{plan}");
-
     let plan1str = format!("{plan_with_alias}");
     let plan2str = format!("{plan}");
     assert_eq!(plan1str, plan2str);
@@ -1510,11 +1588,6 @@ async fn roundtrip_logical_plan_with_ctx(
     let plan2 = from_substrait_plan(&ctx.state(), &proto).await?;
     let plan2 = ctx.state().optimize(&plan2)?;
 
-    println!("{plan}");
-    println!("{plan2}");
-
-    println!("{proto:?}");
-
     let plan1str = format!("{plan}");
     let plan2str = format!("{plan2}");
     assert_eq!(plan1str, plan2str);


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to