This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new c1f6269 Use SessionContext to parse Expr protobuf (#2024)
c1f6269 is described below
commit c1f6269a3138a22b50679d73609645ed1511e5de
Author: Dan Harris <[email protected]>
AuthorDate: Sun Mar 20 06:22:04 2022 -0400
Use SessionContext to parse Expr protobuf (#2024)
* Use ExecutionContext to parse Expr protobuf
* Fixes after merge
* linting
---
ballista/rust/core/src/serde/logical_plan/mod.rs | 29 +-
ballista/rust/core/src/serde/mod.rs | 9 +-
ballista/rust/core/src/serde/physical_plan/mod.rs | 10 +-
datafusion-proto/proto/datafusion.proto | 16 +
datafusion-proto/src/from_proto.rs | 691 ++++++++++++----------
datafusion-proto/src/lib.rs | 161 +++--
datafusion-proto/src/to_proto.rs | 24 +-
7 files changed, 573 insertions(+), 367 deletions(-)
diff --git a/ballista/rust/core/src/serde/logical_plan/mod.rs
b/ballista/rust/core/src/serde/logical_plan/mod.rs
index bfc254a..fa155ce 100644
--- a/ballista/rust/core/src/serde/logical_plan/mod.rs
+++ b/ballista/rust/core/src/serde/logical_plan/mod.rs
@@ -38,6 +38,7 @@ use datafusion::logical_plan::{
};
use datafusion::prelude::SessionContext;
+use datafusion_proto::from_proto::parse_expr;
use prost::bytes::BufMut;
use prost::Message;
use protobuf::listing_table_scan_node::FileFormatType;
@@ -95,10 +96,11 @@ impl AsLogicalPlan for LogicalPlanNode {
.values_list
.chunks_exact(n_cols)
.map(|r| {
- r.iter().map(|v| v.try_into()).collect::<Result<
+ r.iter().map(|expr| parse_expr(expr,
ctx)).collect::<Result<
Vec<_>,
datafusion_proto::from_proto::Error,
- >>()
+ >>(
+ )
})
.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.into())
@@ -113,7 +115,7 @@ impl AsLogicalPlan for LogicalPlanNode {
let x: Vec<Expr> = projection
.expr
.iter()
- .map(|expr| expr.try_into())
+ .map(|expr| parse_expr(expr, ctx))
.collect::<Result<Vec<_>, _>>()?;
LogicalPlanBuilder::from(input)
.project_with_alias(
@@ -133,10 +135,12 @@ impl AsLogicalPlan for LogicalPlanNode {
let expr: Expr = selection
.expr
.as_ref()
+ .map(|expr| parse_expr(expr, ctx))
+ .transpose()?
.ok_or_else(|| {
BallistaError::General("expression
required".to_string())
- })?
- .try_into()?;
+ })?;
+ // .try_into()?;
LogicalPlanBuilder::from(input)
.filter(expr)?
.build()
@@ -148,7 +152,7 @@ impl AsLogicalPlan for LogicalPlanNode {
let window_expr = window
.window_expr
.iter()
- .map(|expr| expr.try_into())
+ .map(|expr| parse_expr(expr, ctx))
.collect::<Result<Vec<Expr>, _>>()?;
LogicalPlanBuilder::from(input)
.window(window_expr)?
@@ -161,12 +165,12 @@ impl AsLogicalPlan for LogicalPlanNode {
let group_expr = aggregate
.group_expr
.iter()
- .map(|expr| expr.try_into())
+ .map(|expr| parse_expr(expr, ctx))
.collect::<Result<Vec<Expr>, _>>()?;
let aggr_expr = aggregate
.aggr_expr
.iter()
- .map(|expr| expr.try_into())
+ .map(|expr| parse_expr(expr, ctx))
.collect::<Result<Vec<Expr>, _>>()?;
LogicalPlanBuilder::from(input)
.aggregate(group_expr, aggr_expr)?
@@ -189,7 +193,7 @@ impl AsLogicalPlan for LogicalPlanNode {
let filters = scan
.filters
.iter()
- .map(|e| e.try_into())
+ .map(|expr| parse_expr(expr, ctx))
.collect::<Result<Vec<_>, _>>()?;
let file_format: Arc<dyn FileFormat> =
@@ -260,7 +264,7 @@ impl AsLogicalPlan for LogicalPlanNode {
let sort_expr: Vec<Expr> = sort
.expr
.iter()
- .map(|expr| expr.try_into())
+ .map(|expr| parse_expr(expr, ctx))
.collect::<Result<Vec<Expr>, _>>()?;
LogicalPlanBuilder::from(input)
.sort(sort_expr)?
@@ -285,7 +289,7 @@ impl AsLogicalPlan for LogicalPlanNode {
}) => Partitioning::Hash(
pb_hash_expr
.iter()
- .map(|pb_expr| pb_expr.try_into())
+ .map(|expr| parse_expr(expr, ctx))
.collect::<Result<Vec<_>, _>>()?,
partition_count as usize,
),
@@ -416,7 +420,8 @@ impl AsLogicalPlan for LogicalPlanNode {
.map(|i| i.try_into_logical_plan(ctx, extension_codec))
.collect::<Result<_, BallistaError>>()?;
- let extension_node = extension_codec.try_decode(node,
&input_plans)?;
+ let extension_node =
+ extension_codec.try_decode(node, &input_plans, ctx)?;
Ok(LogicalPlan::Extension(extension_node))
}
}
diff --git a/ballista/rust/core/src/serde/mod.rs
b/ballista/rust/core/src/serde/mod.rs
index cc1bbb4..19ff22a 100644
--- a/ballista/rust/core/src/serde/mod.rs
+++ b/ballista/rust/core/src/serde/mod.rs
@@ -84,6 +84,7 @@ pub trait LogicalExtensionCodec: Debug + Send + Sync {
&self,
buf: &[u8],
inputs: &[LogicalPlan],
+ ctx: &SessionContext,
) -> Result<Extension, BallistaError>;
fn try_encode(
@@ -101,6 +102,7 @@ impl LogicalExtensionCodec for DefaultLogicalExtensionCodec
{
&self,
_buf: &[u8],
_inputs: &[LogicalPlan],
+ _ctx: &SessionContext,
) -> Result<Extension, BallistaError> {
Err(BallistaError::NotImplemented(
"LogicalExtensionCodec is not provided".to_string(),
@@ -147,6 +149,7 @@ pub trait PhysicalExtensionCodec: Debug + Send + Sync {
&self,
buf: &[u8],
inputs: &[Arc<dyn ExecutionPlan>],
+ ctx: &SessionContext,
) -> Result<Arc<dyn ExecutionPlan>, BallistaError>;
fn try_encode(
@@ -164,6 +167,7 @@ impl PhysicalExtensionCodec for
DefaultPhysicalExtensionCodec {
&self,
_buf: &[u8],
_inputs: &[Arc<dyn ExecutionPlan>],
+ _ctx: &SessionContext,
) -> Result<Arc<dyn ExecutionPlan>, BallistaError> {
Err(BallistaError::NotImplemented(
"PhysicalExtensionCodec is not provided".to_string(),
@@ -360,6 +364,7 @@ mod tests {
use prost::Message;
use std::any::Any;
+ use datafusion_proto::from_proto::parse_expr;
use std::convert::TryInto;
use std::fmt;
use std::fmt::{Debug, Formatter};
@@ -596,6 +601,7 @@ mod tests {
&self,
buf: &[u8],
inputs: &[LogicalPlan],
+ ctx: &SessionContext,
) -> Result<Extension, BallistaError> {
if let Some((input, _)) = inputs.split_first() {
let proto = TopKPlanProto::decode(buf).map_err(|e| {
@@ -609,7 +615,7 @@ mod tests {
let node = TopKPlanNode::new(
proto.k as usize,
input.clone(),
- expr.try_into()?,
+ parse_expr(expr, ctx)?,
);
Ok(Extension {
@@ -653,6 +659,7 @@ mod tests {
&self,
buf: &[u8],
inputs: &[Arc<dyn ExecutionPlan>],
+ _ctx: &SessionContext,
) -> Result<Arc<dyn ExecutionPlan>, BallistaError> {
if let Some((input, _)) = inputs.split_first() {
let proto = TopKExecProto::decode(buf).map_err(|e| {
diff --git a/ballista/rust/core/src/serde/physical_plan/mod.rs
b/ballista/rust/core/src/serde/physical_plan/mod.rs
index 9c15122..936de0e 100644
--- a/ballista/rust/core/src/serde/physical_plan/mod.rs
+++ b/ballista/rust/core/src/serde/physical_plan/mod.rs
@@ -57,6 +57,7 @@ use datafusion::physical_plan::{
AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, WindowExpr,
};
use datafusion::prelude::SessionContext;
+use datafusion_proto::from_proto::parse_expr;
use prost::bytes::BufMut;
use prost::Message;
use std::convert::TryInto;
@@ -133,7 +134,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
let predicate = scan
.pruning_predicate
.as_ref()
- .map(|expr| expr.try_into())
+ .map(|expr| parse_expr(expr, ctx))
.transpose()?;
Ok(Arc::new(ParquetExec::new(
decode_scan_config(scan.base_conf.as_ref().unwrap(), ctx)?,
@@ -483,8 +484,11 @@ impl AsExecutionPlan for PhysicalPlanNode {
.map(|i| i.try_into_physical_plan(ctx, extension_codec))
.collect::<Result<_, BallistaError>>()?;
- let extension_node =
- extension_codec.try_decode(extension.node.as_slice(),
&inputs)?;
+ let extension_node = extension_codec.try_decode(
+ extension.node.as_slice(),
+ &inputs,
+ ctx,
+ )?;
Ok(extension_node)
}
diff --git a/datafusion-proto/proto/datafusion.proto
b/datafusion-proto/proto/datafusion.proto
index a0c5c2f..4ee3298 100644
--- a/datafusion-proto/proto/datafusion.proto
+++ b/datafusion-proto/proto/datafusion.proto
@@ -77,6 +77,12 @@ message LogicalExprNode {
// window expressions
WindowExprNode window_expr = 18;
+
+ // AggregateUDF expressions
+ AggregateUDFExprNode aggregate_udf_expr = 19;
+
+ // Scalar UDF expressions
+ ScalarUDFExprNode scalar_udf_expr = 20;
}
}
@@ -208,6 +214,16 @@ message AggregateExprNode {
repeated LogicalExprNode expr = 2;
}
+message AggregateUDFExprNode {
+ string fun_name = 1;
+ repeated LogicalExprNode args = 2;
+}
+
+message ScalarUDFExprNode {
+ string fun_name = 1;
+ repeated LogicalExprNode args = 2;
+}
+
enum BuiltInWindowFunction {
ROW_NUMBER = 0;
RANK = 1;
diff --git a/datafusion-proto/src/from_proto.rs
b/datafusion-proto/src/from_proto.rs
index e7a5def..8789129 100644
--- a/datafusion-proto/src/from_proto.rs
+++ b/datafusion-proto/src/from_proto.rs
@@ -16,6 +16,8 @@
// under the License.
use crate::protobuf;
+use datafusion::prelude::{bit_length, SessionContext};
+use datafusion::sql::planner::ContextProvider;
use datafusion::{
arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit,
UnionMode},
error::DataFusionError,
@@ -832,331 +834,365 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
}
}
-impl TryFrom<&protobuf::LogicalExprNode> for Expr {
- type Error = Error;
-
- fn try_from(expr: &protobuf::LogicalExprNode) -> Result<Self, Self::Error>
{
- use datafusion::physical_plan::window_functions;
- use protobuf::{logical_expr_node::ExprType, window_expr_node,
ScalarFunction};
-
- let expr_type = expr
- .expr_type
- .as_ref()
- .ok_or_else(|| Error::required("expr_type"))?;
-
- match expr_type {
- ExprType::BinaryExpr(binary_expr) => Ok(Self::BinaryExpr {
- left: Box::new(binary_expr.l.as_deref().required("l")?),
- op: from_proto_binary_op(&binary_expr.op)?,
- right: Box::new(binary_expr.r.as_deref().required("r")?),
- }),
- ExprType::Column(column) => Ok(Self::Column(column.into())),
- ExprType::Literal(literal) => {
- let scalar_value: ScalarValue = literal.try_into()?;
- Ok(Self::Literal(scalar_value))
- }
- ExprType::WindowExpr(expr) => {
- let window_function = expr
- .window_function
- .as_ref()
- .ok_or_else(|| Error::required("window_function"))?;
- let partition_by = expr
- .partition_by
- .iter()
- .map(|e| e.try_into())
- .collect::<Result<Vec<_>, _>>()?;
- let order_by = expr
- .order_by
- .iter()
- .map(|e| e.try_into())
- .collect::<Result<Vec<_>, _>>()?;
- let window_frame = expr
- .window_frame
- .as_ref()
- .map::<Result<WindowFrame, _>, _>(|e| match e {
- window_expr_node::WindowFrame::Frame(frame) => {
- let window_frame: WindowFrame =
frame.clone().try_into()?;
- if WindowFrameUnits::Range == window_frame.units
- && order_by.len() != 1
- {
- Err(proto_error("With window frame of type
RANGE, the order by expression must be of length 1"))
- } else {
- Ok(window_frame)
- }
+pub fn parse_expr(
+ proto: &protobuf::LogicalExprNode,
+ ctx: &SessionContext,
+) -> Result<Expr, Error> {
+ use datafusion::physical_plan::window_functions;
+ use protobuf::{logical_expr_node::ExprType, window_expr_node,
ScalarFunction};
+
+ let expr_type = proto
+ .expr_type
+ .as_ref()
+ .ok_or_else(|| Error::required("expr_type"))?;
+
+ match expr_type {
+ ExprType::BinaryExpr(binary_expr) => Ok(Expr::BinaryExpr {
+ left: Box::new(parse_required_expr(&binary_expr.l, ctx, "l")?),
+ op: from_proto_binary_op(&binary_expr.op)?,
+ right: Box::new(parse_required_expr(&binary_expr.r, ctx, "r")?),
+ }),
+ ExprType::Column(column) => Ok(Expr::Column(column.into())),
+ ExprType::Literal(literal) => {
+ let scalar_value: ScalarValue = literal.try_into()?;
+ Ok(Expr::Literal(scalar_value))
+ }
+ ExprType::WindowExpr(expr) => {
+ let window_function = expr
+ .window_function
+ .as_ref()
+ .ok_or_else(|| Error::required("window_function"))?;
+ let partition_by = expr
+ .partition_by
+ .iter()
+ .map(|e| parse_expr(e, ctx))
+ .collect::<Result<Vec<_>, _>>()?;
+ let order_by = expr
+ .order_by
+ .iter()
+ .map(|e| parse_expr(e, ctx))
+ .collect::<Result<Vec<_>, _>>()?;
+ let window_frame = expr
+ .window_frame
+ .as_ref()
+ .map::<Result<WindowFrame, _>, _>(|e| match e {
+ window_expr_node::WindowFrame::Frame(frame) => {
+ let window_frame: WindowFrame =
frame.clone().try_into()?;
+ if WindowFrameUnits::Range == window_frame.units
+ && order_by.len() != 1
+ {
+ Err(proto_error("With window frame of type RANGE,
the order by expression must be of length 1"))
+ } else {
+ Ok(window_frame)
}
- })
- .transpose()?;
-
- match window_function {
- window_expr_node::WindowFunction::AggrFunction(i) => {
- let aggr_function =
- protobuf::AggregateFunction::try_from(i)?.into();
-
- Ok(Self::WindowFunction {
- fun:
window_functions::WindowFunction::AggregateFunction(
- aggr_function,
- ),
- args: vec![expr.expr.as_deref().required("expr")?],
- partition_by,
- order_by,
- window_frame,
- })
- }
- window_expr_node::WindowFunction::BuiltInFunction(i) => {
- let built_in_function =
- protobuf::BuiltInWindowFunction::from_i32(*i)
- .ok_or_else(|| {
- Error::unknown("BuiltInWindowFunction", *i)
- })?
- .into();
-
- Ok(Self::WindowFunction {
- fun:
window_functions::WindowFunction::BuiltInWindowFunction(
- built_in_function,
- ),
- args: vec![expr.expr.as_deref().required("expr")?],
- partition_by,
- order_by,
- window_frame,
- })
}
- }
- }
- ExprType::AggregateExpr(expr) => {
- let fun =
-
protobuf::AggregateFunction::try_from(&expr.aggr_function)?.into();
-
- Ok(Self::AggregateFunction {
- fun,
- args: expr
- .expr
- .iter()
- .map(|e| e.try_into())
- .collect::<Result<Vec<_>, _>>()?,
- distinct: false, //TODO
})
- }
- ExprType::Alias(alias) => Ok(Self::Alias(
- Box::new(alias.expr.as_deref().required("expr")?),
- alias.alias.clone(),
- )),
- ExprType::IsNullExpr(is_null) => Ok(Self::IsNull(Box::new(
- is_null.expr.as_deref().required("expr")?,
- ))),
- ExprType::IsNotNullExpr(is_not_null) =>
Ok(Self::IsNotNull(Box::new(
- is_not_null.expr.as_deref().required("expr")?,
- ))),
- ExprType::NotExpr(not) => {
- Ok(Self::Not(Box::new(not.expr.as_deref().required("expr")?)))
- }
- ExprType::Between(between) => Ok(Self::Between {
- expr: Box::new(between.expr.as_deref().required("expr")?),
- negated: between.negated,
- low: Box::new(between.low.as_deref().required("low")?),
- high: Box::new(between.high.as_deref().required("high")?),
- }),
- ExprType::Case(case) => {
- let when_then_expr = case
- .when_then_expr
- .iter()
- .map(|e| {
- let when_expr =
e.when_expr.as_ref().required("when_expr")?;
- let then_expr =
e.then_expr.as_ref().required("then_expr")?;
- Ok((Box::new(when_expr), Box::new(then_expr)))
+ .transpose()?;
+
+ match window_function {
+ window_expr_node::WindowFunction::AggrFunction(i) => {
+ let aggr_function =
protobuf::AggregateFunction::try_from(i)?.into();
+
+ Ok(Expr::WindowFunction {
+ fun:
window_functions::WindowFunction::AggregateFunction(
+ aggr_function,
+ ),
+ args: vec![parse_required_expr(&expr.expr, ctx,
"expr")?],
+ partition_by,
+ order_by,
+ window_frame,
})
- .collect::<Result<Vec<(Box<Expr>, Box<Expr>)>, Error>>()?;
- Ok(Self::Case {
- expr: parse_optional_expr(&case.expr)?.map(Box::new),
- when_then_expr,
- else_expr:
parse_optional_expr(&case.else_expr)?.map(Box::new),
- })
- }
- ExprType::Cast(cast) => {
- let expr = Box::new(cast.expr.as_deref().required("expr")?);
- let data_type =
cast.arrow_type.as_ref().required("arrow_type")?;
- Ok(Self::Cast { expr, data_type })
- }
- ExprType::TryCast(cast) => {
- let expr = Box::new(cast.expr.as_deref().required("expr")?);
- let data_type =
cast.arrow_type.as_ref().required("arrow_type")?;
- Ok(Self::TryCast { expr, data_type })
+ }
+ window_expr_node::WindowFunction::BuiltInFunction(i) => {
+ let built_in_function =
protobuf::BuiltInWindowFunction::from_i32(*i)
+ .ok_or_else(|| Error::unknown("BuiltInWindowFunction",
*i))?
+ .into();
+
+ Ok(Expr::WindowFunction {
+ fun:
window_functions::WindowFunction::BuiltInWindowFunction(
+ built_in_function,
+ ),
+ args: vec![parse_required_expr(&expr.expr, ctx,
"expr")?],
+ partition_by,
+ order_by,
+ window_frame,
+ })
+ }
}
- ExprType::Sort(sort) => Ok(Self::Sort {
- expr: Box::new(sort.expr.as_deref().required("expr")?),
- asc: sort.asc,
- nulls_first: sort.nulls_first,
- }),
- ExprType::Negative(negative) => Ok(Self::Negative(Box::new(
- negative.expr.as_deref().required("expr")?,
- ))),
- ExprType::InList(in_list) => Ok(Self::InList {
- expr: Box::new(in_list.expr.as_deref().required("expr")?),
- list: in_list
- .list
+ }
+ ExprType::AggregateExpr(expr) => {
+ let fun =
protobuf::AggregateFunction::try_from(&expr.aggr_function)?.into();
+
+ Ok(Expr::AggregateFunction {
+ fun,
+ args: expr
+ .expr
.iter()
- .map(|expr| expr.try_into())
+ .map(|e| parse_expr(e, ctx))
.collect::<Result<Vec<_>, _>>()?,
- negated: in_list.negated,
- }),
- ExprType::Wildcard(_) => Ok(Self::Wildcard),
- ExprType::ScalarFunction(expr) => {
- let scalar_function =
protobuf::ScalarFunction::from_i32(expr.fun)
- .ok_or_else(|| Error::unknown("ScalarFunction",
expr.fun))?;
- let args = &expr.args;
-
- match scalar_function {
- ScalarFunction::Asin => Ok(asin((&args[0]).try_into()?)),
- ScalarFunction::Acos => Ok(acos((&args[0]).try_into()?)),
- ScalarFunction::Array => Ok(array(
- args.to_owned()
- .iter()
- .map(|e| e.try_into())
- .collect::<Result<Vec<_>, _>>()?,
- )),
- ScalarFunction::Sqrt => Ok(sqrt((&args[0]).try_into()?)),
- ScalarFunction::Sin => Ok(sin((&args[0]).try_into()?)),
- ScalarFunction::Cos => Ok(cos((&args[0]).try_into()?)),
- ScalarFunction::Tan => Ok(tan((&args[0]).try_into()?)),
- ScalarFunction::Atan => Ok(atan((&args[0]).try_into()?)),
- ScalarFunction::Exp => Ok(exp((&args[0]).try_into()?)),
- ScalarFunction::Log2 => Ok(log2((&args[0]).try_into()?)),
- ScalarFunction::Ln => Ok(ln((&args[0]).try_into()?)),
- ScalarFunction::Log10 => Ok(log10((&args[0]).try_into()?)),
- ScalarFunction::Floor => Ok(floor((&args[0]).try_into()?)),
- ScalarFunction::Ceil => Ok(ceil((&args[0]).try_into()?)),
- ScalarFunction::Round => Ok(round((&args[0]).try_into()?)),
- ScalarFunction::Trunc => Ok(trunc((&args[0]).try_into()?)),
- ScalarFunction::Abs => Ok(abs((&args[0]).try_into()?)),
- ScalarFunction::Signum =>
Ok(signum((&args[0]).try_into()?)),
- ScalarFunction::OctetLength => {
- Ok(octet_length((&args[0]).try_into()?))
- }
- ScalarFunction::Lower => Ok(lower((&args[0]).try_into()?)),
- ScalarFunction::Upper => Ok(upper((&args[0]).try_into()?)),
- ScalarFunction::Trim => Ok(trim((&args[0]).try_into()?)),
- ScalarFunction::Ltrim => Ok(ltrim((&args[0]).try_into()?)),
- ScalarFunction::Rtrim => Ok(rtrim((&args[0]).try_into()?)),
- ScalarFunction::DatePart => {
- Ok(date_part((&args[0]).try_into()?,
(&args[1]).try_into()?))
- }
- ScalarFunction::DateTrunc => {
- Ok(date_trunc((&args[0]).try_into()?,
(&args[1]).try_into()?))
- }
- ScalarFunction::Sha224 =>
Ok(sha224((&args[0]).try_into()?)),
- ScalarFunction::Sha256 =>
Ok(sha256((&args[0]).try_into()?)),
- ScalarFunction::Sha384 =>
Ok(sha384((&args[0]).try_into()?)),
- ScalarFunction::Sha512 =>
Ok(sha512((&args[0]).try_into()?)),
- ScalarFunction::Md5 => Ok(md5((&args[0]).try_into()?)),
- ScalarFunction::NullIf =>
Ok(nullif((&args[0]).try_into()?)),
- ScalarFunction::Digest => {
- Ok(digest((&args[0]).try_into()?,
(&args[1]).try_into()?))
- }
- ScalarFunction::Ascii => Ok(ascii((&args[0]).try_into()?)),
- ScalarFunction::BitLength => Ok((&args[0]).try_into()?),
- ScalarFunction::CharacterLength => {
- Ok(character_length((&args[0]).try_into()?))
- }
- ScalarFunction::Chr => Ok(chr((&args[0]).try_into()?)),
- ScalarFunction::InitCap =>
Ok(ascii((&args[0]).try_into()?)),
- ScalarFunction::Left => {
- Ok(left((&args[0]).try_into()?,
(&args[1]).try_into()?))
- }
- ScalarFunction::Random => Ok(random()),
- ScalarFunction::Repeat => {
- Ok(repeat((&args[0]).try_into()?,
(&args[1]).try_into()?))
- }
- ScalarFunction::Replace => Ok(replace(
- (&args[0]).try_into()?,
- (&args[1]).try_into()?,
- (&args[2]).try_into()?,
- )),
- ScalarFunction::Reverse =>
Ok(reverse((&args[0]).try_into()?)),
- ScalarFunction::Right => {
- Ok(right((&args[0]).try_into()?,
(&args[1]).try_into()?))
- }
- ScalarFunction::Concat => Ok(concat_expr(
- args.to_owned()
- .iter()
- .map(|e| e.try_into())
- .collect::<Result<Vec<_>, _>>()?,
- )),
- ScalarFunction::ConcatWithSeparator => Ok(concat_ws_expr(
- args.to_owned()
- .iter()
- .map(|e| e.try_into())
- .collect::<Result<Vec<_>, _>>()?,
- )),
- ScalarFunction::Lpad => Ok(lpad(
- args.to_owned()
- .iter()
- .map(|e| e.try_into())
- .collect::<Result<Vec<_>, _>>()?,
- )),
- ScalarFunction::Rpad => Ok(rpad(
- args.to_owned()
- .iter()
- .map(|e| e.try_into())
- .collect::<Result<Vec<_>, _>>()?,
- )),
- ScalarFunction::RegexpReplace => Ok(regexp_replace(
- args.to_owned()
- .iter()
- .map(|e| e.try_into())
- .collect::<Result<Vec<_>, _>>()?,
- )),
- ScalarFunction::RegexpMatch => Ok(regexp_match(
- args.to_owned()
- .iter()
- .map(|e| e.try_into())
- .collect::<Result<Vec<_>, _>>()?,
- )),
- ScalarFunction::Btrim => Ok(btrim(
- args.to_owned()
- .iter()
- .map(|e| e.try_into())
- .collect::<Result<Vec<_>, _>>()?,
- )),
- ScalarFunction::SplitPart => Ok(split_part(
- (&args[0]).try_into()?,
- (&args[1]).try_into()?,
- (&args[2]).try_into()?,
- )),
- ScalarFunction::StartsWith => {
- Ok(starts_with((&args[0]).try_into()?,
(&args[1]).try_into()?))
- }
- ScalarFunction::Strpos => {
- Ok(strpos((&args[0]).try_into()?,
(&args[1]).try_into()?))
- }
- ScalarFunction::Substr => {
- Ok(substr((&args[0]).try_into()?,
(&args[1]).try_into()?))
- }
- ScalarFunction::ToHex =>
Ok(to_hex((&args[0]).try_into()?)),
- ScalarFunction::ToTimestampMillis => {
- Ok(to_timestamp_millis((&args[0]).try_into()?))
- }
- ScalarFunction::ToTimestampMicros => {
- Ok(to_timestamp_micros((&args[0]).try_into()?))
- }
- ScalarFunction::ToTimestampSeconds => {
- Ok(to_timestamp_seconds((&args[0]).try_into()?))
- }
- ScalarFunction::Now => Ok(now_expr(
- args.to_owned()
- .iter()
- .map(|e| e.try_into())
- .collect::<Result<Vec<_>, _>>()?,
- )),
- ScalarFunction::Translate => Ok(translate(
- (&args[0]).try_into()?,
- (&args[1]).try_into()?,
- (&args[2]).try_into()?,
- )),
- _ => Err(proto_error(
- "Protobuf deserialization error: Unsupported scalar
function",
- )),
+ distinct: false, //TODO
+ })
+ }
+ ExprType::Alias(alias) => Ok(Expr::Alias(
+ Box::new(parse_required_expr(&alias.expr, ctx, "expr")?),
+ alias.alias.clone(),
+ )),
+ ExprType::IsNullExpr(is_null) =>
Ok(Expr::IsNull(Box::new(parse_required_expr(
+ &is_null.expr,
+ ctx,
+ "expr",
+ )?))),
+ ExprType::IsNotNullExpr(is_not_null) => Ok(Expr::IsNotNull(Box::new(
+ parse_required_expr(&is_not_null.expr, ctx, "expr")?,
+ ))),
+ ExprType::NotExpr(not) => Ok(Expr::Not(Box::new(parse_required_expr(
+ ¬.expr, ctx, "expr",
+ )?))),
+ ExprType::Between(between) => Ok(Expr::Between {
+ expr: Box::new(parse_required_expr(&between.expr, ctx, "expr")?),
+ negated: between.negated,
+ low: Box::new(parse_required_expr(&between.low, ctx, "expr")?),
+ high: Box::new(parse_required_expr(&between.high, ctx, "expr")?),
+ }),
+ ExprType::Case(case) => {
+ let when_then_expr = case
+ .when_then_expr
+ .iter()
+ .map(|e| {
+ let when_expr =
+ parse_required_expr_inner(&e.when_expr, ctx,
"when_expr")?;
+ let then_expr =
+ parse_required_expr_inner(&e.then_expr, ctx,
"then_expr")?;
+ Ok((Box::new(when_expr), Box::new(then_expr)))
+ })
+ .collect::<Result<Vec<(Box<Expr>, Box<Expr>)>, Error>>()?;
+ Ok(Expr::Case {
+ expr: parse_optional_expr(&case.expr, ctx)?.map(Box::new),
+ when_then_expr,
+ else_expr: parse_optional_expr(&case.else_expr,
ctx)?.map(Box::new),
+ })
+ }
+ ExprType::Cast(cast) => {
+ let expr = Box::new(parse_required_expr(&cast.expr, ctx, "expr")?);
+ let data_type = cast.arrow_type.as_ref().required("arrow_type")?;
+ Ok(Expr::Cast { expr, data_type })
+ }
+ ExprType::TryCast(cast) => {
+ let expr = Box::new(parse_required_expr(&cast.expr, ctx, "expr")?);
+ let data_type = cast.arrow_type.as_ref().required("arrow_type")?;
+ Ok(Expr::TryCast { expr, data_type })
+ }
+ ExprType::Sort(sort) => Ok(Expr::Sort {
+ expr: Box::new(parse_required_expr(&sort.expr, ctx, "expr")?),
+ asc: sort.asc,
+ nulls_first: sort.nulls_first,
+ }),
+ ExprType::Negative(negative) => Ok(Expr::Negative(Box::new(
+ parse_required_expr(&negative.expr, ctx, "expr")?,
+ ))),
+ ExprType::InList(in_list) => Ok(Expr::InList {
+ expr: Box::new(parse_required_expr(&in_list.expr, ctx, "expr")?),
+ list: in_list
+ .list
+ .iter()
+ .map(|expr| parse_expr(expr, ctx))
+ .collect::<Result<Vec<_>, _>>()?,
+ negated: in_list.negated,
+ }),
+ ExprType::Wildcard(_) => Ok(Expr::Wildcard),
+ ExprType::ScalarFunction(expr) => {
+ let scalar_function = protobuf::ScalarFunction::from_i32(expr.fun)
+ .ok_or_else(|| Error::unknown("ScalarFunction", expr.fun))?;
+ let args = &expr.args;
+
+ match scalar_function {
+ ScalarFunction::Asin => Ok(asin(parse_expr(&args[0], ctx)?)),
+ ScalarFunction::Acos => Ok(acos(parse_expr(&args[0], ctx)?)),
+ ScalarFunction::Array => Ok(array(
+ args.to_owned()
+ .iter()
+ .map(|expr| parse_expr(expr, ctx))
+ .collect::<Result<Vec<_>, _>>()?,
+ )),
+ ScalarFunction::Sqrt => Ok(sqrt(parse_expr(&args[0], ctx)?)),
+ ScalarFunction::Sin => Ok(sin(parse_expr(&args[0], ctx)?)),
+ ScalarFunction::Cos => Ok(cos(parse_expr(&args[0], ctx)?)),
+ ScalarFunction::Tan => Ok(tan(parse_expr(&args[0], ctx)?)),
+ ScalarFunction::Atan => Ok(atan(parse_expr(&args[0], ctx)?)),
+ ScalarFunction::Exp => Ok(exp(parse_expr(&args[0], ctx)?)),
+ ScalarFunction::Log2 => Ok(log2(parse_expr(&args[0], ctx)?)),
+ ScalarFunction::Ln => Ok(ln(parse_expr(&args[0], ctx)?)),
+ ScalarFunction::Log10 => Ok(log10(parse_expr(&args[0], ctx)?)),
+ ScalarFunction::Floor => Ok(floor(parse_expr(&args[0], ctx)?)),
+ ScalarFunction::Ceil => Ok(ceil(parse_expr(&args[0], ctx)?)),
+ ScalarFunction::Round => Ok(round(parse_expr(&args[0], ctx)?)),
+ ScalarFunction::Trunc => Ok(trunc(parse_expr(&args[0], ctx)?)),
+ ScalarFunction::Abs => Ok(abs(parse_expr(&args[0], ctx)?)),
+ ScalarFunction::Signum => Ok(signum(parse_expr(&args[0],
ctx)?)),
+ ScalarFunction::OctetLength => {
+ Ok(octet_length(parse_expr(&args[0], ctx)?))
+ }
+ ScalarFunction::Lower => Ok(lower(parse_expr(&args[0], ctx)?)),
+ ScalarFunction::Upper => Ok(upper(parse_expr(&args[0], ctx)?)),
+ ScalarFunction::Trim => Ok(trim(parse_expr(&args[0], ctx)?)),
+ ScalarFunction::Ltrim => Ok(ltrim(parse_expr(&args[0], ctx)?)),
+ ScalarFunction::Rtrim => Ok(rtrim(parse_expr(&args[0], ctx)?)),
+ ScalarFunction::DatePart => Ok(date_part(
+ parse_expr(&args[0], ctx)?,
+ parse_expr(&args[1], ctx)?,
+ )),
+ ScalarFunction::DateTrunc => Ok(date_trunc(
+ parse_expr(&args[0], ctx)?,
+ parse_expr(&args[1], ctx)?,
+ )),
+ ScalarFunction::Sha224 => Ok(sha224(parse_expr(&args[0],
ctx)?)),
+ ScalarFunction::Sha256 => Ok(sha256(parse_expr(&args[0],
ctx)?)),
+ ScalarFunction::Sha384 => Ok(sha384(parse_expr(&args[0],
ctx)?)),
+ ScalarFunction::Sha512 => Ok(sha512(parse_expr(&args[0],
ctx)?)),
+ ScalarFunction::Md5 => Ok(md5(parse_expr(&args[0], ctx)?)),
+ ScalarFunction::NullIf => Ok(nullif(parse_expr(&args[0],
ctx)?)),
+ ScalarFunction::Digest => Ok(digest(
+ parse_expr(&args[0], ctx)?,
+ parse_expr(&args[1], ctx)?,
+ )),
+ ScalarFunction::Ascii => Ok(ascii(parse_expr(&args[0], ctx)?)),
+ ScalarFunction::BitLength =>
Ok(bit_length(parse_expr(&args[0], ctx)?)),
+ ScalarFunction::CharacterLength => {
+ Ok(character_length(parse_expr(&args[0], ctx)?))
+ }
+ ScalarFunction::Chr => Ok(chr(parse_expr(&args[0], ctx)?)),
+ ScalarFunction::InitCap => Ok(ascii(parse_expr(&args[0],
ctx)?)),
+ ScalarFunction::Left => {
+ Ok(left(parse_expr(&args[0], ctx)?, parse_expr(&args[1],
ctx)?))
+ }
+ ScalarFunction::Random => Ok(random()),
+ ScalarFunction::Repeat => Ok(repeat(
+ parse_expr(&args[0], ctx)?,
+ parse_expr(&args[1], ctx)?,
+ )),
+ ScalarFunction::Replace => Ok(replace(
+ parse_expr(&args[0], ctx)?,
+ parse_expr(&args[1], ctx)?,
+ parse_expr(&args[2], ctx)?,
+ )),
+ ScalarFunction::Reverse => Ok(reverse(parse_expr(&args[0],
ctx)?)),
+ ScalarFunction::Right => Ok(right(
+ parse_expr(&args[0], ctx)?,
+ parse_expr(&args[1], ctx)?,
+ )),
+ ScalarFunction::Concat => Ok(concat_expr(
+ args.to_owned()
+ .iter()
+ .map(|expr| parse_expr(expr, ctx))
+ .collect::<Result<Vec<_>, _>>()?,
+ )),
+ ScalarFunction::ConcatWithSeparator => Ok(concat_ws_expr(
+ args.to_owned()
+ .iter()
+ .map(|expr| parse_expr(expr, ctx))
+ .collect::<Result<Vec<_>, _>>()?,
+ )),
+ ScalarFunction::Lpad => Ok(lpad(
+ args.to_owned()
+ .iter()
+ .map(|expr| parse_expr(expr, ctx))
+ .collect::<Result<Vec<_>, _>>()?,
+ )),
+ ScalarFunction::Rpad => Ok(rpad(
+ args.to_owned()
+ .iter()
+ .map(|expr| parse_expr(expr, ctx))
+ .collect::<Result<Vec<_>, _>>()?,
+ )),
+ ScalarFunction::RegexpReplace => Ok(regexp_replace(
+ args.to_owned()
+ .iter()
+ .map(|expr| parse_expr(expr, ctx))
+ .collect::<Result<Vec<_>, _>>()?,
+ )),
+ ScalarFunction::RegexpMatch => Ok(regexp_match(
+ args.to_owned()
+ .iter()
+ .map(|expr| parse_expr(expr, ctx))
+ .collect::<Result<Vec<_>, _>>()?,
+ )),
+ ScalarFunction::Btrim => Ok(btrim(
+ args.to_owned()
+ .iter()
+ .map(|expr| parse_expr(expr, ctx))
+ .collect::<Result<Vec<_>, _>>()?,
+ )),
+ ScalarFunction::SplitPart => Ok(split_part(
+ parse_expr(&args[0], ctx)?,
+ parse_expr(&args[1], ctx)?,
+ parse_expr(&args[2], ctx)?,
+ )),
+ ScalarFunction::StartsWith => Ok(starts_with(
+ parse_expr(&args[0], ctx)?,
+ parse_expr(&args[1], ctx)?,
+ )),
+ ScalarFunction::Strpos => Ok(strpos(
+ parse_expr(&args[0], ctx)?,
+ parse_expr(&args[1], ctx)?,
+ )),
+ ScalarFunction::Substr => Ok(substr(
+ parse_expr(&args[0], ctx)?,
+ parse_expr(&args[1], ctx)?,
+ )),
+ ScalarFunction::ToHex => Ok(to_hex(parse_expr(&args[0],
ctx)?)),
+ ScalarFunction::ToTimestampMillis => {
+ Ok(to_timestamp_millis(parse_expr(&args[0], ctx)?))
+ }
+ ScalarFunction::ToTimestampMicros => {
+ Ok(to_timestamp_micros(parse_expr(&args[0], ctx)?))
+ }
+ ScalarFunction::ToTimestampSeconds => {
+ Ok(to_timestamp_seconds(parse_expr(&args[0], ctx)?))
}
+ ScalarFunction::Now => Ok(now_expr(
+ args.to_owned()
+ .iter()
+ .map(|expr| parse_expr(expr, ctx))
+ .collect::<Result<Vec<_>, _>>()?,
+ )),
+ ScalarFunction::Translate => Ok(translate(
+ parse_expr(&args[0], ctx)?,
+ parse_expr(&args[1], ctx)?,
+ parse_expr(&args[2], ctx)?,
+ )),
+ _ => Err(proto_error(
+ "Protobuf deserialization error: Unsupported scalar
function",
+ )),
}
}
+ ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { fun_name, args
}) => {
+ let scalar_fn = ctx
+ .state
+ .lock()
+ .get_function_meta(fun_name.as_str()).ok_or_else(||
Error::General(format!("invalid aggregate function message, function {} is not
registered in the ExecutionContext", fun_name)))?;
+
+ Ok(Expr::ScalarUDF {
+ fun: scalar_fn,
+ args: args
+ .iter()
+ .map(|expr| parse_expr(expr, ctx))
+ .collect::<Result<Vec<_>, Error>>()?,
+ })
+ }
+ ExprType::AggregateUdfExpr(protobuf::AggregateUdfExprNode { fun_name,
args }) => {
+ let agg_fn = ctx
+ .state
+ .lock()
+ .get_aggregate_meta(fun_name.as_str()).ok_or_else(||
Error::General(format!("invalid aggregate function message, function {} is not
registered in the ExecutionContext", fun_name)))?;
+
+ Ok(Expr::AggregateUDF {
+ fun: agg_fn,
+ args: args
+ .iter()
+ .map(|expr| parse_expr(expr, ctx))
+ .collect::<Result<Vec<_>, Error>>()?,
+ })
+ }
}
}
@@ -1425,13 +1461,36 @@ fn from_proto_binary_op(op: &str) -> Result<Operator,
Error> {
fn parse_optional_expr(
p: &Option<Box<protobuf::LogicalExprNode>>,
+ ctx: &SessionContext,
) -> Result<Option<Expr>, Error> {
match p {
- Some(expr) => expr.as_ref().try_into().map(Some),
+ Some(expr) => parse_expr(expr.as_ref(), ctx).map(Some),
None => Ok(None),
}
}
+fn parse_required_expr(
+ p: &Option<Box<protobuf::LogicalExprNode>>,
+ ctx: &SessionContext,
+ field: impl Into<String>,
+) -> Result<Expr, Error> {
+ match p {
+ Some(expr) => parse_expr(expr.as_ref(), ctx),
+ None => Err(Error::required(field)),
+ }
+}
+
+fn parse_required_expr_inner(
+ p: &Option<protobuf::LogicalExprNode>,
+ ctx: &SessionContext,
+ field: impl Into<String>,
+) -> Result<Expr, Error> {
+ match p {
+ Some(expr) => parse_expr(expr, ctx),
+ None => Err(Error::required(field)),
+ }
+}
+
fn proto_error<S: Into<String>>(message: S) -> Error {
Error::General(message.into())
}
diff --git a/datafusion-proto/src/lib.rs b/datafusion-proto/src/lib.rs
index b880f8e..0688215 100644
--- a/datafusion-proto/src/lib.rs
+++ b/datafusion-proto/src/lib.rs
@@ -26,6 +26,12 @@ pub mod to_proto;
#[cfg(test)]
mod roundtrip_tests {
+ use super::from_proto::parse_expr;
+ use super::protobuf;
+ use datafusion::arrow::array::ArrayRef;
+ use datafusion::logical_plan::create_udaf;
+ use datafusion::physical_plan::functions::{make_scalar_function,
Volatility};
+ use datafusion::physical_plan::Accumulator;
use datafusion::{
arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit, UnionMode},
logical_plan::{col, Expr},
@@ -33,14 +39,15 @@ mod roundtrip_tests {
prelude::*,
scalar::ScalarValue,
};
+ use std::sync::Arc;
- // Given a DataFusion type, convert it to protobuf and back, using debug
formatting to test
+ // Given a DataFusion logical Expr, convert it to protobuf and back, using
debug formatting to test
// equality.
- macro_rules! roundtrip_test {
- ($initial_struct:ident, $proto_type:ty, $struct_type:ty) => {
- let proto: $proto_type = (&$initial_struct).try_into().unwrap();
+ macro_rules! roundtrip_expr_test {
+ ($initial_struct:ident, $ctx:ident) => {
+ let proto: protobuf::LogicalExprNode =
(&$initial_struct).try_into().unwrap();
- let round_trip: $struct_type = (&proto).try_into().unwrap();
+ let round_trip: Expr = parse_expr(&proto, &$ctx).unwrap();
assert_eq!(
format!("{:?}", $initial_struct),
@@ -575,95 +582,102 @@ mod roundtrip_tests {
#[test]
fn roundtrip_not() {
- let test_expr = Expr::Not(Box::new(Expr::Literal((1.0).into())));
+ let test_expr = Expr::Not(Box::new(lit(1.0_f32)));
- roundtrip_test!(test_expr, super::protobuf::LogicalExprNode, Expr);
+ let ctx = SessionContext::new();
+ roundtrip_expr_test!(test_expr, ctx);
}
#[test]
fn roundtrip_is_null() {
let test_expr = Expr::IsNull(Box::new(col("id")));
- roundtrip_test!(test_expr, super::protobuf::LogicalExprNode, Expr);
+ let ctx = SessionContext::new();
+ roundtrip_expr_test!(test_expr, ctx);
}
#[test]
fn roundtrip_is_not_null() {
let test_expr = Expr::IsNotNull(Box::new(col("id")));
- roundtrip_test!(test_expr, super::protobuf::LogicalExprNode, Expr);
+ let ctx = SessionContext::new();
+ roundtrip_expr_test!(test_expr, ctx);
}
#[test]
fn roundtrip_between() {
let test_expr = Expr::Between {
- expr: Box::new(Expr::Literal((1.0).into())),
+ expr: Box::new(lit(1.0_f32)),
negated: true,
- low: Box::new(Expr::Literal((2.0).into())),
- high: Box::new(Expr::Literal((3.0).into())),
+ low: Box::new(lit(2.0_f32)),
+ high: Box::new(lit(3.0_f32)),
};
- roundtrip_test!(test_expr, super::protobuf::LogicalExprNode, Expr);
+ let ctx = SessionContext::new();
+ roundtrip_expr_test!(test_expr, ctx);
}
#[test]
fn roundtrip_case() {
let test_expr = Expr::Case {
- expr: Some(Box::new(Expr::Literal((1.0).into()))),
- when_then_expr: vec![(
- Box::new(Expr::Literal((2.0).into())),
- Box::new(Expr::Literal((3.0).into())),
- )],
- else_expr: Some(Box::new(Expr::Literal((4.0).into()))),
+ expr: Some(Box::new(lit(1.0_f32))),
+ when_then_expr: vec![(Box::new(lit(2.0_f32)),
Box::new(lit(3.0_f32)))],
+ else_expr: Some(Box::new(lit(4.0_f32))),
};
- roundtrip_test!(test_expr, super::protobuf::LogicalExprNode, Expr);
+ let ctx = SessionContext::new();
+ roundtrip_expr_test!(test_expr, ctx);
}
#[test]
fn roundtrip_cast() {
let test_expr = Expr::Cast {
- expr: Box::new(Expr::Literal((1.0).into())),
+ expr: Box::new(lit(1.0_f32)),
data_type: DataType::Boolean,
};
- roundtrip_test!(test_expr, super::protobuf::LogicalExprNode, Expr);
+ let ctx = SessionContext::new();
+ roundtrip_expr_test!(test_expr, ctx);
}
#[test]
fn roundtrip_sort_expr() {
let test_expr = Expr::Sort {
- expr: Box::new(Expr::Literal((1.0).into())),
+ expr: Box::new(lit(1.0_f32)),
asc: true,
nulls_first: true,
};
- roundtrip_test!(test_expr, super::protobuf::LogicalExprNode, Expr);
+ let ctx = SessionContext::new();
+ roundtrip_expr_test!(test_expr, ctx);
}
#[test]
fn roundtrip_negative() {
- let test_expr = Expr::Negative(Box::new(Expr::Literal((1.0).into())));
+ let test_expr = Expr::Negative(Box::new(lit(1.0_f32)));
- roundtrip_test!(test_expr, super::protobuf::LogicalExprNode, Expr);
+ let ctx = SessionContext::new();
+ roundtrip_expr_test!(test_expr, ctx);
}
#[test]
fn roundtrip_inlist() {
let test_expr = Expr::InList {
- expr: Box::new(Expr::Literal((1.0).into())),
- list: vec![Expr::Literal((2.0).into())],
+ expr: Box::new(lit(1.0_f32)),
+ list: vec![lit(2.0_f32)],
negated: true,
};
- roundtrip_test!(test_expr, super::protobuf::LogicalExprNode, Expr);
+ let ctx = SessionContext::new();
+ roundtrip_expr_test!(test_expr, ctx);
}
#[test]
fn roundtrip_wildcard() {
let test_expr = Expr::Wildcard;
- roundtrip_test!(test_expr, super::protobuf::LogicalExprNode, Expr);
+ let ctx = SessionContext::new();
+ roundtrip_expr_test!(test_expr, ctx);
}
#[test]
@@ -672,17 +686,98 @@ mod roundtrip_tests {
fun: Sqrt,
args: vec![col("col")],
};
- roundtrip_test!(test_expr, super::protobuf::LogicalExprNode, Expr);
+ let ctx = SessionContext::new();
+ roundtrip_expr_test!(test_expr, ctx);
}
#[test]
fn roundtrip_approx_percentile_cont() {
let test_expr = Expr::AggregateFunction {
fun: aggregates::AggregateFunction::ApproxPercentileCont,
- args: vec![col("bananas"), lit(0.42)],
+ args: vec![col("bananas"), lit(0.42_f32)],
distinct: false,
};
- roundtrip_test!(test_expr, super::protobuf::LogicalExprNode, Expr);
+ let ctx = SessionContext::new();
+ roundtrip_expr_test!(test_expr, ctx);
+ }
+
+ #[test]
+ fn roundtrip_aggregate_udf() {
+ #[derive(Debug)]
+ struct Dummy {}
+
+ impl Accumulator for Dummy {
+ fn state(&self) -> datafusion::error::Result<Vec<ScalarValue>> {
+ Ok(vec![])
+ }
+
+ fn update_batch(
+ &mut self,
+ _values: &[ArrayRef],
+ ) -> datafusion::error::Result<()> {
+ Ok(())
+ }
+
+ fn merge_batch(
+ &mut self,
+ _states: &[ArrayRef],
+ ) -> datafusion::error::Result<()> {
+ Ok(())
+ }
+
+ fn evaluate(&self) -> datafusion::error::Result<ScalarValue> {
+ Ok(ScalarValue::Float64(None))
+ }
+ }
+
+ let dummy_agg = create_udaf(
+ // the name; used to represent it in plan descriptions and in the
registry, to use in SQL.
+ "dummy_agg",
+ // the input type; DataFusion guarantees that the first entry of
`values` in `update` has this type.
+ DataType::Float64,
+ // the return type; DataFusion expects this to match the type
returned by `evaluate`.
+ Arc::new(DataType::Float64),
+ Volatility::Immutable,
+ // This is the accumulator factory; DataFusion uses it to create
new accumulators.
+ Arc::new(|| Ok(Box::new(Dummy {}))),
+ // This is the description of the state. `state()` must match the
types here.
+ Arc::new(vec![DataType::Float64, DataType::UInt32]),
+ );
+
+ let test_expr = Expr::AggregateUDF {
+ fun: Arc::new(dummy_agg.clone()),
+ args: vec![lit(1.0_f64)],
+ };
+
+ let mut ctx = SessionContext::new();
+ ctx.register_udaf(dummy_agg);
+
+ roundtrip_expr_test!(test_expr, ctx);
+ }
+
+ #[test]
+ fn roundtrip_scalar_udf() {
+ let fn_impl = |args: &[ArrayRef]| Ok(Arc::new(args[0].clone()) as
ArrayRef);
+
+ let scalar_fn = make_scalar_function(fn_impl);
+
+ let udf = create_udf(
+ "dummy",
+ vec![DataType::Utf8],
+ Arc::new(DataType::Utf8),
+ Volatility::Immutable,
+ scalar_fn,
+ );
+
+ let test_expr = Expr::ScalarUDF {
+ fun: Arc::new(udf.clone()),
+ args: vec![lit("")],
+ };
+
+ let mut ctx = SessionContext::new();
+ ctx.register_udf(udf);
+
+ roundtrip_expr_test!(test_expr, ctx);
}
}
diff --git a/datafusion-proto/src/to_proto.rs b/datafusion-proto/src/to_proto.rs
index 29c533a..753e59c 100644
--- a/datafusion-proto/src/to_proto.rs
+++ b/datafusion-proto/src/to_proto.rs
@@ -20,6 +20,7 @@
//! processes.
use crate::protobuf;
+
use datafusion::{
arrow::datatypes::{
DataType, Field, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionMode,
@@ -523,8 +524,27 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
)),
}
}
- Expr::ScalarUDF { .. } => unimplemented!(),
- Expr::AggregateUDF { .. } => unimplemented!(),
+ Expr::ScalarUDF { fun, args } => Self {
+ expr_type:
Some(ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode {
+ fun_name: fun.name.clone(),
+ args: args
+ .iter()
+ .map(|expr| expr.try_into())
+ .collect::<Result<Vec<_>, Error>>()?,
+ })),
+ },
+ Expr::AggregateUDF { fun, args } => Self {
+ expr_type: Some(ExprType::AggregateUdfExpr(
+ protobuf::AggregateUdfExprNode {
+ fun_name: fun.name.clone(),
+ args: args.iter().map(|expr|
expr.try_into()).collect::<Result<
+ Vec<_>,
+ Error,
+ >>(
+ )?,
+ },
+ )),
+ },
Expr::Not(expr) => {
let expr = Box::new(protobuf::Not {
expr: Some(Box::new(expr.as_ref().try_into()?)),