This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 27627546d1 Support Substrait's VirtualTables (#10531)
27627546d1 is described below

commit 27627546d1e58ac14e25819241c655c3d18807b3
Author: Arttu <[email protected]>
AuthorDate: Tue May 28 15:18:45 2024 +0200

    Support Substrait's VirtualTables (#10531)
    
    * Add support for Substrait VirtualTables
    
    Adds support for Substrait's VirtualTables, ie. tables with data baked-in 
into the Substrait plan instead of being read from a source.
    
    Adds conversion in both ways (Substrait -> DataFusion and DataFusion -> 
Substrait)
    and a roundtrip test.
    
    * fix clippy
    
    * Add support for empty relations
    
    * Fix consuming Structs inside Lists and Structs
    
    Also adds roundtrip schema assertions for cases where possible
    
    * Rename from_substrait_struct -> from_substrait_struct_type for clarity
    
    * Add DataType::LargeList to to_substrait_named_struct
    
    * cargo fmt --all
    
    * Add validation that names list matches schema exactly
    
    * Add a LargeList into VALUES test
---
 datafusion/substrait/src/logical_plan/consumer.rs  | 191 +++++++++++++++++----
 datafusion/substrait/src/logical_plan/producer.rs  | 134 ++++++++++++++-
 .../tests/cases/roundtrip_logical_plan.rs          |  98 ++++++++---
 3 files changed, 361 insertions(+), 62 deletions(-)

diff --git a/datafusion/substrait/src/logical_plan/consumer.rs 
b/datafusion/substrait/src/logical_plan/consumer.rs
index d6c60ebdde..abebc68123 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -16,15 +16,17 @@
 // under the License.
 
 use async_recursion::async_recursion;
-use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit};
+use datafusion::arrow::datatypes::{
+    DataType, Field, Fields, IntervalUnit, Schema, TimeUnit,
+};
 use datafusion::common::{
     not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, 
DFSchemaRef,
 };
 
 use datafusion::execution::FunctionRegistry;
 use datafusion::logical_expr::{
-    aggregate_function, expr::find_df_window_func, BinaryExpr, Case, Expr, 
LogicalPlan,
-    Operator, ScalarUDF,
+    aggregate_function, expr::find_df_window_func, BinaryExpr, Case, 
EmptyRelation, Expr,
+    LogicalPlan, Operator, ScalarUDF, Values,
 };
 use datafusion::logical_expr::{
     expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning,
@@ -58,7 +60,7 @@ use substrait::proto::{
     rel::RelType,
     set_rel,
     sort_field::{SortDirection, SortKind::*},
-    AggregateFunction, Expression, Plan, Rel, Type,
+    AggregateFunction, Expression, NamedStruct, Plan, Rel, Type,
 };
 use substrait::proto::{FunctionArgument, SortField};
 
@@ -509,7 +511,51 @@ pub async fn from_substrait_rel(
                     _ => Ok(t),
                 }
             }
-            _ => not_impl_err!("Only NamedTable reads are supported"),
+            Some(ReadType::VirtualTable(vt)) => {
+                let base_schema = read.base_schema.as_ref().ok_or_else(|| {
+                    substrait_datafusion_err!("No base schema provided for 
Virtual Table")
+                })?;
+
+                let schema = from_substrait_named_struct(base_schema)?;
+
+                if vt.values.is_empty() {
+                    return Ok(LogicalPlan::EmptyRelation(EmptyRelation {
+                        produce_one_row: false,
+                        schema,
+                    }));
+                }
+
+                let values = vt
+                    .values
+                    .iter()
+                    .map(|row| {
+                        let mut name_idx = 0;
+                        let lits = row
+                            .fields
+                            .iter()
+                            .map(|lit| {
+                                name_idx += 1; // top-level names are provided 
through schema
+                                Ok(Expr::Literal(from_substrait_literal(
+                                    lit,
+                                    &base_schema.names,
+                                    &mut name_idx,
+                                )?))
+                            })
+                            .collect::<Result<_>>()?;
+                        if name_idx != base_schema.names.len() {
+                            return substrait_err!(
+                                "Names list must match exactly to nested 
schema, but found {} uses for {} names",
+                                name_idx,
+                                base_schema.names.len()
+                            );
+                        }
+                        Ok(lits)
+                    })
+                    .collect::<Result<_>>()?;
+
+                Ok(LogicalPlan::Values(Values { schema, values }))
+            }
+            _ => not_impl_err!("Only NamedTable and VirtualTable reads are 
supported"),
         },
         Some(RelType::Set(set)) => match set_rel::SetOp::try_from(set.op) {
             Ok(set_op) => match set_op {
@@ -948,7 +994,7 @@ pub async fn from_substrait_rex(
             }
         }
         Some(RexType::Literal(lit)) => {
-            let scalar_value = from_substrait_literal(lit)?;
+            let scalar_value = from_substrait_literal_without_names(lit)?;
             Ok(Arc::new(Expr::Literal(scalar_value)))
         }
         Some(RexType::Cast(cast)) => match cast.as_ref().r#type.as_ref() {
@@ -964,9 +1010,9 @@ pub async fn from_substrait_rex(
                     .as_ref()
                     .clone(),
                 ),
-                from_substrait_type(output_type)?,
+                from_substrait_type_without_names(output_type)?,
             )))),
-            None => substrait_err!("Cast experssion without output type is not 
allowed"),
+            None => substrait_err!("Cast expression without output type is not 
allowed"),
         },
         Some(RexType::WindowFunction(window)) => {
             let fun = match extensions.get(&window.function_reference) {
@@ -1062,7 +1108,15 @@ pub async fn from_substrait_rex(
     }
 }
 
-pub(crate) fn from_substrait_type(dt: &substrait::proto::Type) -> 
Result<DataType> {
+pub(crate) fn from_substrait_type_without_names(dt: &Type) -> Result<DataType> 
{
+    from_substrait_type(dt, &[], &mut 0)
+}
+
+fn from_substrait_type(
+    dt: &Type,
+    dfs_names: &[String],
+    name_idx: &mut usize,
+) -> Result<DataType> {
     match &dt.kind {
         Some(s_kind) => match s_kind {
             r#type::Kind::Bool(_) => Ok(DataType::Boolean),
@@ -1142,7 +1196,7 @@ pub(crate) fn from_substrait_type(dt: 
&substrait::proto::Type) -> Result<DataTyp
                     substrait_datafusion_err!("List type must have inner type")
                 })?;
                 let field = Arc::new(Field::new_list_field(
-                    from_substrait_type(inner_type)?,
+                    from_substrait_type(inner_type, dfs_names, name_idx)?,
                     is_substrait_type_nullable(inner_type)?,
                 ));
                 match list.type_variation_reference {
@@ -1182,24 +1236,69 @@ pub(crate) fn from_substrait_type(dt: 
&substrait::proto::Type) -> Result<DataTyp
                     ),
                 }
             },
-            r#type::Kind::Struct(s) => {
-                let mut fields = vec![];
-                for (i, f) in s.types.iter().enumerate() {
-                    let field = Field::new(
-                        &format!("c{i}"),
-                        from_substrait_type(f)?,
-                        is_substrait_type_nullable(f)?,
-                    );
-                    fields.push(field);
-                }
-                Ok(DataType::Struct(fields.into()))
-            }
+            r#type::Kind::Struct(s) => 
Ok(DataType::Struct(from_substrait_struct_type(
+                s, dfs_names, name_idx,
+            )?)),
             _ => not_impl_err!("Unsupported Substrait type: {s_kind:?}"),
         },
         _ => not_impl_err!("`None` Substrait kind is not supported"),
     }
 }
 
+fn from_substrait_struct_type(
+    s: &r#type::Struct,
+    dfs_names: &[String],
+    name_idx: &mut usize,
+) -> Result<Fields> {
+    let mut fields = vec![];
+    for (i, f) in s.types.iter().enumerate() {
+        let field = Field::new(
+            next_struct_field_name(i, dfs_names, name_idx)?,
+            from_substrait_type(f, dfs_names, name_idx)?,
+            is_substrait_type_nullable(f)?,
+        );
+        fields.push(field);
+    }
+    Ok(fields.into())
+}
+
+fn next_struct_field_name(
+    i: usize,
+    dfs_names: &[String],
+    name_idx: &mut usize,
+) -> Result<String> {
+    if dfs_names.is_empty() {
+        // If names are not given, create dummy names
+        // c0, c1, ... align with e.g. SqlToRel::create_named_struct
+        Ok(format!("c{i}"))
+    } else {
+        let name = dfs_names.get(*name_idx).cloned().ok_or_else(|| {
+            substrait_datafusion_err!("Named schema must contain names for all 
fields")
+        })?;
+        *name_idx += 1;
+        Ok(name)
+    }
+}
+
+fn from_substrait_named_struct(base_schema: &NamedStruct) -> 
Result<DFSchemaRef> {
+    let mut name_idx = 0;
+    let fields = from_substrait_struct_type(
+        base_schema.r#struct.as_ref().ok_or_else(|| {
+            substrait_datafusion_err!("Named struct must contain a struct")
+        })?,
+        &base_schema.names,
+        &mut name_idx,
+    );
+    if name_idx != base_schema.names.len() {
+        return substrait_err!(
+                                "Names list must match exactly to nested 
schema, but found {} uses for {} names",
+                                name_idx,
+                                base_schema.names.len()
+                            );
+    }
+    Ok(DFSchemaRef::new(DFSchema::try_from(Schema::new(fields?))?))
+}
+
 fn is_substrait_type_nullable(dtype: &Type) -> Result<bool> {
     fn is_nullable(nullability: i32) -> bool {
         nullability != substrait::proto::r#type::Nullability::Required as i32
@@ -1277,7 +1376,15 @@ fn from_substrait_bound(
     }
 }
 
-pub(crate) fn from_substrait_literal(lit: &Literal) -> Result<ScalarValue> {
+pub(crate) fn from_substrait_literal_without_names(lit: &Literal) -> 
Result<ScalarValue> {
+    from_substrait_literal(lit, &vec![], &mut 0)
+}
+
+fn from_substrait_literal(
+    lit: &Literal,
+    dfs_names: &Vec<String>,
+    name_idx: &mut usize,
+) -> Result<ScalarValue> {
     let scalar_value = match &lit.literal_type {
         Some(LiteralType::Boolean(b)) => ScalarValue::Boolean(Some(*b)),
         Some(LiteralType::I8(n)) => match lit.type_variation_reference {
@@ -1359,7 +1466,7 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> 
Result<ScalarValue> {
             let elements = l
                 .values
                 .iter()
-                .map(from_substrait_literal)
+                .map(|el| from_substrait_literal(el, dfs_names, name_idx))
                 .collect::<Result<Vec<_>>>()?;
             if elements.is_empty() {
                 return substrait_err!(
@@ -1381,7 +1488,11 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> 
Result<ScalarValue> {
             }
         }
         Some(LiteralType::EmptyList(l)) => {
-            let element_type = 
from_substrait_type(l.r#type.clone().unwrap().as_ref())?;
+            let element_type = from_substrait_type(
+                l.r#type.clone().unwrap().as_ref(),
+                dfs_names,
+                name_idx,
+            )?;
             match lit.type_variation_reference {
                 DEFAULT_CONTAINER_TYPE_REF => {
                     ScalarValue::List(ScalarValue::new_list(&[], 
&element_type))
@@ -1397,16 +1508,16 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> 
Result<ScalarValue> {
         Some(LiteralType::Struct(s)) => {
             let mut builder = ScalarStructBuilder::new();
             for (i, field) in s.fields.iter().enumerate() {
-                let sv = from_substrait_literal(field)?;
-                // c0, c1, ... align with e.g. SqlToRel::create_named_struct
-                builder = builder.with_scalar(
-                    Field::new(&format!("c{i}"), sv.data_type(), 
field.nullable),
-                    sv,
-                );
+                let name = next_struct_field_name(i, dfs_names, name_idx)?;
+                let sv = from_substrait_literal(field, dfs_names, name_idx)?;
+                builder = builder
+                    .with_scalar(Field::new(name, sv.data_type(), 
field.nullable), sv);
             }
             builder.build()?
         }
-        Some(LiteralType::Null(ntype)) => from_substrait_null(ntype)?,
+        Some(LiteralType::Null(ntype)) => {
+            from_substrait_null(ntype, dfs_names, name_idx)?
+        }
         Some(LiteralType::UserDefined(user_defined)) => {
             match user_defined.type_reference {
                 INTERVAL_YEAR_MONTH_TYPE_REF => {
@@ -1461,7 +1572,11 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> 
Result<ScalarValue> {
     Ok(scalar_value)
 }
 
-fn from_substrait_null(null_type: &Type) -> Result<ScalarValue> {
+fn from_substrait_null(
+    null_type: &Type,
+    dfs_names: &[String],
+    name_idx: &mut usize,
+) -> Result<ScalarValue> {
     if let Some(kind) = &null_type.kind {
         match kind {
             r#type::Kind::Bool(_) => Ok(ScalarValue::Boolean(None)),
@@ -1539,7 +1654,11 @@ fn from_substrait_null(null_type: &Type) -> 
Result<ScalarValue> {
             )),
             r#type::Kind::List(l) => {
                 let field = Field::new_list_field(
-                    from_substrait_type(l.r#type.clone().unwrap().as_ref())?,
+                    from_substrait_type(
+                        l.r#type.clone().unwrap().as_ref(),
+                        dfs_names,
+                        name_idx,
+                    )?,
                     true,
                 );
                 match l.type_variation_reference {
@@ -1554,6 +1673,10 @@ fn from_substrait_null(null_type: &Type) -> 
Result<ScalarValue> {
                     ),
                 }
             }
+            r#type::Kind::Struct(s) => {
+                let fields = from_substrait_struct_type(s, dfs_names, 
name_idx)?;
+                Ok(ScalarStructBuilder::new_null(fields))
+            }
             _ => not_impl_err!("Unsupported Substrait type for null: 
{kind:?}"),
         }
     } else {
diff --git a/datafusion/substrait/src/logical_plan/producer.rs 
b/datafusion/substrait/src/logical_plan/producer.rs
index 592b40db59..4dd8226366 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -15,6 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use itertools::Itertools;
 use std::collections::HashMap;
 use std::ops::Deref;
 use std::sync::Arc;
@@ -32,7 +33,9 @@ use datafusion::{
 };
 
 use datafusion::arrow::array::{Array, GenericListArray, OffsetSizeTrait};
-use datafusion::common::{exec_err, internal_err, not_impl_err};
+use datafusion::common::{
+    exec_err, internal_err, not_impl_err, plan_err, substrait_datafusion_err,
+};
 use datafusion::common::{substrait_err, DFSchemaRef};
 #[allow(unused_imports)]
 use datafusion::logical_expr::aggregate_function;
@@ -50,6 +53,7 @@ use substrait::proto::expression::literal::{List, Struct};
 use substrait::proto::expression::subquery::InPredicate;
 use substrait::proto::expression::window_function::BoundsType;
 use substrait::proto::r#type::{parameter, Parameter};
+use substrait::proto::read_rel::VirtualTable;
 use substrait::proto::{CrossRel, ExchangeRel};
 use substrait::{
     proto::{
@@ -174,6 +178,62 @@ pub fn to_substrait_rel(
                 }))),
             }))
         }
+        LogicalPlan::EmptyRelation(e) => {
+            if e.produce_one_row {
+                return not_impl_err!(
+                    "Producing a row from empty relation is unsupported"
+                );
+            }
+            Ok(Box::new(Rel {
+                rel_type: Some(RelType::Read(Box::new(ReadRel {
+                    common: None,
+                    base_schema: Some(to_substrait_named_struct(&e.schema)?),
+                    filter: None,
+                    best_effort_filter: None,
+                    projection: None,
+                    advanced_extension: None,
+                    read_type: Some(ReadType::VirtualTable(VirtualTable {
+                        values: vec![],
+                    })),
+                }))),
+            }))
+        }
+        LogicalPlan::Values(v) => {
+            let values = v
+                .values
+                .iter()
+                .map(|row| {
+                    let fields = row
+                        .iter()
+                        .map(|v| match v {
+                            Expr::Literal(sv) => to_substrait_literal(sv),
+                            Expr::Alias(alias) => match alias.expr.as_ref() {
+                                // The schema gives us the names, so we can 
skip aliases
+                                Expr::Literal(sv) => to_substrait_literal(sv),
+                                _ => Err(substrait_datafusion_err!(
+                                    "Only literal types can be aliased in 
Virtual Tables, got: {}", alias.expr.variant_name()
+                                )),
+                            },
+                            _ => Err(substrait_datafusion_err!(
+                                "Only literal types and aliases are supported 
in Virtual Tables, got: {}", v.variant_name()
+                            )),
+                        })
+                        .collect::<Result<_>>()?;
+                    Ok(Struct { fields })
+                })
+                .collect::<Result<_>>()?;
+            Ok(Box::new(Rel {
+                rel_type: Some(RelType::Read(Box::new(ReadRel {
+                    common: None,
+                    base_schema: Some(to_substrait_named_struct(&v.schema)?),
+                    filter: None,
+                    best_effort_filter: None,
+                    projection: None,
+                    advanced_extension: None,
+                    read_type: Some(ReadType::VirtualTable(VirtualTable { 
values })),
+                }))),
+            }))
+        }
         LogicalPlan::Projection(p) => {
             let expressions = p
                 .expr
@@ -519,6 +579,63 @@ pub fn to_substrait_rel(
     }
 }
 
+fn to_substrait_named_struct(schema: &DFSchemaRef) -> Result<NamedStruct> {
+    // Substrait wants a list of all field names, including nested fields from 
structs,
+    // also from within e.g. lists and maps. However, it does not want the 
list and map field names
+    // themselves - only proper structs fields are considered to have useful 
names.
+    fn names_dfs(dtype: &DataType) -> Result<Vec<String>> {
+        match dtype {
+            DataType::Struct(fields) => {
+                let mut names = Vec::new();
+                for field in fields {
+                    names.push(field.name().to_string());
+                    names.extend(names_dfs(field.data_type())?);
+                }
+                Ok(names)
+            }
+            DataType::List(l) => names_dfs(l.data_type()),
+            DataType::LargeList(l) => names_dfs(l.data_type()),
+            DataType::Map(m, _) => match m.data_type() {
+                DataType::Struct(key_and_value) if key_and_value.len() == 2 => 
{
+                    let key_names =
+                        names_dfs(key_and_value.first().unwrap().data_type())?;
+                    let value_names =
+                        names_dfs(key_and_value.last().unwrap().data_type())?;
+                    Ok([key_names, value_names].concat())
+                }
+                _ => plan_err!("Map fields must contain a Struct with exactly 
2 fields"),
+            },
+            _ => Ok(Vec::new()),
+        }
+    }
+
+    let names = schema
+        .fields()
+        .iter()
+        .map(|f| {
+            let mut names = vec![f.name().to_string()];
+            names.extend(names_dfs(f.data_type())?);
+            Ok(names)
+        })
+        .flatten_ok()
+        .collect::<Result<_>>()?;
+
+    let field_types = r#type::Struct {
+        types: schema
+            .fields()
+            .iter()
+            .map(|f| to_substrait_type(f.data_type(), f.is_nullable()))
+            .collect::<Result<_>>()?,
+        type_variation_reference: DEFAULT_TYPE_REF,
+        nullability: r#type::Nullability::Unspecified as i32,
+    };
+
+    Ok(NamedStruct {
+        names,
+        r#struct: Some(field_types),
+    })
+}
+
 fn to_substrait_join_expr(
     ctx: &SessionContext,
     join_conditions: &Vec<(Expr, Expr)>,
@@ -2042,7 +2159,9 @@ fn substrait_field_ref(index: usize) -> 
Result<Expression> {
 
 #[cfg(test)]
 mod test {
-    use crate::logical_plan::consumer::{from_substrait_literal, 
from_substrait_type};
+    use crate::logical_plan::consumer::{
+        from_substrait_literal_without_names, 
from_substrait_type_without_names,
+    };
     use datafusion::arrow::array::GenericListArray;
     use datafusion::arrow::datatypes::Field;
     use datafusion::common::scalar::ScalarStructBuilder;
@@ -2115,11 +2234,12 @@ mod test {
         let c2 = Field::new("c2", DataType::Utf8, true);
         round_trip_literal(
             ScalarStructBuilder::new()
-                .with_scalar(c0, ScalarValue::Boolean(Some(true)))
-                .with_scalar(c1, ScalarValue::Int32(Some(1)))
-                .with_scalar(c2, ScalarValue::Utf8(None))
+                .with_scalar(c0.to_owned(), ScalarValue::Boolean(Some(true)))
+                .with_scalar(c1.to_owned(), ScalarValue::Int32(Some(1)))
+                .with_scalar(c2.to_owned(), ScalarValue::Utf8(None))
                 .build()?,
         )?;
+        round_trip_literal(ScalarStructBuilder::new_null(vec![c0, c1, c2]))?;
 
         Ok(())
     }
@@ -2128,7 +2248,7 @@ mod test {
         println!("Checking round trip of {scalar:?}");
 
         let substrait_literal = to_substrait_literal(&scalar)?;
-        let roundtrip_scalar = from_substrait_literal(&substrait_literal)?;
+        let roundtrip_scalar = 
from_substrait_literal_without_names(&substrait_literal)?;
         assert_eq!(scalar, roundtrip_scalar);
         Ok(())
     }
@@ -2186,7 +2306,7 @@ mod test {
         // As DataFusion doesn't consider nullability as a property of the 
type, but field,
         // it doesn't matter if we set nullability to true or false here.
         let substrait = to_substrait_type(&dt, true)?;
-        let roundtrip_dt = from_substrait_type(&substrait)?;
+        let roundtrip_dt = from_substrait_type_without_names(&substrait)?;
         assert_eq!(dt, roundtrip_dt);
         Ok(())
     }
diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs 
b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index de989001df..5490819b08 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -235,7 +235,8 @@ async fn aggregate_grouping_rollup() -> Result<()> {
     assert_expected_plan(
         "SELECT a, c, e, avg(b) FROM data GROUP BY ROLLUP (a, c, e)",
         "Aggregate: groupBy=[[GROUPING SETS ((data.a, data.c, data.e), 
(data.a, data.c), (data.a), ())]], aggr=[[AVG(data.b)]]\
-        \n  TableScan: data projection=[a, b, c, e]"
+        \n  TableScan: data projection=[a, b, c, e]",
+        true
     ).await
 }
 
@@ -368,6 +369,7 @@ async fn aggregate_case() -> Result<()> {
         "SELECT SUM(CASE WHEN a > 0 THEN 1 ELSE NULL END) FROM data",
         "Aggregate: groupBy=[[]], aggr=[[SUM(CASE WHEN data.a > Int64(0) THEN 
Int64(1) ELSE Int64(NULL) END)]]\
          \n  TableScan: data projection=[a]",
+        false // NULL vs Int64(NULL)
     )
         .await
 }
@@ -414,7 +416,8 @@ async fn roundtrip_inlist_5() -> Result<()> {
     \n    Subquery:\
     \n      Projection: data2.a\
     \n        Filter: data2.f IN ([Utf8(\"b\"), Utf8(\"c\"), Utf8(\"d\")])\
-    \n          TableScan: data2 projection=[a, b, c, d, e, f]").await
+    \n          TableScan: data2 projection=[a, b, c, d, e, f]",
+    true).await
 }
 
 #[tokio::test]
@@ -450,7 +453,8 @@ async fn roundtrip_exists_filter() -> Result<()> {
         "Projection: data.b\
         \n  LeftSemi Join: data.a = data2.a Filter: data2.e != CAST(data.e AS 
Int64)\
         \n    TableScan: data projection=[a, b, e]\
-        \n    TableScan: data2 projection=[a, e]"
+        \n    TableScan: data2 projection=[a, e]",
+        false // "d1" vs "data" field qualifier
     ).await
 }
 
@@ -462,6 +466,7 @@ async fn inner_join() -> Result<()> {
          \n  Inner Join: data.a = data2.a\
          \n    TableScan: data projection=[a]\
          \n    TableScan: data2 projection=[a]",
+        true,
     )
     .await
 }
@@ -592,6 +597,7 @@ async fn simple_intersect() -> Result<()> {
          \n      Aggregate: groupBy=[[data.a]], aggr=[[]]\
          \n        TableScan: data projection=[a]\
          \n      TableScan: data2 projection=[a]",
+        false // COUNT(*) vs COUNT(Int64(1))
     )
         .await
 }
@@ -606,6 +612,7 @@ async fn simple_intersect_table_reuse() -> Result<()> {
          \n      Aggregate: groupBy=[[data.a]], aggr=[[]]\
          \n        TableScan: data projection=[a]\
          \n      TableScan: data projection=[a]",
+        false // COUNT(*) vs COUNT(Int64(1))
     )
         .await
 }
@@ -633,6 +640,7 @@ async fn roundtrip_inner_join_table_reuse_zero_index() -> 
Result<()> {
          \n  Inner Join: data.a = data.a\
          \n    TableScan: data projection=[a, b]\
          \n    TableScan: data projection=[a, c]",
+        false, // "d1" vs "data" field qualifier
     )
     .await
 }
@@ -645,6 +653,7 @@ async fn roundtrip_inner_join_table_reuse_non_zero_index() 
-> Result<()> {
          \n  Inner Join: data.b = data.b\
          \n    TableScan: data projection=[b]\
          \n    TableScan: data projection=[b, c]",
+        false, // "d1" vs "data" field qualifier
     )
     .await
 }
@@ -689,6 +698,7 @@ async fn roundtrip_literal_list() -> Result<()> {
         "SELECT [[1,2,3], [], NULL, [NULL]] FROM data",
         "Projection: List([[1, 2, 3], [], , []])\
         \n  TableScan: data projection=[]",
+        false, // "List(..)" vs "make_array(..)"
     )
     .await
 }
@@ -699,10 +709,45 @@ async fn roundtrip_literal_struct() -> Result<()> {
         "SELECT STRUCT(1, true, CAST(NULL AS STRING)) FROM data",
         "Projection: Struct({c0:1,c1:true,c2:})\
         \n  TableScan: data projection=[]",
+        false, // "Struct(..)" vs "struct(..)"
     )
     .await
 }
 
+#[tokio::test]
+async fn roundtrip_values() -> Result<()> {
+    // TODO: would be nice to have a struct inside the LargeList, but 
arrow_cast doesn't support that currently
+    let values = "(\
+                1, \
+                'a', \
+                [[-213.1, NULL, 5.5, 2.0, 1.0], []], \
+                arrow_cast([1,2,3], 'LargeList(Int64)'), \
+                STRUCT(true, 1 AS int_field, CAST(NULL AS STRING)), \
+                [STRUCT(STRUCT('a' AS string_field) AS struct_field)]\
+            )";
+
+    // Test LogicalPlan::Values
+    assert_expected_plan(
+        format!("VALUES \
+            {values}, \
+            (NULL, NULL, NULL, NULL, NULL, NULL)").as_str(),
+        "Values: \
+            (\
+                Int64(1), \
+                Utf8(\"a\"), \
+                List([[-213.1, , 5.5, 2.0, 1.0], []]), \
+                LargeList([1, 2, 3]), \
+                Struct({c0:true,int_field:1,c2:}), \
+                List([{struct_field: {string_field: a}}])\
+            ), \
+            (Int64(NULL), Utf8(NULL), List(), LargeList(), 
Struct({c0:,int_field:,c2:}), List())",
+    true)
+        .await?;
+
+    // Test LogicalPlan::EmptyRelation
+    roundtrip(format!("SELECT * FROM (VALUES {values}) LIMIT 
0").as_str()).await
+}
+
 /// Construct a plan that cast columns. Only those SQL types are supported for 
now.
 #[tokio::test]
 async fn new_test_grammar() -> Result<()> {
@@ -918,31 +963,47 @@ async fn verify_post_join_filter_value(proto: Box<Plan>) 
-> Result<()> {
     Ok(())
 }
 
-async fn assert_expected_plan(sql: &str, expected_plan_str: &str) -> 
Result<()> {
+async fn assert_expected_plan(
+    sql: &str,
+    expected_plan_str: &str,
+    assert_schema: bool,
+) -> Result<()> {
     let ctx = create_context().await?;
     let df = ctx.sql(sql).await?;
     let plan = df.into_optimized_plan()?;
     let proto = to_substrait_plan(&plan, &ctx)?;
     let plan2 = from_substrait_plan(&ctx, &proto).await?;
     let plan2 = ctx.state().optimize(&plan2)?;
+
+    println!("{plan:#?}");
+    println!("{plan2:#?}");
+
+    println!("{proto:?}");
+
     let plan2str = format!("{plan2:?}");
     assert_eq!(expected_plan_str, &plan2str);
+
+    if assert_schema {
+        assert_eq!(plan.schema(), plan2.schema());
+    }
     Ok(())
 }
 
 async fn roundtrip_fill_na(sql: &str) -> Result<()> {
     let ctx = create_context().await?;
     let df = ctx.sql(sql).await?;
-    let plan1 = df.into_optimized_plan()?;
-    let proto = to_substrait_plan(&plan1, &ctx)?;
+    let plan = df.into_optimized_plan()?;
+    let proto = to_substrait_plan(&plan, &ctx)?;
     let plan2 = from_substrait_plan(&ctx, &proto).await?;
     let plan2 = ctx.state().optimize(&plan2)?;
 
     // Format plan string and replace all None's with 0
-    let plan1str = format!("{plan1:?}").replace("None", "0");
+    let plan1str = format!("{plan:?}").replace("None", "0");
     let plan2str = format!("{plan2:?}").replace("None", "0");
 
     assert_eq!(plan1str, plan2str);
+
+    assert_eq!(plan.schema(), plan2.schema());
     Ok(())
 }
 
@@ -966,6 +1027,8 @@ async fn test_alias(sql_with_alias: &str, sql_no_alias: 
&str) -> Result<()> {
     let plan1str = format!("{plan_with_alias:?}");
     let plan2str = format!("{plan:?}");
     assert_eq!(plan1str, plan2str);
+
+    assert_eq!(plan_with_alias.schema(), plan.schema());
     Ok(())
 }
 
@@ -979,9 +1042,13 @@ async fn roundtrip_with_ctx(sql: &str, ctx: 
SessionContext) -> Result<()> {
     println!("{plan:#?}");
     println!("{plan2:#?}");
 
+    println!("{proto:?}");
+
     let plan1str = format!("{plan:?}");
     let plan2str = format!("{plan2:?}");
     assert_eq!(plan1str, plan2str);
+
+    assert_eq!(plan.schema(), plan2.schema());
     Ok(())
 }
 
@@ -1004,25 +1071,14 @@ async fn roundtrip_verify_post_join_filter(sql: &str) 
-> Result<()> {
     let plan2str = format!("{plan2:?}");
     assert_eq!(plan1str, plan2str);
 
+    assert_eq!(plan.schema(), plan2.schema());
+
     // verify that the join filters are None
     verify_post_join_filter_value(proto).await
 }
 
 async fn roundtrip_all_types(sql: &str) -> Result<()> {
-    let ctx = create_all_type_context().await?;
-    let df = ctx.sql(sql).await?;
-    let plan = df.into_optimized_plan()?;
-    let proto = to_substrait_plan(&plan, &ctx)?;
-    let plan2 = from_substrait_plan(&ctx, &proto).await?;
-    let plan2 = ctx.state().optimize(&plan2)?;
-
-    println!("{plan:#?}");
-    println!("{plan2:#?}");
-
-    let plan1str = format!("{plan:?}");
-    let plan2str = format!("{plan2:?}");
-    assert_eq!(plan1str, plan2str);
-    Ok(())
+    roundtrip_with_ctx(sql, create_all_type_context().await?).await
 }
 
 async fn function_extension_info(sql: &str) -> Result<(Vec<String>, Vec<u32>)> 
{


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to