This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 5a0ea0bbad fix(substrait): remove optimize calls from substrait
consumer (#12800)
5a0ea0bbad is described below
commit 5a0ea0bbad3de73ded192bf32080094cee5db9f3
Author: Tornike Gurgenidze <[email protected]>
AuthorDate: Tue Oct 15 21:13:54 2024 +0400
fix(substrait): remove optimize calls from substrait consumer (#12800)
* fix(substrait): remove optimize calls from substrait consumer
* fix(substrait): fix schema comparison in ensure_schema_compatability
* fix(substrait): correctly apply read projections
* fix(substrait): nits
* fix(substrait): split schema validation and apply_projection
* fix(substrait): return an error when apply_projection is called with
something other than a TableScan
* fix(substrait): clippy errors
---
datafusion/substrait/src/lib.rs | 1 +
datafusion/substrait/src/logical_plan/consumer.rs | 163 ++++++++++++---------
.../substrait/tests/cases/consumer_integration.rs | 122 +++++++--------
datafusion/substrait/tests/cases/function_test.rs | 2 +-
datafusion/substrait/tests/cases/logical_plans.rs | 4 +-
.../tests/cases/roundtrip_logical_plan.rs | 4 +-
.../substrait/tests/cases/substrait_validations.rs | 15 +-
7 files changed, 170 insertions(+), 141 deletions(-)
diff --git a/datafusion/substrait/src/lib.rs b/datafusion/substrait/src/lib.rs
index 0b1c796553..a6f7c033f9 100644
--- a/datafusion/substrait/src/lib.rs
+++ b/datafusion/substrait/src/lib.rs
@@ -68,6 +68,7 @@
//!
//! // Receive a substrait protobuf from somewhere, and turn it into a
LogicalPlan
//! let logical_round_trip =
logical_plan::consumer::from_substrait_plan(&ctx, &substrait_plan).await?;
+//! let logical_round_trip = ctx.state().optimize(&logical_round_trip)?;
//! assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip));
//! # Ok(())
//! # }
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs
b/datafusion/substrait/src/logical_plan/consumer.rs
index 8884807749..3cafd64321 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -55,7 +55,6 @@ use crate::variation_const::{
use datafusion::arrow::array::{new_empty_array, AsArray};
use datafusion::common::scalar::ScalarStructBuilder;
use datafusion::dataframe::DataFrame;
-use datafusion::logical_expr::builder::project;
use datafusion::logical_expr::expr::InList;
use datafusion::logical_expr::{
col, expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder,
Partitioning,
@@ -69,7 +68,7 @@ use datafusion::{
prelude::{Column, SessionContext},
scalar::ScalarValue,
};
-use std::collections::{HashMap, HashSet};
+use std::collections::HashSet;
use std::sync::Arc;
use substrait::proto::exchange_rel::ExchangeKind;
use
substrait::proto::expression::literal::interval_day_to_second::PrecisionMode;
@@ -227,7 +226,6 @@ pub async fn from_substrait_plan(
// Nothing to do if the schema is already
equivalent
return Ok(plan);
}
-
match plan {
// If the last node of the plan produces
expressions, bake the renames into those expressions.
// This isn't necessary for correctness, but helps
with roundtrip tests.
@@ -327,12 +325,11 @@ pub async fn from_substrait_extended_expr(
})
}
-/// parse projection
-pub fn extract_projection(
- t: LogicalPlan,
- projection: &::core::option::Option<expression::MaskExpression>,
-) -> Result<LogicalPlan> {
- match projection {
+pub fn apply_masking(
+ schema: DFSchema,
+ mask_expression: &::core::option::Option<expression::MaskExpression>,
+) -> Result<DFSchema> {
+ match mask_expression {
Some(MaskExpression { select, .. }) => match &select.as_ref() {
Some(projection) => {
let column_indices: Vec<usize> = projection
@@ -340,41 +337,23 @@ pub fn extract_projection(
.iter()
.map(|item| item.field as usize)
.collect();
- match t {
- LogicalPlan::TableScan(mut scan) => {
- let fields = column_indices
- .iter()
- .map(|i| scan.projected_schema.qualified_field(*i))
- .map(|(qualifier, field)| {
- (qualifier.cloned(), Arc::new(field.clone()))
- })
- .collect();
- scan.projection = Some(column_indices);
- scan.projected_schema = DFSchemaRef::new(
- DFSchema::new_with_metadata(fields,
HashMap::new())?,
- );
- Ok(LogicalPlan::TableScan(scan))
- }
- LogicalPlan::Projection(projection) => {
- // create another Projection around the Projection to
handle the field masking
- let fields: Vec<Expr> = column_indices
- .into_iter()
- .map(|i| {
- let (qualifier, field) =
- projection.schema.qualified_field(i);
- let column =
- Column::new(qualifier.cloned(),
field.name());
- Expr::Column(column)
- })
- .collect();
- project(LogicalPlan::Projection(projection), fields)
- }
- _ => plan_err!("unexpected plan for table"),
- }
+
+ let fields = column_indices
+ .iter()
+ .map(|i| schema.qualified_field(*i))
+ .map(|(qualifier, field)| {
+ (qualifier.cloned(), Arc::new(field.clone()))
+ })
+ .collect();
+
+ Ok(DFSchema::new_with_metadata(
+ fields,
+ schema.metadata().clone(),
+ )?)
}
- _ => Ok(t),
+ None => Ok(schema),
},
- _ => Ok(t),
+ None => Ok(schema),
}
}
@@ -777,14 +756,20 @@ pub async fn from_substrait_rel(
},
};
+ let t = ctx.table(table_reference.clone()).await?;
+
let substrait_schema =
from_substrait_named_struct(named_struct, extensions)?
- .replace_qualifier(table_reference.clone());
+ .replace_qualifier(table_reference);
- let t = ctx.table(table_reference.clone()).await?;
- let t = ensure_schema_compatability(t, substrait_schema)?;
- let t = t.into_optimized_plan()?;
- extract_projection(t, &read.projection)
+ ensure_schema_compatability(
+ t.schema().to_owned(),
+ substrait_schema.clone(),
+ )?;
+
+ let substrait_schema = apply_masking(substrait_schema,
&read.projection)?;
+
+ apply_projection(t, substrait_schema)
}
Some(ReadType::VirtualTable(vt)) => {
let base_schema = read.base_schema.as_ref().ok_or_else(|| {
@@ -835,6 +820,10 @@ pub async fn from_substrait_rel(
}))
}
Some(ReadType::LocalFiles(lf)) => {
+ let named_struct = read.base_schema.as_ref().ok_or_else(|| {
+ substrait_datafusion_err!("No base schema provided for
LocalFiles")
+ })?;
+
fn extract_filename(name: &str) -> Option<String> {
let corrected_url =
if name.starts_with("file://") &&
!name.starts_with("file:///") {
@@ -865,9 +854,20 @@ pub async fn from_substrait_rel(
let name = filename.unwrap();
// directly use unwrap here since we could determine it is a
valid one
let table_reference = TableReference::Bare { table:
name.into() };
- let t = ctx.table(table_reference).await?;
- let t = t.into_optimized_plan()?;
- extract_projection(t, &read.projection)
+ let t = ctx.table(table_reference.clone()).await?;
+
+ let substrait_schema =
+ from_substrait_named_struct(named_struct, extensions)?
+ .replace_qualifier(table_reference);
+
+ ensure_schema_compatability(
+ t.schema().to_owned(),
+ substrait_schema.clone(),
+ )?;
+
+ let substrait_schema = apply_masking(substrait_schema,
&read.projection)?;
+
+ apply_projection(t, substrait_schema)
}
_ => not_impl_err!("Unsupported ReadType: {:?}",
&read.as_ref().read_type),
},
@@ -995,30 +995,61 @@ pub async fn from_substrait_rel(
/// 1. All fields present in the Substrait schema are present in the
DataFusion schema. The
/// DataFusion schema may have MORE fields, but not the other way around.
/// 2. All fields are compatible. See [`ensure_field_compatability`] for
details
-///
-/// This function returns a DataFrame with fields adjusted if necessary in the
event that the
-/// Substrait schema is a subset of the DataFusion schema.
fn ensure_schema_compatability(
- table: DataFrame,
+ table_schema: DFSchema,
substrait_schema: DFSchema,
-) -> Result<DataFrame> {
- let df_schema = table.schema().to_owned().strip_qualifiers();
- if df_schema.logically_equivalent_names_and_types(&substrait_schema) {
- return Ok(table);
- }
- let selected_columns = substrait_schema
+) -> Result<()> {
+ substrait_schema
.strip_qualifiers()
.fields()
.iter()
- .map(|substrait_field| {
+ .try_for_each(|substrait_field| {
let df_field =
- df_schema.field_with_unqualified_name(substrait_field.name())?;
- ensure_field_compatability(df_field, substrait_field)?;
- Ok(col(format!("\"{}\"", df_field.name())))
+
table_schema.field_with_unqualified_name(substrait_field.name())?;
+ ensure_field_compatability(df_field, substrait_field)
})
- .collect::<Result<_>>()?;
+}
+
+/// This function returns a DataFrame with fields adjusted if necessary in the
event that the
+/// Substrait schema is a subset of the DataFusion schema.
+fn apply_projection(table: DataFrame, substrait_schema: DFSchema) ->
Result<LogicalPlan> {
+ let df_schema = table.schema().to_owned();
- table.select(selected_columns)
+ let t = table.into_unoptimized_plan();
+
+ if df_schema.logically_equivalent_names_and_types(&substrait_schema) {
+ return Ok(t);
+ }
+
+ match t {
+ LogicalPlan::TableScan(mut scan) => {
+ let column_indices: Vec<usize> = substrait_schema
+ .strip_qualifiers()
+ .fields()
+ .iter()
+ .map(|substrait_field| {
+ Ok(df_schema
+ .index_of_column_by_name(None,
substrait_field.name().as_str())
+ .unwrap())
+ })
+ .collect::<Result<_>>()?;
+
+ let fields = column_indices
+ .iter()
+ .map(|i| df_schema.qualified_field(*i))
+ .map(|(qualifier, field)| (qualifier.cloned(),
Arc::new(field.clone())))
+ .collect();
+
+ scan.projected_schema =
DFSchemaRef::new(DFSchema::new_with_metadata(
+ fields,
+ df_schema.metadata().clone(),
+ )?);
+ scan.projection = Some(column_indices);
+
+ Ok(LogicalPlan::TableScan(scan))
+ }
+ _ => plan_err!("DataFrame passed to apply_projection must be a
TableScan"),
+ }
}
/// Ensures that the given Substrait field is compatible with the given
DataFusion field
diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs
b/datafusion/substrait/tests/cases/consumer_integration.rs
index b1cc763050..fffa29df1d 100644
--- a/datafusion/substrait/tests/cases/consumer_integration.rs
+++ b/datafusion/substrait/tests/cases/consumer_integration.rs
@@ -55,7 +55,7 @@ mod tests {
\n Aggregate: groupBy=[[LINEITEM.L_RETURNFLAG,
LINEITEM.L_LINESTATUS]], aggr=[[sum(LINEITEM.L_QUANTITY),
sum(LINEITEM.L_EXTENDEDPRICE), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) -
LINEITEM.L_DISCOUNT), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) -
LINEITEM.L_DISCOUNT * Int32(1) + LINEITEM.L_TAX), avg(LINEITEM.L_QUANTITY),
avg(LINEITEM.L_EXTENDEDPRICE), avg(LINEITEM.L_DISCOUNT), count(Int64(1))]]\
\n Projection: LINEITEM.L_RETURNFLAG, LINEITEM.L_LINESTATUS,
LINEITEM.L_QUANTITY, LINEITEM.L_EXTENDEDPRICE, LINEITEM.L_EXTENDEDPRICE *
(CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT),
LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) -
LINEITEM.L_DISCOUNT) * (CAST(Int32(1) AS Decimal128(15, 2)) + LINEITEM.L_TAX),
LINEITEM.L_DISCOUNT\
\n Filter: LINEITEM.L_SHIPDATE <= Date32(\"1998-12-01\") -
IntervalDayTime(\"IntervalDayTime { days: 0, milliseconds: 10368000 }\")\
- \n TableScan: LINEITEM projection=[L_ORDERKEY, L_PARTKEY,
L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX,
L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE,
L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]"
+ \n TableScan: LINEITEM"
);
Ok(())
}
@@ -76,19 +76,19 @@ mod tests {
\n CrossJoin:\
\n CrossJoin:\
\n CrossJoin:\
- \n TableScan: PARTSUPP
projection=[PS_PARTKEY, PS_SUPPKEY, PS_AVAILQTY, PS_SUPPLYCOST, PS_COMMENT]\
- \n TableScan: SUPPLIER projection=[S_SUPPKEY,
S_NAME, S_ADDRESS, S_NATIONKEY, S_PHONE, S_ACCTBAL, S_COMMENT]\
- \n TableScan: NATION projection=[N_NATIONKEY,
N_NAME, N_REGIONKEY, N_COMMENT]\
- \n TableScan: REGION projection=[R_REGIONKEY,
R_NAME, R_COMMENT]\
+ \n TableScan: PARTSUPP\
+ \n TableScan: SUPPLIER\
+ \n TableScan: NATION\
+ \n TableScan: REGION\
\n CrossJoin:\
\n CrossJoin:\
\n CrossJoin:\
\n CrossJoin:\
- \n TableScan: PART projection=[P_PARTKEY, P_NAME,
P_MFGR, P_BRAND, P_TYPE, P_SIZE, P_CONTAINER, P_RETAILPRICE, P_COMMENT]\
- \n TableScan: SUPPLIER projection=[S_SUPPKEY,
S_NAME, S_ADDRESS, S_NATIONKEY, S_PHONE, S_ACCTBAL, S_COMMENT]\
- \n TableScan: PARTSUPP projection=[PS_PARTKEY,
PS_SUPPKEY, PS_AVAILQTY, PS_SUPPLYCOST, PS_COMMENT]\
- \n TableScan: NATION projection=[N_NATIONKEY, N_NAME,
N_REGIONKEY, N_COMMENT]\
- \n TableScan: REGION projection=[R_REGIONKEY, R_NAME,
R_COMMENT]"
+ \n TableScan: PART\
+ \n TableScan: SUPPLIER\
+ \n TableScan: PARTSUPP\
+ \n TableScan: NATION\
+ \n TableScan: REGION"
);
Ok(())
}
@@ -107,9 +107,9 @@ mod tests {
\n Filter: CUSTOMER.C_MKTSEGMENT = Utf8(\"BUILDING\")
AND CUSTOMER.C_CUSTKEY = ORDERS.O_CUSTKEY AND LINEITEM.L_ORDERKEY =
ORDERS.O_ORDERKEY AND ORDERS.O_ORDERDATE < CAST(Utf8(\"1995-03-15\") AS Date32)
AND LINEITEM.L_SHIPDATE > CAST(Utf8(\"1995-03-15\") AS Date32)\
\n CrossJoin:\
\n CrossJoin:\
- \n TableScan: LINEITEM projection=[L_ORDERKEY,
L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT,
L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE,
L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]\
- \n TableScan: CUSTOMER projection=[C_CUSTKEY,
C_NAME, C_ADDRESS, C_NATIONKEY, C_PHONE, C_ACCTBAL, C_MKTSEGMENT, C_COMMENT]\
- \n TableScan: ORDERS projection=[O_ORDERKEY,
O_CUSTKEY, O_ORDERSTATUS, O_TOTALPRICE, O_ORDERDATE, O_ORDERPRIORITY, O_CLERK,
O_SHIPPRIORITY, O_COMMENT]"
+ \n TableScan: LINEITEM\
+ \n TableScan: CUSTOMER\
+ \n TableScan: ORDERS"
);
Ok(())
}
@@ -126,8 +126,8 @@ mod tests {
\n Filter: ORDERS.O_ORDERDATE >= CAST(Utf8(\"1993-07-01\")
AS Date32) AND ORDERS.O_ORDERDATE < CAST(Utf8(\"1993-10-01\") AS Date32) AND
EXISTS (<subquery>)\
\n Subquery:\
\n Filter: LINEITEM.L_ORDERKEY = LINEITEM.L_ORDERKEY
AND LINEITEM.L_COMMITDATE < LINEITEM.L_RECEIPTDATE\
- \n TableScan: LINEITEM projection=[L_ORDERKEY,
L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT,
L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE,
L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]\
- \n TableScan: ORDERS projection=[O_ORDERKEY, O_CUSTKEY,
O_ORDERSTATUS, O_TOTALPRICE, O_ORDERDATE, O_ORDERPRIORITY, O_CLERK,
O_SHIPPRIORITY, O_COMMENT]"
+ \n TableScan: LINEITEM\
+ \n TableScan: ORDERS"
);
Ok(())
}
@@ -147,12 +147,12 @@ mod tests {
\n CrossJoin:\
\n CrossJoin:\
\n CrossJoin:\
- \n TableScan: CUSTOMER projection=[C_CUSTKEY,
C_NAME, C_ADDRESS, C_NATIONKEY, C_PHONE, C_ACCTBAL, C_MKTSEGMENT, C_COMMENT]\
- \n TableScan: ORDERS projection=[O_ORDERKEY,
O_CUSTKEY, O_ORDERSTATUS, O_TOTALPRICE, O_ORDERDATE, O_ORDERPRIORITY, O_CLERK,
O_SHIPPRIORITY, O_COMMENT]\
- \n TableScan: LINEITEM projection=[L_ORDERKEY,
L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT,
L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE,
L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]\
- \n TableScan: SUPPLIER projection=[S_SUPPKEY,
S_NAME, S_ADDRESS, S_NATIONKEY, S_PHONE, S_ACCTBAL, S_COMMENT]\
- \n TableScan: NATION projection=[N_NATIONKEY, N_NAME,
N_REGIONKEY, N_COMMENT]\
- \n TableScan: REGION projection=[R_REGIONKEY, R_NAME,
R_COMMENT]"
+ \n TableScan: CUSTOMER\
+ \n TableScan: ORDERS\
+ \n TableScan: LINEITEM\
+ \n TableScan: SUPPLIER\
+ \n TableScan: NATION\
+ \n TableScan: REGION"
);
Ok(())
}
@@ -165,7 +165,7 @@ mod tests {
"Aggregate: groupBy=[[]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE *
LINEITEM.L_DISCOUNT) AS REVENUE]]\
\n Projection: LINEITEM.L_EXTENDEDPRICE * LINEITEM.L_DISCOUNT\
\n Filter: LINEITEM.L_SHIPDATE >= CAST(Utf8(\"1994-01-01\") AS
Date32) AND LINEITEM.L_SHIPDATE < CAST(Utf8(\"1995-01-01\") AS Date32) AND
LINEITEM.L_DISCOUNT >= Decimal128(Some(5),3,2) AND LINEITEM.L_DISCOUNT <=
Decimal128(Some(7),3,2) AND LINEITEM.L_QUANTITY < CAST(Int32(24) AS
Decimal128(15, 2))\
- \n TableScan: LINEITEM projection=[L_ORDERKEY, L_PARTKEY,
L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX,
L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE,
L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]"
+ \n TableScan: LINEITEM"
);
Ok(())
}
@@ -209,10 +209,10 @@ mod tests {
\n CrossJoin:\
\n CrossJoin:\
\n CrossJoin:\
- \n TableScan: CUSTOMER projection=[C_CUSTKEY,
C_NAME, C_ADDRESS, C_NATIONKEY, C_PHONE, C_ACCTBAL, C_MKTSEGMENT, C_COMMENT]\
- \n TableScan: ORDERS projection=[O_ORDERKEY,
O_CUSTKEY, O_ORDERSTATUS, O_TOTALPRICE, O_ORDERDATE, O_ORDERPRIORITY, O_CLERK,
O_SHIPPRIORITY, O_COMMENT]\
- \n TableScan: LINEITEM projection=[L_ORDERKEY,
L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT,
L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE,
L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]\
- \n TableScan: NATION projection=[N_NATIONKEY,
N_NAME, N_REGIONKEY, N_COMMENT]"
+ \n TableScan: CUSTOMER\
+ \n TableScan: ORDERS\
+ \n TableScan: LINEITEM\
+ \n TableScan: NATION"
);
Ok(())
}
@@ -232,17 +232,17 @@ mod tests {
\n Filter: PARTSUPP.PS_SUPPKEY = SUPPLIER.S_SUPPKEY
AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_NAME =
Utf8(\"JAPAN\")\
\n CrossJoin:\
\n CrossJoin:\
- \n TableScan: PARTSUPP projection=[PS_PARTKEY,
PS_SUPPKEY, PS_AVAILQTY, PS_SUPPLYCOST, PS_COMMENT]\
- \n TableScan: SUPPLIER projection=[S_SUPPKEY,
S_NAME, S_ADDRESS, S_NATIONKEY, S_PHONE, S_ACCTBAL, S_COMMENT]\
- \n TableScan: NATION projection=[N_NATIONKEY,
N_NAME, N_REGIONKEY, N_COMMENT]\
+ \n TableScan: PARTSUPP\
+ \n TableScan: SUPPLIER\
+ \n TableScan: NATION\
\n Aggregate: groupBy=[[PARTSUPP.PS_PARTKEY]],
aggr=[[sum(PARTSUPP.PS_SUPPLYCOST * PARTSUPP.PS_AVAILQTY)]]\
\n Projection: PARTSUPP.PS_PARTKEY, PARTSUPP.PS_SUPPLYCOST
* CAST(PARTSUPP.PS_AVAILQTY AS Decimal128(19, 0))\
\n Filter: PARTSUPP.PS_SUPPKEY = SUPPLIER.S_SUPPKEY AND
SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_NAME = Utf8(\"JAPAN\")\
\n CrossJoin:\
\n CrossJoin:\
- \n TableScan: PARTSUPP projection=[PS_PARTKEY,
PS_SUPPKEY, PS_AVAILQTY, PS_SUPPLYCOST, PS_COMMENT]\
- \n TableScan: SUPPLIER projection=[S_SUPPKEY,
S_NAME, S_ADDRESS, S_NATIONKEY, S_PHONE, S_ACCTBAL, S_COMMENT]\
- \n TableScan: NATION projection=[N_NATIONKEY, N_NAME,
N_REGIONKEY, N_COMMENT]"
+ \n TableScan: PARTSUPP\
+ \n TableScan: SUPPLIER\
+ \n TableScan: NATION"
);
Ok(())
}
@@ -258,8 +258,8 @@ mod tests {
\n Projection: LINEITEM.L_SHIPMODE, CASE WHEN
ORDERS.O_ORDERPRIORITY = Utf8(\"1-URGENT\") OR ORDERS.O_ORDERPRIORITY =
Utf8(\"2-HIGH\") THEN Int32(1) ELSE Int32(0) END, CASE WHEN
ORDERS.O_ORDERPRIORITY != Utf8(\"1-URGENT\") AND ORDERS.O_ORDERPRIORITY !=
Utf8(\"2-HIGH\") THEN Int32(1) ELSE Int32(0) END\
\n Filter: ORDERS.O_ORDERKEY = LINEITEM.L_ORDERKEY AND
(LINEITEM.L_SHIPMODE = CAST(Utf8(\"MAIL\") AS Utf8) OR LINEITEM.L_SHIPMODE =
CAST(Utf8(\"SHIP\") AS Utf8)) AND LINEITEM.L_COMMITDATE <
LINEITEM.L_RECEIPTDATE AND LINEITEM.L_SHIPDATE < LINEITEM.L_COMMITDATE AND
LINEITEM.L_RECEIPTDATE >= CAST(Utf8(\"1994-01-01\") AS Date32) AND
LINEITEM.L_RECEIPTDATE < CAST(Utf8(\"1995-01-01\") AS Date32)\
\n CrossJoin:\
- \n TableScan: ORDERS projection=[O_ORDERKEY, O_CUSTKEY,
O_ORDERSTATUS, O_TOTALPRICE, O_ORDERDATE, O_ORDERPRIORITY, O_CLERK,
O_SHIPPRIORITY, O_COMMENT]\
- \n TableScan: LINEITEM projection=[L_ORDERKEY,
L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT,
L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE,
L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]"
+ \n TableScan: ORDERS\
+ \n TableScan: LINEITEM"
);
Ok(())
}
@@ -277,8 +277,8 @@ mod tests {
\n Aggregate: groupBy=[[CUSTOMER.C_CUSTKEY]],
aggr=[[count(ORDERS.O_ORDERKEY)]]\
\n Projection: CUSTOMER.C_CUSTKEY, ORDERS.O_ORDERKEY\
\n Left Join: CUSTOMER.C_CUSTKEY = ORDERS.O_CUSTKEY
Filter: NOT ORDERS.O_COMMENT LIKE CAST(Utf8(\"%special%requests%\") AS Utf8)\
- \n TableScan: CUSTOMER projection=[C_CUSTKEY,
C_NAME, C_ADDRESS, C_NATIONKEY, C_PHONE, C_ACCTBAL, C_MKTSEGMENT, C_COMMENT]\
- \n TableScan: ORDERS projection=[O_ORDERKEY,
O_CUSTKEY, O_ORDERSTATUS, O_TOTALPRICE, O_ORDERDATE, O_ORDERPRIORITY, O_CLERK,
O_SHIPPRIORITY, O_COMMENT]"
+ \n TableScan: CUSTOMER\
+ \n TableScan: ORDERS"
);
Ok(())
}
@@ -293,8 +293,8 @@ mod tests {
\n Projection: CASE WHEN PART.P_TYPE LIKE CAST(Utf8(\"PROMO%\")
AS Utf8) THEN LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) -
LINEITEM.L_DISCOUNT) ELSE Decimal128(Some(0),19,4) END,
LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) -
LINEITEM.L_DISCOUNT)\
\n Filter: LINEITEM.L_PARTKEY = PART.P_PARTKEY AND
LINEITEM.L_SHIPDATE >= Date32(\"1995-09-01\") AND LINEITEM.L_SHIPDATE <
CAST(Utf8(\"1995-10-01\") AS Date32)\
\n CrossJoin:\
- \n TableScan: LINEITEM projection=[L_ORDERKEY, L_PARTKEY,
L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX,
L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE,
L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]\
- \n TableScan: PART projection=[P_PARTKEY, P_NAME, P_MFGR,
P_BRAND, P_TYPE, P_SIZE, P_CONTAINER, P_RETAILPRICE, P_COMMENT]"
+ \n TableScan: LINEITEM\
+ \n TableScan: PART"
);
Ok(())
}
@@ -320,10 +320,10 @@ mod tests {
\n Subquery:\
\n Projection: SUPPLIER.S_SUPPKEY\
\n Filter: SUPPLIER.S_COMMENT LIKE
CAST(Utf8(\"%Customer%Complaints%\") AS Utf8)\
- \n TableScan: SUPPLIER projection=[S_SUPPKEY,
S_NAME, S_ADDRESS, S_NATIONKEY, S_PHONE, S_ACCTBAL, S_COMMENT]\
+ \n TableScan: SUPPLIER\
\n CrossJoin:\
- \n TableScan: PARTSUPP projection=[PS_PARTKEY,
PS_SUPPKEY, PS_AVAILQTY, PS_SUPPLYCOST, PS_COMMENT]\
- \n TableScan: PART projection=[P_PARTKEY, P_NAME,
P_MFGR, P_BRAND, P_TYPE, P_SIZE, P_CONTAINER, P_RETAILPRICE, P_COMMENT]"
+ \n TableScan: PARTSUPP\
+ \n TableScan: PART"
);
Ok(())
}
@@ -352,12 +352,12 @@ mod tests {
\n Filter: sum(LINEITEM.L_QUANTITY) >
CAST(Int32(300) AS Decimal128(15, 2))\
\n Aggregate: groupBy=[[LINEITEM.L_ORDERKEY]],
aggr=[[sum(LINEITEM.L_QUANTITY)]]\
\n Projection: LINEITEM.L_ORDERKEY,
LINEITEM.L_QUANTITY\
- \n TableScan: LINEITEM
projection=[L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY,
L_EXTENDEDPRICE, L_DISCOUNT, L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE,
L_COMMITDATE, L_RECEIPTDATE, L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]\
+ \n TableScan: LINEITEM\
\n CrossJoin:\
\n CrossJoin:\
- \n TableScan: CUSTOMER projection=[C_CUSTKEY,
C_NAME, C_ADDRESS, C_NATIONKEY, C_PHONE, C_ACCTBAL, C_MKTSEGMENT, C_COMMENT]\
- \n TableScan: ORDERS projection=[O_ORDERKEY,
O_CUSTKEY, O_ORDERSTATUS, O_TOTALPRICE, O_ORDERDATE, O_ORDERPRIORITY, O_CLERK,
O_SHIPPRIORITY, O_COMMENT]\
- \n TableScan: LINEITEM projection=[L_ORDERKEY,
L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT,
L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE,
L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]"
+ \n TableScan: CUSTOMER\
+ \n TableScan: ORDERS\
+ \n TableScan: LINEITEM"
);
Ok(())
}
@@ -370,8 +370,8 @@ mod tests {
\n Projection: LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS
Decimal128(15, 2)) - LINEITEM.L_DISCOUNT)\
\n Filter: PART.P_PARTKEY = LINEITEM.L_PARTKEY AND PART.P_BRAND
= Utf8(\"Brand#12\") AND (PART.P_CONTAINER = CAST(Utf8(\"SM CASE\") AS Utf8) OR
PART.P_CONTAINER = CAST(Utf8(\"SM BOX\") AS Utf8) OR PART.P_CONTAINER =
CAST(Utf8(\"SM PACK\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"SM PKG\") AS
Utf8)) AND LINEITEM.L_QUANTITY >= CAST(Int32(1) AS Decimal128(15, 2)) AND
LINEITEM.L_QUANTITY <= CAST(Int32(1) + Int32(10) AS Decimal128(15, 2)) AND
PART.P_SIZE >= Int32(1) AND PART. [...]
\n CrossJoin:\
- \n TableScan: LINEITEM projection=[L_ORDERKEY, L_PARTKEY,
L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT, L_TAX,
L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE,
L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]\
- \n TableScan: PART projection=[P_PARTKEY, P_NAME, P_MFGR,
P_BRAND, P_TYPE, P_SIZE, P_CONTAINER, P_RETAILPRICE, P_COMMENT]"
+ \n TableScan: LINEITEM\
+ \n TableScan: PART"
);
Ok(())
}
@@ -390,17 +390,17 @@ mod tests {
\n Subquery:\
\n Projection: PART.P_PARTKEY\
\n Filter: PART.P_NAME LIKE CAST(Utf8(\"forest%\")
AS Utf8)\
- \n TableScan: PART projection=[P_PARTKEY, P_NAME,
P_MFGR, P_BRAND, P_TYPE, P_SIZE, P_CONTAINER, P_RETAILPRICE, P_COMMENT]\
+ \n TableScan: PART\
\n Subquery:\
\n Projection: Decimal128(Some(5),2,1) *
sum(LINEITEM.L_QUANTITY)\
\n Aggregate: groupBy=[[]],
aggr=[[sum(LINEITEM.L_QUANTITY)]]\
\n Projection: LINEITEM.L_QUANTITY\
\n Filter: LINEITEM.L_PARTKEY =
LINEITEM.L_ORDERKEY AND LINEITEM.L_SUPPKEY = LINEITEM.L_PARTKEY AND
LINEITEM.L_SHIPDATE >= CAST(Utf8(\"1994-01-01\") AS Date32) AND
LINEITEM.L_SHIPDATE < CAST(Utf8(\"1995-01-01\") AS Date32)\
- \n TableScan: LINEITEM
projection=[L_ORDERKEY, L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY,
L_EXTENDEDPRICE, L_DISCOUNT, L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE,
L_COMMITDATE, L_RECEIPTDATE, L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]\
- \n TableScan: PARTSUPP projection=[PS_PARTKEY,
PS_SUPPKEY, PS_AVAILQTY, PS_SUPPLYCOST, PS_COMMENT]\
+ \n TableScan: LINEITEM\
+ \n TableScan: PARTSUPP\
\n CrossJoin:\
- \n TableScan: SUPPLIER projection=[S_SUPPKEY, S_NAME,
S_ADDRESS, S_NATIONKEY, S_PHONE, S_ACCTBAL, S_COMMENT]\
- \n TableScan: NATION projection=[N_NATIONKEY, N_NAME,
N_REGIONKEY, N_COMMENT]"
+ \n TableScan: SUPPLIER\
+ \n TableScan: NATION"
);
Ok(())
}
@@ -418,17 +418,17 @@ mod tests {
\n Filter: SUPPLIER.S_SUPPKEY = LINEITEM.L_SUPPKEY AND
ORDERS.O_ORDERKEY = LINEITEM.L_ORDERKEY AND ORDERS.O_ORDERSTATUS = Utf8(\"F\")
AND LINEITEM.L_RECEIPTDATE > LINEITEM.L_COMMITDATE AND EXISTS (<subquery>) AND
NOT EXISTS (<subquery>) AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND
NATION.N_NAME = Utf8(\"SAUDI ARABIA\")\
\n Subquery:\
\n Filter: LINEITEM.L_ORDERKEY = LINEITEM.L_TAX AND
LINEITEM.L_SUPPKEY != LINEITEM.L_LINESTATUS\
- \n TableScan: LINEITEM projection=[L_ORDERKEY,
L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT,
L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE,
L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]\
+ \n TableScan: LINEITEM\
\n Subquery:\
\n Filter: LINEITEM.L_ORDERKEY = LINEITEM.L_TAX AND
LINEITEM.L_SUPPKEY != LINEITEM.L_LINESTATUS AND LINEITEM.L_RECEIPTDATE >
LINEITEM.L_COMMITDATE\
- \n TableScan: LINEITEM projection=[L_ORDERKEY,
L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT,
L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE,
L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]\
+ \n TableScan: LINEITEM\
\n CrossJoin:\
\n CrossJoin:\
\n CrossJoin:\
- \n TableScan: SUPPLIER projection=[S_SUPPKEY,
S_NAME, S_ADDRESS, S_NATIONKEY, S_PHONE, S_ACCTBAL, S_COMMENT]\
- \n TableScan: LINEITEM projection=[L_ORDERKEY,
L_PARTKEY, L_SUPPKEY, L_LINENUMBER, L_QUANTITY, L_EXTENDEDPRICE, L_DISCOUNT,
L_TAX, L_RETURNFLAG, L_LINESTATUS, L_SHIPDATE, L_COMMITDATE, L_RECEIPTDATE,
L_SHIPINSTRUCT, L_SHIPMODE, L_COMMENT]\
- \n TableScan: ORDERS projection=[O_ORDERKEY,
O_CUSTKEY, O_ORDERSTATUS, O_TOTALPRICE, O_ORDERDATE, O_ORDERPRIORITY, O_CLERK,
O_SHIPPRIORITY, O_COMMENT]\
- \n TableScan: NATION projection=[N_NATIONKEY, N_NAME,
N_REGIONKEY, N_COMMENT]"
+ \n TableScan: SUPPLIER\
+ \n TableScan: LINEITEM\
+ \n TableScan: ORDERS\
+ \n TableScan: NATION"
);
Ok(())
}
@@ -447,11 +447,11 @@ mod tests {
\n Aggregate: groupBy=[[]],
aggr=[[avg(CUSTOMER.C_ACCTBAL)]]\
\n Projection: CUSTOMER.C_ACCTBAL\
\n Filter: CUSTOMER.C_ACCTBAL >
Decimal128(Some(0),3,2) AND (substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) =
CAST(Utf8(\"13\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) =
CAST(Utf8(\"31\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) =
CAST(Utf8(\"23\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) =
CAST(Utf8(\"29\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) =
CAST(Utf8(\"30\") AS Utf8) OR substr(CUSTOMER.C_P [...]
- \n TableScan: CUSTOMER projection=[C_CUSTKEY,
C_NAME, C_ADDRESS, C_NATIONKEY, C_PHONE, C_ACCTBAL, C_MKTSEGMENT, C_COMMENT]\
+ \n TableScan: CUSTOMER\
\n Subquery:\
\n Filter: ORDERS.O_CUSTKEY = ORDERS.O_ORDERKEY\
- \n TableScan: ORDERS projection=[O_ORDERKEY,
O_CUSTKEY, O_ORDERSTATUS, O_TOTALPRICE, O_ORDERDATE, O_ORDERPRIORITY, O_CLERK,
O_SHIPPRIORITY, O_COMMENT]\
- \n TableScan: CUSTOMER projection=[C_CUSTKEY, C_NAME,
C_ADDRESS, C_NATIONKEY, C_PHONE, C_ACCTBAL, C_MKTSEGMENT, C_COMMENT]"
+ \n TableScan: ORDERS\
+ \n TableScan: CUSTOMER"
);
Ok(())
}
diff --git a/datafusion/substrait/tests/cases/function_test.rs
b/datafusion/substrait/tests/cases/function_test.rs
index 5806b55d84..b136b0af19 100644
--- a/datafusion/substrait/tests/cases/function_test.rs
+++ b/datafusion/substrait/tests/cases/function_test.rs
@@ -37,7 +37,7 @@ mod tests {
plan_str,
"Projection: nation.n_name\
\n Filter: contains(nation.n_name, Utf8(\"IA\"))\
- \n TableScan: nation projection=[n_nationkey, n_name,
n_regionkey, n_comment]"
+ \n TableScan: nation"
);
Ok(())
}
diff --git a/datafusion/substrait/tests/cases/logical_plans.rs
b/datafusion/substrait/tests/cases/logical_plans.rs
index 6794b32838..f4e34af35d 100644
--- a/datafusion/substrait/tests/cases/logical_plans.rs
+++ b/datafusion/substrait/tests/cases/logical_plans.rs
@@ -43,7 +43,7 @@ mod tests {
assert_eq!(
format!("{}", plan),
"Projection: NOT DATA.D AS EXPR$0\
- \n TableScan: DATA projection=[D]"
+ \n TableScan: DATA"
);
Ok(())
}
@@ -69,7 +69,7 @@ mod tests {
format!("{}", plan),
"Projection: sum(DATA.D) PARTITION BY [DATA.PART] ORDER BY
[DATA.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS
LEAD_EXPR\
\n WindowAggr: windowExpr=[[sum(DATA.D) PARTITION BY [DATA.PART]
ORDER BY [DATA.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED
FOLLOWING]]\
- \n TableScan: DATA projection=[D, PART, ORD]"
+ \n TableScan: DATA"
);
Ok(())
}
diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index 75881a421d..d60fe90388 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -472,12 +472,12 @@ async fn roundtrip_inlist_5() -> Result<()> {
\n Subquery:\
\n Projection: data2.a\
\n Filter: data2.f IN ([Utf8(\"b\"), Utf8(\"c\"), Utf8(\"d\")])\
- \n TableScan: data2 projection=[a, b, c, d, e, f]\
+ \n TableScan: data2\
\n TableScan: data projection=[a, f], partial_filters=[data.f =
Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = Utf8(\"c\") OR data.a IN
(<subquery>)]\
\n Subquery:\
\n Projection: data2.a\
\n Filter: data2.f IN ([Utf8(\"b\"), Utf8(\"c\"), Utf8(\"d\")])\
- \n TableScan: data2 projection=[a, b, c, d, e, f]",
+ \n TableScan: data2",
true).await
}
diff --git a/datafusion/substrait/tests/cases/substrait_validations.rs
b/datafusion/substrait/tests/cases/substrait_validations.rs
index cb1fb67fc0..5ae586afe5 100644
--- a/datafusion/substrait/tests/cases/substrait_validations.rs
+++ b/datafusion/substrait/tests/cases/substrait_validations.rs
@@ -70,7 +70,7 @@ mod tests {
assert_eq!(
format!("{}", plan),
"Projection: DATA.a, DATA.b\
- \n TableScan: DATA projection=[a, b]"
+ \n TableScan: DATA"
);
Ok(())
}
@@ -91,8 +91,7 @@ mod tests {
assert_eq!(
format!("{}", plan),
"Projection: DATA.a, DATA.b\
- \n Projection: DATA.a, DATA.b\
- \n TableScan: DATA projection=[b, a]"
+ \n TableScan: DATA projection=[a, b]"
);
Ok(())
}
@@ -102,12 +101,12 @@ mod tests {
let proto_plan = read_json(
"tests/testdata/test_plans/simple_select_with_mask.substrait.json",
);
- // the DataFusion schema { b, a, c, d } contains the Substrait
schema { a, b, c }
+ // the DataFusion schema { d, a, c, b } contains the Substrait
schema { a, b, c }
let df_schema = vec![
- ("b", DataType::Int32, true),
+ ("d", DataType::Int32, true),
("a", DataType::Int32, false),
("c", DataType::Int32, false),
- ("d", DataType::Int32, false),
+ ("b", DataType::Int32, false),
];
let ctx = generate_context_with_table("DATA", df_schema)?;
let plan = from_substrait_plan(&ctx, &proto_plan).await?;
@@ -115,9 +114,7 @@ mod tests {
assert_eq!(
format!("{}", plan),
"Projection: DATA.a, DATA.b\
- \n Projection: DATA.a, DATA.b\
- \n Projection: DATA.a, DATA.b, DATA.c\
- \n TableScan: DATA projection=[b, a, c]"
+ \n TableScan: DATA projection=[a, b]"
);
Ok(())
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]