This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 02326998f0 Add extension hooks for encoding and decoding UDAFs and
UDWFs (#11417)
02326998f0 is described below
commit 02326998f07a13fda0c93988bf13853413c4a2b2
Author: Georgi Krastev <[email protected]>
AuthorDate: Wed Jul 17 00:52:20 2024 +0300
Add extension hooks for encoding and decoding UDAFs and UDWFs (#11417)
* Add extension hooks for encoding and decoding UDAFs and UDWFs
* Add tests for encoding and decoding UDAF
---
.../examples/composed_extension_codec.rs | 80 +++----
.../physical-expr-common/src/aggregate/mod.rs | 5 +
datafusion/proto/proto/datafusion.proto | 35 +--
datafusion/proto/src/generated/pbjson.rs | 102 +++++++++
datafusion/proto/src/generated/prost.rs | 10 +
datafusion/proto/src/logical_plan/file_formats.rs | 80 -------
datafusion/proto/src/logical_plan/from_proto.rs | 42 ++--
datafusion/proto/src/logical_plan/mod.rs | 22 +-
datafusion/proto/src/logical_plan/to_proto.rs | 84 ++++---
datafusion/proto/src/physical_plan/from_proto.rs | 6 +-
datafusion/proto/src/physical_plan/mod.rs | 23 +-
datafusion/proto/src/physical_plan/to_proto.rs | 122 +++++-----
datafusion/proto/tests/cases/mod.rs | 99 ++++++++
.../proto/tests/cases/roundtrip_logical_plan.rs | 171 ++++++--------
.../proto/tests/cases/roundtrip_physical_plan.rs | 251 +++++++++++++--------
15 files changed, 686 insertions(+), 446 deletions(-)
diff --git a/datafusion-examples/examples/composed_extension_codec.rs
b/datafusion-examples/examples/composed_extension_codec.rs
index 43c6daba21..5c34eccf26 100644
--- a/datafusion-examples/examples/composed_extension_codec.rs
+++ b/datafusion-examples/examples/composed_extension_codec.rs
@@ -30,18 +30,19 @@
//! DeltaScan
//! ```
+use std::any::Any;
+use std::fmt::Debug;
+use std::ops::Deref;
+use std::sync::Arc;
+
use datafusion::common::Result;
use datafusion::physical_plan::{DisplayAs, ExecutionPlan};
use datafusion::prelude::SessionContext;
-use datafusion_common::internal_err;
+use datafusion_common::{internal_err, DataFusionError};
use datafusion_expr::registry::FunctionRegistry;
-use datafusion_expr::ScalarUDF;
+use datafusion_expr::{AggregateUDF, ScalarUDF};
use datafusion_proto::physical_plan::{AsExecutionPlan, PhysicalExtensionCodec};
use datafusion_proto::protobuf;
-use std::any::Any;
-use std::fmt::Debug;
-use std::ops::Deref;
-use std::sync::Arc;
#[tokio::main]
async fn main() {
@@ -239,6 +240,25 @@ struct ComposedPhysicalExtensionCodec {
codecs: Vec<Arc<dyn PhysicalExtensionCodec>>,
}
+impl ComposedPhysicalExtensionCodec {
+ fn try_any<T>(
+ &self,
+ mut f: impl FnMut(&dyn PhysicalExtensionCodec) -> Result<T>,
+ ) -> Result<T> {
+ let mut last_err = None;
+ for codec in &self.codecs {
+ match f(codec.as_ref()) {
+ Ok(node) => return Ok(node),
+ Err(err) => last_err = Some(err),
+ }
+ }
+
+ Err(last_err.unwrap_or_else(|| {
+ DataFusionError::NotImplemented("Empty list of composed
codecs".to_owned())
+ }))
+ }
+}
+
impl PhysicalExtensionCodec for ComposedPhysicalExtensionCodec {
fn try_decode(
&self,
@@ -246,46 +266,26 @@ impl PhysicalExtensionCodec for
ComposedPhysicalExtensionCodec {
inputs: &[Arc<dyn ExecutionPlan>],
registry: &dyn FunctionRegistry,
) -> Result<Arc<dyn ExecutionPlan>> {
- let mut last_err = None;
- for codec in &self.codecs {
- match codec.try_decode(buf, inputs, registry) {
- Ok(plan) => return Ok(plan),
- Err(e) => last_err = Some(e),
- }
- }
- Err(last_err.unwrap())
+ self.try_any(|codec| codec.try_decode(buf, inputs, registry))
}
fn try_encode(&self, node: Arc<dyn ExecutionPlan>, buf: &mut Vec<u8>) ->
Result<()> {
- let mut last_err = None;
- for codec in &self.codecs {
- match codec.try_encode(node.clone(), buf) {
- Ok(_) => return Ok(()),
- Err(e) => last_err = Some(e),
- }
- }
- Err(last_err.unwrap())
+ self.try_any(|codec| codec.try_encode(node.clone(), buf))
}
- fn try_decode_udf(&self, name: &str, _buf: &[u8]) ->
Result<Arc<ScalarUDF>> {
- let mut last_err = None;
- for codec in &self.codecs {
- match codec.try_decode_udf(name, _buf) {
- Ok(plan) => return Ok(plan),
- Err(e) => last_err = Some(e),
- }
- }
- Err(last_err.unwrap())
+ fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result<Arc<ScalarUDF>>
{
+ self.try_any(|codec| codec.try_decode_udf(name, buf))
}
- fn try_encode_udf(&self, _node: &ScalarUDF, _buf: &mut Vec<u8>) ->
Result<()> {
- let mut last_err = None;
- for codec in &self.codecs {
- match codec.try_encode_udf(_node, _buf) {
- Ok(_) => return Ok(()),
- Err(e) => last_err = Some(e),
- }
- }
- Err(last_err.unwrap())
+ fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec<u8>) ->
Result<()> {
+ self.try_any(|codec| codec.try_encode_udf(node, buf))
+ }
+
+ fn try_decode_udaf(&self, name: &str, buf: &[u8]) ->
Result<Arc<AggregateUDF>> {
+ self.try_any(|codec| codec.try_decode_udaf(name, buf))
+ }
+
+ fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec<u8>) ->
Result<()> {
+ self.try_any(|codec| codec.try_encode_udaf(node, buf))
}
}
diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs
b/datafusion/physical-expr-common/src/aggregate/mod.rs
index db4581a622..0e245fd0a6 100644
--- a/datafusion/physical-expr-common/src/aggregate/mod.rs
+++ b/datafusion/physical-expr-common/src/aggregate/mod.rs
@@ -283,6 +283,11 @@ impl AggregateFunctionExpr {
pub fn is_distinct(&self) -> bool {
self.is_distinct
}
+
+ /// Return if the aggregation ignores nulls
+ pub fn ignore_nulls(&self) -> bool {
+ self.ignore_nulls
+ }
}
impl AggregateExpr for AggregateFunctionExpr {
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index 9ef884531e..dc551778c5 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -164,7 +164,7 @@ message CreateExternalTableNode {
map<string, string> options = 8;
datafusion_common.Constraints constraints = 12;
map<string, LogicalExprNode> column_defaults = 13;
- }
+}
message PrepareNode {
string name = 1;
@@ -249,24 +249,24 @@ message DistinctOnNode {
}
message CopyToNode {
- LogicalPlanNode input = 1;
- string output_url = 2;
- bytes file_type = 3;
- repeated string partition_by = 7;
+ LogicalPlanNode input = 1;
+ string output_url = 2;
+ bytes file_type = 3;
+ repeated string partition_by = 7;
}
message UnnestNode {
- LogicalPlanNode input = 1;
- repeated datafusion_common.Column exec_columns = 2;
- repeated uint64 list_type_columns = 3;
- repeated uint64 struct_type_columns = 4;
- repeated uint64 dependency_indices = 5;
- datafusion_common.DfSchema schema = 6;
- UnnestOptions options = 7;
+ LogicalPlanNode input = 1;
+ repeated datafusion_common.Column exec_columns = 2;
+ repeated uint64 list_type_columns = 3;
+ repeated uint64 struct_type_columns = 4;
+ repeated uint64 dependency_indices = 5;
+ datafusion_common.DfSchema schema = 6;
+ UnnestOptions options = 7;
}
message UnnestOptions {
- bool preserve_nulls = 1;
+ bool preserve_nulls = 1;
}
message UnionNode {
@@ -488,8 +488,8 @@ enum AggregateFunction {
// BIT_AND = 19;
// BIT_OR = 20;
// BIT_XOR = 21;
-// BOOL_AND = 22;
-// BOOL_OR = 23;
+ // BOOL_AND = 22;
+ // BOOL_OR = 23;
// REGR_SLOPE = 26;
// REGR_INTERCEPT = 27;
// REGR_COUNT = 28;
@@ -517,6 +517,7 @@ message AggregateUDFExprNode {
bool distinct = 5;
LogicalExprNode filter = 3;
repeated LogicalExprNode order_by = 4;
+ optional bytes fun_definition = 6;
}
message ScalarUDFExprNode {
@@ -551,6 +552,7 @@ message WindowExprNode {
repeated LogicalExprNode order_by = 6;
// repeated LogicalExprNode filter = 7;
WindowFrame window_frame = 8;
+ optional bytes fun_definition = 10;
}
message BetweenNode {
@@ -856,6 +858,8 @@ message PhysicalAggregateExprNode {
repeated PhysicalExprNode expr = 2;
repeated PhysicalSortExprNode ordering_req = 5;
bool distinct = 3;
+ bool ignore_nulls = 6;
+ optional bytes fun_definition = 7;
}
message PhysicalWindowExprNode {
@@ -869,6 +873,7 @@ message PhysicalWindowExprNode {
repeated PhysicalSortExprNode order_by = 6;
WindowFrame window_frame = 7;
string name = 8;
+ optional bytes fun_definition = 9;
}
message PhysicalIsNull {
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index fa989480fa..8f77c24bd9 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -829,6 +829,9 @@ impl serde::Serialize for AggregateUdfExprNode {
if !self.order_by.is_empty() {
len += 1;
}
+ if self.fun_definition.is_some() {
+ len += 1;
+ }
let mut struct_ser =
serializer.serialize_struct("datafusion.AggregateUDFExprNode", len)?;
if !self.fun_name.is_empty() {
struct_ser.serialize_field("funName", &self.fun_name)?;
@@ -845,6 +848,10 @@ impl serde::Serialize for AggregateUdfExprNode {
if !self.order_by.is_empty() {
struct_ser.serialize_field("orderBy", &self.order_by)?;
}
+ if let Some(v) = self.fun_definition.as_ref() {
+ #[allow(clippy::needless_borrow)]
+ struct_ser.serialize_field("funDefinition",
pbjson::private::base64::encode(&v).as_str())?;
+ }
struct_ser.end()
}
}
@@ -862,6 +869,8 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode {
"filter",
"order_by",
"orderBy",
+ "fun_definition",
+ "funDefinition",
];
#[allow(clippy::enum_variant_names)]
@@ -871,6 +880,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode {
Distinct,
Filter,
OrderBy,
+ FunDefinition,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
@@ -897,6 +907,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode {
"distinct" => Ok(GeneratedField::Distinct),
"filter" => Ok(GeneratedField::Filter),
"orderBy" | "order_by" =>
Ok(GeneratedField::OrderBy),
+ "funDefinition" | "fun_definition" =>
Ok(GeneratedField::FunDefinition),
_ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
}
}
@@ -921,6 +932,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode {
let mut distinct__ = None;
let mut filter__ = None;
let mut order_by__ = None;
+ let mut fun_definition__ = None;
while let Some(k) = map_.next_key()? {
match k {
GeneratedField::FunName => {
@@ -953,6 +965,14 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode
{
}
order_by__ = Some(map_.next_value()?);
}
+ GeneratedField::FunDefinition => {
+ if fun_definition__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("funDefinition"));
+ }
+ fun_definition__ =
+
map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x|
x.0)
+ ;
+ }
}
}
Ok(AggregateUdfExprNode {
@@ -961,6 +981,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode {
distinct: distinct__.unwrap_or_default(),
filter: filter__,
order_by: order_by__.unwrap_or_default(),
+ fun_definition: fun_definition__,
})
}
}
@@ -12631,6 +12652,12 @@ impl serde::Serialize for PhysicalAggregateExprNode {
if self.distinct {
len += 1;
}
+ if self.ignore_nulls {
+ len += 1;
+ }
+ if self.fun_definition.is_some() {
+ len += 1;
+ }
if self.aggregate_function.is_some() {
len += 1;
}
@@ -12644,6 +12671,13 @@ impl serde::Serialize for PhysicalAggregateExprNode {
if self.distinct {
struct_ser.serialize_field("distinct", &self.distinct)?;
}
+ if self.ignore_nulls {
+ struct_ser.serialize_field("ignoreNulls", &self.ignore_nulls)?;
+ }
+ if let Some(v) = self.fun_definition.as_ref() {
+ #[allow(clippy::needless_borrow)]
+ struct_ser.serialize_field("funDefinition",
pbjson::private::base64::encode(&v).as_str())?;
+ }
if let Some(v) = self.aggregate_function.as_ref() {
match v {
physical_aggregate_expr_node::AggregateFunction::AggrFunction(v) => {
@@ -12670,6 +12704,10 @@ impl<'de> serde::Deserialize<'de> for
PhysicalAggregateExprNode {
"ordering_req",
"orderingReq",
"distinct",
+ "ignore_nulls",
+ "ignoreNulls",
+ "fun_definition",
+ "funDefinition",
"aggr_function",
"aggrFunction",
"user_defined_aggr_function",
@@ -12681,6 +12719,8 @@ impl<'de> serde::Deserialize<'de> for
PhysicalAggregateExprNode {
Expr,
OrderingReq,
Distinct,
+ IgnoreNulls,
+ FunDefinition,
AggrFunction,
UserDefinedAggrFunction,
}
@@ -12707,6 +12747,8 @@ impl<'de> serde::Deserialize<'de> for
PhysicalAggregateExprNode {
"expr" => Ok(GeneratedField::Expr),
"orderingReq" | "ordering_req" =>
Ok(GeneratedField::OrderingReq),
"distinct" => Ok(GeneratedField::Distinct),
+ "ignoreNulls" | "ignore_nulls" =>
Ok(GeneratedField::IgnoreNulls),
+ "funDefinition" | "fun_definition" =>
Ok(GeneratedField::FunDefinition),
"aggrFunction" | "aggr_function" =>
Ok(GeneratedField::AggrFunction),
"userDefinedAggrFunction" |
"user_defined_aggr_function" => Ok(GeneratedField::UserDefinedAggrFunction),
_ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
@@ -12731,6 +12773,8 @@ impl<'de> serde::Deserialize<'de> for
PhysicalAggregateExprNode {
let mut expr__ = None;
let mut ordering_req__ = None;
let mut distinct__ = None;
+ let mut ignore_nulls__ = None;
+ let mut fun_definition__ = None;
let mut aggregate_function__ = None;
while let Some(k) = map_.next_key()? {
match k {
@@ -12752,6 +12796,20 @@ impl<'de> serde::Deserialize<'de> for
PhysicalAggregateExprNode {
}
distinct__ = Some(map_.next_value()?);
}
+ GeneratedField::IgnoreNulls => {
+ if ignore_nulls__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("ignoreNulls"));
+ }
+ ignore_nulls__ = Some(map_.next_value()?);
+ }
+ GeneratedField::FunDefinition => {
+ if fun_definition__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("funDefinition"));
+ }
+ fun_definition__ =
+
map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x|
x.0)
+ ;
+ }
GeneratedField::AggrFunction => {
if aggregate_function__.is_some() {
return
Err(serde::de::Error::duplicate_field("aggrFunction"));
@@ -12770,6 +12828,8 @@ impl<'de> serde::Deserialize<'de> for
PhysicalAggregateExprNode {
expr: expr__.unwrap_or_default(),
ordering_req: ordering_req__.unwrap_or_default(),
distinct: distinct__.unwrap_or_default(),
+ ignore_nulls: ignore_nulls__.unwrap_or_default(),
+ fun_definition: fun_definition__,
aggregate_function: aggregate_function__,
})
}
@@ -15832,6 +15892,9 @@ impl serde::Serialize for PhysicalWindowExprNode {
if !self.name.is_empty() {
len += 1;
}
+ if self.fun_definition.is_some() {
+ len += 1;
+ }
if self.window_function.is_some() {
len += 1;
}
@@ -15851,6 +15914,10 @@ impl serde::Serialize for PhysicalWindowExprNode {
if !self.name.is_empty() {
struct_ser.serialize_field("name", &self.name)?;
}
+ if let Some(v) = self.fun_definition.as_ref() {
+ #[allow(clippy::needless_borrow)]
+ struct_ser.serialize_field("funDefinition",
pbjson::private::base64::encode(&v).as_str())?;
+ }
if let Some(v) = self.window_function.as_ref() {
match v {
physical_window_expr_node::WindowFunction::AggrFunction(v) => {
@@ -15886,6 +15953,8 @@ impl<'de> serde::Deserialize<'de> for
PhysicalWindowExprNode {
"window_frame",
"windowFrame",
"name",
+ "fun_definition",
+ "funDefinition",
"aggr_function",
"aggrFunction",
"built_in_function",
@@ -15901,6 +15970,7 @@ impl<'de> serde::Deserialize<'de> for
PhysicalWindowExprNode {
OrderBy,
WindowFrame,
Name,
+ FunDefinition,
AggrFunction,
BuiltInFunction,
UserDefinedAggrFunction,
@@ -15930,6 +16000,7 @@ impl<'de> serde::Deserialize<'de> for
PhysicalWindowExprNode {
"orderBy" | "order_by" =>
Ok(GeneratedField::OrderBy),
"windowFrame" | "window_frame" =>
Ok(GeneratedField::WindowFrame),
"name" => Ok(GeneratedField::Name),
+ "funDefinition" | "fun_definition" =>
Ok(GeneratedField::FunDefinition),
"aggrFunction" | "aggr_function" =>
Ok(GeneratedField::AggrFunction),
"builtInFunction" | "built_in_function" =>
Ok(GeneratedField::BuiltInFunction),
"userDefinedAggrFunction" |
"user_defined_aggr_function" => Ok(GeneratedField::UserDefinedAggrFunction),
@@ -15957,6 +16028,7 @@ impl<'de> serde::Deserialize<'de> for
PhysicalWindowExprNode {
let mut order_by__ = None;
let mut window_frame__ = None;
let mut name__ = None;
+ let mut fun_definition__ = None;
let mut window_function__ = None;
while let Some(k) = map_.next_key()? {
match k {
@@ -15990,6 +16062,14 @@ impl<'de> serde::Deserialize<'de> for
PhysicalWindowExprNode {
}
name__ = Some(map_.next_value()?);
}
+ GeneratedField::FunDefinition => {
+ if fun_definition__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("funDefinition"));
+ }
+ fun_definition__ =
+
map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x|
x.0)
+ ;
+ }
GeneratedField::AggrFunction => {
if window_function__.is_some() {
return
Err(serde::de::Error::duplicate_field("aggrFunction"));
@@ -16016,6 +16096,7 @@ impl<'de> serde::Deserialize<'de> for
PhysicalWindowExprNode {
order_by: order_by__.unwrap_or_default(),
window_frame: window_frame__,
name: name__.unwrap_or_default(),
+ fun_definition: fun_definition__,
window_function: window_function__,
})
}
@@ -20349,6 +20430,9 @@ impl serde::Serialize for WindowExprNode {
if self.window_frame.is_some() {
len += 1;
}
+ if self.fun_definition.is_some() {
+ len += 1;
+ }
if self.window_function.is_some() {
len += 1;
}
@@ -20365,6 +20449,10 @@ impl serde::Serialize for WindowExprNode {
if let Some(v) = self.window_frame.as_ref() {
struct_ser.serialize_field("windowFrame", v)?;
}
+ if let Some(v) = self.fun_definition.as_ref() {
+ #[allow(clippy::needless_borrow)]
+ struct_ser.serialize_field("funDefinition",
pbjson::private::base64::encode(&v).as_str())?;
+ }
if let Some(v) = self.window_function.as_ref() {
match v {
window_expr_node::WindowFunction::AggrFunction(v) => {
@@ -20402,6 +20490,8 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode {
"orderBy",
"window_frame",
"windowFrame",
+ "fun_definition",
+ "funDefinition",
"aggr_function",
"aggrFunction",
"built_in_function",
@@ -20416,6 +20506,7 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode {
PartitionBy,
OrderBy,
WindowFrame,
+ FunDefinition,
AggrFunction,
BuiltInFunction,
Udaf,
@@ -20445,6 +20536,7 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode {
"partitionBy" | "partition_by" =>
Ok(GeneratedField::PartitionBy),
"orderBy" | "order_by" =>
Ok(GeneratedField::OrderBy),
"windowFrame" | "window_frame" =>
Ok(GeneratedField::WindowFrame),
+ "funDefinition" | "fun_definition" =>
Ok(GeneratedField::FunDefinition),
"aggrFunction" | "aggr_function" =>
Ok(GeneratedField::AggrFunction),
"builtInFunction" | "built_in_function" =>
Ok(GeneratedField::BuiltInFunction),
"udaf" => Ok(GeneratedField::Udaf),
@@ -20472,6 +20564,7 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode {
let mut partition_by__ = None;
let mut order_by__ = None;
let mut window_frame__ = None;
+ let mut fun_definition__ = None;
let mut window_function__ = None;
while let Some(k) = map_.next_key()? {
match k {
@@ -20499,6 +20592,14 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode {
}
window_frame__ = map_.next_value()?;
}
+ GeneratedField::FunDefinition => {
+ if fun_definition__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("funDefinition"));
+ }
+ fun_definition__ =
+
map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x|
x.0)
+ ;
+ }
GeneratedField::AggrFunction => {
if window_function__.is_some() {
return
Err(serde::de::Error::duplicate_field("aggrFunction"));
@@ -20530,6 +20631,7 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode {
partition_by: partition_by__.unwrap_or_default(),
order_by: order_by__.unwrap_or_default(),
window_frame: window_frame__,
+ fun_definition: fun_definition__,
window_function: window_function__,
})
}
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index 8407e545fe..605c56fa94 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -756,6 +756,8 @@ pub struct AggregateUdfExprNode {
pub filter:
::core::option::Option<::prost::alloc::boxed::Box<LogicalExprNode>>,
#[prost(message, repeated, tag = "4")]
pub order_by: ::prost::alloc::vec::Vec<LogicalExprNode>,
+ #[prost(bytes = "vec", optional, tag = "6")]
+ pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec<u8>>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
@@ -779,6 +781,8 @@ pub struct WindowExprNode {
/// repeated LogicalExprNode filter = 7;
#[prost(message, optional, tag = "8")]
pub window_frame: ::core::option::Option<WindowFrame>,
+ #[prost(bytes = "vec", optional, tag = "10")]
+ pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec<u8>>,
#[prost(oneof = "window_expr_node::WindowFunction", tags = "1, 2, 3, 9")]
pub window_function:
::core::option::Option<window_expr_node::WindowFunction>,
}
@@ -1291,6 +1295,10 @@ pub struct PhysicalAggregateExprNode {
pub ordering_req: ::prost::alloc::vec::Vec<PhysicalSortExprNode>,
#[prost(bool, tag = "3")]
pub distinct: bool,
+ #[prost(bool, tag = "6")]
+ pub ignore_nulls: bool,
+ #[prost(bytes = "vec", optional, tag = "7")]
+ pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec<u8>>,
#[prost(oneof = "physical_aggregate_expr_node::AggregateFunction", tags =
"1, 4")]
pub aggregate_function: ::core::option::Option<
physical_aggregate_expr_node::AggregateFunction,
@@ -1320,6 +1328,8 @@ pub struct PhysicalWindowExprNode {
pub window_frame: ::core::option::Option<WindowFrame>,
#[prost(string, tag = "8")]
pub name: ::prost::alloc::string::String,
+ #[prost(bytes = "vec", optional, tag = "9")]
+ pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec<u8>>,
#[prost(oneof = "physical_window_expr_node::WindowFunction", tags = "1, 2,
3")]
pub window_function: ::core::option::Option<
physical_window_expr_node::WindowFunction,
diff --git a/datafusion/proto/src/logical_plan/file_formats.rs
b/datafusion/proto/src/logical_plan/file_formats.rs
index 106d563948..09e36a650b 100644
--- a/datafusion/proto/src/logical_plan/file_formats.rs
+++ b/datafusion/proto/src/logical_plan/file_formats.rs
@@ -86,22 +86,6 @@ impl LogicalExtensionCodec for CsvLogicalExtensionCodec {
) -> datafusion_common::Result<()> {
Ok(())
}
-
- fn try_decode_udf(
- &self,
- name: &str,
- __buf: &[u8],
- ) -> datafusion_common::Result<Arc<datafusion_expr::ScalarUDF>> {
- not_impl_err!("LogicalExtensionCodec is not provided for scalar
function {name}")
- }
-
- fn try_encode_udf(
- &self,
- __node: &datafusion_expr::ScalarUDF,
- __buf: &mut Vec<u8>,
- ) -> datafusion_common::Result<()> {
- Ok(())
- }
}
#[derive(Debug)]
@@ -162,22 +146,6 @@ impl LogicalExtensionCodec for JsonLogicalExtensionCodec {
) -> datafusion_common::Result<()> {
Ok(())
}
-
- fn try_decode_udf(
- &self,
- name: &str,
- __buf: &[u8],
- ) -> datafusion_common::Result<Arc<datafusion_expr::ScalarUDF>> {
- not_impl_err!("LogicalExtensionCodec is not provided for scalar
function {name}")
- }
-
- fn try_encode_udf(
- &self,
- __node: &datafusion_expr::ScalarUDF,
- __buf: &mut Vec<u8>,
- ) -> datafusion_common::Result<()> {
- Ok(())
- }
}
#[derive(Debug)]
@@ -238,22 +206,6 @@ impl LogicalExtensionCodec for
ParquetLogicalExtensionCodec {
) -> datafusion_common::Result<()> {
Ok(())
}
-
- fn try_decode_udf(
- &self,
- name: &str,
- __buf: &[u8],
- ) -> datafusion_common::Result<Arc<datafusion_expr::ScalarUDF>> {
- not_impl_err!("LogicalExtensionCodec is not provided for scalar
function {name}")
- }
-
- fn try_encode_udf(
- &self,
- __node: &datafusion_expr::ScalarUDF,
- __buf: &mut Vec<u8>,
- ) -> datafusion_common::Result<()> {
- Ok(())
- }
}
#[derive(Debug)]
@@ -314,22 +266,6 @@ impl LogicalExtensionCodec for ArrowLogicalExtensionCodec {
) -> datafusion_common::Result<()> {
Ok(())
}
-
- fn try_decode_udf(
- &self,
- name: &str,
- __buf: &[u8],
- ) -> datafusion_common::Result<Arc<datafusion_expr::ScalarUDF>> {
- not_impl_err!("LogicalExtensionCodec is not provided for scalar
function {name}")
- }
-
- fn try_encode_udf(
- &self,
- __node: &datafusion_expr::ScalarUDF,
- __buf: &mut Vec<u8>,
- ) -> datafusion_common::Result<()> {
- Ok(())
- }
}
#[derive(Debug)]
@@ -390,20 +326,4 @@ impl LogicalExtensionCodec for AvroLogicalExtensionCodec {
) -> datafusion_common::Result<()> {
Ok(())
}
-
- fn try_decode_udf(
- &self,
- name: &str,
- __buf: &[u8],
- ) -> datafusion_common::Result<Arc<datafusion_expr::ScalarUDF>> {
- not_impl_err!("LogicalExtensionCodec is not provided for scalar
function {name}")
- }
-
- fn try_encode_udf(
- &self,
- __node: &datafusion_expr::ScalarUDF,
- __buf: &mut Vec<u8>,
- ) -> datafusion_common::Result<()> {
- Ok(())
- }
}
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs
b/datafusion/proto/src/logical_plan/from_proto.rs
index 095c6a5097..b6b556a8ed 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -308,14 +308,17 @@ pub fn parse_expr(
let aggr_function = parse_i32_to_aggregate_function(i)?;
Ok(Expr::WindowFunction(WindowFunction::new(
-
datafusion_expr::expr::WindowFunctionDefinition::AggregateFunction(
- aggr_function,
- ),
- vec![parse_required_expr(expr.expr.as_deref(),
registry, "expr", codec)?],
+
expr::WindowFunctionDefinition::AggregateFunction(aggr_function),
+ vec![parse_required_expr(
+ expr.expr.as_deref(),
+ registry,
+ "expr",
+ codec,
+ )?],
partition_by,
order_by,
window_frame,
- None
+ None,
)))
}
window_expr_node::WindowFunction::BuiltInFunction(i) => {
@@ -329,26 +332,28 @@ pub fn parse_expr(
.unwrap_or_else(Vec::new);
Ok(Expr::WindowFunction(WindowFunction::new(
-
datafusion_expr::expr::WindowFunctionDefinition::BuiltInWindowFunction(
+ expr::WindowFunctionDefinition::BuiltInWindowFunction(
built_in_function,
),
args,
partition_by,
order_by,
window_frame,
- null_treatment
+ null_treatment,
)))
}
window_expr_node::WindowFunction::Udaf(udaf_name) => {
- let udaf_function = registry.udaf(udaf_name)?;
+ let udaf_function = match &expr.fun_definition {
+ Some(buf) => codec.try_decode_udaf(udaf_name, buf)?,
+ None => registry.udaf(udaf_name)?,
+ };
+
let args =
parse_optional_expr(expr.expr.as_deref(), registry,
codec)?
.map(|e| vec![e])
.unwrap_or_else(Vec::new);
Ok(Expr::WindowFunction(WindowFunction::new(
-
datafusion_expr::expr::WindowFunctionDefinition::AggregateUDF(
- udaf_function,
- ),
+
expr::WindowFunctionDefinition::AggregateUDF(udaf_function),
args,
partition_by,
order_by,
@@ -357,15 +362,17 @@ pub fn parse_expr(
)))
}
window_expr_node::WindowFunction::Udwf(udwf_name) => {
- let udwf_function = registry.udwf(udwf_name)?;
+ let udwf_function = match &expr.fun_definition {
+ Some(buf) => codec.try_decode_udwf(udwf_name, buf)?,
+ None => registry.udwf(udwf_name)?,
+ };
+
let args =
parse_optional_expr(expr.expr.as_deref(), registry,
codec)?
.map(|e| vec![e])
.unwrap_or_else(Vec::new);
Ok(Expr::WindowFunction(WindowFunction::new(
-
datafusion_expr::expr::WindowFunctionDefinition::WindowUDF(
- udwf_function,
- ),
+
expr::WindowFunctionDefinition::WindowUDF(udwf_function),
args,
partition_by,
order_by,
@@ -613,7 +620,10 @@ pub fn parse_expr(
)))
}
ExprType::AggregateUdfExpr(pb) => {
- let agg_fn = registry.udaf(pb.fun_name.as_str())?;
+ let agg_fn = match &pb.fun_definition {
+ Some(buf) => codec.try_decode_udaf(&pb.fun_name, buf)?,
+ None => registry.udaf(&pb.fun_name)?,
+ };
Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf(
agg_fn,
diff --git a/datafusion/proto/src/logical_plan/mod.rs
b/datafusion/proto/src/logical_plan/mod.rs
index 664cd7e115..2a963fb13c 100644
--- a/datafusion/proto/src/logical_plan/mod.rs
+++ b/datafusion/proto/src/logical_plan/mod.rs
@@ -51,7 +51,6 @@ use datafusion_common::{
context, internal_datafusion_err, internal_err, not_impl_err,
DataFusionError,
Result, TableReference,
};
-use datafusion_expr::Unnest;
use datafusion_expr::{
dml,
logical_plan::{
@@ -60,8 +59,9 @@ use datafusion_expr::{
EmptyRelation, Extension, Join, JoinConstraint, Limit, Prepare,
Projection,
Repartition, Sort, SubqueryAlias, TableScan, Values, Window,
},
- DistinctOn, DropView, Expr, LogicalPlan, LogicalPlanBuilder, ScalarUDF,
+ DistinctOn, DropView, Expr, LogicalPlan, LogicalPlanBuilder, ScalarUDF,
WindowUDF,
};
+use datafusion_expr::{AggregateUDF, Unnest};
use prost::bytes::BufMut;
use prost::Message;
@@ -144,6 +144,24 @@ pub trait LogicalExtensionCodec: Debug + Send + Sync {
fn try_encode_udf(&self, _node: &ScalarUDF, _buf: &mut Vec<u8>) ->
Result<()> {
Ok(())
}
+
+ fn try_decode_udaf(&self, name: &str, _buf: &[u8]) ->
Result<Arc<AggregateUDF>> {
+ not_impl_err!(
+ "LogicalExtensionCodec is not provided for aggregate function
{name}"
+ )
+ }
+
+ fn try_encode_udaf(&self, _node: &AggregateUDF, _buf: &mut Vec<u8>) ->
Result<()> {
+ Ok(())
+ }
+
+ fn try_decode_udwf(&self, name: &str, _buf: &[u8]) ->
Result<Arc<WindowUDF>> {
+ not_impl_err!("LogicalExtensionCodec is not provided for window
function {name}")
+ }
+
+ fn try_encode_udwf(&self, _node: &WindowUDF, _buf: &mut Vec<u8>) ->
Result<()> {
+ Ok(())
+ }
}
#[derive(Debug, Clone)]
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs
b/datafusion/proto/src/logical_plan/to_proto.rs
index d8f8ea002b..9607b918eb 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -319,25 +319,37 @@ pub fn serialize_expr(
// TODO: support null treatment in proto
null_treatment: _,
}) => {
- let window_function = match fun {
- WindowFunctionDefinition::AggregateFunction(fun) => {
+ let (window_function, fun_definition) = match fun {
+ WindowFunctionDefinition::AggregateFunction(fun) => (
protobuf::window_expr_node::WindowFunction::AggrFunction(
protobuf::AggregateFunction::from(fun).into(),
- )
- }
- WindowFunctionDefinition::BuiltInWindowFunction(fun) => {
+ ),
+ None,
+ ),
+ WindowFunctionDefinition::BuiltInWindowFunction(fun) => (
protobuf::window_expr_node::WindowFunction::BuiltInFunction(
protobuf::BuiltInWindowFunction::from(fun).into(),
- )
- }
+ ),
+ None,
+ ),
WindowFunctionDefinition::AggregateUDF(aggr_udf) => {
- protobuf::window_expr_node::WindowFunction::Udaf(
- aggr_udf.name().to_string(),
+ let mut buf = Vec::new();
+ let _ = codec.try_encode_udaf(aggr_udf, &mut buf);
+ (
+ protobuf::window_expr_node::WindowFunction::Udaf(
+ aggr_udf.name().to_string(),
+ ),
+ (!buf.is_empty()).then_some(buf),
)
}
WindowFunctionDefinition::WindowUDF(window_udf) => {
- protobuf::window_expr_node::WindowFunction::Udwf(
- window_udf.name().to_string(),
+ let mut buf = Vec::new();
+ let _ = codec.try_encode_udwf(window_udf, &mut buf);
+ (
+ protobuf::window_expr_node::WindowFunction::Udwf(
+ window_udf.name().to_string(),
+ ),
+ (!buf.is_empty()).then_some(buf),
)
}
};
@@ -358,6 +370,7 @@ pub fn serialize_expr(
partition_by,
order_by,
window_frame,
+ fun_definition,
});
protobuf::LogicalExprNode {
expr_type: Some(ExprType::WindowExpr(window_expr)),
@@ -395,23 +408,30 @@ pub fn serialize_expr(
expr_type:
Some(ExprType::AggregateExpr(Box::new(aggregate_expr))),
}
}
- AggregateFunctionDefinition::UDF(fun) => protobuf::LogicalExprNode
{
- expr_type: Some(ExprType::AggregateUdfExpr(Box::new(
- protobuf::AggregateUdfExprNode {
- fun_name: fun.name().to_string(),
- args: serialize_exprs(args, codec)?,
- distinct: *distinct,
- filter: match filter {
- Some(e) =>
Some(Box::new(serialize_expr(e.as_ref(), codec)?)),
- None => None,
- },
- order_by: match order_by {
- Some(e) => serialize_exprs(e, codec)?,
- None => vec![],
+ AggregateFunctionDefinition::UDF(fun) => {
+ let mut buf = Vec::new();
+ let _ = codec.try_encode_udaf(fun, &mut buf);
+ protobuf::LogicalExprNode {
+ expr_type: Some(ExprType::AggregateUdfExpr(Box::new(
+ protobuf::AggregateUdfExprNode {
+ fun_name: fun.name().to_string(),
+ args: serialize_exprs(args, codec)?,
+ distinct: *distinct,
+ filter: match filter {
+ Some(e) => {
+ Some(Box::new(serialize_expr(e.as_ref(),
codec)?))
+ }
+ None => None,
+ },
+ order_by: match order_by {
+ Some(e) => serialize_exprs(e, codec)?,
+ None => vec![],
+ },
+ fun_definition: (!buf.is_empty()).then_some(buf),
},
- },
- ))),
- },
+ ))),
+ }
+ }
},
Expr::ScalarVariable(_, _) => {
@@ -420,17 +440,13 @@ pub fn serialize_expr(
))
}
Expr::ScalarFunction(ScalarFunction { func, args }) => {
- let args = serialize_exprs(args, codec)?;
let mut buf = Vec::new();
- let _ = codec.try_encode_udf(func.as_ref(), &mut buf);
-
- let fun_definition = if buf.is_empty() { None } else { Some(buf) };
-
+ let _ = codec.try_encode_udf(func, &mut buf);
protobuf::LogicalExprNode {
expr_type:
Some(ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode {
fun_name: func.name().to_string(),
- fun_definition,
- args,
+ fun_definition: (!buf.is_empty()).then_some(buf),
+ args: serialize_exprs(args, codec)?,
})),
}
}
diff --git a/datafusion/proto/src/physical_plan/from_proto.rs
b/datafusion/proto/src/physical_plan/from_proto.rs
index b7311c694d..5ecca51478 100644
--- a/datafusion/proto/src/physical_plan/from_proto.rs
+++ b/datafusion/proto/src/physical_plan/from_proto.rs
@@ -164,8 +164,10 @@ pub fn parse_physical_window_expr(
WindowFunctionDefinition::BuiltInWindowFunction(f.into())
}
protobuf::physical_window_expr_node::WindowFunction::UserDefinedAggrFunction(udaf_name)
=> {
- let agg_udf = registry.udaf(udaf_name)?;
- WindowFunctionDefinition::AggregateUDF(agg_udf)
+ WindowFunctionDefinition::AggregateUDF(match
&proto.fun_definition {
+ Some(buf) => codec.try_decode_udaf(udaf_name, buf)?,
+ None => registry.udaf(udaf_name)?
+ })
}
}
} else {
diff --git a/datafusion/proto/src/physical_plan/mod.rs
b/datafusion/proto/src/physical_plan/mod.rs
index 948a39bfe0..1220f42ded 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -61,7 +61,7 @@ use datafusion::physical_plan::{
udaf, AggregateExpr, ExecutionPlan, InputOrderMode, PhysicalExpr,
WindowExpr,
};
use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result};
-use datafusion_expr::ScalarUDF;
+use datafusion_expr::{AggregateUDF, ScalarUDF};
use crate::common::{byte_to_string, str_to_byte};
use crate::convert_required;
@@ -491,19 +491,22 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode {
&ordering_req,
&physical_schema,
name.to_string(),
- false,
+ agg_node.ignore_nulls,
)
}
AggregateFunction::UserDefinedAggrFunction(udaf_name) => {
- let agg_udf =
registry.udaf(udaf_name)?;
+ let agg_udf = match
&agg_node.fun_definition {
+ Some(buf) =>
extension_codec.try_decode_udaf(udaf_name, buf)?,
+ None =>
registry.udaf(udaf_name)?
+ };
+
// TODO: 'logical_exprs' is not
supported for UDAF yet.
// approx_percentile_cont and
approx_percentile_cont_weight are not supported for UDAF from protobuf yet.
let logical_exprs = &[];
// TODO: `order by` is not
supported for UDAF yet
let sort_exprs = &[];
let ordering_req = &[];
- let ignore_nulls = false;
-
udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, logical_exprs,
sort_exprs, ordering_req, &physical_schema, name, ignore_nulls, false)
+
udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, logical_exprs,
sort_exprs, ordering_req, &physical_schema, name, agg_node.ignore_nulls,
agg_node.distinct)
}
}
}).transpose()?.ok_or_else(|| {
@@ -2034,6 +2037,16 @@ pub trait PhysicalExtensionCodec: Debug + Send + Sync {
) -> Result<()> {
not_impl_err!("PhysicalExtensionCodec is not provided")
}
+
+ fn try_decode_udaf(&self, name: &str, _buf: &[u8]) ->
Result<Arc<AggregateUDF>> {
+ not_impl_err!(
+ "PhysicalExtensionCodec is not provided for aggregate function
{name}"
+ )
+ }
+
+ fn try_encode_udaf(&self, _node: &AggregateUDF, _buf: &mut Vec<u8>) ->
Result<()> {
+ Ok(())
+ }
}
#[derive(Debug)]
diff --git a/datafusion/proto/src/physical_plan/to_proto.rs
b/datafusion/proto/src/physical_plan/to_proto.rs
index d8d0291e1c..7ea2902cf3 100644
--- a/datafusion/proto/src/physical_plan/to_proto.rs
+++ b/datafusion/proto/src/physical_plan/to_proto.rs
@@ -40,6 +40,7 @@ use datafusion::{
physical_plan::expressions::LikeExpr,
};
use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result};
+use datafusion_expr::WindowFrame;
use crate::protobuf::{
self, physical_aggregate_expr_node, physical_window_expr_node,
PhysicalSortExprNode,
@@ -58,13 +59,17 @@ pub fn serialize_physical_aggr_expr(
if let Some(a) =
aggr_expr.as_any().downcast_ref::<AggregateFunctionExpr>() {
let name = a.fun().name().to_string();
+ let mut buf = Vec::new();
+ codec.try_encode_udaf(a.fun(), &mut buf)?;
return Ok(protobuf::PhysicalExprNode {
expr_type:
Some(protobuf::physical_expr_node::ExprType::AggregateExpr(
protobuf::PhysicalAggregateExprNode {
aggregate_function:
Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(name)),
expr: expressions,
ordering_req,
- distinct: false,
+ distinct: a.is_distinct(),
+ ignore_nulls: a.ignore_nulls(),
+ fun_definition: (!buf.is_empty()).then_some(buf)
},
)),
});
@@ -86,11 +91,55 @@ pub fn serialize_physical_aggr_expr(
expr: expressions,
ordering_req,
distinct,
+ ignore_nulls: false,
+ fun_definition: None,
},
)),
})
}
+fn serialize_physical_window_aggr_expr(
+ aggr_expr: &dyn AggregateExpr,
+ window_frame: &WindowFrame,
+ codec: &dyn PhysicalExtensionCodec,
+) -> Result<(physical_window_expr_node::WindowFunction, Option<Vec<u8>>)> {
+ if let Some(a) =
aggr_expr.as_any().downcast_ref::<AggregateFunctionExpr>() {
+ if a.is_distinct() || a.ignore_nulls() {
+ // TODO
+ return not_impl_err!(
+ "Distinct aggregate functions not supported in window
expressions"
+ );
+ }
+
+ let mut buf = Vec::new();
+ codec.try_encode_udaf(a.fun(), &mut buf)?;
+ Ok((
+ physical_window_expr_node::WindowFunction::UserDefinedAggrFunction(
+ a.fun().name().to_string(),
+ ),
+ (!buf.is_empty()).then_some(buf),
+ ))
+ } else {
+ let AggrFn { inner, distinct } = aggr_expr_to_aggr_fn(aggr_expr)?;
+ if distinct {
+ return not_impl_err!(
+ "Distinct aggregate functions not supported in window
expressions"
+ );
+ }
+
+ if !window_frame.start_bound.is_unbounded() {
+ return Err(DataFusionError::Internal(format!(
+ "Unbounded start bound in WindowFrame = {window_frame}"
+ )));
+ }
+
+ Ok((
+ physical_window_expr_node::WindowFunction::AggrFunction(inner as
i32),
+ None,
+ ))
+ }
+}
+
pub fn serialize_physical_window_expr(
window_expr: Arc<dyn WindowExpr>,
codec: &dyn PhysicalExtensionCodec,
@@ -99,7 +148,7 @@ pub fn serialize_physical_window_expr(
let mut args = window_expr.expressions().to_vec();
let window_frame = window_expr.get_window_frame();
- let window_function = if let Some(built_in_window_expr) =
+ let (window_function, fun_definition) = if let Some(built_in_window_expr) =
expr.downcast_ref::<BuiltInWindowExpr>()
{
let expr = built_in_window_expr.get_built_in_func_expr();
@@ -160,58 +209,26 @@ pub fn serialize_physical_window_expr(
return not_impl_err!("BuiltIn function not supported: {expr:?}");
};
- physical_window_expr_node::WindowFunction::BuiltInFunction(builtin_fn
as i32)
+ (
+
physical_window_expr_node::WindowFunction::BuiltInFunction(builtin_fn as i32),
+ None,
+ )
} else if let Some(plain_aggr_window_expr) =
expr.downcast_ref::<PlainAggregateWindowExpr>()
{
- let aggr_expr = plain_aggr_window_expr.get_aggregate_expr();
- if let Some(a) =
aggr_expr.as_any().downcast_ref::<AggregateFunctionExpr>() {
- physical_window_expr_node::WindowFunction::UserDefinedAggrFunction(
- a.fun().name().to_string(),
- )
- } else {
- let AggrFn { inner, distinct } = aggr_expr_to_aggr_fn(
- plain_aggr_window_expr.get_aggregate_expr().as_ref(),
- )?;
-
- if distinct {
- return not_impl_err!(
- "Distinct aggregate functions not supported in window
expressions"
- );
- }
-
- if !window_frame.start_bound.is_unbounded() {
- return Err(DataFusionError::Internal(format!("Invalid
PlainAggregateWindowExpr = {window_expr:?} with WindowFrame =
{window_frame:?}")));
- }
-
- physical_window_expr_node::WindowFunction::AggrFunction(inner as
i32)
- }
+ serialize_physical_window_aggr_expr(
+ plain_aggr_window_expr.get_aggregate_expr().as_ref(),
+ window_frame,
+ codec,
+ )?
} else if let Some(sliding_aggr_window_expr) =
expr.downcast_ref::<SlidingAggregateWindowExpr>()
{
- let aggr_expr = sliding_aggr_window_expr.get_aggregate_expr();
- if let Some(a) =
aggr_expr.as_any().downcast_ref::<AggregateFunctionExpr>() {
- physical_window_expr_node::WindowFunction::UserDefinedAggrFunction(
- a.fun().name().to_string(),
- )
- } else {
- let AggrFn { inner, distinct } = aggr_expr_to_aggr_fn(
- sliding_aggr_window_expr.get_aggregate_expr().as_ref(),
- )?;
-
- if distinct {
- // TODO
- return not_impl_err!(
- "Distinct aggregate functions not supported in window
expressions"
- );
- }
-
- if window_frame.start_bound.is_unbounded() {
- return Err(DataFusionError::Internal(format!("Invalid
SlidingAggregateWindowExpr = {window_expr:?} with WindowFrame =
{window_frame:?}")));
- }
-
- physical_window_expr_node::WindowFunction::AggrFunction(inner as
i32)
- }
+ serialize_physical_window_aggr_expr(
+ sliding_aggr_window_expr.get_aggregate_expr().as_ref(),
+ window_frame,
+ codec,
+ )?
} else {
return not_impl_err!("WindowExpr not supported: {window_expr:?}");
};
@@ -232,6 +249,7 @@ pub fn serialize_physical_window_expr(
window_frame: Some(window_frame),
window_function: Some(window_function),
name: window_expr.name().to_string(),
+ fun_definition,
})
}
@@ -461,18 +479,14 @@ pub fn serialize_physical_expr(
))),
})
} else if let Some(expr) = expr.downcast_ref::<ScalarFunctionExpr>() {
- let args = serialize_physical_exprs(expr.args().to_vec(), codec)?;
-
let mut buf = Vec::new();
codec.try_encode_udf(expr.fun(), &mut buf)?;
-
- let fun_definition = if buf.is_empty() { None } else { Some(buf) };
Ok(protobuf::PhysicalExprNode {
expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarUdf(
protobuf::PhysicalScalarUdfNode {
name: expr.name().to_string(),
- args,
- fun_definition,
+ args: serialize_physical_exprs(expr.args().to_vec(),
codec)?,
+ fun_definition: (!buf.is_empty()).then_some(buf),
return_type: Some(expr.return_type().try_into()?),
},
)),
diff --git a/datafusion/proto/tests/cases/mod.rs
b/datafusion/proto/tests/cases/mod.rs
index b17289205f..1f837b7f42 100644
--- a/datafusion/proto/tests/cases/mod.rs
+++ b/datafusion/proto/tests/cases/mod.rs
@@ -15,6 +15,105 @@
// specific language governing permissions and limitations
// under the License.
+use std::any::Any;
+
+use arrow::datatypes::DataType;
+
+use datafusion_common::plan_err;
+use datafusion_expr::function::AccumulatorArgs;
+use datafusion_expr::{
+ Accumulator, AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, Signature,
Volatility,
+};
+
mod roundtrip_logical_plan;
mod roundtrip_physical_plan;
mod serialize;
+
+#[derive(Debug, PartialEq, Eq, Hash)]
+struct MyRegexUdf {
+ signature: Signature,
+ // regex as original string
+ pattern: String,
+}
+
+impl MyRegexUdf {
+ fn new(pattern: String) -> Self {
+ let signature = Signature::exact(vec![DataType::Utf8],
Volatility::Immutable);
+ Self { signature, pattern }
+ }
+}
+
+/// Implement the ScalarUDFImpl trait for MyRegexUdf
+impl ScalarUDFImpl for MyRegexUdf {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+ fn name(&self) -> &str {
+ "regex_udf"
+ }
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+ fn return_type(&self, args: &[DataType]) ->
datafusion_common::Result<DataType> {
+ if matches!(args, [DataType::Utf8]) {
+ Ok(DataType::Int64)
+ } else {
+ plan_err!("regex_udf only accepts Utf8 arguments")
+ }
+ }
+ fn invoke(
+ &self,
+ _args: &[ColumnarValue],
+ ) -> datafusion_common::Result<ColumnarValue> {
+ unimplemented!()
+ }
+}
+
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct MyRegexUdfNode {
+ #[prost(string, tag = "1")]
+ pub pattern: String,
+}
+
+#[derive(Debug, PartialEq, Eq, Hash)]
+struct MyAggregateUDF {
+ signature: Signature,
+ result: String,
+}
+
+impl MyAggregateUDF {
+ fn new(result: String) -> Self {
+ let signature = Signature::exact(vec![DataType::Int64],
Volatility::Immutable);
+ Self { signature, result }
+ }
+}
+
+impl AggregateUDFImpl for MyAggregateUDF {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+ fn name(&self) -> &str {
+ "aggregate_udf"
+ }
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+ fn return_type(
+ &self,
+ _arg_types: &[DataType],
+ ) -> datafusion_common::Result<DataType> {
+ Ok(DataType::Utf8)
+ }
+ fn accumulator(
+ &self,
+ _acc_args: AccumulatorArgs,
+ ) -> datafusion_common::Result<Box<dyn Accumulator>> {
+ unimplemented!()
+ }
+}
+
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct MyAggregateUdfNode {
+ #[prost(string, tag = "1")]
+ pub result: String,
+}
diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
index d0209d811b..0117502f40 100644
--- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
@@ -28,15 +28,12 @@ use arrow::datatypes::{
DataType, Field, Fields, Int32Type, IntervalDayTimeType,
IntervalMonthDayNanoType,
IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode,
};
+use prost::Message;
+
use datafusion::datasource::file_format::arrow::ArrowFormatFactory;
use datafusion::datasource::file_format::csv::CsvFormatFactory;
use datafusion::datasource::file_format::format_as_file_type;
use datafusion::datasource::file_format::parquet::ParquetFormatFactory;
-use datafusion_proto::logical_plan::file_formats::{
- ArrowLogicalExtensionCodec, CsvLogicalExtensionCodec,
ParquetLogicalExtensionCodec,
-};
-use prost::Message;
-
use datafusion::datasource::provider::TableProviderFactory;
use datafusion::datasource::TableProvider;
use datafusion::execution::session_state::SessionStateBuilder;
@@ -62,9 +59,9 @@ use datafusion_expr::expr::{
};
use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore};
use datafusion_expr::{
- Accumulator, AggregateExt, AggregateFunction, ColumnarValue, ExprSchemable,
- LogicalPlan, Operator, PartitionEvaluator, ScalarUDF, ScalarUDFImpl,
Signature,
- TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits,
+ Accumulator, AggregateExt, AggregateFunction, AggregateUDF, ColumnarValue,
+ ExprSchemable, Literal, LogicalPlan, Operator, PartitionEvaluator,
ScalarUDF,
+ Signature, TryCast, Volatility, WindowFrame, WindowFrameBound,
WindowFrameUnits,
WindowFunctionDefinition, WindowUDF, WindowUDFImpl,
};
use datafusion_functions_aggregate::average::avg_udaf;
@@ -76,12 +73,17 @@ use datafusion_proto::bytes::{
logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec,
logical_plan_to_bytes, logical_plan_to_bytes_with_extension_codec,
};
+use datafusion_proto::logical_plan::file_formats::{
+ ArrowLogicalExtensionCodec, CsvLogicalExtensionCodec,
ParquetLogicalExtensionCodec,
+};
use datafusion_proto::logical_plan::to_proto::serialize_expr;
use datafusion_proto::logical_plan::{
from_proto, DefaultLogicalExtensionCodec, LogicalExtensionCodec,
};
use datafusion_proto::protobuf;
+use crate::cases::{MyAggregateUDF, MyAggregateUdfNode, MyRegexUdf,
MyRegexUdfNode};
+
#[cfg(feature = "json")]
fn roundtrip_json_test(proto: &protobuf::LogicalExprNode) {
let string = serde_json::to_string(proto).unwrap();
@@ -744,7 +746,7 @@ pub mod proto {
pub k: u64,
#[prost(message, optional, tag = "2")]
- pub expr:
::core::option::Option<datafusion_proto::protobuf::LogicalExprNode>,
+ pub expr: Option<datafusion_proto::protobuf::LogicalExprNode>,
}
#[derive(Clone, PartialEq, Eq, ::prost::Message)]
@@ -752,12 +754,6 @@ pub mod proto {
#[prost(uint64, tag = "1")]
pub k: u64,
}
-
- #[derive(Clone, PartialEq, ::prost::Message)]
- pub struct MyRegexUdfNode {
- #[prost(string, tag = "1")]
- pub pattern: String,
- }
}
#[derive(PartialEq, Eq, Hash)]
@@ -890,51 +886,9 @@ impl LogicalExtensionCodec for TopKExtensionCodec {
}
#[derive(Debug)]
-struct MyRegexUdf {
- signature: Signature,
- // regex as original string
- pattern: String,
-}
-
-impl MyRegexUdf {
- fn new(pattern: String) -> Self {
- Self {
- signature: Signature::uniform(
- 1,
- vec![DataType::Int32],
- Volatility::Immutable,
- ),
- pattern,
- }
- }
-}
-
-/// Implement the ScalarUDFImpl trait for MyRegexUdf
-impl ScalarUDFImpl for MyRegexUdf {
- fn as_any(&self) -> &dyn Any {
- self
- }
- fn name(&self) -> &str {
- "regex_udf"
- }
- fn signature(&self) -> &Signature {
- &self.signature
- }
- fn return_type(&self, args: &[DataType]) -> Result<DataType> {
- if !matches!(args.first(), Some(&DataType::Utf8)) {
- return plan_err!("regex_udf only accepts Utf8 arguments");
- }
- Ok(DataType::Int32)
- }
- fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
- unimplemented!()
- }
-}
-
-#[derive(Debug)]
-pub struct ScalarUDFExtensionCodec {}
+pub struct UDFExtensionCodec;
-impl LogicalExtensionCodec for ScalarUDFExtensionCodec {
+impl LogicalExtensionCodec for UDFExtensionCodec {
fn try_decode(
&self,
_buf: &[u8],
@@ -969,13 +923,11 @@ impl LogicalExtensionCodec for ScalarUDFExtensionCodec {
fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result<Arc<ScalarUDF>>
{
if name == "regex_udf" {
- let proto = proto::MyRegexUdfNode::decode(buf).map_err(|err| {
- DataFusionError::Internal(format!("failed to decode regex_udf:
{}", err))
+ let proto = MyRegexUdfNode::decode(buf).map_err(|err| {
+ DataFusionError::Internal(format!("failed to decode regex_udf:
{err}"))
})?;
- Ok(Arc::new(ScalarUDF::new_from_impl(MyRegexUdf::new(
- proto.pattern,
- ))))
+ Ok(Arc::new(ScalarUDF::from(MyRegexUdf::new(proto.pattern))))
} else {
not_impl_err!("unrecognized scalar UDF implementation, cannot
decode")
}
@@ -984,11 +936,39 @@ impl LogicalExtensionCodec for ScalarUDFExtensionCodec {
fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec<u8>) ->
Result<()> {
let binding = node.inner();
let udf = binding.as_any().downcast_ref::<MyRegexUdf>().unwrap();
- let proto = proto::MyRegexUdfNode {
+ let proto = MyRegexUdfNode {
pattern: udf.pattern.clone(),
};
- proto.encode(buf).map_err(|e| {
- DataFusionError::Internal(format!("failed to encode udf: {e:?}"))
+ proto.encode(buf).map_err(|err| {
+ DataFusionError::Internal(format!("failed to encode udf: {err}"))
+ })?;
+ Ok(())
+ }
+
+ fn try_decode_udaf(&self, name: &str, buf: &[u8]) ->
Result<Arc<AggregateUDF>> {
+ if name == "aggregate_udf" {
+ let proto = MyAggregateUdfNode::decode(buf).map_err(|err| {
+ DataFusionError::Internal(format!(
+ "failed to decode aggregate_udf: {err}"
+ ))
+ })?;
+
+ Ok(Arc::new(AggregateUDF::from(MyAggregateUDF::new(
+ proto.result,
+ ))))
+ } else {
+ not_impl_err!("unrecognized aggregate UDF implementation, cannot
decode")
+ }
+ }
+
+ fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec<u8>) ->
Result<()> {
+ let binding = node.inner();
+ let udf = binding.as_any().downcast_ref::<MyAggregateUDF>().unwrap();
+ let proto = MyAggregateUdfNode {
+ result: udf.result.clone(),
+ };
+ proto.encode(buf).map_err(|err| {
+ DataFusionError::Internal(format!("failed to encode udf: {err}"))
})?;
Ok(())
}
@@ -1563,8 +1543,7 @@ fn roundtrip_null_scalar_values() {
for test_case in test_types.into_iter() {
let proto_scalar: protobuf::ScalarValue =
(&test_case).try_into().unwrap();
- let returned_scalar: datafusion::scalar::ScalarValue =
- (&proto_scalar).try_into().unwrap();
+ let returned_scalar: ScalarValue = (&proto_scalar).try_into().unwrap();
assert_eq!(format!("{:?}", &test_case),
format!("{returned_scalar:?}"));
}
}
@@ -1893,22 +1872,19 @@ fn roundtrip_aggregate_udf() {
struct Dummy {}
impl Accumulator for Dummy {
- fn state(&mut self) -> datafusion::error::Result<Vec<ScalarValue>> {
+ fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![])
}
- fn update_batch(
- &mut self,
- _values: &[ArrayRef],
- ) -> datafusion::error::Result<()> {
+ fn update_batch(&mut self, _values: &[ArrayRef]) -> Result<()> {
Ok(())
}
- fn merge_batch(&mut self, _states: &[ArrayRef]) ->
datafusion::error::Result<()> {
+ fn merge_batch(&mut self, _states: &[ArrayRef]) -> Result<()> {
Ok(())
}
- fn evaluate(&mut self) -> datafusion::error::Result<ScalarValue> {
+ fn evaluate(&mut self) -> Result<ScalarValue> {
Ok(ScalarValue::Float64(None))
}
@@ -1976,25 +1952,27 @@ fn roundtrip_scalar_udf() {
#[test]
fn roundtrip_scalar_udf_extension_codec() {
- let pattern = ".*";
- let udf = ScalarUDF::from(MyRegexUdf::new(pattern.to_string()));
- let test_expr =
- Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(udf.clone()),
vec![]));
-
+ let udf = ScalarUDF::from(MyRegexUdf::new(".*".to_owned()));
+ let test_expr = udf.call(vec!["foo".lit()]);
let ctx = SessionContext::new();
- ctx.register_udf(udf);
-
- let extension_codec = ScalarUDFExtensionCodec {};
- let proto: protobuf::LogicalExprNode =
- match serialize_expr(&test_expr, &extension_codec) {
- Ok(p) => p,
- Err(e) => panic!("Error serializing expression: {:?}", e),
- };
- let round_trip: Expr =
- from_proto::parse_expr(&proto, &ctx, &extension_codec).unwrap();
+ let proto = serialize_expr(&test_expr,
&UDFExtensionCodec).expect("serialize expr");
+ let round_trip =
+ from_proto::parse_expr(&proto, &ctx, &UDFExtensionCodec).expect("parse
expr");
assert_eq!(format!("{:?}", &test_expr), format!("{round_trip:?}"));
+ roundtrip_json_test(&proto);
+}
+
+#[test]
+fn roundtrip_aggregate_udf_extension_codec() {
+ let udf = AggregateUDF::from(MyAggregateUDF::new("DataFusion".to_owned()));
+ let test_expr = udf.call(vec![42.lit()]);
+ let ctx = SessionContext::new();
+ let proto = serialize_expr(&test_expr,
&UDFExtensionCodec).expect("serialize expr");
+ let round_trip =
+ from_proto::parse_expr(&proto, &ctx, &UDFExtensionCodec).expect("parse
expr");
+ assert_eq!(format!("{:?}", &test_expr), format!("{round_trip:?}"));
roundtrip_json_test(&proto);
}
@@ -2120,22 +2098,19 @@ fn roundtrip_window() {
struct DummyAggr {}
impl Accumulator for DummyAggr {
- fn state(&mut self) -> datafusion::error::Result<Vec<ScalarValue>> {
+ fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![])
}
- fn update_batch(
- &mut self,
- _values: &[ArrayRef],
- ) -> datafusion::error::Result<()> {
+ fn update_batch(&mut self, _values: &[ArrayRef]) -> Result<()> {
Ok(())
}
- fn merge_batch(&mut self, _states: &[ArrayRef]) ->
datafusion::error::Result<()> {
+ fn merge_batch(&mut self, _states: &[ArrayRef]) -> Result<()> {
Ok(())
}
- fn evaluate(&mut self) -> datafusion::error::Result<ScalarValue> {
+ fn evaluate(&mut self) -> Result<ScalarValue> {
Ok(ScalarValue::Float64(None))
}
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index 2fcc65008f..fba6dfe425 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -15,7 +15,6 @@
// specific language governing permissions and limitations
// under the License.
-use arrow::array::RecordBatch;
use std::any::Any;
use std::fmt::Display;
use std::hash::Hasher;
@@ -23,8 +22,8 @@ use std::ops::Deref;
use std::sync::Arc;
use std::vec;
+use arrow::array::RecordBatch;
use arrow::csv::WriterBuilder;
-use datafusion::functions_aggregate::sum::sum_udaf;
use prost::Message;
use datafusion::arrow::array::ArrayRef;
@@ -40,9 +39,10 @@ use datafusion::datasource::physical_plan::{
FileSinkConfig, ParquetExec,
};
use datafusion::execution::FunctionRegistry;
+use datafusion::functions_aggregate::sum::sum_udaf;
use datafusion::logical_expr::{create_udf, JoinType, Operator, Volatility};
use datafusion::physical_expr::aggregate::utils::down_cast_any_ref;
-use datafusion::physical_expr::expressions::Max;
+use datafusion::physical_expr::expressions::{Literal, Max};
use datafusion::physical_expr::window::SlidingAggregateWindowExpr;
use datafusion::physical_expr::{PhysicalSortRequirement, ScalarFunctionExpr};
use datafusion::physical_plan::aggregates::{
@@ -70,7 +70,7 @@ use datafusion::physical_plan::windows::{
BuiltInWindowExpr, PlainAggregateWindowExpr, WindowAggExec,
};
use datafusion::physical_plan::{
- udaf, AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, Statistics,
+ AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, Statistics,
};
use datafusion::prelude::SessionContext;
use datafusion::scalar::ScalarValue;
@@ -79,10 +79,10 @@ use
datafusion_common::file_options::csv_writer::CsvWriterOptions;
use datafusion_common::file_options::json_writer::JsonWriterOptions;
use datafusion_common::parsers::CompressionTypeVariant;
use datafusion_common::stats::Precision;
-use datafusion_common::{internal_err, not_impl_err, plan_err, DataFusionError,
Result};
+use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result};
use datafusion_expr::{
Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue,
ScalarUDF,
- ScalarUDFImpl, Signature, SimpleAggregateUDF, WindowFrame,
WindowFrameBound,
+ Signature, SimpleAggregateUDF, WindowFrame, WindowFrameBound,
};
use datafusion_functions_aggregate::average::avg_udaf;
use datafusion_functions_aggregate::nth_value::nth_value_udaf;
@@ -92,6 +92,8 @@ use datafusion_proto::physical_plan::{
};
use datafusion_proto::protobuf;
+use crate::cases::{MyAggregateUDF, MyAggregateUdfNode, MyRegexUdf,
MyRegexUdfNode};
+
/// Perform a serde roundtrip and assert that the string representation of the
before and after plans
/// are identical. Note that this often isn't sufficient to guarantee that no
information is
/// lost during serde because the string representation of a plan often only
shows a subset of state.
@@ -312,7 +314,7 @@ fn roundtrip_window() -> Result<()> {
);
let args = vec![cast(col("a", &schema)?, &schema, DataType::Float64)?];
- let sum_expr = udaf::create_aggregate_expr(
+ let sum_expr = create_aggregate_expr(
&sum_udaf(),
&args,
&[],
@@ -367,7 +369,7 @@ fn rountrip_aggregate() -> Result<()> {
false,
)?],
// NTH_VALUE
- vec![udaf::create_aggregate_expr(
+ vec![create_aggregate_expr(
&nth_value_udaf(),
&[col("b", &schema)?, lit(1u64)],
&[],
@@ -379,7 +381,7 @@ fn rountrip_aggregate() -> Result<()> {
false,
)?],
// STRING_AGG
- vec![udaf::create_aggregate_expr(
+ vec![create_aggregate_expr(
&AggregateUDF::new_from_impl(StringAgg::new()),
&[
cast(col("b", &schema)?, &schema, DataType::Utf8)?,
@@ -490,7 +492,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> {
let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =
vec![(col("a", &schema)?, "unused".to_string())];
- let aggregates: Vec<Arc<dyn AggregateExpr>> =
vec![udaf::create_aggregate_expr(
+ let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![create_aggregate_expr(
&udaf,
&[col("b", &schema)?],
&[],
@@ -845,123 +847,161 @@ fn roundtrip_scalar_udf() -> Result<()> {
roundtrip_test_with_context(Arc::new(project), &ctx)
}
-#[test]
-fn roundtrip_scalar_udf_extension_codec() -> Result<()> {
- #[derive(Debug)]
- struct MyRegexUdf {
- signature: Signature,
- // regex as original string
- pattern: String,
+#[derive(Debug)]
+struct UDFExtensionCodec;
+
+impl PhysicalExtensionCodec for UDFExtensionCodec {
+ fn try_decode(
+ &self,
+ _buf: &[u8],
+ _inputs: &[Arc<dyn ExecutionPlan>],
+ _registry: &dyn FunctionRegistry,
+ ) -> Result<Arc<dyn ExecutionPlan>> {
+ not_impl_err!("No extension codec provided")
}
- impl MyRegexUdf {
- fn new(pattern: String) -> Self {
- Self {
- signature: Signature::exact(vec![DataType::Utf8],
Volatility::Immutable),
- pattern,
- }
- }
+ fn try_encode(
+ &self,
+ _node: Arc<dyn ExecutionPlan>,
+ _buf: &mut Vec<u8>,
+ ) -> Result<()> {
+ not_impl_err!("No extension codec provided")
}
- /// Implement the ScalarUDFImpl trait for MyRegexUdf
- impl ScalarUDFImpl for MyRegexUdf {
- fn as_any(&self) -> &dyn Any {
- self
- }
+ fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result<Arc<ScalarUDF>>
{
+ if name == "regex_udf" {
+ let proto = MyRegexUdfNode::decode(buf).map_err(|err| {
+ DataFusionError::Internal(format!("failed to decode regex_udf:
{err}"))
+ })?;
- fn name(&self) -> &str {
- "regex_udf"
+ Ok(Arc::new(ScalarUDF::from(MyRegexUdf::new(proto.pattern))))
+ } else {
+ not_impl_err!("unrecognized scalar UDF implementation, cannot
decode")
}
+ }
- fn signature(&self) -> &Signature {
- &self.signature
+ fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec<u8>) ->
Result<()> {
+ let binding = node.inner();
+ if let Some(udf) = binding.as_any().downcast_ref::<MyRegexUdf>() {
+ let proto = MyRegexUdfNode {
+ pattern: udf.pattern.clone(),
+ };
+ proto.encode(buf).map_err(|err| {
+ DataFusionError::Internal(format!("failed to encode udf:
{err}"))
+ })?;
}
+ Ok(())
+ }
- fn return_type(&self, args: &[DataType]) -> Result<DataType> {
- if !matches!(args.first(), Some(&DataType::Utf8)) {
- return plan_err!("regex_udf only accepts Utf8 arguments");
- }
- Ok(DataType::Int64)
+ fn try_decode_udaf(&self, name: &str, buf: &[u8]) ->
Result<Arc<AggregateUDF>> {
+ if name == "aggregate_udf" {
+ let proto = MyAggregateUdfNode::decode(buf).map_err(|err| {
+ DataFusionError::Internal(format!(
+ "failed to decode aggregate_udf: {err}"
+ ))
+ })?;
+
+ Ok(Arc::new(AggregateUDF::from(MyAggregateUDF::new(
+ proto.result,
+ ))))
+ } else {
+ not_impl_err!("unrecognized scalar UDF implementation, cannot
decode")
}
+ }
- fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
- unimplemented!()
+ fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec<u8>) ->
Result<()> {
+ let binding = node.inner();
+ if let Some(udf) = binding.as_any().downcast_ref::<MyAggregateUDF>() {
+ let proto = MyAggregateUdfNode {
+ result: udf.result.clone(),
+ };
+ proto.encode(buf).map_err(|err| {
+ DataFusionError::Internal(format!("failed to encode udf:
{err:?}"))
+ })?;
}
+ Ok(())
}
+}
- #[derive(Clone, PartialEq, ::prost::Message)]
- pub struct MyRegexUdfNode {
- #[prost(string, tag = "1")]
- pub pattern: String,
- }
+#[test]
+fn roundtrip_scalar_udf_extension_codec() -> Result<()> {
+ let field_text = Field::new("text", DataType::Utf8, true);
+ let field_published = Field::new("published", DataType::Boolean, false);
+ let field_author = Field::new("author", DataType::Utf8, false);
+ let schema = Arc::new(Schema::new(vec![field_text, field_published,
field_author]));
+ let input = Arc::new(EmptyExec::new(schema.clone()));
- #[derive(Debug)]
- pub struct ScalarUDFExtensionCodec {}
+ let udf_expr = Arc::new(ScalarFunctionExpr::new(
+ "regex_udf",
+ Arc::new(ScalarUDF::from(MyRegexUdf::new(".*".to_string()))),
+ vec![col("text", &schema)?],
+ DataType::Int64,
+ ));
- impl PhysicalExtensionCodec for ScalarUDFExtensionCodec {
- fn try_decode(
- &self,
- _buf: &[u8],
- _inputs: &[Arc<dyn ExecutionPlan>],
- _registry: &dyn FunctionRegistry,
- ) -> Result<Arc<dyn ExecutionPlan>> {
- not_impl_err!("No extension codec provided")
- }
+ let filter = Arc::new(FilterExec::try_new(
+ Arc::new(BinaryExpr::new(
+ col("published", &schema)?,
+ Operator::And,
+ Arc::new(BinaryExpr::new(udf_expr.clone(), Operator::Gt, lit(0))),
+ )),
+ input,
+ )?);
- fn try_encode(
- &self,
- _node: Arc<dyn ExecutionPlan>,
- _buf: &mut Vec<u8>,
- ) -> Result<()> {
- not_impl_err!("No extension codec provided")
- }
+ let window = Arc::new(WindowAggExec::try_new(
+ vec![Arc::new(PlainAggregateWindowExpr::new(
+ Arc::new(Max::new(udf_expr.clone(), "max", DataType::Int64)),
+ &[col("author", &schema)?],
+ &[],
+ Arc::new(WindowFrame::new(None)),
+ ))],
+ filter,
+ vec![col("author", &schema)?],
+ )?);
- fn try_decode_udf(&self, name: &str, buf: &[u8]) ->
Result<Arc<ScalarUDF>> {
- if name == "regex_udf" {
- let proto = MyRegexUdfNode::decode(buf).map_err(|err| {
- DataFusionError::Internal(format!(
- "failed to decode regex_udf: {}",
- err
- ))
- })?;
-
- Ok(Arc::new(ScalarUDF::new_from_impl(MyRegexUdf::new(
- proto.pattern,
- ))))
- } else {
- not_impl_err!("unrecognized scalar UDF implementation, cannot
decode")
- }
- }
+ let aggregate = Arc::new(AggregateExec::try_new(
+ AggregateMode::Final,
+ PhysicalGroupBy::new(vec![], vec![], vec![]),
+ vec![Arc::new(Max::new(udf_expr, "max", DataType::Int64))],
+ vec![None],
+ window,
+ schema.clone(),
+ )?);
- fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec<u8>) ->
Result<()> {
- let binding = node.inner();
- if let Some(udf) = binding.as_any().downcast_ref::<MyRegexUdf>() {
- let proto = MyRegexUdfNode {
- pattern: udf.pattern.clone(),
- };
- proto.encode(buf).map_err(|e| {
- DataFusionError::Internal(format!("failed to encode udf:
{e:?}"))
- })?;
- }
- Ok(())
- }
- }
+ let ctx = SessionContext::new();
+ roundtrip_test_and_return(aggregate, &ctx, &UDFExtensionCodec)?;
+ Ok(())
+}
+#[test]
+fn roundtrip_aggregate_udf_extension_codec() -> Result<()> {
let field_text = Field::new("text", DataType::Utf8, true);
let field_published = Field::new("published", DataType::Boolean, false);
let field_author = Field::new("author", DataType::Utf8, false);
let schema = Arc::new(Schema::new(vec![field_text, field_published,
field_author]));
let input = Arc::new(EmptyExec::new(schema.clone()));
- let pattern = ".*";
- let udf = ScalarUDF::from(MyRegexUdf::new(pattern.to_string()));
let udf_expr = Arc::new(ScalarFunctionExpr::new(
- udf.name(),
- Arc::new(udf.clone()),
+ "regex_udf",
+ Arc::new(ScalarUDF::from(MyRegexUdf::new(".*".to_string()))),
vec![col("text", &schema)?],
DataType::Int64,
));
+ let udaf = AggregateUDF::from(MyAggregateUDF::new("result".to_string()));
+ let aggr_args: [Arc<dyn PhysicalExpr>; 1] =
+ [Arc::new(Literal::new(ScalarValue::from(42)))];
+ let aggr_expr = create_aggregate_expr(
+ &udaf,
+ &aggr_args,
+ &[],
+ &[],
+ &[],
+ &schema,
+ "aggregate_udf",
+ false,
+ false,
+ )?;
+
let filter = Arc::new(FilterExec::try_new(
Arc::new(BinaryExpr::new(
col("published", &schema)?,
@@ -973,7 +1013,7 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> {
let window = Arc::new(WindowAggExec::try_new(
vec![Arc::new(PlainAggregateWindowExpr::new(
- Arc::new(Max::new(udf_expr.clone(), "max", DataType::Int64)),
+ aggr_expr,
&[col("author", &schema)?],
&[],
Arc::new(WindowFrame::new(None)),
@@ -982,18 +1022,29 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> {
vec![col("author", &schema)?],
)?);
+ let aggr_expr = create_aggregate_expr(
+ &udaf,
+ &aggr_args,
+ &[],
+ &[],
+ &[],
+ &schema,
+ "aggregate_udf",
+ true,
+ true,
+ )?;
+
let aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Final,
PhysicalGroupBy::new(vec![], vec![], vec![]),
- vec![Arc::new(Max::new(udf_expr, "max", DataType::Int64))],
+ vec![aggr_expr],
vec![None],
window,
schema.clone(),
)?);
let ctx = SessionContext::new();
- let codec = ScalarUDFExtensionCodec {};
- roundtrip_test_and_return(aggregate, &ctx, &codec)?;
+ roundtrip_test_and_return(aggregate, &ctx, &UDFExtensionCodec)?;
Ok(())
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]