This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new c1f6269  Use SessionContext to parse Expr protobuf (#2024)
c1f6269 is described below

commit c1f6269a3138a22b50679d73609645ed1511e5de
Author: Dan Harris <[email protected]>
AuthorDate: Sun Mar 20 06:22:04 2022 -0400

    Use SessionContext to parse Expr protobuf (#2024)
    
    * Use ExecutionContext to parse Expr protobuf
    
    * Fixes after merge
    
    * linting
---
 ballista/rust/core/src/serde/logical_plan/mod.rs  |  29 +-
 ballista/rust/core/src/serde/mod.rs               |   9 +-
 ballista/rust/core/src/serde/physical_plan/mod.rs |  10 +-
 datafusion-proto/proto/datafusion.proto           |  16 +
 datafusion-proto/src/from_proto.rs                | 691 ++++++++++++----------
 datafusion-proto/src/lib.rs                       | 161 +++--
 datafusion-proto/src/to_proto.rs                  |  24 +-
 7 files changed, 573 insertions(+), 367 deletions(-)

diff --git a/ballista/rust/core/src/serde/logical_plan/mod.rs 
b/ballista/rust/core/src/serde/logical_plan/mod.rs
index bfc254a..fa155ce 100644
--- a/ballista/rust/core/src/serde/logical_plan/mod.rs
+++ b/ballista/rust/core/src/serde/logical_plan/mod.rs
@@ -38,6 +38,7 @@ use datafusion::logical_plan::{
 };
 use datafusion::prelude::SessionContext;
 
+use datafusion_proto::from_proto::parse_expr;
 use prost::bytes::BufMut;
 use prost::Message;
 use protobuf::listing_table_scan_node::FileFormatType;
@@ -95,10 +96,11 @@ impl AsLogicalPlan for LogicalPlanNode {
                         .values_list
                         .chunks_exact(n_cols)
                         .map(|r| {
-                            r.iter().map(|v| v.try_into()).collect::<Result<
+                            r.iter().map(|expr| parse_expr(expr, 
ctx)).collect::<Result<
                                 Vec<_>,
                                 datafusion_proto::from_proto::Error,
-                            >>()
+                            >>(
+                            )
                         })
                         .collect::<Result<Vec<_>, _>>()
                         .map_err(|e| e.into())
@@ -113,7 +115,7 @@ impl AsLogicalPlan for LogicalPlanNode {
                 let x: Vec<Expr> = projection
                     .expr
                     .iter()
-                    .map(|expr| expr.try_into())
+                    .map(|expr| parse_expr(expr, ctx))
                     .collect::<Result<Vec<_>, _>>()?;
                 LogicalPlanBuilder::from(input)
                     .project_with_alias(
@@ -133,10 +135,12 @@ impl AsLogicalPlan for LogicalPlanNode {
                 let expr: Expr = selection
                     .expr
                     .as_ref()
+                    .map(|expr| parse_expr(expr, ctx))
+                    .transpose()?
                     .ok_or_else(|| {
                         BallistaError::General("expression 
required".to_string())
-                    })?
-                    .try_into()?;
+                    })?;
+                // .try_into()?;
                 LogicalPlanBuilder::from(input)
                     .filter(expr)?
                     .build()
@@ -148,7 +152,7 @@ impl AsLogicalPlan for LogicalPlanNode {
                 let window_expr = window
                     .window_expr
                     .iter()
-                    .map(|expr| expr.try_into())
+                    .map(|expr| parse_expr(expr, ctx))
                     .collect::<Result<Vec<Expr>, _>>()?;
                 LogicalPlanBuilder::from(input)
                     .window(window_expr)?
@@ -161,12 +165,12 @@ impl AsLogicalPlan for LogicalPlanNode {
                 let group_expr = aggregate
                     .group_expr
                     .iter()
-                    .map(|expr| expr.try_into())
+                    .map(|expr| parse_expr(expr, ctx))
                     .collect::<Result<Vec<Expr>, _>>()?;
                 let aggr_expr = aggregate
                     .aggr_expr
                     .iter()
-                    .map(|expr| expr.try_into())
+                    .map(|expr| parse_expr(expr, ctx))
                     .collect::<Result<Vec<Expr>, _>>()?;
                 LogicalPlanBuilder::from(input)
                     .aggregate(group_expr, aggr_expr)?
@@ -189,7 +193,7 @@ impl AsLogicalPlan for LogicalPlanNode {
                 let filters = scan
                     .filters
                     .iter()
-                    .map(|e| e.try_into())
+                    .map(|expr| parse_expr(expr, ctx))
                     .collect::<Result<Vec<_>, _>>()?;
 
                 let file_format: Arc<dyn FileFormat> =
@@ -260,7 +264,7 @@ impl AsLogicalPlan for LogicalPlanNode {
                 let sort_expr: Vec<Expr> = sort
                     .expr
                     .iter()
-                    .map(|expr| expr.try_into())
+                    .map(|expr| parse_expr(expr, ctx))
                     .collect::<Result<Vec<Expr>, _>>()?;
                 LogicalPlanBuilder::from(input)
                     .sort(sort_expr)?
@@ -285,7 +289,7 @@ impl AsLogicalPlan for LogicalPlanNode {
                     }) => Partitioning::Hash(
                         pb_hash_expr
                             .iter()
-                            .map(|pb_expr| pb_expr.try_into())
+                            .map(|expr| parse_expr(expr, ctx))
                             .collect::<Result<Vec<_>, _>>()?,
                         partition_count as usize,
                     ),
@@ -416,7 +420,8 @@ impl AsLogicalPlan for LogicalPlanNode {
                     .map(|i| i.try_into_logical_plan(ctx, extension_codec))
                     .collect::<Result<_, BallistaError>>()?;
 
-                let extension_node = extension_codec.try_decode(node, 
&input_plans)?;
+                let extension_node =
+                    extension_codec.try_decode(node, &input_plans, ctx)?;
                 Ok(LogicalPlan::Extension(extension_node))
             }
         }
diff --git a/ballista/rust/core/src/serde/mod.rs 
b/ballista/rust/core/src/serde/mod.rs
index cc1bbb4..19ff22a 100644
--- a/ballista/rust/core/src/serde/mod.rs
+++ b/ballista/rust/core/src/serde/mod.rs
@@ -84,6 +84,7 @@ pub trait LogicalExtensionCodec: Debug + Send + Sync {
         &self,
         buf: &[u8],
         inputs: &[LogicalPlan],
+        ctx: &SessionContext,
     ) -> Result<Extension, BallistaError>;
 
     fn try_encode(
@@ -101,6 +102,7 @@ impl LogicalExtensionCodec for DefaultLogicalExtensionCodec 
{
         &self,
         _buf: &[u8],
         _inputs: &[LogicalPlan],
+        _ctx: &SessionContext,
     ) -> Result<Extension, BallistaError> {
         Err(BallistaError::NotImplemented(
             "LogicalExtensionCodec is not provided".to_string(),
@@ -147,6 +149,7 @@ pub trait PhysicalExtensionCodec: Debug + Send + Sync {
         &self,
         buf: &[u8],
         inputs: &[Arc<dyn ExecutionPlan>],
+        ctx: &SessionContext,
     ) -> Result<Arc<dyn ExecutionPlan>, BallistaError>;
 
     fn try_encode(
@@ -164,6 +167,7 @@ impl PhysicalExtensionCodec for 
DefaultPhysicalExtensionCodec {
         &self,
         _buf: &[u8],
         _inputs: &[Arc<dyn ExecutionPlan>],
+        _ctx: &SessionContext,
     ) -> Result<Arc<dyn ExecutionPlan>, BallistaError> {
         Err(BallistaError::NotImplemented(
             "PhysicalExtensionCodec is not provided".to_string(),
@@ -360,6 +364,7 @@ mod tests {
     use prost::Message;
     use std::any::Any;
 
+    use datafusion_proto::from_proto::parse_expr;
     use std::convert::TryInto;
     use std::fmt;
     use std::fmt::{Debug, Formatter};
@@ -596,6 +601,7 @@ mod tests {
             &self,
             buf: &[u8],
             inputs: &[LogicalPlan],
+            ctx: &SessionContext,
         ) -> Result<Extension, BallistaError> {
             if let Some((input, _)) = inputs.split_first() {
                 let proto = TopKPlanProto::decode(buf).map_err(|e| {
@@ -609,7 +615,7 @@ mod tests {
                     let node = TopKPlanNode::new(
                         proto.k as usize,
                         input.clone(),
-                        expr.try_into()?,
+                        parse_expr(expr, ctx)?,
                     );
 
                     Ok(Extension {
@@ -653,6 +659,7 @@ mod tests {
             &self,
             buf: &[u8],
             inputs: &[Arc<dyn ExecutionPlan>],
+            _ctx: &SessionContext,
         ) -> Result<Arc<dyn ExecutionPlan>, BallistaError> {
             if let Some((input, _)) = inputs.split_first() {
                 let proto = TopKExecProto::decode(buf).map_err(|e| {
diff --git a/ballista/rust/core/src/serde/physical_plan/mod.rs 
b/ballista/rust/core/src/serde/physical_plan/mod.rs
index 9c15122..936de0e 100644
--- a/ballista/rust/core/src/serde/physical_plan/mod.rs
+++ b/ballista/rust/core/src/serde/physical_plan/mod.rs
@@ -57,6 +57,7 @@ use datafusion::physical_plan::{
     AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, WindowExpr,
 };
 use datafusion::prelude::SessionContext;
+use datafusion_proto::from_proto::parse_expr;
 use prost::bytes::BufMut;
 use prost::Message;
 use std::convert::TryInto;
@@ -133,7 +134,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
                 let predicate = scan
                     .pruning_predicate
                     .as_ref()
-                    .map(|expr| expr.try_into())
+                    .map(|expr| parse_expr(expr, ctx))
                     .transpose()?;
                 Ok(Arc::new(ParquetExec::new(
                     decode_scan_config(scan.base_conf.as_ref().unwrap(), ctx)?,
@@ -483,8 +484,11 @@ impl AsExecutionPlan for PhysicalPlanNode {
                     .map(|i| i.try_into_physical_plan(ctx, extension_codec))
                     .collect::<Result<_, BallistaError>>()?;
 
-                let extension_node =
-                    extension_codec.try_decode(extension.node.as_slice(), 
&inputs)?;
+                let extension_node = extension_codec.try_decode(
+                    extension.node.as_slice(),
+                    &inputs,
+                    ctx,
+                )?;
 
                 Ok(extension_node)
             }
diff --git a/datafusion-proto/proto/datafusion.proto 
b/datafusion-proto/proto/datafusion.proto
index a0c5c2f..4ee3298 100644
--- a/datafusion-proto/proto/datafusion.proto
+++ b/datafusion-proto/proto/datafusion.proto
@@ -77,6 +77,12 @@ message LogicalExprNode {
 
     // window expressions
     WindowExprNode window_expr = 18;
+
+    // AggregateUDF expressions
+    AggregateUDFExprNode aggregate_udf_expr = 19;
+
+    // Scalar UDF expressions
+    ScalarUDFExprNode scalar_udf_expr = 20;
   }
 }
 
@@ -208,6 +214,16 @@ message AggregateExprNode {
   repeated LogicalExprNode expr = 2;
 }
 
+message AggregateUDFExprNode {
+  string fun_name = 1;
+  repeated LogicalExprNode args = 2;
+}
+
+message ScalarUDFExprNode {
+  string fun_name = 1;
+  repeated LogicalExprNode args = 2;
+}
+
 enum BuiltInWindowFunction {
   ROW_NUMBER = 0;
   RANK = 1;
diff --git a/datafusion-proto/src/from_proto.rs 
b/datafusion-proto/src/from_proto.rs
index e7a5def..8789129 100644
--- a/datafusion-proto/src/from_proto.rs
+++ b/datafusion-proto/src/from_proto.rs
@@ -16,6 +16,8 @@
 // under the License.
 
 use crate::protobuf;
+use datafusion::prelude::{bit_length, SessionContext};
+use datafusion::sql::planner::ContextProvider;
 use datafusion::{
     arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit, 
UnionMode},
     error::DataFusionError,
@@ -832,331 +834,365 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
     }
 }
 
-impl TryFrom<&protobuf::LogicalExprNode> for Expr {
-    type Error = Error;
-
-    fn try_from(expr: &protobuf::LogicalExprNode) -> Result<Self, Self::Error> 
{
-        use datafusion::physical_plan::window_functions;
-        use protobuf::{logical_expr_node::ExprType, window_expr_node, 
ScalarFunction};
-
-        let expr_type = expr
-            .expr_type
-            .as_ref()
-            .ok_or_else(|| Error::required("expr_type"))?;
-
-        match expr_type {
-            ExprType::BinaryExpr(binary_expr) => Ok(Self::BinaryExpr {
-                left: Box::new(binary_expr.l.as_deref().required("l")?),
-                op: from_proto_binary_op(&binary_expr.op)?,
-                right: Box::new(binary_expr.r.as_deref().required("r")?),
-            }),
-            ExprType::Column(column) => Ok(Self::Column(column.into())),
-            ExprType::Literal(literal) => {
-                let scalar_value: ScalarValue = literal.try_into()?;
-                Ok(Self::Literal(scalar_value))
-            }
-            ExprType::WindowExpr(expr) => {
-                let window_function = expr
-                    .window_function
-                    .as_ref()
-                    .ok_or_else(|| Error::required("window_function"))?;
-                let partition_by = expr
-                    .partition_by
-                    .iter()
-                    .map(|e| e.try_into())
-                    .collect::<Result<Vec<_>, _>>()?;
-                let order_by = expr
-                    .order_by
-                    .iter()
-                    .map(|e| e.try_into())
-                    .collect::<Result<Vec<_>, _>>()?;
-                let window_frame = expr
-                    .window_frame
-                    .as_ref()
-                    .map::<Result<WindowFrame, _>, _>(|e| match e {
-                        window_expr_node::WindowFrame::Frame(frame) => {
-                            let window_frame: WindowFrame = 
frame.clone().try_into()?;
-                            if WindowFrameUnits::Range == window_frame.units
-                                && order_by.len() != 1
-                            {
-                                Err(proto_error("With window frame of type 
RANGE, the order by expression must be of length 1"))
-                            } else {
-                                Ok(window_frame)
-                            }
+pub fn parse_expr(
+    proto: &protobuf::LogicalExprNode,
+    ctx: &SessionContext,
+) -> Result<Expr, Error> {
+    use datafusion::physical_plan::window_functions;
+    use protobuf::{logical_expr_node::ExprType, window_expr_node, 
ScalarFunction};
+
+    let expr_type = proto
+        .expr_type
+        .as_ref()
+        .ok_or_else(|| Error::required("expr_type"))?;
+
+    match expr_type {
+        ExprType::BinaryExpr(binary_expr) => Ok(Expr::BinaryExpr {
+            left: Box::new(parse_required_expr(&binary_expr.l, ctx, "l")?),
+            op: from_proto_binary_op(&binary_expr.op)?,
+            right: Box::new(parse_required_expr(&binary_expr.r, ctx, "r")?),
+        }),
+        ExprType::Column(column) => Ok(Expr::Column(column.into())),
+        ExprType::Literal(literal) => {
+            let scalar_value: ScalarValue = literal.try_into()?;
+            Ok(Expr::Literal(scalar_value))
+        }
+        ExprType::WindowExpr(expr) => {
+            let window_function = expr
+                .window_function
+                .as_ref()
+                .ok_or_else(|| Error::required("window_function"))?;
+            let partition_by = expr
+                .partition_by
+                .iter()
+                .map(|e| parse_expr(e, ctx))
+                .collect::<Result<Vec<_>, _>>()?;
+            let order_by = expr
+                .order_by
+                .iter()
+                .map(|e| parse_expr(e, ctx))
+                .collect::<Result<Vec<_>, _>>()?;
+            let window_frame = expr
+                .window_frame
+                .as_ref()
+                .map::<Result<WindowFrame, _>, _>(|e| match e {
+                    window_expr_node::WindowFrame::Frame(frame) => {
+                        let window_frame: WindowFrame = 
frame.clone().try_into()?;
+                        if WindowFrameUnits::Range == window_frame.units
+                            && order_by.len() != 1
+                        {
+                            Err(proto_error("With window frame of type RANGE, 
the order by expression must be of length 1"))
+                        } else {
+                            Ok(window_frame)
                         }
-                    })
-                    .transpose()?;
-
-                match window_function {
-                    window_expr_node::WindowFunction::AggrFunction(i) => {
-                        let aggr_function =
-                            protobuf::AggregateFunction::try_from(i)?.into();
-
-                        Ok(Self::WindowFunction {
-                            fun: 
window_functions::WindowFunction::AggregateFunction(
-                                aggr_function,
-                            ),
-                            args: vec![expr.expr.as_deref().required("expr")?],
-                            partition_by,
-                            order_by,
-                            window_frame,
-                        })
-                    }
-                    window_expr_node::WindowFunction::BuiltInFunction(i) => {
-                        let built_in_function =
-                            protobuf::BuiltInWindowFunction::from_i32(*i)
-                                .ok_or_else(|| {
-                                    Error::unknown("BuiltInWindowFunction", *i)
-                                })?
-                                .into();
-
-                        Ok(Self::WindowFunction {
-                            fun: 
window_functions::WindowFunction::BuiltInWindowFunction(
-                                built_in_function,
-                            ),
-                            args: vec![expr.expr.as_deref().required("expr")?],
-                            partition_by,
-                            order_by,
-                            window_frame,
-                        })
                     }
-                }
-            }
-            ExprType::AggregateExpr(expr) => {
-                let fun =
-                    
protobuf::AggregateFunction::try_from(&expr.aggr_function)?.into();
-
-                Ok(Self::AggregateFunction {
-                    fun,
-                    args: expr
-                        .expr
-                        .iter()
-                        .map(|e| e.try_into())
-                        .collect::<Result<Vec<_>, _>>()?,
-                    distinct: false, //TODO
                 })
-            }
-            ExprType::Alias(alias) => Ok(Self::Alias(
-                Box::new(alias.expr.as_deref().required("expr")?),
-                alias.alias.clone(),
-            )),
-            ExprType::IsNullExpr(is_null) => Ok(Self::IsNull(Box::new(
-                is_null.expr.as_deref().required("expr")?,
-            ))),
-            ExprType::IsNotNullExpr(is_not_null) => 
Ok(Self::IsNotNull(Box::new(
-                is_not_null.expr.as_deref().required("expr")?,
-            ))),
-            ExprType::NotExpr(not) => {
-                Ok(Self::Not(Box::new(not.expr.as_deref().required("expr")?)))
-            }
-            ExprType::Between(between) => Ok(Self::Between {
-                expr: Box::new(between.expr.as_deref().required("expr")?),
-                negated: between.negated,
-                low: Box::new(between.low.as_deref().required("low")?),
-                high: Box::new(between.high.as_deref().required("high")?),
-            }),
-            ExprType::Case(case) => {
-                let when_then_expr = case
-                    .when_then_expr
-                    .iter()
-                    .map(|e| {
-                        let when_expr = 
e.when_expr.as_ref().required("when_expr")?;
-                        let then_expr = 
e.then_expr.as_ref().required("then_expr")?;
-                        Ok((Box::new(when_expr), Box::new(then_expr)))
+                .transpose()?;
+
+            match window_function {
+                window_expr_node::WindowFunction::AggrFunction(i) => {
+                    let aggr_function = 
protobuf::AggregateFunction::try_from(i)?.into();
+
+                    Ok(Expr::WindowFunction {
+                        fun: 
window_functions::WindowFunction::AggregateFunction(
+                            aggr_function,
+                        ),
+                        args: vec![parse_required_expr(&expr.expr, ctx, 
"expr")?],
+                        partition_by,
+                        order_by,
+                        window_frame,
                     })
-                    .collect::<Result<Vec<(Box<Expr>, Box<Expr>)>, Error>>()?;
-                Ok(Self::Case {
-                    expr: parse_optional_expr(&case.expr)?.map(Box::new),
-                    when_then_expr,
-                    else_expr: 
parse_optional_expr(&case.else_expr)?.map(Box::new),
-                })
-            }
-            ExprType::Cast(cast) => {
-                let expr = Box::new(cast.expr.as_deref().required("expr")?);
-                let data_type = 
cast.arrow_type.as_ref().required("arrow_type")?;
-                Ok(Self::Cast { expr, data_type })
-            }
-            ExprType::TryCast(cast) => {
-                let expr = Box::new(cast.expr.as_deref().required("expr")?);
-                let data_type = 
cast.arrow_type.as_ref().required("arrow_type")?;
-                Ok(Self::TryCast { expr, data_type })
+                }
+                window_expr_node::WindowFunction::BuiltInFunction(i) => {
+                    let built_in_function = 
protobuf::BuiltInWindowFunction::from_i32(*i)
+                        .ok_or_else(|| Error::unknown("BuiltInWindowFunction", 
*i))?
+                        .into();
+
+                    Ok(Expr::WindowFunction {
+                        fun: 
window_functions::WindowFunction::BuiltInWindowFunction(
+                            built_in_function,
+                        ),
+                        args: vec![parse_required_expr(&expr.expr, ctx, 
"expr")?],
+                        partition_by,
+                        order_by,
+                        window_frame,
+                    })
+                }
             }
-            ExprType::Sort(sort) => Ok(Self::Sort {
-                expr: Box::new(sort.expr.as_deref().required("expr")?),
-                asc: sort.asc,
-                nulls_first: sort.nulls_first,
-            }),
-            ExprType::Negative(negative) => Ok(Self::Negative(Box::new(
-                negative.expr.as_deref().required("expr")?,
-            ))),
-            ExprType::InList(in_list) => Ok(Self::InList {
-                expr: Box::new(in_list.expr.as_deref().required("expr")?),
-                list: in_list
-                    .list
+        }
+        ExprType::AggregateExpr(expr) => {
+            let fun = 
protobuf::AggregateFunction::try_from(&expr.aggr_function)?.into();
+
+            Ok(Expr::AggregateFunction {
+                fun,
+                args: expr
+                    .expr
                     .iter()
-                    .map(|expr| expr.try_into())
+                    .map(|e| parse_expr(e, ctx))
                     .collect::<Result<Vec<_>, _>>()?,
-                negated: in_list.negated,
-            }),
-            ExprType::Wildcard(_) => Ok(Self::Wildcard),
-            ExprType::ScalarFunction(expr) => {
-                let scalar_function = 
protobuf::ScalarFunction::from_i32(expr.fun)
-                    .ok_or_else(|| Error::unknown("ScalarFunction", 
expr.fun))?;
-                let args = &expr.args;
-
-                match scalar_function {
-                    ScalarFunction::Asin => Ok(asin((&args[0]).try_into()?)),
-                    ScalarFunction::Acos => Ok(acos((&args[0]).try_into()?)),
-                    ScalarFunction::Array => Ok(array(
-                        args.to_owned()
-                            .iter()
-                            .map(|e| e.try_into())
-                            .collect::<Result<Vec<_>, _>>()?,
-                    )),
-                    ScalarFunction::Sqrt => Ok(sqrt((&args[0]).try_into()?)),
-                    ScalarFunction::Sin => Ok(sin((&args[0]).try_into()?)),
-                    ScalarFunction::Cos => Ok(cos((&args[0]).try_into()?)),
-                    ScalarFunction::Tan => Ok(tan((&args[0]).try_into()?)),
-                    ScalarFunction::Atan => Ok(atan((&args[0]).try_into()?)),
-                    ScalarFunction::Exp => Ok(exp((&args[0]).try_into()?)),
-                    ScalarFunction::Log2 => Ok(log2((&args[0]).try_into()?)),
-                    ScalarFunction::Ln => Ok(ln((&args[0]).try_into()?)),
-                    ScalarFunction::Log10 => Ok(log10((&args[0]).try_into()?)),
-                    ScalarFunction::Floor => Ok(floor((&args[0]).try_into()?)),
-                    ScalarFunction::Ceil => Ok(ceil((&args[0]).try_into()?)),
-                    ScalarFunction::Round => Ok(round((&args[0]).try_into()?)),
-                    ScalarFunction::Trunc => Ok(trunc((&args[0]).try_into()?)),
-                    ScalarFunction::Abs => Ok(abs((&args[0]).try_into()?)),
-                    ScalarFunction::Signum => 
Ok(signum((&args[0]).try_into()?)),
-                    ScalarFunction::OctetLength => {
-                        Ok(octet_length((&args[0]).try_into()?))
-                    }
-                    ScalarFunction::Lower => Ok(lower((&args[0]).try_into()?)),
-                    ScalarFunction::Upper => Ok(upper((&args[0]).try_into()?)),
-                    ScalarFunction::Trim => Ok(trim((&args[0]).try_into()?)),
-                    ScalarFunction::Ltrim => Ok(ltrim((&args[0]).try_into()?)),
-                    ScalarFunction::Rtrim => Ok(rtrim((&args[0]).try_into()?)),
-                    ScalarFunction::DatePart => {
-                        Ok(date_part((&args[0]).try_into()?, 
(&args[1]).try_into()?))
-                    }
-                    ScalarFunction::DateTrunc => {
-                        Ok(date_trunc((&args[0]).try_into()?, 
(&args[1]).try_into()?))
-                    }
-                    ScalarFunction::Sha224 => 
Ok(sha224((&args[0]).try_into()?)),
-                    ScalarFunction::Sha256 => 
Ok(sha256((&args[0]).try_into()?)),
-                    ScalarFunction::Sha384 => 
Ok(sha384((&args[0]).try_into()?)),
-                    ScalarFunction::Sha512 => 
Ok(sha512((&args[0]).try_into()?)),
-                    ScalarFunction::Md5 => Ok(md5((&args[0]).try_into()?)),
-                    ScalarFunction::NullIf => 
Ok(nullif((&args[0]).try_into()?)),
-                    ScalarFunction::Digest => {
-                        Ok(digest((&args[0]).try_into()?, 
(&args[1]).try_into()?))
-                    }
-                    ScalarFunction::Ascii => Ok(ascii((&args[0]).try_into()?)),
-                    ScalarFunction::BitLength => Ok((&args[0]).try_into()?),
-                    ScalarFunction::CharacterLength => {
-                        Ok(character_length((&args[0]).try_into()?))
-                    }
-                    ScalarFunction::Chr => Ok(chr((&args[0]).try_into()?)),
-                    ScalarFunction::InitCap => 
Ok(ascii((&args[0]).try_into()?)),
-                    ScalarFunction::Left => {
-                        Ok(left((&args[0]).try_into()?, 
(&args[1]).try_into()?))
-                    }
-                    ScalarFunction::Random => Ok(random()),
-                    ScalarFunction::Repeat => {
-                        Ok(repeat((&args[0]).try_into()?, 
(&args[1]).try_into()?))
-                    }
-                    ScalarFunction::Replace => Ok(replace(
-                        (&args[0]).try_into()?,
-                        (&args[1]).try_into()?,
-                        (&args[2]).try_into()?,
-                    )),
-                    ScalarFunction::Reverse => 
Ok(reverse((&args[0]).try_into()?)),
-                    ScalarFunction::Right => {
-                        Ok(right((&args[0]).try_into()?, 
(&args[1]).try_into()?))
-                    }
-                    ScalarFunction::Concat => Ok(concat_expr(
-                        args.to_owned()
-                            .iter()
-                            .map(|e| e.try_into())
-                            .collect::<Result<Vec<_>, _>>()?,
-                    )),
-                    ScalarFunction::ConcatWithSeparator => Ok(concat_ws_expr(
-                        args.to_owned()
-                            .iter()
-                            .map(|e| e.try_into())
-                            .collect::<Result<Vec<_>, _>>()?,
-                    )),
-                    ScalarFunction::Lpad => Ok(lpad(
-                        args.to_owned()
-                            .iter()
-                            .map(|e| e.try_into())
-                            .collect::<Result<Vec<_>, _>>()?,
-                    )),
-                    ScalarFunction::Rpad => Ok(rpad(
-                        args.to_owned()
-                            .iter()
-                            .map(|e| e.try_into())
-                            .collect::<Result<Vec<_>, _>>()?,
-                    )),
-                    ScalarFunction::RegexpReplace => Ok(regexp_replace(
-                        args.to_owned()
-                            .iter()
-                            .map(|e| e.try_into())
-                            .collect::<Result<Vec<_>, _>>()?,
-                    )),
-                    ScalarFunction::RegexpMatch => Ok(regexp_match(
-                        args.to_owned()
-                            .iter()
-                            .map(|e| e.try_into())
-                            .collect::<Result<Vec<_>, _>>()?,
-                    )),
-                    ScalarFunction::Btrim => Ok(btrim(
-                        args.to_owned()
-                            .iter()
-                            .map(|e| e.try_into())
-                            .collect::<Result<Vec<_>, _>>()?,
-                    )),
-                    ScalarFunction::SplitPart => Ok(split_part(
-                        (&args[0]).try_into()?,
-                        (&args[1]).try_into()?,
-                        (&args[2]).try_into()?,
-                    )),
-                    ScalarFunction::StartsWith => {
-                        Ok(starts_with((&args[0]).try_into()?, 
(&args[1]).try_into()?))
-                    }
-                    ScalarFunction::Strpos => {
-                        Ok(strpos((&args[0]).try_into()?, 
(&args[1]).try_into()?))
-                    }
-                    ScalarFunction::Substr => {
-                        Ok(substr((&args[0]).try_into()?, 
(&args[1]).try_into()?))
-                    }
-                    ScalarFunction::ToHex => 
Ok(to_hex((&args[0]).try_into()?)),
-                    ScalarFunction::ToTimestampMillis => {
-                        Ok(to_timestamp_millis((&args[0]).try_into()?))
-                    }
-                    ScalarFunction::ToTimestampMicros => {
-                        Ok(to_timestamp_micros((&args[0]).try_into()?))
-                    }
-                    ScalarFunction::ToTimestampSeconds => {
-                        Ok(to_timestamp_seconds((&args[0]).try_into()?))
-                    }
-                    ScalarFunction::Now => Ok(now_expr(
-                        args.to_owned()
-                            .iter()
-                            .map(|e| e.try_into())
-                            .collect::<Result<Vec<_>, _>>()?,
-                    )),
-                    ScalarFunction::Translate => Ok(translate(
-                        (&args[0]).try_into()?,
-                        (&args[1]).try_into()?,
-                        (&args[2]).try_into()?,
-                    )),
-                    _ => Err(proto_error(
-                        "Protobuf deserialization error: Unsupported scalar 
function",
-                    )),
+                distinct: false, //TODO
+            })
+        }
+        ExprType::Alias(alias) => Ok(Expr::Alias(
+            Box::new(parse_required_expr(&alias.expr, ctx, "expr")?),
+            alias.alias.clone(),
+        )),
+        ExprType::IsNullExpr(is_null) => 
Ok(Expr::IsNull(Box::new(parse_required_expr(
+            &is_null.expr,
+            ctx,
+            "expr",
+        )?))),
+        ExprType::IsNotNullExpr(is_not_null) => Ok(Expr::IsNotNull(Box::new(
+            parse_required_expr(&is_not_null.expr, ctx, "expr")?,
+        ))),
+        ExprType::NotExpr(not) => Ok(Expr::Not(Box::new(parse_required_expr(
+            &not.expr, ctx, "expr",
+        )?))),
+        ExprType::Between(between) => Ok(Expr::Between {
+            expr: Box::new(parse_required_expr(&between.expr, ctx, "expr")?),
+            negated: between.negated,
+            low: Box::new(parse_required_expr(&between.low, ctx, "expr")?),
+            high: Box::new(parse_required_expr(&between.high, ctx, "expr")?),
+        }),
+        ExprType::Case(case) => {
+            let when_then_expr = case
+                .when_then_expr
+                .iter()
+                .map(|e| {
+                    let when_expr =
+                        parse_required_expr_inner(&e.when_expr, ctx, 
"when_expr")?;
+                    let then_expr =
+                        parse_required_expr_inner(&e.then_expr, ctx, 
"then_expr")?;
+                    Ok((Box::new(when_expr), Box::new(then_expr)))
+                })
+                .collect::<Result<Vec<(Box<Expr>, Box<Expr>)>, Error>>()?;
+            Ok(Expr::Case {
+                expr: parse_optional_expr(&case.expr, ctx)?.map(Box::new),
+                when_then_expr,
+                else_expr: parse_optional_expr(&case.else_expr, 
ctx)?.map(Box::new),
+            })
+        }
+        ExprType::Cast(cast) => {
+            let expr = Box::new(parse_required_expr(&cast.expr, ctx, "expr")?);
+            let data_type = cast.arrow_type.as_ref().required("arrow_type")?;
+            Ok(Expr::Cast { expr, data_type })
+        }
+        ExprType::TryCast(cast) => {
+            let expr = Box::new(parse_required_expr(&cast.expr, ctx, "expr")?);
+            let data_type = cast.arrow_type.as_ref().required("arrow_type")?;
+            Ok(Expr::TryCast { expr, data_type })
+        }
+        ExprType::Sort(sort) => Ok(Expr::Sort {
+            expr: Box::new(parse_required_expr(&sort.expr, ctx, "expr")?),
+            asc: sort.asc,
+            nulls_first: sort.nulls_first,
+        }),
+        ExprType::Negative(negative) => Ok(Expr::Negative(Box::new(
+            parse_required_expr(&negative.expr, ctx, "expr")?,
+        ))),
+        ExprType::InList(in_list) => Ok(Expr::InList {
+            expr: Box::new(parse_required_expr(&in_list.expr, ctx, "expr")?),
+            list: in_list
+                .list
+                .iter()
+                .map(|expr| parse_expr(expr, ctx))
+                .collect::<Result<Vec<_>, _>>()?,
+            negated: in_list.negated,
+        }),
+        ExprType::Wildcard(_) => Ok(Expr::Wildcard),
+        ExprType::ScalarFunction(expr) => {
+            let scalar_function = protobuf::ScalarFunction::from_i32(expr.fun)
+                .ok_or_else(|| Error::unknown("ScalarFunction", expr.fun))?;
+            let args = &expr.args;
+
+            match scalar_function {
+                ScalarFunction::Asin => Ok(asin(parse_expr(&args[0], ctx)?)),
+                ScalarFunction::Acos => Ok(acos(parse_expr(&args[0], ctx)?)),
+                ScalarFunction::Array => Ok(array(
+                    args.to_owned()
+                        .iter()
+                        .map(|expr| parse_expr(expr, ctx))
+                        .collect::<Result<Vec<_>, _>>()?,
+                )),
+                ScalarFunction::Sqrt => Ok(sqrt(parse_expr(&args[0], ctx)?)),
+                ScalarFunction::Sin => Ok(sin(parse_expr(&args[0], ctx)?)),
+                ScalarFunction::Cos => Ok(cos(parse_expr(&args[0], ctx)?)),
+                ScalarFunction::Tan => Ok(tan(parse_expr(&args[0], ctx)?)),
+                ScalarFunction::Atan => Ok(atan(parse_expr(&args[0], ctx)?)),
+                ScalarFunction::Exp => Ok(exp(parse_expr(&args[0], ctx)?)),
+                ScalarFunction::Log2 => Ok(log2(parse_expr(&args[0], ctx)?)),
+                ScalarFunction::Ln => Ok(ln(parse_expr(&args[0], ctx)?)),
+                ScalarFunction::Log10 => Ok(log10(parse_expr(&args[0], ctx)?)),
+                ScalarFunction::Floor => Ok(floor(parse_expr(&args[0], ctx)?)),
+                ScalarFunction::Ceil => Ok(ceil(parse_expr(&args[0], ctx)?)),
+                ScalarFunction::Round => Ok(round(parse_expr(&args[0], ctx)?)),
+                ScalarFunction::Trunc => Ok(trunc(parse_expr(&args[0], ctx)?)),
+                ScalarFunction::Abs => Ok(abs(parse_expr(&args[0], ctx)?)),
+                ScalarFunction::Signum => Ok(signum(parse_expr(&args[0], 
ctx)?)),
+                ScalarFunction::OctetLength => {
+                    Ok(octet_length(parse_expr(&args[0], ctx)?))
+                }
+                ScalarFunction::Lower => Ok(lower(parse_expr(&args[0], ctx)?)),
+                ScalarFunction::Upper => Ok(upper(parse_expr(&args[0], ctx)?)),
+                ScalarFunction::Trim => Ok(trim(parse_expr(&args[0], ctx)?)),
+                ScalarFunction::Ltrim => Ok(ltrim(parse_expr(&args[0], ctx)?)),
+                ScalarFunction::Rtrim => Ok(rtrim(parse_expr(&args[0], ctx)?)),
+                ScalarFunction::DatePart => Ok(date_part(
+                    parse_expr(&args[0], ctx)?,
+                    parse_expr(&args[1], ctx)?,
+                )),
+                ScalarFunction::DateTrunc => Ok(date_trunc(
+                    parse_expr(&args[0], ctx)?,
+                    parse_expr(&args[1], ctx)?,
+                )),
+                ScalarFunction::Sha224 => Ok(sha224(parse_expr(&args[0], 
ctx)?)),
+                ScalarFunction::Sha256 => Ok(sha256(parse_expr(&args[0], 
ctx)?)),
+                ScalarFunction::Sha384 => Ok(sha384(parse_expr(&args[0], 
ctx)?)),
+                ScalarFunction::Sha512 => Ok(sha512(parse_expr(&args[0], 
ctx)?)),
+                ScalarFunction::Md5 => Ok(md5(parse_expr(&args[0], ctx)?)),
+                ScalarFunction::NullIf => Ok(nullif(parse_expr(&args[0], 
ctx)?)),
+                ScalarFunction::Digest => Ok(digest(
+                    parse_expr(&args[0], ctx)?,
+                    parse_expr(&args[1], ctx)?,
+                )),
+                ScalarFunction::Ascii => Ok(ascii(parse_expr(&args[0], ctx)?)),
+                ScalarFunction::BitLength => 
Ok(bit_length(parse_expr(&args[0], ctx)?)),
+                ScalarFunction::CharacterLength => {
+                    Ok(character_length(parse_expr(&args[0], ctx)?))
+                }
+                ScalarFunction::Chr => Ok(chr(parse_expr(&args[0], ctx)?)),
+                ScalarFunction::InitCap => Ok(ascii(parse_expr(&args[0], 
ctx)?)),
+                ScalarFunction::Left => {
+                    Ok(left(parse_expr(&args[0], ctx)?, parse_expr(&args[1], 
ctx)?))
+                }
+                ScalarFunction::Random => Ok(random()),
+                ScalarFunction::Repeat => Ok(repeat(
+                    parse_expr(&args[0], ctx)?,
+                    parse_expr(&args[1], ctx)?,
+                )),
+                ScalarFunction::Replace => Ok(replace(
+                    parse_expr(&args[0], ctx)?,
+                    parse_expr(&args[1], ctx)?,
+                    parse_expr(&args[2], ctx)?,
+                )),
+                ScalarFunction::Reverse => Ok(reverse(parse_expr(&args[0], 
ctx)?)),
+                ScalarFunction::Right => Ok(right(
+                    parse_expr(&args[0], ctx)?,
+                    parse_expr(&args[1], ctx)?,
+                )),
+                ScalarFunction::Concat => Ok(concat_expr(
+                    args.to_owned()
+                        .iter()
+                        .map(|expr| parse_expr(expr, ctx))
+                        .collect::<Result<Vec<_>, _>>()?,
+                )),
+                ScalarFunction::ConcatWithSeparator => Ok(concat_ws_expr(
+                    args.to_owned()
+                        .iter()
+                        .map(|expr| parse_expr(expr, ctx))
+                        .collect::<Result<Vec<_>, _>>()?,
+                )),
+                ScalarFunction::Lpad => Ok(lpad(
+                    args.to_owned()
+                        .iter()
+                        .map(|expr| parse_expr(expr, ctx))
+                        .collect::<Result<Vec<_>, _>>()?,
+                )),
+                ScalarFunction::Rpad => Ok(rpad(
+                    args.to_owned()
+                        .iter()
+                        .map(|expr| parse_expr(expr, ctx))
+                        .collect::<Result<Vec<_>, _>>()?,
+                )),
+                ScalarFunction::RegexpReplace => Ok(regexp_replace(
+                    args.to_owned()
+                        .iter()
+                        .map(|expr| parse_expr(expr, ctx))
+                        .collect::<Result<Vec<_>, _>>()?,
+                )),
+                ScalarFunction::RegexpMatch => Ok(regexp_match(
+                    args.to_owned()
+                        .iter()
+                        .map(|expr| parse_expr(expr, ctx))
+                        .collect::<Result<Vec<_>, _>>()?,
+                )),
+                ScalarFunction::Btrim => Ok(btrim(
+                    args.to_owned()
+                        .iter()
+                        .map(|expr| parse_expr(expr, ctx))
+                        .collect::<Result<Vec<_>, _>>()?,
+                )),
+                ScalarFunction::SplitPart => Ok(split_part(
+                    parse_expr(&args[0], ctx)?,
+                    parse_expr(&args[1], ctx)?,
+                    parse_expr(&args[2], ctx)?,
+                )),
+                ScalarFunction::StartsWith => Ok(starts_with(
+                    parse_expr(&args[0], ctx)?,
+                    parse_expr(&args[1], ctx)?,
+                )),
+                ScalarFunction::Strpos => Ok(strpos(
+                    parse_expr(&args[0], ctx)?,
+                    parse_expr(&args[1], ctx)?,
+                )),
+                ScalarFunction::Substr => Ok(substr(
+                    parse_expr(&args[0], ctx)?,
+                    parse_expr(&args[1], ctx)?,
+                )),
+                ScalarFunction::ToHex => Ok(to_hex(parse_expr(&args[0], 
ctx)?)),
+                ScalarFunction::ToTimestampMillis => {
+                    Ok(to_timestamp_millis(parse_expr(&args[0], ctx)?))
+                }
+                ScalarFunction::ToTimestampMicros => {
+                    Ok(to_timestamp_micros(parse_expr(&args[0], ctx)?))
+                }
+                ScalarFunction::ToTimestampSeconds => {
+                    Ok(to_timestamp_seconds(parse_expr(&args[0], ctx)?))
                 }
+                ScalarFunction::Now => Ok(now_expr(
+                    args.to_owned()
+                        .iter()
+                        .map(|expr| parse_expr(expr, ctx))
+                        .collect::<Result<Vec<_>, _>>()?,
+                )),
+                ScalarFunction::Translate => Ok(translate(
+                    parse_expr(&args[0], ctx)?,
+                    parse_expr(&args[1], ctx)?,
+                    parse_expr(&args[2], ctx)?,
+                )),
+                _ => Err(proto_error(
+                    "Protobuf deserialization error: Unsupported scalar 
function",
+                )),
             }
         }
+        ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { fun_name, args 
}) => {
+            let scalar_fn = ctx
+                .state
+                .lock()
+                .get_function_meta(fun_name.as_str()).ok_or_else(|| 
Error::General(format!("invalid aggregate function message, function {} is not 
registered in the ExecutionContext", fun_name)))?;
+
+            Ok(Expr::ScalarUDF {
+                fun: scalar_fn,
+                args: args
+                    .iter()
+                    .map(|expr| parse_expr(expr, ctx))
+                    .collect::<Result<Vec<_>, Error>>()?,
+            })
+        }
+        ExprType::AggregateUdfExpr(protobuf::AggregateUdfExprNode { fun_name, 
args }) => {
+            let agg_fn = ctx
+                .state
+                .lock()
+                .get_aggregate_meta(fun_name.as_str()).ok_or_else(|| 
Error::General(format!("invalid aggregate function message, function {} is not 
registered in the ExecutionContext", fun_name)))?;
+
+            Ok(Expr::AggregateUDF {
+                fun: agg_fn,
+                args: args
+                    .iter()
+                    .map(|expr| parse_expr(expr, ctx))
+                    .collect::<Result<Vec<_>, Error>>()?,
+            })
+        }
     }
 }
 
@@ -1425,13 +1461,36 @@ fn from_proto_binary_op(op: &str) -> Result<Operator, 
Error> {
 
 fn parse_optional_expr(
     p: &Option<Box<protobuf::LogicalExprNode>>,
+    ctx: &SessionContext,
 ) -> Result<Option<Expr>, Error> {
     match p {
-        Some(expr) => expr.as_ref().try_into().map(Some),
+        Some(expr) => parse_expr(expr.as_ref(), ctx).map(Some),
         None => Ok(None),
     }
 }
 
+fn parse_required_expr(
+    p: &Option<Box<protobuf::LogicalExprNode>>,
+    ctx: &SessionContext,
+    field: impl Into<String>,
+) -> Result<Expr, Error> {
+    match p {
+        Some(expr) => parse_expr(expr.as_ref(), ctx),
+        None => Err(Error::required(field)),
+    }
+}
+
+fn parse_required_expr_inner(
+    p: &Option<protobuf::LogicalExprNode>,
+    ctx: &SessionContext,
+    field: impl Into<String>,
+) -> Result<Expr, Error> {
+    match p {
+        Some(expr) => parse_expr(expr, ctx),
+        None => Err(Error::required(field)),
+    }
+}
+
 fn proto_error<S: Into<String>>(message: S) -> Error {
     Error::General(message.into())
 }
diff --git a/datafusion-proto/src/lib.rs b/datafusion-proto/src/lib.rs
index b880f8e..0688215 100644
--- a/datafusion-proto/src/lib.rs
+++ b/datafusion-proto/src/lib.rs
@@ -26,6 +26,12 @@ pub mod to_proto;
 
 #[cfg(test)]
 mod roundtrip_tests {
+    use super::from_proto::parse_expr;
+    use super::protobuf;
+    use datafusion::arrow::array::ArrayRef;
+    use datafusion::logical_plan::create_udaf;
+    use datafusion::physical_plan::functions::{make_scalar_function, 
Volatility};
+    use datafusion::physical_plan::Accumulator;
     use datafusion::{
         arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit, UnionMode},
         logical_plan::{col, Expr},
@@ -33,14 +39,15 @@ mod roundtrip_tests {
         prelude::*,
         scalar::ScalarValue,
     };
+    use std::sync::Arc;
 
-    // Given a DataFusion type, convert it to protobuf and back, using debug 
formatting to test
+    // Given a DataFusion logical Expr, convert it to protobuf and back, using 
debug formatting to test
     // equality.
-    macro_rules! roundtrip_test {
-        ($initial_struct:ident, $proto_type:ty, $struct_type:ty) => {
-            let proto: $proto_type = (&$initial_struct).try_into().unwrap();
+    macro_rules! roundtrip_expr_test {
+        ($initial_struct:ident, $ctx:ident) => {
+            let proto: protobuf::LogicalExprNode = 
(&$initial_struct).try_into().unwrap();
 
-            let round_trip: $struct_type = (&proto).try_into().unwrap();
+            let round_trip: Expr = parse_expr(&proto, &$ctx).unwrap();
 
             assert_eq!(
                 format!("{:?}", $initial_struct),
@@ -575,95 +582,102 @@ mod roundtrip_tests {
 
     #[test]
     fn roundtrip_not() {
-        let test_expr = Expr::Not(Box::new(Expr::Literal((1.0).into())));
+        let test_expr = Expr::Not(Box::new(lit(1.0_f32)));
 
-        roundtrip_test!(test_expr, super::protobuf::LogicalExprNode, Expr);
+        let ctx = SessionContext::new();
+        roundtrip_expr_test!(test_expr, ctx);
     }
 
     #[test]
     fn roundtrip_is_null() {
         let test_expr = Expr::IsNull(Box::new(col("id")));
 
-        roundtrip_test!(test_expr, super::protobuf::LogicalExprNode, Expr);
+        let ctx = SessionContext::new();
+        roundtrip_expr_test!(test_expr, ctx);
     }
 
     #[test]
     fn roundtrip_is_not_null() {
         let test_expr = Expr::IsNotNull(Box::new(col("id")));
 
-        roundtrip_test!(test_expr, super::protobuf::LogicalExprNode, Expr);
+        let ctx = SessionContext::new();
+        roundtrip_expr_test!(test_expr, ctx);
     }
 
     #[test]
     fn roundtrip_between() {
         let test_expr = Expr::Between {
-            expr: Box::new(Expr::Literal((1.0).into())),
+            expr: Box::new(lit(1.0_f32)),
             negated: true,
-            low: Box::new(Expr::Literal((2.0).into())),
-            high: Box::new(Expr::Literal((3.0).into())),
+            low: Box::new(lit(2.0_f32)),
+            high: Box::new(lit(3.0_f32)),
         };
 
-        roundtrip_test!(test_expr, super::protobuf::LogicalExprNode, Expr);
+        let ctx = SessionContext::new();
+        roundtrip_expr_test!(test_expr, ctx);
     }
 
     #[test]
     fn roundtrip_case() {
         let test_expr = Expr::Case {
-            expr: Some(Box::new(Expr::Literal((1.0).into()))),
-            when_then_expr: vec![(
-                Box::new(Expr::Literal((2.0).into())),
-                Box::new(Expr::Literal((3.0).into())),
-            )],
-            else_expr: Some(Box::new(Expr::Literal((4.0).into()))),
+            expr: Some(Box::new(lit(1.0_f32))),
+            when_then_expr: vec![(Box::new(lit(2.0_f32)), 
Box::new(lit(3.0_f32)))],
+            else_expr: Some(Box::new(lit(4.0_f32))),
         };
 
-        roundtrip_test!(test_expr, super::protobuf::LogicalExprNode, Expr);
+        let ctx = SessionContext::new();
+        roundtrip_expr_test!(test_expr, ctx);
     }
 
     #[test]
     fn roundtrip_cast() {
         let test_expr = Expr::Cast {
-            expr: Box::new(Expr::Literal((1.0).into())),
+            expr: Box::new(lit(1.0_f32)),
             data_type: DataType::Boolean,
         };
 
-        roundtrip_test!(test_expr, super::protobuf::LogicalExprNode, Expr);
+        let ctx = SessionContext::new();
+        roundtrip_expr_test!(test_expr, ctx);
     }
 
     #[test]
     fn roundtrip_sort_expr() {
         let test_expr = Expr::Sort {
-            expr: Box::new(Expr::Literal((1.0).into())),
+            expr: Box::new(lit(1.0_f32)),
             asc: true,
             nulls_first: true,
         };
 
-        roundtrip_test!(test_expr, super::protobuf::LogicalExprNode, Expr);
+        let ctx = SessionContext::new();
+        roundtrip_expr_test!(test_expr, ctx);
     }
 
     #[test]
     fn roundtrip_negative() {
-        let test_expr = Expr::Negative(Box::new(Expr::Literal((1.0).into())));
+        let test_expr = Expr::Negative(Box::new(lit(1.0_f32)));
 
-        roundtrip_test!(test_expr, super::protobuf::LogicalExprNode, Expr);
+        let ctx = SessionContext::new();
+        roundtrip_expr_test!(test_expr, ctx);
     }
 
     #[test]
     fn roundtrip_inlist() {
         let test_expr = Expr::InList {
-            expr: Box::new(Expr::Literal((1.0).into())),
-            list: vec![Expr::Literal((2.0).into())],
+            expr: Box::new(lit(1.0_f32)),
+            list: vec![lit(2.0_f32)],
             negated: true,
         };
 
-        roundtrip_test!(test_expr, super::protobuf::LogicalExprNode, Expr);
+        let ctx = SessionContext::new();
+        roundtrip_expr_test!(test_expr, ctx);
     }
 
     #[test]
     fn roundtrip_wildcard() {
         let test_expr = Expr::Wildcard;
 
-        roundtrip_test!(test_expr, super::protobuf::LogicalExprNode, Expr);
+        let ctx = SessionContext::new();
+        roundtrip_expr_test!(test_expr, ctx);
     }
 
     #[test]
@@ -672,17 +686,98 @@ mod roundtrip_tests {
             fun: Sqrt,
             args: vec![col("col")],
         };
-        roundtrip_test!(test_expr, super::protobuf::LogicalExprNode, Expr);
+        let ctx = SessionContext::new();
+        roundtrip_expr_test!(test_expr, ctx);
     }
 
     #[test]
     fn roundtrip_approx_percentile_cont() {
         let test_expr = Expr::AggregateFunction {
             fun: aggregates::AggregateFunction::ApproxPercentileCont,
-            args: vec![col("bananas"), lit(0.42)],
+            args: vec![col("bananas"), lit(0.42_f32)],
             distinct: false,
         };
 
-        roundtrip_test!(test_expr, super::protobuf::LogicalExprNode, Expr);
+        let ctx = SessionContext::new();
+        roundtrip_expr_test!(test_expr, ctx);
+    }
+
+    #[test]
+    fn roundtrip_aggregate_udf() {
+        #[derive(Debug)]
+        struct Dummy {}
+
+        impl Accumulator for Dummy {
+            fn state(&self) -> datafusion::error::Result<Vec<ScalarValue>> {
+                Ok(vec![])
+            }
+
+            fn update_batch(
+                &mut self,
+                _values: &[ArrayRef],
+            ) -> datafusion::error::Result<()> {
+                Ok(())
+            }
+
+            fn merge_batch(
+                &mut self,
+                _states: &[ArrayRef],
+            ) -> datafusion::error::Result<()> {
+                Ok(())
+            }
+
+            fn evaluate(&self) -> datafusion::error::Result<ScalarValue> {
+                Ok(ScalarValue::Float64(None))
+            }
+        }
+
+        let dummy_agg = create_udaf(
+            // the name; used to represent it in plan descriptions and in the 
registry, to use in SQL.
+            "dummy_agg",
+            // the input type; DataFusion guarantees that the first entry of 
`values` in `update` has this type.
+            DataType::Float64,
+            // the return type; DataFusion expects this to match the type 
returned by `evaluate`.
+            Arc::new(DataType::Float64),
+            Volatility::Immutable,
+            // This is the accumulator factory; DataFusion uses it to create 
new accumulators.
+            Arc::new(|| Ok(Box::new(Dummy {}))),
+            // This is the description of the state. `state()` must match the 
types here.
+            Arc::new(vec![DataType::Float64, DataType::UInt32]),
+        );
+
+        let test_expr = Expr::AggregateUDF {
+            fun: Arc::new(dummy_agg.clone()),
+            args: vec![lit(1.0_f64)],
+        };
+
+        let mut ctx = SessionContext::new();
+        ctx.register_udaf(dummy_agg);
+
+        roundtrip_expr_test!(test_expr, ctx);
+    }
+
+    #[test]
+    fn roundtrip_scalar_udf() {
+        let fn_impl = |args: &[ArrayRef]| Ok(Arc::new(args[0].clone()) as 
ArrayRef);
+
+        let scalar_fn = make_scalar_function(fn_impl);
+
+        let udf = create_udf(
+            "dummy",
+            vec![DataType::Utf8],
+            Arc::new(DataType::Utf8),
+            Volatility::Immutable,
+            scalar_fn,
+        );
+
+        let test_expr = Expr::ScalarUDF {
+            fun: Arc::new(udf.clone()),
+            args: vec![lit("")],
+        };
+
+        let mut ctx = SessionContext::new();
+        ctx.register_udf(udf);
+
+        roundtrip_expr_test!(test_expr, ctx);
     }
 }
diff --git a/datafusion-proto/src/to_proto.rs b/datafusion-proto/src/to_proto.rs
index 29c533a..753e59c 100644
--- a/datafusion-proto/src/to_proto.rs
+++ b/datafusion-proto/src/to_proto.rs
@@ -20,6 +20,7 @@
 //! processes.
 
 use crate::protobuf;
+
 use datafusion::{
     arrow::datatypes::{
         DataType, Field, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionMode,
@@ -523,8 +524,27 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
                     )),
                 }
             }
-            Expr::ScalarUDF { .. } => unimplemented!(),
-            Expr::AggregateUDF { .. } => unimplemented!(),
+            Expr::ScalarUDF { fun, args } => Self {
+                expr_type: 
Some(ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode {
+                    fun_name: fun.name.clone(),
+                    args: args
+                        .iter()
+                        .map(|expr| expr.try_into())
+                        .collect::<Result<Vec<_>, Error>>()?,
+                })),
+            },
+            Expr::AggregateUDF { fun, args } => Self {
+                expr_type: Some(ExprType::AggregateUdfExpr(
+                    protobuf::AggregateUdfExprNode {
+                        fun_name: fun.name.clone(),
+                        args: args.iter().map(|expr| 
expr.try_into()).collect::<Result<
+                            Vec<_>,
+                            Error,
+                        >>(
+                        )?,
+                    },
+                )),
+            },
             Expr::Not(expr) => {
                 let expr = Box::new(protobuf::Not {
                     expr: Some(Box::new(expr.as_ref().try_into()?)),

Reply via email to