This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 4c898b4572 feat(substrait): modular substrait producer (#13931)
4c898b4572 is described below
commit 4c898b45720efed56f15f8030e8ca2c1e5f6ec1a
Author: Victor Barua <[email protected]>
AuthorDate: Sun Jan 5 10:26:09 2025 -0800
feat(substrait): modular substrait producer (#13931)
* feat(substrait): modular substrait producer
* refactor(substrait): simplify col_ref_offset handling in producer
* refactor(substrait): remove column offset tracking from producer
* docs(substrait): document SubstraitProducer
* refactor: minor cleanup
* feature: remove unused SubstraitPlanningState
BREAKING CHANGE: SubstraitPlanningState is no longer available
* refactor: cargo fmt
* refactor(substrait): consume_ -> handle_
* refactor(substrait): expand match blocks
* refactor: DefaultSubstraitProducer only needs serializer_registry
* refactor: remove unnecessary warning suppression
* fix(substrait): route expr conversion through handle_expr
* cargo fmt
---
datafusion/substrait/src/logical_plan/consumer.rs | 3 +
datafusion/substrait/src/logical_plan/mod.rs | 1 -
datafusion/substrait/src/logical_plan/producer.rs | 2260 +++++++++++---------
datafusion/substrait/src/logical_plan/state.rs | 63 -
.../tests/cases/roundtrip_logical_plan.rs | 15 +
5 files changed, 1295 insertions(+), 1047 deletions(-)
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs
b/datafusion/substrait/src/logical_plan/consumer.rs
index 0ee87afe32..9623f12c88 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -114,6 +114,9 @@ use substrait::proto::{
/// This trait is used to consume Substrait plans, converting them into
DataFusion Logical Plans.
/// It can be implemented by users to allow for custom handling of relations,
expressions, etc.
///
+/// Combined with the [crate::logical_plan::producer::SubstraitProducer] this
allows for fully
+/// customizable Substrait serde.
+///
/// # Example Usage
///
/// ```
diff --git a/datafusion/substrait/src/logical_plan/mod.rs
b/datafusion/substrait/src/logical_plan/mod.rs
index 9e2fa9fa49..6f8b8e493f 100644
--- a/datafusion/substrait/src/logical_plan/mod.rs
+++ b/datafusion/substrait/src/logical_plan/mod.rs
@@ -17,4 +17,3 @@
pub mod consumer;
pub mod producer;
-pub mod state;
diff --git a/datafusion/substrait/src/logical_plan/producer.rs
b/datafusion/substrait/src/logical_plan/producer.rs
index b73d246e19..e501ddf5c6 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -22,7 +22,11 @@ use std::sync::Arc;
use substrait::proto::expression_reference::ExprType;
use datafusion::arrow::datatypes::{Field, IntervalUnit};
-use datafusion::logical_expr::{Distinct, Like, Partitioning, TryCast,
WindowFrameUnits};
+use datafusion::logical_expr::{
+ Aggregate, Distinct, EmptyRelation, Extension, Filter, Join, Like, Limit,
+ Partitioning, Projection, Repartition, Sort, SortExpr, SubqueryAlias,
TableScan,
+ TryCast, Union, Values, Window, WindowFrameUnits,
+};
use datafusion::{
arrow::datatypes::{DataType, TimeUnit},
error::{DataFusionError, Result},
@@ -43,11 +47,12 @@ use datafusion::arrow::array::{Array, GenericListArray,
OffsetSizeTrait};
use datafusion::arrow::temporal_conversions::NANOSECONDS;
use datafusion::common::{
exec_err, internal_err, not_impl_err, plan_err, substrait_datafusion_err,
- substrait_err, DFSchema, DFSchemaRef, ToDFSchema,
+ substrait_err, Column, DFSchema, DFSchemaRef, ToDFSchema,
};
-#[allow(unused_imports)]
+use datafusion::execution::registry::SerializerRegistry;
+use datafusion::execution::SessionState;
use datafusion::logical_expr::expr::{
- Alias, BinaryExpr, Case, Cast, GroupingSet, InList, InSubquery, Sort,
WindowFunction,
+ Alias, BinaryExpr, Case, Cast, GroupingSet, InList, InSubquery,
WindowFunction,
};
use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan,
Operator};
use datafusion::prelude::Expr;
@@ -63,6 +68,7 @@ use substrait::proto::expression::literal::{
};
use substrait::proto::expression::subquery::InPredicate;
use substrait::proto::expression::window_function::BoundsType;
+use substrait::proto::expression::ScalarFunction;
use substrait::proto::read_rel::VirtualTable;
use substrait::proto::rel_common::EmitKind;
use substrait::proto::rel_common::EmitKind::Emit;
@@ -84,8 +90,7 @@ use substrait::{
window_function::bound::Kind as BoundKind,
window_function::Bound,
FieldReference, IfThen, Literal, MaskExpression, ReferenceSegment,
RexType,
- ScalarFunction, SingularOrList, Subquery,
- WindowFunction as SubstraitWindowFunction,
+ SingularOrList, WindowFunction as SubstraitWindowFunction,
},
function_argument::ArgType,
join_rel, plan_rel, r#type,
@@ -101,14 +106,329 @@ use substrait::{
version,
};
-use super::state::SubstraitPlanningState;
+/// This trait is used to produce Substrait plans, converting them from
DataFusion Logical Plans.
+/// It can be implemented by users to allow for custom handling of relations,
expressions, etc.
+///
+/// Combined with the [crate::logical_plan::consumer::SubstraitConsumer] this
allows for fully
+/// customizable Substrait serde.
+///
+/// # Example Usage
+///
+/// ```
+/// # use std::sync::Arc;
+/// # use substrait::proto::{Expression, Rel};
+/// # use substrait::proto::rel::RelType;
+/// # use datafusion::common::DFSchemaRef;
+/// # use datafusion::error::Result;
+/// # use datafusion::execution::SessionState;
+/// # use datafusion::logical_expr::{Between, Extension, Projection};
+/// # use datafusion_substrait::extensions::Extensions;
+/// # use datafusion_substrait::logical_plan::producer::{from_projection,
SubstraitProducer};
+///
+/// struct CustomSubstraitProducer {
+/// extensions: Extensions,
+/// state: Arc<SessionState>,
+/// }
+///
+/// impl SubstraitProducer for CustomSubstraitProducer {
+///
+/// fn register_function(&mut self, signature: String) -> u32 {
+/// self.extensions.register_function(signature)
+/// }
+///
+/// fn get_extensions(self) -> Extensions {
+/// self.extensions
+/// }
+///
+/// // You can set additional metadata on the Rels you produce
+/// fn handle_projection(&mut self, plan: &Projection) -> Result<Box<Rel>>
{
+/// let mut rel = from_projection(self, plan)?;
+/// match rel.rel_type {
+/// Some(RelType::Project(mut project)) => {
+/// let mut project = project.clone();
+/// // set common metadata or advanced extension
+/// project.common = None;
+/// project.advanced_extension = None;
+/// Ok(Box::new(Rel {
+/// rel_type: Some(RelType::Project(project)),
+/// }))
+/// }
+/// rel_type => Ok(Box::new(Rel { rel_type })),
+/// }
+/// }
+///
+/// // You can tweak how you convert expressions for your target system
+/// fn handle_between(&mut self, between: &Between, schema: &DFSchemaRef)
-> Result<Expression> {
+/// // add your own encoding for Between
+/// todo!()
+/// }
+///
+/// // You can fully control how you convert UserDefinedLogicalNodes into
Substrait
+/// fn handle_extension(&mut self, _plan: &Extension) -> Result<Box<Rel>> {
+/// // implement your own serializer into Substrait
+/// todo!()
+/// }
+/// }
+/// ```
+pub trait SubstraitProducer: Send + Sync + Sized {
+ /// Within a Substrait plan, functions are referenced using function
anchors that are stored at
+ /// the top level of the [Plan] within
+ ///
[ExtensionFunction](substrait::proto::extensions::simple_extension_declaration::ExtensionFunction)
+ /// messages.
+ ///
+ /// When given a function signature, this method should return the
existing anchor for it if
+ /// there is one. Otherwise, it should generate a new anchor.
+ fn register_function(&mut self, signature: String) -> u32;
+
+ /// Consume the producer to generate the [Extensions] for the Substrait
plan based on the
+ /// functions that have been registered
+ fn get_extensions(self) -> Extensions;
+
+ // Logical Plan Methods
+ // There is one method per LogicalPlan to allow for easy overriding of
producer behaviour.
+ // These methods have default implementations calling the common handler
code, to allow for users
+ // to re-use common handling logic.
+
+ fn handle_plan(&mut self, plan: &LogicalPlan) -> Result<Box<Rel>> {
+ to_substrait_rel(self, plan)
+ }
+
+ fn handle_projection(&mut self, plan: &Projection) -> Result<Box<Rel>> {
+ from_projection(self, plan)
+ }
+
+ fn handle_filter(&mut self, plan: &Filter) -> Result<Box<Rel>> {
+ from_filter(self, plan)
+ }
+
+ fn handle_window(&mut self, plan: &Window) -> Result<Box<Rel>> {
+ from_window(self, plan)
+ }
+
+ fn handle_aggregate(&mut self, plan: &Aggregate) -> Result<Box<Rel>> {
+ from_aggregate(self, plan)
+ }
+
+ fn handle_sort(&mut self, plan: &Sort) -> Result<Box<Rel>> {
+ from_sort(self, plan)
+ }
+
+ fn handle_join(&mut self, plan: &Join) -> Result<Box<Rel>> {
+ from_join(self, plan)
+ }
+
+ fn handle_repartition(&mut self, plan: &Repartition) -> Result<Box<Rel>> {
+ from_repartition(self, plan)
+ }
+
+ fn handle_union(&mut self, plan: &Union) -> Result<Box<Rel>> {
+ from_union(self, plan)
+ }
+
+ fn handle_table_scan(&mut self, plan: &TableScan) -> Result<Box<Rel>> {
+ from_table_scan(self, plan)
+ }
+
+ fn handle_empty_relation(&mut self, plan: &EmptyRelation) ->
Result<Box<Rel>> {
+ from_empty_relation(plan)
+ }
+
+ fn handle_subquery_alias(&mut self, plan: &SubqueryAlias) ->
Result<Box<Rel>> {
+ from_subquery_alias(self, plan)
+ }
+
+ fn handle_limit(&mut self, plan: &Limit) -> Result<Box<Rel>> {
+ from_limit(self, plan)
+ }
+
+ fn handle_values(&mut self, plan: &Values) -> Result<Box<Rel>> {
+ from_values(self, plan)
+ }
+
+ fn handle_distinct(&mut self, plan: &Distinct) -> Result<Box<Rel>> {
+ from_distinct(self, plan)
+ }
+
+ fn handle_extension(&mut self, _plan: &Extension) -> Result<Box<Rel>> {
+ substrait_err!("Specify handling for LogicalPlan::Extension by
implementing the SubstraitProducer trait")
+ }
+
+ // Expression Methods
+ // There is one method per DataFusion Expr to allow for easy overriding of
producer behaviour
+ // These methods have default implementations calling the common handler
code, to allow for users
+ // to re-use common handling logic.
+
+ fn handle_expr(&mut self, expr: &Expr, schema: &DFSchemaRef) ->
Result<Expression> {
+ to_substrait_rex(self, expr, schema)
+ }
+
+ fn handle_alias(
+ &mut self,
+ alias: &Alias,
+ schema: &DFSchemaRef,
+ ) -> Result<Expression> {
+ from_alias(self, alias, schema)
+ }
+
+ fn handle_column(
+ &mut self,
+ column: &Column,
+ schema: &DFSchemaRef,
+ ) -> Result<Expression> {
+ from_column(column, schema)
+ }
+
+ fn handle_literal(&mut self, value: &ScalarValue) -> Result<Expression> {
+ from_literal(self, value)
+ }
+
+ fn handle_binary_expr(
+ &mut self,
+ expr: &BinaryExpr,
+ schema: &DFSchemaRef,
+ ) -> Result<Expression> {
+ from_binary_expr(self, expr, schema)
+ }
+
+ fn handle_like(&mut self, like: &Like, schema: &DFSchemaRef) ->
Result<Expression> {
+ from_like(self, like, schema)
+ }
+
+ /// For handling Not, IsNotNull, IsNull, IsTrue, IsFalse, IsUnknown,
IsNotTrue, IsNotFalse, IsNotUnknown, Negative
+ fn handle_unary_expr(
+ &mut self,
+ expr: &Expr,
+ schema: &DFSchemaRef,
+ ) -> Result<Expression> {
+ from_unary_expr(self, expr, schema)
+ }
+
+ fn handle_between(
+ &mut self,
+ between: &Between,
+ schema: &DFSchemaRef,
+ ) -> Result<Expression> {
+ from_between(self, between, schema)
+ }
+
+ fn handle_case(&mut self, case: &Case, schema: &DFSchemaRef) ->
Result<Expression> {
+ from_case(self, case, schema)
+ }
+
+ fn handle_cast(&mut self, cast: &Cast, schema: &DFSchemaRef) ->
Result<Expression> {
+ from_cast(self, cast, schema)
+ }
+
+ fn handle_try_cast(
+ &mut self,
+ cast: &TryCast,
+ schema: &DFSchemaRef,
+ ) -> Result<Expression> {
+ from_try_cast(self, cast, schema)
+ }
+
+ fn handle_scalar_function(
+ &mut self,
+ scalar_fn: &expr::ScalarFunction,
+ schema: &DFSchemaRef,
+ ) -> Result<Expression> {
+ from_scalar_function(self, scalar_fn, schema)
+ }
+
+ fn handle_aggregate_function(
+ &mut self,
+ agg_fn: &expr::AggregateFunction,
+ schema: &DFSchemaRef,
+ ) -> Result<Measure> {
+ from_aggregate_function(self, agg_fn, schema)
+ }
+
+ fn handle_window_function(
+ &mut self,
+ window_fn: &WindowFunction,
+ schema: &DFSchemaRef,
+ ) -> Result<Expression> {
+ from_window_function(self, window_fn, schema)
+ }
+
+ fn handle_in_list(
+ &mut self,
+ in_list: &InList,
+ schema: &DFSchemaRef,
+ ) -> Result<Expression> {
+ from_in_list(self, in_list, schema)
+ }
+
+ fn handle_in_subquery(
+ &mut self,
+ in_subquery: &InSubquery,
+ schema: &DFSchemaRef,
+ ) -> Result<Expression> {
+ from_in_subquery(self, in_subquery, schema)
+ }
+}
+
+struct DefaultSubstraitProducer<'a> {
+ extensions: Extensions,
+ serializer_registry: &'a dyn SerializerRegistry,
+}
+
+impl<'a> DefaultSubstraitProducer<'a> {
+ pub fn new(state: &'a SessionState) -> Self {
+ DefaultSubstraitProducer {
+ extensions: Extensions::default(),
+ serializer_registry: state.serializer_registry().as_ref(),
+ }
+ }
+}
+
+impl SubstraitProducer for DefaultSubstraitProducer<'_> {
+ fn register_function(&mut self, fn_name: String) -> u32 {
+ self.extensions.register_function(fn_name)
+ }
+
+ fn get_extensions(self) -> Extensions {
+ self.extensions
+ }
+
+ fn handle_extension(&mut self, plan: &Extension) -> Result<Box<Rel>> {
+ let extension_bytes = self
+ .serializer_registry
+ .serialize_logical_plan(plan.node.as_ref())?;
+ let detail = ProtoAny {
+ type_url: plan.node.name().to_string(),
+ value: extension_bytes.into(),
+ };
+ let mut inputs_rel = plan
+ .node
+ .inputs()
+ .into_iter()
+ .map(|plan| self.handle_plan(plan))
+ .collect::<Result<Vec<_>>>()?;
+ let rel_type = match inputs_rel.len() {
+ 0 => RelType::ExtensionLeaf(ExtensionLeafRel {
+ common: None,
+ detail: Some(detail),
+ }),
+ 1 => RelType::ExtensionSingle(Box::new(ExtensionSingleRel {
+ common: None,
+ detail: Some(detail),
+ input: Some(inputs_rel.pop().unwrap()),
+ })),
+ _ => RelType::ExtensionMulti(ExtensionMultiRel {
+ common: None,
+ detail: Some(detail),
+ inputs: inputs_rel.into_iter().map(|r| *r).collect(),
+ }),
+ };
+ Ok(Box::new(Rel {
+ rel_type: Some(rel_type),
+ }))
+ }
+}
/// Convert DataFusion LogicalPlan to Substrait Plan
-pub fn to_substrait_plan(
- plan: &LogicalPlan,
- state: &dyn SubstraitPlanningState,
-) -> Result<Box<Plan>> {
- let mut extensions = Extensions::default();
+pub fn to_substrait_plan(plan: &LogicalPlan, state: &SessionState) ->
Result<Box<Plan>> {
// Parse relation nodes
// Generate PlanRel(s)
// Note: Only 1 relation tree is currently supported
@@ -117,14 +437,16 @@ pub fn to_substrait_plan(
let plan = Arc::new(ExpandWildcardRule::new())
.analyze(plan.clone(), &ConfigOptions::default())?;
+ let mut producer: DefaultSubstraitProducer =
DefaultSubstraitProducer::new(state);
let plan_rels = vec![PlanRel {
rel_type: Some(plan_rel::RelType::Root(RelRoot {
- input: Some(*to_substrait_rel(&plan, state, &mut extensions)?),
+ input: Some(*producer.handle_plan(&plan)?),
names: to_substrait_named_struct(plan.schema())?.names,
})),
}];
// Return parsed plan
+ let extensions = producer.get_extensions();
Ok(Box::new(Plan {
version: Some(version::version_with_producer("datafusion")),
extension_uris: vec![],
@@ -150,20 +472,13 @@ pub fn to_substrait_plan(
pub fn to_substrait_extended_expr(
exprs: &[(&Expr, &Field)],
schema: &DFSchemaRef,
- state: &dyn SubstraitPlanningState,
+ state: &SessionState,
) -> Result<Box<ExtendedExpression>> {
- let mut extensions = Extensions::default();
-
+ let mut producer = DefaultSubstraitProducer::new(state);
let substrait_exprs = exprs
.iter()
.map(|(expr, field)| {
- let substrait_expr = to_substrait_rex(
- state,
- expr,
- schema,
- /*col_ref_offset=*/ 0,
- &mut extensions,
- )?;
+ let substrait_expr = producer.handle_expr(expr, schema)?;
let mut output_names = Vec::new();
flatten_names(field, false, &mut output_names)?;
Ok(ExpressionReference {
@@ -174,6 +489,7 @@ pub fn to_substrait_extended_expr(
.collect::<Result<Vec<_>>>()?;
let substrait_schema = to_substrait_named_struct(schema)?;
+ let extensions = producer.get_extensions();
Ok(Box::new(ExtendedExpression {
advanced_extensions: None,
expected_type_urls: vec![],
@@ -185,257 +501,303 @@ pub fn to_substrait_extended_expr(
}))
}
-/// Convert DataFusion LogicalPlan to Substrait Rel
-#[allow(deprecated)]
pub fn to_substrait_rel(
+ producer: &mut impl SubstraitProducer,
plan: &LogicalPlan,
- state: &dyn SubstraitPlanningState,
- extensions: &mut Extensions,
) -> Result<Box<Rel>> {
match plan {
- LogicalPlan::TableScan(scan) => {
- let projection = scan.projection.as_ref().map(|p| {
- p.iter()
- .map(|i| StructItem {
- field: *i as i32,
- child: None,
- })
- .collect()
- });
+ LogicalPlan::Projection(plan) => producer.handle_projection(plan),
+ LogicalPlan::Filter(plan) => producer.handle_filter(plan),
+ LogicalPlan::Window(plan) => producer.handle_window(plan),
+ LogicalPlan::Aggregate(plan) => producer.handle_aggregate(plan),
+ LogicalPlan::Sort(plan) => producer.handle_sort(plan),
+ LogicalPlan::Join(plan) => producer.handle_join(plan),
+ LogicalPlan::Repartition(plan) => producer.handle_repartition(plan),
+ LogicalPlan::Union(plan) => producer.handle_union(plan),
+ LogicalPlan::TableScan(plan) => producer.handle_table_scan(plan),
+ LogicalPlan::EmptyRelation(plan) =>
producer.handle_empty_relation(plan),
+ LogicalPlan::Subquery(plan) => not_impl_err!("Unsupported plan type:
{plan:?}")?,
+ LogicalPlan::SubqueryAlias(plan) =>
producer.handle_subquery_alias(plan),
+ LogicalPlan::Limit(plan) => producer.handle_limit(plan),
+ LogicalPlan::Statement(plan) => not_impl_err!("Unsupported plan type:
{plan:?}")?,
+ LogicalPlan::Values(plan) => producer.handle_values(plan),
+ LogicalPlan::Explain(plan) => not_impl_err!("Unsupported plan type:
{plan:?}")?,
+ LogicalPlan::Analyze(plan) => not_impl_err!("Unsupported plan type:
{plan:?}")?,
+ LogicalPlan::Extension(plan) => producer.handle_extension(plan),
+ LogicalPlan::Distinct(plan) => producer.handle_distinct(plan),
+ LogicalPlan::Dml(plan) => not_impl_err!("Unsupported plan type:
{plan:?}")?,
+ LogicalPlan::Ddl(plan) => not_impl_err!("Unsupported plan type:
{plan:?}")?,
+ LogicalPlan::Copy(plan) => not_impl_err!("Unsupported plan type:
{plan:?}")?,
+ LogicalPlan::DescribeTable(plan) => {
+ not_impl_err!("Unsupported plan type: {plan:?}")?
+ }
+ LogicalPlan::Unnest(plan) => not_impl_err!("Unsupported plan type:
{plan:?}")?,
+ LogicalPlan::RecursiveQuery(plan) => {
+ not_impl_err!("Unsupported plan type: {plan:?}")?
+ }
+ }
+}
- let projection = projection.map(|struct_items| MaskExpression {
- select: Some(StructSelect { struct_items }),
- maintain_singular_struct: false,
- });
+pub fn from_table_scan(
+ _producer: &mut impl SubstraitProducer,
+ scan: &TableScan,
+) -> Result<Box<Rel>> {
+ let projection = scan.projection.as_ref().map(|p| {
+ p.iter()
+ .map(|i| StructItem {
+ field: *i as i32,
+ child: None,
+ })
+ .collect()
+ });
+
+ let projection = projection.map(|struct_items| MaskExpression {
+ select: Some(StructSelect { struct_items }),
+ maintain_singular_struct: false,
+ });
+
+ let table_schema = scan.source.schema().to_dfschema_ref()?;
+ let base_schema = to_substrait_named_struct(&table_schema)?;
+
+ Ok(Box::new(Rel {
+ rel_type: Some(RelType::Read(Box::new(ReadRel {
+ common: None,
+ base_schema: Some(base_schema),
+ filter: None,
+ best_effort_filter: None,
+ projection,
+ advanced_extension: None,
+ read_type: Some(ReadType::NamedTable(NamedTable {
+ names: scan.table_name.to_vec(),
+ advanced_extension: None,
+ })),
+ }))),
+ }))
+}
- let table_schema = scan.source.schema().to_dfschema_ref()?;
- let base_schema = to_substrait_named_struct(&table_schema)?;
+pub fn from_empty_relation(e: &EmptyRelation) -> Result<Box<Rel>> {
+ if e.produce_one_row {
+ return not_impl_err!("Producing a row from empty relation is
unsupported");
+ }
+ #[allow(deprecated)]
+ Ok(Box::new(Rel {
+ rel_type: Some(RelType::Read(Box::new(ReadRel {
+ common: None,
+ base_schema: Some(to_substrait_named_struct(&e.schema)?),
+ filter: None,
+ best_effort_filter: None,
+ projection: None,
+ advanced_extension: None,
+ read_type: Some(ReadType::VirtualTable(VirtualTable {
+ values: vec![],
+ expressions: vec![],
+ })),
+ }))),
+ }))
+}
- Ok(Box::new(Rel {
- rel_type: Some(RelType::Read(Box::new(ReadRel {
- common: None,
- base_schema: Some(base_schema),
- filter: None,
- best_effort_filter: None,
- projection,
- advanced_extension: None,
- read_type: Some(ReadType::NamedTable(NamedTable {
- names: scan.table_name.to_vec(),
- advanced_extension: None,
- })),
- }))),
- }))
- }
- LogicalPlan::EmptyRelation(e) => {
- if e.produce_one_row {
- return not_impl_err!(
- "Producing a row from empty relation is unsupported"
- );
- }
- Ok(Box::new(Rel {
- rel_type: Some(RelType::Read(Box::new(ReadRel {
- common: None,
- base_schema: Some(to_substrait_named_struct(&e.schema)?),
- filter: None,
- best_effort_filter: None,
- projection: None,
- advanced_extension: None,
- read_type: Some(ReadType::VirtualTable(VirtualTable {
- values: vec![],
- expressions: vec![],
- })),
- }))),
- }))
- }
- LogicalPlan::Values(v) => {
- let values = v
- .values
+pub fn from_values(
+ producer: &mut impl SubstraitProducer,
+ v: &Values,
+) -> Result<Box<Rel>> {
+ let values = v
+ .values
+ .iter()
+ .map(|row| {
+ let fields = row
.iter()
- .map(|row| {
- let fields = row
- .iter()
- .map(|v| match v {
- Expr::Literal(sv) => to_substrait_literal(sv,
extensions),
- Expr::Alias(alias) => match alias.expr.as_ref() {
- // The schema gives us the names, so we can
skip aliases
- Expr::Literal(sv) => to_substrait_literal(sv,
extensions),
- _ => Err(substrait_datafusion_err!(
+ .map(|v| match v {
+ Expr::Literal(sv) => to_substrait_literal(producer, sv),
+ Expr::Alias(alias) => match alias.expr.as_ref() {
+ // The schema gives us the names, so we can skip
aliases
+ Expr::Literal(sv) => to_substrait_literal(producer,
sv),
+ _ => Err(substrait_datafusion_err!(
"Only literal types can be aliased in
Virtual Tables, got: {}", alias.expr.variant_name()
)),
- },
- _ => Err(substrait_datafusion_err!(
+ },
+ _ => Err(substrait_datafusion_err!(
"Only literal types and aliases are supported
in Virtual Tables, got: {}", v.variant_name()
)),
- })
- .collect::<Result<_>>()?;
- Ok(Struct { fields })
})
.collect::<Result<_>>()?;
- Ok(Box::new(Rel {
- rel_type: Some(RelType::Read(Box::new(ReadRel {
- common: None,
- base_schema: Some(to_substrait_named_struct(&v.schema)?),
- filter: None,
- best_effort_filter: None,
- projection: None,
- advanced_extension: None,
- read_type: Some(ReadType::VirtualTable(VirtualTable {
- values,
- expressions: vec![],
- })),
- }))),
- }))
- }
- LogicalPlan::Projection(p) => {
- let expressions = p
- .expr
- .iter()
- .map(|e| to_substrait_rex(state, e, p.input.schema(), 0,
extensions))
- .collect::<Result<Vec<_>>>()?;
+ Ok(Struct { fields })
+ })
+ .collect::<Result<_>>()?;
+ #[allow(deprecated)]
+ Ok(Box::new(Rel {
+ rel_type: Some(RelType::Read(Box::new(ReadRel {
+ common: None,
+ base_schema: Some(to_substrait_named_struct(&v.schema)?),
+ filter: None,
+ best_effort_filter: None,
+ projection: None,
+ advanced_extension: None,
+ read_type: Some(ReadType::VirtualTable(VirtualTable {
+ values,
+ expressions: vec![],
+ })),
+ }))),
+ }))
+}
- let emit_kind = create_project_remapping(
- expressions.len(),
- p.input.as_ref().schema().fields().len(),
- );
- let common = RelCommon {
- emit_kind: Some(emit_kind),
- hint: None,
- advanced_extension: None,
- };
+pub fn from_projection(
+ producer: &mut impl SubstraitProducer,
+ p: &Projection,
+) -> Result<Box<Rel>> {
+ let expressions = p
+ .expr
+ .iter()
+ .map(|e| producer.handle_expr(e, p.input.schema()))
+ .collect::<Result<Vec<_>>>()?;
- Ok(Box::new(Rel {
- rel_type: Some(RelType::Project(Box::new(ProjectRel {
- common: Some(common),
- input: Some(to_substrait_rel(p.input.as_ref(), state,
extensions)?),
- expressions,
- advanced_extension: None,
- }))),
- }))
- }
- LogicalPlan::Filter(filter) => {
- let input = to_substrait_rel(filter.input.as_ref(), state,
extensions)?;
- let filter_expr = to_substrait_rex(
- state,
- &filter.predicate,
- filter.input.schema(),
- 0,
- extensions,
- )?;
- Ok(Box::new(Rel {
- rel_type: Some(RelType::Filter(Box::new(FilterRel {
- common: None,
- input: Some(input),
- condition: Some(Box::new(filter_expr)),
- advanced_extension: None,
- }))),
- }))
- }
- LogicalPlan::Limit(limit) => {
- let input = to_substrait_rel(limit.input.as_ref(), state,
extensions)?;
- let empty_schema = Arc::new(DFSchema::empty());
- let offset_mode = limit
- .skip
- .as_ref()
- .map(|expr| {
- to_substrait_rex(state, expr.as_ref(), &empty_schema, 0,
extensions)
- })
- .transpose()?
- .map(Box::new)
- .map(fetch_rel::OffsetMode::OffsetExpr);
- let count_mode = limit
- .fetch
- .as_ref()
- .map(|expr| {
- to_substrait_rex(state, expr.as_ref(), &empty_schema, 0,
extensions)
- })
- .transpose()?
- .map(Box::new)
- .map(fetch_rel::CountMode::CountExpr);
+ let emit_kind = create_project_remapping(
+ expressions.len(),
+ p.input.as_ref().schema().fields().len(),
+ );
+ let common = RelCommon {
+ emit_kind: Some(emit_kind),
+ hint: None,
+ advanced_extension: None,
+ };
+
+ Ok(Box::new(Rel {
+ rel_type: Some(RelType::Project(Box::new(ProjectRel {
+ common: Some(common),
+ input: Some(producer.handle_plan(p.input.as_ref())?),
+ expressions,
+ advanced_extension: None,
+ }))),
+ }))
+}
+
+pub fn from_filter(
+ producer: &mut impl SubstraitProducer,
+ filter: &Filter,
+) -> Result<Box<Rel>> {
+ let input = producer.handle_plan(filter.input.as_ref())?;
+ let filter_expr = producer.handle_expr(&filter.predicate,
filter.input.schema())?;
+ Ok(Box::new(Rel {
+ rel_type: Some(RelType::Filter(Box::new(FilterRel {
+ common: None,
+ input: Some(input),
+ condition: Some(Box::new(filter_expr)),
+ advanced_extension: None,
+ }))),
+ }))
+}
+
+pub fn from_limit(
+ producer: &mut impl SubstraitProducer,
+ limit: &Limit,
+) -> Result<Box<Rel>> {
+ let input = producer.handle_plan(limit.input.as_ref())?;
+ let empty_schema = Arc::new(DFSchema::empty());
+ let offset_mode = limit
+ .skip
+ .as_ref()
+ .map(|expr| producer.handle_expr(expr.as_ref(), &empty_schema))
+ .transpose()?
+ .map(Box::new)
+ .map(fetch_rel::OffsetMode::OffsetExpr);
+ let count_mode = limit
+ .fetch
+ .as_ref()
+ .map(|expr| producer.handle_expr(expr.as_ref(), &empty_schema))
+ .transpose()?
+ .map(Box::new)
+ .map(fetch_rel::CountMode::CountExpr);
+ Ok(Box::new(Rel {
+ rel_type: Some(RelType::Fetch(Box::new(FetchRel {
+ common: None,
+ input: Some(input),
+ offset_mode,
+ count_mode,
+ advanced_extension: None,
+ }))),
+ }))
+}
+
+pub fn from_sort(producer: &mut impl SubstraitProducer, sort: &Sort) ->
Result<Box<Rel>> {
+ let Sort { expr, input, fetch } = sort;
+ let sort_fields = expr
+ .iter()
+ .map(|e| substrait_sort_field(producer, e, input.schema()))
+ .collect::<Result<Vec<_>>>()?;
+
+ let input = producer.handle_plan(input.as_ref())?;
+
+ let sort_rel = Box::new(Rel {
+ rel_type: Some(RelType::Sort(Box::new(SortRel {
+ common: None,
+ input: Some(input),
+ sorts: sort_fields,
+ advanced_extension: None,
+ }))),
+ });
+
+ match fetch {
+ Some(amount) => {
+ let count_mode =
+ Some(fetch_rel::CountMode::CountExpr(Box::new(Expression {
+ rex_type: Some(RexType::Literal(Literal {
+ nullable: false,
+ type_variation_reference: DEFAULT_TYPE_VARIATION_REF,
+ literal_type: Some(LiteralType::I64(*amount as i64)),
+ })),
+ })));
Ok(Box::new(Rel {
rel_type: Some(RelType::Fetch(Box::new(FetchRel {
common: None,
- input: Some(input),
- offset_mode,
+ input: Some(sort_rel),
+ offset_mode: None,
count_mode,
advanced_extension: None,
}))),
}))
}
- LogicalPlan::Sort(datafusion::logical_expr::Sort { expr, input, fetch
}) => {
- let sort_fields = expr
- .iter()
- .map(|e| substrait_sort_field(state, e, input.schema(),
extensions))
- .collect::<Result<Vec<_>>>()?;
+ None => Ok(sort_rel),
+ }
+}
- let input = to_substrait_rel(input.as_ref(), state, extensions)?;
+pub fn from_aggregate(
+ producer: &mut impl SubstraitProducer,
+ agg: &Aggregate,
+) -> Result<Box<Rel>> {
+ let input = producer.handle_plan(agg.input.as_ref())?;
+ let (grouping_expressions, groupings) =
+ to_substrait_groupings(producer, &agg.group_expr, agg.input.schema())?;
+ let measures = agg
+ .aggr_expr
+ .iter()
+ .map(|e| to_substrait_agg_measure(producer, e, agg.input.schema()))
+ .collect::<Result<Vec<_>>>()?;
- let sort_rel = Box::new(Rel {
- rel_type: Some(RelType::Sort(Box::new(SortRel {
- common: None,
- input: Some(input),
- sorts: sort_fields,
- advanced_extension: None,
- }))),
- });
-
- match fetch {
- Some(amount) => {
- let count_mode =
-
Some(fetch_rel::CountMode::CountExpr(Box::new(Expression {
- rex_type: Some(RexType::Literal(Literal {
- nullable: false,
- type_variation_reference:
DEFAULT_TYPE_VARIATION_REF,
- literal_type: Some(LiteralType::I64(*amount as
i64)),
- })),
- })));
- Ok(Box::new(Rel {
- rel_type: Some(RelType::Fetch(Box::new(FetchRel {
- common: None,
- input: Some(sort_rel),
- offset_mode: None,
- count_mode,
- advanced_extension: None,
- }))),
- }))
- }
- None => Ok(sort_rel),
- }
- }
- LogicalPlan::Aggregate(agg) => {
- let input = to_substrait_rel(agg.input.as_ref(), state,
extensions)?;
- let (grouping_expressions, groupings) = to_substrait_groupings(
- state,
- &agg.group_expr,
- agg.input.schema(),
- extensions,
- )?;
- let measures = agg
- .aggr_expr
- .iter()
- .map(|e| {
- to_substrait_agg_measure(state, e, agg.input.schema(),
extensions)
- })
- .collect::<Result<Vec<_>>>()?;
+ Ok(Box::new(Rel {
+ rel_type: Some(RelType::Aggregate(Box::new(AggregateRel {
+ common: None,
+ input: Some(input),
+ grouping_expressions,
+ groupings,
+ measures,
+ advanced_extension: None,
+ }))),
+ }))
+}
- Ok(Box::new(Rel {
- rel_type: Some(RelType::Aggregate(Box::new(AggregateRel {
- common: None,
- input: Some(input),
- grouping_expressions,
- groupings,
- measures,
- advanced_extension: None,
- }))),
- }))
- }
- LogicalPlan::Distinct(Distinct::All(plan)) => {
+pub fn from_distinct(
+ producer: &mut impl SubstraitProducer,
+ distinct: &Distinct,
+) -> Result<Box<Rel>> {
+ match distinct {
+ Distinct::All(plan) => {
// Use Substrait's AggregateRel with empty measures to represent
`select distinct`
- let input = to_substrait_rel(plan.as_ref(), state, extensions)?;
+ let input = producer.handle_plan(plan.as_ref())?;
// Get grouping keys from the input relation's number of output
fields
let grouping = (0..plan.schema().fields().len())
.map(substrait_field_ref)
.collect::<Result<Vec<_>>>()?;
+ #[allow(deprecated)]
Ok(Box::new(Rel {
rel_type: Some(RelType::Aggregate(Box::new(AggregateRel {
common: None,
@@ -450,220 +812,176 @@ pub fn to_substrait_rel(
}))),
}))
}
- LogicalPlan::Join(join) => {
- let left = to_substrait_rel(join.left.as_ref(), state,
extensions)?;
- let right = to_substrait_rel(join.right.as_ref(), state,
extensions)?;
- let join_type = to_substrait_jointype(join.join_type);
- // we only support basic joins so return an error for anything not
yet supported
- match join.join_constraint {
- JoinConstraint::On => {}
- JoinConstraint::Using => {
- return not_impl_err!("join constraint: `using`")
- }
- }
- // parse filter if exists
- let in_join_schema = join.left.schema().join(join.right.schema())?;
- let join_filter = match &join.filter {
- Some(filter) => Some(to_substrait_rex(
- state,
- filter,
- &Arc::new(in_join_schema),
- 0,
- extensions,
- )?),
- None => None,
- };
+ Distinct::On(_) => not_impl_err!("Cannot convert Distinct::On"),
+ }
+}
- // map the left and right columns to binary expressions in the
form `l = r`
- // build a single expression for the ON condition, such as `l.a =
r.a AND l.b = r.b`
- let eq_op = if join.null_equals_null {
- Operator::IsNotDistinctFrom
- } else {
- Operator::Eq
- };
- let join_on = to_substrait_join_expr(
- state,
- &join.on,
- eq_op,
- join.left.schema(),
- join.right.schema(),
- extensions,
- )?;
-
- // create conjunction between `join_on` and `join_filter` to embed
all join conditions,
- // whether equal or non-equal in a single expression
- let join_expr = match &join_on {
- Some(on_expr) => match &join_filter {
- Some(filter) => Some(Box::new(make_binary_op_scalar_func(
- on_expr,
- filter,
- Operator::And,
- extensions,
- ))),
- None => join_on.map(Box::new), // the join expression will
only contain `join_on` if filter doesn't exist
- },
- None => match &join_filter {
- Some(_) => join_filter.map(Box::new), // the join
expression will only contain `join_filter` if the `on` condition doesn't exist
- None => None,
- },
- };
+pub fn from_join(producer: &mut impl SubstraitProducer, join: &Join) ->
Result<Box<Rel>> {
+ let left = producer.handle_plan(join.left.as_ref())?;
+ let right = producer.handle_plan(join.right.as_ref())?;
+ let join_type = to_substrait_jointype(join.join_type);
+ // we only support basic joins so return an error for anything not yet
supported
+ match join.join_constraint {
+ JoinConstraint::On => {}
+ JoinConstraint::Using => return not_impl_err!("join constraint:
`using`"),
+ }
+ let in_join_schema =
Arc::new(join.left.schema().join(join.right.schema())?);
- Ok(Box::new(Rel {
- rel_type: Some(RelType::Join(Box::new(JoinRel {
- common: None,
- left: Some(left),
- right: Some(right),
- r#type: join_type as i32,
- expression: join_expr,
- post_join_filter: None,
- advanced_extension: None,
- }))),
- }))
- }
- LogicalPlan::SubqueryAlias(alias) => {
- // Do nothing if encounters SubqueryAlias
- // since there is no corresponding relation type in Substrait
- to_substrait_rel(alias.input.as_ref(), state, extensions)
- }
- LogicalPlan::Union(union) => {
- let input_rels = union
- .inputs
- .iter()
- .map(|input| to_substrait_rel(input.as_ref(), state,
extensions))
- .collect::<Result<Vec<_>>>()?
- .into_iter()
- .map(|ptr| *ptr)
- .collect();
- Ok(Box::new(Rel {
- rel_type: Some(RelType::Set(SetRel {
- common: None,
- inputs: input_rels,
- op: set_rel::SetOp::UnionAll as i32, // UNION DISTINCT
gets translated to AGGREGATION + UNION ALL
- advanced_extension: None,
- })),
- }))
- }
- LogicalPlan::Window(window) => {
- let input = to_substrait_rel(window.input.as_ref(), state,
extensions)?;
+ // convert filter if present
+ let join_filter = match &join.filter {
+ Some(filter) => Some(producer.handle_expr(filter, &in_join_schema)?),
+ None => None,
+ };
- // create a field reference for each input field
- let mut expressions = (0..window.input.schema().fields().len())
- .map(substrait_field_ref)
- .collect::<Result<Vec<_>>>()?;
+ // map the left and right columns to binary expressions in the form `l = r`
+ // build a single expression for the ON condition, such as `l.a = r.a AND
l.b = r.b`
+ let eq_op = if join.null_equals_null {
+ Operator::IsNotDistinctFrom
+ } else {
+ Operator::Eq
+ };
+ let join_on = to_substrait_join_expr(producer, &join.on, eq_op,
&in_join_schema)?;
+
+ // create conjunction between `join_on` and `join_filter` to embed all
join conditions,
+ // whether equal or non-equal in a single expression
+ let join_expr = match &join_on {
+ Some(on_expr) => match &join_filter {
+ Some(filter) => Some(Box::new(make_binary_op_scalar_func(
+ producer,
+ on_expr,
+ filter,
+ Operator::And,
+ ))),
+ None => join_on.map(Box::new), // the join expression will only
contain `join_on` if filter doesn't exist
+ },
+ None => match &join_filter {
+ Some(_) => join_filter.map(Box::new), // the join expression will
only contain `join_filter` if the `on` condition doesn't exist
+ None => None,
+ },
+ };
- // process and add each window function expression
- for expr in &window.window_expr {
- expressions.push(to_substrait_rex(
- state,
- expr,
- window.input.schema(),
- 0,
- extensions,
- )?);
- }
+ Ok(Box::new(Rel {
+ rel_type: Some(RelType::Join(Box::new(JoinRel {
+ common: None,
+ left: Some(left),
+ right: Some(right),
+ r#type: join_type as i32,
+ expression: join_expr,
+ post_join_filter: None,
+ advanced_extension: None,
+ }))),
+ }))
+}
- let emit_kind = create_project_remapping(
- expressions.len(),
- window.input.schema().fields().len(),
- );
- let common = RelCommon {
- emit_kind: Some(emit_kind),
- hint: None,
- advanced_extension: None,
- };
- let project_rel = Box::new(ProjectRel {
- common: Some(common),
- input: Some(input),
- expressions,
- advanced_extension: None,
- });
+pub fn from_subquery_alias(
+ producer: &mut impl SubstraitProducer,
+ alias: &SubqueryAlias,
+) -> Result<Box<Rel>> {
+ // Do nothing if encounters SubqueryAlias
+ // since there is no corresponding relation type in Substrait
+ producer.handle_plan(alias.input.as_ref())
+}
- Ok(Box::new(Rel {
- rel_type: Some(RelType::Project(project_rel)),
- }))
+pub fn from_union(
+ producer: &mut impl SubstraitProducer,
+ union: &Union,
+) -> Result<Box<Rel>> {
+ let input_rels = union
+ .inputs
+ .iter()
+ .map(|input| producer.handle_plan(input.as_ref()))
+ .collect::<Result<Vec<_>>>()?
+ .into_iter()
+ .map(|ptr| *ptr)
+ .collect();
+ Ok(Box::new(Rel {
+ rel_type: Some(RelType::Set(SetRel {
+ common: None,
+ inputs: input_rels,
+ op: set_rel::SetOp::UnionAll as i32, // UNION DISTINCT gets
translated to AGGREGATION + UNION ALL
+ advanced_extension: None,
+ })),
+ }))
+}
+
+pub fn from_window(
+ producer: &mut impl SubstraitProducer,
+ window: &Window,
+) -> Result<Box<Rel>> {
+ let input = producer.handle_plan(window.input.as_ref())?;
+
+ // create a field reference for each input field
+ let mut expressions = (0..window.input.schema().fields().len())
+ .map(substrait_field_ref)
+ .collect::<Result<Vec<_>>>()?;
+
+ // process and add each window function expression
+ for expr in &window.window_expr {
+ expressions.push(producer.handle_expr(expr, window.input.schema())?);
+ }
+
+ let emit_kind =
+ create_project_remapping(expressions.len(),
window.input.schema().fields().len());
+ let common = RelCommon {
+ emit_kind: Some(emit_kind),
+ hint: None,
+ advanced_extension: None,
+ };
+ let project_rel = Box::new(ProjectRel {
+ common: Some(common),
+ input: Some(input),
+ expressions,
+ advanced_extension: None,
+ });
+
+ Ok(Box::new(Rel {
+ rel_type: Some(RelType::Project(project_rel)),
+ }))
+}
+
+pub fn from_repartition(
+ producer: &mut impl SubstraitProducer,
+ repartition: &Repartition,
+) -> Result<Box<Rel>> {
+ let input = producer.handle_plan(repartition.input.as_ref())?;
+ let partition_count = match repartition.partitioning_scheme {
+ Partitioning::RoundRobinBatch(num) => num,
+ Partitioning::Hash(_, num) => num,
+ Partitioning::DistributeBy(_) => {
+ return not_impl_err!(
+ "Physical plan does not support DistributeBy partitioning"
+ )
}
- LogicalPlan::Repartition(repartition) => {
- let input = to_substrait_rel(repartition.input.as_ref(), state,
extensions)?;
- let partition_count = match repartition.partitioning_scheme {
- Partitioning::RoundRobinBatch(num) => num,
- Partitioning::Hash(_, num) => num,
- Partitioning::DistributeBy(_) => {
- return not_impl_err!(
- "Physical plan does not support DistributeBy
partitioning"
- )
- }
- };
- // ref:
https://substrait.io/relations/physical_relations/#exchange-types
- let exchange_kind = match &repartition.partitioning_scheme {
- Partitioning::RoundRobinBatch(_) => {
- ExchangeKind::RoundRobin(RoundRobin::default())
- }
- Partitioning::Hash(exprs, _) => {
- let fields = exprs
- .iter()
- .map(|e| {
- try_to_substrait_field_reference(
- e,
- repartition.input.schema(),
- )
- })
- .collect::<Result<Vec<_>>>()?;
- ExchangeKind::ScatterByFields(ScatterFields { fields })
- }
- Partitioning::DistributeBy(_) => {
- return not_impl_err!(
- "Physical plan does not support DistributeBy
partitioning"
- )
- }
- };
- let exchange_rel = ExchangeRel {
- common: None,
- input: Some(input),
- exchange_kind: Some(exchange_kind),
- advanced_extension: None,
- partition_count: partition_count as i32,
- targets: vec![],
- };
- Ok(Box::new(Rel {
- rel_type: Some(RelType::Exchange(Box::new(exchange_rel))),
- }))
+ };
+ // ref: https://substrait.io/relations/physical_relations/#exchange-types
+ let exchange_kind = match &repartition.partitioning_scheme {
+ Partitioning::RoundRobinBatch(_) => {
+ ExchangeKind::RoundRobin(RoundRobin::default())
}
- LogicalPlan::Extension(extension_plan) => {
- let extension_bytes = state
- .serializer_registry()
- .serialize_logical_plan(extension_plan.node.as_ref())?;
- let detail = ProtoAny {
- type_url: extension_plan.node.name().to_string(),
- value: extension_bytes.into(),
- };
- let mut inputs_rel = extension_plan
- .node
- .inputs()
- .into_iter()
- .map(|plan| to_substrait_rel(plan, state, extensions))
+ Partitioning::Hash(exprs, _) => {
+ let fields = exprs
+ .iter()
+ .map(|e| try_to_substrait_field_reference(e,
repartition.input.schema()))
.collect::<Result<Vec<_>>>()?;
- let rel_type = match inputs_rel.len() {
- 0 => RelType::ExtensionLeaf(ExtensionLeafRel {
- common: None,
- detail: Some(detail),
- }),
- 1 => RelType::ExtensionSingle(Box::new(ExtensionSingleRel {
- common: None,
- detail: Some(detail),
- input: Some(inputs_rel.pop().unwrap()),
- })),
- _ => RelType::ExtensionMulti(ExtensionMultiRel {
- common: None,
- detail: Some(detail),
- inputs: inputs_rel.into_iter().map(|r| *r).collect(),
- }),
- };
- Ok(Box::new(Rel {
- rel_type: Some(rel_type),
- }))
+ ExchangeKind::ScatterByFields(ScatterFields { fields })
}
- _ => not_impl_err!("Unsupported operator: {plan}"),
- }
+ Partitioning::DistributeBy(_) => {
+ return not_impl_err!(
+ "Physical plan does not support DistributeBy partitioning"
+ )
+ }
+ };
+ let exchange_rel = ExchangeRel {
+ common: None,
+ input: Some(input),
+ exchange_kind: Some(exchange_kind),
+ advanced_extension: None,
+ partition_count: partition_count as i32,
+ targets: vec![],
+ };
+ Ok(Box::new(Rel {
+ rel_type: Some(RelType::Exchange(Box::new(exchange_rel))),
+ }))
}
/// By default, a Substrait Project outputs all input fields followed by all
expressions.
@@ -730,32 +1048,23 @@ fn to_substrait_named_struct(schema: &DFSchemaRef) ->
Result<NamedStruct> {
}
fn to_substrait_join_expr(
- state: &dyn SubstraitPlanningState,
+ producer: &mut impl SubstraitProducer,
join_conditions: &Vec<(Expr, Expr)>,
eq_op: Operator,
- left_schema: &DFSchemaRef,
- right_schema: &DFSchemaRef,
- extensions: &mut Extensions,
+ join_schema: &DFSchemaRef,
) -> Result<Option<Expression>> {
// Only support AND conjunction for each binary expression in join
conditions
let mut exprs: Vec<Expression> = vec![];
for (left, right) in join_conditions {
- // Parse left
- let l = to_substrait_rex(state, left, left_schema, 0, extensions)?;
- // Parse right
- let r = to_substrait_rex(
- state,
- right,
- right_schema,
- left_schema.fields().len(), // offset to return the correct index
- extensions,
- )?;
+ let l = producer.handle_expr(left, join_schema)?;
+ let r = producer.handle_expr(right, join_schema)?;
// AND with existing expression
- exprs.push(make_binary_op_scalar_func(&l, &r, eq_op, extensions));
+ exprs.push(make_binary_op_scalar_func(producer, &l, &r, eq_op));
}
+
let join_expr: Option<Expression> =
exprs.into_iter().reduce(|acc: Expression, e: Expression| {
- make_binary_op_scalar_func(&acc, &e, Operator::And, extensions)
+ make_binary_op_scalar_func(producer, &acc, &e, Operator::And)
});
Ok(join_expr)
}
@@ -811,23 +1120,22 @@ pub fn operator_to_name(op: Operator) -> &'static str {
}
}
-#[allow(deprecated)]
pub fn parse_flat_grouping_exprs(
- state: &dyn SubstraitPlanningState,
+ producer: &mut impl SubstraitProducer,
exprs: &[Expr],
schema: &DFSchemaRef,
- extensions: &mut Extensions,
ref_group_exprs: &mut Vec<Expression>,
) -> Result<Grouping> {
let mut expression_references = vec![];
let mut grouping_expressions = vec![];
for e in exprs {
- let rex = to_substrait_rex(state, e, schema, 0, extensions)?;
+ let rex = producer.handle_expr(e, schema)?;
grouping_expressions.push(rex.clone());
ref_group_exprs.push(rex);
expression_references.push((ref_group_exprs.len() - 1) as u32);
}
+ #[allow(deprecated)]
Ok(Grouping {
grouping_expressions,
expression_references,
@@ -835,10 +1143,9 @@ pub fn parse_flat_grouping_exprs(
}
pub fn to_substrait_groupings(
- state: &dyn SubstraitPlanningState,
+ producer: &mut impl SubstraitProducer,
exprs: &[Expr],
schema: &DFSchemaRef,
- extensions: &mut Extensions,
) -> Result<(Vec<Expression>, Vec<Grouping>)> {
let mut ref_group_exprs = vec![];
let groupings = match exprs.len() {
@@ -851,10 +1158,9 @@ pub fn to_substrait_groupings(
.iter()
.map(|set| {
parse_flat_grouping_exprs(
- state,
+ producer,
set,
schema,
- extensions,
&mut ref_group_exprs,
)
})
@@ -869,10 +1175,9 @@ pub fn to_substrait_groupings(
.rev()
.map(|set| {
parse_flat_grouping_exprs(
- state,
+ producer,
set,
schema,
- extensions,
&mut ref_group_exprs,
)
})
@@ -880,66 +1185,81 @@ pub fn to_substrait_groupings(
}
},
_ => Ok(vec![parse_flat_grouping_exprs(
- state,
+ producer,
exprs,
schema,
- extensions,
&mut ref_group_exprs,
)?]),
},
_ => Ok(vec![parse_flat_grouping_exprs(
- state,
+ producer,
exprs,
schema,
- extensions,
&mut ref_group_exprs,
)?]),
}?;
Ok((ref_group_exprs, groupings))
}
-#[allow(deprecated)]
+pub fn from_aggregate_function(
+ producer: &mut impl SubstraitProducer,
+ agg_fn: &expr::AggregateFunction,
+ schema: &DFSchemaRef,
+) -> Result<Measure> {
+ let expr::AggregateFunction {
+ func,
+ args,
+ distinct,
+ filter,
+ order_by,
+ null_treatment: _null_treatment,
+ } = agg_fn;
+ let sorts = if let Some(order_by) = order_by {
+ order_by
+ .iter()
+ .map(|expr| to_substrait_sort_field(producer, expr, schema))
+ .collect::<Result<Vec<_>>>()?
+ } else {
+ vec![]
+ };
+ let mut arguments: Vec<FunctionArgument> = vec![];
+ for arg in args {
+ arguments.push(FunctionArgument {
+ arg_type: Some(ArgType::Value(producer.handle_expr(arg, schema)?)),
+ });
+ }
+ let function_anchor = producer.register_function(func.name().to_string());
+ #[allow(deprecated)]
+ Ok(Measure {
+ measure: Some(AggregateFunction {
+ function_reference: function_anchor,
+ arguments,
+ sorts,
+ output_type: None,
+ invocation: match distinct {
+ true => AggregationInvocation::Distinct as i32,
+ false => AggregationInvocation::All as i32,
+ },
+ phase: AggregationPhase::Unspecified as i32,
+ args: vec![],
+ options: vec![],
+ }),
+ filter: match filter {
+ Some(f) => Some(producer.handle_expr(f, schema)?),
+ None => None,
+ },
+ })
+}
+
pub fn to_substrait_agg_measure(
- state: &dyn SubstraitPlanningState,
+ producer: &mut impl SubstraitProducer,
expr: &Expr,
schema: &DFSchemaRef,
- extensions: &mut Extensions,
) -> Result<Measure> {
match expr {
- Expr::AggregateFunction(expr::AggregateFunction { func, args,
distinct, filter, order_by, null_treatment: _, }) => {
- let sorts = if let Some(order_by) = order_by {
- order_by.iter().map(|expr|
to_substrait_sort_field(state, expr, schema,
extensions)).collect::<Result<Vec<_>>>()?
- } else {
- vec![]
- };
- let mut arguments: Vec<FunctionArgument> = vec![];
- for arg in args {
- arguments.push(FunctionArgument { arg_type:
Some(ArgType::Value(to_substrait_rex(state, arg, schema, 0, extensions)?)) });
- }
- let function_anchor =
extensions.register_function(func.name().to_string());
- Ok(Measure {
- measure: Some(AggregateFunction {
- function_reference: function_anchor,
- arguments,
- sorts,
- output_type: None,
- invocation: match distinct {
- true => AggregationInvocation::Distinct as i32,
- false => AggregationInvocation::All as i32,
- },
- phase: AggregationPhase::Unspecified as i32,
- args: vec![],
- options: vec![],
- }),
- filter: match filter {
- Some(f) => Some(to_substrait_rex(state, f, schema,
0, extensions)?),
- None => None
- }
- })
-
- }
- Expr::Alias(Alias{expr,..})=> {
- to_substrait_agg_measure(state, expr, schema, extensions)
+ Expr::AggregateFunction(agg_fn) => from_aggregate_function(producer,
agg_fn, schema),
+ Expr::Alias(Alias { expr, .. }) => {
+ to_substrait_agg_measure(producer, expr, schema)
}
_ => internal_err!(
"Expression must be compatible with aggregation. Unsupported
expression: {:?}. ExpressionType: {:?}",
@@ -951,10 +1271,9 @@ pub fn to_substrait_agg_measure(
/// Converts sort expression to corresponding substrait `SortField`
fn to_substrait_sort_field(
- state: &dyn SubstraitPlanningState,
- sort: &Sort,
+ producer: &mut impl SubstraitProducer,
+ sort: &expr::Sort,
schema: &DFSchemaRef,
- extensions: &mut Extensions,
) -> Result<SortField> {
let sort_kind = match (sort.asc, sort.nulls_first) {
(true, true) => SortDirection::AscNullsFirst,
@@ -963,20 +1282,20 @@ fn to_substrait_sort_field(
(false, false) => SortDirection::DescNullsLast,
};
Ok(SortField {
- expr: Some(to_substrait_rex(state, &sort.expr, schema, 0,
extensions)?),
+ expr: Some(producer.handle_expr(&sort.expr, schema)?),
sort_kind: Some(SortKind::Direction(sort_kind.into())),
})
}
/// Return Substrait scalar function with two arguments
-#[allow(deprecated)]
pub fn make_binary_op_scalar_func(
+ producer: &mut impl SubstraitProducer,
lhs: &Expression,
rhs: &Expression,
op: Operator,
- extensions: &mut Extensions,
) -> Expression {
- let function_anchor =
extensions.register_function(operator_to_name(op).to_string());
+ let function_anchor =
producer.register_function(operator_to_name(op).to_string());
+ #[allow(deprecated)]
Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
function_reference: function_anchor,
@@ -998,450 +1317,431 @@ pub fn make_binary_op_scalar_func(
/// Convert DataFusion Expr to Substrait Rex
///
/// # Arguments
-///
-/// * `expr` - DataFusion expression to be parse into a Substrait expression
-/// * `schema` - DataFusion input schema for looking up field qualifiers
-/// * `col_ref_offset` - Offset for calculating Substrait field reference
indices.
-/// This should only be set by caller with more than one
input relations i.e. Join.
-/// Substrait expects one set of indices when joining two
relations.
-/// Let's say `left` and `right` have `m` and `n` columns,
respectively. The `right`
-/// relation will have column indices from `0` to `n-1`,
however, Substrait will expect
-/// the `right` indices to be offset by the `left`. This
means Substrait will expect to
-/// evaluate the join condition expression on indices [0
.. n-1, n .. n+m-1]. For example:
-/// ```SELECT *
-/// FROM t1
-/// JOIN t2
-/// ON t1.c1 = t2.c0;```
-/// where t1 consists of columns [c0, c1, c2], and t2 =
columns [c0, c1]
-/// the join condition should become
-/// `col_ref(1) = col_ref(3 + 0)`
-/// , where `3` is the number of `left` columns
(`col_ref_offset`) and `0` is the index
-/// of the join key column from `right`
-/// * `extensions` - Substrait extension info. Contains registered function
information
-#[allow(deprecated)]
+/// * `producer` - SubstraitProducer implementation which the handles the
actual conversion
+/// * `expr` - DataFusion expression to convert into a Substrait expression
+/// * `schema` - DataFusion input schema for looking up columns
pub fn to_substrait_rex(
- state: &dyn SubstraitPlanningState,
+ producer: &mut impl SubstraitProducer,
expr: &Expr,
schema: &DFSchemaRef,
- col_ref_offset: usize,
- extensions: &mut Extensions,
) -> Result<Expression> {
match expr {
- Expr::InList(InList {
- expr,
- list,
- negated,
- }) => {
- let substrait_list = list
- .iter()
- .map(|x| to_substrait_rex(state, x, schema, col_ref_offset,
extensions))
- .collect::<Result<Vec<Expression>>>()?;
- let substrait_expr =
- to_substrait_rex(state, expr, schema, col_ref_offset,
extensions)?;
-
- let substrait_or_list = Expression {
- rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList
{
- value: Some(Box::new(substrait_expr)),
- options: substrait_list,
- }))),
- };
-
- if *negated {
- let function_anchor =
extensions.register_function("not".to_string());
-
- Ok(Expression {
- rex_type: Some(RexType::ScalarFunction(ScalarFunction {
- function_reference: function_anchor,
- arguments: vec![FunctionArgument {
- arg_type: Some(ArgType::Value(substrait_or_list)),
- }],
- output_type: None,
- args: vec![],
- options: vec![],
- })),
- })
- } else {
- Ok(substrait_or_list)
- }
+ Expr::Alias(expr) => producer.handle_alias(expr, schema),
+ Expr::Column(expr) => producer.handle_column(expr, schema),
+ Expr::ScalarVariable(_, _) => {
+ not_impl_err!("Cannot convert {expr:?} to Substrait")
}
- Expr::ScalarFunction(fun) => {
- let mut arguments: Vec<FunctionArgument> = vec![];
- for arg in &fun.args {
- arguments.push(FunctionArgument {
- arg_type: Some(ArgType::Value(to_substrait_rex(
- state,
- arg,
- schema,
- col_ref_offset,
- extensions,
- )?)),
- });
- }
-
- let function_anchor =
extensions.register_function(fun.name().to_string());
- Ok(Expression {
- rex_type: Some(RexType::ScalarFunction(ScalarFunction {
- function_reference: function_anchor,
- arguments,
- output_type: None,
- args: vec![],
- options: vec![],
- })),
- })
+ Expr::Literal(expr) => producer.handle_literal(expr),
+ Expr::BinaryExpr(expr) => producer.handle_binary_expr(expr, schema),
+ Expr::Like(expr) => producer.handle_like(expr, schema),
+ Expr::SimilarTo(_) => not_impl_err!("Cannot convert {expr:?} to
Substrait"),
+ Expr::Not(_) => producer.handle_unary_expr(expr, schema),
+ Expr::IsNotNull(_) => producer.handle_unary_expr(expr, schema),
+ Expr::IsNull(_) => producer.handle_unary_expr(expr, schema),
+ Expr::IsTrue(_) => producer.handle_unary_expr(expr, schema),
+ Expr::IsFalse(_) => producer.handle_unary_expr(expr, schema),
+ Expr::IsUnknown(_) => producer.handle_unary_expr(expr, schema),
+ Expr::IsNotTrue(_) => producer.handle_unary_expr(expr, schema),
+ Expr::IsNotFalse(_) => producer.handle_unary_expr(expr, schema),
+ Expr::IsNotUnknown(_) => producer.handle_unary_expr(expr, schema),
+ Expr::Negative(_) => producer.handle_unary_expr(expr, schema),
+ Expr::Between(expr) => producer.handle_between(expr, schema),
+ Expr::Case(expr) => producer.handle_case(expr, schema),
+ Expr::Cast(expr) => producer.handle_cast(expr, schema),
+ Expr::TryCast(expr) => producer.handle_try_cast(expr, schema),
+ Expr::ScalarFunction(expr) => producer.handle_scalar_function(expr,
schema),
+ Expr::AggregateFunction(_) => {
+ internal_err!(
+ "AggregateFunction should only be encountered as part of a
LogicalPlan::Aggregate"
+ )
}
- Expr::Between(Between {
- expr,
- negated,
- low,
- high,
- }) => {
- if *negated {
- // `expr NOT BETWEEN low AND high` can be translated into
(expr < low OR high < expr)
- let substrait_expr =
- to_substrait_rex(state, expr, schema, col_ref_offset,
extensions)?;
- let substrait_low =
- to_substrait_rex(state, low, schema, col_ref_offset,
extensions)?;
- let substrait_high =
- to_substrait_rex(state, high, schema, col_ref_offset,
extensions)?;
-
- let l_expr = make_binary_op_scalar_func(
- &substrait_expr,
- &substrait_low,
- Operator::Lt,
- extensions,
- );
- let r_expr = make_binary_op_scalar_func(
- &substrait_high,
- &substrait_expr,
- Operator::Lt,
- extensions,
- );
-
- Ok(make_binary_op_scalar_func(
- &l_expr,
- &r_expr,
- Operator::Or,
- extensions,
- ))
- } else {
- // `expr BETWEEN low AND high` can be translated into (low <=
expr AND expr <= high)
- let substrait_expr =
- to_substrait_rex(state, expr, schema, col_ref_offset,
extensions)?;
- let substrait_low =
- to_substrait_rex(state, low, schema, col_ref_offset,
extensions)?;
- let substrait_high =
- to_substrait_rex(state, high, schema, col_ref_offset,
extensions)?;
-
- let l_expr = make_binary_op_scalar_func(
- &substrait_low,
- &substrait_expr,
- Operator::LtEq,
- extensions,
- );
- let r_expr = make_binary_op_scalar_func(
- &substrait_expr,
- &substrait_high,
- Operator::LtEq,
- extensions,
- );
-
- Ok(make_binary_op_scalar_func(
- &l_expr,
- &r_expr,
- Operator::And,
- extensions,
- ))
- }
+ Expr::WindowFunction(expr) => producer.handle_window_function(expr,
schema),
+ Expr::InList(expr) => producer.handle_in_list(expr, schema),
+ Expr::Exists(expr) => not_impl_err!("Cannot convert {expr:?} to
Substrait"),
+ Expr::InSubquery(expr) => producer.handle_in_subquery(expr, schema),
+ Expr::ScalarSubquery(expr) => {
+ not_impl_err!("Cannot convert {expr:?} to Substrait")
}
- Expr::Column(col) => {
- let index = schema.index_of_column(col)?;
- substrait_field_ref(index + col_ref_offset)
+ Expr::Wildcard { .. } => not_impl_err!("Cannot convert {expr:?} to
Substrait"),
+ Expr::GroupingSet(expr) => not_impl_err!("Cannot convert {expr:?} to
Substrait"),
+ Expr::Placeholder(expr) => not_impl_err!("Cannot convert {expr:?} to
Substrait"),
+ Expr::OuterReferenceColumn(_, _) => {
+ not_impl_err!("Cannot convert {expr:?} to Substrait")
}
- Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
- let l = to_substrait_rex(state, left, schema, col_ref_offset,
extensions)?;
- let r = to_substrait_rex(state, right, schema, col_ref_offset,
extensions)?;
+ Expr::Unnest(expr) => not_impl_err!("Cannot convert {expr:?} to
Substrait"),
+ }
+}
- Ok(make_binary_op_scalar_func(&l, &r, *op, extensions))
- }
- Expr::Case(Case {
- expr,
- when_then_expr,
- else_expr,
- }) => {
- let mut ifs: Vec<IfClause> = vec![];
- // Parse base
- if let Some(e) = expr {
- // Base expression exists
- ifs.push(IfClause {
- r#if: Some(to_substrait_rex(
- state,
- e,
- schema,
- col_ref_offset,
- extensions,
- )?),
- then: None,
- });
- }
- // Parse `when`s
- for (r#if, then) in when_then_expr {
- ifs.push(IfClause {
- r#if: Some(to_substrait_rex(
- state,
- r#if,
- schema,
- col_ref_offset,
- extensions,
- )?),
- then: Some(to_substrait_rex(
- state,
- then,
- schema,
- col_ref_offset,
- extensions,
- )?),
- });
- }
+pub fn from_in_list(
+ producer: &mut impl SubstraitProducer,
+ in_list: &InList,
+ schema: &DFSchemaRef,
+) -> Result<Expression> {
+ let InList {
+ expr,
+ list,
+ negated,
+ } = in_list;
+ let substrait_list = list
+ .iter()
+ .map(|x| producer.handle_expr(x, schema))
+ .collect::<Result<Vec<Expression>>>()?;
+ let substrait_expr = producer.handle_expr(expr, schema)?;
+
+ let substrait_or_list = Expression {
+ rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList {
+ value: Some(Box::new(substrait_expr)),
+ options: substrait_list,
+ }))),
+ };
- // Parse outer `else`
- let r#else: Option<Box<Expression>> = match else_expr {
- Some(e) => Some(Box::new(to_substrait_rex(
- state,
- e,
- schema,
- col_ref_offset,
- extensions,
- )?)),
- None => None,
- };
+ if *negated {
+ let function_anchor = producer.register_function("not".to_string());
- Ok(Expression {
- rex_type: Some(RexType::IfThen(Box::new(IfThen { ifs, r#else
}))),
- })
- }
- Expr::Cast(Cast { expr, data_type }) => Ok(Expression {
- rex_type: Some(RexType::Cast(Box::new(
- substrait::proto::expression::Cast {
- r#type: Some(to_substrait_type(data_type, true)?),
- input: Some(Box::new(to_substrait_rex(
- state,
- expr,
- schema,
- col_ref_offset,
- extensions,
- )?)),
- failure_behavior: FailureBehavior::ThrowException.into(),
- },
- ))),
- }),
- Expr::TryCast(TryCast { expr, data_type }) => Ok(Expression {
- rex_type: Some(RexType::Cast(Box::new(
- substrait::proto::expression::Cast {
- r#type: Some(to_substrait_type(data_type, true)?),
- input: Some(Box::new(to_substrait_rex(
- state,
- expr,
- schema,
- col_ref_offset,
- extensions,
- )?)),
- failure_behavior: FailureBehavior::ReturnNull.into(),
- },
- ))),
- }),
- Expr::Literal(value) => to_substrait_literal_expr(value, extensions),
- Expr::Alias(Alias { expr, .. }) => {
- to_substrait_rex(state, expr, schema, col_ref_offset, extensions)
- }
- Expr::WindowFunction(WindowFunction {
- fun,
- args,
- partition_by,
- order_by,
- window_frame,
- null_treatment: _,
- }) => {
- // function reference
- let function_anchor =
extensions.register_function(fun.to_string());
- // arguments
- let mut arguments: Vec<FunctionArgument> = vec![];
- for arg in args {
- arguments.push(FunctionArgument {
- arg_type: Some(ArgType::Value(to_substrait_rex(
- state,
- arg,
- schema,
- col_ref_offset,
- extensions,
- )?)),
- });
- }
- // partition by expressions
- let partition_by = partition_by
- .iter()
- .map(|e| to_substrait_rex(state, e, schema, col_ref_offset,
extensions))
- .collect::<Result<Vec<_>>>()?;
- // order by expressions
- let order_by = order_by
- .iter()
- .map(|e| substrait_sort_field(state, e, schema, extensions))
- .collect::<Result<Vec<_>>>()?;
- // window frame
- let bounds = to_substrait_bounds(window_frame)?;
- let bound_type = to_substrait_bound_type(window_frame)?;
- Ok(make_substrait_window_function(
- function_anchor,
- arguments,
- partition_by,
- order_by,
- bounds,
- bound_type,
- ))
- }
- Expr::Like(Like {
- negated,
- expr,
- pattern,
- escape_char,
- case_insensitive,
- }) => make_substrait_like_expr(
- state,
- *case_insensitive,
- *negated,
- expr,
- pattern,
- *escape_char,
- schema,
- col_ref_offset,
- extensions,
- ),
- Expr::InSubquery(InSubquery {
- expr,
- subquery,
- negated,
- }) => {
- let substrait_expr =
- to_substrait_rex(state, expr, schema, col_ref_offset,
extensions)?;
-
- let subquery_plan =
- to_substrait_rel(subquery.subquery.as_ref(), state,
extensions)?;
-
- let substrait_subquery = Expression {
- rex_type: Some(RexType::Subquery(Box::new(Subquery {
- subquery_type: Some(
-
substrait::proto::expression::subquery::SubqueryType::InPredicate(
- Box::new(InPredicate {
- needles: (vec![substrait_expr]),
- haystack: Some(subquery_plan),
- }),
- ),
+ #[allow(deprecated)]
+ Ok(Expression {
+ rex_type: Some(RexType::ScalarFunction(ScalarFunction {
+ function_reference: function_anchor,
+ arguments: vec![FunctionArgument {
+ arg_type: Some(ArgType::Value(substrait_or_list)),
+ }],
+ output_type: None,
+ args: vec![],
+ options: vec![],
+ })),
+ })
+ } else {
+ Ok(substrait_or_list)
+ }
+}
+
+pub fn from_scalar_function(
+ producer: &mut impl SubstraitProducer,
+ fun: &expr::ScalarFunction,
+ schema: &DFSchemaRef,
+) -> Result<Expression> {
+ let mut arguments: Vec<FunctionArgument> = vec![];
+ for arg in &fun.args {
+ arguments.push(FunctionArgument {
+ arg_type: Some(ArgType::Value(producer.handle_expr(arg, schema)?)),
+ });
+ }
+
+ let function_anchor = producer.register_function(fun.name().to_string());
+ #[allow(deprecated)]
+ Ok(Expression {
+ rex_type: Some(RexType::ScalarFunction(ScalarFunction {
+ function_reference: function_anchor,
+ arguments,
+ output_type: None,
+ options: vec![],
+ args: vec![],
+ })),
+ })
+}
+
+pub fn from_between(
+ producer: &mut impl SubstraitProducer,
+ between: &Between,
+ schema: &DFSchemaRef,
+) -> Result<Expression> {
+ let Between {
+ expr,
+ negated,
+ low,
+ high,
+ } = between;
+ if *negated {
+ // `expr NOT BETWEEN low AND high` can be translated into (expr < low
OR high < expr)
+ let substrait_expr = producer.handle_expr(expr.as_ref(), schema)?;
+ let substrait_low = producer.handle_expr(low.as_ref(), schema)?;
+ let substrait_high = producer.handle_expr(high.as_ref(), schema)?;
+
+ let l_expr = make_binary_op_scalar_func(
+ producer,
+ &substrait_expr,
+ &substrait_low,
+ Operator::Lt,
+ );
+ let r_expr = make_binary_op_scalar_func(
+ producer,
+ &substrait_high,
+ &substrait_expr,
+ Operator::Lt,
+ );
+
+ Ok(make_binary_op_scalar_func(
+ producer,
+ &l_expr,
+ &r_expr,
+ Operator::Or,
+ ))
+ } else {
+ // `expr BETWEEN low AND high` can be translated into (low <= expr AND
expr <= high)
+ let substrait_expr = producer.handle_expr(expr.as_ref(), schema)?;
+ let substrait_low = producer.handle_expr(low.as_ref(), schema)?;
+ let substrait_high = producer.handle_expr(high.as_ref(), schema)?;
+
+ let l_expr = make_binary_op_scalar_func(
+ producer,
+ &substrait_low,
+ &substrait_expr,
+ Operator::LtEq,
+ );
+ let r_expr = make_binary_op_scalar_func(
+ producer,
+ &substrait_expr,
+ &substrait_high,
+ Operator::LtEq,
+ );
+
+ Ok(make_binary_op_scalar_func(
+ producer,
+ &l_expr,
+ &r_expr,
+ Operator::And,
+ ))
+ }
+}
+pub fn from_column(col: &Column, schema: &DFSchemaRef) -> Result<Expression> {
+ let index = schema.index_of_column(col)?;
+ substrait_field_ref(index)
+}
+
+pub fn from_binary_expr(
+ producer: &mut impl SubstraitProducer,
+ expr: &BinaryExpr,
+ schema: &DFSchemaRef,
+) -> Result<Expression> {
+ let BinaryExpr { left, op, right } = expr;
+ let l = producer.handle_expr(left, schema)?;
+ let r = producer.handle_expr(right, schema)?;
+ Ok(make_binary_op_scalar_func(producer, &l, &r, *op))
+}
+pub fn from_case(
+ producer: &mut impl SubstraitProducer,
+ case: &Case,
+ schema: &DFSchemaRef,
+) -> Result<Expression> {
+ let Case {
+ expr,
+ when_then_expr,
+ else_expr,
+ } = case;
+ let mut ifs: Vec<IfClause> = vec![];
+ // Parse base
+ if let Some(e) = expr {
+ // Base expression exists
+ ifs.push(IfClause {
+ r#if: Some(producer.handle_expr(e, schema)?),
+ then: None,
+ });
+ }
+ // Parse `when`s
+ for (r#if, then) in when_then_expr {
+ ifs.push(IfClause {
+ r#if: Some(producer.handle_expr(r#if, schema)?),
+ then: Some(producer.handle_expr(then, schema)?),
+ });
+ }
+
+ // Parse outer `else`
+ let r#else: Option<Box<Expression>> = match else_expr {
+ Some(e) => Some(Box::new(producer.handle_expr(e, schema)?)),
+ None => None,
+ };
+
+ Ok(Expression {
+ rex_type: Some(RexType::IfThen(Box::new(IfThen { ifs, r#else }))),
+ })
+}
+
+pub fn from_cast(
+ producer: &mut impl SubstraitProducer,
+ cast: &Cast,
+ schema: &DFSchemaRef,
+) -> Result<Expression> {
+ let Cast { expr, data_type } = cast;
+ Ok(Expression {
+ rex_type: Some(RexType::Cast(Box::new(
+ substrait::proto::expression::Cast {
+ r#type: Some(to_substrait_type(data_type, true)?),
+ input: Some(Box::new(producer.handle_expr(expr, schema)?)),
+ failure_behavior: FailureBehavior::ThrowException.into(),
+ },
+ ))),
+ })
+}
+
+pub fn from_try_cast(
+ producer: &mut impl SubstraitProducer,
+ cast: &TryCast,
+ schema: &DFSchemaRef,
+) -> Result<Expression> {
+ let TryCast { expr, data_type } = cast;
+ Ok(Expression {
+ rex_type: Some(RexType::Cast(Box::new(
+ substrait::proto::expression::Cast {
+ r#type: Some(to_substrait_type(data_type, true)?),
+ input: Some(Box::new(producer.handle_expr(expr, schema)?)),
+ failure_behavior: FailureBehavior::ReturnNull.into(),
+ },
+ ))),
+ })
+}
+
+pub fn from_literal(
+ producer: &mut impl SubstraitProducer,
+ value: &ScalarValue,
+) -> Result<Expression> {
+ to_substrait_literal_expr(producer, value)
+}
+
+pub fn from_alias(
+ producer: &mut impl SubstraitProducer,
+ alias: &Alias,
+ schema: &DFSchemaRef,
+) -> Result<Expression> {
+ producer.handle_expr(alias.expr.as_ref(), schema)
+}
+
+pub fn from_window_function(
+ producer: &mut impl SubstraitProducer,
+ window_fn: &WindowFunction,
+ schema: &DFSchemaRef,
+) -> Result<Expression> {
+ let WindowFunction {
+ fun,
+ args,
+ partition_by,
+ order_by,
+ window_frame,
+ null_treatment: _,
+ } = window_fn;
+ // function reference
+ let function_anchor = producer.register_function(fun.to_string());
+ // arguments
+ let mut arguments: Vec<FunctionArgument> = vec![];
+ for arg in args {
+ arguments.push(FunctionArgument {
+ arg_type: Some(ArgType::Value(producer.handle_expr(arg, schema)?)),
+ });
+ }
+ // partition by expressions
+ let partition_by = partition_by
+ .iter()
+ .map(|e| producer.handle_expr(e, schema))
+ .collect::<Result<Vec<_>>>()?;
+ // order by expressions
+ let order_by = order_by
+ .iter()
+ .map(|e| substrait_sort_field(producer, e, schema))
+ .collect::<Result<Vec<_>>>()?;
+ // window frame
+ let bounds = to_substrait_bounds(window_frame)?;
+ let bound_type = to_substrait_bound_type(window_frame)?;
+ Ok(make_substrait_window_function(
+ function_anchor,
+ arguments,
+ partition_by,
+ order_by,
+ bounds,
+ bound_type,
+ ))
+}
+
+pub fn from_like(
+ producer: &mut impl SubstraitProducer,
+ like: &Like,
+ schema: &DFSchemaRef,
+) -> Result<Expression> {
+ let Like {
+ negated,
+ expr,
+ pattern,
+ escape_char,
+ case_insensitive,
+ } = like;
+ make_substrait_like_expr(
+ producer,
+ *case_insensitive,
+ *negated,
+ expr,
+ pattern,
+ *escape_char,
+ schema,
+ )
+}
+
+pub fn from_in_subquery(
+ producer: &mut impl SubstraitProducer,
+ subquery: &InSubquery,
+ schema: &DFSchemaRef,
+) -> Result<Expression> {
+ let InSubquery {
+ expr,
+ subquery,
+ negated,
+ } = subquery;
+ let substrait_expr = producer.handle_expr(expr, schema)?;
+
+ let subquery_plan = producer.handle_plan(subquery.subquery.as_ref())?;
+
+ let substrait_subquery = Expression {
+ rex_type: Some(RexType::Subquery(Box::new(
+ substrait::proto::expression::Subquery {
+ subquery_type: Some(
+
substrait::proto::expression::subquery::SubqueryType::InPredicate(
+ Box::new(InPredicate {
+ needles: (vec![substrait_expr]),
+ haystack: Some(subquery_plan),
+ }),
),
- }))),
- };
- if *negated {
- let function_anchor =
extensions.register_function("not".to_string());
-
- Ok(Expression {
- rex_type: Some(RexType::ScalarFunction(ScalarFunction {
- function_reference: function_anchor,
- arguments: vec![FunctionArgument {
- arg_type: Some(ArgType::Value(substrait_subquery)),
- }],
- output_type: None,
- args: vec![],
- options: vec![],
- })),
- })
- } else {
- Ok(substrait_subquery)
- }
- }
- Expr::Not(arg) => to_substrait_unary_scalar_fn(
- state,
- "not",
- arg,
- schema,
- col_ref_offset,
- extensions,
- ),
- Expr::IsNull(arg) => to_substrait_unary_scalar_fn(
- state,
- "is_null",
- arg,
- schema,
- col_ref_offset,
- extensions,
- ),
- Expr::IsNotNull(arg) => to_substrait_unary_scalar_fn(
- state,
- "is_not_null",
- arg,
- schema,
- col_ref_offset,
- extensions,
- ),
- Expr::IsTrue(arg) => to_substrait_unary_scalar_fn(
- state,
- "is_true",
- arg,
- schema,
- col_ref_offset,
- extensions,
- ),
- Expr::IsFalse(arg) => to_substrait_unary_scalar_fn(
- state,
- "is_false",
- arg,
- schema,
- col_ref_offset,
- extensions,
- ),
- Expr::IsUnknown(arg) => to_substrait_unary_scalar_fn(
- state,
- "is_unknown",
- arg,
- schema,
- col_ref_offset,
- extensions,
- ),
- Expr::IsNotTrue(arg) => to_substrait_unary_scalar_fn(
- state,
- "is_not_true",
- arg,
- schema,
- col_ref_offset,
- extensions,
- ),
- Expr::IsNotFalse(arg) => to_substrait_unary_scalar_fn(
- state,
- "is_not_false",
- arg,
- schema,
- col_ref_offset,
- extensions,
- ),
- Expr::IsNotUnknown(arg) => to_substrait_unary_scalar_fn(
- state,
- "is_not_unknown",
- arg,
- schema,
- col_ref_offset,
- extensions,
- ),
- Expr::Negative(arg) => to_substrait_unary_scalar_fn(
- state,
- "negate",
- arg,
- schema,
- col_ref_offset,
- extensions,
- ),
- _ => {
- not_impl_err!("Unsupported expression: {expr:?}")
- }
+ ),
+ },
+ ))),
+ };
+ if *negated {
+ let function_anchor = producer.register_function("not".to_string());
+
+ #[allow(deprecated)]
+ Ok(Expression {
+ rex_type: Some(RexType::ScalarFunction(ScalarFunction {
+ function_reference: function_anchor,
+ arguments: vec![FunctionArgument {
+ arg_type: Some(ArgType::Value(substrait_subquery)),
+ }],
+ output_type: None,
+ args: vec![],
+ options: vec![],
+ })),
+ })
+ } else {
+ Ok(substrait_subquery)
}
}
+pub fn from_unary_expr(
+ producer: &mut impl SubstraitProducer,
+ expr: &Expr,
+ schema: &DFSchemaRef,
+) -> Result<Expression> {
+ let (fn_name, arg) = match expr {
+ Expr::Not(arg) => ("not", arg),
+ Expr::IsNull(arg) => ("is_null", arg),
+ Expr::IsNotNull(arg) => ("is_not_null", arg),
+ Expr::IsTrue(arg) => ("is_true", arg),
+ Expr::IsFalse(arg) => ("is_false", arg),
+ Expr::IsUnknown(arg) => ("is_unknown", arg),
+ Expr::IsNotTrue(arg) => ("is_not_true", arg),
+ Expr::IsNotFalse(arg) => ("is_not_false", arg),
+ Expr::IsNotUnknown(arg) => ("is_not_unknown", arg),
+ Expr::Negative(arg) => ("negate", arg),
+ expr => not_impl_err!("Unsupported expression: {expr:?}")?,
+ };
+ to_substrait_unary_scalar_fn(producer, fn_name, arg, schema)
+}
+
fn to_substrait_type(dt: &DataType, nullable: bool) ->
Result<substrait::proto::Type> {
let nullability = if nullable {
r#type::Nullability::Nullable as i32
@@ -1700,7 +2000,6 @@ fn to_substrait_type(dt: &DataType, nullable: bool) ->
Result<substrait::proto::
}
}
-#[allow(deprecated)]
fn make_substrait_window_function(
function_reference: u32,
arguments: Vec<FunctionArgument>,
@@ -1709,6 +2008,7 @@ fn make_substrait_window_function(
bounds: (Bound, Bound),
bounds_type: BoundsType,
) -> Expression {
+ #[allow(deprecated)]
Expression {
rex_type: Some(RexType::WindowFunction(SubstraitWindowFunction {
function_reference,
@@ -1727,29 +2027,25 @@ fn make_substrait_window_function(
}
}
-#[allow(deprecated)]
-#[allow(clippy::too_many_arguments)]
fn make_substrait_like_expr(
- state: &dyn SubstraitPlanningState,
+ producer: &mut impl SubstraitProducer,
ignore_case: bool,
negated: bool,
expr: &Expr,
pattern: &Expr,
escape_char: Option<char>,
schema: &DFSchemaRef,
- col_ref_offset: usize,
- extensions: &mut Extensions,
) -> Result<Expression> {
let function_anchor = if ignore_case {
- extensions.register_function("ilike".to_string())
+ producer.register_function("ilike".to_string())
} else {
- extensions.register_function("like".to_string())
+ producer.register_function("like".to_string())
};
- let expr = to_substrait_rex(state, expr, schema, col_ref_offset,
extensions)?;
- let pattern = to_substrait_rex(state, pattern, schema, col_ref_offset,
extensions)?;
+ let expr = producer.handle_expr(expr, schema)?;
+ let pattern = producer.handle_expr(pattern, schema)?;
let escape_char = to_substrait_literal_expr(
+ producer,
&ScalarValue::Utf8(escape_char.map(|c| c.to_string())),
- extensions,
)?;
let arguments = vec![
FunctionArgument {
@@ -1763,6 +2059,7 @@ fn make_substrait_like_expr(
},
];
+ #[allow(deprecated)]
let substrait_like = Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
function_reference: function_anchor,
@@ -1774,8 +2071,9 @@ fn make_substrait_like_expr(
};
if negated {
- let function_anchor = extensions.register_function("not".to_string());
+ let function_anchor = producer.register_function("not".to_string());
+ #[allow(deprecated)]
Ok(Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
function_reference: function_anchor,
@@ -1847,8 +2145,8 @@ fn to_substrait_bounds(window_frame: &WindowFrame) ->
Result<(Bound, Bound)> {
}
fn to_substrait_literal(
+ producer: &mut impl SubstraitProducer,
value: &ScalarValue,
- extensions: &mut Extensions,
) -> Result<Literal> {
if value.is_null() {
return Ok(Literal {
@@ -2026,11 +2324,11 @@ fn to_substrait_literal(
DECIMAL_128_TYPE_VARIATION_REF,
),
ScalarValue::List(l) => (
- convert_array_to_literal_list(l, extensions)?,
+ convert_array_to_literal_list(producer, l)?,
DEFAULT_CONTAINER_TYPE_VARIATION_REF,
),
ScalarValue::LargeList(l) => (
- convert_array_to_literal_list(l, extensions)?,
+ convert_array_to_literal_list(producer, l)?,
LARGE_CONTAINER_TYPE_VARIATION_REF,
),
ScalarValue::Map(m) => {
@@ -2047,16 +2345,16 @@ fn to_substrait_literal(
let keys = (0..m.keys().len())
.map(|i| {
to_substrait_literal(
+ producer,
&ScalarValue::try_from_array(&m.keys(), i)?,
- extensions,
)
})
.collect::<Result<Vec<_>>>()?;
let values = (0..m.values().len())
.map(|i| {
to_substrait_literal(
+ producer,
&ScalarValue::try_from_array(&m.values(), i)?,
- extensions,
)
})
.collect::<Result<Vec<_>>>()?;
@@ -2082,8 +2380,8 @@ fn to_substrait_literal(
.iter()
.map(|col| {
to_substrait_literal(
+ producer,
&ScalarValue::try_from_array(col, 0)?,
- extensions,
)
})
.collect::<Result<Vec<_>>>()?,
@@ -2104,8 +2402,8 @@ fn to_substrait_literal(
}
fn convert_array_to_literal_list<T: OffsetSizeTrait>(
+ producer: &mut impl SubstraitProducer,
array: &GenericListArray<T>,
- extensions: &mut Extensions,
) -> Result<LiteralType> {
assert_eq!(array.len(), 1);
let nested_array = array.value(0);
@@ -2113,8 +2411,8 @@ fn convert_array_to_literal_list<T: OffsetSizeTrait>(
let values = (0..nested_array.len())
.map(|i| {
to_substrait_literal(
+ producer,
&ScalarValue::try_from_array(&nested_array, i)?,
- extensions,
)
})
.collect::<Result<Vec<_>>>()?;
@@ -2133,10 +2431,10 @@ fn convert_array_to_literal_list<T: OffsetSizeTrait>(
}
fn to_substrait_literal_expr(
+ producer: &mut impl SubstraitProducer,
value: &ScalarValue,
- extensions: &mut Extensions,
) -> Result<Expression> {
- let literal = to_substrait_literal(value, extensions)?;
+ let literal = to_substrait_literal(producer, value)?;
Ok(Expression {
rex_type: Some(RexType::Literal(literal)),
})
@@ -2144,16 +2442,13 @@ fn to_substrait_literal_expr(
/// Util to generate substrait [RexType::ScalarFunction] with one argument
fn to_substrait_unary_scalar_fn(
- state: &dyn SubstraitPlanningState,
+ producer: &mut impl SubstraitProducer,
fn_name: &str,
arg: &Expr,
schema: &DFSchemaRef,
- col_ref_offset: usize,
- extensions: &mut Extensions,
) -> Result<Expression> {
- let function_anchor = extensions.register_function(fn_name.to_string());
- let substrait_expr =
- to_substrait_rex(state, arg, schema, col_ref_offset, extensions)?;
+ let function_anchor = producer.register_function(fn_name.to_string());
+ let substrait_expr = producer.handle_expr(arg, schema)?;
Ok(Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
@@ -2194,17 +2489,16 @@ fn try_to_substrait_field_reference(
}
fn substrait_sort_field(
- state: &dyn SubstraitPlanningState,
- sort: &Sort,
+ producer: &mut impl SubstraitProducer,
+ sort: &SortExpr,
schema: &DFSchemaRef,
- extensions: &mut Extensions,
) -> Result<SortField> {
- let Sort {
+ let SortExpr {
expr,
asc,
nulls_first,
} = sort;
- let e = to_substrait_rex(state, expr, schema, 0, extensions)?;
+ let e = producer.handle_expr(expr, schema)?;
let d = match (asc, nulls_first) {
(true, true) => SortDirection::AscNullsFirst,
(true, false) => SortDirection::AscNullsLast,
@@ -2380,9 +2674,9 @@ mod test {
fn round_trip_literal(scalar: ScalarValue) -> Result<()> {
println!("Checking round trip of {scalar:?}");
-
- let mut extensions = Extensions::default();
- let substrait_literal = to_substrait_literal(&scalar, &mut
extensions)?;
+ let state = SessionContext::default().state();
+ let mut producer = DefaultSubstraitProducer::new(&state);
+ let substrait_literal = to_substrait_literal(&mut producer, &scalar)?;
let roundtrip_scalar =
from_substrait_literal_without_names(&test_consumer(),
&substrait_literal)?;
assert_eq!(scalar, roundtrip_scalar);
diff --git a/datafusion/substrait/src/logical_plan/state.rs
b/datafusion/substrait/src/logical_plan/state.rs
deleted file mode 100644
index 0bd749c110..0000000000
--- a/datafusion/substrait/src/logical_plan/state.rs
+++ /dev/null
@@ -1,63 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-use std::sync::Arc;
-
-use async_trait::async_trait;
-use datafusion::{
- catalog::TableProvider,
- error::{DataFusionError, Result},
- execution::{registry::SerializerRegistry, FunctionRegistry, SessionState},
- sql::TableReference,
-};
-
-/// This trait provides the context needed to transform a substrait plan into a
-/// [`datafusion::logical_expr::LogicalPlan`] (via
[`super::consumer::from_substrait_plan`])
-/// and back again into a substrait plan (via
[`super::producer::to_substrait_plan`]).
-///
-/// The context is declared as a trait to decouple the substrait plan encoder /
-/// decoder from the [`SessionState`], potentially allowing users to define
-/// their own slimmer context just for serializing and deserializing substrait.
-///
-/// [`SessionState`] implements this trait.
-#[async_trait]
-pub trait SubstraitPlanningState: Sync + Send + FunctionRegistry {
- /// Return [SerializerRegistry] for extensions
- fn serializer_registry(&self) -> &Arc<dyn SerializerRegistry>;
-
- async fn table(
- &self,
- reference: &TableReference,
- ) -> Result<Option<Arc<dyn TableProvider>>>;
-}
-
-#[async_trait]
-impl SubstraitPlanningState for SessionState {
- fn serializer_registry(&self) -> &Arc<dyn SerializerRegistry> {
- self.serializer_registry()
- }
-
- async fn table(
- &self,
- reference: &TableReference,
- ) -> Result<Option<Arc<dyn TableProvider>>, DataFusionError> {
- let table = reference.table().to_string();
- let schema = self.schema_for_ref(reference.clone())?;
- let table_provider = schema.table(&table).await?;
- Ok(table_provider)
- }
-}
diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index 383fe44be5..7045729493 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -571,6 +571,21 @@ async fn roundtrip_self_implicit_cross_join() ->
Result<()> {
roundtrip("SELECT left.a left_a, left.b, right.a right_a, right.c FROM
data AS left, data AS right").await
}
+#[tokio::test]
+async fn self_join_introduces_aliases() -> Result<()> {
+ assert_expected_plan(
+ "SELECT d1.b, d2.c FROM data d1 JOIN data d2 ON d1.b = d2.b",
+ "Projection: left.b, right.c\
+ \n Inner Join: left.b = right.b\
+ \n SubqueryAlias: left\
+ \n TableScan: data projection=[b]\
+ \n SubqueryAlias: right\
+ \n TableScan: data projection=[b, c]",
+ false,
+ )
+ .await
+}
+
#[tokio::test]
async fn roundtrip_arithmetic_ops() -> Result<()> {
roundtrip("SELECT a - a FROM data").await?;
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]