This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new a4c71e220b [Minor]: Move some repetitive codes to functions(proto)
(#9811)
a4c71e220b is described below
commit a4c71e220be20c591b3dde38d5a9aa410e458466
Author: Mustafa Akur <[email protected]>
AuthorDate: Wed Mar 27 06:09:23 2024 +0300
[Minor]: Move some repetitive codes to functions(proto) (#9811)
* add parse_exprs util
* Minor changes
* Minor changes
* Add vector field converter
* Add serialize exprs
* proto to arrow field conversion
* Simplifications
* All tests pass
* Simplifications
---
datafusion/proto/src/logical_plan/from_proto.rs | 176 ++++++++---------------
datafusion/proto/src/logical_plan/to_proto.rs | 116 ++++++---------
datafusion/proto/src/physical_plan/from_proto.rs | 106 ++++++--------
datafusion/proto/src/physical_plan/mod.rs | 51 ++++---
datafusion/proto/src/physical_plan/to_proto.rs | 124 ++++++++--------
5 files changed, 231 insertions(+), 342 deletions(-)
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs
b/datafusion/proto/src/logical_plan/from_proto.rs
index d5eebcb698..4b9874bf8f 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -323,11 +323,7 @@ impl TryFrom<&protobuf::arrow_type::ArrowTypeEnum> for
DataType {
DataType::FixedSizeList(Arc::new(list_type), list_size)
}
arrow_type::ArrowTypeEnum::Struct(strct) => DataType::Struct(
- strct
- .sub_field_types
- .iter()
- .map(Field::try_from)
- .collect::<Result<_, _>>()?,
+ parse_proto_fields_to_fields(&strct.sub_field_types)?.into(),
),
arrow_type::ArrowTypeEnum::Union(union) => {
let union_mode =
protobuf::UnionMode::try_from(union.union_mode)
@@ -336,11 +332,7 @@ impl TryFrom<&protobuf::arrow_type::ArrowTypeEnum> for
DataType {
protobuf::UnionMode::Dense => UnionMode::Dense,
protobuf::UnionMode::Sparse => UnionMode::Sparse,
};
- let union_fields = union
- .union_types
- .iter()
- .map(TryInto::try_into)
- .collect::<Result<Vec<Field>, _>>()?;
+ let union_fields =
parse_proto_fields_to_fields(&union.union_types)?;
// Default to index based type ids if not provided
let type_ids: Vec<_> = match union.type_ids.is_empty() {
@@ -763,10 +755,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
.map(|f| f.field.clone())
.collect::<Option<Vec<_>>>();
let fields = fields.ok_or_else(||
Error::required("UnionField"))?;
- let fields = fields
- .iter()
- .map(Field::try_from)
- .collect::<Result<Vec<_>, _>>()?;
+ let fields = parse_proto_fields_to_fields(&fields)?;
let fields = UnionFields::new(ids, fields);
let v_id = val.value_id as i8;
let val = match &val.value {
@@ -937,11 +926,7 @@ pub fn parse_expr(
match expr_type {
ExprType::BinaryExpr(binary_expr) => {
let op = from_proto_binary_op(&binary_expr.op)?;
- let operands = binary_expr
- .operands
- .iter()
- .map(|expr| parse_expr(expr, registry, codec))
- .collect::<Result<Vec<_>, _>>()?;
+ let operands = parse_exprs(&binary_expr.operands, registry,
codec)?;
if operands.len() < 2 {
return Err(proto_error(
@@ -1025,16 +1010,8 @@ pub fn parse_expr(
.window_function
.as_ref()
.ok_or_else(|| Error::required("window_function"))?;
- let partition_by = expr
- .partition_by
- .iter()
- .map(|e| parse_expr(e, registry, codec))
- .collect::<Result<Vec<_>, _>>()?;
- let mut order_by = expr
- .order_by
- .iter()
- .map(|e| parse_expr(e, registry, codec))
- .collect::<Result<Vec<_>, _>>()?;
+ let partition_by = parse_exprs(&expr.partition_by, registry,
codec)?;
+ let mut order_by = parse_exprs(&expr.order_by, registry, codec)?;
let window_frame = expr
.window_frame
.as_ref()
@@ -1130,10 +1107,7 @@ pub fn parse_expr(
Ok(Expr::AggregateFunction(expr::AggregateFunction::new(
fun,
- expr.expr
- .iter()
- .map(|e| parse_expr(e, registry, codec))
- .collect::<Result<Vec<_>, _>>()?,
+ parse_exprs(&expr.expr, registry, codec)?,
expr.distinct,
parse_optional_expr(expr.filter.as_deref(), registry, codec)?
.map(Box::new),
@@ -1331,11 +1305,7 @@ pub fn parse_expr(
parse_required_expr(negative.expr.as_deref(), registry, "expr",
codec)?,
))),
ExprType::Unnest(unnest) => {
- let exprs = unnest
- .exprs
- .iter()
- .map(|e| parse_expr(e, registry, codec))
- .collect::<Result<Vec<_>, _>>()?;
+ let exprs = parse_exprs(&unnest.exprs, registry, codec)?;
Ok(Expr::Unnest(Unnest { exprs }))
}
ExprType::InList(in_list) => Ok(Expr::InList(InList::new(
@@ -1345,11 +1315,7 @@ pub fn parse_expr(
"expr",
codec,
)?),
- in_list
- .list
- .iter()
- .map(|expr| parse_expr(expr, registry, codec))
- .collect::<Result<Vec<_>, _>>()?,
+ parse_exprs(&in_list.list, registry, codec)?,
in_list.negated,
))),
ExprType::Wildcard(protobuf::Wildcard { qualifier }) =>
Ok(Expr::Wildcard {
@@ -1401,18 +1367,8 @@ pub fn parse_expr(
Ok(factorial(parse_expr(&args[0], registry, codec)?))
}
ScalarFunction::Ceil => Ok(ceil(parse_expr(&args[0], registry,
codec)?)),
- ScalarFunction::Round => Ok(round(
- args.to_owned()
- .iter()
- .map(|expr| parse_expr(expr, registry, codec))
- .collect::<Result<Vec<_>, _>>()?,
- )),
- ScalarFunction::Trunc => Ok(trunc(
- args.to_owned()
- .iter()
- .map(|expr| parse_expr(expr, registry, codec))
- .collect::<Result<Vec<_>, _>>()?,
- )),
+ ScalarFunction::Round => Ok(round(parse_exprs(args, registry,
codec)?)),
+ ScalarFunction::Trunc => Ok(trunc(parse_exprs(args, registry,
codec)?)),
ScalarFunction::Signum => {
Ok(signum(parse_expr(&args[0], registry, codec)?))
}
@@ -1442,30 +1398,14 @@ pub fn parse_expr(
parse_expr(&args[0], registry, codec)?,
parse_expr(&args[1], registry, codec)?,
)),
- ScalarFunction::Concat => Ok(concat_expr(
- args.to_owned()
- .iter()
- .map(|expr| parse_expr(expr, registry, codec))
- .collect::<Result<Vec<_>, _>>()?,
- )),
- ScalarFunction::ConcatWithSeparator => Ok(concat_ws_expr(
- args.to_owned()
- .iter()
- .map(|expr| parse_expr(expr, registry, codec))
- .collect::<Result<Vec<_>, _>>()?,
- )),
- ScalarFunction::Lpad => Ok(lpad(
- args.to_owned()
- .iter()
- .map(|expr| parse_expr(expr, registry, codec))
- .collect::<Result<Vec<_>, _>>()?,
- )),
- ScalarFunction::Rpad => Ok(rpad(
- args.to_owned()
- .iter()
- .map(|expr| parse_expr(expr, registry, codec))
- .collect::<Result<Vec<_>, _>>()?,
- )),
+ ScalarFunction::Concat => {
+ Ok(concat_expr(parse_exprs(args, registry, codec)?))
+ }
+ ScalarFunction::ConcatWithSeparator => {
+ Ok(concat_ws_expr(parse_exprs(args, registry, codec)?))
+ }
+ ScalarFunction::Lpad => Ok(lpad(parse_exprs(args, registry,
codec)?)),
+ ScalarFunction::Rpad => Ok(rpad(parse_exprs(args, registry,
codec)?)),
ScalarFunction::EndsWith => Ok(ends_with(
parse_expr(&args[0], registry, codec)?,
parse_expr(&args[1], registry, codec)?,
@@ -1494,12 +1434,9 @@ pub fn parse_expr(
parse_expr(&args[1], registry, codec)?,
parse_expr(&args[2], registry, codec)?,
)),
- ScalarFunction::Coalesce => Ok(coalesce(
- args.to_owned()
- .iter()
- .map(|expr| parse_expr(expr, registry, codec))
- .collect::<Result<Vec<_>, _>>()?,
- )),
+ ScalarFunction::Coalesce => {
+ Ok(coalesce(parse_exprs(args, registry, codec)?))
+ }
ScalarFunction::Pi => Ok(pi()),
ScalarFunction::Power => Ok(power(
parse_expr(&args[0], registry, codec)?,
@@ -1543,9 +1480,7 @@ pub fn parse_expr(
};
Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf(
scalar_fn,
- args.iter()
- .map(|expr| parse_expr(expr, registry, codec))
- .collect::<Result<Vec<_>, Error>>()?,
+ parse_exprs(args, registry, codec)?,
)))
}
ExprType::AggregateUdfExpr(pb) => {
@@ -1553,10 +1488,7 @@ pub fn parse_expr(
Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf(
agg_fn,
- pb.args
- .iter()
- .map(|expr| parse_expr(expr, registry, codec))
- .collect::<Result<Vec<_>, Error>>()?,
+ parse_exprs(&pb.args, registry, codec)?,
false,
parse_optional_expr(pb.filter.as_deref(), registry,
codec)?.map(Box::new),
parse_vec_expr(&pb.order_by, registry, codec)?,
@@ -1566,28 +1498,16 @@ pub fn parse_expr(
ExprType::GroupingSet(GroupingSetNode { expr }) => {
Ok(Expr::GroupingSet(GroupingSets(
expr.iter()
- .map(|expr_list| {
- expr_list
- .expr
- .iter()
- .map(|expr| parse_expr(expr, registry, codec))
- .collect::<Result<Vec<_>, Error>>()
- })
+ .map(|expr_list| parse_exprs(&expr_list.expr, registry,
codec))
.collect::<Result<Vec<_>, Error>>()?,
)))
}
ExprType::Cube(CubeNode { expr }) =>
Ok(Expr::GroupingSet(GroupingSet::Cube(
- expr.iter()
- .map(|expr| parse_expr(expr, registry, codec))
- .collect::<Result<Vec<_>, Error>>()?,
+ parse_exprs(expr, registry, codec)?,
))),
- ExprType::Rollup(RollupNode { expr }) => {
- Ok(Expr::GroupingSet(GroupingSet::Rollup(
- expr.iter()
- .map(|expr| parse_expr(expr, registry, codec))
- .collect::<Result<Vec<_>, Error>>()?,
- )))
- }
+ ExprType::Rollup(RollupNode { expr }) => Ok(Expr::GroupingSet(
+ GroupingSet::Rollup(parse_exprs(expr, registry, codec)?),
+ )),
ExprType::Placeholder(PlaceholderNode { id, data_type }) => match
data_type {
None => Ok(Expr::Placeholder(Placeholder::new(id.clone(), None))),
Some(data_type) => Ok(Expr::Placeholder(Placeholder::new(
@@ -1598,6 +1518,24 @@ pub fn parse_expr(
}
}
+/// Parse a vector of `protobuf::LogicalExprNode`s.
+pub fn parse_exprs<'a, I>(
+ protos: I,
+ registry: &dyn FunctionRegistry,
+ codec: &dyn LogicalExtensionCodec,
+) -> Result<Vec<Expr>, Error>
+where
+ I: IntoIterator<Item = &'a protobuf::LogicalExprNode>,
+{
+ let res = protos
+ .into_iter()
+ .map(|elem| {
+ parse_expr(elem, registry, codec).map_err(|e|
plan_datafusion_err!("{}", e))
+ })
+ .collect::<Result<Vec<_>>>()?;
+ Ok(res)
+}
+
/// Parse an optional escape_char for Like, ILike, SimilarTo
fn parse_escape_char(s: &str) -> Result<Option<char>> {
match s.len() {
@@ -1654,12 +1592,7 @@ fn parse_vec_expr(
registry: &dyn FunctionRegistry,
codec: &dyn LogicalExtensionCodec,
) -> Result<Option<Vec<Expr>>, Error> {
- let res = p
- .iter()
- .map(|elem| {
- parse_expr(elem, registry, codec).map_err(|e|
plan_datafusion_err!("{}", e))
- })
- .collect::<Result<Vec<_>>>()?;
+ let res = parse_exprs(p, registry, codec)?;
// Convert empty vector to None.
Ok((!res.is_empty()).then_some(res))
}
@@ -1690,3 +1623,16 @@ fn parse_required_expr(
fn proto_error<S: Into<String>>(message: S) -> Error {
Error::General(message.into())
}
+
+/// Converts a vector of `protobuf::Field`s to `Arc<arrow::Field>`s.
+fn parse_proto_fields_to_fields<'a, I>(
+ fields: I,
+) -> std::result::Result<Vec<Field>, Error>
+where
+ I: IntoIterator<Item = &'a protobuf::Field>,
+{
+ fields
+ .into_iter()
+ .map(Field::try_from)
+ .collect::<Result<_, _>>()
+}
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs
b/datafusion/proto/src/logical_plan/to_proto.rs
index 0432b54acf..1335d511a0 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -19,6 +19,8 @@
//! DataFusion logical plans to be serialized and transmitted between
//! processes.
+use std::sync::Arc;
+
use crate::protobuf::{
self,
arrow_type::ArrowTypeEnum,
@@ -186,10 +188,7 @@ impl TryFrom<&DataType> for
protobuf::arrow_type::ArrowTypeEnum {
field_type: Some(Box::new(item_type.as_ref().try_into()?)),
})),
DataType::Struct(struct_fields) => Self::Struct(protobuf::Struct {
- sub_field_types: struct_fields
- .iter()
- .map(|field| field.as_ref().try_into())
- .collect::<Result<Vec<_>, Error>>()?,
+ sub_field_types:
convert_arc_fields_to_proto_fields(struct_fields)?,
}),
DataType::Union(fields, union_mode) => {
let union_mode = match union_mode {
@@ -197,10 +196,7 @@ impl TryFrom<&DataType> for
protobuf::arrow_type::ArrowTypeEnum {
UnionMode::Dense => protobuf::UnionMode::Dense,
};
Self::Union(protobuf::Union {
- union_types: fields
- .iter()
- .map(|(_, field)| field.as_ref().try_into())
- .collect::<Result<Vec<_>, Error>>()?,
+ union_types:
convert_arc_fields_to_proto_fields(fields.iter().map(|(_, item)|item))?,
union_mode: union_mode.into(),
type_ids: fields.iter().map(|(x, _)| x as i32).collect(),
})
@@ -262,11 +258,7 @@ impl TryFrom<&Schema> for protobuf::Schema {
fn try_from(schema: &Schema) -> Result<Self, Self::Error> {
Ok(Self {
- columns: schema
- .fields()
- .iter()
- .map(|f| f.as_ref().try_into())
- .collect::<Result<Vec<_>, Error>>()?,
+ columns: convert_arc_fields_to_proto_fields(schema.fields())?,
metadata: schema.metadata.clone(),
})
}
@@ -277,11 +269,7 @@ impl TryFrom<SchemaRef> for protobuf::Schema {
fn try_from(schema: SchemaRef) -> Result<Self, Self::Error> {
Ok(Self {
- columns: schema
- .fields()
- .iter()
- .map(|f| f.as_ref().try_into())
- .collect::<Result<Vec<_>, Error>>()?,
+ columns: convert_arc_fields_to_proto_fields(schema.fields())?,
metadata: schema.metadata.clone(),
})
}
@@ -486,6 +474,19 @@ impl TryFrom<&WindowFrame> for protobuf::WindowFrame {
}
}
+pub fn serialize_exprs<'a, I>(
+ exprs: I,
+ codec: &dyn LogicalExtensionCodec,
+) -> Result<Vec<protobuf::LogicalExprNode>, Error>
+where
+ I: IntoIterator<Item = &'a Expr>,
+{
+ exprs
+ .into_iter()
+ .map(|expr| serialize_expr(expr, codec))
+ .collect::<Result<Vec<_>, Error>>()
+}
+
pub fn serialize_expr(
expr: &Expr,
codec: &dyn LogicalExtensionCodec,
@@ -543,11 +544,7 @@ pub fn serialize_expr(
// We need to reverse exprs since operands are expected to be
// linearized from left innermost to right outermost (but while
// traversing the chain we do the exact opposite).
- operands: exprs
- .into_iter()
- .rev()
- .map(|expr| serialize_expr(expr, codec))
- .collect::<Result<Vec<_>, Error>>()?,
+ operands: serialize_exprs(exprs.into_iter().rev(), codec)?,
op: format!("{op:?}"),
};
protobuf::LogicalExprNode {
@@ -639,14 +636,8 @@ pub fn serialize_expr(
} else {
None
};
- let partition_by = partition_by
- .iter()
- .map(|e| serialize_expr(e, codec))
- .collect::<Result<Vec<_>, _>>()?;
- let order_by = order_by
- .iter()
- .map(|e| serialize_expr(e, codec))
- .collect::<Result<Vec<_>, _>>()?;
+ let partition_by = serialize_exprs(partition_by, codec)?;
+ let order_by = serialize_exprs(order_by, codec)?;
let window_frame: Option<protobuf::WindowFrame> =
Some(window_frame.try_into()?);
@@ -744,20 +735,14 @@ pub fn serialize_expr(
let aggregate_expr = protobuf::AggregateExprNode {
aggr_function: aggr_function.into(),
- expr: args
- .iter()
- .map(|v| serialize_expr(v, codec))
- .collect::<Result<Vec<_>, _>>()?,
+ expr: serialize_exprs(args, codec)?,
distinct: *distinct,
filter: match filter {
Some(e) => Some(Box::new(serialize_expr(e, codec)?)),
None => None,
},
order_by: match order_by {
- Some(e) => e
- .iter()
- .map(|expr| serialize_expr(expr, codec))
- .collect::<Result<Vec<_>, _>>()?,
+ Some(e) => serialize_exprs(e, codec)?,
None => vec![],
},
};
@@ -769,19 +754,13 @@ pub fn serialize_expr(
expr_type: Some(ExprType::AggregateUdfExpr(Box::new(
protobuf::AggregateUdfExprNode {
fun_name: fun.name().to_string(),
- args: args
- .iter()
- .map(|expr| serialize_expr(expr, codec))
- .collect::<Result<Vec<_>, Error>>()?,
+ args: serialize_exprs(args, codec)?,
filter: match filter {
Some(e) =>
Some(Box::new(serialize_expr(e.as_ref(), codec)?)),
None => None,
},
order_by: match order_by {
- Some(e) => e
- .iter()
- .map(|expr| serialize_expr(expr, codec))
- .collect::<Result<Vec<_>, _>>()?,
+ Some(e) => serialize_exprs(e, codec)?,
None => vec![],
},
},
@@ -801,10 +780,7 @@ pub fn serialize_expr(
))
}
Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
- let args = args
- .iter()
- .map(|expr| serialize_expr(expr, codec))
- .collect::<Result<Vec<_>, Error>>()?;
+ let args = serialize_exprs(args, codec)?;
match func_def {
ScalarFunctionDefinition::BuiltIn(fun) => {
let fun: protobuf::ScalarFunction = fun.try_into()?;
@@ -997,10 +973,7 @@ pub fn serialize_expr(
}
Expr::Unnest(Unnest { exprs }) => {
let expr = protobuf::Unnest {
- exprs: exprs
- .iter()
- .map(|expr| serialize_expr(expr, codec))
- .collect::<Result<Vec<_>, Error>>()?,
+ exprs: serialize_exprs(exprs, codec)?,
};
protobuf::LogicalExprNode {
expr_type: Some(ExprType::Unnest(expr)),
@@ -1013,10 +986,7 @@ pub fn serialize_expr(
}) => {
let expr = Box::new(protobuf::InListNode {
expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)),
- list: list
- .iter()
- .map(|expr| serialize_expr(expr, codec))
- .collect::<Result<Vec<_>, Error>>()?,
+ list: serialize_exprs(list, codec)?,
negated: *negated,
});
protobuf::LogicalExprNode {
@@ -1077,18 +1047,12 @@ pub fn serialize_expr(
Expr::GroupingSet(GroupingSet::Cube(exprs)) =>
protobuf::LogicalExprNode {
expr_type: Some(ExprType::Cube(CubeNode {
- expr: exprs
- .iter()
- .map(|expr| serialize_expr(expr, codec))
- .collect::<Result<Vec<_>, Error>>()?,
+ expr: serialize_exprs(exprs, codec)?,
})),
},
Expr::GroupingSet(GroupingSet::Rollup(exprs)) =>
protobuf::LogicalExprNode {
expr_type: Some(ExprType::Rollup(RollupNode {
- expr: exprs
- .iter()
- .map(|expr| serialize_expr(expr, codec))
- .collect::<Result<Vec<_>, Error>>()?,
+ expr: serialize_exprs(exprs, codec)?,
})),
},
Expr::GroupingSet(GroupingSet::GroupingSets(exprs)) => {
@@ -1098,10 +1062,7 @@ pub fn serialize_expr(
.iter()
.map(|expr_list| {
Ok(LogicalExprList {
- expr: expr_list
- .iter()
- .map(|expr| serialize_expr(expr, codec))
- .collect::<Result<Vec<_>, Error>>()?,
+ expr: serialize_exprs(expr_list, codec)?,
})
})
.collect::<Result<Vec<_>, Error>>()?,
@@ -1680,3 +1641,16 @@ fn encode_scalar_nested_value(
_ => unreachable!(),
}
}
+
+/// Converts a vector of `Arc<arrow::Field>`s to `protobuf::Field`s
+fn convert_arc_fields_to_proto_fields<'a, I>(
+ fields: I,
+) -> Result<Vec<protobuf::Field>, Error>
+where
+ I: IntoIterator<Item = &'a Arc<Field>>,
+{
+ fields
+ .into_iter()
+ .map(|field| field.as_ref().try_into())
+ .collect::<Result<Vec<_>, Error>>()
+}
diff --git a/datafusion/proto/src/physical_plan/from_proto.rs
b/datafusion/proto/src/physical_plan/from_proto.rs
index ca54d4e803..aaca4dc482 100644
--- a/datafusion/proto/src/physical_plan/from_proto.rs
+++ b/datafusion/proto/src/physical_plan/from_proto.rs
@@ -83,10 +83,10 @@ pub fn parse_physical_sort_expr(
proto: &protobuf::PhysicalSortExprNode,
registry: &dyn FunctionRegistry,
input_schema: &Schema,
+ codec: &dyn PhysicalExtensionCodec,
) -> Result<PhysicalSortExpr> {
if let Some(expr) = &proto.expr {
- let codec = DefaultPhysicalExtensionCodec {};
- let expr = parse_physical_expr(expr.as_ref(), registry, input_schema,
&codec)?;
+ let expr = parse_physical_expr(expr.as_ref(), registry, input_schema,
codec)?;
let options = SortOptions {
descending: !proto.asc,
nulls_first: proto.nulls_first,
@@ -109,22 +109,12 @@ pub fn parse_physical_sort_exprs(
proto: &[protobuf::PhysicalSortExprNode],
registry: &dyn FunctionRegistry,
input_schema: &Schema,
+ codec: &dyn PhysicalExtensionCodec,
) -> Result<Vec<PhysicalSortExpr>> {
proto
.iter()
.map(|sort_expr| {
- if let Some(expr) = &sort_expr.expr {
- let codec = DefaultPhysicalExtensionCodec {};
- let expr =
- parse_physical_expr(expr.as_ref(), registry, input_schema,
&codec)?;
- let options = SortOptions {
- descending: !sort_expr.asc,
- nulls_first: sort_expr.nulls_first,
- };
- Ok(PhysicalSortExpr { expr, options })
- } else {
- Err(proto_error("Unexpected empty physical expression"))
- }
+ parse_physical_sort_expr(sort_expr, registry, input_schema, codec)
})
.collect::<Result<Vec<_>>>()
}
@@ -144,23 +134,14 @@ pub fn parse_physical_window_expr(
input_schema: &Schema,
) -> Result<Arc<dyn WindowExpr>> {
let codec = DefaultPhysicalExtensionCodec {};
- let window_node_expr = proto
- .args
- .iter()
- .map(|e| parse_physical_expr(e, registry, input_schema, &codec))
- .collect::<Result<Vec<_>>>()?;
+ let window_node_expr =
+ parse_physical_exprs(&proto.args, registry, input_schema, &codec)?;
- let partition_by = proto
- .partition_by
- .iter()
- .map(|p| parse_physical_expr(p, registry, input_schema, &codec))
- .collect::<Result<Vec<_>>>()?;
+ let partition_by =
+ parse_physical_exprs(&proto.partition_by, registry, input_schema,
&codec)?;
- let order_by = proto
- .order_by
- .iter()
- .map(|o| parse_physical_sort_expr(o, registry, input_schema))
- .collect::<Result<Vec<_>>>()?;
+ let order_by =
+ parse_physical_sort_exprs(&proto.order_by, registry, input_schema,
&codec)?;
let window_frame = proto
.window_frame
@@ -186,6 +167,21 @@ pub fn parse_physical_window_expr(
)
}
+pub fn parse_physical_exprs<'a, I>(
+ protos: I,
+ registry: &dyn FunctionRegistry,
+ input_schema: &Schema,
+ codec: &dyn PhysicalExtensionCodec,
+) -> Result<Vec<Arc<dyn PhysicalExpr>>>
+where
+ I: IntoIterator<Item = &'a protobuf::PhysicalExprNode>,
+{
+ protos
+ .into_iter()
+ .map(|p| parse_physical_expr(p, registry, input_schema, codec))
+ .collect::<Result<Vec<_>>>()
+}
+
/// Parses a physical expression from a protobuf.
///
/// # Arguments
@@ -276,10 +272,7 @@ pub fn parse_physical_expr(
"expr",
input_schema,
)?,
- e.list
- .iter()
- .map(|x| parse_physical_expr(x, registry, input_schema, codec))
- .collect::<Result<Vec<_>, _>>()?,
+ parse_physical_exprs(&e.list, registry, input_schema, codec)?,
&e.negated,
input_schema,
)?,
@@ -339,11 +332,7 @@ pub fn parse_physical_expr(
)
})?;
- let args = e
- .args
- .iter()
- .map(|x| parse_physical_expr(x, registry, input_schema, codec))
- .collect::<Result<Vec<_>, _>>()?;
+ let args = parse_physical_exprs(&e.args, registry, input_schema,
codec)?;
// TODO Do not create new the ExecutionProps
let execution_props = ExecutionProps::new();
@@ -363,11 +352,7 @@ pub fn parse_physical_expr(
let signature = udf.signature();
let scalar_fun_def = ScalarFunctionDefinition::UDF(udf.clone());
- let args = e
- .args
- .iter()
- .map(|x| parse_physical_expr(x, registry, input_schema, codec))
- .collect::<Result<Vec<_>, _>>()?;
+ let args = parse_physical_exprs(&e.args, registry, input_schema,
codec)?;
Arc::new(ScalarFunctionExpr::new(
e.name.as_str(),
@@ -452,11 +437,12 @@ pub fn parse_protobuf_hash_partitioning(
match partitioning {
Some(hash_part) => {
let codec = DefaultPhysicalExtensionCodec {};
- let expr = hash_part
- .hash_expr
- .iter()
- .map(|e| parse_physical_expr(e, registry, input_schema,
&codec))
- .collect::<Result<Vec<Arc<dyn PhysicalExpr>>, _>>()?;
+ let expr = parse_physical_exprs(
+ &hash_part.hash_expr,
+ registry,
+ input_schema,
+ &codec,
+ )?;
Ok(Some(Partitioning::Hash(
expr,
@@ -517,24 +503,12 @@ pub fn parse_protobuf_file_scan_config(
let mut output_ordering = vec![];
for node_collection in &proto.output_ordering {
let codec = DefaultPhysicalExtensionCodec {};
- let sort_expr = node_collection
- .physical_sort_expr_nodes
- .iter()
- .map(|node| {
- let expr = node
- .expr
- .as_ref()
- .map(|e| parse_physical_expr(e.as_ref(), registry,
&schema, &codec))
- .unwrap()?;
- Ok(PhysicalSortExpr {
- expr,
- options: SortOptions {
- descending: !node.asc,
- nulls_first: node.nulls_first,
- },
- })
- })
- .collect::<Result<Vec<PhysicalSortExpr>>>()?;
+ let sort_expr = parse_physical_sort_exprs(
+ &node_collection.physical_sort_expr_nodes,
+ registry,
+ &schema,
+ &codec,
+ )?;
output_ordering.push(sort_expr);
}
diff --git a/datafusion/proto/src/physical_plan/mod.rs
b/datafusion/proto/src/physical_plan/mod.rs
index da31c5e762..00dacffe06 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -48,7 +48,7 @@ use datafusion::datasource::physical_plan::ParquetExec;
use datafusion::datasource::physical_plan::{AvroExec, CsvExec};
use datafusion::execution::runtime_env::RuntimeEnv;
use datafusion::execution::FunctionRegistry;
-use datafusion::physical_expr::PhysicalExprRef;
+use datafusion::physical_expr::{PhysicalExprRef, PhysicalSortRequirement};
use datafusion::physical_plan::aggregates::{create_aggregate_expr,
AggregateMode};
use datafusion::physical_plan::aggregates::{AggregateExec, PhysicalGroupBy};
use datafusion::physical_plan::analyze::AnalyzeExec;
@@ -492,7 +492,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
let input_phy_expr: Vec<Arc<dyn PhysicalExpr>>
= agg_node.expr.iter()
.map(|e| parse_physical_expr(e, registry,
&physical_schema, extension_codec).unwrap()).collect();
let ordering_req: Vec<PhysicalSortExpr> =
agg_node.ordering_req.iter()
- .map(|e| parse_physical_sort_expr(e,
registry, &physical_schema).unwrap()).collect();
+ .map(|e| parse_physical_sort_expr(e,
registry, &physical_schema, extension_codec).unwrap()).collect();
agg_node.aggregate_function.as_ref().map(|func| {
match func {
AggregateFunction::AggrFunction(i) => {
@@ -736,6 +736,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
&sym_join.left_sort_exprs,
registry,
&left_schema,
+ extension_codec,
)?;
let left_sort_exprs = if left_sort_exprs.is_empty() {
None
@@ -747,6 +748,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
&sym_join.right_sort_exprs,
registry,
&right_schema,
+ extension_codec,
)?;
let right_sort_exprs = if right_sort_exprs.is_empty() {
None
@@ -1018,14 +1020,13 @@ impl AsExecutionPlan for PhysicalPlanNode {
.sort_order
.as_ref()
.map(|collection| {
- collection
- .physical_sort_expr_nodes
- .iter()
- .map(|proto| {
- parse_physical_sort_expr(proto, registry,
&sink_schema)
- .map(Into::into)
- })
- .collect::<Result<Vec<_>>>()
+ parse_physical_sort_exprs(
+ &collection.physical_sort_expr_nodes,
+ registry,
+ &sink_schema,
+ extension_codec,
+ )
+ .map(|item|
PhysicalSortRequirement::from_sort_exprs(&item))
})
.transpose()?;
Ok(Arc::new(FileSinkExec::new(
@@ -1049,14 +1050,13 @@ impl AsExecutionPlan for PhysicalPlanNode {
.sort_order
.as_ref()
.map(|collection| {
- collection
- .physical_sort_expr_nodes
- .iter()
- .map(|proto| {
- parse_physical_sort_expr(proto, registry,
&sink_schema)
- .map(Into::into)
- })
- .collect::<Result<Vec<_>>>()
+ parse_physical_sort_exprs(
+ &collection.physical_sort_expr_nodes,
+ registry,
+ &sink_schema,
+ extension_codec,
+ )
+ .map(|item|
PhysicalSortRequirement::from_sort_exprs(&item))
})
.transpose()?;
Ok(Arc::new(FileSinkExec::new(
@@ -1080,14 +1080,13 @@ impl AsExecutionPlan for PhysicalPlanNode {
.sort_order
.as_ref()
.map(|collection| {
- collection
- .physical_sort_expr_nodes
- .iter()
- .map(|proto| {
- parse_physical_sort_expr(proto, registry,
&sink_schema)
- .map(Into::into)
- })
- .collect::<Result<Vec<_>>>()
+ parse_physical_sort_exprs(
+ &collection.physical_sort_expr_nodes,
+ registry,
+ &sink_schema,
+ extension_codec,
+ )
+ .map(|item|
PhysicalSortRequirement::from_sort_exprs(&item))
})
.transpose()?;
Ok(Arc::new(FileSinkExec::new(
diff --git a/datafusion/proto/src/physical_plan/to_proto.rs
b/datafusion/proto/src/physical_plan/to_proto.rs
index b66709d0c5..e1574f48fb 100644
--- a/datafusion/proto/src/physical_plan/to_proto.rs
+++ b/datafusion/proto/src/physical_plan/to_proto.rs
@@ -79,18 +79,10 @@ impl TryFrom<Arc<dyn AggregateExpr>> for
protobuf::PhysicalExprNode {
fn try_from(a: Arc<dyn AggregateExpr>) -> Result<Self, Self::Error> {
let codec = DefaultPhysicalExtensionCodec {};
- let expressions: Vec<protobuf::PhysicalExprNode> = a
- .expressions()
- .iter()
- .map(|e| serialize_physical_expr(e.clone(), &codec))
- .collect::<Result<Vec<_>>>()?;
+ let expressions = serialize_physical_exprs(a.expressions(), &codec)?;
- let ordering_req: Vec<protobuf::PhysicalSortExprNode> = a
- .order_bys()
- .unwrap_or(&[])
- .iter()
- .map(|e| e.clone().try_into())
- .collect::<Result<Vec<_>>>()?;
+ let ordering_req = a.order_bys().unwrap_or(&[]).to_vec();
+ let ordering_req = serialize_physical_sort_exprs(ordering_req,
&codec)?;
if let Some(a) = a.as_any().downcast_ref::<AggregateFunctionExpr>() {
let name = a.fun().name().to_string();
@@ -245,22 +237,12 @@ impl TryFrom<Arc<dyn WindowExpr>> for
protobuf::PhysicalWindowExprNode {
return not_impl_err!("WindowExpr not supported: {window_expr:?}");
};
let codec = DefaultPhysicalExtensionCodec {};
- let args = args
- .into_iter()
- .map(|e| serialize_physical_expr(e, &codec))
- .collect::<Result<Vec<protobuf::PhysicalExprNode>>>()?;
-
- let partition_by = window_expr
- .partition_by()
- .iter()
- .map(|p| serialize_physical_expr(p.clone(), &codec))
- .collect::<Result<Vec<protobuf::PhysicalExprNode>>>()?;
+ let args = serialize_physical_exprs(args, &codec)?;
+ let partition_by =
+ serialize_physical_exprs(window_expr.partition_by().to_vec(),
&codec)?;
- let order_by = window_expr
- .order_by()
- .iter()
- .map(|o| o.clone().try_into())
- .collect::<Result<Vec<protobuf::PhysicalSortExprNode>>>()?;
+ let order_by =
+ serialize_physical_sort_exprs(window_expr.order_by().to_vec(),
&codec)?;
let window_frame: protobuf::WindowFrame = window_frame
.as_ref()
@@ -381,6 +363,45 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) ->
Result<AggrFn> {
Ok(AggrFn { inner, distinct })
}
+pub fn serialize_physical_sort_exprs<I>(
+ sort_exprs: I,
+ codec: &dyn PhysicalExtensionCodec,
+) -> Result<Vec<protobuf::PhysicalSortExprNode>, DataFusionError>
+where
+ I: IntoIterator<Item = PhysicalSortExpr>,
+{
+ sort_exprs
+ .into_iter()
+ .map(|sort_expr| serialize_physical_sort_expr(sort_expr, codec))
+ .collect()
+}
+
+pub fn serialize_physical_sort_expr(
+ sort_expr: PhysicalSortExpr,
+ codec: &dyn PhysicalExtensionCodec,
+) -> Result<protobuf::PhysicalSortExprNode, DataFusionError> {
+ let PhysicalSortExpr { expr, options } = sort_expr;
+ let expr = serialize_physical_expr(expr, codec)?;
+ Ok(PhysicalSortExprNode {
+ expr: Some(Box::new(expr)),
+ asc: !options.descending,
+ nulls_first: options.nulls_first,
+ })
+}
+
+pub fn serialize_physical_exprs<I>(
+ values: I,
+ codec: &dyn PhysicalExtensionCodec,
+) -> Result<Vec<protobuf::PhysicalExprNode>, DataFusionError>
+where
+ I: IntoIterator<Item = Arc<dyn PhysicalExpr>>,
+{
+ values
+ .into_iter()
+ .map(|value| serialize_physical_expr(value, codec))
+ .collect()
+}
+
/// Serialize a `PhysicalExpr` to default protobuf representation.
///
/// If required, a [`PhysicalExtensionCodec`] can be provided which can handle
@@ -488,27 +509,16 @@ pub fn serialize_physical_expr(
})
} else if let Some(expr) = expr.downcast_ref::<InListExpr>() {
Ok(protobuf::PhysicalExprNode {
- expr_type: Some(
- protobuf::physical_expr_node::ExprType::InList(
- Box::new(
- protobuf::PhysicalInListNode {
- expr: Some(Box::new(serialize_physical_expr(
- expr.expr().to_owned(),
- codec,
- )?)),
- list: expr
- .list()
- .iter()
- .map(|a| serialize_physical_expr(a.clone(),
codec))
- .collect::<Result<
- Vec<protobuf::PhysicalExprNode>,
- DataFusionError,
- >>()?,
- negated: expr.negated(),
- },
- ),
- ),
- ),
+ expr_type:
Some(protobuf::physical_expr_node::ExprType::InList(Box::new(
+ protobuf::PhysicalInListNode {
+ expr: Some(Box::new(serialize_physical_expr(
+ expr.expr().to_owned(),
+ codec,
+ )?)),
+ list: serialize_physical_exprs(expr.list().to_vec(),
codec)?,
+ negated: expr.negated(),
+ },
+ ))),
})
} else if let Some(expr) = expr.downcast_ref::<NegativeExpr>() {
Ok(protobuf::PhysicalExprNode {
@@ -552,11 +562,7 @@ pub fn serialize_physical_expr(
))),
})
} else if let Some(expr) = expr.downcast_ref::<ScalarFunctionExpr>() {
- let args: Vec<protobuf::PhysicalExprNode> = expr
- .args()
- .iter()
- .map(|e| serialize_physical_expr(e.to_owned(), codec))
- .collect::<Result<Vec<_>, _>>()?;
+ let args = serialize_physical_exprs(expr.args().to_vec(), codec)?;
if let Ok(fun) = BuiltinScalarFunction::from_str(expr.name()) {
let fun: protobuf::ScalarFunction = (&fun).try_into()?;
@@ -754,18 +760,8 @@ impl TryFrom<&FileScanConfig> for
protobuf::FileScanExecConf {
let mut output_orderings = vec![];
for order in &conf.output_ordering {
- let expr_node_vec = order
- .iter()
- .map(|sort_expr| {
- let expr = serialize_physical_expr(sort_expr.expr.clone(),
&codec)?;
- Ok(PhysicalSortExprNode {
- expr: Some(Box::new(expr)),
- asc: !sort_expr.options.descending,
- nulls_first: sort_expr.options.nulls_first,
- })
- })
- .collect::<Result<Vec<PhysicalSortExprNode>>>()?;
- output_orderings.push(expr_node_vec)
+ let ordering = serialize_physical_sort_exprs(order.to_vec(),
&codec)?;
+ output_orderings.push(ordering)
}
// Fields must be added to the schema so that they can persist in the
protobuf