This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 27627546d1 Support Substrait's VirtualTables (#10531)
27627546d1 is described below
commit 27627546d1e58ac14e25819241c655c3d18807b3
Author: Arttu <[email protected]>
AuthorDate: Tue May 28 15:18:45 2024 +0200
Support Substrait's VirtualTables (#10531)
* Add support for Substrait VirtualTables
Adds support for Substrait's VirtualTables, ie. tables with data baked-in
into the Substrait plan instead of being read from a source.
Adds conversion in both ways (Substrait -> DataFusion and DataFusion ->
Substrait)
and a roundtrip test.
* fix clippy
* Add support for empty relations
* Fix consuming Structs inside Lists and Structs
Also adds roundtrip schema assertions for cases where possible
* Rename from_substrait_struct -> from_substrait_struct_type for clarity
* Add DataType::LargeList to to_substrait_named_struct
* cargo fmt --all
* Add validation that names list matches schema exactly
* Add a LargeList into VALUES test
---
datafusion/substrait/src/logical_plan/consumer.rs | 191 +++++++++++++++++----
datafusion/substrait/src/logical_plan/producer.rs | 134 ++++++++++++++-
.../tests/cases/roundtrip_logical_plan.rs | 98 ++++++++---
3 files changed, 361 insertions(+), 62 deletions(-)
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs
b/datafusion/substrait/src/logical_plan/consumer.rs
index d6c60ebdde..abebc68123 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -16,15 +16,17 @@
// under the License.
use async_recursion::async_recursion;
-use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit};
+use datafusion::arrow::datatypes::{
+ DataType, Field, Fields, IntervalUnit, Schema, TimeUnit,
+};
use datafusion::common::{
not_impl_err, substrait_datafusion_err, substrait_err, DFSchema,
DFSchemaRef,
};
use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::{
- aggregate_function, expr::find_df_window_func, BinaryExpr, Case, Expr,
LogicalPlan,
- Operator, ScalarUDF,
+ aggregate_function, expr::find_df_window_func, BinaryExpr, Case,
EmptyRelation, Expr,
+ LogicalPlan, Operator, ScalarUDF, Values,
};
use datafusion::logical_expr::{
expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning,
@@ -58,7 +60,7 @@ use substrait::proto::{
rel::RelType,
set_rel,
sort_field::{SortDirection, SortKind::*},
- AggregateFunction, Expression, Plan, Rel, Type,
+ AggregateFunction, Expression, NamedStruct, Plan, Rel, Type,
};
use substrait::proto::{FunctionArgument, SortField};
@@ -509,7 +511,51 @@ pub async fn from_substrait_rel(
_ => Ok(t),
}
}
- _ => not_impl_err!("Only NamedTable reads are supported"),
+ Some(ReadType::VirtualTable(vt)) => {
+ let base_schema = read.base_schema.as_ref().ok_or_else(|| {
+ substrait_datafusion_err!("No base schema provided for
Virtual Table")
+ })?;
+
+ let schema = from_substrait_named_struct(base_schema)?;
+
+ if vt.values.is_empty() {
+ return Ok(LogicalPlan::EmptyRelation(EmptyRelation {
+ produce_one_row: false,
+ schema,
+ }));
+ }
+
+ let values = vt
+ .values
+ .iter()
+ .map(|row| {
+ let mut name_idx = 0;
+ let lits = row
+ .fields
+ .iter()
+ .map(|lit| {
+ name_idx += 1; // top-level names are provided
through schema
+ Ok(Expr::Literal(from_substrait_literal(
+ lit,
+ &base_schema.names,
+ &mut name_idx,
+ )?))
+ })
+ .collect::<Result<_>>()?;
+ if name_idx != base_schema.names.len() {
+ return substrait_err!(
+ "Names list must match exactly to nested
schema, but found {} uses for {} names",
+ name_idx,
+ base_schema.names.len()
+ );
+ }
+ Ok(lits)
+ })
+ .collect::<Result<_>>()?;
+
+ Ok(LogicalPlan::Values(Values { schema, values }))
+ }
+ _ => not_impl_err!("Only NamedTable and VirtualTable reads are
supported"),
},
Some(RelType::Set(set)) => match set_rel::SetOp::try_from(set.op) {
Ok(set_op) => match set_op {
@@ -948,7 +994,7 @@ pub async fn from_substrait_rex(
}
}
Some(RexType::Literal(lit)) => {
- let scalar_value = from_substrait_literal(lit)?;
+ let scalar_value = from_substrait_literal_without_names(lit)?;
Ok(Arc::new(Expr::Literal(scalar_value)))
}
Some(RexType::Cast(cast)) => match cast.as_ref().r#type.as_ref() {
@@ -964,9 +1010,9 @@ pub async fn from_substrait_rex(
.as_ref()
.clone(),
),
- from_substrait_type(output_type)?,
+ from_substrait_type_without_names(output_type)?,
)))),
- None => substrait_err!("Cast experssion without output type is not
allowed"),
+ None => substrait_err!("Cast expression without output type is not
allowed"),
},
Some(RexType::WindowFunction(window)) => {
let fun = match extensions.get(&window.function_reference) {
@@ -1062,7 +1108,15 @@ pub async fn from_substrait_rex(
}
}
-pub(crate) fn from_substrait_type(dt: &substrait::proto::Type) ->
Result<DataType> {
+pub(crate) fn from_substrait_type_without_names(dt: &Type) -> Result<DataType>
{
+ from_substrait_type(dt, &[], &mut 0)
+}
+
+fn from_substrait_type(
+ dt: &Type,
+ dfs_names: &[String],
+ name_idx: &mut usize,
+) -> Result<DataType> {
match &dt.kind {
Some(s_kind) => match s_kind {
r#type::Kind::Bool(_) => Ok(DataType::Boolean),
@@ -1142,7 +1196,7 @@ pub(crate) fn from_substrait_type(dt:
&substrait::proto::Type) -> Result<DataTyp
substrait_datafusion_err!("List type must have inner type")
})?;
let field = Arc::new(Field::new_list_field(
- from_substrait_type(inner_type)?,
+ from_substrait_type(inner_type, dfs_names, name_idx)?,
is_substrait_type_nullable(inner_type)?,
));
match list.type_variation_reference {
@@ -1182,24 +1236,69 @@ pub(crate) fn from_substrait_type(dt:
&substrait::proto::Type) -> Result<DataTyp
),
}
},
- r#type::Kind::Struct(s) => {
- let mut fields = vec![];
- for (i, f) in s.types.iter().enumerate() {
- let field = Field::new(
- &format!("c{i}"),
- from_substrait_type(f)?,
- is_substrait_type_nullable(f)?,
- );
- fields.push(field);
- }
- Ok(DataType::Struct(fields.into()))
- }
+ r#type::Kind::Struct(s) =>
Ok(DataType::Struct(from_substrait_struct_type(
+ s, dfs_names, name_idx,
+ )?)),
_ => not_impl_err!("Unsupported Substrait type: {s_kind:?}"),
},
_ => not_impl_err!("`None` Substrait kind is not supported"),
}
}
+fn from_substrait_struct_type(
+ s: &r#type::Struct,
+ dfs_names: &[String],
+ name_idx: &mut usize,
+) -> Result<Fields> {
+ let mut fields = vec![];
+ for (i, f) in s.types.iter().enumerate() {
+ let field = Field::new(
+ next_struct_field_name(i, dfs_names, name_idx)?,
+ from_substrait_type(f, dfs_names, name_idx)?,
+ is_substrait_type_nullable(f)?,
+ );
+ fields.push(field);
+ }
+ Ok(fields.into())
+}
+
+fn next_struct_field_name(
+ i: usize,
+ dfs_names: &[String],
+ name_idx: &mut usize,
+) -> Result<String> {
+ if dfs_names.is_empty() {
+ // If names are not given, create dummy names
+ // c0, c1, ... align with e.g. SqlToRel::create_named_struct
+ Ok(format!("c{i}"))
+ } else {
+ let name = dfs_names.get(*name_idx).cloned().ok_or_else(|| {
+ substrait_datafusion_err!("Named schema must contain names for all
fields")
+ })?;
+ *name_idx += 1;
+ Ok(name)
+ }
+}
+
+fn from_substrait_named_struct(base_schema: &NamedStruct) ->
Result<DFSchemaRef> {
+ let mut name_idx = 0;
+ let fields = from_substrait_struct_type(
+ base_schema.r#struct.as_ref().ok_or_else(|| {
+ substrait_datafusion_err!("Named struct must contain a struct")
+ })?,
+ &base_schema.names,
+ &mut name_idx,
+ );
+ if name_idx != base_schema.names.len() {
+ return substrait_err!(
+ "Names list must match exactly to nested
schema, but found {} uses for {} names",
+ name_idx,
+ base_schema.names.len()
+ );
+ }
+ Ok(DFSchemaRef::new(DFSchema::try_from(Schema::new(fields?))?))
+}
+
fn is_substrait_type_nullable(dtype: &Type) -> Result<bool> {
fn is_nullable(nullability: i32) -> bool {
nullability != substrait::proto::r#type::Nullability::Required as i32
@@ -1277,7 +1376,15 @@ fn from_substrait_bound(
}
}
-pub(crate) fn from_substrait_literal(lit: &Literal) -> Result<ScalarValue> {
+pub(crate) fn from_substrait_literal_without_names(lit: &Literal) ->
Result<ScalarValue> {
+ from_substrait_literal(lit, &vec![], &mut 0)
+}
+
+fn from_substrait_literal(
+ lit: &Literal,
+ dfs_names: &Vec<String>,
+ name_idx: &mut usize,
+) -> Result<ScalarValue> {
let scalar_value = match &lit.literal_type {
Some(LiteralType::Boolean(b)) => ScalarValue::Boolean(Some(*b)),
Some(LiteralType::I8(n)) => match lit.type_variation_reference {
@@ -1359,7 +1466,7 @@ pub(crate) fn from_substrait_literal(lit: &Literal) ->
Result<ScalarValue> {
let elements = l
.values
.iter()
- .map(from_substrait_literal)
+ .map(|el| from_substrait_literal(el, dfs_names, name_idx))
.collect::<Result<Vec<_>>>()?;
if elements.is_empty() {
return substrait_err!(
@@ -1381,7 +1488,11 @@ pub(crate) fn from_substrait_literal(lit: &Literal) ->
Result<ScalarValue> {
}
}
Some(LiteralType::EmptyList(l)) => {
- let element_type =
from_substrait_type(l.r#type.clone().unwrap().as_ref())?;
+ let element_type = from_substrait_type(
+ l.r#type.clone().unwrap().as_ref(),
+ dfs_names,
+ name_idx,
+ )?;
match lit.type_variation_reference {
DEFAULT_CONTAINER_TYPE_REF => {
ScalarValue::List(ScalarValue::new_list(&[],
&element_type))
@@ -1397,16 +1508,16 @@ pub(crate) fn from_substrait_literal(lit: &Literal) ->
Result<ScalarValue> {
Some(LiteralType::Struct(s)) => {
let mut builder = ScalarStructBuilder::new();
for (i, field) in s.fields.iter().enumerate() {
- let sv = from_substrait_literal(field)?;
- // c0, c1, ... align with e.g. SqlToRel::create_named_struct
- builder = builder.with_scalar(
- Field::new(&format!("c{i}"), sv.data_type(),
field.nullable),
- sv,
- );
+ let name = next_struct_field_name(i, dfs_names, name_idx)?;
+ let sv = from_substrait_literal(field, dfs_names, name_idx)?;
+ builder = builder
+ .with_scalar(Field::new(name, sv.data_type(),
field.nullable), sv);
}
builder.build()?
}
- Some(LiteralType::Null(ntype)) => from_substrait_null(ntype)?,
+ Some(LiteralType::Null(ntype)) => {
+ from_substrait_null(ntype, dfs_names, name_idx)?
+ }
Some(LiteralType::UserDefined(user_defined)) => {
match user_defined.type_reference {
INTERVAL_YEAR_MONTH_TYPE_REF => {
@@ -1461,7 +1572,11 @@ pub(crate) fn from_substrait_literal(lit: &Literal) ->
Result<ScalarValue> {
Ok(scalar_value)
}
-fn from_substrait_null(null_type: &Type) -> Result<ScalarValue> {
+fn from_substrait_null(
+ null_type: &Type,
+ dfs_names: &[String],
+ name_idx: &mut usize,
+) -> Result<ScalarValue> {
if let Some(kind) = &null_type.kind {
match kind {
r#type::Kind::Bool(_) => Ok(ScalarValue::Boolean(None)),
@@ -1539,7 +1654,11 @@ fn from_substrait_null(null_type: &Type) ->
Result<ScalarValue> {
)),
r#type::Kind::List(l) => {
let field = Field::new_list_field(
- from_substrait_type(l.r#type.clone().unwrap().as_ref())?,
+ from_substrait_type(
+ l.r#type.clone().unwrap().as_ref(),
+ dfs_names,
+ name_idx,
+ )?,
true,
);
match l.type_variation_reference {
@@ -1554,6 +1673,10 @@ fn from_substrait_null(null_type: &Type) ->
Result<ScalarValue> {
),
}
}
+ r#type::Kind::Struct(s) => {
+ let fields = from_substrait_struct_type(s, dfs_names,
name_idx)?;
+ Ok(ScalarStructBuilder::new_null(fields))
+ }
_ => not_impl_err!("Unsupported Substrait type for null:
{kind:?}"),
}
} else {
diff --git a/datafusion/substrait/src/logical_plan/producer.rs
b/datafusion/substrait/src/logical_plan/producer.rs
index 592b40db59..4dd8226366 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.
+use itertools::Itertools;
use std::collections::HashMap;
use std::ops::Deref;
use std::sync::Arc;
@@ -32,7 +33,9 @@ use datafusion::{
};
use datafusion::arrow::array::{Array, GenericListArray, OffsetSizeTrait};
-use datafusion::common::{exec_err, internal_err, not_impl_err};
+use datafusion::common::{
+ exec_err, internal_err, not_impl_err, plan_err, substrait_datafusion_err,
+};
use datafusion::common::{substrait_err, DFSchemaRef};
#[allow(unused_imports)]
use datafusion::logical_expr::aggregate_function;
@@ -50,6 +53,7 @@ use substrait::proto::expression::literal::{List, Struct};
use substrait::proto::expression::subquery::InPredicate;
use substrait::proto::expression::window_function::BoundsType;
use substrait::proto::r#type::{parameter, Parameter};
+use substrait::proto::read_rel::VirtualTable;
use substrait::proto::{CrossRel, ExchangeRel};
use substrait::{
proto::{
@@ -174,6 +178,62 @@ pub fn to_substrait_rel(
}))),
}))
}
+ LogicalPlan::EmptyRelation(e) => {
+ if e.produce_one_row {
+ return not_impl_err!(
+ "Producing a row from empty relation is unsupported"
+ );
+ }
+ Ok(Box::new(Rel {
+ rel_type: Some(RelType::Read(Box::new(ReadRel {
+ common: None,
+ base_schema: Some(to_substrait_named_struct(&e.schema)?),
+ filter: None,
+ best_effort_filter: None,
+ projection: None,
+ advanced_extension: None,
+ read_type: Some(ReadType::VirtualTable(VirtualTable {
+ values: vec![],
+ })),
+ }))),
+ }))
+ }
+ LogicalPlan::Values(v) => {
+ let values = v
+ .values
+ .iter()
+ .map(|row| {
+ let fields = row
+ .iter()
+ .map(|v| match v {
+ Expr::Literal(sv) => to_substrait_literal(sv),
+ Expr::Alias(alias) => match alias.expr.as_ref() {
+ // The schema gives us the names, so we can
skip aliases
+ Expr::Literal(sv) => to_substrait_literal(sv),
+ _ => Err(substrait_datafusion_err!(
+ "Only literal types can be aliased in
Virtual Tables, got: {}", alias.expr.variant_name()
+ )),
+ },
+ _ => Err(substrait_datafusion_err!(
+ "Only literal types and aliases are supported
in Virtual Tables, got: {}", v.variant_name()
+ )),
+ })
+ .collect::<Result<_>>()?;
+ Ok(Struct { fields })
+ })
+ .collect::<Result<_>>()?;
+ Ok(Box::new(Rel {
+ rel_type: Some(RelType::Read(Box::new(ReadRel {
+ common: None,
+ base_schema: Some(to_substrait_named_struct(&v.schema)?),
+ filter: None,
+ best_effort_filter: None,
+ projection: None,
+ advanced_extension: None,
+ read_type: Some(ReadType::VirtualTable(VirtualTable {
values })),
+ }))),
+ }))
+ }
LogicalPlan::Projection(p) => {
let expressions = p
.expr
@@ -519,6 +579,63 @@ pub fn to_substrait_rel(
}
}
+fn to_substrait_named_struct(schema: &DFSchemaRef) -> Result<NamedStruct> {
+ // Substrait wants a list of all field names, including nested fields from
structs,
+ // also from within e.g. lists and maps. However, it does not want the
list and map field names
+ // themselves - only proper structs fields are considered to have useful
names.
+ fn names_dfs(dtype: &DataType) -> Result<Vec<String>> {
+ match dtype {
+ DataType::Struct(fields) => {
+ let mut names = Vec::new();
+ for field in fields {
+ names.push(field.name().to_string());
+ names.extend(names_dfs(field.data_type())?);
+ }
+ Ok(names)
+ }
+ DataType::List(l) => names_dfs(l.data_type()),
+ DataType::LargeList(l) => names_dfs(l.data_type()),
+ DataType::Map(m, _) => match m.data_type() {
+ DataType::Struct(key_and_value) if key_and_value.len() == 2 =>
{
+ let key_names =
+ names_dfs(key_and_value.first().unwrap().data_type())?;
+ let value_names =
+ names_dfs(key_and_value.last().unwrap().data_type())?;
+ Ok([key_names, value_names].concat())
+ }
+ _ => plan_err!("Map fields must contain a Struct with exactly
2 fields"),
+ },
+ _ => Ok(Vec::new()),
+ }
+ }
+
+ let names = schema
+ .fields()
+ .iter()
+ .map(|f| {
+ let mut names = vec![f.name().to_string()];
+ names.extend(names_dfs(f.data_type())?);
+ Ok(names)
+ })
+ .flatten_ok()
+ .collect::<Result<_>>()?;
+
+ let field_types = r#type::Struct {
+ types: schema
+ .fields()
+ .iter()
+ .map(|f| to_substrait_type(f.data_type(), f.is_nullable()))
+ .collect::<Result<_>>()?,
+ type_variation_reference: DEFAULT_TYPE_REF,
+ nullability: r#type::Nullability::Unspecified as i32,
+ };
+
+ Ok(NamedStruct {
+ names,
+ r#struct: Some(field_types),
+ })
+}
+
fn to_substrait_join_expr(
ctx: &SessionContext,
join_conditions: &Vec<(Expr, Expr)>,
@@ -2042,7 +2159,9 @@ fn substrait_field_ref(index: usize) ->
Result<Expression> {
#[cfg(test)]
mod test {
- use crate::logical_plan::consumer::{from_substrait_literal,
from_substrait_type};
+ use crate::logical_plan::consumer::{
+ from_substrait_literal_without_names,
from_substrait_type_without_names,
+ };
use datafusion::arrow::array::GenericListArray;
use datafusion::arrow::datatypes::Field;
use datafusion::common::scalar::ScalarStructBuilder;
@@ -2115,11 +2234,12 @@ mod test {
let c2 = Field::new("c2", DataType::Utf8, true);
round_trip_literal(
ScalarStructBuilder::new()
- .with_scalar(c0, ScalarValue::Boolean(Some(true)))
- .with_scalar(c1, ScalarValue::Int32(Some(1)))
- .with_scalar(c2, ScalarValue::Utf8(None))
+ .with_scalar(c0.to_owned(), ScalarValue::Boolean(Some(true)))
+ .with_scalar(c1.to_owned(), ScalarValue::Int32(Some(1)))
+ .with_scalar(c2.to_owned(), ScalarValue::Utf8(None))
.build()?,
)?;
+ round_trip_literal(ScalarStructBuilder::new_null(vec![c0, c1, c2]))?;
Ok(())
}
@@ -2128,7 +2248,7 @@ mod test {
println!("Checking round trip of {scalar:?}");
let substrait_literal = to_substrait_literal(&scalar)?;
- let roundtrip_scalar = from_substrait_literal(&substrait_literal)?;
+ let roundtrip_scalar =
from_substrait_literal_without_names(&substrait_literal)?;
assert_eq!(scalar, roundtrip_scalar);
Ok(())
}
@@ -2186,7 +2306,7 @@ mod test {
// As DataFusion doesn't consider nullability as a property of the
type, but field,
// it doesn't matter if we set nullability to true or false here.
let substrait = to_substrait_type(&dt, true)?;
- let roundtrip_dt = from_substrait_type(&substrait)?;
+ let roundtrip_dt = from_substrait_type_without_names(&substrait)?;
assert_eq!(dt, roundtrip_dt);
Ok(())
}
diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index de989001df..5490819b08 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -235,7 +235,8 @@ async fn aggregate_grouping_rollup() -> Result<()> {
assert_expected_plan(
"SELECT a, c, e, avg(b) FROM data GROUP BY ROLLUP (a, c, e)",
"Aggregate: groupBy=[[GROUPING SETS ((data.a, data.c, data.e),
(data.a, data.c), (data.a), ())]], aggr=[[AVG(data.b)]]\
- \n TableScan: data projection=[a, b, c, e]"
+ \n TableScan: data projection=[a, b, c, e]",
+ true
).await
}
@@ -368,6 +369,7 @@ async fn aggregate_case() -> Result<()> {
"SELECT SUM(CASE WHEN a > 0 THEN 1 ELSE NULL END) FROM data",
"Aggregate: groupBy=[[]], aggr=[[SUM(CASE WHEN data.a > Int64(0) THEN
Int64(1) ELSE Int64(NULL) END)]]\
\n TableScan: data projection=[a]",
+ false // NULL vs Int64(NULL)
)
.await
}
@@ -414,7 +416,8 @@ async fn roundtrip_inlist_5() -> Result<()> {
\n Subquery:\
\n Projection: data2.a\
\n Filter: data2.f IN ([Utf8(\"b\"), Utf8(\"c\"), Utf8(\"d\")])\
- \n TableScan: data2 projection=[a, b, c, d, e, f]").await
+ \n TableScan: data2 projection=[a, b, c, d, e, f]",
+ true).await
}
#[tokio::test]
@@ -450,7 +453,8 @@ async fn roundtrip_exists_filter() -> Result<()> {
"Projection: data.b\
\n LeftSemi Join: data.a = data2.a Filter: data2.e != CAST(data.e AS
Int64)\
\n TableScan: data projection=[a, b, e]\
- \n TableScan: data2 projection=[a, e]"
+ \n TableScan: data2 projection=[a, e]",
+ false // "d1" vs "data" field qualifier
).await
}
@@ -462,6 +466,7 @@ async fn inner_join() -> Result<()> {
\n Inner Join: data.a = data2.a\
\n TableScan: data projection=[a]\
\n TableScan: data2 projection=[a]",
+ true,
)
.await
}
@@ -592,6 +597,7 @@ async fn simple_intersect() -> Result<()> {
\n Aggregate: groupBy=[[data.a]], aggr=[[]]\
\n TableScan: data projection=[a]\
\n TableScan: data2 projection=[a]",
+ false // COUNT(*) vs COUNT(Int64(1))
)
.await
}
@@ -606,6 +612,7 @@ async fn simple_intersect_table_reuse() -> Result<()> {
\n Aggregate: groupBy=[[data.a]], aggr=[[]]\
\n TableScan: data projection=[a]\
\n TableScan: data projection=[a]",
+ false // COUNT(*) vs COUNT(Int64(1))
)
.await
}
@@ -633,6 +640,7 @@ async fn roundtrip_inner_join_table_reuse_zero_index() ->
Result<()> {
\n Inner Join: data.a = data.a\
\n TableScan: data projection=[a, b]\
\n TableScan: data projection=[a, c]",
+ false, // "d1" vs "data" field qualifier
)
.await
}
@@ -645,6 +653,7 @@ async fn roundtrip_inner_join_table_reuse_non_zero_index()
-> Result<()> {
\n Inner Join: data.b = data.b\
\n TableScan: data projection=[b]\
\n TableScan: data projection=[b, c]",
+ false, // "d1" vs "data" field qualifier
)
.await
}
@@ -689,6 +698,7 @@ async fn roundtrip_literal_list() -> Result<()> {
"SELECT [[1,2,3], [], NULL, [NULL]] FROM data",
"Projection: List([[1, 2, 3], [], , []])\
\n TableScan: data projection=[]",
+ false, // "List(..)" vs "make_array(..)"
)
.await
}
@@ -699,10 +709,45 @@ async fn roundtrip_literal_struct() -> Result<()> {
"SELECT STRUCT(1, true, CAST(NULL AS STRING)) FROM data",
"Projection: Struct({c0:1,c1:true,c2:})\
\n TableScan: data projection=[]",
+ false, // "Struct(..)" vs "struct(..)"
)
.await
}
+#[tokio::test]
+async fn roundtrip_values() -> Result<()> {
+ // TODO: would be nice to have a struct inside the LargeList, but
arrow_cast doesn't support that currently
+ let values = "(\
+ 1, \
+ 'a', \
+ [[-213.1, NULL, 5.5, 2.0, 1.0], []], \
+ arrow_cast([1,2,3], 'LargeList(Int64)'), \
+ STRUCT(true, 1 AS int_field, CAST(NULL AS STRING)), \
+ [STRUCT(STRUCT('a' AS string_field) AS struct_field)]\
+ )";
+
+ // Test LogicalPlan::Values
+ assert_expected_plan(
+ format!("VALUES \
+ {values}, \
+ (NULL, NULL, NULL, NULL, NULL, NULL)").as_str(),
+ "Values: \
+ (\
+ Int64(1), \
+ Utf8(\"a\"), \
+ List([[-213.1, , 5.5, 2.0, 1.0], []]), \
+ LargeList([1, 2, 3]), \
+ Struct({c0:true,int_field:1,c2:}), \
+ List([{struct_field: {string_field: a}}])\
+ ), \
+ (Int64(NULL), Utf8(NULL), List(), LargeList(),
Struct({c0:,int_field:,c2:}), List())",
+ true)
+ .await?;
+
+ // Test LogicalPlan::EmptyRelation
+ roundtrip(format!("SELECT * FROM (VALUES {values}) LIMIT
0").as_str()).await
+}
+
/// Construct a plan that cast columns. Only those SQL types are supported for
now.
#[tokio::test]
async fn new_test_grammar() -> Result<()> {
@@ -918,31 +963,47 @@ async fn verify_post_join_filter_value(proto: Box<Plan>)
-> Result<()> {
Ok(())
}
-async fn assert_expected_plan(sql: &str, expected_plan_str: &str) ->
Result<()> {
+async fn assert_expected_plan(
+ sql: &str,
+ expected_plan_str: &str,
+ assert_schema: bool,
+) -> Result<()> {
let ctx = create_context().await?;
let df = ctx.sql(sql).await?;
let plan = df.into_optimized_plan()?;
let proto = to_substrait_plan(&plan, &ctx)?;
let plan2 = from_substrait_plan(&ctx, &proto).await?;
let plan2 = ctx.state().optimize(&plan2)?;
+
+ println!("{plan:#?}");
+ println!("{plan2:#?}");
+
+ println!("{proto:?}");
+
let plan2str = format!("{plan2:?}");
assert_eq!(expected_plan_str, &plan2str);
+
+ if assert_schema {
+ assert_eq!(plan.schema(), plan2.schema());
+ }
Ok(())
}
async fn roundtrip_fill_na(sql: &str) -> Result<()> {
let ctx = create_context().await?;
let df = ctx.sql(sql).await?;
- let plan1 = df.into_optimized_plan()?;
- let proto = to_substrait_plan(&plan1, &ctx)?;
+ let plan = df.into_optimized_plan()?;
+ let proto = to_substrait_plan(&plan, &ctx)?;
let plan2 = from_substrait_plan(&ctx, &proto).await?;
let plan2 = ctx.state().optimize(&plan2)?;
// Format plan string and replace all None's with 0
- let plan1str = format!("{plan1:?}").replace("None", "0");
+ let plan1str = format!("{plan:?}").replace("None", "0");
let plan2str = format!("{plan2:?}").replace("None", "0");
assert_eq!(plan1str, plan2str);
+
+ assert_eq!(plan.schema(), plan2.schema());
Ok(())
}
@@ -966,6 +1027,8 @@ async fn test_alias(sql_with_alias: &str, sql_no_alias:
&str) -> Result<()> {
let plan1str = format!("{plan_with_alias:?}");
let plan2str = format!("{plan:?}");
assert_eq!(plan1str, plan2str);
+
+ assert_eq!(plan_with_alias.schema(), plan.schema());
Ok(())
}
@@ -979,9 +1042,13 @@ async fn roundtrip_with_ctx(sql: &str, ctx:
SessionContext) -> Result<()> {
println!("{plan:#?}");
println!("{plan2:#?}");
+ println!("{proto:?}");
+
let plan1str = format!("{plan:?}");
let plan2str = format!("{plan2:?}");
assert_eq!(plan1str, plan2str);
+
+ assert_eq!(plan.schema(), plan2.schema());
Ok(())
}
@@ -1004,25 +1071,14 @@ async fn roundtrip_verify_post_join_filter(sql: &str)
-> Result<()> {
let plan2str = format!("{plan2:?}");
assert_eq!(plan1str, plan2str);
+ assert_eq!(plan.schema(), plan2.schema());
+
// verify that the join filters are None
verify_post_join_filter_value(proto).await
}
async fn roundtrip_all_types(sql: &str) -> Result<()> {
- let ctx = create_all_type_context().await?;
- let df = ctx.sql(sql).await?;
- let plan = df.into_optimized_plan()?;
- let proto = to_substrait_plan(&plan, &ctx)?;
- let plan2 = from_substrait_plan(&ctx, &proto).await?;
- let plan2 = ctx.state().optimize(&plan2)?;
-
- println!("{plan:#?}");
- println!("{plan2:#?}");
-
- let plan1str = format!("{plan:?}");
- let plan2str = format!("{plan2:?}");
- assert_eq!(plan1str, plan2str);
- Ok(())
+ roundtrip_with_ctx(sql, create_all_type_context().await?).await
}
async fn function_extension_info(sql: &str) -> Result<(Vec<String>, Vec<u32>)>
{
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]