This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 0d27fcb04 Add datafusion-substrait crate (#4543)
0d27fcb04 is described below

commit 0d27fcb04b71693adf570346507ca6282f8cba71
Author: Andy Grove <[email protected]>
AuthorDate: Thu Jan 12 15:18:23 2023 -0700

    Add datafusion-substrait crate (#4543)
    
    * Initial commit
    
    * initial commit
    
    * failing test
    
    * table scan projection
    
    * closer
    
    * test passes, with some hacks
    
    * use DataFrame (#2)
    
    * update README
    
    * update dependency
    
    * code cleanup (#3)
    
    * Add support for Filter operator and BinaryOp expressions (#4)
    
    * GitHub action (#5)
    
    * Split code into producer and consumer modules (#6)
    
    * Support more functions and scalar types (#7)
    
    * Use substrait 0.1 and datafusion 8.0 (#8)
    
    * use substrait 0.1
    
    * use datafusion 8.0
    
    * update datafusion to 10.0 and substrait to 0.2 (#11)
    
    * Add basic join support (#12)
    
    * Added fetch support (#23)
    
    Added fetch to consumer
    
    Added limit to producer
    
    Added unit tests for limit
    
    Added roundtrip_fill_none() for testing when None input can be converted to 0
    
    Update src/consumer.rs
    
    Co-authored-by: Andy Grove <[email protected]>
    
    Co-authored-by: Andy Grove <[email protected]>
    
    * Upgrade to DataFusion 13.0.0 (#25)
    
    * Add sort consumer and producer (#24)
    
    Add consumer
    
    Add producer and test
    
    Modified error string
    
    * Add serializer/deserializer (#26)
    
    * Add plan and function extension support (#27)
    
    * Add plan and function extension support
    
    * Removed unwraps
    
    * Implement GROUP BY (#28)
    
    * Add consumer, producer and tests for aggregate relation
    
    Change function extension registration from absolute to relative anchor
    (reference)
    
    Remove operator to/from reference
    
    * Fixed function registration bug
    
    * Add test
    
    * Addressed PR comments
    
    * Changed field reference from mask to direct reference (#29)
    
    * Changed field reference from masked reference to direct reference
    
    * Handle unsupported case (struct with child)
    
    * Handle SubqueryAlias (#30)
    
    Fixed aggregate function register bug
    
    * Add support for SELECT DISTINCT (#31)
    
    Add test case
    
    * Implement BETWEEN (#32)
    
    * Add case (#33)
    
    * Implement CASE WHEN
    
    * Add more case to test
    
    * Addressed comments
    
    * feat: support explicit catalog/schema names in ReadRel (#34)
    
    * feat: support explicit catalog/schema names in ReadRel
    
    Signed-off-by: Ruihang Xia <[email protected]>
    
    * fix: use re-exported expr crate
    
    Signed-off-by: Ruihang Xia <[email protected]>
    
    Signed-off-by: Ruihang Xia <[email protected]>
    
    * move files to subfolder
    
    * RAT
    
    * remove rust.yaml
    
    * revert .gitignore changes
    
    * tomlfmt
    
    * tomlfmt
    
    Signed-off-by: Ruihang Xia <[email protected]>
    Co-authored-by: DaniĆ«l Heres <[email protected]>
    Co-authored-by: JanKaul <[email protected]>
    Co-authored-by: nseekhao <[email protected]>
    Co-authored-by: Ruihang Xia <[email protected]>
---
 datafusion/substrait/Cargo.toml              |  32 ++
 datafusion/substrait/README.md               |  34 ++
 datafusion/substrait/src/consumer.rs         | 544 ++++++++++++++++++++++++++
 datafusion/substrait/src/lib.rs              |  20 +
 datafusion/substrait/src/producer.rs         | 557 +++++++++++++++++++++++++++
 datafusion/substrait/src/serializer.rs       |  57 +++
 datafusion/substrait/tests/roundtrip.rs      | 273 +++++++++++++
 datafusion/substrait/tests/serialize.rs      |  62 +++
 datafusion/substrait/tests/testdata/data.csv |   3 +
 9 files changed, 1582 insertions(+)

diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml
new file mode 100644
index 000000000..b8c0e56d2
--- /dev/null
+++ b/datafusion/substrait/Cargo.toml
@@ -0,0 +1,32 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+[package]
+name = "datafusion-substrait"
+version = "0.1.0"
+edition = "2021"
+
+[dependencies]
+async-recursion = "1.0"
+datafusion = "13.0"
+prost = "0.9"
+prost-types = "0.9"
+substrait = "0.2"
+tokio = "1.17"
+
+[build-dependencies]
+prost-build = { version = "0.9" }
diff --git a/datafusion/substrait/README.md b/datafusion/substrait/README.md
new file mode 100644
index 000000000..9f21d514a
--- /dev/null
+++ b/datafusion/substrait/README.md
@@ -0,0 +1,34 @@
+<!---
+  Licensed to the Apache Software Foundation (ASF) under one
+  or more contributor license agreements.  See the NOTICE file
+  distributed with this work for additional information
+  regarding copyright ownership.  The ASF licenses this file
+  to you under the Apache License, Version 2.0 (the
+  "License"); you may not use this file except in compliance
+  with the License.  You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+  Unless required by applicable law or agreed to in writing,
+  software distributed under the License is distributed on an
+  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+  KIND, either express or implied.  See the License for the
+  specific language governing permissions and limitations
+  under the License.
+-->
+
+# DataFusion + Substrait
+
+[Substrait](https://substrait.io/) provides a cross-language serialization 
format for relational algebra, based on
+protocol buffers.
+
+This repository provides a Substrait producer and consumer for DataFusion:
+
+- The producer converts a DataFusion logical plan into a Substrait protobuf.
+- The consumer converts a Substrait protobuf into a DataFusion logical plan.
+
+Potential uses of this crate:
+
+- Replace the current [DataFusion protobuf 
definition](https://github.com/apache/arrow-datafusion/blob/master/datafusion-proto/proto/datafusion.proto)
 used in Ballista for passing query plan fragments to executors
+- Make it easier to pass query plans over FFI boundaries, such as from Python 
to Rust
+- Allow Apache Calcite query plans to be executed in DataFusion
diff --git a/datafusion/substrait/src/consumer.rs 
b/datafusion/substrait/src/consumer.rs
new file mode 100644
index 000000000..c747a30a6
--- /dev/null
+++ b/datafusion/substrait/src/consumer.rs
@@ -0,0 +1,544 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use async_recursion::async_recursion;
+use datafusion::common::{DFField, DFSchema, DFSchemaRef};
+use datafusion::logical_expr::{LogicalPlan, aggregate_function};
+use datafusion::logical_plan::build_join_schema;
+use datafusion::prelude::JoinType;
+use datafusion::{
+    error::{DataFusionError, Result},
+    logical_plan::{Expr, Operator},
+    optimizer::utils::split_conjunction,
+    prelude::{Column, DataFrame, SessionContext},
+    scalar::ScalarValue,
+};
+
+use datafusion::sql::TableReference;
+use substrait::protobuf::{
+    aggregate_function::AggregationInvocation,
+    expression::{
+        field_reference::ReferenceType::DirectReference,
+        literal::LiteralType,
+        MaskExpression,
+        reference_segment::ReferenceType::StructField,
+        RexType,
+    },
+    extensions::simple_extension_declaration::MappingType,
+    function_argument::ArgType,
+    read_rel::ReadType,
+    rel::RelType,
+    sort_field::{SortKind::*, SortDirection},
+    AggregateFunction, Expression, Plan, Rel,
+};
+
+use std::collections::HashMap;
+use std::str::FromStr;
+use std::sync::Arc;
+
+pub fn name_to_op(name: &str) -> Result<Operator> {
+    match name {
+        "equal" => Ok(Operator::Eq),
+        "not_equal" => Ok(Operator::NotEq),
+        "lt" => Ok(Operator::Lt),
+        "lte" => Ok(Operator::LtEq),
+        "gt" => Ok(Operator::Gt),
+        "gte" => Ok(Operator::GtEq),
+        "add" => Ok(Operator::Plus),
+        "subtract" => Ok(Operator::Minus),
+        "multiply" => Ok(Operator::Multiply),
+        "divide" => Ok(Operator::Divide),
+        "mod" => Ok(Operator::Modulo),
+        "and" => Ok(Operator::And),
+        "or" => Ok(Operator::Or),
+        "like" => Ok(Operator::Like),
+        "not_like" => Ok(Operator::NotLike),
+        "is_distinct_from" => Ok(Operator::IsDistinctFrom),
+        "is_not_distinct_from" => Ok(Operator::IsNotDistinctFrom),
+        "regex_match" => Ok(Operator::RegexMatch),
+        "regex_imatch" => Ok(Operator::RegexIMatch),
+        "regex_not_match" => Ok(Operator::RegexNotMatch),
+        "regex_not_imatch" => Ok(Operator::RegexNotIMatch),
+        "bitwise_and" => Ok(Operator::BitwiseAnd),
+        "bitwise_or" => Ok(Operator::BitwiseOr),
+        "str_concat" => Ok(Operator::StringConcat),
+        "bitwise_xor" => Ok(Operator::BitwiseXor),
+        "bitwise_shift_right" => Ok(Operator::BitwiseShiftRight),
+        "bitwise_shift_left" => Ok(Operator::BitwiseShiftLeft),
+        _ => Err(DataFusionError::NotImplemented(format!(
+            "Unsupported function name: {:?}",
+            name
+        ))),
+    }
+}
+
+/// Convert Substrait Plan to DataFusion DataFrame
+pub async fn from_substrait_plan(ctx: &mut SessionContext, plan: &Plan) -> 
Result<Arc<DataFrame>> {
+    // Register function extension
+    let function_extension = plan.extensions
+        .iter()
+        .map(|e| match &e.mapping_type {
+            Some(ext) => {
+                match ext {
+                    MappingType::ExtensionFunction(ext_f) => 
Ok((ext_f.function_anchor, &ext_f.name)),
+                    _ => 
Err(DataFusionError::NotImplemented(format!("Extension type not supported: 
{:?}", ext)))
+                }
+            }
+            None => Err(DataFusionError::NotImplemented("Cannot parse empty 
extension".to_string()))
+        })
+        .collect::<Result<HashMap<_, _>>>()?;
+    // Parse relations
+    match plan.relations.len() {
+        1 => {
+            match plan.relations[0].rel_type.as_ref() {
+                Some(rt) => match rt {
+                    substrait::protobuf::plan_rel::RelType::Rel(rel) => {
+                        Ok(from_substrait_rel(ctx, &rel, 
&function_extension).await?)
+                    },
+                    substrait::protobuf::plan_rel::RelType::Root(root) => {
+                        Ok(from_substrait_rel(ctx, 
&root.input.as_ref().unwrap(), &function_extension).await?)
+                    }
+                },
+                None => Err(DataFusionError::Internal("Cannot parse plan 
relation: None".to_string()))
+            }
+            
+        },
+        _ => Err(DataFusionError::NotImplemented(format!(
+            "Substrait plan with more than 1 relation trees not supported. 
Number of relation trees: {:?}",
+            plan.relations.len()
+        )))
+    }
+}
+
+/// Convert Substrait Rel to DataFusion DataFrame
+#[async_recursion]
+pub async fn from_substrait_rel(ctx: &mut SessionContext, rel: &Rel, 
extensions: &HashMap<u32, &String>) -> Result<Arc<DataFrame>> {
+    match &rel.rel_type {
+        Some(RelType::Project(p)) => {
+            if let Some(input) = p.input.as_ref() {
+                let input = from_substrait_rel(ctx, input, extensions).await?;
+                let mut exprs: Vec<Expr> = vec![];
+                for e in &p.expressions {
+                    let x = from_substrait_rex(e, &input.schema(), 
extensions).await?;
+                    exprs.push(x.as_ref().clone());
+                }
+                input.select(exprs)
+            } else {
+                Err(DataFusionError::NotImplemented(
+                    "Projection without an input is not supported".to_string(),
+                ))
+            }
+        }
+        Some(RelType::Filter(filter)) => {
+            if let Some(input) = filter.input.as_ref() {
+                let input = from_substrait_rel(ctx, input, extensions).await?;
+                if let Some(condition) = filter.condition.as_ref() {
+                    let expr = from_substrait_rex(condition, &input.schema(), 
extensions).await?;
+                    input.filter(expr.as_ref().clone())
+                } else {
+                    Err(DataFusionError::NotImplemented(
+                        "Filter without an condition is not valid".to_string(),
+                    ))
+                }
+            } else {
+                Err(DataFusionError::NotImplemented(
+                    "Filter without an input is not valid".to_string(),
+                ))
+            }
+        }
+        Some(RelType::Fetch(fetch)) => {
+            if let Some(input) = fetch.input.as_ref() {
+                let input = from_substrait_rel(ctx, input, extensions).await?;
+                let offset = fetch.offset as usize;
+                let count = fetch.count as usize;
+                input.limit(offset, Some(count))
+            } else {
+                Err(DataFusionError::NotImplemented(
+                    "Fetch without an input is not valid".to_string(),
+                ))
+            }
+        }
+        Some(RelType::Sort(sort)) => {
+            if let Some(input) = sort.input.as_ref() {
+                let input = from_substrait_rel(ctx, input, extensions).await?;
+                let mut sorts: Vec<Expr> = vec![];
+                for s in &sort.sorts {
+                    let expr = from_substrait_rex(&s.expr.as_ref().unwrap(), 
&input.schema(), extensions).await?;
+                    let asc_nullfirst = match &s.sort_kind {
+                        Some(k) => match k {
+                            Direction(d) => {
+                                let direction : SortDirection = unsafe {
+                                    ::std::mem::transmute(*d)
+                                };
+                                match direction {
+                                    SortDirection::AscNullsFirst => Ok((true, 
true)),
+                                    SortDirection::AscNullsLast => Ok((true, 
false)),
+                                    SortDirection::DescNullsFirst => 
Ok((false, true)),
+                                    SortDirection::DescNullsLast => Ok((false, 
false)),
+                                    SortDirection::Clustered => {
+                                        Err(DataFusionError::NotImplemented(
+                                            "Sort with direction clustered is 
not yet supported".to_string(),
+                                        ))  
+                                    },
+                                    SortDirection::Unspecified => {
+                                        Err(DataFusionError::NotImplemented(
+                                            "Unspecified sort direction is 
invalid".to_string(),
+                                        ))  
+                                    }
+                                }
+                            }
+                            ComparisonFunctionReference(_) => {
+                                Err(DataFusionError::NotImplemented(
+                                    "Sort using comparison function reference 
is not supported".to_string(),
+                                ))
+                            },
+                        },
+                        None => {
+                            Err(DataFusionError::NotImplemented(
+                                "Sort without sort kind is 
invalid".to_string(),
+                            ))
+                        },
+                    };
+                    let (asc, nulls_first) = asc_nullfirst.unwrap();
+                    sorts.push(Expr::Sort { expr: 
Box::new(expr.as_ref().clone()), asc: asc, nulls_first: nulls_first });
+                }
+                input.sort(sorts)
+            } else {
+                Err(DataFusionError::NotImplemented(
+                    "Sort without an input is not valid".to_string(),
+                ))
+            }
+        }
+        Some(RelType::Aggregate(agg)) => {
+            if let Some(input) = agg.input.as_ref() {
+                let input = from_substrait_rel(ctx, input, extensions).await?;
+                let mut group_expr = vec![];
+                let mut aggr_expr = vec![];
+
+                let groupings = match agg.groupings.len() {
+                    1 => { Ok(&agg.groupings[0]) },
+                    _ => {
+                        Err(DataFusionError::NotImplemented(
+                            "Aggregate with multiple grouping sets is not 
supported".to_string(),
+                        ))
+                    }
+                };
+
+                for e in &groupings?.grouping_expressions {
+                    let x = from_substrait_rex(&e, &input.schema(), 
extensions).await?;
+                    group_expr.push(x.as_ref().clone());
+                }
+
+                for m in &agg.measures {
+                    let filter = match &m.filter {
+                        Some(fil) => Some(Box::new(from_substrait_rex(fil, 
&input.schema(), extensions).await?.as_ref().clone())),
+                        None => None
+                    };
+                    let agg_func = match &m.measure {
+                        Some(f) => {
+                            let distinct = match f.invocation  {
+                                _ if f.invocation == 
AggregationInvocation::Distinct as i32 => true,
+                                _ if f.invocation == 
AggregationInvocation::All as i32 => false,
+                                _ => false
+                            };
+                            from_substrait_agg_func(&f, &input.schema(), 
extensions, filter, distinct).await
+                        },
+                        None => Err(DataFusionError::NotImplemented(
+                            "Aggregate without aggregate function is not 
supported".to_string(),
+                        )),
+                    };
+                    aggr_expr.push(agg_func?.as_ref().clone());
+                }
+
+                input.aggregate(group_expr, aggr_expr)
+            } else {
+                Err(DataFusionError::NotImplemented(
+                    "Aggregate without an input is not valid".to_string(),
+                ))
+            }
+        }
+        Some(RelType::Join(join)) => {
+            let left = from_substrait_rel(ctx, &join.left.as_ref().unwrap(), 
extensions).await?;
+            let right = from_substrait_rel(ctx, &join.right.as_ref().unwrap(), 
extensions).await?;
+            let join_type = match join.r#type {
+                1 => JoinType::Inner,
+                2 => JoinType::Left,
+                3 => JoinType::Right,
+                4 => JoinType::Full,
+                5 => JoinType::Anti,
+                6 => JoinType::Semi,
+                _ => return Err(DataFusionError::Internal("invalid join 
type".to_string())),
+            };
+            let mut predicates = vec![];
+            let schema = build_join_schema(&left.schema(), &right.schema(), 
&JoinType::Inner)?;
+            let on = from_substrait_rex(&join.expression.as_ref().unwrap(), 
&schema, extensions).await?;
+            split_conjunction(&on, &mut predicates);
+            let pairs = predicates
+                .iter()
+                .map(|p| match p {
+                    Expr::BinaryExpr {
+                        left,
+                        op: Operator::Eq,
+                        right,
+                    } => match (left.as_ref(), right.as_ref()) {
+                        (Expr::Column(l), Expr::Column(r)) => 
Ok((l.flat_name(), r.flat_name())),
+                        _ => {
+                            return Err(DataFusionError::Internal(
+                                "invalid join condition".to_string(),
+                            ))
+                        }
+                    },
+                    _ => {
+                        return Err(DataFusionError::Internal(
+                            "invalid join condition".to_string(),
+                        ))
+                    }
+                })
+                .collect::<Result<Vec<_>>>()?;
+            let left_cols: Vec<&str> = pairs.iter().map(|(l, _)| 
l.as_str()).collect();
+            let right_cols: Vec<&str> = pairs.iter().map(|(_, r)| 
r.as_str()).collect();
+            left.join(right, join_type, &left_cols, &right_cols, None)
+        }
+        Some(RelType::Read(read)) => match &read.as_ref().read_type {
+            Some(ReadType::NamedTable(nt)) => {
+                let table_reference = match nt.names.len() {
+                    0 => {
+                        return Err(DataFusionError::Internal(
+                            "No table name found in NamedTable".to_string(),
+                        ));
+                    }
+                    1 => TableReference::Bare {
+                        table: &nt.names[0],
+                    },
+                    2 => TableReference::Partial {
+                        schema: &nt.names[0],
+                        table: &nt.names[1],
+                    },
+                    _ => TableReference::Full {
+                        catalog: &nt.names[0],
+                        schema: &nt.names[1],
+                        table: &nt.names[2],
+                    },
+                };
+                let t = ctx.table(table_reference)?;
+                match &read.projection {
+                    Some(MaskExpression { select, .. }) => match 
&select.as_ref() {
+                        Some(projection) => {
+                            let column_indices: Vec<usize> = projection
+                                .struct_items
+                                .iter()
+                                .map(|item| item.field as usize)
+                                .collect();
+                            match t.to_logical_plan()? {
+                                LogicalPlan::TableScan(scan) => {
+                                    let mut scan = scan.clone();
+                                    let fields: Vec<DFField> = column_indices
+                                        .iter()
+                                        .map(|i| 
scan.projected_schema.field(*i).clone())
+                                        .collect();
+                                    scan.projection = Some(column_indices);
+                                    scan.projected_schema = DFSchemaRef::new(
+                                        DFSchema::new_with_metadata(fields, 
HashMap::new())?,
+                                    );
+                                    let plan = LogicalPlan::TableScan(scan);
+                                    
Ok(Arc::new(DataFrame::new(ctx.state.clone(), &plan)))
+                                }
+                                _ => Err(DataFusionError::Internal(
+                                    "unexpected plan for table".to_string(),
+                                )),
+                            }
+                        }
+                        _ => Ok(t),
+                    },
+                    _ => Ok(t),
+                }
+            }
+            _ => Err(DataFusionError::NotImplemented(
+                "Only NamedTable reads are supported".to_string(),
+            )),
+        },
+        _ => Err(DataFusionError::NotImplemented(format!(
+            "Unsupported RelType: {:?}",
+            rel.rel_type
+        ))),
+    }
+}
+
+/// Convert Substrait AggregateFunction to DataFusion Expr
+pub async fn from_substrait_agg_func(
+    f: &AggregateFunction,
+    input_schema: &DFSchema,
+    extensions: &HashMap<u32, &String>,
+    filter: Option<Box<Expr>>,
+    distinct: bool
+) -> Result<Arc<Expr>> {
+    let mut args: Vec<Expr> = vec![];
+    for arg in &f.arguments {
+        let arg_expr = match &arg.arg_type {
+            Some(ArgType::Value(e)) => from_substrait_rex(e, input_schema, 
extensions).await,
+            _ => Err(DataFusionError::NotImplemented(
+                    "Aggregated function argument non-Value type not 
supported".to_string(),
+                ))
+        };
+        args.push(arg_expr?.as_ref().clone());
+    }
+
+    let fun = match extensions.get(&f.function_reference) {
+        Some(function_name) => 
aggregate_function::AggregateFunction::from_str(function_name),
+        None => Err(DataFusionError::NotImplemented(format!(
+                "Aggregated function not found: function anchor = {:?}",
+                f.function_reference
+            )
+        ))
+    };
+
+    Ok(
+        Arc::new(
+            Expr::AggregateFunction {
+                fun: fun.unwrap(),
+                args: args,
+                distinct: distinct,
+                filter: filter
+            }
+        )
+    )
+}
+
+/// Convert Substrait Rex to DataFusion Expr
+#[async_recursion]
+pub async fn from_substrait_rex(e: &Expression, input_schema: &DFSchema, 
extensions: &HashMap<u32, &String>) -> Result<Arc<Expr>> {
+    match &e.rex_type {
+        Some(RexType::Selection(field_ref)) => match &field_ref.reference_type 
{
+            Some(DirectReference(direct)) => match 
&direct.reference_type.as_ref() {
+                Some(StructField(x)) => match &x.child.as_ref() {
+                    Some(_) => Err(DataFusionError::NotImplemented(
+                        "Direct reference StructField with child is not 
supported".to_string(),
+                    )),
+                    None => Ok(Arc::new(Expr::Column(Column {
+                        relation: None,
+                        name: input_schema
+                            .field(x.field as usize)
+                            .name()
+                            .to_string(),
+                    }))),
+                },
+                _ => Err(DataFusionError::NotImplemented(
+                    "Direct reference with types other than StructField is not 
supported".to_string(),
+                )),
+            },
+            _ => Err(DataFusionError::NotImplemented(
+                "unsupported field ref type".to_string(),
+            )),
+        },
+        Some(RexType::IfThen(if_then)) => {
+            // Parse `ifs`
+            // If the first element does not have a `then` part, then we can 
assume it's a base expression
+            let mut when_then_expr: Vec<(Box<Expr>, Box<Expr>)> = vec![];
+            let mut expr = None;
+            for (i, if_expr) in if_then.ifs.iter().enumerate() {
+                if i == 0 {
+                    // Check if the first element is type base expression
+                    if if_expr.then.is_none() {
+                        expr = 
Some(Box::new(from_substrait_rex(&if_expr.r#if.as_ref().unwrap(), input_schema, 
extensions).await?.as_ref().clone()));
+                        continue;
+                    }
+                }
+                when_then_expr.push(
+                    (
+                        
Box::new(from_substrait_rex(&if_expr.r#if.as_ref().unwrap(), input_schema, 
extensions).await?.as_ref().clone()),
+                        
Box::new(from_substrait_rex(&if_expr.then.as_ref().unwrap(), input_schema, 
extensions).await?.as_ref().clone())
+                    ),
+                );
+            }
+            // Parse `else`
+            let else_expr = match &if_then.r#else {
+                Some(e) => Some(Box::new(
+                                                from_substrait_rex(&e, 
input_schema, extensions).await?.as_ref().clone(),
+                                            )),
+                None => None
+            };
+            Ok(Arc::new(Expr::Case { expr: expr, when_then_expr: 
when_then_expr, else_expr: else_expr }))
+        },
+        Some(RexType::ScalarFunction(f)) => {
+            assert!(f.arguments.len() == 2);
+            let op = match extensions.get(&f.function_reference) {
+                    Some(fname) => name_to_op(fname),
+                    None => Err(DataFusionError::NotImplemented(format!(
+                        "Aggregated function not found: function reference = 
{:?}",
+                        f.function_reference
+                    )
+                ))
+            };
+            match (&f.arguments[0].arg_type, &f.arguments[1].arg_type) {
+                (Some(ArgType::Value(l)), Some(ArgType::Value(r))) => {
+                    Ok(Arc::new(Expr::BinaryExpr {
+                        left: Box::new(from_substrait_rex(l, input_schema, 
extensions).await?.as_ref().clone()),
+                        op: op?,
+                        right: Box::new(
+                            from_substrait_rex(r, input_schema, 
extensions).await?.as_ref().clone(),
+                        ),
+                    }))
+                }
+                (l, r) => Err(DataFusionError::NotImplemented(format!(
+                    "Invalid arguments for binary expression: {:?} and {:?}",
+                    l, r
+                ))),
+            }
+        }
+        Some(RexType::Literal(lit)) => match &lit.literal_type {
+            Some(LiteralType::I8(n)) => {
+                Ok(Arc::new(Expr::Literal(ScalarValue::Int8(Some(*n as i8)))))
+            }
+            Some(LiteralType::I16(n)) => {
+                Ok(Arc::new(Expr::Literal(ScalarValue::Int16(Some(*n as 
i16)))))
+            }
+            Some(LiteralType::I32(n)) => {
+                Ok(Arc::new(Expr::Literal(ScalarValue::Int32(Some(*n as 
i32)))))
+            }
+            Some(LiteralType::I64(n)) => {
+                Ok(Arc::new(Expr::Literal(ScalarValue::Int64(Some(*n as 
i64)))))
+            }
+            Some(LiteralType::Boolean(b)) => {
+                Ok(Arc::new(Expr::Literal(ScalarValue::Boolean(Some(*b)))))
+            }
+            Some(LiteralType::Date(d)) => {
+                Ok(Arc::new(Expr::Literal(ScalarValue::Date32(Some(*d)))))
+            }
+            Some(LiteralType::Fp32(f)) => {
+                Ok(Arc::new(Expr::Literal(ScalarValue::Float32(Some(*f)))))
+            }
+            Some(LiteralType::Fp64(f)) => {
+                Ok(Arc::new(Expr::Literal(ScalarValue::Float64(Some(*f)))))
+            }
+            Some(LiteralType::String(s)) => 
Ok(Arc::new(Expr::Literal(ScalarValue::Utf8(
+                Some(s.clone()),
+            )))),
+            Some(LiteralType::Binary(b)) => 
Ok(Arc::new(Expr::Literal(ScalarValue::Binary(Some(
+                b.clone(),
+            ))))),
+            _ => {
+                return Err(DataFusionError::NotImplemented(format!(
+                    "Unsupported literal_type: {:?}",
+                    lit.literal_type
+                )))
+            }
+        },
+        _ => Err(DataFusionError::NotImplemented(
+            "unsupported rex_type".to_string(),
+        )),
+    }
+}
diff --git a/datafusion/substrait/src/lib.rs b/datafusion/substrait/src/lib.rs
new file mode 100644
index 000000000..07b8e7add
--- /dev/null
+++ b/datafusion/substrait/src/lib.rs
@@ -0,0 +1,20 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+pub mod consumer;
+pub mod producer;
+pub mod serializer;
diff --git a/datafusion/substrait/src/producer.rs 
b/datafusion/substrait/src/producer.rs
new file mode 100644
index 000000000..78532046b
--- /dev/null
+++ b/datafusion/substrait/src/producer.rs
@@ -0,0 +1,557 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::collections::HashMap;
+
+use datafusion::{
+    error::{DataFusionError, Result},
+    logical_plan::{DFSchemaRef, Expr, JoinConstraint, LogicalPlan, Operator},
+    prelude::JoinType,
+    scalar::ScalarValue,
+};
+
+use substrait::protobuf::{
+    aggregate_function::AggregationInvocation,
+    aggregate_rel::{Grouping, Measure},
+    expression::{
+        field_reference::ReferenceType,
+        if_then::IfClause,
+        literal::LiteralType,
+        mask_expression::{StructItem, StructSelect},
+        reference_segment,
+        FieldReference, IfThen, Literal, MaskExpression, ReferenceSegment, 
RexType, ScalarFunction,
+    },
+    extensions::{self, simple_extension_declaration::{MappingType, 
ExtensionFunction}},
+    function_argument::ArgType,
+    plan_rel,
+    read_rel::{NamedTable, ReadType},
+    rel::RelType,
+    sort_field::{
+        SortDirection,
+        SortKind,
+    },
+    AggregateRel, Expression, FetchRel, FilterRel, FunctionArgument, JoinRel, 
NamedStruct, ProjectRel, ReadRel, SortField, SortRel,
+    PlanRel,
+    Plan, Rel, RelRoot, AggregateFunction,
+};
+
+/// Convert DataFusion LogicalPlan to Substrait Plan
+pub fn to_substrait_plan(plan: &LogicalPlan) -> Result<Box<Plan>> {
+    // Parse relation nodes
+    let mut extension_info: (Vec<extensions::SimpleExtensionDeclaration>, 
HashMap<String, u32>) = (vec![], HashMap::new());
+    // Generate PlanRel(s)
+    // Note: Only 1 relation tree is currently supported
+    let plan_rels = vec![PlanRel {
+        rel_type: Some(plan_rel::RelType::Root(
+            RelRoot {
+                input: Some(*to_substrait_rel(plan, &mut extension_info)?),
+                names: plan.schema().field_names(),
+            }
+        ))
+    }];
+
+    let (function_extensions, _) = extension_info;
+
+    // Return parsed plan
+    Ok(Box::new(Plan {
+        extension_uris: vec![],
+        extensions: function_extensions,
+        relations: plan_rels,
+        advanced_extensions: None,
+        expected_type_urls: vec![],
+    }))
+
+}
+
+/// Convert DataFusion LogicalPlan to Substrait Rel
+pub fn to_substrait_rel(plan: &LogicalPlan, extension_info: &mut 
(Vec<extensions::SimpleExtensionDeclaration>, HashMap<String, u32>)) -> 
Result<Box<Rel>> {
+    match plan {
+        LogicalPlan::TableScan(scan) => {
+            let projection = scan.projection.as_ref().map(|p| {
+                p.iter()
+                    .map(|i| StructItem {
+                        field: *i as i32,
+                        child: None,
+                    })
+                    .collect()
+            });
+
+            if let Some(struct_items) = projection {
+                Ok(Box::new(Rel {
+                    rel_type: Some(RelType::Read(Box::new(ReadRel {
+                        common: None,
+                        base_schema: Some(NamedStruct {
+                            names: scan
+                                .projected_schema
+                                .fields()
+                                .iter()
+                                .map(|f| f.name().to_owned())
+                                .collect(),
+                            r#struct: None,
+                        }),
+                        filter: None,
+                        projection: Some(MaskExpression {
+                            select: Some(StructSelect { struct_items }),
+                            maintain_singular_struct: false,
+                        }),
+                        advanced_extension: None,
+                        read_type: Some(ReadType::NamedTable(NamedTable {
+                            names: vec![scan.table_name.clone()],
+                            advanced_extension: None,
+                        })),
+                    }))),
+                }))
+            } else {
+                Err(DataFusionError::NotImplemented(
+                    "TableScan without projection is not 
supported".to_string(),
+                ))
+            }
+        }
+        LogicalPlan::Projection(p) => {
+            let expressions = p
+                .expr
+                .iter()
+                .map(|e| to_substrait_rex(e, p.input.schema(), extension_info))
+                .collect::<Result<Vec<_>>>()?;
+            Ok(Box::new(Rel {
+                rel_type: Some(RelType::Project(Box::new(ProjectRel {
+                    common: None,
+                    input: Some(to_substrait_rel(p.input.as_ref(), 
extension_info)?),
+                    expressions,
+                    advanced_extension: None,
+                }))),
+            }))
+        }
+        LogicalPlan::Filter(filter) => {
+            let input = to_substrait_rel(filter.input.as_ref(), 
extension_info)?;
+            let filter_expr = to_substrait_rex(&filter.predicate, 
filter.input.schema(), extension_info)?;
+            Ok(Box::new(Rel {
+                rel_type: Some(RelType::Filter(Box::new(FilterRel {
+                    common: None,
+                    input: Some(input),
+                    condition: Some(Box::new(filter_expr)),
+                    advanced_extension: None,
+                }))),
+            }))
+        }
+        LogicalPlan::Limit(limit) => {
+            let input = to_substrait_rel(limit.input.as_ref(), 
extension_info)?;
+            let limit_fetch = match limit.fetch {
+                Some(count) => count,
+                None => 0,
+            };
+            Ok(Box::new(Rel {
+                rel_type: Some(RelType::Fetch(Box::new(FetchRel {
+                    common: None,
+                    input: Some(input),
+                    offset: limit.skip as i64,
+                    count: limit_fetch as i64,
+                    advanced_extension: None,
+                }))),
+            }))
+        }
+        LogicalPlan::Sort(sort) => {
+            let input = to_substrait_rel(sort.input.as_ref(), extension_info)?;
+            let sort_fields = sort
+                .expr
+                .iter()
+                .map(|e| substrait_sort_field(e, sort.input.schema(), 
extension_info))
+                .collect::<Result<Vec<_>>>()?;
+            Ok(Box::new(Rel {
+                rel_type: Some(RelType::Sort(Box::new(SortRel {
+                    common: None,
+                    input: Some(input),
+                    sorts: sort_fields,
+                    advanced_extension: None,
+                }))),
+            }))
+        }
+        LogicalPlan::Aggregate(agg) => {
+            let input = to_substrait_rel(agg.input.as_ref(), extension_info)?;
+            // Translate aggregate expression to Substrait's groupings 
(repeated repeated Expression)
+            let grouping = agg
+                .group_expr
+                .iter()
+                .map(|e| to_substrait_rex(e, agg.input.schema(), 
extension_info))
+                .collect::<Result<Vec<_>>>()?;
+            let measures = agg
+                .aggr_expr
+                .iter()
+                .map(|e| to_substrait_agg_measure(e, agg.input.schema(), 
extension_info))
+                .collect::<Result<Vec<_>>>()?;
+            
+            Ok(Box::new(Rel {
+                rel_type: Some(RelType::Aggregate(Box::new(AggregateRel {
+                    common: None,
+                    input: Some(input),
+                    groupings: vec![Grouping { grouping_expressions: grouping 
}], //groupings, 
+                    measures: measures,
+                    advanced_extension: None,
+                }))),
+            }))
+        }
+        LogicalPlan::Distinct(distinct) => {
+            // Use Substrait's AggregateRel with empty measures to represent 
`select distinct`
+            let input = to_substrait_rel(distinct.input.as_ref(), 
extension_info)?;
+            // Get grouping keys from the input relation's number of output 
fields
+            let grouping = (0..distinct.input.schema().fields().len())
+                .map(|x: usize| substrait_field_ref(x))
+                .collect::<Result<Vec<_>>>()?;
+
+            Ok(Box::new(Rel {
+                rel_type: Some(RelType::Aggregate(Box::new(AggregateRel {
+                    common: None,
+                    input: Some(input),
+                    groupings: vec![Grouping { grouping_expressions: grouping 
}],
+                    measures: vec![],
+                    advanced_extension: None,
+                }))),
+            }))
+        }
+        LogicalPlan::Join(join) => {
+            let left = to_substrait_rel(join.left.as_ref(), extension_info)?;
+            let right = to_substrait_rel(join.right.as_ref(), extension_info)?;
+            let join_type = match join.join_type {
+                JoinType::Inner => 1,
+                JoinType::Left => 2,
+                JoinType::Right => 3,
+                JoinType::Full => 4,
+                JoinType::Anti => 5,
+                JoinType::Semi => 6,
+            };
+            // we only support basic joins so return an error for anything not 
yet supported
+            if join.null_equals_null {
+                return Err(DataFusionError::NotImplemented(
+                    "join null_equals_null".to_string(),
+                ));
+            }
+            if join.filter.is_some() {
+                return Err(DataFusionError::NotImplemented("join 
filter".to_string()));
+            }
+            match join.join_constraint {
+                JoinConstraint::On => {}
+                _ => {
+                    return Err(DataFusionError::NotImplemented(
+                        "join constraint".to_string(),
+                    ))
+                }
+            }
+            // map the left and right columns to binary expressions in the 
form `l = r`
+            let join_expression: Vec<Expr> = join
+                .on
+                .iter()
+                .map(|(l, r)| 
Expr::Column(l.clone()).eq(Expr::Column(r.clone())))
+                .collect();
+            // build a single expression for the ON condition, such as `l.a = 
r.a AND l.b = r.b`
+            let join_expression = join_expression
+                .into_iter()
+                .reduce(|acc: Expr, expr: Expr| acc.and(expr));
+            if let Some(e) = join_expression {
+                Ok(Box::new(Rel {
+                    rel_type: Some(RelType::Join(Box::new(JoinRel {
+                        common: None,
+                        left: Some(left),
+                        right: Some(right),
+                        r#type: join_type,
+                        expression: Some(Box::new(to_substrait_rex(&e, 
&join.schema, extension_info)?)),
+                        post_join_filter: None,
+                        advanced_extension: None,
+                    }))),
+                }))
+            } else {
+                Err(DataFusionError::NotImplemented(
+                    "Empty join condition".to_string(),
+                ))
+            }
+        }
+        LogicalPlan::SubqueryAlias(alias) => {
+            // Do nothing if encounters SubqueryAlias
+            // since there is no corresponding relation type in Substrait
+            to_substrait_rel(alias.input.as_ref(), extension_info)
+        }
+        _ => Err(DataFusionError::NotImplemented(format!(
+            "Unsupported operator: {:?}",
+            plan
+        ))),
+    }
+}
+
+pub fn operator_to_name(op: Operator) -> &'static str {
+    match op {
+        Operator::Eq => "equal",
+        Operator::NotEq => "not_equal",
+        Operator::Lt => "lt",
+        Operator::LtEq => "lte",
+        Operator::Gt => "gt",
+        Operator::GtEq => "gte",
+        Operator::Plus => "add",
+        Operator::Minus => "substract",
+        Operator::Multiply => "multiply",
+        Operator::Divide => "divide",
+        Operator::Modulo => "mod",
+        Operator::And => "and",
+        Operator::Or => "or",
+        Operator::Like => "like",
+        Operator::NotLike => "not_like",
+        Operator::IsDistinctFrom => "is_distinct_from",
+        Operator::IsNotDistinctFrom => "is_not_distinct_from",
+        Operator::RegexMatch => "regex_match",
+        Operator::RegexIMatch => "regex_imatch",
+        Operator::RegexNotMatch => "regex_not_match",
+        Operator::RegexNotIMatch => "regex_not_imatch",
+        Operator::BitwiseAnd => "bitwise_and",
+        Operator::BitwiseOr => "bitwise_or",
+        Operator::StringConcat => "str_concat",
+        Operator::BitwiseXor => "bitwise_xor",
+        Operator::BitwiseShiftRight => "bitwise_shift_right",
+        Operator::BitwiseShiftLeft => "bitwise_shift_left",
+    }
+}
+
+pub fn to_substrait_agg_measure(expr: &Expr, schema: &DFSchemaRef, 
extension_info: &mut (Vec<extensions::SimpleExtensionDeclaration>, 
HashMap<String, u32>)) -> Result<Measure> {
+    match expr {
+        Expr::AggregateFunction { fun, args, distinct, filter } => {
+            let mut arguments: Vec<FunctionArgument> = vec![];
+            for arg in args {
+                arguments.push(FunctionArgument { arg_type: 
Some(ArgType::Value(to_substrait_rex(arg, schema, extension_info)?)) });
+            }
+            let function_name = fun.to_string().to_lowercase();
+            let function_anchor = _register_function(function_name, 
extension_info);
+            Ok(Measure {
+                measure: Some(AggregateFunction {
+                    function_reference: function_anchor,
+                    arguments: arguments,
+                    sorts: vec![],
+                    output_type: None,
+                    invocation: match distinct {
+                        true => AggregationInvocation::Distinct as i32,
+                        false => AggregationInvocation::All as i32,
+                    },
+                    phase: substrait::protobuf::AggregationPhase::Unspecified 
as i32,
+                    args: vec![],
+                }),
+                filter: match filter {
+                    Some(f) => Some(to_substrait_rex(f, schema, 
extension_info)?),
+                    None => None
+                }
+            })
+        },
+        _ => Err(DataFusionError::Internal(format!(
+            "Expression must be compatible with aggregation. Unsupported 
expression: {:?}",
+            expr
+        ))),
+    }
+}
+
+fn _register_function(function_name: String, extension_info: &mut 
(Vec<extensions::SimpleExtensionDeclaration>, HashMap<String, u32>)) -> u32 {
+    let (function_extensions, function_set) = extension_info;
+    let function_name = function_name.to_lowercase();
+    // To prevent ambiguous references between ScalarFunctions and 
AggregateFunctions,
+    // a plan-relative identifier starting from 0 is used as the 
function_anchor.
+    // The consumer is responsible for correctly registering 
<function_anchor,function_name>
+    // mapping info stored in the extensions by the producer.
+    let function_anchor = match function_set.get(&function_name) {
+        Some(function_anchor) => {
+            // Function has been registered
+            *function_anchor
+        },
+        None => {
+            // Function has NOT been registered
+            let function_anchor = function_set.len() as u32;
+            function_set.insert(function_name.clone(), function_anchor);
+
+            let function_extension = ExtensionFunction {
+                extension_uri_reference: u32::MAX,
+                function_anchor: function_anchor,
+                name: function_name,
+            };
+            let simple_extension = extensions::SimpleExtensionDeclaration {
+                mapping_type: 
Some(MappingType::ExtensionFunction(function_extension)),
+            };
+            function_extensions.push(simple_extension);
+            function_anchor
+        }
+    };
+    
+    // Return function anchor
+    function_anchor
+
+}
+
+/// Return Substrait scalar function with two arguments
+pub fn make_binary_op_scalar_func(lhs: &Expression, rhs: &Expression, op: 
Operator, extension_info: &mut (Vec<extensions::SimpleExtensionDeclaration>, 
HashMap<String, u32>)) -> Expression {
+    let function_name = operator_to_name(op).to_string().to_lowercase();
+    let function_anchor = _register_function(function_name, extension_info);
+    Expression {
+        rex_type: Some(RexType::ScalarFunction(ScalarFunction {
+            function_reference: function_anchor,
+            arguments: vec![
+                FunctionArgument {
+                    arg_type: Some(ArgType::Value(lhs.clone())),
+                },
+                FunctionArgument {
+                    arg_type: Some(ArgType::Value(rhs.clone())),
+                },
+            ],
+            output_type: None,
+            args: vec![],
+        })),
+    }
+}
+
+/// Convert DataFusion Expr to Substrait Rex
+pub fn to_substrait_rex(expr: &Expr, schema: &DFSchemaRef, extension_info: 
&mut (Vec<extensions::SimpleExtensionDeclaration>, HashMap<String, u32>)) -> 
Result<Expression> {
+    match expr {
+        Expr::Between { expr, negated, low, high } => {
+            if *negated {
+                // `expr NOT BETWEEN low AND high` can be translated into 
(expr < low OR high < expr)
+                let substrait_expr = to_substrait_rex(expr, schema, 
extension_info)?;
+                let substrait_low = to_substrait_rex(low, schema, 
extension_info)?;
+                let substrait_high = to_substrait_rex(high, schema, 
extension_info)?;
+
+                let l_expr = make_binary_op_scalar_func(&substrait_expr, 
&substrait_low, Operator::Lt, extension_info);
+                let r_expr = make_binary_op_scalar_func(&substrait_high, 
&substrait_expr, Operator::Lt, extension_info);
+
+                Ok(make_binary_op_scalar_func(&l_expr, &r_expr, Operator::Or, 
extension_info))
+            } else {
+                // `expr BETWEEN low AND high` can be translated into (low <= 
expr AND expr <= high)
+                let substrait_expr = to_substrait_rex(expr, schema, 
extension_info)?;
+                let substrait_low = to_substrait_rex(low, schema, 
extension_info)?;
+                let substrait_high = to_substrait_rex(high, schema, 
extension_info)?;
+
+                let l_expr = make_binary_op_scalar_func(&substrait_low, 
&substrait_expr, Operator::LtEq, extension_info);
+                let r_expr = make_binary_op_scalar_func(&substrait_expr, 
&substrait_high, Operator::LtEq, extension_info);
+
+                Ok(make_binary_op_scalar_func(&l_expr, &r_expr, Operator::And, 
extension_info))
+            }
+        }
+        Expr::Column(col) => {
+            let index = schema.index_of_column(&col)?;
+            substrait_field_ref(index)
+        }
+        Expr::BinaryExpr { left, op, right } => {
+            let l = to_substrait_rex(left, schema, extension_info)?;
+            let r = to_substrait_rex(right, schema, extension_info)?;
+
+            Ok(make_binary_op_scalar_func(&l, &r, *op, extension_info))
+        }
+        Expr::Case { expr, when_then_expr, else_expr } => {
+            let mut ifs: Vec<IfClause> = vec![];
+            // Parse base
+            if let Some(e) = expr { // Base expression exists
+                ifs.push(IfClause {
+                    r#if: Some(to_substrait_rex(e, schema, extension_info)?),
+                    then: None,
+                });
+            }
+            // Parse `when`s
+            for (r#if, then) in when_then_expr {
+                ifs.push(IfClause {
+                    r#if: Some(to_substrait_rex(r#if, schema, 
extension_info)?),
+                    then: Some(to_substrait_rex(then, schema, 
extension_info)?),
+                });
+            }
+
+            // Parse outer `else`
+            let r#else: Option<Box<Expression>> = match else_expr {
+                Some(e) => Some(Box::new(to_substrait_rex(e, schema, 
extension_info)?)),
+                None => None,
+            };
+            
+            Ok(Expression {
+                rex_type: Some(RexType::IfThen(Box::new(IfThen {
+                    ifs: ifs,
+                    r#else: r#else
+                }))),
+            })
+        }
+        Expr::Literal(value) => {
+            let literal_type = match value {
+                ScalarValue::Int8(Some(n)) => Some(LiteralType::I8(*n as i32)),
+                ScalarValue::Int16(Some(n)) => Some(LiteralType::I16(*n as 
i32)),
+                ScalarValue::Int32(Some(n)) => Some(LiteralType::I32(*n)),
+                ScalarValue::Int64(Some(n)) => Some(LiteralType::I64(*n)),
+                ScalarValue::Boolean(Some(b)) => 
Some(LiteralType::Boolean(*b)),
+                ScalarValue::Float32(Some(f)) => Some(LiteralType::Fp32(*f)),
+                ScalarValue::Float64(Some(f)) => Some(LiteralType::Fp64(*f)),
+                ScalarValue::Utf8(Some(s)) => 
Some(LiteralType::String(s.clone())),
+                ScalarValue::LargeUtf8(Some(s)) => 
Some(LiteralType::String(s.clone())),
+                ScalarValue::Binary(Some(b)) => 
Some(LiteralType::Binary(b.clone())),
+                ScalarValue::LargeBinary(Some(b)) => 
Some(LiteralType::Binary(b.clone())),
+                ScalarValue::Date32(Some(d)) => Some(LiteralType::Date(*d)),
+                _ => {
+                    return Err(DataFusionError::NotImplemented(format!(
+                        "Unsupported literal: {:?}",
+                        value
+                    )))
+                }
+            };
+            Ok(Expression {
+                rex_type: Some(RexType::Literal(Literal {
+                    nullable: true,
+                    type_variation_reference: 0,
+                    literal_type,
+                })),
+            })
+        }
+        Expr::Alias(expr, _alias) => {
+            to_substrait_rex(expr, schema, extension_info)
+        }
+        _ => Err(DataFusionError::NotImplemented(format!(
+            "Unsupported expression: {:?}",
+            expr
+        ))),
+    }
+}
+
+fn substrait_sort_field(expr: &Expr, schema: &DFSchemaRef, extension_info: 
&mut (Vec<extensions::SimpleExtensionDeclaration>, HashMap<String, u32>)) -> 
Result<SortField> {
+    match expr {
+        Expr::Sort { expr, asc, nulls_first } => {
+            let e = to_substrait_rex(expr, schema, extension_info)?;
+            let d = match (asc, nulls_first) {
+                (true, true) => SortDirection::AscNullsFirst,
+                (true, false) => SortDirection::AscNullsLast,
+                (false, true) => SortDirection::DescNullsFirst,
+                (false, false) => SortDirection::DescNullsLast,
+            };
+            Ok(SortField {
+                expr: Some(e),
+                sort_kind: Some(SortKind::Direction(d as i32)),
+            })
+        },
+        _ => Err(DataFusionError::NotImplemented(format!(
+            "Expecting sort expression but got {:?}",
+            expr
+        ))),
+    }
+}
+
+fn substrait_field_ref(index: usize) -> Result<Expression> {
+    Ok(Expression {
+        rex_type: Some(RexType::Selection(Box::new(FieldReference {
+            reference_type: 
Some(ReferenceType::DirectReference(ReferenceSegment {
+                reference_type: 
Some(reference_segment::ReferenceType::StructField(
+                    Box::new(reference_segment::StructField {
+                        field: index as i32,
+                        child: None,
+                    }),
+                )),
+            })),
+            root_type: None,
+        }))),
+    })
+}
diff --git a/datafusion/substrait/src/serializer.rs 
b/datafusion/substrait/src/serializer.rs
new file mode 100644
index 000000000..7f52077f1
--- /dev/null
+++ b/datafusion/substrait/src/serializer.rs
@@ -0,0 +1,57 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::producer;
+
+use datafusion::error::Result;
+use datafusion::prelude::*;
+
+use prost::Message;
+use substrait::protobuf::Plan;
+
+use std::fs::OpenOptions;
+use std::io::{Write, Read};
+
+pub async fn serialize(sql: &str, ctx: &SessionContext, path: &str) -> 
Result<()> {
+    let df = ctx.sql(sql).await?;
+    let plan = df.to_logical_plan()?;
+    let proto = producer::to_substrait_plan(&plan)?;
+
+    let mut protobuf_out = Vec::<u8>::new();
+    proto.encode(&mut protobuf_out).unwrap();
+    let mut file = OpenOptions::new()
+                .create(true)
+                .write(true)
+                .open(path)?;
+    file.write_all(&protobuf_out)?;
+    Ok(())
+}
+
+pub async fn deserialize(path: &str) -> Result<Box<Plan>> {
+    let mut protobuf_in = Vec::<u8>::new();
+
+    let mut file = OpenOptions::new()
+                .read(true)
+                .open(path)?;
+
+    file.read_to_end(&mut protobuf_in)?;
+    let proto = Message::decode(&*protobuf_in).unwrap();
+
+    Ok(Box::new(proto))
+}
+
+
diff --git a/datafusion/substrait/tests/roundtrip.rs 
b/datafusion/substrait/tests/roundtrip.rs
new file mode 100644
index 000000000..21a3a5f29
--- /dev/null
+++ b/datafusion/substrait/tests/roundtrip.rs
@@ -0,0 +1,273 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use datafusion_substrait::consumer;
+use datafusion_substrait::producer;
+
+#[cfg(test)]
+mod tests {
+
+    use crate::{consumer::from_substrait_plan, producer::to_substrait_plan};
+    use datafusion::error::Result;
+    use datafusion::prelude::*;
+    use 
substrait::protobuf::extensions::simple_extension_declaration::MappingType;
+
+    #[tokio::test]
+    async fn simple_select() -> Result<()> {
+        roundtrip("SELECT a, b FROM data").await
+    }
+
+    #[tokio::test]
+    async fn wildcard_select() -> Result<()> {
+        roundtrip("SELECT * FROM data").await
+    }
+
+    #[tokio::test]
+    async fn select_with_filter() -> Result<()> {
+        roundtrip("SELECT * FROM data WHERE a > 1").await
+    }
+
+    #[tokio::test]
+    async fn select_with_reused_functions() -> Result<()> {
+        let sql = "SELECT * FROM data WHERE a > 1 AND a < 10 AND b > 0";
+        roundtrip(sql).await?;
+        let (mut function_names, mut function_anchors) = 
function_extension_info(sql).await?;
+        function_names.sort();
+        function_anchors.sort();
+
+        assert_eq!(function_names, ["and", "gt", "lt"]);
+        assert_eq!(function_anchors, [0, 1, 2]);
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn select_with_filter_date() -> Result<()> {
+        roundtrip("SELECT * FROM data WHERE c > CAST('2020-01-01' AS 
DATE)").await
+    }
+
+    #[tokio::test]
+    async fn select_with_filter_bool_expr() -> Result<()> {
+        roundtrip("SELECT * FROM data WHERE d AND a > 1").await
+    }
+
+    #[tokio::test]
+    async fn select_with_limit() -> Result<()> {
+        roundtrip_fill_na("SELECT * FROM data LIMIT 100").await
+    }
+
+    #[tokio::test]
+    async fn select_with_limit_offset() -> Result<()> {
+        roundtrip("SELECT * FROM data LIMIT 200 OFFSET 10").await
+    }
+
+    #[tokio::test]
+    async fn simple_aggregate() -> Result<()> {
+        roundtrip("SELECT a, sum(b) FROM data GROUP BY a").await
+    }
+
+    #[tokio::test]
+    async fn aggregate_distinct_with_having() -> Result<()> {
+        roundtrip("SELECT a, count(distinct b) FROM data GROUP BY a, c HAVING 
count(b) > 100").await
+    }
+
+    #[tokio::test]
+    async fn aggregate_multiple_keys() -> Result<()> {
+        roundtrip("SELECT a, c, avg(b) FROM data GROUP BY a, c").await
+    }
+
+    #[tokio::test]
+    async fn simple_distinct() -> Result<()> {
+        test_alias(
+            "SELECT * FROM (SELECT distinct a FROM data)", // `SELECT *` is 
used to add `projection` at the root
+            "SELECT a FROM data GROUP BY a",
+        ).await
+    }
+
+    #[tokio::test]
+    async fn select_distinct_two_fields() -> Result<()> {
+        test_alias(
+            "SELECT * FROM (SELECT distinct a, b FROM data)", // `SELECT *` is 
used to add `projection` at the root
+            "SELECT a, b FROM data GROUP BY a, b",
+        ).await
+    }
+
+    #[tokio::test]
+    async fn simple_alias() -> Result<()> {
+        test_alias(
+            "SELECT d1.a, d1.b FROM data d1",
+            "SELECT a, b FROM data",
+        ).await
+    }
+
+    #[tokio::test]
+    async fn two_table_alias() -> Result<()> {
+        test_alias(
+            "SELECT d1.a FROM data d1 JOIN data2 d2 ON d1.a = d2.a",
+            "SELECT data.a FROM data JOIN data2 ON data.a = data2.a",
+        )
+        .await
+    }
+
+    #[tokio::test]
+    async fn between_integers() -> Result<()> {
+        test_alias(
+            "SELECT * FROM data WHERE a BETWEEN 2 AND 6",
+            "SELECT * FROM data WHERE a >= 2 AND a <= 6"
+        )
+        .await
+    }
+
+    #[tokio::test]
+    async fn not_between_integers() -> Result<()> {
+        test_alias(
+            "SELECT * FROM data WHERE a NOT BETWEEN 2 AND 6",
+            "SELECT * FROM data WHERE a < 2 OR a > 6"
+        )
+        .await
+    }
+
+    #[tokio::test]
+    async fn case_without_base_expression() -> Result<()> {
+        roundtrip("SELECT (CASE WHEN a >= 0 THEN 'positive' ELSE 'negative' 
END) FROM data").await
+    }
+
+    #[tokio::test]
+    async fn case_with_base_expression() -> Result<()> {
+        roundtrip("SELECT (CASE a
+                            WHEN 0 THEN 'zero'
+                            WHEN 1 THEN 'one'
+                            ELSE 'other'
+                           END) FROM data").await
+    }
+
+    #[tokio::test]
+    async fn roundtrip_inner_join() -> Result<()> {
+        roundtrip("SELECT data.a FROM data JOIN data2 ON data.a = 
data2.a").await
+    }
+
+    #[tokio::test]
+    async fn inner_join() -> Result<()> {
+        assert_expected_plan(
+            "SELECT data.a FROM data JOIN data2 ON data.a = data2.a",
+            "Projection: data.a\
+            \n  Inner Join: data.a = data2.a\
+            \n    TableScan: data projection=[a]\
+            \n    TableScan: data2 projection=[a]",
+        )
+        .await
+    }
+
+    async fn assert_expected_plan(sql: &str, expected_plan_str: &str) -> 
Result<()> {
+        let mut ctx = create_context().await?;
+        let df = ctx.sql(sql).await?;
+        let plan = df.to_logical_plan()?;
+        let proto = to_substrait_plan(&plan)?;
+        let df = from_substrait_plan(&mut ctx, &proto).await?;
+        let plan2 = df.to_logical_plan()?;
+        let plan2str = format!("{:?}", plan2);
+        assert_eq!(expected_plan_str, &plan2str);
+        Ok(())
+    }
+
+    async fn roundtrip_fill_na(sql: &str) -> Result<()> {
+        let mut ctx = create_context().await?;
+        let df = ctx.sql(sql).await?;
+        let plan1 = df.to_logical_plan()?;
+        let proto = to_substrait_plan(&plan1)?;
+
+        let df = from_substrait_plan(&mut ctx, &proto).await?;
+        let plan2 = df.to_logical_plan()?;
+
+        // Format plan string and replace all None's with 0
+        let plan1str = format!("{:?}", plan1).replace("None", "0");
+        let plan2str = format!("{:?}", plan2).replace("None", "0");
+
+        assert_eq!(plan1str, plan2str);
+        Ok(())
+    }
+
+    async fn test_alias(sql_with_alias: &str, sql_no_alias: &str) -> 
Result<()> {
+        // Since we ignore the SubqueryAlias in the producer, the result 
should be
+        // the same as producing a Substrait plan from the same query without 
aliases
+        // sql_with_alias -> substrait -> logical plan = sql_no_alias -> 
substrait -> logical plan
+        let mut ctx = create_context().await?;
+
+        let df_a = ctx.sql(sql_with_alias).await?;
+        let proto_a = to_substrait_plan(&df_a.to_logical_plan()?)?;
+        let plan_with_alias = from_substrait_plan(&mut ctx, 
&proto_a).await?.to_logical_plan()?;
+
+        let df = ctx.sql(sql_no_alias).await?;
+        let proto = to_substrait_plan(&df.to_logical_plan()?)?;
+        let plan = from_substrait_plan(&mut ctx, 
&proto).await?.to_logical_plan()?;
+
+        println!("{:#?}", plan_with_alias);
+        println!("{:#?}", plan);
+
+        let plan1str = format!("{:?}", plan_with_alias);
+        let plan2str = format!("{:?}", plan);
+        assert_eq!(plan1str, plan2str);
+        Ok(())
+    }
+
+    async fn roundtrip(sql: &str) -> Result<()> {
+        let mut ctx = create_context().await?;
+        let df = ctx.sql(sql).await?;
+        let plan = df.to_logical_plan()?;
+        let proto = to_substrait_plan(&plan)?;
+
+        let df = from_substrait_plan(&mut ctx, &proto).await?;
+        let plan2 = df.to_logical_plan()?;
+
+        println!("{:#?}", plan);
+        println!("{:#?}", plan2);
+
+        let plan1str = format!("{:?}", plan);
+        let plan2str = format!("{:?}", plan2);
+        assert_eq!(plan1str, plan2str);
+        Ok(())
+    }
+
+    async fn function_extension_info(sql: &str) -> Result<(Vec<String>, 
Vec<u32>)>  {
+        let ctx = create_context().await?;
+        let df = ctx.sql(sql).await?;
+        let plan = df.to_logical_plan()?;
+        let proto = to_substrait_plan(&plan)?;
+
+        let mut function_names: Vec<String> = vec![];
+        let mut function_anchors: Vec<u32> = vec![];
+        for e in &proto.extensions {
+            let (function_anchor, function_name) = match 
e.mapping_type.as_ref().unwrap() {
+                MappingType::ExtensionFunction(ext_f) => 
(ext_f.function_anchor, &ext_f.name),
+                _ => unreachable!("Producer does not generate a non-function 
extension")
+            };
+            function_names.push(function_name.to_string());
+            function_anchors.push(function_anchor);
+        }
+        
+        Ok((function_names, function_anchors))
+    }
+
+    async fn create_context() -> Result<SessionContext> {
+        let ctx = SessionContext::new();
+        ctx.register_csv("data", "tests/testdata/data.csv", 
CsvReadOptions::new())
+            .await?;
+        ctx.register_csv("data2", "tests/testdata/data.csv", 
CsvReadOptions::new())
+            .await?;
+        Ok(ctx)
+    }
+}
diff --git a/datafusion/substrait/tests/serialize.rs 
b/datafusion/substrait/tests/serialize.rs
new file mode 100644
index 000000000..505c4f5f4
--- /dev/null
+++ b/datafusion/substrait/tests/serialize.rs
@@ -0,0 +1,62 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#[cfg(test)]
+mod tests {
+
+    use datafusion_substrait::consumer::from_substrait_plan;
+    use datafusion_substrait::serializer;
+
+    use datafusion::error::Result;
+    use datafusion::prelude::*;
+
+    use std::fs;
+
+    #[tokio::test]
+    async fn serialize_simple_select() -> Result<()> {
+        let mut ctx = create_context().await?;
+        let path = "tests/simple_select.bin";
+        let sql = "SELECT a, b FROM data";
+        // Test reference
+        let df_ref = ctx.sql(sql).await?;
+        let plan_ref = df_ref.to_logical_plan()?;
+        // Test
+        // Write substrait plan to file
+        serializer::serialize(sql, &ctx, &path).await?;
+        // Read substrait plan from file
+        let proto = serializer::deserialize(path).await?;
+        // Check plan equality
+        let df = from_substrait_plan(&mut ctx, &proto).await?;
+        let plan = df.to_logical_plan()?;
+        let plan_str_ref = format!("{:?}", plan_ref);
+        let plan_str = format!("{:?}", plan);
+        assert_eq!(plan_str_ref, plan_str);
+        // Delete test binary file
+        fs::remove_file(path)?;
+
+        Ok(())
+    }
+
+    async fn create_context() -> Result<SessionContext> {
+        let ctx = SessionContext::new();
+        ctx.register_csv("data", "tests/testdata/data.csv", 
CsvReadOptions::new())
+            .await?;
+        ctx.register_csv("data2", "tests/testdata/data.csv", 
CsvReadOptions::new())
+            .await?;
+        Ok(ctx)
+    }
+}
\ No newline at end of file
diff --git a/datafusion/substrait/tests/testdata/data.csv 
b/datafusion/substrait/tests/testdata/data.csv
new file mode 100644
index 000000000..4394789bc
--- /dev/null
+++ b/datafusion/substrait/tests/testdata/data.csv
@@ -0,0 +1,3 @@
+a,b,c,d
+1,2,2020-01-01,false
+3,4,2020-01-01,true
\ No newline at end of file

Reply via email to