This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new b99400e142 feat(substrait): modular substrait consumer (#13803)
b99400e142 is described below
commit b99400e142a92c6e580fc5364196294c1eb1c91b
Author: Victor Barua <[email protected]>
AuthorDate: Sat Dec 21 06:20:25 2024 -0800
feat(substrait): modular substrait consumer (#13803)
* feat(substrait): modular substrait consumer
* feat(substrait): include Extension Rel handlers in default consumer
Include SerializerRegistry based handlers for Extension Relations in the
DefaultSubstraitConsumer
* refactor(substrait) _selection -> _field_reference
* refactor(substrait): remove SubstraitPlannerState usage from consumer
* refactor: get_state() -> get_function_registry()
* docs: elide imports from example
* test: simplify test
* refactor: remove Arc from DefaultSubstraitConsumer
* doc: add ticket for API improvements
* doc: link DefaultSubstraitConsumer to from_subtrait_plan
* refactor: remove redundant Extensions parsing
---
datafusion/substrait/src/logical_plan/consumer.rs | 2373 ++++++++++++--------
datafusion/substrait/src/logical_plan/producer.rs | 22 +-
.../tests/cases/roundtrip_logical_plan.rs | 3 +-
datafusion/substrait/tests/utils.rs | 39 +-
4 files changed, 1442 insertions(+), 995 deletions(-)
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs
b/datafusion/substrait/src/logical_plan/consumer.rs
index 9f98fdace6..9aa3f00804 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -21,19 +21,19 @@ use datafusion::arrow::array::{GenericListArray, MapArray};
use datafusion::arrow::datatypes::{
DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit,
};
-use datafusion::common::plan_err;
use datafusion::common::{
- not_impl_err, plan_datafusion_err, substrait_datafusion_err,
substrait_err, DFSchema,
- DFSchemaRef,
+ not_impl_datafusion_err, not_impl_err, plan_datafusion_err, plan_err,
+ substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef,
};
use datafusion::datasource::provider_as_source;
use datafusion::logical_expr::expr::{Exists, InSubquery, Sort};
use datafusion::logical_expr::{
- Aggregate, BinaryExpr, Case, EmptyRelation, Expr, ExprSchemable,
LogicalPlan,
- Operator, Projection, SortExpr, TryCast, Values,
+ Aggregate, BinaryExpr, Case, Cast, EmptyRelation, Expr, ExprSchemable,
Extension,
+ LogicalPlan, Operator, Projection, SortExpr, Subquery, TryCast, Values,
};
use substrait::proto::aggregate_rel::Grouping;
+use substrait::proto::expression as substrait_expression;
use substrait::proto::expression::subquery::set_predicate::PredicateOp;
use substrait::proto::expression_reference::ExprType;
use url::Url;
@@ -53,14 +53,17 @@ use crate::variation_const::{
TIMESTAMP_MICRO_TYPE_VARIATION_REF, TIMESTAMP_MILLI_TYPE_VARIATION_REF,
TIMESTAMP_NANO_TYPE_VARIATION_REF, TIMESTAMP_SECOND_TYPE_VARIATION_REF,
};
+use async_trait::async_trait;
use datafusion::arrow::array::{new_empty_array, AsArray};
use datafusion::arrow::temporal_conversions::NANOSECONDS;
+use datafusion::catalog::TableProvider;
use datafusion::common::scalar::ScalarStructBuilder;
+use datafusion::execution::{FunctionRegistry, SessionState};
use datafusion::logical_expr::builder::project;
use datafusion::logical_expr::expr::InList;
use datafusion::logical_expr::{
- col, expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder,
Partitioning,
- Repartition, Subquery, WindowFrameBound, WindowFrameUnits,
WindowFunctionDefinition,
+ col, expr, GroupingSet, Like, LogicalPlanBuilder, Partitioning,
Repartition,
+ WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
};
use datafusion::prelude::{lit, JoinType};
use datafusion::sql::TableReference;
@@ -70,17 +73,21 @@ use datafusion::{
};
use std::collections::HashSet;
use std::sync::Arc;
+use substrait::proto;
use substrait::proto::exchange_rel::ExchangeKind;
use substrait::proto::expression::cast::FailureBehavior::ReturnNull;
use substrait::proto::expression::literal::user_defined::Val;
use substrait::proto::expression::literal::{
interval_day_to_second, IntervalCompound, IntervalDayToSecond,
IntervalYearToMonth,
- UserDefined,
};
use substrait::proto::expression::subquery::SubqueryType;
-use substrait::proto::expression::{FieldReference, Literal, ScalarFunction};
+use substrait::proto::expression::{
+ Enum, FieldReference, IfThen, Literal, MultiOrList, Nested, ScalarFunction,
+ SingularOrList, SwitchExpression, WindowFunction,
+};
use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile;
use substrait::proto::rel_common::{Emit, EmitKind};
+use substrait::proto::set_rel::SetOp;
use substrait::proto::{
aggregate_function::AggregationInvocation,
expression::{
@@ -90,17 +97,469 @@ use substrait::proto::{
window_function::bound::Kind as BoundKind, window_function::Bound,
window_function::BoundsType, MaskExpression, RexType,
},
+ fetch_rel,
function_argument::ArgType,
join_rel, plan_rel, r#type,
read_rel::ReadType,
rel::RelType,
- rel_common, set_rel,
+ rel_common,
sort_field::{SortDirection, SortKind::*},
- AggregateFunction, Expression, NamedStruct, Plan, Rel, RelCommon, Type,
+ AggregateFunction, AggregateRel, ConsistentPartitionWindowRel, CrossRel,
ExchangeRel,
+ Expression, ExtendedExpression, ExtensionLeafRel, ExtensionMultiRel,
+ ExtensionSingleRel, FetchRel, FilterRel, FunctionArgument, JoinRel,
NamedStruct,
+ Plan, ProjectRel, ReadRel, Rel, RelCommon, SetRel, SortField, SortRel,
Type,
};
-use substrait::proto::{fetch_rel, ExtendedExpression, FunctionArgument,
SortField};
-use super::state::SubstraitPlanningState;
+#[async_trait]
+/// This trait is used to consume Substrait plans, converting them into
DataFusion Logical Plans.
+/// It can be implemented by users to allow for custom handling of relations,
expressions, etc.
+///
+/// # Example Usage
+///
+/// ```
+/// # use async_trait::async_trait;
+/// # use datafusion::catalog::TableProvider;
+/// # use datafusion::common::{not_impl_err, substrait_err, DFSchema,
ScalarValue, TableReference};
+/// # use datafusion::error::Result;
+/// # use datafusion::execution::{FunctionRegistry, SessionState};
+/// # use datafusion::logical_expr::{Expr, LogicalPlan, LogicalPlanBuilder};
+/// # use std::sync::Arc;
+/// # use substrait::proto;
+/// # use substrait::proto::{ExtensionLeafRel, FilterRel, ProjectRel};
+/// # use datafusion::arrow::datatypes::DataType;
+/// # use datafusion::logical_expr::expr::ScalarFunction;
+/// # use datafusion_substrait::extensions::Extensions;
+/// # use datafusion_substrait::logical_plan::consumer::{
+/// # from_project_rel, from_substrait_rel, from_substrait_rex,
SubstraitConsumer
+/// # };
+///
+/// struct CustomSubstraitConsumer {
+/// extensions: Arc<Extensions>,
+/// state: Arc<SessionState>,
+/// }
+///
+/// #[async_trait]
+/// impl SubstraitConsumer for CustomSubstraitConsumer {
+/// async fn resolve_table_ref(
+/// &self,
+/// table_ref: &TableReference,
+/// ) -> Result<Option<Arc<dyn TableProvider>>> {
+/// let table = table_ref.table().to_string();
+/// let schema = self.state.schema_for_ref(table_ref.clone())?;
+/// let table_provider = schema.table(&table).await?;
+/// Ok(table_provider)
+/// }
+///
+/// fn get_extensions(&self) -> &Extensions {
+/// self.extensions.as_ref()
+/// }
+///
+/// fn get_function_registry(&self) -> &impl FunctionRegistry {
+/// self.state.as_ref()
+/// }
+///
+/// // You can reuse existing consumer code to assist in handling advanced
extensions
+/// async fn consume_project(&self, rel: &ProjectRel) ->
Result<LogicalPlan> {
+/// let df_plan = from_project_rel(self, rel).await?;
+/// if let Some(advanced_extension) = rel.advanced_extension.as_ref() {
+/// not_impl_err!(
+/// "decode and handle an advanced extension: {:?}",
+/// advanced_extension
+/// )
+/// } else {
+/// Ok(df_plan)
+/// }
+/// }
+///
+/// // You can implement a fully custom consumer method if you need
special handling
+/// async fn consume_filter(&self, rel: &FilterRel) -> Result<LogicalPlan>
{
+/// let input = from_substrait_rel(self,
rel.input.as_ref().unwrap()).await?;
+/// let expression =
+/// from_substrait_rex(self, rel.condition.as_ref().unwrap(),
input.schema())
+/// .await?;
+/// // though this one is quite boring
+/// LogicalPlanBuilder::from(input).filter(expression)?.build()
+/// }
+///
+/// // You can add handlers for extension relations
+/// async fn consume_extension_leaf(
+/// &self,
+/// rel: &ExtensionLeafRel,
+/// ) -> Result<LogicalPlan> {
+/// not_impl_err!(
+/// "handle protobuf Any {} as you need",
+/// rel.detail.as_ref().unwrap().type_url
+/// )
+/// }
+///
+/// // and handlers for user-define types
+/// fn consume_user_defined_type(&self, typ: &proto::r#type::UserDefined)
-> Result<DataType> {
+/// let type_string =
self.extensions.types.get(&typ.type_reference).unwrap();
+/// match type_string.as_str() {
+/// "u!foo" => not_impl_err!("handle foo conversion"),
+/// "u!bar" => not_impl_err!("handle bar conversion"),
+/// _ => substrait_err!("unexpected type")
+/// }
+/// }
+///
+/// // and user-defined literals
+/// fn consume_user_defined_literal(&self, literal:
&proto::expression::literal::UserDefined) -> Result<ScalarValue> {
+/// let type_string =
self.extensions.types.get(&literal.type_reference).unwrap();
+/// match type_string.as_str() {
+/// "u!foo" => not_impl_err!("handle foo conversion"),
+/// "u!bar" => not_impl_err!("handle bar conversion"),
+/// _ => substrait_err!("unexpected type")
+/// }
+/// }
+/// }
+/// ```
+///
+pub trait SubstraitConsumer: Send + Sync + Sized {
+ async fn resolve_table_ref(
+ &self,
+ table_ref: &TableReference,
+ ) -> Result<Option<Arc<dyn TableProvider>>>;
+
+ // TODO: Remove these two methods
+ // Ideally, the abstract consumer should not place any constraints on
implementations.
+ // The functionality for which the Extensions and FunctionRegistry is
needed should be abstracted
+ // out into methods on the trait. As an example, resolve_table_reference
is such a method.
+ // See: https://github.com/apache/datafusion/issues/13863
+ fn get_extensions(&self) -> &Extensions;
+ fn get_function_registry(&self) -> &impl FunctionRegistry;
+
+ // Relation Methods
+ // There is one method per Substrait relation to allow for easy overriding
of consumer behaviour.
+ // These methods have default implementations calling the common handler
code, to allow for users
+ // to re-use common handling logic.
+
+ async fn consume_read(&self, rel: &ReadRel) -> Result<LogicalPlan> {
+ from_read_rel(self, rel).await
+ }
+
+ async fn consume_filter(&self, rel: &FilterRel) -> Result<LogicalPlan> {
+ from_filter_rel(self, rel).await
+ }
+
+ async fn consume_fetch(&self, rel: &FetchRel) -> Result<LogicalPlan> {
+ from_fetch_rel(self, rel).await
+ }
+
+ async fn consume_aggregate(&self, rel: &AggregateRel) ->
Result<LogicalPlan> {
+ from_aggregate_rel(self, rel).await
+ }
+
+ async fn consume_sort(&self, rel: &SortRel) -> Result<LogicalPlan> {
+ from_sort_rel(self, rel).await
+ }
+
+ async fn consume_join(&self, rel: &JoinRel) -> Result<LogicalPlan> {
+ from_join_rel(self, rel).await
+ }
+
+ async fn consume_project(&self, rel: &ProjectRel) -> Result<LogicalPlan> {
+ from_project_rel(self, rel).await
+ }
+
+ async fn consume_set(&self, rel: &SetRel) -> Result<LogicalPlan> {
+ from_set_rel(self, rel).await
+ }
+
+ async fn consume_cross(&self, rel: &CrossRel) -> Result<LogicalPlan> {
+ from_cross_rel(self, rel).await
+ }
+
+ async fn consume_consistent_partition_window(
+ &self,
+ _rel: &ConsistentPartitionWindowRel,
+ ) -> Result<LogicalPlan> {
+ not_impl_err!("Consistent Partition Window Rel not supported")
+ }
+
+ async fn consume_exchange(&self, rel: &ExchangeRel) -> Result<LogicalPlan>
{
+ from_exchange_rel(self, rel).await
+ }
+
+ // Expression Methods
+ // There is one method per Substrait expression to allow for easy
overriding of consumer behaviour
+ // These methods have default implementations calling the common handler
code, to allow for users
+ // to re-use common handling logic.
+
+ async fn consume_literal(&self, expr: &Literal) -> Result<Expr> {
+ from_literal(self, expr).await
+ }
+
+ async fn consume_field_reference(
+ &self,
+ expr: &FieldReference,
+ input_schema: &DFSchema,
+ ) -> Result<Expr> {
+ from_field_reference(self, expr, input_schema).await
+ }
+
+ async fn consume_scalar_function(
+ &self,
+ expr: &ScalarFunction,
+ input_schema: &DFSchema,
+ ) -> Result<Expr> {
+ from_scalar_function(self, expr, input_schema).await
+ }
+
+ async fn consume_window_function(
+ &self,
+ expr: &WindowFunction,
+ input_schema: &DFSchema,
+ ) -> Result<Expr> {
+ from_window_function(self, expr, input_schema).await
+ }
+
+ async fn consume_if_then(
+ &self,
+ expr: &IfThen,
+ input_schema: &DFSchema,
+ ) -> Result<Expr> {
+ from_if_then(self, expr, input_schema).await
+ }
+
+ async fn consume_switch(
+ &self,
+ _expr: &SwitchExpression,
+ _input_schema: &DFSchema,
+ ) -> Result<Expr> {
+ not_impl_err!("Switch expression not supported")
+ }
+
+ async fn consume_singular_or_list(
+ &self,
+ expr: &SingularOrList,
+ input_schema: &DFSchema,
+ ) -> Result<Expr> {
+ from_singular_or_list(self, expr, input_schema).await
+ }
+
+ async fn consume_multi_or_list(
+ &self,
+ _expr: &MultiOrList,
+ _input_schema: &DFSchema,
+ ) -> Result<Expr> {
+ not_impl_err!("Multi Or List expression not supported")
+ }
+
+ async fn consume_cast(
+ &self,
+ expr: &substrait_expression::Cast,
+ input_schema: &DFSchema,
+ ) -> Result<Expr> {
+ from_cast(self, expr, input_schema).await
+ }
+
+ async fn consume_subquery(
+ &self,
+ expr: &substrait_expression::Subquery,
+ input_schema: &DFSchema,
+ ) -> Result<Expr> {
+ from_subquery(self, expr, input_schema).await
+ }
+
+ async fn consume_nested(
+ &self,
+ _expr: &Nested,
+ _input_schema: &DFSchema,
+ ) -> Result<Expr> {
+ not_impl_err!("Nested expression not supported")
+ }
+
+ async fn consume_enum(&self, _expr: &Enum, _input_schema: &DFSchema) ->
Result<Expr> {
+ not_impl_err!("Enum expression not supported")
+ }
+
+ // User-Defined Functionality
+
+ // The details of extension relations, and how to handle them, are fully
up to users to specify.
+ // The following methods allow users to customize the consumer behaviour
+
+ async fn consume_extension_leaf(
+ &self,
+ rel: &ExtensionLeafRel,
+ ) -> Result<LogicalPlan> {
+ if let Some(detail) = rel.detail.as_ref() {
+ return substrait_err!(
+ "Missing handler for ExtensionLeafRel: {}",
+ detail.type_url
+ );
+ }
+ substrait_err!("Missing handler for ExtensionLeafRel")
+ }
+
+ async fn consume_extension_single(
+ &self,
+ rel: &ExtensionSingleRel,
+ ) -> Result<LogicalPlan> {
+ if let Some(detail) = rel.detail.as_ref() {
+ return substrait_err!(
+ "Missing handler for ExtensionSingleRel: {}",
+ detail.type_url
+ );
+ }
+ substrait_err!("Missing handler for ExtensionSingleRel")
+ }
+
+ async fn consume_extension_multi(
+ &self,
+ rel: &ExtensionMultiRel,
+ ) -> Result<LogicalPlan> {
+ if let Some(detail) = rel.detail.as_ref() {
+ return substrait_err!(
+ "Missing handler for ExtensionMultiRel: {}",
+ detail.type_url
+ );
+ }
+ substrait_err!("Missing handler for ExtensionMultiRel")
+ }
+
+ // Users can bring their own types to Substrait which require custom
handling
+
+ fn consume_user_defined_type(
+ &self,
+ user_defined_type: &r#type::UserDefined,
+ ) -> Result<DataType> {
+ substrait_err!(
+ "Missing handler for user-defined type: {}",
+ user_defined_type.type_reference
+ )
+ }
+
+ fn consume_user_defined_literal(
+ &self,
+ user_defined_literal: &proto::expression::literal::UserDefined,
+ ) -> Result<ScalarValue> {
+ substrait_err!(
+ "Missing handler for user-defined literals {}",
+ user_defined_literal.type_reference
+ )
+ }
+}
+
+/// Convert Substrait Rel to DataFusion DataFrame
+#[async_recursion]
+pub async fn from_substrait_rel(
+ consumer: &impl SubstraitConsumer,
+ relation: &Rel,
+) -> Result<LogicalPlan> {
+ let plan: Result<LogicalPlan> = match &relation.rel_type {
+ Some(rel_type) => match rel_type {
+ RelType::Read(rel) => consumer.consume_read(rel).await,
+ RelType::Filter(rel) => consumer.consume_filter(rel).await,
+ RelType::Fetch(rel) => consumer.consume_fetch(rel).await,
+ RelType::Aggregate(rel) => consumer.consume_aggregate(rel).await,
+ RelType::Sort(rel) => consumer.consume_sort(rel).await,
+ RelType::Join(rel) => consumer.consume_join(rel).await,
+ RelType::Project(rel) => consumer.consume_project(rel).await,
+ RelType::Set(rel) => consumer.consume_set(rel).await,
+ RelType::ExtensionSingle(rel) =>
consumer.consume_extension_single(rel).await,
+ RelType::ExtensionMulti(rel) =>
consumer.consume_extension_multi(rel).await,
+ RelType::ExtensionLeaf(rel) =>
consumer.consume_extension_leaf(rel).await,
+ RelType::Cross(rel) => consumer.consume_cross(rel).await,
+ RelType::Window(rel) => {
+ consumer.consume_consistent_partition_window(rel).await
+ }
+ RelType::Exchange(rel) => consumer.consume_exchange(rel).await,
+ rt => not_impl_err!("{rt:?} rel not supported yet"),
+ },
+ None => return substrait_err!("rel must set rel_type"),
+ };
+ apply_emit_kind(retrieve_rel_common(relation), plan?)
+}
+
+/// Default SubstraitConsumer for converting standard Substrait without
user-defined extensions.
+///
+/// Used as the consumer in [from_substrait_plan]
+pub struct DefaultSubstraitConsumer<'a> {
+ extensions: &'a Extensions,
+ state: &'a SessionState,
+}
+
+impl<'a> DefaultSubstraitConsumer<'a> {
+ pub fn new(extensions: &'a Extensions, state: &'a SessionState) -> Self {
+ DefaultSubstraitConsumer { extensions, state }
+ }
+}
+
+#[async_trait]
+impl SubstraitConsumer for DefaultSubstraitConsumer<'_> {
+ async fn resolve_table_ref(
+ &self,
+ table_ref: &TableReference,
+ ) -> Result<Option<Arc<dyn TableProvider>>> {
+ let table = table_ref.table().to_string();
+ let schema = self.state.schema_for_ref(table_ref.clone())?;
+ let table_provider = schema.table(&table).await?;
+ Ok(table_provider)
+ }
+
+ fn get_extensions(&self) -> &Extensions {
+ self.extensions
+ }
+
+ fn get_function_registry(&self) -> &impl FunctionRegistry {
+ self.state
+ }
+
+ async fn consume_extension_leaf(
+ &self,
+ rel: &ExtensionLeafRel,
+ ) -> Result<LogicalPlan> {
+ let Some(ext_detail) = &rel.detail else {
+ return substrait_err!("Unexpected empty detail in
ExtensionLeafRel");
+ };
+ let plan = self
+ .state
+ .serializer_registry()
+ .deserialize_logical_plan(&ext_detail.type_url,
&ext_detail.value)?;
+ Ok(LogicalPlan::Extension(Extension { node: plan }))
+ }
+
+ async fn consume_extension_single(
+ &self,
+ rel: &ExtensionSingleRel,
+ ) -> Result<LogicalPlan> {
+ let Some(ext_detail) = &rel.detail else {
+ return substrait_err!("Unexpected empty detail in
ExtensionSingleRel");
+ };
+ let plan = self
+ .state
+ .serializer_registry()
+ .deserialize_logical_plan(&ext_detail.type_url,
&ext_detail.value)?;
+ let Some(input_rel) = &rel.input else {
+ return substrait_err!(
+ "ExtensionSingleRel missing input rel, try using
ExtensionLeafRel instead"
+ );
+ };
+ let input_plan = from_substrait_rel(self, input_rel).await?;
+ let plan = plan.with_exprs_and_inputs(plan.expressions(),
vec![input_plan])?;
+ Ok(LogicalPlan::Extension(Extension { node: plan }))
+ }
+
+ async fn consume_extension_multi(
+ &self,
+ rel: &ExtensionMultiRel,
+ ) -> Result<LogicalPlan> {
+ let Some(ext_detail) = &rel.detail else {
+ return substrait_err!("Unexpected empty detail in
ExtensionMultiRel");
+ };
+ let plan = self
+ .state
+ .serializer_registry()
+ .deserialize_logical_plan(&ext_detail.type_url,
&ext_detail.value)?;
+ let mut inputs = Vec::with_capacity(rel.inputs.len());
+ for input in &rel.inputs {
+ let input_plan = from_substrait_rel(self, input).await?;
+ inputs.push(input_plan);
+ }
+ let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?;
+ Ok(LogicalPlan::Extension(Extension { node: plan }))
+ }
+}
// Substrait PrecisionTimestampTz indicates that the timestamp is relative to
UTC, which
// is the same as the expectation for any non-empty timezone in DF, so any
non-empty timezone
@@ -202,16 +661,15 @@ fn split_eq_and_noneq_join_predicate_with_nulls_equality(
}
async fn union_rels(
+ consumer: &impl SubstraitConsumer,
rels: &[Rel],
- state: &dyn SubstraitPlanningState,
- extensions: &Extensions,
is_all: bool,
) -> Result<LogicalPlan> {
let mut union_builder = Ok(LogicalPlanBuilder::from(
- from_substrait_rel(state, &rels[0], extensions).await?,
+ from_substrait_rel(consumer, &rels[0]).await?,
));
for input in &rels[1..] {
- let rel_plan = from_substrait_rel(state, input, extensions).await?;
+ let rel_plan = from_substrait_rel(consumer, input).await?;
union_builder = if is_all {
union_builder?.union(rel_plan)
@@ -223,17 +681,16 @@ async fn union_rels(
}
async fn intersect_rels(
+ consumer: &impl SubstraitConsumer,
rels: &[Rel],
- state: &dyn SubstraitPlanningState,
- extensions: &Extensions,
is_all: bool,
) -> Result<LogicalPlan> {
- let mut rel = from_substrait_rel(state, &rels[0], extensions).await?;
+ let mut rel = from_substrait_rel(consumer, &rels[0]).await?;
for input in &rels[1..] {
rel = LogicalPlanBuilder::intersect(
rel,
- from_substrait_rel(state, input, extensions).await?,
+ from_substrait_rel(consumer, input).await?,
is_all,
)?
}
@@ -242,17 +699,16 @@ async fn intersect_rels(
}
async fn except_rels(
+ consumer: &impl SubstraitConsumer,
rels: &[Rel],
- state: &dyn SubstraitPlanningState,
- extensions: &Extensions,
is_all: bool,
) -> Result<LogicalPlan> {
- let mut rel = from_substrait_rel(state, &rels[0], extensions).await?;
+ let mut rel = from_substrait_rel(consumer, &rels[0]).await?;
for input in &rels[1..] {
rel = LogicalPlanBuilder::except(
rel,
- from_substrait_rel(state, input, extensions).await?,
+ from_substrait_rel(consumer, input).await?,
is_all,
)?
}
@@ -262,7 +718,7 @@ async fn except_rels(
/// Convert Substrait Plan to DataFusion LogicalPlan
pub async fn from_substrait_plan(
- state: &dyn SubstraitPlanningState,
+ state: &SessionState,
plan: &Plan,
) -> Result<LogicalPlan> {
// Register function extension
@@ -271,16 +727,27 @@ pub async fn from_substrait_plan(
return not_impl_err!("Type variation extensions are not supported");
}
- // Parse relations
+ let consumer = DefaultSubstraitConsumer {
+ extensions: &extensions,
+ state,
+ };
+ from_substrait_plan_with_consumer(&consumer, plan).await
+}
+
+/// Convert Substrait Plan to DataFusion LogicalPlan using the given consumer
+pub async fn from_substrait_plan_with_consumer(
+ consumer: &impl SubstraitConsumer,
+ plan: &Plan,
+) -> Result<LogicalPlan> {
match plan.relations.len() {
1 => {
match plan.relations[0].rel_type.as_ref() {
Some(rt) => match rt {
plan_rel::RelType::Rel(rel) => {
- Ok(from_substrait_rel(state, rel, &extensions).await?)
+ Ok(from_substrait_rel(consumer, rel).await?)
},
plan_rel::RelType::Root(root) => {
- let plan = from_substrait_rel(state,
root.input.as_ref().unwrap(), &extensions).await?;
+ let plan = from_substrait_rel(consumer,
root.input.as_ref().unwrap()).await?;
if root.names.is_empty() {
// Backwards compatibility for plans missing names
return Ok(plan);
@@ -341,7 +808,7 @@ pub struct ExprContainer {
/// between systems. This is often useful for scenarios like pushdown where
filter
/// expressions need to be sent to remote systems.
pub async fn from_substrait_extended_expr(
- state: &dyn SubstraitPlanningState,
+ state: &SessionState,
extended_expr: &ExtendedExpression,
) -> Result<ExprContainer> {
// Register function extension
@@ -350,8 +817,13 @@ pub async fn from_substrait_extended_expr(
return not_impl_err!("Type variation extensions are not supported");
}
+ let consumer = DefaultSubstraitConsumer {
+ extensions: &extensions,
+ state,
+ };
+
let input_schema = DFSchemaRef::new(match &extended_expr.base_schema {
- Some(base_schema) => from_substrait_named_struct(base_schema,
&extensions),
+ Some(base_schema) => from_substrait_named_struct(&consumer,
base_schema),
None => {
plan_err!("required property `base_schema` missing from Substrait
ExtendedExpression message")
}
@@ -369,8 +841,7 @@ pub async fn from_substrait_extended_expr(
plan_err!("required property `expr_type` missing from
Substrait ExpressionReference message")
}
}?;
- let expr =
- from_substrait_rex(state, scalar_expr, &input_schema,
&extensions).await?;
+ let expr = from_substrait_rex(&consumer, scalar_expr,
&input_schema).await?;
let (output_type, expected_nullability) =
expr.data_type_and_nullable(&input_schema)?;
let output_field = Field::new("", output_type, expected_nullability);
@@ -557,583 +1028,498 @@ fn make_renamed_schema(
)
}
-/// Convert Substrait Rel to DataFusion DataFrame
-#[allow(deprecated)]
#[async_recursion]
-pub async fn from_substrait_rel(
- state: &dyn SubstraitPlanningState,
- rel: &Rel,
- extensions: &Extensions,
+pub async fn from_project_rel(
+ consumer: &impl SubstraitConsumer,
+ p: &ProjectRel,
) -> Result<LogicalPlan> {
- let plan: Result<LogicalPlan> = match &rel.rel_type {
- Some(RelType::Project(p)) => {
- if let Some(input) = p.input.as_ref() {
- let mut input = LogicalPlanBuilder::from(
- from_substrait_rel(state, input, extensions).await?,
- );
- let original_schema = input.schema().clone();
-
- // Ensure that all expressions have a unique display name, so
that
- // validate_unique_names does not fail when constructing the
project.
- let mut name_tracker = NameTracker::new();
-
- // By default, a Substrait Project emits all inputs fields
followed by all expressions.
- // We build the explicit expressions first, and then the input
expressions to avoid
- // adding aliases to the explicit expressions (as part of
ensuring unique names).
- //
- // This is helpful for plan visualization and tests, because
when DataFusion produces
- // Substrait Projects it adds an output mapping that excludes
all input columns
- // leaving only explicit expressions.
-
- let mut explicit_exprs: Vec<Expr> = vec![];
- for expr in &p.expressions {
- let e = from_substrait_rex(
- state,
- expr,
- input.clone().schema(),
- extensions,
- )
- .await?;
- // if the expression is WindowFunction, wrap in a Window
relation
- if let Expr::WindowFunction(_) = &e {
- // Adding the same expression here and in the project
below
- // works because the project's builder uses
columnize_expr(..)
- // to transform it into a column reference
- input = input.window(vec![e.clone()])?
- }
-
explicit_exprs.push(name_tracker.get_uniquely_named_expr(e)?);
- }
+ if let Some(input) = p.input.as_ref() {
+ let mut input =
+ LogicalPlanBuilder::from(from_substrait_rel(consumer,
input).await?);
+ let original_schema = input.schema().clone();
+
+ // Ensure that all expressions have a unique display name, so that
+ // validate_unique_names does not fail when constructing the project.
+ let mut name_tracker = NameTracker::new();
+
+ // By default, a Substrait Project emits all inputs fields followed by
all expressions.
+ // We build the explicit expressions first, and then the input
expressions to avoid
+ // adding aliases to the explicit expressions (as part of ensuring
unique names).
+ //
+ // This is helpful for plan visualization and tests, because when
DataFusion produces
+ // Substrait Projects it adds an output mapping that excludes all
input columns
+ // leaving only explicit expressions.
+
+ let mut explicit_exprs: Vec<Expr> = vec![];
+ for expr in &p.expressions {
+ let e = from_substrait_rex(consumer, expr,
input.clone().schema()).await?;
+ // if the expression is WindowFunction, wrap in a Window relation
+ if let Expr::WindowFunction(_) = &e {
+ // Adding the same expression here and in the project below
+ // works because the project's builder uses columnize_expr(..)
+ // to transform it into a column reference
+ input = input.window(vec![e.clone()])?
+ }
+ explicit_exprs.push(name_tracker.get_uniquely_named_expr(e)?);
+ }
- let mut final_exprs: Vec<Expr> = vec![];
- for index in 0..original_schema.fields().len() {
- let e = Expr::Column(Column::from(
- original_schema.qualified_field(index),
- ));
- final_exprs.push(name_tracker.get_uniquely_named_expr(e)?);
- }
- final_exprs.append(&mut explicit_exprs);
+ let mut final_exprs: Vec<Expr> = vec![];
+ for index in 0..original_schema.fields().len() {
+ let e =
Expr::Column(Column::from(original_schema.qualified_field(index)));
+ final_exprs.push(name_tracker.get_uniquely_named_expr(e)?);
+ }
+ final_exprs.append(&mut explicit_exprs);
+ input.project(final_exprs)?.build()
+ } else {
+ not_impl_err!("Projection without an input is not supported")
+ }
+}
- input.project(final_exprs)?.build()
- } else {
- not_impl_err!("Projection without an input is not supported")
- }
+#[async_recursion]
+pub async fn from_filter_rel(
+ consumer: &impl SubstraitConsumer,
+ filter: &FilterRel,
+) -> Result<LogicalPlan> {
+ if let Some(input) = filter.input.as_ref() {
+ let input = LogicalPlanBuilder::from(from_substrait_rel(consumer,
input).await?);
+ if let Some(condition) = filter.condition.as_ref() {
+ let expr = from_substrait_rex(consumer, condition,
input.schema()).await?;
+ input.filter(expr)?.build()
+ } else {
+ not_impl_err!("Filter without an condition is not valid")
}
- Some(RelType::Filter(filter)) => {
- if let Some(input) = filter.input.as_ref() {
- let input = LogicalPlanBuilder::from(
- from_substrait_rel(state, input, extensions).await?,
- );
- if let Some(condition) = filter.condition.as_ref() {
- let expr =
- from_substrait_rex(state, condition, input.schema(),
extensions)
- .await?;
- input.filter(expr)?.build()
- } else {
- not_impl_err!("Filter without an condition is not valid")
- }
- } else {
- not_impl_err!("Filter without an input is not valid")
+ } else {
+ not_impl_err!("Filter without an input is not valid")
+ }
+}
+
+#[async_recursion]
+pub async fn from_fetch_rel(
+ consumer: &impl SubstraitConsumer,
+ fetch: &FetchRel,
+) -> Result<LogicalPlan> {
+ if let Some(input) = fetch.input.as_ref() {
+ let input = LogicalPlanBuilder::from(from_substrait_rel(consumer,
input).await?);
+ let empty_schema = DFSchemaRef::new(DFSchema::empty());
+ let offset = match &fetch.offset_mode {
+ Some(fetch_rel::OffsetMode::Offset(offset)) => Some(lit(*offset)),
+ Some(fetch_rel::OffsetMode::OffsetExpr(expr)) => {
+ Some(from_substrait_rex(consumer, expr, &empty_schema).await?)
}
- }
- Some(RelType::Fetch(fetch)) => {
- if let Some(input) = fetch.input.as_ref() {
- let input = LogicalPlanBuilder::from(
- from_substrait_rel(state, input, extensions).await?,
- );
- let empty_schema = DFSchemaRef::new(DFSchema::empty());
- let offset = match &fetch.offset_mode {
- Some(fetch_rel::OffsetMode::Offset(offset)) =>
Some(lit(*offset)),
- Some(fetch_rel::OffsetMode::OffsetExpr(expr)) => Some(
- from_substrait_rex(state, expr, &empty_schema,
extensions)
- .await?,
- ),
- None => None,
- };
- let count = match &fetch.count_mode {
- Some(fetch_rel::CountMode::Count(count)) => {
- // -1 means that ALL records should be returned,
equivalent to None
- (*count != -1).then(|| lit(*count))
- }
- Some(fetch_rel::CountMode::CountExpr(expr)) => Some(
- from_substrait_rex(state, expr, &empty_schema,
extensions)
- .await?,
- ),
- None => None,
- };
- input.limit_by_expr(offset, count)?.build()
- } else {
- not_impl_err!("Fetch without an input is not valid")
+ None => None,
+ };
+ let count = match &fetch.count_mode {
+ Some(fetch_rel::CountMode::Count(count)) => {
+ // -1 means that ALL records should be returned, equivalent to
None
+ (*count != -1).then(|| lit(*count))
}
- }
- Some(RelType::Sort(sort)) => {
- if let Some(input) = sort.input.as_ref() {
- let input = LogicalPlanBuilder::from(
- from_substrait_rel(state, input, extensions).await?,
- );
- let sorts =
- from_substrait_sorts(state, &sort.sorts, input.schema(),
extensions)
- .await?;
- input.sort(sorts)?.build()
- } else {
- not_impl_err!("Sort without an input is not valid")
+ Some(fetch_rel::CountMode::CountExpr(expr)) => {
+ Some(from_substrait_rex(consumer, expr, &empty_schema).await?)
}
+ None => None,
+ };
+ input.limit_by_expr(offset, count)?.build()
+ } else {
+ not_impl_err!("Fetch without an input is not valid")
+ }
+}
+
+pub async fn from_sort_rel(
+ consumer: &impl SubstraitConsumer,
+ sort: &SortRel,
+) -> Result<LogicalPlan> {
+ if let Some(input) = sort.input.as_ref() {
+ let input = LogicalPlanBuilder::from(from_substrait_rel(consumer,
input).await?);
+ let sorts = from_substrait_sorts(consumer, &sort.sorts,
input.schema()).await?;
+ input.sort(sorts)?.build()
+ } else {
+ not_impl_err!("Sort without an input is not valid")
+ }
+}
+
+pub async fn from_aggregate_rel(
+ consumer: &impl SubstraitConsumer,
+ agg: &AggregateRel,
+) -> Result<LogicalPlan> {
+ if let Some(input) = agg.input.as_ref() {
+ let input = LogicalPlanBuilder::from(from_substrait_rel(consumer,
input).await?);
+ let mut ref_group_exprs = vec![];
+
+ for e in &agg.grouping_expressions {
+ let x = from_substrait_rex(consumer, e, input.schema()).await?;
+ ref_group_exprs.push(x);
}
- Some(RelType::Aggregate(agg)) => {
- if let Some(input) = agg.input.as_ref() {
- let input = LogicalPlanBuilder::from(
- from_substrait_rel(state, input, extensions).await?,
- );
- let mut ref_group_exprs = vec![];
- for e in &agg.grouping_expressions {
- let x =
- from_substrait_rex(state, e, input.schema(),
extensions).await?;
- ref_group_exprs.push(x);
+ let mut group_exprs = vec![];
+ let mut aggr_exprs = vec![];
+
+ match agg.groupings.len() {
+ 1 => {
+ group_exprs.extend_from_slice(
+ &from_substrait_grouping(
+ consumer,
+ &agg.groupings[0],
+ &ref_group_exprs,
+ input.schema(),
+ )
+ .await?,
+ );
+ }
+ _ => {
+ let mut grouping_sets = vec![];
+ for grouping in &agg.groupings {
+ let grouping_set = from_substrait_grouping(
+ consumer,
+ grouping,
+ &ref_group_exprs,
+ input.schema(),
+ )
+ .await?;
+ grouping_sets.push(grouping_set);
}
+ // Single-element grouping expression of type
Expr::GroupingSet.
+ // Note that GroupingSet::Rollup would become
GroupingSet::GroupingSets, when
+ // parsed by the producer and consumer, since Substrait does
not have a type dedicated
+ // to ROLLUP. Only vector of Groupings (grouping sets) is
available.
+ group_exprs
+
.push(Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets)));
+ }
+ };
- let mut group_exprs = vec![];
- let mut aggr_exprs = vec![];
-
- match agg.groupings.len() {
- 1 => {
- group_exprs.extend_from_slice(
- &from_substrait_grouping(
- state,
- &agg.groupings[0],
- &ref_group_exprs,
- input.schema(),
- extensions,
- )
- .await?,
- );
- }
- _ => {
- let mut grouping_sets = vec![];
- for grouping in &agg.groupings {
- let grouping_set = from_substrait_grouping(
- state,
- grouping,
- &ref_group_exprs,
- input.schema(),
- extensions,
- )
- .await?;
- grouping_sets.push(grouping_set);
+ for m in &agg.measures {
+ let filter = match &m.filter {
+ Some(fil) => Some(Box::new(
+ from_substrait_rex(consumer, fil, input.schema()).await?,
+ )),
+ None => None,
+ };
+ let agg_func = match &m.measure {
+ Some(f) => {
+ let distinct = match f.invocation {
+ _ if f.invocation == AggregationInvocation::Distinct
as i32 => {
+ true
}
- // Single-element grouping expression of type
Expr::GroupingSet.
- // Note that GroupingSet::Rollup would become
GroupingSet::GroupingSets, when
- // parsed by the producer and consumer, since
Substrait does not have a type dedicated
- // to ROLLUP. Only vector of Groupings (grouping sets)
is available.
-
group_exprs.push(Expr::GroupingSet(GroupingSet::GroupingSets(
- grouping_sets,
- )));
- }
- };
-
- for m in &agg.measures {
- let filter = match &m.filter {
- Some(fil) => Some(Box::new(
- from_substrait_rex(state, fil, input.schema(),
extensions)
- .await?,
- )),
- None => None,
+ _ if f.invocation == AggregationInvocation::All as i32
=> false,
+ _ => false,
};
- let agg_func = match &m.measure {
- Some(f) => {
- let distinct = match f.invocation {
- _ if f.invocation
- == AggregationInvocation::Distinct as i32
=>
- {
- true
- }
- _ if f.invocation
- == AggregationInvocation::All as i32 =>
- {
- false
- }
- _ => false,
- };
- let order_by = if !f.sorts.is_empty() {
- Some(
- from_substrait_sorts(
- state,
- &f.sorts,
- input.schema(),
- extensions,
- )
- .await?,
- )
- } else {
- None
- };
-
- from_substrait_agg_func(
- state,
- f,
- input.schema(),
- extensions,
- filter,
- order_by,
- distinct,
- )
- .await
- }
- None => not_impl_err!(
- "Aggregate without aggregate function is not
supported"
- ),
+ let order_by = if !f.sorts.is_empty() {
+ Some(
+ from_substrait_sorts(consumer, &f.sorts,
input.schema())
+ .await?,
+ )
+ } else {
+ None
};
- aggr_exprs.push(agg_func?.as_ref().clone());
- }
- input.aggregate(group_exprs, aggr_exprs)?.build()
- } else {
- not_impl_err!("Aggregate without an input is not valid")
- }
- }
- Some(RelType::Join(join)) => {
- if join.post_join_filter.is_some() {
- return not_impl_err!(
- "JoinRel with post_join_filter is not yet supported"
- );
- }
- let left: LogicalPlanBuilder = LogicalPlanBuilder::from(
- from_substrait_rel(state, join.left.as_ref().unwrap(),
extensions)
- .await?,
- );
- let right = LogicalPlanBuilder::from(
- from_substrait_rel(state, join.right.as_ref().unwrap(),
extensions)
- .await?,
- );
- let (left, right) = requalify_sides_if_needed(left, right)?;
-
- let join_type = from_substrait_jointype(join.r#type)?;
- // The join condition expression needs full input schema and not
the output schema from join since we lose columns from
- // certain join types such as semi and anti joins
- let in_join_schema = left.schema().join(right.schema())?;
-
- // If join expression exists, parse the `on` condition expression,
build join and return
- // Otherwise, build join with only the filter, without join keys
- match &join.expression.as_ref() {
- Some(expr) => {
- let on = from_substrait_rex(state, expr, &in_join_schema,
extensions)
- .await?;
- // The join expression can contain both equal and
non-equal ops.
- // As of datafusion 31.0.0, the equal and non equal join
conditions are in separate fields.
- // So we extract each part as follows:
- // - If an Eq or IsNotDistinctFrom op is encountered, add
the left column, right column and is_null_equal_nulls to `join_ons` vector
- // - Otherwise we add the expression to join_filter (use
conjunction if filter already exists)
- let (join_ons, nulls_equal_nulls, join_filter) =
-
split_eq_and_noneq_join_predicate_with_nulls_equality(&on);
- let (left_cols, right_cols): (Vec<_>, Vec<_>) =
- itertools::multiunzip(join_ons);
- left.join_detailed(
- right.build()?,
- join_type,
- (left_cols, right_cols),
- join_filter,
- nulls_equal_nulls,
- )?
- .build()
+ from_substrait_agg_func(
+ consumer,
+ f,
+ input.schema(),
+ filter,
+ order_by,
+ distinct,
+ )
+ .await
}
None => {
- let on: Vec<String> = vec![];
- left.join_detailed(
- right.build()?,
- join_type,
- (on.clone(), on),
- None,
- false,
- )?
- .build()
+ not_impl_err!("Aggregate without aggregate function is not
supported")
}
- }
+ };
+ aggr_exprs.push(agg_func?.as_ref().clone());
}
- Some(RelType::Cross(cross)) => {
- let left = LogicalPlanBuilder::from(
- from_substrait_rel(state, cross.left.as_ref().unwrap(),
extensions)
- .await?,
- );
- let right = LogicalPlanBuilder::from(
- from_substrait_rel(state, cross.right.as_ref().unwrap(),
extensions)
- .await?,
- );
- let (left, right) = requalify_sides_if_needed(left, right)?;
- left.cross_join(right.build()?)?.build()
+ input.aggregate(group_exprs, aggr_exprs)?.build()
+ } else {
+ not_impl_err!("Aggregate without an input is not valid")
+ }
+}
+
+pub async fn from_join_rel(
+ consumer: &impl SubstraitConsumer,
+ join: &JoinRel,
+) -> Result<LogicalPlan> {
+ if join.post_join_filter.is_some() {
+ return not_impl_err!("JoinRel with post_join_filter is not yet
supported");
+ }
+
+ let left: LogicalPlanBuilder = LogicalPlanBuilder::from(
+ from_substrait_rel(consumer, join.left.as_ref().unwrap()).await?,
+ );
+ let right = LogicalPlanBuilder::from(
+ from_substrait_rel(consumer, join.right.as_ref().unwrap()).await?,
+ );
+ let (left, right) = requalify_sides_if_needed(left, right)?;
+
+ let join_type = from_substrait_jointype(join.r#type)?;
+ // The join condition expression needs full input schema and not the
output schema from join since we lose columns from
+ // certain join types such as semi and anti joins
+ let in_join_schema = left.schema().join(right.schema())?;
+
+ // If join expression exists, parse the `on` condition expression, build
join and return
+ // Otherwise, build join with only the filter, without join keys
+ match &join.expression.as_ref() {
+ Some(expr) => {
+ let on = from_substrait_rex(consumer, expr,
&in_join_schema).await?;
+ // The join expression can contain both equal and non-equal ops.
+ // As of datafusion 31.0.0, the equal and non equal join
conditions are in separate fields.
+ // So we extract each part as follows:
+ // - If an Eq or IsNotDistinctFrom op is encountered, add the left
column, right column and is_null_equal_nulls to `join_ons` vector
+ // - Otherwise we add the expression to join_filter (use
conjunction if filter already exists)
+ let (join_ons, nulls_equal_nulls, join_filter) =
+ split_eq_and_noneq_join_predicate_with_nulls_equality(&on);
+ let (left_cols, right_cols): (Vec<_>, Vec<_>) =
+ itertools::multiunzip(join_ons);
+ left.join_detailed(
+ right.build()?,
+ join_type,
+ (left_cols, right_cols),
+ join_filter,
+ nulls_equal_nulls,
+ )?
+ .build()
}
- Some(RelType::Read(read)) => {
- async fn read_with_schema(
- state: &dyn SubstraitPlanningState,
- table_ref: TableReference,
- schema: DFSchema,
- projection: &Option<MaskExpression>,
- ) -> Result<LogicalPlan> {
- let schema = schema.replace_qualifier(table_ref.clone());
-
- let plan = {
- let provider = match state.table(&table_ref).await? {
- Some(ref provider) => Arc::clone(provider),
- _ => return plan_err!("No table named '{table_ref}'"),
- };
+ None => {
+ let on: Vec<String> = vec![];
+ left.join_detailed(right.build()?, join_type, (on.clone(), on),
None, false)?
+ .build()
+ }
+ }
+}
- LogicalPlanBuilder::scan(
- table_ref,
- provider_as_source(Arc::clone(&provider)),
- None,
- )?
- .build()?
- };
+pub async fn from_cross_rel(
+ consumer: &impl SubstraitConsumer,
+ cross: &CrossRel,
+) -> Result<LogicalPlan> {
+ let left = LogicalPlanBuilder::from(
+ from_substrait_rel(consumer, cross.left.as_ref().unwrap()).await?,
+ );
+ let right = LogicalPlanBuilder::from(
+ from_substrait_rel(consumer, cross.right.as_ref().unwrap()).await?,
+ );
+ let (left, right) = requalify_sides_if_needed(left, right)?;
+ left.cross_join(right.build()?)?.build()
+}
+
+#[allow(deprecated)]
+pub async fn from_read_rel(
+ consumer: &impl SubstraitConsumer,
+ read: &ReadRel,
+) -> Result<LogicalPlan> {
+ async fn read_with_schema(
+ consumer: &impl SubstraitConsumer,
+ table_ref: TableReference,
+ schema: DFSchema,
+ projection: &Option<MaskExpression>,
+ ) -> Result<LogicalPlan> {
+ let schema = schema.replace_qualifier(table_ref.clone());
+
+ let plan = {
+ let provider = match consumer.resolve_table_ref(&table_ref).await?
{
+ Some(ref provider) => Arc::clone(provider),
+ _ => return plan_err!("No table named '{table_ref}'"),
+ };
- ensure_schema_compatability(plan.schema(), schema.clone())?;
+ LogicalPlanBuilder::scan(
+ table_ref,
+ provider_as_source(Arc::clone(&provider)),
+ None,
+ )?
+ .build()?
+ };
- let schema = apply_masking(schema, projection)?;
+ ensure_schema_compatability(plan.schema(), schema.clone())?;
- apply_projection(plan, schema)
- }
+ let schema = apply_masking(schema, projection)?;
- let named_struct = read.base_schema.as_ref().ok_or_else(|| {
- substrait_datafusion_err!("No base schema provided for Read
Relation")
- })?;
+ apply_projection(plan, schema)
+ }
- let substrait_schema = from_substrait_named_struct(named_struct,
extensions)?;
+ let named_struct = read.base_schema.as_ref().ok_or_else(|| {
+ substrait_datafusion_err!("No base schema provided for Read Relation")
+ })?;
- match &read.as_ref().read_type {
- Some(ReadType::NamedTable(nt)) => {
- let table_reference = match nt.names.len() {
- 0 => {
- return plan_err!("No table name found in
NamedTable");
- }
- 1 => TableReference::Bare {
- table: nt.names[0].clone().into(),
- },
- 2 => TableReference::Partial {
- schema: nt.names[0].clone().into(),
- table: nt.names[1].clone().into(),
- },
- _ => TableReference::Full {
- catalog: nt.names[0].clone().into(),
- schema: nt.names[1].clone().into(),
- table: nt.names[2].clone().into(),
- },
- };
+ let substrait_schema = from_substrait_named_struct(consumer,
named_struct)?;
- read_with_schema(
- state,
- table_reference,
- substrait_schema,
- &read.projection,
- )
- .await
+ match &read.read_type {
+ Some(ReadType::NamedTable(nt)) => {
+ let table_reference = match nt.names.len() {
+ 0 => {
+ return plan_err!("No table name found in NamedTable");
}
- Some(ReadType::VirtualTable(vt)) => {
- if vt.values.is_empty() {
- return Ok(LogicalPlan::EmptyRelation(EmptyRelation {
- produce_one_row: false,
- schema: DFSchemaRef::new(substrait_schema),
- }));
- }
+ 1 => TableReference::Bare {
+ table: nt.names[0].clone().into(),
+ },
+ 2 => TableReference::Partial {
+ schema: nt.names[0].clone().into(),
+ table: nt.names[1].clone().into(),
+ },
+ _ => TableReference::Full {
+ catalog: nt.names[0].clone().into(),
+ schema: nt.names[1].clone().into(),
+ table: nt.names[2].clone().into(),
+ },
+ };
- let values = vt
- .values
- .iter()
- .map(|row| {
- let mut name_idx = 0;
- let lits = row
- .fields
- .iter()
- .map(|lit| {
- name_idx += 1; // top-level names are provided
through schema
- Ok(Expr::Literal(from_substrait_literal(
- lit,
- extensions,
- &named_struct.names,
- &mut name_idx,
- )?))
- })
- .collect::<Result<_>>()?;
- if name_idx != named_struct.names.len() {
- return substrait_err!(
+ read_with_schema(
+ consumer,
+ table_reference,
+ substrait_schema,
+ &read.projection,
+ )
+ .await
+ }
+ Some(ReadType::VirtualTable(vt)) => {
+ if vt.values.is_empty() {
+ return Ok(LogicalPlan::EmptyRelation(EmptyRelation {
+ produce_one_row: false,
+ schema: DFSchemaRef::new(substrait_schema),
+ }));
+ }
+
+ let values = vt
+ .values
+ .iter()
+ .map(|row| {
+ let mut name_idx = 0;
+ let lits = row
+ .fields
+ .iter()
+ .map(|lit| {
+ name_idx += 1; // top-level names are provided
through schema
+ Ok(Expr::Literal(from_substrait_literal(
+ consumer,
+ lit,
+ &named_struct.names,
+ &mut name_idx,
+ )?))
+ })
+ .collect::<Result<_>>()?;
+ if name_idx != named_struct.names.len() {
+ return substrait_err!(
"Names list must match exactly to nested
schema, but found {} uses for {} names",
name_idx,
named_struct.names.len()
);
- }
- Ok(lits)
- })
- .collect::<Result<_>>()?;
+ }
+ Ok(lits)
+ })
+ .collect::<Result<_>>()?;
- Ok(LogicalPlan::Values(Values {
- schema: DFSchemaRef::new(substrait_schema),
- values,
- }))
- }
- Some(ReadType::LocalFiles(lf)) => {
- fn extract_filename(name: &str) -> Option<String> {
- let corrected_url = if name.starts_with("file://")
- && !name.starts_with("file:///")
- {
- name.replacen("file://", "file:///", 1)
- } else {
- name.to_string()
- };
+ Ok(LogicalPlan::Values(Values {
+ schema: DFSchemaRef::new(substrait_schema),
+ values,
+ }))
+ }
+ Some(ReadType::LocalFiles(lf)) => {
+ fn extract_filename(name: &str) -> Option<String> {
+ let corrected_url =
+ if name.starts_with("file://") &&
!name.starts_with("file:///") {
+ name.replacen("file://", "file:///", 1)
+ } else {
+ name.to_string()
+ };
- Url::parse(&corrected_url).ok().and_then(|url| {
- let path = url.path();
- std::path::Path::new(path)
- .file_name()
- .map(|filename|
filename.to_string_lossy().to_string())
- })
- }
+ Url::parse(&corrected_url).ok().and_then(|url| {
+ let path = url.path();
+ std::path::Path::new(path)
+ .file_name()
+ .map(|filename| filename.to_string_lossy().to_string())
+ })
+ }
- // we could use the file name to check the original table
provider
- // TODO: currently does not support multiple local files
- let filename: Option<String> =
- lf.items.first().and_then(|x| match
x.path_type.as_ref() {
- Some(UriFile(name)) => extract_filename(name),
- _ => None,
- });
+ // we could use the file name to check the original table provider
+ // TODO: currently does not support multiple local files
+ let filename: Option<String> =
+ lf.items.first().and_then(|x| match x.path_type.as_ref() {
+ Some(UriFile(name)) => extract_filename(name),
+ _ => None,
+ });
- if lf.items.len() > 1 || filename.is_none() {
- return not_impl_err!("Only single file reads are
supported");
- }
- let name = filename.unwrap();
- // directly use unwrap here since we could determine it is
a valid one
- let table_reference = TableReference::Bare { table:
name.into() };
-
- read_with_schema(
- state,
- table_reference,
- substrait_schema,
- &read.projection,
- )
- .await
- }
- _ => {
- not_impl_err!("Unsupported ReadType: {:?}",
&read.as_ref().read_type)
- }
- }
- }
- Some(RelType::Set(set)) => match set_rel::SetOp::try_from(set.op) {
- Ok(set_op) => {
- if set.inputs.len() < 2 {
- substrait_err!("Set operation requires at least two
inputs")
- } else {
- match set_op {
- set_rel::SetOp::UnionAll => {
- union_rels(&set.inputs, state, extensions,
true).await
- }
- set_rel::SetOp::UnionDistinct => {
- union_rels(&set.inputs, state, extensions,
false).await
- }
- set_rel::SetOp::IntersectionPrimary => {
- LogicalPlanBuilder::intersect(
- from_substrait_rel(state, &set.inputs[0],
extensions)
- .await?,
- union_rels(&set.inputs[1..], state,
extensions, true)
- .await?,
- false,
- )
- }
- set_rel::SetOp::IntersectionMultiset => {
- intersect_rels(&set.inputs, state, extensions,
false).await
- }
- set_rel::SetOp::IntersectionMultisetAll => {
- intersect_rels(&set.inputs, state, extensions,
true).await
- }
- set_rel::SetOp::MinusPrimary => {
- except_rels(&set.inputs, state, extensions,
false).await
- }
- set_rel::SetOp::MinusPrimaryAll => {
- except_rels(&set.inputs, state, extensions,
true).await
- }
- _ => not_impl_err!("Unsupported set operator:
{set_op:?}"),
- }
- }
+ if lf.items.len() > 1 || filename.is_none() {
+ return not_impl_err!("Only single file reads are supported");
}
- Err(e) => not_impl_err!("Invalid set operation type {}: {e}",
set.op),
- },
- Some(RelType::ExtensionLeaf(extension)) => {
- let Some(ext_detail) = &extension.detail else {
- return substrait_err!("Unexpected empty detail in
ExtensionLeafRel");
- };
- let plan = state
- .serializer_registry()
- .deserialize_logical_plan(&ext_detail.type_url,
&ext_detail.value)?;
- Ok(LogicalPlan::Extension(Extension { node: plan }))
+ let name = filename.unwrap();
+ // directly use unwrap here since we could determine it is a valid
one
+ let table_reference = TableReference::Bare { table: name.into() };
+
+ read_with_schema(
+ consumer,
+ table_reference,
+ substrait_schema,
+ &read.projection,
+ )
+ .await
}
- Some(RelType::ExtensionSingle(extension)) => {
- let Some(ext_detail) = &extension.detail else {
- return substrait_err!("Unexpected empty detail in
ExtensionSingleRel");
- };
- let plan = state
- .serializer_registry()
- .deserialize_logical_plan(&ext_detail.type_url,
&ext_detail.value)?;
- let Some(input_rel) = &extension.input else {
- return substrait_err!(
- "ExtensionSingleRel doesn't contains input rel. Try use
ExtensionLeafRel instead"
- );
- };
- let input_plan = from_substrait_rel(state, input_rel,
extensions).await?;
- let plan =
- plan.with_exprs_and_inputs(plan.expressions(),
vec![input_plan])?;
- Ok(LogicalPlan::Extension(Extension { node: plan }))
+ _ => {
+ not_impl_err!("Unsupported ReadType: {:?}", read.read_type)
}
- Some(RelType::ExtensionMulti(extension)) => {
- let Some(ext_detail) = &extension.detail else {
- return substrait_err!("Unexpected empty detail in
ExtensionSingleRel");
- };
- let plan = state
- .serializer_registry()
- .deserialize_logical_plan(&ext_detail.type_url,
&ext_detail.value)?;
- let mut inputs = Vec::with_capacity(extension.inputs.len());
- for input in &extension.inputs {
- let input_plan = from_substrait_rel(state, input,
extensions).await?;
- inputs.push(input_plan);
- }
- let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?;
- Ok(LogicalPlan::Extension(Extension { node: plan }))
+ }
+}
+
+pub async fn from_set_rel(
+ consumer: &impl SubstraitConsumer,
+ set: &SetRel,
+) -> Result<LogicalPlan> {
+ if set.inputs.len() < 2 {
+ substrait_err!("Set operation requires at least two inputs")
+ } else {
+ match set.op() {
+ SetOp::UnionAll => union_rels(consumer, &set.inputs, true).await,
+ SetOp::UnionDistinct => union_rels(consumer, &set.inputs,
false).await,
+ SetOp::IntersectionPrimary => LogicalPlanBuilder::intersect(
+ from_substrait_rel(consumer, &set.inputs[0]).await?,
+ union_rels(consumer, &set.inputs[1..], true).await?,
+ false,
+ ),
+ SetOp::IntersectionMultiset => {
+ intersect_rels(consumer, &set.inputs, false).await
+ }
+ SetOp::IntersectionMultisetAll => {
+ intersect_rels(consumer, &set.inputs, true).await
+ }
+ SetOp::MinusPrimary => except_rels(consumer, &set.inputs,
false).await,
+ SetOp::MinusPrimaryAll => except_rels(consumer, &set.inputs,
true).await,
+ set_op => not_impl_err!("Unsupported set operator: {set_op:?}"),
}
- Some(RelType::Exchange(exchange)) => {
- let Some(input) = exchange.input.as_ref() else {
- return substrait_err!("Unexpected empty input in ExchangeRel");
- };
- let input = Arc::new(from_substrait_rel(state, input,
extensions).await?);
+ }
+}
- let Some(exchange_kind) = &exchange.exchange_kind else {
- return substrait_err!("Unexpected empty input in ExchangeRel");
- };
+pub async fn from_exchange_rel(
+ consumer: &impl SubstraitConsumer,
+ exchange: &ExchangeRel,
+) -> Result<LogicalPlan> {
+ let Some(input) = exchange.input.as_ref() else {
+ return substrait_err!("Unexpected empty input in ExchangeRel");
+ };
+ let input = Arc::new(from_substrait_rel(consumer, input).await?);
- // ref:
https://substrait.io/relations/physical_relations/#exchange-types
- let partitioning_scheme = match exchange_kind {
- ExchangeKind::ScatterByFields(scatter_fields) => {
- let mut partition_columns = vec![];
- let input_schema = input.schema();
- for field_ref in &scatter_fields.fields {
- let column =
- from_substrait_field_reference(field_ref,
input_schema)?;
- partition_columns.push(column);
- }
- Partitioning::Hash(
- partition_columns,
- exchange.partition_count as usize,
- )
- }
- ExchangeKind::RoundRobin(_) => {
- Partitioning::RoundRobinBatch(exchange.partition_count as
usize)
- }
- ExchangeKind::SingleTarget(_)
- | ExchangeKind::MultiTarget(_)
- | ExchangeKind::Broadcast(_) => {
- return not_impl_err!("Unsupported exchange kind:
{exchange_kind:?}");
- }
- };
- Ok(LogicalPlan::Repartition(Repartition {
- input,
- partitioning_scheme,
- }))
+ let Some(exchange_kind) = &exchange.exchange_kind else {
+ return substrait_err!("Unexpected empty input in ExchangeRel");
+ };
+
+ // ref: https://substrait.io/relations/physical_relations/#exchange-types
+ let partitioning_scheme = match exchange_kind {
+ ExchangeKind::ScatterByFields(scatter_fields) => {
+ let mut partition_columns = vec![];
+ let input_schema = input.schema();
+ for field_ref in &scatter_fields.fields {
+ let column = from_substrait_field_reference(field_ref,
input_schema)?;
+ partition_columns.push(column);
+ }
+ Partitioning::Hash(partition_columns, exchange.partition_count as
usize)
+ }
+ ExchangeKind::RoundRobin(_) => {
+ Partitioning::RoundRobinBatch(exchange.partition_count as usize)
+ }
+ ExchangeKind::SingleTarget(_)
+ | ExchangeKind::MultiTarget(_)
+ | ExchangeKind::Broadcast(_) => {
+ return not_impl_err!("Unsupported exchange kind:
{exchange_kind:?}");
}
- _ => not_impl_err!("Unsupported RelType: {:?}", rel.rel_type),
};
- apply_emit_kind(retrieve_rel_common(rel), plan?)
+ Ok(LogicalPlan::Repartition(Repartition {
+ input,
+ partitioning_scheme,
+ }))
}
fn retrieve_rel_common(rel: &Rel) -> Option<&RelCommon> {
@@ -1384,7 +1770,7 @@ fn compatible_nullabilities(
}
/// (Re)qualify the sides of a join if needed, i.e. if the columns from one
side would otherwise
-/// conflict with the columns from the other.
+/// conflict with the columns from the other.
/// Substrait doesn't currently allow specifying aliases, neither for columns
nor for tables. For
/// Substrait the names don't matter since it only refers to columns by
indices, however DataFusion
/// requires columns to be uniquely identifiable, in some places (see e.g.
DFSchema::check_names).
@@ -1430,16 +1816,14 @@ fn from_substrait_jointype(join_type: i32) ->
Result<JoinType> {
/// Convert Substrait Sorts to DataFusion Exprs
pub async fn from_substrait_sorts(
- state: &dyn SubstraitPlanningState,
+ consumer: &impl SubstraitConsumer,
substrait_sorts: &Vec<SortField>,
input_schema: &DFSchema,
- extensions: &Extensions,
) -> Result<Vec<Sort>> {
let mut sorts: Vec<Sort> = vec![];
for s in substrait_sorts {
let expr =
- from_substrait_rex(state, s.expr.as_ref().unwrap(), input_schema,
extensions)
- .await?;
+ from_substrait_rex(consumer, s.expr.as_ref().unwrap(),
input_schema).await?;
let asc_nullfirst = match &s.sort_kind {
Some(k) => match k {
Direction(d) => {
@@ -1480,15 +1864,13 @@ pub async fn from_substrait_sorts(
/// Convert Substrait Expressions to DataFusion Exprs
pub async fn from_substrait_rex_vec(
- state: &dyn SubstraitPlanningState,
+ consumer: &impl SubstraitConsumer,
exprs: &Vec<Expression>,
input_schema: &DFSchema,
- extensions: &Extensions,
) -> Result<Vec<Expr>> {
let mut expressions: Vec<Expr> = vec![];
for expr in exprs {
- let expression =
- from_substrait_rex(state, expr, input_schema, extensions).await?;
+ let expression = from_substrait_rex(consumer, expr,
input_schema).await?;
expressions.push(expression);
}
Ok(expressions)
@@ -1496,16 +1878,15 @@ pub async fn from_substrait_rex_vec(
/// Convert Substrait FunctionArguments to DataFusion Exprs
pub async fn from_substrait_func_args(
- state: &dyn SubstraitPlanningState,
+ consumer: &impl SubstraitConsumer,
arguments: &Vec<FunctionArgument>,
input_schema: &DFSchema,
- extensions: &Extensions,
) -> Result<Vec<Expr>> {
let mut args: Vec<Expr> = vec![];
for arg in arguments {
let arg_expr = match &arg.arg_type {
Some(ArgType::Value(e)) => {
- from_substrait_rex(state, e, input_schema, extensions).await
+ from_substrait_rex(consumer, e, input_schema).await
}
_ => not_impl_err!("Function argument non-Value type not
supported"),
};
@@ -1516,370 +1897,416 @@ pub async fn from_substrait_func_args(
/// Convert Substrait AggregateFunction to DataFusion Expr
pub async fn from_substrait_agg_func(
- state: &dyn SubstraitPlanningState,
+ consumer: &impl SubstraitConsumer,
f: &AggregateFunction,
input_schema: &DFSchema,
- extensions: &Extensions,
filter: Option<Box<Expr>>,
order_by: Option<Vec<SortExpr>>,
distinct: bool,
) -> Result<Arc<Expr>> {
- let args =
- from_substrait_func_args(state, &f.arguments, input_schema,
extensions).await?;
-
- let Some(function_name) = extensions.functions.get(&f.function_reference)
else {
+ let Some(fn_signature) = consumer
+ .get_extensions()
+ .functions
+ .get(&f.function_reference)
+ else {
return plan_err!(
"Aggregate function not registered: function anchor = {:?}",
f.function_reference
);
};
- let function_name = substrait_fun_name(function_name);
- // try udaf first, then built-in aggr fn.
- if let Ok(fun) = state.udaf(function_name) {
- // deal with situation that count(*) got no arguments
- let args = if fun.name() == "count" && args.is_empty() {
- vec![Expr::Literal(ScalarValue::Int64(Some(1)))]
- } else {
- args
- };
-
- Ok(Arc::new(Expr::AggregateFunction(
- expr::AggregateFunction::new_udf(fun, args, distinct, filter,
order_by, None),
- )))
- } else {
- not_impl_err!(
+ let fn_name = substrait_fun_name(fn_signature);
+ let udaf = consumer.get_function_registry().udaf(fn_name);
+ let udaf = udaf.map_err(|_| {
+ not_impl_datafusion_err!(
"Aggregate function {} is not supported: function anchor = {:?}",
- function_name,
+ fn_signature,
f.function_reference
)
- }
+ })?;
+
+ let args = from_substrait_func_args(consumer, &f.arguments,
input_schema).await?;
+
+ // deal with situation that count(*) got no arguments
+ let args = if udaf.name() == "count" && args.is_empty() {
+ vec![Expr::Literal(ScalarValue::Int64(Some(1)))]
+ } else {
+ args
+ };
+
+ Ok(Arc::new(Expr::AggregateFunction(
+ expr::AggregateFunction::new_udf(udaf, args, distinct, filter,
order_by, None),
+ )))
}
/// Convert Substrait Rex to DataFusion Expr
-#[async_recursion]
pub async fn from_substrait_rex(
- state: &dyn SubstraitPlanningState,
- e: &Expression,
+ consumer: &impl SubstraitConsumer,
+ expression: &Expression,
input_schema: &DFSchema,
- extensions: &Extensions,
) -> Result<Expr> {
- match &e.rex_type {
- Some(RexType::SingularOrList(s)) => {
- let substrait_expr = s.value.as_ref().unwrap();
- let substrait_list = s.options.as_ref();
- Ok(Expr::InList(InList {
- expr: Box::new(
- from_substrait_rex(state, substrait_expr, input_schema,
extensions)
- .await?,
- ),
- list: from_substrait_rex_vec(
- state,
- substrait_list,
- input_schema,
- extensions,
- )
- .await?,
- negated: false,
- }))
- }
- Some(RexType::Selection(field_ref)) => {
- Ok(from_substrait_field_reference(field_ref, input_schema)?)
- }
- Some(RexType::IfThen(if_then)) => {
- // Parse `ifs`
- // If the first element does not have a `then` part, then we can
assume it's a base expression
- let mut when_then_expr: Vec<(Box<Expr>, Box<Expr>)> = vec![];
- let mut expr = None;
- for (i, if_expr) in if_then.ifs.iter().enumerate() {
- if i == 0 {
- // Check if the first element is type base expression
- if if_expr.then.is_none() {
- expr = Some(Box::new(
- from_substrait_rex(
- state,
- if_expr.r#if.as_ref().unwrap(),
- input_schema,
- extensions,
- )
- .await?,
- ));
- continue;
- }
- }
- when_then_expr.push((
- Box::new(
- from_substrait_rex(
- state,
- if_expr.r#if.as_ref().unwrap(),
- input_schema,
- extensions,
- )
- .await?,
- ),
- Box::new(
- from_substrait_rex(
- state,
- if_expr.then.as_ref().unwrap(),
- input_schema,
- extensions,
- )
- .await?,
- ),
- ));
+ match &expression.rex_type {
+ Some(t) => match t {
+ RexType::Literal(expr) => consumer.consume_literal(expr).await,
+ RexType::Selection(expr) => {
+ consumer.consume_field_reference(expr, input_schema).await
+ }
+ RexType::ScalarFunction(expr) => {
+ consumer.consume_scalar_function(expr, input_schema).await
+ }
+ RexType::WindowFunction(expr) => {
+ consumer.consume_window_function(expr, input_schema).await
+ }
+ RexType::IfThen(expr) => consumer.consume_if_then(expr,
input_schema).await,
+ RexType::SwitchExpression(expr) => {
+ consumer.consume_switch(expr, input_schema).await
+ }
+ RexType::SingularOrList(expr) => {
+ consumer.consume_singular_or_list(expr, input_schema).await
}
- // Parse `else`
- let else_expr = match &if_then.r#else {
- Some(e) => Some(Box::new(
- from_substrait_rex(state, e, input_schema,
extensions).await?,
- )),
- None => None,
- };
- Ok(Expr::Case(Case {
- expr,
- when_then_expr,
- else_expr,
- }))
- }
- Some(RexType::ScalarFunction(f)) => {
- let Some(fn_name) =
extensions.functions.get(&f.function_reference) else {
- return plan_err!(
- "Scalar function not found: function reference = {:?}",
- f.function_reference
- );
- };
- let fn_name = substrait_fun_name(fn_name);
- let args =
- from_substrait_func_args(state, &f.arguments, input_schema,
extensions)
- .await?;
+ RexType::MultiOrList(expr) => {
+ consumer.consume_multi_or_list(expr, input_schema).await
+ }
- // try to first match the requested function into registered udfs,
then built-in ops
- // and finally built-in expressions
- if let Ok(func) = state.udf(fn_name) {
- Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf(
- func.to_owned(),
- args,
- )))
- } else if let Some(op) = name_to_op(fn_name) {
- if f.arguments.len() < 2 {
- return not_impl_err!(
- "Expect at least two arguments for binary operator
{op:?}, the provided number of operators is {:?}",
- f.arguments.len()
- );
- }
- // Some expressions are binary in DataFusion but take in a
variadic number of args in Substrait.
- // In those cases we iterate through all the arguments,
applying the binary expression against them all
- let combined_expr = args
- .into_iter()
- .fold(None, |combined_expr: Option<Expr>, arg: Expr| {
- Some(match combined_expr {
- Some(expr) => Expr::BinaryExpr(BinaryExpr {
- left: Box::new(expr),
- op,
- right: Box::new(arg),
- }),
- None => arg,
- })
- })
- .unwrap();
+ RexType::Cast(expr) => {
+ consumer.consume_cast(expr.as_ref(), input_schema).await
+ }
- Ok(combined_expr)
- } else if let Some(builder) =
BuiltinExprBuilder::try_from_name(fn_name) {
- builder.build(state, f, input_schema, extensions).await
- } else {
- not_impl_err!("Unsupported function name: {fn_name:?}")
+ RexType::Subquery(expr) => {
+ consumer.consume_subquery(expr.as_ref(), input_schema).await
}
- }
- Some(RexType::Literal(lit)) => {
- let scalar_value = from_substrait_literal_without_names(lit,
extensions)?;
- Ok(Expr::Literal(scalar_value))
- }
- Some(RexType::Cast(cast)) => match cast.as_ref().r#type.as_ref() {
- Some(output_type) => {
- let input_expr = Box::new(
+ RexType::Nested(expr) => consumer.consume_nested(expr,
input_schema).await,
+ RexType::Enum(expr) => consumer.consume_enum(expr,
input_schema).await,
+ },
+ None => substrait_err!("Expression must set rex_type: {:?}",
expression),
+ }
+}
+
+pub async fn from_singular_or_list(
+ consumer: &impl SubstraitConsumer,
+ expr: &SingularOrList,
+ input_schema: &DFSchema,
+) -> Result<Expr> {
+ let substrait_expr = expr.value.as_ref().unwrap();
+ let substrait_list = expr.options.as_ref();
+ Ok(Expr::InList(InList {
+ expr: Box::new(from_substrait_rex(consumer, substrait_expr,
input_schema).await?),
+ list: from_substrait_rex_vec(consumer, substrait_list,
input_schema).await?,
+ negated: false,
+ }))
+}
+
+pub async fn from_field_reference(
+ _consumer: &impl SubstraitConsumer,
+ field_ref: &FieldReference,
+ input_schema: &DFSchema,
+) -> Result<Expr> {
+ from_substrait_field_reference(field_ref, input_schema)
+}
+
+pub async fn from_if_then(
+ consumer: &impl SubstraitConsumer,
+ if_then: &IfThen,
+ input_schema: &DFSchema,
+) -> Result<Expr> {
+ // Parse `ifs`
+ // If the first element does not have a `then` part, then we can assume
it's a base expression
+ let mut when_then_expr: Vec<(Box<Expr>, Box<Expr>)> = vec![];
+ let mut expr = None;
+ for (i, if_expr) in if_then.ifs.iter().enumerate() {
+ if i == 0 {
+ // Check if the first element is type base expression
+ if if_expr.then.is_none() {
+ expr = Some(Box::new(
from_substrait_rex(
- state,
- cast.as_ref().input.as_ref().unwrap().as_ref(),
+ consumer,
+ if_expr.r#if.as_ref().unwrap(),
input_schema,
- extensions,
)
.await?,
- );
- let data_type =
- from_substrait_type_without_names(output_type,
extensions)?;
- if cast.failure_behavior() == ReturnNull {
- Ok(Expr::TryCast(TryCast::new(input_expr, data_type)))
- } else {
- Ok(Expr::Cast(Cast::new(input_expr, data_type)))
- }
+ ));
+ continue;
}
- None => substrait_err!("Cast expression without output type is not
allowed"),
- },
- Some(RexType::WindowFunction(window)) => {
- let Some(fn_name) =
extensions.functions.get(&window.function_reference)
- else {
- return plan_err!(
- "Window function not found: function reference = {:?}",
- window.function_reference
- );
- };
- let fn_name = substrait_fun_name(fn_name);
-
- // check udwf first, then udaf, then built-in window and aggregate
functions
- let fun = if let Ok(udwf) = state.udwf(fn_name) {
- Ok(WindowFunctionDefinition::WindowUDF(udwf))
- } else if let Ok(udaf) = state.udaf(fn_name) {
- Ok(WindowFunctionDefinition::AggregateUDF(udaf))
- } else {
- not_impl_err!(
- "Window function {} is not supported: function anchor =
{:?}",
- fn_name,
- window.function_reference
+ }
+ when_then_expr.push((
+ Box::new(
+ from_substrait_rex(
+ consumer,
+ if_expr.r#if.as_ref().unwrap(),
+ input_schema,
)
- }?;
-
- let order_by =
- from_substrait_sorts(state, &window.sorts, input_schema,
extensions)
- .await?;
-
- let bound_units =
- match BoundsType::try_from(window.bounds_type).map_err(|e| {
- plan_datafusion_err!("Invalid bound type {}: {e}",
window.bounds_type)
- })? {
- BoundsType::Rows => WindowFrameUnits::Rows,
- BoundsType::Range => WindowFrameUnits::Range,
- BoundsType::Unspecified => {
- // If the plan does not specify the bounds type, then
we use a simple logic to determine the units
- // If there is no `ORDER BY`, then by default, the
frame counts each row from the lower up to upper boundary
- // If there is `ORDER BY`, then by default, each frame
is a range starting from unbounded preceding to current row
- if order_by.is_empty() {
- WindowFrameUnits::Rows
- } else {
- WindowFrameUnits::Range
- }
- }
- };
- Ok(Expr::WindowFunction(expr::WindowFunction {
- fun,
- args: from_substrait_func_args(
- state,
- &window.arguments,
+ .await?,
+ ),
+ Box::new(
+ from_substrait_rex(
+ consumer,
+ if_expr.then.as_ref().unwrap(),
input_schema,
- extensions,
)
.await?,
- partition_by: from_substrait_rex_vec(
- state,
- &window.partitions,
+ ),
+ ));
+ }
+ // Parse `else`
+ let else_expr = match &if_then.r#else {
+ Some(e) => Some(Box::new(
+ from_substrait_rex(consumer, e, input_schema).await?,
+ )),
+ None => None,
+ };
+ Ok(Expr::Case(Case {
+ expr,
+ when_then_expr,
+ else_expr,
+ }))
+}
+
+pub async fn from_scalar_function(
+ consumer: &impl SubstraitConsumer,
+ f: &ScalarFunction,
+ input_schema: &DFSchema,
+) -> Result<Expr> {
+ let Some(fn_signature) = consumer
+ .get_extensions()
+ .functions
+ .get(&f.function_reference)
+ else {
+ return plan_err!(
+ "Scalar function not found: function reference = {:?}",
+ f.function_reference
+ );
+ };
+ let fn_name = substrait_fun_name(fn_signature);
+ let args = from_substrait_func_args(consumer, &f.arguments,
input_schema).await?;
+
+ // try to first match the requested function into registered udfs, then
built-in ops
+ // and finally built-in expressions
+ if let Ok(func) = consumer.get_function_registry().udf(fn_name) {
+ Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf(
+ func.to_owned(),
+ args,
+ )))
+ } else if let Some(op) = name_to_op(fn_name) {
+ if f.arguments.len() < 2 {
+ return not_impl_err!(
+ "Expect at least two arguments for binary operator
{op:?}, the provided number of operators is {:?}",
+ f.arguments.len()
+ );
+ }
+ // Some expressions are binary in DataFusion but take in a variadic
number of args in Substrait.
+ // In those cases we iterate through all the arguments, applying the
binary expression against them all
+ let combined_expr = args
+ .into_iter()
+ .fold(None, |combined_expr: Option<Expr>, arg: Expr| {
+ Some(match combined_expr {
+ Some(expr) => Expr::BinaryExpr(BinaryExpr {
+ left: Box::new(expr),
+ op,
+ right: Box::new(arg),
+ }),
+ None => arg,
+ })
+ })
+ .unwrap();
+
+ Ok(combined_expr)
+ } else if let Some(builder) = BuiltinExprBuilder::try_from_name(fn_name) {
+ builder.build(consumer, f, input_schema).await
+ } else {
+ not_impl_err!("Unsupported function name: {fn_name:?}")
+ }
+}
+
+pub async fn from_literal(
+ consumer: &impl SubstraitConsumer,
+ expr: &Literal,
+) -> Result<Expr> {
+ let scalar_value = from_substrait_literal_without_names(consumer, expr)?;
+ Ok(Expr::Literal(scalar_value))
+}
+
+pub async fn from_cast(
+ consumer: &impl SubstraitConsumer,
+ cast: &substrait_expression::Cast,
+ input_schema: &DFSchema,
+) -> Result<Expr> {
+ match cast.r#type.as_ref() {
+ Some(output_type) => {
+ let input_expr = Box::new(
+ from_substrait_rex(
+ consumer,
+ cast.input.as_ref().unwrap().as_ref(),
input_schema,
- extensions,
)
.await?,
- order_by,
- window_frame:
datafusion::logical_expr::WindowFrame::new_bounds(
- bound_units,
- from_substrait_bound(&window.lower_bound, true)?,
- from_substrait_bound(&window.upper_bound, false)?,
- ),
- null_treatment: None,
- }))
+ );
+ let data_type = from_substrait_type_without_names(consumer,
output_type)?;
+ if cast.failure_behavior() == ReturnNull {
+ Ok(Expr::TryCast(TryCast::new(input_expr, data_type)))
+ } else {
+ Ok(Expr::Cast(Cast::new(input_expr, data_type)))
+ }
}
- Some(RexType::Subquery(subquery)) => match
&subquery.as_ref().subquery_type {
- Some(subquery_type) => match subquery_type {
- SubqueryType::InPredicate(in_predicate) => {
- if in_predicate.needles.len() != 1 {
- substrait_err!("InPredicate Subquery type must have
exactly one Needle expression")
- } else {
- let needle_expr = &in_predicate.needles[0];
- let haystack_expr = &in_predicate.haystack;
- if let Some(haystack_expr) = haystack_expr {
- let haystack_expr =
- from_substrait_rel(state, haystack_expr,
extensions)
- .await?;
- let outer_refs = haystack_expr.all_out_ref_exprs();
- Ok(Expr::InSubquery(InSubquery {
- expr: Box::new(
- from_substrait_rex(
- state,
- needle_expr,
- input_schema,
- extensions,
- )
+ None => substrait_err!("Cast expression without output type is not
allowed"),
+ }
+}
+
+pub async fn from_window_function(
+ consumer: &impl SubstraitConsumer,
+ window: &WindowFunction,
+ input_schema: &DFSchema,
+) -> Result<Expr> {
+ let Some(fn_signature) = consumer
+ .get_extensions()
+ .functions
+ .get(&window.function_reference)
+ else {
+ return plan_err!(
+ "Window function not found: function reference = {:?}",
+ window.function_reference
+ );
+ };
+ let fn_name = substrait_fun_name(fn_signature);
+
+ // check udwf first, then udaf, then built-in window and aggregate
functions
+ let fun = if let Ok(udwf) = consumer.get_function_registry().udwf(fn_name)
{
+ Ok(WindowFunctionDefinition::WindowUDF(udwf))
+ } else if let Ok(udaf) = consumer.get_function_registry().udaf(fn_name) {
+ Ok(WindowFunctionDefinition::AggregateUDF(udaf))
+ } else {
+ not_impl_err!(
+ "Window function {} is not supported: function anchor = {:?}",
+ fn_name,
+ window.function_reference
+ )
+ }?;
+
+ let order_by = from_substrait_sorts(consumer, &window.sorts,
input_schema).await?;
+
+ let bound_units = match
BoundsType::try_from(window.bounds_type).map_err(|e| {
+ plan_datafusion_err!("Invalid bound type {}: {e}", window.bounds_type)
+ })? {
+ BoundsType::Rows => WindowFrameUnits::Rows,
+ BoundsType::Range => WindowFrameUnits::Range,
+ BoundsType::Unspecified => {
+ // If the plan does not specify the bounds type, then we use a
simple logic to determine the units
+ // If there is no `ORDER BY`, then by default, the frame counts
each row from the lower up to upper boundary
+ // If there is `ORDER BY`, then by default, each frame is a range
starting from unbounded preceding to current row
+ if order_by.is_empty() {
+ WindowFrameUnits::Rows
+ } else {
+ WindowFrameUnits::Range
+ }
+ }
+ };
+ Ok(Expr::WindowFunction(expr::WindowFunction {
+ fun,
+ args: from_substrait_func_args(consumer, &window.arguments,
input_schema).await?,
+ partition_by: from_substrait_rex_vec(consumer, &window.partitions,
input_schema)
+ .await?,
+ order_by,
+ window_frame: datafusion::logical_expr::WindowFrame::new_bounds(
+ bound_units,
+ from_substrait_bound(&window.lower_bound, true)?,
+ from_substrait_bound(&window.upper_bound, false)?,
+ ),
+ null_treatment: None,
+ }))
+}
+
+pub async fn from_subquery(
+ consumer: &impl SubstraitConsumer,
+ subquery: &substrait_expression::Subquery,
+ input_schema: &DFSchema,
+) -> Result<Expr> {
+ match &subquery.subquery_type {
+ Some(subquery_type) => match subquery_type {
+ SubqueryType::InPredicate(in_predicate) => {
+ if in_predicate.needles.len() != 1 {
+ substrait_err!("InPredicate Subquery type must have
exactly one Needle expression")
+ } else {
+ let needle_expr = &in_predicate.needles[0];
+ let haystack_expr = &in_predicate.haystack;
+ if let Some(haystack_expr) = haystack_expr {
+ let haystack_expr =
+ from_substrait_rel(consumer, haystack_expr).await?;
+ let outer_refs = haystack_expr.all_out_ref_exprs();
+ Ok(Expr::InSubquery(InSubquery {
+ expr: Box::new(
+ from_substrait_rex(consumer, needle_expr,
input_schema)
.await?,
- ),
- subquery: Subquery {
- subquery: Arc::new(haystack_expr),
- outer_ref_columns: outer_refs,
- },
- negated: false,
- }))
- } else {
- substrait_err!("InPredicate Subquery type must
have a Haystack expression")
- }
+ ),
+ subquery: Subquery {
+ subquery: Arc::new(haystack_expr),
+ outer_ref_columns: outer_refs,
+ },
+ negated: false,
+ }))
+ } else {
+ substrait_err!(
+ "InPredicate Subquery type must have a Haystack
expression"
+ )
}
}
- SubqueryType::Scalar(query) => {
- let plan = from_substrait_rel(
- state,
- &(query.input.clone()).unwrap_or_default(),
- extensions,
- )
- .await?;
- let outer_ref_columns = plan.all_out_ref_exprs();
- Ok(Expr::ScalarSubquery(Subquery {
- subquery: Arc::new(plan),
- outer_ref_columns,
- }))
- }
- SubqueryType::SetPredicate(predicate) => {
- match predicate.predicate_op() {
- // exist
- PredicateOp::Exists => {
- let relation = &predicate.tuples;
- let plan = from_substrait_rel(
- state,
- &relation.clone().unwrap_or_default(),
- extensions,
- )
- .await?;
- let outer_ref_columns = plan.all_out_ref_exprs();
- Ok(Expr::Exists(Exists::new(
- Subquery {
- subquery: Arc::new(plan),
- outer_ref_columns,
- },
- false,
- )))
- }
- other_type => substrait_err!(
- "unimplemented type {:?} for set predicate",
- other_type
- ),
+ }
+ SubqueryType::Scalar(query) => {
+ let plan = from_substrait_rel(
+ consumer,
+ &(query.input.clone()).unwrap_or_default(),
+ )
+ .await?;
+ let outer_ref_columns = plan.all_out_ref_exprs();
+ Ok(Expr::ScalarSubquery(Subquery {
+ subquery: Arc::new(plan),
+ outer_ref_columns,
+ }))
+ }
+ SubqueryType::SetPredicate(predicate) => {
+ match predicate.predicate_op() {
+ // exist
+ PredicateOp::Exists => {
+ let relation = &predicate.tuples;
+ let plan = from_substrait_rel(
+ consumer,
+ &relation.clone().unwrap_or_default(),
+ )
+ .await?;
+ let outer_ref_columns = plan.all_out_ref_exprs();
+ Ok(Expr::Exists(Exists::new(
+ Subquery {
+ subquery: Arc::new(plan),
+ outer_ref_columns,
+ },
+ false,
+ )))
}
+ other_type => substrait_err!(
+ "unimplemented type {:?} for set predicate",
+ other_type
+ ),
}
- other_type => {
- substrait_err!("Subquery type {:?} not implemented",
other_type)
- }
- },
- None => {
- substrait_err!("Subquery expression without SubqueryType is
not allowed")
+ }
+ other_type => {
+ substrait_err!("Subquery type {:?} not implemented",
other_type)
}
},
- _ => not_impl_err!("unsupported rex_type"),
+ None => {
+ substrait_err!("Subquery expression without SubqueryType is not
allowed")
+ }
}
}
pub(crate) fn from_substrait_type_without_names(
+ consumer: &impl SubstraitConsumer,
dt: &Type,
- extensions: &Extensions,
) -> Result<DataType> {
- from_substrait_type(dt, extensions, &[], &mut 0)
+ from_substrait_type(consumer, dt, &[], &mut 0)
}
fn from_substrait_type(
+ consumer: &impl SubstraitConsumer,
dt: &Type,
- extensions: &Extensions,
dfs_names: &[String],
name_idx: &mut usize,
) -> Result<DataType> {
@@ -1992,7 +2419,7 @@ fn from_substrait_type(
substrait_datafusion_err!("List type must have inner type")
})?;
let field = Arc::new(Field::new_list_field(
- from_substrait_type(inner_type, extensions, dfs_names,
name_idx)?,
+ from_substrait_type(consumer, inner_type, dfs_names,
name_idx)?,
// We ignore Substrait's nullability here to match
to_substrait_literal
// which always creates nullable lists
true,
@@ -2014,12 +2441,12 @@ fn from_substrait_type(
})?;
let key_field = Arc::new(Field::new(
"key",
- from_substrait_type(key_type, extensions, dfs_names,
name_idx)?,
+ from_substrait_type(consumer, key_type, dfs_names,
name_idx)?,
false,
));
let value_field = Arc::new(Field::new(
"value",
- from_substrait_type(value_type, extensions, dfs_names,
name_idx)?,
+ from_substrait_type(consumer, value_type, dfs_names,
name_idx)?,
true,
));
Ok(DataType::Map(
@@ -2050,42 +2477,48 @@ fn from_substrait_type(
Ok(DataType::Interval(IntervalUnit::MonthDayNano))
}
r#type::Kind::UserDefined(u) => {
- if let Some(name) = extensions.types.get(&u.type_reference) {
+ if let Ok(data_type) = consumer.consume_user_defined_type(u) {
+ return Ok(data_type);
+ }
+
+ // TODO: remove the code below once the producer has been
updated
+ if let Some(name) =
consumer.get_extensions().types.get(&u.type_reference)
+ {
#[allow(deprecated)]
- match name.as_ref() {
- // Kept for backwards compatibility, producers should
use IntervalCompound instead
- INTERVAL_MONTH_DAY_NANO_TYPE_NAME =>
Ok(DataType::Interval(IntervalUnit::MonthDayNano)),
+ match name.as_ref() {
+ // Kept for backwards compatibility, producers
should use IntervalCompound instead
+ INTERVAL_MONTH_DAY_NANO_TYPE_NAME =>
Ok(DataType::Interval(IntervalUnit::MonthDayNano)),
_ => not_impl_err!(
"Unsupported Substrait user defined type with
ref {} and variation {}",
u.type_reference,
u.type_variation_reference
),
- }
+ }
} else {
#[allow(deprecated)]
- match u.type_reference {
- // Kept for backwards compatibility, producers should
use IntervalYear instead
- INTERVAL_YEAR_MONTH_TYPE_REF => {
- Ok(DataType::Interval(IntervalUnit::YearMonth))
- }
- // Kept for backwards compatibility, producers should
use IntervalDay instead
- INTERVAL_DAY_TIME_TYPE_REF => {
- Ok(DataType::Interval(IntervalUnit::DayTime))
- }
- // Kept for backwards compatibility, producers should
use IntervalCompound instead
- INTERVAL_MONTH_DAY_NANO_TYPE_REF => {
- Ok(DataType::Interval(IntervalUnit::MonthDayNano))
- }
- _ => not_impl_err!(
+ match u.type_reference {
+ // Kept for backwards compatibility, producers
should use IntervalYear instead
+ INTERVAL_YEAR_MONTH_TYPE_REF => {
+ Ok(DataType::Interval(IntervalUnit::YearMonth))
+ }
+ // Kept for backwards compatibility, producers
should use IntervalDay instead
+ INTERVAL_DAY_TIME_TYPE_REF => {
+ Ok(DataType::Interval(IntervalUnit::DayTime))
+ }
+ // Kept for backwards compatibility, producers
should use IntervalCompound instead
+ INTERVAL_MONTH_DAY_NANO_TYPE_REF => {
+
Ok(DataType::Interval(IntervalUnit::MonthDayNano))
+ }
+ _ => not_impl_err!(
"Unsupported Substrait user defined type with ref {}
and variation {}",
u.type_reference,
u.type_variation_reference
),
- }
+ }
}
}
r#type::Kind::Struct(s) =>
Ok(DataType::Struct(from_substrait_struct_type(
- s, extensions, dfs_names, name_idx,
+ consumer, s, dfs_names, name_idx,
)?)),
r#type::Kind::Varchar(_) => Ok(DataType::Utf8),
r#type::Kind::FixedChar(_) => Ok(DataType::Utf8),
@@ -2096,8 +2529,8 @@ fn from_substrait_type(
}
fn from_substrait_struct_type(
+ consumer: &impl SubstraitConsumer,
s: &r#type::Struct,
- extensions: &Extensions,
dfs_names: &[String],
name_idx: &mut usize,
) -> Result<Fields> {
@@ -2105,7 +2538,7 @@ fn from_substrait_struct_type(
for (i, f) in s.types.iter().enumerate() {
let field = Field::new(
next_struct_field_name(i, dfs_names, name_idx)?,
- from_substrait_type(f, extensions, dfs_names, name_idx)?,
+ from_substrait_type(consumer, f, dfs_names, name_idx)?,
true, // We assume everything to be nullable since that's easier
than ensuring it matches
);
fields.push(field);
@@ -2133,15 +2566,15 @@ fn next_struct_field_name(
/// Convert Substrait NamedStruct to DataFusion DFSchemaRef
pub fn from_substrait_named_struct(
+ consumer: &impl SubstraitConsumer,
base_schema: &NamedStruct,
- extensions: &Extensions,
) -> Result<DFSchema> {
let mut name_idx = 0;
let fields = from_substrait_struct_type(
+ consumer,
base_schema.r#struct.as_ref().ok_or_else(|| {
substrait_datafusion_err!("Named struct must contain a struct")
})?,
- extensions,
&base_schema.names,
&mut name_idx,
);
@@ -2202,15 +2635,15 @@ fn from_substrait_bound(
}
pub(crate) fn from_substrait_literal_without_names(
+ consumer: &impl SubstraitConsumer,
lit: &Literal,
- extensions: &Extensions,
) -> Result<ScalarValue> {
- from_substrait_literal(lit, extensions, &vec![], &mut 0)
+ from_substrait_literal(consumer, lit, &vec![], &mut 0)
}
fn from_substrait_literal(
+ consumer: &impl SubstraitConsumer,
lit: &Literal,
- extensions: &Extensions,
dfs_names: &Vec<String>,
name_idx: &mut usize,
) -> Result<ScalarValue> {
@@ -2346,12 +2779,7 @@ fn from_substrait_literal(
.iter()
.map(|el| {
element_name_idx = *name_idx;
- from_substrait_literal(
- el,
- extensions,
- dfs_names,
- &mut element_name_idx,
- )
+ from_substrait_literal(consumer, el, dfs_names, &mut
element_name_idx)
})
.collect::<Result<Vec<_>>>()?;
*name_idx = element_name_idx;
@@ -2375,8 +2803,8 @@ fn from_substrait_literal(
}
Some(LiteralType::EmptyList(l)) => {
let element_type = from_substrait_type(
+ consumer,
l.r#type.clone().unwrap().as_ref(),
- extensions,
dfs_names,
name_idx,
)?;
@@ -2402,14 +2830,14 @@ fn from_substrait_literal(
.map(|kv| {
entry_name_idx = *name_idx;
let key_sv = from_substrait_literal(
+ consumer,
kv.key.as_ref().unwrap(),
- extensions,
dfs_names,
&mut entry_name_idx,
)?;
let value_sv = from_substrait_literal(
+ consumer,
kv.value.as_ref().unwrap(),
- extensions,
dfs_names,
&mut entry_name_idx,
)?;
@@ -2447,8 +2875,8 @@ fn from_substrait_literal(
Some(v) => Ok(v),
_ => plan_err!("Missing value type for empty map"),
}?;
- let key_type = from_substrait_type(key, extensions, dfs_names,
name_idx)?;
- let value_type = from_substrait_type(value, extensions, dfs_names,
name_idx)?;
+ let key_type = from_substrait_type(consumer, key, dfs_names,
name_idx)?;
+ let value_type = from_substrait_type(consumer, value, dfs_names,
name_idx)?;
// new_empty_array on a MapType creates a too empty array
// We want it to contain an empty struct array to align with an
empty MapBuilder one
@@ -2474,7 +2902,7 @@ fn from_substrait_literal(
let mut builder = ScalarStructBuilder::new();
for (i, field) in s.fields.iter().enumerate() {
let name = next_struct_field_name(i, dfs_names, name_idx)?;
- let sv = from_substrait_literal(field, extensions, dfs_names,
name_idx)?;
+ let sv = from_substrait_literal(consumer, field, dfs_names,
name_idx)?;
// We assume everything to be nullable, since Arrow's strict
about things matching
// and it's hard to match otherwise.
builder = builder.with_scalar(Field::new(name, sv.data_type(),
true), sv);
@@ -2482,7 +2910,7 @@ fn from_substrait_literal(
builder.build()?
}
Some(LiteralType::Null(ntype)) => {
- from_substrait_null(ntype, extensions, dfs_names, name_idx)?
+ from_substrait_null(consumer, ntype, dfs_names, name_idx)?
}
Some(LiteralType::IntervalDayToSecond(IntervalDayToSecond {
days,
@@ -2546,9 +2974,15 @@ fn from_substrait_literal(
},
Some(LiteralType::FixedChar(c)) => ScalarValue::Utf8(Some(c.clone())),
Some(LiteralType::UserDefined(user_defined)) => {
+ if let Ok(value) =
consumer.consume_user_defined_literal(user_defined) {
+ return Ok(value);
+ }
+
+ // TODO: remove the code below once the producer has been updated
+
// Helper function to prevent duplicating this code - can be
inlined once the non-extension path is removed
let interval_month_day_nano =
- |user_defined: &UserDefined| -> Result<ScalarValue> {
+ |user_defined: &proto::expression::literal::UserDefined| ->
Result<ScalarValue> {
let Some(Val::Value(raw_val)) = user_defined.val.as_ref()
else {
return substrait_err!("Interval month day nano value
is empty");
};
@@ -2572,7 +3006,11 @@ fn from_substrait_literal(
)))
};
- if let Some(name) =
extensions.types.get(&user_defined.type_reference) {
+ if let Some(name) = consumer
+ .get_extensions()
+ .types
+ .get(&user_defined.type_reference)
+ {
match name.as_ref() {
// Kept for backwards compatibility - producers should use
IntervalCompound instead
#[allow(deprecated)]
@@ -2645,8 +3083,8 @@ fn from_substrait_literal(
}
fn from_substrait_null(
+ consumer: &impl SubstraitConsumer,
null_type: &Type,
- extensions: &Extensions,
dfs_names: &[String],
name_idx: &mut usize,
) -> Result<ScalarValue> {
@@ -2764,8 +3202,8 @@ fn from_substrait_null(
r#type::Kind::List(l) => {
let field = Field::new_list_field(
from_substrait_type(
+ consumer,
l.r#type.clone().unwrap().as_ref(),
- extensions,
dfs_names,
name_idx,
)?,
@@ -2792,9 +3230,9 @@ fn from_substrait_null(
})?;
let key_type =
- from_substrait_type(key_type, extensions, dfs_names,
name_idx)?;
+ from_substrait_type(consumer, key_type, dfs_names,
name_idx)?;
let value_type =
- from_substrait_type(value_type, extensions, dfs_names,
name_idx)?;
+ from_substrait_type(consumer, value_type, dfs_names,
name_idx)?;
let entries_field = Arc::new(Field::new_struct(
"entries",
vec![
@@ -2808,7 +3246,7 @@ fn from_substrait_null(
}
r#type::Kind::Struct(s) => {
let fields =
- from_substrait_struct_type(s, extensions, dfs_names,
name_idx)?;
+ from_substrait_struct_type(consumer, s, dfs_names,
name_idx)?;
Ok(ScalarStructBuilder::new_null(fields))
}
_ => not_impl_err!("Unsupported Substrait type for null:
{kind:?}"),
@@ -2820,16 +3258,15 @@ fn from_substrait_null(
#[allow(deprecated)]
async fn from_substrait_grouping(
- state: &dyn SubstraitPlanningState,
+ consumer: &impl SubstraitConsumer,
grouping: &Grouping,
expressions: &[Expr],
input_schema: &DFSchemaRef,
- extensions: &Extensions,
) -> Result<Vec<Expr>> {
let mut group_exprs = vec![];
if !grouping.grouping_expressions.is_empty() {
for e in &grouping.grouping_expressions {
- let expr = from_substrait_rex(state, e, input_schema,
extensions).await?;
+ let expr = from_substrait_rex(consumer, e, input_schema).await?;
group_exprs.push(expr);
}
return Ok(group_exprs);
@@ -2882,29 +3319,17 @@ impl BuiltinExprBuilder {
pub async fn build(
self,
- state: &dyn SubstraitPlanningState,
+ consumer: &impl SubstraitConsumer,
f: &ScalarFunction,
input_schema: &DFSchema,
- extensions: &Extensions,
) -> Result<Expr> {
match self.expr_name.as_str() {
- "like" => {
- Self::build_like_expr(state, false, f, input_schema,
extensions).await
- }
- "ilike" => {
- Self::build_like_expr(state, true, f, input_schema,
extensions).await
- }
+ "like" => Self::build_like_expr(consumer, false, f,
input_schema).await,
+ "ilike" => Self::build_like_expr(consumer, true, f,
input_schema).await,
"not" | "negative" | "negate" | "is_null" | "is_not_null" |
"is_true"
| "is_false" | "is_not_true" | "is_not_false" | "is_unknown"
| "is_not_unknown" => {
- Self::build_unary_expr(
- state,
- &self.expr_name,
- f,
- input_schema,
- extensions,
- )
- .await
+ Self::build_unary_expr(consumer, &self.expr_name, f,
input_schema).await
}
_ => {
not_impl_err!("Unsupported builtin expression: {}",
self.expr_name)
@@ -2913,11 +3338,10 @@ impl BuiltinExprBuilder {
}
async fn build_unary_expr(
- state: &dyn SubstraitPlanningState,
+ consumer: &impl SubstraitConsumer,
fn_name: &str,
f: &ScalarFunction,
input_schema: &DFSchema,
- extensions: &Extensions,
) -> Result<Expr> {
if f.arguments.len() != 1 {
return substrait_err!("Expect one argument for {fn_name} expr");
@@ -2925,8 +3349,7 @@ impl BuiltinExprBuilder {
let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type
else {
return substrait_err!("Invalid arguments type for {fn_name} expr");
};
- let arg =
- from_substrait_rex(state, expr_substrait, input_schema,
extensions).await?;
+ let arg = from_substrait_rex(consumer, expr_substrait,
input_schema).await?;
let arg = Box::new(arg);
let expr = match fn_name {
@@ -2947,11 +3370,10 @@ impl BuiltinExprBuilder {
}
async fn build_like_expr(
- state: &dyn SubstraitPlanningState,
+ consumer: &impl SubstraitConsumer,
case_insensitive: bool,
f: &ScalarFunction,
input_schema: &DFSchema,
- extensions: &Extensions,
) -> Result<Expr> {
let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" };
if f.arguments.len() != 2 && f.arguments.len() != 3 {
@@ -2961,14 +3383,12 @@ impl BuiltinExprBuilder {
let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type
else {
return substrait_err!("Invalid arguments type for `{fn_name}`
expr");
};
- let expr =
- from_substrait_rex(state, expr_substrait, input_schema,
extensions).await?;
+ let expr = from_substrait_rex(consumer, expr_substrait,
input_schema).await?;
let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type
else {
return substrait_err!("Invalid arguments type for `{fn_name}`
expr");
};
let pattern =
- from_substrait_rex(state, pattern_substrait, input_schema,
extensions)
- .await?;
+ from_substrait_rex(consumer, pattern_substrait,
input_schema).await?;
// Default case: escape character is Literal(Utf8(None))
let escape_char = if f.arguments.len() == 3 {
@@ -2977,13 +3397,8 @@ impl BuiltinExprBuilder {
return substrait_err!("Invalid arguments type for `{fn_name}`
expr");
};
- let escape_char_expr = from_substrait_rex(
- state,
- escape_char_substrait,
- input_schema,
- extensions,
- )
- .await?;
+ let escape_char_expr =
+ from_substrait_rex(consumer, escape_char_substrait,
input_schema).await?;
match escape_char_expr {
Expr::Literal(ScalarValue::Utf8(escape_char_string)) => {
@@ -3013,16 +3428,29 @@ impl BuiltinExprBuilder {
#[cfg(test)]
mod test {
use crate::extensions::Extensions;
- use crate::logical_plan::consumer::from_substrait_literal_without_names;
+ use crate::logical_plan::consumer::{
+ from_substrait_literal_without_names, DefaultSubstraitConsumer,
+ };
use arrow_buffer::IntervalMonthDayNano;
use datafusion::error::Result;
+ use datafusion::execution::SessionState;
+ use datafusion::prelude::SessionContext;
use datafusion::scalar::ScalarValue;
+ use std::sync::OnceLock;
use substrait::proto::expression::literal::{
interval_day_to_second, IntervalCompound, IntervalDayToSecond,
IntervalYearToMonth, LiteralType,
};
use substrait::proto::expression::Literal;
+ static TEST_SESSION_STATE: OnceLock<SessionState> = OnceLock::new();
+ static TEST_EXTENSIONS: OnceLock<Extensions> = OnceLock::new();
+ fn test_consumer() -> DefaultSubstraitConsumer<'static> {
+ let extensions = TEST_EXTENSIONS.get_or_init(Extensions::default);
+ let state = TEST_SESSION_STATE.get_or_init(||
SessionContext::default().state());
+ DefaultSubstraitConsumer::new(extensions, state)
+ }
+
#[test]
fn interval_compound_different_precision() -> Result<()> {
// DF producer (and thus roundtrip) always uses precision = 9,
@@ -3046,8 +3474,9 @@ mod test {
})),
};
+ let consumer = test_consumer();
assert_eq!(
- from_substrait_literal_without_names(&substrait,
&Extensions::default())?,
+ from_substrait_literal_without_names(&consumer, &substrait)?,
ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano {
months: 14,
days: 3,
diff --git a/datafusion/substrait/src/logical_plan/producer.rs
b/datafusion/substrait/src/logical_plan/producer.rs
index 375cb734f5..5191a620b4 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -2211,11 +2211,11 @@ fn substrait_field_ref(index: usize) ->
Result<Expression> {
#[cfg(test)]
mod test {
-
use super::*;
use crate::logical_plan::consumer::{
from_substrait_extended_expr, from_substrait_literal_without_names,
from_substrait_named_struct, from_substrait_type_without_names,
+ DefaultSubstraitConsumer,
};
use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
use datafusion::arrow::array::{
@@ -2224,7 +2224,17 @@ mod test {
use datafusion::arrow::datatypes::{Field, Fields, Schema};
use datafusion::common::scalar::ScalarStructBuilder;
use datafusion::common::DFSchema;
- use datafusion::execution::SessionStateBuilder;
+ use datafusion::execution::{SessionState, SessionStateBuilder};
+ use datafusion::prelude::SessionContext;
+ use std::sync::OnceLock;
+
+ static TEST_SESSION_STATE: OnceLock<SessionState> = OnceLock::new();
+ static TEST_EXTENSIONS: OnceLock<Extensions> = OnceLock::new();
+ fn test_consumer() -> DefaultSubstraitConsumer<'static> {
+ let extensions = TEST_EXTENSIONS.get_or_init(Extensions::default);
+ let state = TEST_SESSION_STATE.get_or_init(||
SessionContext::default().state());
+ DefaultSubstraitConsumer::new(extensions, state)
+ }
#[test]
fn round_trip_literals() -> Result<()> {
@@ -2350,7 +2360,7 @@ mod test {
let mut extensions = Extensions::default();
let substrait_literal = to_substrait_literal(&scalar, &mut
extensions)?;
let roundtrip_scalar =
- from_substrait_literal_without_names(&substrait_literal,
&extensions)?;
+ from_substrait_literal_without_names(&test_consumer(),
&substrait_literal)?;
assert_eq!(scalar, roundtrip_scalar);
Ok(())
}
@@ -2429,8 +2439,8 @@ mod test {
// As DataFusion doesn't consider nullability as a property of the
type, but field,
// it doesn't matter if we set nullability to true or false here.
let substrait = to_substrait_type(&dt, true)?;
- let roundtrip_dt =
- from_substrait_type_without_names(&substrait,
&Extensions::default())?;
+ let consumer = test_consumer();
+ let roundtrip_dt = from_substrait_type_without_names(&consumer,
&substrait)?;
assert_eq!(dt, roundtrip_dt);
Ok(())
}
@@ -2481,7 +2491,7 @@ mod test {
);
let roundtrip_schema =
- from_substrait_named_struct(&named_struct,
&Extensions::default())?;
+ from_substrait_named_struct(&test_consumer(), &named_struct)?;
assert_eq!(schema.as_ref(), &roundtrip_schema);
Ok(())
}
diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index 1291bbd6a2..1ce0eec1b2 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -30,6 +30,7 @@ use datafusion::common::{not_impl_err, plan_err, DFSchema,
DFSchemaRef};
use datafusion::error::Result;
use datafusion::execution::registry::SerializerRegistry;
use datafusion::execution::runtime_env::RuntimeEnv;
+use datafusion::execution::session_state::SessionStateBuilder;
use datafusion::logical_expr::{
Extension, LogicalPlan, PartitionEvaluator, Repartition,
UserDefinedLogicalNode,
Values, Volatility,
@@ -38,8 +39,6 @@ use
datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLI
use datafusion::prelude::*;
use std::hash::Hash;
use std::sync::Arc;
-
-use datafusion::execution::session_state::SessionStateBuilder;
use substrait::proto::extensions::simple_extension_declaration::MappingType;
use substrait::proto::rel::RelType;
use substrait::proto::{plan_rel, Plan, Rel};
diff --git a/datafusion/substrait/tests/utils.rs
b/datafusion/substrait/tests/utils.rs
index 00cbfb0c41..b9e5e0e525 100644
--- a/datafusion/substrait/tests/utils.rs
+++ b/datafusion/substrait/tests/utils.rs
@@ -24,7 +24,9 @@ pub mod test {
use datafusion::error::Result;
use datafusion::prelude::SessionContext;
use datafusion_substrait::extensions::Extensions;
- use
datafusion_substrait::logical_plan::consumer::from_substrait_named_struct;
+ use datafusion_substrait::logical_plan::consumer::{
+ from_substrait_named_struct, DefaultSubstraitConsumer,
SubstraitConsumer,
+ };
use std::collections::HashMap;
use std::fs::File;
use std::io::BufReader;
@@ -50,7 +52,18 @@ pub mod test {
ctx: SessionContext,
plan: &Plan,
) -> Result<SessionContext> {
- let schemas = TestSchemaCollector::collect_schemas(plan)?;
+ let extensions = Extensions::default();
+ let state = ctx.state();
+ let consumer = DefaultSubstraitConsumer::new(&extensions, &state);
+ add_plan_schemas_to_ctx_with_consumer(&consumer, ctx, plan)
+ }
+
+ fn add_plan_schemas_to_ctx_with_consumer(
+ consumer: &impl SubstraitConsumer,
+ ctx: SessionContext,
+ plan: &Plan,
+ ) -> Result<SessionContext> {
+ let schemas = TestSchemaCollector::collect_schemas(consumer, plan)?;
let mut schema_map: HashMap<TableReference, Arc<dyn TableProvider>> =
HashMap::new();
for (table_reference, table) in schemas.into_iter() {
@@ -71,21 +84,24 @@ pub mod test {
Ok(ctx)
}
- pub struct TestSchemaCollector {
+ pub struct TestSchemaCollector<'a, T: SubstraitConsumer> {
+ consumer: &'a T,
schemas: Vec<(TableReference, Arc<dyn TableProvider>)>,
}
- impl TestSchemaCollector {
- fn new() -> Self {
+ impl<'a, T: SubstraitConsumer> TestSchemaCollector<'a, T> {
+ fn new(consumer: &'a T) -> Self {
TestSchemaCollector {
schemas: Vec::new(),
+ consumer,
}
}
fn collect_schemas(
+ consumer: &'a T,
plan: &Plan,
) -> Result<Vec<(TableReference, Arc<dyn TableProvider>)>> {
- let mut schema_collector = Self::new();
+ let mut schema_collector = Self::new(consumer);
for plan_rel in plan.relations.iter() {
let rel_type = plan_rel
@@ -132,15 +148,8 @@ pub mod test {
"No base schema found for NamedTable: {}",
table_reference
))?;
- let empty_extensions = Extensions {
- functions: Default::default(),
- types: Default::default(),
- type_variations: Default::default(),
- };
-
- let df_schema =
- from_substrait_named_struct(substrait_schema,
&empty_extensions)?
- .replace_qualifier(table_reference.clone());
+ let df_schema = from_substrait_named_struct(self.consumer,
substrait_schema)?
+ .replace_qualifier(table_reference.clone());
let table = EmptyTable::new(df_schema.inner().clone());
self.schemas.push((table_reference, Arc::new(table)));
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]