This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 0d27fcb04 Add datafusion-substrait crate (#4543)
0d27fcb04 is described below
commit 0d27fcb04b71693adf570346507ca6282f8cba71
Author: Andy Grove <[email protected]>
AuthorDate: Thu Jan 12 15:18:23 2023 -0700
Add datafusion-substrait crate (#4543)
* Initial commit
* initial commit
* failing test
* table scan projection
* closer
* test passes, with some hacks
* use DataFrame (#2)
* update README
* update dependency
* code cleanup (#3)
* Add support for Filter operator and BinaryOp expressions (#4)
* GitHub action (#5)
* Split code into producer and consumer modules (#6)
* Support more functions and scalar types (#7)
* Use substrait 0.1 and datafusion 8.0 (#8)
* use substrait 0.1
* use datafusion 8.0
* update datafusion to 10.0 and substrait to 0.2 (#11)
* Add basic join support (#12)
* Added fetch support (#23)
Added fetch to consumer
Added limit to producer
Added unit tests for limit
Added roundtrip_fill_none() for testing when None input can be converted to 0
Update src/consumer.rs
Co-authored-by: Andy Grove <[email protected]>
Co-authored-by: Andy Grove <[email protected]>
* Upgrade to DataFusion 13.0.0 (#25)
* Add sort consumer and producer (#24)
Add consumer
Add producer and test
Modified error string
* Add serializer/deserializer (#26)
* Add plan and function extension support (#27)
* Add plan and function extension support
* Removed unwraps
* Implement GROUP BY (#28)
* Add consumer, producer and tests for aggregate relation
Change function extension registration from absolute to relative anchor
(reference)
Remove operator to/from reference
* Fixed function registration bug
* Add test
* Addressed PR comments
* Changed field reference from mask to direct reference (#29)
* Changed field reference from masked reference to direct reference
* Handle unsupported case (struct with child)
* Handle SubqueryAlias (#30)
Fixed aggregate function register bug
* Add support for SELECT DISTINCT (#31)
Add test case
* Implement BETWEEN (#32)
* Add case (#33)
* Implement CASE WHEN
* Add more case to test
* Addressed comments
* feat: support explicit catalog/schema names in ReadRel (#34)
* feat: support explicit catalog/schema names in ReadRel
Signed-off-by: Ruihang Xia <[email protected]>
* fix: use re-exported expr crate
Signed-off-by: Ruihang Xia <[email protected]>
Signed-off-by: Ruihang Xia <[email protected]>
* move files to subfolder
* RAT
* remove rust.yaml
* revert .gitignore changes
* tomlfmt
* tomlfmt
Signed-off-by: Ruihang Xia <[email protected]>
Co-authored-by: Daniƫl Heres <[email protected]>
Co-authored-by: JanKaul <[email protected]>
Co-authored-by: nseekhao <[email protected]>
Co-authored-by: Ruihang Xia <[email protected]>
---
datafusion/substrait/Cargo.toml | 32 ++
datafusion/substrait/README.md | 34 ++
datafusion/substrait/src/consumer.rs | 544 ++++++++++++++++++++++++++
datafusion/substrait/src/lib.rs | 20 +
datafusion/substrait/src/producer.rs | 557 +++++++++++++++++++++++++++
datafusion/substrait/src/serializer.rs | 57 +++
datafusion/substrait/tests/roundtrip.rs | 273 +++++++++++++
datafusion/substrait/tests/serialize.rs | 62 +++
datafusion/substrait/tests/testdata/data.csv | 3 +
9 files changed, 1582 insertions(+)
diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml
new file mode 100644
index 000000000..b8c0e56d2
--- /dev/null
+++ b/datafusion/substrait/Cargo.toml
@@ -0,0 +1,32 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+[package]
+name = "datafusion-substrait"
+version = "0.1.0"
+edition = "2021"
+
+[dependencies]
+async-recursion = "1.0"
+datafusion = "13.0"
+prost = "0.9"
+prost-types = "0.9"
+substrait = "0.2"
+tokio = "1.17"
+
+[build-dependencies]
+prost-build = { version = "0.9" }
diff --git a/datafusion/substrait/README.md b/datafusion/substrait/README.md
new file mode 100644
index 000000000..9f21d514a
--- /dev/null
+++ b/datafusion/substrait/README.md
@@ -0,0 +1,34 @@
+<!---
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+# DataFusion + Substrait
+
+[Substrait](https://substrait.io/) provides a cross-language serialization
format for relational algebra, based on
+protocol buffers.
+
+This repository provides a Substrait producer and consumer for DataFusion:
+
+- The producer converts a DataFusion logical plan into a Substrait protobuf.
+- The consumer converts a Substrait protobuf into a DataFusion logical plan.
+
+Potential uses of this crate:
+
+- Replace the current [DataFusion protobuf
definition](https://github.com/apache/arrow-datafusion/blob/master/datafusion-proto/proto/datafusion.proto)
used in Ballista for passing query plan fragments to executors
+- Make it easier to pass query plans over FFI boundaries, such as from Python
to Rust
+- Allow Apache Calcite query plans to be executed in DataFusion
diff --git a/datafusion/substrait/src/consumer.rs
b/datafusion/substrait/src/consumer.rs
new file mode 100644
index 000000000..c747a30a6
--- /dev/null
+++ b/datafusion/substrait/src/consumer.rs
@@ -0,0 +1,544 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use async_recursion::async_recursion;
+use datafusion::common::{DFField, DFSchema, DFSchemaRef};
+use datafusion::logical_expr::{LogicalPlan, aggregate_function};
+use datafusion::logical_plan::build_join_schema;
+use datafusion::prelude::JoinType;
+use datafusion::{
+ error::{DataFusionError, Result},
+ logical_plan::{Expr, Operator},
+ optimizer::utils::split_conjunction,
+ prelude::{Column, DataFrame, SessionContext},
+ scalar::ScalarValue,
+};
+
+use datafusion::sql::TableReference;
+use substrait::protobuf::{
+ aggregate_function::AggregationInvocation,
+ expression::{
+ field_reference::ReferenceType::DirectReference,
+ literal::LiteralType,
+ MaskExpression,
+ reference_segment::ReferenceType::StructField,
+ RexType,
+ },
+ extensions::simple_extension_declaration::MappingType,
+ function_argument::ArgType,
+ read_rel::ReadType,
+ rel::RelType,
+ sort_field::{SortKind::*, SortDirection},
+ AggregateFunction, Expression, Plan, Rel,
+};
+
+use std::collections::HashMap;
+use std::str::FromStr;
+use std::sync::Arc;
+
+pub fn name_to_op(name: &str) -> Result<Operator> {
+ match name {
+ "equal" => Ok(Operator::Eq),
+ "not_equal" => Ok(Operator::NotEq),
+ "lt" => Ok(Operator::Lt),
+ "lte" => Ok(Operator::LtEq),
+ "gt" => Ok(Operator::Gt),
+ "gte" => Ok(Operator::GtEq),
+ "add" => Ok(Operator::Plus),
+ "subtract" => Ok(Operator::Minus),
+ "multiply" => Ok(Operator::Multiply),
+ "divide" => Ok(Operator::Divide),
+ "mod" => Ok(Operator::Modulo),
+ "and" => Ok(Operator::And),
+ "or" => Ok(Operator::Or),
+ "like" => Ok(Operator::Like),
+ "not_like" => Ok(Operator::NotLike),
+ "is_distinct_from" => Ok(Operator::IsDistinctFrom),
+ "is_not_distinct_from" => Ok(Operator::IsNotDistinctFrom),
+ "regex_match" => Ok(Operator::RegexMatch),
+ "regex_imatch" => Ok(Operator::RegexIMatch),
+ "regex_not_match" => Ok(Operator::RegexNotMatch),
+ "regex_not_imatch" => Ok(Operator::RegexNotIMatch),
+ "bitwise_and" => Ok(Operator::BitwiseAnd),
+ "bitwise_or" => Ok(Operator::BitwiseOr),
+ "str_concat" => Ok(Operator::StringConcat),
+ "bitwise_xor" => Ok(Operator::BitwiseXor),
+ "bitwise_shift_right" => Ok(Operator::BitwiseShiftRight),
+ "bitwise_shift_left" => Ok(Operator::BitwiseShiftLeft),
+ _ => Err(DataFusionError::NotImplemented(format!(
+ "Unsupported function name: {:?}",
+ name
+ ))),
+ }
+}
+
+/// Convert Substrait Plan to DataFusion DataFrame
+pub async fn from_substrait_plan(ctx: &mut SessionContext, plan: &Plan) ->
Result<Arc<DataFrame>> {
+ // Register function extension
+ let function_extension = plan.extensions
+ .iter()
+ .map(|e| match &e.mapping_type {
+ Some(ext) => {
+ match ext {
+ MappingType::ExtensionFunction(ext_f) =>
Ok((ext_f.function_anchor, &ext_f.name)),
+ _ =>
Err(DataFusionError::NotImplemented(format!("Extension type not supported:
{:?}", ext)))
+ }
+ }
+ None => Err(DataFusionError::NotImplemented("Cannot parse empty
extension".to_string()))
+ })
+ .collect::<Result<HashMap<_, _>>>()?;
+ // Parse relations
+ match plan.relations.len() {
+ 1 => {
+ match plan.relations[0].rel_type.as_ref() {
+ Some(rt) => match rt {
+ substrait::protobuf::plan_rel::RelType::Rel(rel) => {
+ Ok(from_substrait_rel(ctx, &rel,
&function_extension).await?)
+ },
+ substrait::protobuf::plan_rel::RelType::Root(root) => {
+ Ok(from_substrait_rel(ctx,
&root.input.as_ref().unwrap(), &function_extension).await?)
+ }
+ },
+ None => Err(DataFusionError::Internal("Cannot parse plan
relation: None".to_string()))
+ }
+
+ },
+ _ => Err(DataFusionError::NotImplemented(format!(
+ "Substrait plan with more than 1 relation trees not supported.
Number of relation trees: {:?}",
+ plan.relations.len()
+ )))
+ }
+}
+
+/// Convert Substrait Rel to DataFusion DataFrame
+#[async_recursion]
+pub async fn from_substrait_rel(ctx: &mut SessionContext, rel: &Rel,
extensions: &HashMap<u32, &String>) -> Result<Arc<DataFrame>> {
+ match &rel.rel_type {
+ Some(RelType::Project(p)) => {
+ if let Some(input) = p.input.as_ref() {
+ let input = from_substrait_rel(ctx, input, extensions).await?;
+ let mut exprs: Vec<Expr> = vec![];
+ for e in &p.expressions {
+ let x = from_substrait_rex(e, &input.schema(),
extensions).await?;
+ exprs.push(x.as_ref().clone());
+ }
+ input.select(exprs)
+ } else {
+ Err(DataFusionError::NotImplemented(
+ "Projection without an input is not supported".to_string(),
+ ))
+ }
+ }
+ Some(RelType::Filter(filter)) => {
+ if let Some(input) = filter.input.as_ref() {
+ let input = from_substrait_rel(ctx, input, extensions).await?;
+ if let Some(condition) = filter.condition.as_ref() {
+ let expr = from_substrait_rex(condition, &input.schema(),
extensions).await?;
+ input.filter(expr.as_ref().clone())
+ } else {
+ Err(DataFusionError::NotImplemented(
+ "Filter without an condition is not valid".to_string(),
+ ))
+ }
+ } else {
+ Err(DataFusionError::NotImplemented(
+ "Filter without an input is not valid".to_string(),
+ ))
+ }
+ }
+ Some(RelType::Fetch(fetch)) => {
+ if let Some(input) = fetch.input.as_ref() {
+ let input = from_substrait_rel(ctx, input, extensions).await?;
+ let offset = fetch.offset as usize;
+ let count = fetch.count as usize;
+ input.limit(offset, Some(count))
+ } else {
+ Err(DataFusionError::NotImplemented(
+ "Fetch without an input is not valid".to_string(),
+ ))
+ }
+ }
+ Some(RelType::Sort(sort)) => {
+ if let Some(input) = sort.input.as_ref() {
+ let input = from_substrait_rel(ctx, input, extensions).await?;
+ let mut sorts: Vec<Expr> = vec![];
+ for s in &sort.sorts {
+ let expr = from_substrait_rex(&s.expr.as_ref().unwrap(),
&input.schema(), extensions).await?;
+ let asc_nullfirst = match &s.sort_kind {
+ Some(k) => match k {
+ Direction(d) => {
+ let direction : SortDirection = unsafe {
+ ::std::mem::transmute(*d)
+ };
+ match direction {
+ SortDirection::AscNullsFirst => Ok((true,
true)),
+ SortDirection::AscNullsLast => Ok((true,
false)),
+ SortDirection::DescNullsFirst =>
Ok((false, true)),
+ SortDirection::DescNullsLast => Ok((false,
false)),
+ SortDirection::Clustered => {
+ Err(DataFusionError::NotImplemented(
+ "Sort with direction clustered is
not yet supported".to_string(),
+ ))
+ },
+ SortDirection::Unspecified => {
+ Err(DataFusionError::NotImplemented(
+ "Unspecified sort direction is
invalid".to_string(),
+ ))
+ }
+ }
+ }
+ ComparisonFunctionReference(_) => {
+ Err(DataFusionError::NotImplemented(
+ "Sort using comparison function reference
is not supported".to_string(),
+ ))
+ },
+ },
+ None => {
+ Err(DataFusionError::NotImplemented(
+ "Sort without sort kind is
invalid".to_string(),
+ ))
+ },
+ };
+ let (asc, nulls_first) = asc_nullfirst.unwrap();
+ sorts.push(Expr::Sort { expr:
Box::new(expr.as_ref().clone()), asc: asc, nulls_first: nulls_first });
+ }
+ input.sort(sorts)
+ } else {
+ Err(DataFusionError::NotImplemented(
+ "Sort without an input is not valid".to_string(),
+ ))
+ }
+ }
+ Some(RelType::Aggregate(agg)) => {
+ if let Some(input) = agg.input.as_ref() {
+ let input = from_substrait_rel(ctx, input, extensions).await?;
+ let mut group_expr = vec![];
+ let mut aggr_expr = vec![];
+
+ let groupings = match agg.groupings.len() {
+ 1 => { Ok(&agg.groupings[0]) },
+ _ => {
+ Err(DataFusionError::NotImplemented(
+ "Aggregate with multiple grouping sets is not
supported".to_string(),
+ ))
+ }
+ };
+
+ for e in &groupings?.grouping_expressions {
+ let x = from_substrait_rex(&e, &input.schema(),
extensions).await?;
+ group_expr.push(x.as_ref().clone());
+ }
+
+ for m in &agg.measures {
+ let filter = match &m.filter {
+ Some(fil) => Some(Box::new(from_substrait_rex(fil,
&input.schema(), extensions).await?.as_ref().clone())),
+ None => None
+ };
+ let agg_func = match &m.measure {
+ Some(f) => {
+ let distinct = match f.invocation {
+ _ if f.invocation ==
AggregationInvocation::Distinct as i32 => true,
+ _ if f.invocation ==
AggregationInvocation::All as i32 => false,
+ _ => false
+ };
+ from_substrait_agg_func(&f, &input.schema(),
extensions, filter, distinct).await
+ },
+ None => Err(DataFusionError::NotImplemented(
+ "Aggregate without aggregate function is not
supported".to_string(),
+ )),
+ };
+ aggr_expr.push(agg_func?.as_ref().clone());
+ }
+
+ input.aggregate(group_expr, aggr_expr)
+ } else {
+ Err(DataFusionError::NotImplemented(
+ "Aggregate without an input is not valid".to_string(),
+ ))
+ }
+ }
+ Some(RelType::Join(join)) => {
+ let left = from_substrait_rel(ctx, &join.left.as_ref().unwrap(),
extensions).await?;
+ let right = from_substrait_rel(ctx, &join.right.as_ref().unwrap(),
extensions).await?;
+ let join_type = match join.r#type {
+ 1 => JoinType::Inner,
+ 2 => JoinType::Left,
+ 3 => JoinType::Right,
+ 4 => JoinType::Full,
+ 5 => JoinType::Anti,
+ 6 => JoinType::Semi,
+ _ => return Err(DataFusionError::Internal("invalid join
type".to_string())),
+ };
+ let mut predicates = vec![];
+ let schema = build_join_schema(&left.schema(), &right.schema(),
&JoinType::Inner)?;
+ let on = from_substrait_rex(&join.expression.as_ref().unwrap(),
&schema, extensions).await?;
+ split_conjunction(&on, &mut predicates);
+ let pairs = predicates
+ .iter()
+ .map(|p| match p {
+ Expr::BinaryExpr {
+ left,
+ op: Operator::Eq,
+ right,
+ } => match (left.as_ref(), right.as_ref()) {
+ (Expr::Column(l), Expr::Column(r)) =>
Ok((l.flat_name(), r.flat_name())),
+ _ => {
+ return Err(DataFusionError::Internal(
+ "invalid join condition".to_string(),
+ ))
+ }
+ },
+ _ => {
+ return Err(DataFusionError::Internal(
+ "invalid join condition".to_string(),
+ ))
+ }
+ })
+ .collect::<Result<Vec<_>>>()?;
+ let left_cols: Vec<&str> = pairs.iter().map(|(l, _)|
l.as_str()).collect();
+ let right_cols: Vec<&str> = pairs.iter().map(|(_, r)|
r.as_str()).collect();
+ left.join(right, join_type, &left_cols, &right_cols, None)
+ }
+ Some(RelType::Read(read)) => match &read.as_ref().read_type {
+ Some(ReadType::NamedTable(nt)) => {
+ let table_reference = match nt.names.len() {
+ 0 => {
+ return Err(DataFusionError::Internal(
+ "No table name found in NamedTable".to_string(),
+ ));
+ }
+ 1 => TableReference::Bare {
+ table: &nt.names[0],
+ },
+ 2 => TableReference::Partial {
+ schema: &nt.names[0],
+ table: &nt.names[1],
+ },
+ _ => TableReference::Full {
+ catalog: &nt.names[0],
+ schema: &nt.names[1],
+ table: &nt.names[2],
+ },
+ };
+ let t = ctx.table(table_reference)?;
+ match &read.projection {
+ Some(MaskExpression { select, .. }) => match
&select.as_ref() {
+ Some(projection) => {
+ let column_indices: Vec<usize> = projection
+ .struct_items
+ .iter()
+ .map(|item| item.field as usize)
+ .collect();
+ match t.to_logical_plan()? {
+ LogicalPlan::TableScan(scan) => {
+ let mut scan = scan.clone();
+ let fields: Vec<DFField> = column_indices
+ .iter()
+ .map(|i|
scan.projected_schema.field(*i).clone())
+ .collect();
+ scan.projection = Some(column_indices);
+ scan.projected_schema = DFSchemaRef::new(
+ DFSchema::new_with_metadata(fields,
HashMap::new())?,
+ );
+ let plan = LogicalPlan::TableScan(scan);
+
Ok(Arc::new(DataFrame::new(ctx.state.clone(), &plan)))
+ }
+ _ => Err(DataFusionError::Internal(
+ "unexpected plan for table".to_string(),
+ )),
+ }
+ }
+ _ => Ok(t),
+ },
+ _ => Ok(t),
+ }
+ }
+ _ => Err(DataFusionError::NotImplemented(
+ "Only NamedTable reads are supported".to_string(),
+ )),
+ },
+ _ => Err(DataFusionError::NotImplemented(format!(
+ "Unsupported RelType: {:?}",
+ rel.rel_type
+ ))),
+ }
+}
+
+/// Convert Substrait AggregateFunction to DataFusion Expr
+pub async fn from_substrait_agg_func(
+ f: &AggregateFunction,
+ input_schema: &DFSchema,
+ extensions: &HashMap<u32, &String>,
+ filter: Option<Box<Expr>>,
+ distinct: bool
+) -> Result<Arc<Expr>> {
+ let mut args: Vec<Expr> = vec![];
+ for arg in &f.arguments {
+ let arg_expr = match &arg.arg_type {
+ Some(ArgType::Value(e)) => from_substrait_rex(e, input_schema,
extensions).await,
+ _ => Err(DataFusionError::NotImplemented(
+ "Aggregated function argument non-Value type not
supported".to_string(),
+ ))
+ };
+ args.push(arg_expr?.as_ref().clone());
+ }
+
+ let fun = match extensions.get(&f.function_reference) {
+ Some(function_name) =>
aggregate_function::AggregateFunction::from_str(function_name),
+ None => Err(DataFusionError::NotImplemented(format!(
+ "Aggregated function not found: function anchor = {:?}",
+ f.function_reference
+ )
+ ))
+ };
+
+ Ok(
+ Arc::new(
+ Expr::AggregateFunction {
+ fun: fun.unwrap(),
+ args: args,
+ distinct: distinct,
+ filter: filter
+ }
+ )
+ )
+}
+
+/// Convert Substrait Rex to DataFusion Expr
+#[async_recursion]
+pub async fn from_substrait_rex(e: &Expression, input_schema: &DFSchema,
extensions: &HashMap<u32, &String>) -> Result<Arc<Expr>> {
+ match &e.rex_type {
+ Some(RexType::Selection(field_ref)) => match &field_ref.reference_type
{
+ Some(DirectReference(direct)) => match
&direct.reference_type.as_ref() {
+ Some(StructField(x)) => match &x.child.as_ref() {
+ Some(_) => Err(DataFusionError::NotImplemented(
+ "Direct reference StructField with child is not
supported".to_string(),
+ )),
+ None => Ok(Arc::new(Expr::Column(Column {
+ relation: None,
+ name: input_schema
+ .field(x.field as usize)
+ .name()
+ .to_string(),
+ }))),
+ },
+ _ => Err(DataFusionError::NotImplemented(
+ "Direct reference with types other than StructField is not
supported".to_string(),
+ )),
+ },
+ _ => Err(DataFusionError::NotImplemented(
+ "unsupported field ref type".to_string(),
+ )),
+ },
+ Some(RexType::IfThen(if_then)) => {
+ // Parse `ifs`
+ // If the first element does not have a `then` part, then we can
assume it's a base expression
+ let mut when_then_expr: Vec<(Box<Expr>, Box<Expr>)> = vec![];
+ let mut expr = None;
+ for (i, if_expr) in if_then.ifs.iter().enumerate() {
+ if i == 0 {
+ // Check if the first element is type base expression
+ if if_expr.then.is_none() {
+ expr =
Some(Box::new(from_substrait_rex(&if_expr.r#if.as_ref().unwrap(), input_schema,
extensions).await?.as_ref().clone()));
+ continue;
+ }
+ }
+ when_then_expr.push(
+ (
+
Box::new(from_substrait_rex(&if_expr.r#if.as_ref().unwrap(), input_schema,
extensions).await?.as_ref().clone()),
+
Box::new(from_substrait_rex(&if_expr.then.as_ref().unwrap(), input_schema,
extensions).await?.as_ref().clone())
+ ),
+ );
+ }
+ // Parse `else`
+ let else_expr = match &if_then.r#else {
+ Some(e) => Some(Box::new(
+ from_substrait_rex(&e,
input_schema, extensions).await?.as_ref().clone(),
+ )),
+ None => None
+ };
+ Ok(Arc::new(Expr::Case { expr: expr, when_then_expr:
when_then_expr, else_expr: else_expr }))
+ },
+ Some(RexType::ScalarFunction(f)) => {
+ assert!(f.arguments.len() == 2);
+ let op = match extensions.get(&f.function_reference) {
+ Some(fname) => name_to_op(fname),
+ None => Err(DataFusionError::NotImplemented(format!(
+ "Aggregated function not found: function reference =
{:?}",
+ f.function_reference
+ )
+ ))
+ };
+ match (&f.arguments[0].arg_type, &f.arguments[1].arg_type) {
+ (Some(ArgType::Value(l)), Some(ArgType::Value(r))) => {
+ Ok(Arc::new(Expr::BinaryExpr {
+ left: Box::new(from_substrait_rex(l, input_schema,
extensions).await?.as_ref().clone()),
+ op: op?,
+ right: Box::new(
+ from_substrait_rex(r, input_schema,
extensions).await?.as_ref().clone(),
+ ),
+ }))
+ }
+ (l, r) => Err(DataFusionError::NotImplemented(format!(
+ "Invalid arguments for binary expression: {:?} and {:?}",
+ l, r
+ ))),
+ }
+ }
+ Some(RexType::Literal(lit)) => match &lit.literal_type {
+ Some(LiteralType::I8(n)) => {
+ Ok(Arc::new(Expr::Literal(ScalarValue::Int8(Some(*n as i8)))))
+ }
+ Some(LiteralType::I16(n)) => {
+ Ok(Arc::new(Expr::Literal(ScalarValue::Int16(Some(*n as
i16)))))
+ }
+ Some(LiteralType::I32(n)) => {
+ Ok(Arc::new(Expr::Literal(ScalarValue::Int32(Some(*n as
i32)))))
+ }
+ Some(LiteralType::I64(n)) => {
+ Ok(Arc::new(Expr::Literal(ScalarValue::Int64(Some(*n as
i64)))))
+ }
+ Some(LiteralType::Boolean(b)) => {
+ Ok(Arc::new(Expr::Literal(ScalarValue::Boolean(Some(*b)))))
+ }
+ Some(LiteralType::Date(d)) => {
+ Ok(Arc::new(Expr::Literal(ScalarValue::Date32(Some(*d)))))
+ }
+ Some(LiteralType::Fp32(f)) => {
+ Ok(Arc::new(Expr::Literal(ScalarValue::Float32(Some(*f)))))
+ }
+ Some(LiteralType::Fp64(f)) => {
+ Ok(Arc::new(Expr::Literal(ScalarValue::Float64(Some(*f)))))
+ }
+ Some(LiteralType::String(s)) =>
Ok(Arc::new(Expr::Literal(ScalarValue::Utf8(
+ Some(s.clone()),
+ )))),
+ Some(LiteralType::Binary(b)) =>
Ok(Arc::new(Expr::Literal(ScalarValue::Binary(Some(
+ b.clone(),
+ ))))),
+ _ => {
+ return Err(DataFusionError::NotImplemented(format!(
+ "Unsupported literal_type: {:?}",
+ lit.literal_type
+ )))
+ }
+ },
+ _ => Err(DataFusionError::NotImplemented(
+ "unsupported rex_type".to_string(),
+ )),
+ }
+}
diff --git a/datafusion/substrait/src/lib.rs b/datafusion/substrait/src/lib.rs
new file mode 100644
index 000000000..07b8e7add
--- /dev/null
+++ b/datafusion/substrait/src/lib.rs
@@ -0,0 +1,20 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+pub mod consumer;
+pub mod producer;
+pub mod serializer;
diff --git a/datafusion/substrait/src/producer.rs
b/datafusion/substrait/src/producer.rs
new file mode 100644
index 000000000..78532046b
--- /dev/null
+++ b/datafusion/substrait/src/producer.rs
@@ -0,0 +1,557 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::collections::HashMap;
+
+use datafusion::{
+ error::{DataFusionError, Result},
+ logical_plan::{DFSchemaRef, Expr, JoinConstraint, LogicalPlan, Operator},
+ prelude::JoinType,
+ scalar::ScalarValue,
+};
+
+use substrait::protobuf::{
+ aggregate_function::AggregationInvocation,
+ aggregate_rel::{Grouping, Measure},
+ expression::{
+ field_reference::ReferenceType,
+ if_then::IfClause,
+ literal::LiteralType,
+ mask_expression::{StructItem, StructSelect},
+ reference_segment,
+ FieldReference, IfThen, Literal, MaskExpression, ReferenceSegment,
RexType, ScalarFunction,
+ },
+ extensions::{self, simple_extension_declaration::{MappingType,
ExtensionFunction}},
+ function_argument::ArgType,
+ plan_rel,
+ read_rel::{NamedTable, ReadType},
+ rel::RelType,
+ sort_field::{
+ SortDirection,
+ SortKind,
+ },
+ AggregateRel, Expression, FetchRel, FilterRel, FunctionArgument, JoinRel,
NamedStruct, ProjectRel, ReadRel, SortField, SortRel,
+ PlanRel,
+ Plan, Rel, RelRoot, AggregateFunction,
+};
+
+/// Convert DataFusion LogicalPlan to Substrait Plan
+pub fn to_substrait_plan(plan: &LogicalPlan) -> Result<Box<Plan>> {
+ // Parse relation nodes
+ let mut extension_info: (Vec<extensions::SimpleExtensionDeclaration>,
HashMap<String, u32>) = (vec![], HashMap::new());
+ // Generate PlanRel(s)
+ // Note: Only 1 relation tree is currently supported
+ let plan_rels = vec![PlanRel {
+ rel_type: Some(plan_rel::RelType::Root(
+ RelRoot {
+ input: Some(*to_substrait_rel(plan, &mut extension_info)?),
+ names: plan.schema().field_names(),
+ }
+ ))
+ }];
+
+ let (function_extensions, _) = extension_info;
+
+ // Return parsed plan
+ Ok(Box::new(Plan {
+ extension_uris: vec![],
+ extensions: function_extensions,
+ relations: plan_rels,
+ advanced_extensions: None,
+ expected_type_urls: vec![],
+ }))
+
+}
+
+/// Convert DataFusion LogicalPlan to Substrait Rel
+pub fn to_substrait_rel(plan: &LogicalPlan, extension_info: &mut
(Vec<extensions::SimpleExtensionDeclaration>, HashMap<String, u32>)) ->
Result<Box<Rel>> {
+ match plan {
+ LogicalPlan::TableScan(scan) => {
+ let projection = scan.projection.as_ref().map(|p| {
+ p.iter()
+ .map(|i| StructItem {
+ field: *i as i32,
+ child: None,
+ })
+ .collect()
+ });
+
+ if let Some(struct_items) = projection {
+ Ok(Box::new(Rel {
+ rel_type: Some(RelType::Read(Box::new(ReadRel {
+ common: None,
+ base_schema: Some(NamedStruct {
+ names: scan
+ .projected_schema
+ .fields()
+ .iter()
+ .map(|f| f.name().to_owned())
+ .collect(),
+ r#struct: None,
+ }),
+ filter: None,
+ projection: Some(MaskExpression {
+ select: Some(StructSelect { struct_items }),
+ maintain_singular_struct: false,
+ }),
+ advanced_extension: None,
+ read_type: Some(ReadType::NamedTable(NamedTable {
+ names: vec![scan.table_name.clone()],
+ advanced_extension: None,
+ })),
+ }))),
+ }))
+ } else {
+ Err(DataFusionError::NotImplemented(
+ "TableScan without projection is not
supported".to_string(),
+ ))
+ }
+ }
+ LogicalPlan::Projection(p) => {
+ let expressions = p
+ .expr
+ .iter()
+ .map(|e| to_substrait_rex(e, p.input.schema(), extension_info))
+ .collect::<Result<Vec<_>>>()?;
+ Ok(Box::new(Rel {
+ rel_type: Some(RelType::Project(Box::new(ProjectRel {
+ common: None,
+ input: Some(to_substrait_rel(p.input.as_ref(),
extension_info)?),
+ expressions,
+ advanced_extension: None,
+ }))),
+ }))
+ }
+ LogicalPlan::Filter(filter) => {
+ let input = to_substrait_rel(filter.input.as_ref(),
extension_info)?;
+ let filter_expr = to_substrait_rex(&filter.predicate,
filter.input.schema(), extension_info)?;
+ Ok(Box::new(Rel {
+ rel_type: Some(RelType::Filter(Box::new(FilterRel {
+ common: None,
+ input: Some(input),
+ condition: Some(Box::new(filter_expr)),
+ advanced_extension: None,
+ }))),
+ }))
+ }
+ LogicalPlan::Limit(limit) => {
+ let input = to_substrait_rel(limit.input.as_ref(),
extension_info)?;
+ let limit_fetch = match limit.fetch {
+ Some(count) => count,
+ None => 0,
+ };
+ Ok(Box::new(Rel {
+ rel_type: Some(RelType::Fetch(Box::new(FetchRel {
+ common: None,
+ input: Some(input),
+ offset: limit.skip as i64,
+ count: limit_fetch as i64,
+ advanced_extension: None,
+ }))),
+ }))
+ }
+ LogicalPlan::Sort(sort) => {
+ let input = to_substrait_rel(sort.input.as_ref(), extension_info)?;
+ let sort_fields = sort
+ .expr
+ .iter()
+ .map(|e| substrait_sort_field(e, sort.input.schema(),
extension_info))
+ .collect::<Result<Vec<_>>>()?;
+ Ok(Box::new(Rel {
+ rel_type: Some(RelType::Sort(Box::new(SortRel {
+ common: None,
+ input: Some(input),
+ sorts: sort_fields,
+ advanced_extension: None,
+ }))),
+ }))
+ }
+ LogicalPlan::Aggregate(agg) => {
+ let input = to_substrait_rel(agg.input.as_ref(), extension_info)?;
+ // Translate aggregate expression to Substrait's groupings
(repeated repeated Expression)
+ let grouping = agg
+ .group_expr
+ .iter()
+ .map(|e| to_substrait_rex(e, agg.input.schema(),
extension_info))
+ .collect::<Result<Vec<_>>>()?;
+ let measures = agg
+ .aggr_expr
+ .iter()
+ .map(|e| to_substrait_agg_measure(e, agg.input.schema(),
extension_info))
+ .collect::<Result<Vec<_>>>()?;
+
+ Ok(Box::new(Rel {
+ rel_type: Some(RelType::Aggregate(Box::new(AggregateRel {
+ common: None,
+ input: Some(input),
+ groupings: vec![Grouping { grouping_expressions: grouping
}], //groupings,
+ measures: measures,
+ advanced_extension: None,
+ }))),
+ }))
+ }
+ LogicalPlan::Distinct(distinct) => {
+ // Use Substrait's AggregateRel with empty measures to represent
`select distinct`
+ let input = to_substrait_rel(distinct.input.as_ref(),
extension_info)?;
+ // Get grouping keys from the input relation's number of output
fields
+ let grouping = (0..distinct.input.schema().fields().len())
+ .map(|x: usize| substrait_field_ref(x))
+ .collect::<Result<Vec<_>>>()?;
+
+ Ok(Box::new(Rel {
+ rel_type: Some(RelType::Aggregate(Box::new(AggregateRel {
+ common: None,
+ input: Some(input),
+ groupings: vec![Grouping { grouping_expressions: grouping
}],
+ measures: vec![],
+ advanced_extension: None,
+ }))),
+ }))
+ }
+ LogicalPlan::Join(join) => {
+ let left = to_substrait_rel(join.left.as_ref(), extension_info)?;
+ let right = to_substrait_rel(join.right.as_ref(), extension_info)?;
+ let join_type = match join.join_type {
+ JoinType::Inner => 1,
+ JoinType::Left => 2,
+ JoinType::Right => 3,
+ JoinType::Full => 4,
+ JoinType::Anti => 5,
+ JoinType::Semi => 6,
+ };
+ // we only support basic joins so return an error for anything not
yet supported
+ if join.null_equals_null {
+ return Err(DataFusionError::NotImplemented(
+ "join null_equals_null".to_string(),
+ ));
+ }
+ if join.filter.is_some() {
+ return Err(DataFusionError::NotImplemented("join
filter".to_string()));
+ }
+ match join.join_constraint {
+ JoinConstraint::On => {}
+ _ => {
+ return Err(DataFusionError::NotImplemented(
+ "join constraint".to_string(),
+ ))
+ }
+ }
+ // map the left and right columns to binary expressions in the
form `l = r`
+ let join_expression: Vec<Expr> = join
+ .on
+ .iter()
+ .map(|(l, r)|
Expr::Column(l.clone()).eq(Expr::Column(r.clone())))
+ .collect();
+ // build a single expression for the ON condition, such as `l.a =
r.a AND l.b = r.b`
+ let join_expression = join_expression
+ .into_iter()
+ .reduce(|acc: Expr, expr: Expr| acc.and(expr));
+ if let Some(e) = join_expression {
+ Ok(Box::new(Rel {
+ rel_type: Some(RelType::Join(Box::new(JoinRel {
+ common: None,
+ left: Some(left),
+ right: Some(right),
+ r#type: join_type,
+ expression: Some(Box::new(to_substrait_rex(&e,
&join.schema, extension_info)?)),
+ post_join_filter: None,
+ advanced_extension: None,
+ }))),
+ }))
+ } else {
+ Err(DataFusionError::NotImplemented(
+ "Empty join condition".to_string(),
+ ))
+ }
+ }
+ LogicalPlan::SubqueryAlias(alias) => {
+ // Do nothing if encounters SubqueryAlias
+ // since there is no corresponding relation type in Substrait
+ to_substrait_rel(alias.input.as_ref(), extension_info)
+ }
+ _ => Err(DataFusionError::NotImplemented(format!(
+ "Unsupported operator: {:?}",
+ plan
+ ))),
+ }
+}
+
+pub fn operator_to_name(op: Operator) -> &'static str {
+ match op {
+ Operator::Eq => "equal",
+ Operator::NotEq => "not_equal",
+ Operator::Lt => "lt",
+ Operator::LtEq => "lte",
+ Operator::Gt => "gt",
+ Operator::GtEq => "gte",
+ Operator::Plus => "add",
+ Operator::Minus => "substract",
+ Operator::Multiply => "multiply",
+ Operator::Divide => "divide",
+ Operator::Modulo => "mod",
+ Operator::And => "and",
+ Operator::Or => "or",
+ Operator::Like => "like",
+ Operator::NotLike => "not_like",
+ Operator::IsDistinctFrom => "is_distinct_from",
+ Operator::IsNotDistinctFrom => "is_not_distinct_from",
+ Operator::RegexMatch => "regex_match",
+ Operator::RegexIMatch => "regex_imatch",
+ Operator::RegexNotMatch => "regex_not_match",
+ Operator::RegexNotIMatch => "regex_not_imatch",
+ Operator::BitwiseAnd => "bitwise_and",
+ Operator::BitwiseOr => "bitwise_or",
+ Operator::StringConcat => "str_concat",
+ Operator::BitwiseXor => "bitwise_xor",
+ Operator::BitwiseShiftRight => "bitwise_shift_right",
+ Operator::BitwiseShiftLeft => "bitwise_shift_left",
+ }
+}
+
+pub fn to_substrait_agg_measure(expr: &Expr, schema: &DFSchemaRef,
extension_info: &mut (Vec<extensions::SimpleExtensionDeclaration>,
HashMap<String, u32>)) -> Result<Measure> {
+ match expr {
+ Expr::AggregateFunction { fun, args, distinct, filter } => {
+ let mut arguments: Vec<FunctionArgument> = vec![];
+ for arg in args {
+ arguments.push(FunctionArgument { arg_type:
Some(ArgType::Value(to_substrait_rex(arg, schema, extension_info)?)) });
+ }
+ let function_name = fun.to_string().to_lowercase();
+ let function_anchor = _register_function(function_name,
extension_info);
+ Ok(Measure {
+ measure: Some(AggregateFunction {
+ function_reference: function_anchor,
+ arguments: arguments,
+ sorts: vec![],
+ output_type: None,
+ invocation: match distinct {
+ true => AggregationInvocation::Distinct as i32,
+ false => AggregationInvocation::All as i32,
+ },
+ phase: substrait::protobuf::AggregationPhase::Unspecified
as i32,
+ args: vec![],
+ }),
+ filter: match filter {
+ Some(f) => Some(to_substrait_rex(f, schema,
extension_info)?),
+ None => None
+ }
+ })
+ },
+ _ => Err(DataFusionError::Internal(format!(
+ "Expression must be compatible with aggregation. Unsupported
expression: {:?}",
+ expr
+ ))),
+ }
+}
+
+fn _register_function(function_name: String, extension_info: &mut
(Vec<extensions::SimpleExtensionDeclaration>, HashMap<String, u32>)) -> u32 {
+ let (function_extensions, function_set) = extension_info;
+ let function_name = function_name.to_lowercase();
+ // To prevent ambiguous references between ScalarFunctions and
AggregateFunctions,
+ // a plan-relative identifier starting from 0 is used as the
function_anchor.
+ // The consumer is responsible for correctly registering
<function_anchor,function_name>
+ // mapping info stored in the extensions by the producer.
+ let function_anchor = match function_set.get(&function_name) {
+ Some(function_anchor) => {
+ // Function has been registered
+ *function_anchor
+ },
+ None => {
+ // Function has NOT been registered
+ let function_anchor = function_set.len() as u32;
+ function_set.insert(function_name.clone(), function_anchor);
+
+ let function_extension = ExtensionFunction {
+ extension_uri_reference: u32::MAX,
+ function_anchor: function_anchor,
+ name: function_name,
+ };
+ let simple_extension = extensions::SimpleExtensionDeclaration {
+ mapping_type:
Some(MappingType::ExtensionFunction(function_extension)),
+ };
+ function_extensions.push(simple_extension);
+ function_anchor
+ }
+ };
+
+ // Return function anchor
+ function_anchor
+
+}
+
+/// Return Substrait scalar function with two arguments
+pub fn make_binary_op_scalar_func(lhs: &Expression, rhs: &Expression, op:
Operator, extension_info: &mut (Vec<extensions::SimpleExtensionDeclaration>,
HashMap<String, u32>)) -> Expression {
+ let function_name = operator_to_name(op).to_string().to_lowercase();
+ let function_anchor = _register_function(function_name, extension_info);
+ Expression {
+ rex_type: Some(RexType::ScalarFunction(ScalarFunction {
+ function_reference: function_anchor,
+ arguments: vec![
+ FunctionArgument {
+ arg_type: Some(ArgType::Value(lhs.clone())),
+ },
+ FunctionArgument {
+ arg_type: Some(ArgType::Value(rhs.clone())),
+ },
+ ],
+ output_type: None,
+ args: vec![],
+ })),
+ }
+}
+
+/// Convert DataFusion Expr to Substrait Rex
+pub fn to_substrait_rex(expr: &Expr, schema: &DFSchemaRef, extension_info:
&mut (Vec<extensions::SimpleExtensionDeclaration>, HashMap<String, u32>)) ->
Result<Expression> {
+ match expr {
+ Expr::Between { expr, negated, low, high } => {
+ if *negated {
+ // `expr NOT BETWEEN low AND high` can be translated into
(expr < low OR high < expr)
+ let substrait_expr = to_substrait_rex(expr, schema,
extension_info)?;
+ let substrait_low = to_substrait_rex(low, schema,
extension_info)?;
+ let substrait_high = to_substrait_rex(high, schema,
extension_info)?;
+
+ let l_expr = make_binary_op_scalar_func(&substrait_expr,
&substrait_low, Operator::Lt, extension_info);
+ let r_expr = make_binary_op_scalar_func(&substrait_high,
&substrait_expr, Operator::Lt, extension_info);
+
+ Ok(make_binary_op_scalar_func(&l_expr, &r_expr, Operator::Or,
extension_info))
+ } else {
+ // `expr BETWEEN low AND high` can be translated into (low <=
expr AND expr <= high)
+ let substrait_expr = to_substrait_rex(expr, schema,
extension_info)?;
+ let substrait_low = to_substrait_rex(low, schema,
extension_info)?;
+ let substrait_high = to_substrait_rex(high, schema,
extension_info)?;
+
+ let l_expr = make_binary_op_scalar_func(&substrait_low,
&substrait_expr, Operator::LtEq, extension_info);
+ let r_expr = make_binary_op_scalar_func(&substrait_expr,
&substrait_high, Operator::LtEq, extension_info);
+
+ Ok(make_binary_op_scalar_func(&l_expr, &r_expr, Operator::And,
extension_info))
+ }
+ }
+ Expr::Column(col) => {
+ let index = schema.index_of_column(&col)?;
+ substrait_field_ref(index)
+ }
+ Expr::BinaryExpr { left, op, right } => {
+ let l = to_substrait_rex(left, schema, extension_info)?;
+ let r = to_substrait_rex(right, schema, extension_info)?;
+
+ Ok(make_binary_op_scalar_func(&l, &r, *op, extension_info))
+ }
+ Expr::Case { expr, when_then_expr, else_expr } => {
+ let mut ifs: Vec<IfClause> = vec![];
+ // Parse base
+ if let Some(e) = expr { // Base expression exists
+ ifs.push(IfClause {
+ r#if: Some(to_substrait_rex(e, schema, extension_info)?),
+ then: None,
+ });
+ }
+ // Parse `when`s
+ for (r#if, then) in when_then_expr {
+ ifs.push(IfClause {
+ r#if: Some(to_substrait_rex(r#if, schema,
extension_info)?),
+ then: Some(to_substrait_rex(then, schema,
extension_info)?),
+ });
+ }
+
+ // Parse outer `else`
+ let r#else: Option<Box<Expression>> = match else_expr {
+ Some(e) => Some(Box::new(to_substrait_rex(e, schema,
extension_info)?)),
+ None => None,
+ };
+
+ Ok(Expression {
+ rex_type: Some(RexType::IfThen(Box::new(IfThen {
+ ifs: ifs,
+ r#else: r#else
+ }))),
+ })
+ }
+ Expr::Literal(value) => {
+ let literal_type = match value {
+ ScalarValue::Int8(Some(n)) => Some(LiteralType::I8(*n as i32)),
+ ScalarValue::Int16(Some(n)) => Some(LiteralType::I16(*n as
i32)),
+ ScalarValue::Int32(Some(n)) => Some(LiteralType::I32(*n)),
+ ScalarValue::Int64(Some(n)) => Some(LiteralType::I64(*n)),
+ ScalarValue::Boolean(Some(b)) =>
Some(LiteralType::Boolean(*b)),
+ ScalarValue::Float32(Some(f)) => Some(LiteralType::Fp32(*f)),
+ ScalarValue::Float64(Some(f)) => Some(LiteralType::Fp64(*f)),
+ ScalarValue::Utf8(Some(s)) =>
Some(LiteralType::String(s.clone())),
+ ScalarValue::LargeUtf8(Some(s)) =>
Some(LiteralType::String(s.clone())),
+ ScalarValue::Binary(Some(b)) =>
Some(LiteralType::Binary(b.clone())),
+ ScalarValue::LargeBinary(Some(b)) =>
Some(LiteralType::Binary(b.clone())),
+ ScalarValue::Date32(Some(d)) => Some(LiteralType::Date(*d)),
+ _ => {
+ return Err(DataFusionError::NotImplemented(format!(
+ "Unsupported literal: {:?}",
+ value
+ )))
+ }
+ };
+ Ok(Expression {
+ rex_type: Some(RexType::Literal(Literal {
+ nullable: true,
+ type_variation_reference: 0,
+ literal_type,
+ })),
+ })
+ }
+ Expr::Alias(expr, _alias) => {
+ to_substrait_rex(expr, schema, extension_info)
+ }
+ _ => Err(DataFusionError::NotImplemented(format!(
+ "Unsupported expression: {:?}",
+ expr
+ ))),
+ }
+}
+
+fn substrait_sort_field(expr: &Expr, schema: &DFSchemaRef, extension_info:
&mut (Vec<extensions::SimpleExtensionDeclaration>, HashMap<String, u32>)) ->
Result<SortField> {
+ match expr {
+ Expr::Sort { expr, asc, nulls_first } => {
+ let e = to_substrait_rex(expr, schema, extension_info)?;
+ let d = match (asc, nulls_first) {
+ (true, true) => SortDirection::AscNullsFirst,
+ (true, false) => SortDirection::AscNullsLast,
+ (false, true) => SortDirection::DescNullsFirst,
+ (false, false) => SortDirection::DescNullsLast,
+ };
+ Ok(SortField {
+ expr: Some(e),
+ sort_kind: Some(SortKind::Direction(d as i32)),
+ })
+ },
+ _ => Err(DataFusionError::NotImplemented(format!(
+ "Expecting sort expression but got {:?}",
+ expr
+ ))),
+ }
+}
+
+fn substrait_field_ref(index: usize) -> Result<Expression> {
+ Ok(Expression {
+ rex_type: Some(RexType::Selection(Box::new(FieldReference {
+ reference_type:
Some(ReferenceType::DirectReference(ReferenceSegment {
+ reference_type:
Some(reference_segment::ReferenceType::StructField(
+ Box::new(reference_segment::StructField {
+ field: index as i32,
+ child: None,
+ }),
+ )),
+ })),
+ root_type: None,
+ }))),
+ })
+}
diff --git a/datafusion/substrait/src/serializer.rs
b/datafusion/substrait/src/serializer.rs
new file mode 100644
index 000000000..7f52077f1
--- /dev/null
+++ b/datafusion/substrait/src/serializer.rs
@@ -0,0 +1,57 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::producer;
+
+use datafusion::error::Result;
+use datafusion::prelude::*;
+
+use prost::Message;
+use substrait::protobuf::Plan;
+
+use std::fs::OpenOptions;
+use std::io::{Write, Read};
+
+pub async fn serialize(sql: &str, ctx: &SessionContext, path: &str) ->
Result<()> {
+ let df = ctx.sql(sql).await?;
+ let plan = df.to_logical_plan()?;
+ let proto = producer::to_substrait_plan(&plan)?;
+
+ let mut protobuf_out = Vec::<u8>::new();
+ proto.encode(&mut protobuf_out).unwrap();
+ let mut file = OpenOptions::new()
+ .create(true)
+ .write(true)
+ .open(path)?;
+ file.write_all(&protobuf_out)?;
+ Ok(())
+}
+
+pub async fn deserialize(path: &str) -> Result<Box<Plan>> {
+ let mut protobuf_in = Vec::<u8>::new();
+
+ let mut file = OpenOptions::new()
+ .read(true)
+ .open(path)?;
+
+ file.read_to_end(&mut protobuf_in)?;
+ let proto = Message::decode(&*protobuf_in).unwrap();
+
+ Ok(Box::new(proto))
+}
+
+
diff --git a/datafusion/substrait/tests/roundtrip.rs
b/datafusion/substrait/tests/roundtrip.rs
new file mode 100644
index 000000000..21a3a5f29
--- /dev/null
+++ b/datafusion/substrait/tests/roundtrip.rs
@@ -0,0 +1,273 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use datafusion_substrait::consumer;
+use datafusion_substrait::producer;
+
+#[cfg(test)]
+mod tests {
+
+ use crate::{consumer::from_substrait_plan, producer::to_substrait_plan};
+ use datafusion::error::Result;
+ use datafusion::prelude::*;
+ use
substrait::protobuf::extensions::simple_extension_declaration::MappingType;
+
+ #[tokio::test]
+ async fn simple_select() -> Result<()> {
+ roundtrip("SELECT a, b FROM data").await
+ }
+
+ #[tokio::test]
+ async fn wildcard_select() -> Result<()> {
+ roundtrip("SELECT * FROM data").await
+ }
+
+ #[tokio::test]
+ async fn select_with_filter() -> Result<()> {
+ roundtrip("SELECT * FROM data WHERE a > 1").await
+ }
+
+ #[tokio::test]
+ async fn select_with_reused_functions() -> Result<()> {
+ let sql = "SELECT * FROM data WHERE a > 1 AND a < 10 AND b > 0";
+ roundtrip(sql).await?;
+ let (mut function_names, mut function_anchors) =
function_extension_info(sql).await?;
+ function_names.sort();
+ function_anchors.sort();
+
+ assert_eq!(function_names, ["and", "gt", "lt"]);
+ assert_eq!(function_anchors, [0, 1, 2]);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn select_with_filter_date() -> Result<()> {
+ roundtrip("SELECT * FROM data WHERE c > CAST('2020-01-01' AS
DATE)").await
+ }
+
+ #[tokio::test]
+ async fn select_with_filter_bool_expr() -> Result<()> {
+ roundtrip("SELECT * FROM data WHERE d AND a > 1").await
+ }
+
+ #[tokio::test]
+ async fn select_with_limit() -> Result<()> {
+ roundtrip_fill_na("SELECT * FROM data LIMIT 100").await
+ }
+
+ #[tokio::test]
+ async fn select_with_limit_offset() -> Result<()> {
+ roundtrip("SELECT * FROM data LIMIT 200 OFFSET 10").await
+ }
+
+ #[tokio::test]
+ async fn simple_aggregate() -> Result<()> {
+ roundtrip("SELECT a, sum(b) FROM data GROUP BY a").await
+ }
+
+ #[tokio::test]
+ async fn aggregate_distinct_with_having() -> Result<()> {
+ roundtrip("SELECT a, count(distinct b) FROM data GROUP BY a, c HAVING
count(b) > 100").await
+ }
+
+ #[tokio::test]
+ async fn aggregate_multiple_keys() -> Result<()> {
+ roundtrip("SELECT a, c, avg(b) FROM data GROUP BY a, c").await
+ }
+
+ #[tokio::test]
+ async fn simple_distinct() -> Result<()> {
+ test_alias(
+ "SELECT * FROM (SELECT distinct a FROM data)", // `SELECT *` is
used to add `projection` at the root
+ "SELECT a FROM data GROUP BY a",
+ ).await
+ }
+
+ #[tokio::test]
+ async fn select_distinct_two_fields() -> Result<()> {
+ test_alias(
+ "SELECT * FROM (SELECT distinct a, b FROM data)", // `SELECT *` is
used to add `projection` at the root
+ "SELECT a, b FROM data GROUP BY a, b",
+ ).await
+ }
+
+ #[tokio::test]
+ async fn simple_alias() -> Result<()> {
+ test_alias(
+ "SELECT d1.a, d1.b FROM data d1",
+ "SELECT a, b FROM data",
+ ).await
+ }
+
+ #[tokio::test]
+ async fn two_table_alias() -> Result<()> {
+ test_alias(
+ "SELECT d1.a FROM data d1 JOIN data2 d2 ON d1.a = d2.a",
+ "SELECT data.a FROM data JOIN data2 ON data.a = data2.a",
+ )
+ .await
+ }
+
+ #[tokio::test]
+ async fn between_integers() -> Result<()> {
+ test_alias(
+ "SELECT * FROM data WHERE a BETWEEN 2 AND 6",
+ "SELECT * FROM data WHERE a >= 2 AND a <= 6"
+ )
+ .await
+ }
+
+ #[tokio::test]
+ async fn not_between_integers() -> Result<()> {
+ test_alias(
+ "SELECT * FROM data WHERE a NOT BETWEEN 2 AND 6",
+ "SELECT * FROM data WHERE a < 2 OR a > 6"
+ )
+ .await
+ }
+
+ #[tokio::test]
+ async fn case_without_base_expression() -> Result<()> {
+ roundtrip("SELECT (CASE WHEN a >= 0 THEN 'positive' ELSE 'negative'
END) FROM data").await
+ }
+
+ #[tokio::test]
+ async fn case_with_base_expression() -> Result<()> {
+ roundtrip("SELECT (CASE a
+ WHEN 0 THEN 'zero'
+ WHEN 1 THEN 'one'
+ ELSE 'other'
+ END) FROM data").await
+ }
+
+ #[tokio::test]
+ async fn roundtrip_inner_join() -> Result<()> {
+ roundtrip("SELECT data.a FROM data JOIN data2 ON data.a =
data2.a").await
+ }
+
+ #[tokio::test]
+ async fn inner_join() -> Result<()> {
+ assert_expected_plan(
+ "SELECT data.a FROM data JOIN data2 ON data.a = data2.a",
+ "Projection: data.a\
+ \n Inner Join: data.a = data2.a\
+ \n TableScan: data projection=[a]\
+ \n TableScan: data2 projection=[a]",
+ )
+ .await
+ }
+
+ async fn assert_expected_plan(sql: &str, expected_plan_str: &str) ->
Result<()> {
+ let mut ctx = create_context().await?;
+ let df = ctx.sql(sql).await?;
+ let plan = df.to_logical_plan()?;
+ let proto = to_substrait_plan(&plan)?;
+ let df = from_substrait_plan(&mut ctx, &proto).await?;
+ let plan2 = df.to_logical_plan()?;
+ let plan2str = format!("{:?}", plan2);
+ assert_eq!(expected_plan_str, &plan2str);
+ Ok(())
+ }
+
+ async fn roundtrip_fill_na(sql: &str) -> Result<()> {
+ let mut ctx = create_context().await?;
+ let df = ctx.sql(sql).await?;
+ let plan1 = df.to_logical_plan()?;
+ let proto = to_substrait_plan(&plan1)?;
+
+ let df = from_substrait_plan(&mut ctx, &proto).await?;
+ let plan2 = df.to_logical_plan()?;
+
+ // Format plan string and replace all None's with 0
+ let plan1str = format!("{:?}", plan1).replace("None", "0");
+ let plan2str = format!("{:?}", plan2).replace("None", "0");
+
+ assert_eq!(plan1str, plan2str);
+ Ok(())
+ }
+
+ async fn test_alias(sql_with_alias: &str, sql_no_alias: &str) ->
Result<()> {
+ // Since we ignore the SubqueryAlias in the producer, the result
should be
+ // the same as producing a Substrait plan from the same query without
aliases
+ // sql_with_alias -> substrait -> logical plan = sql_no_alias ->
substrait -> logical plan
+ let mut ctx = create_context().await?;
+
+ let df_a = ctx.sql(sql_with_alias).await?;
+ let proto_a = to_substrait_plan(&df_a.to_logical_plan()?)?;
+ let plan_with_alias = from_substrait_plan(&mut ctx,
&proto_a).await?.to_logical_plan()?;
+
+ let df = ctx.sql(sql_no_alias).await?;
+ let proto = to_substrait_plan(&df.to_logical_plan()?)?;
+ let plan = from_substrait_plan(&mut ctx,
&proto).await?.to_logical_plan()?;
+
+ println!("{:#?}", plan_with_alias);
+ println!("{:#?}", plan);
+
+ let plan1str = format!("{:?}", plan_with_alias);
+ let plan2str = format!("{:?}", plan);
+ assert_eq!(plan1str, plan2str);
+ Ok(())
+ }
+
+ async fn roundtrip(sql: &str) -> Result<()> {
+ let mut ctx = create_context().await?;
+ let df = ctx.sql(sql).await?;
+ let plan = df.to_logical_plan()?;
+ let proto = to_substrait_plan(&plan)?;
+
+ let df = from_substrait_plan(&mut ctx, &proto).await?;
+ let plan2 = df.to_logical_plan()?;
+
+ println!("{:#?}", plan);
+ println!("{:#?}", plan2);
+
+ let plan1str = format!("{:?}", plan);
+ let plan2str = format!("{:?}", plan2);
+ assert_eq!(plan1str, plan2str);
+ Ok(())
+ }
+
+ async fn function_extension_info(sql: &str) -> Result<(Vec<String>,
Vec<u32>)> {
+ let ctx = create_context().await?;
+ let df = ctx.sql(sql).await?;
+ let plan = df.to_logical_plan()?;
+ let proto = to_substrait_plan(&plan)?;
+
+ let mut function_names: Vec<String> = vec![];
+ let mut function_anchors: Vec<u32> = vec![];
+ for e in &proto.extensions {
+ let (function_anchor, function_name) = match
e.mapping_type.as_ref().unwrap() {
+ MappingType::ExtensionFunction(ext_f) =>
(ext_f.function_anchor, &ext_f.name),
+ _ => unreachable!("Producer does not generate a non-function
extension")
+ };
+ function_names.push(function_name.to_string());
+ function_anchors.push(function_anchor);
+ }
+
+ Ok((function_names, function_anchors))
+ }
+
+ async fn create_context() -> Result<SessionContext> {
+ let ctx = SessionContext::new();
+ ctx.register_csv("data", "tests/testdata/data.csv",
CsvReadOptions::new())
+ .await?;
+ ctx.register_csv("data2", "tests/testdata/data.csv",
CsvReadOptions::new())
+ .await?;
+ Ok(ctx)
+ }
+}
diff --git a/datafusion/substrait/tests/serialize.rs
b/datafusion/substrait/tests/serialize.rs
new file mode 100644
index 000000000..505c4f5f4
--- /dev/null
+++ b/datafusion/substrait/tests/serialize.rs
@@ -0,0 +1,62 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#[cfg(test)]
+mod tests {
+
+ use datafusion_substrait::consumer::from_substrait_plan;
+ use datafusion_substrait::serializer;
+
+ use datafusion::error::Result;
+ use datafusion::prelude::*;
+
+ use std::fs;
+
+ #[tokio::test]
+ async fn serialize_simple_select() -> Result<()> {
+ let mut ctx = create_context().await?;
+ let path = "tests/simple_select.bin";
+ let sql = "SELECT a, b FROM data";
+ // Test reference
+ let df_ref = ctx.sql(sql).await?;
+ let plan_ref = df_ref.to_logical_plan()?;
+ // Test
+ // Write substrait plan to file
+ serializer::serialize(sql, &ctx, &path).await?;
+ // Read substrait plan from file
+ let proto = serializer::deserialize(path).await?;
+ // Check plan equality
+ let df = from_substrait_plan(&mut ctx, &proto).await?;
+ let plan = df.to_logical_plan()?;
+ let plan_str_ref = format!("{:?}", plan_ref);
+ let plan_str = format!("{:?}", plan);
+ assert_eq!(plan_str_ref, plan_str);
+ // Delete test binary file
+ fs::remove_file(path)?;
+
+ Ok(())
+ }
+
+ async fn create_context() -> Result<SessionContext> {
+ let ctx = SessionContext::new();
+ ctx.register_csv("data", "tests/testdata/data.csv",
CsvReadOptions::new())
+ .await?;
+ ctx.register_csv("data2", "tests/testdata/data.csv",
CsvReadOptions::new())
+ .await?;
+ Ok(ctx)
+ }
+}
\ No newline at end of file
diff --git a/datafusion/substrait/tests/testdata/data.csv
b/datafusion/substrait/tests/testdata/data.csv
new file mode 100644
index 000000000..4394789bc
--- /dev/null
+++ b/datafusion/substrait/tests/testdata/data.csv
@@ -0,0 +1,3 @@
+a,b,c,d
+1,2,2020-01-01,false
+3,4,2020-01-01,true
\ No newline at end of file