This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 02326998f0 Add extension hooks for encoding and decoding UDAFs and 
UDWFs (#11417)
02326998f0 is described below

commit 02326998f07a13fda0c93988bf13853413c4a2b2
Author: Georgi Krastev <[email protected]>
AuthorDate: Wed Jul 17 00:52:20 2024 +0300

    Add extension hooks for encoding and decoding UDAFs and UDWFs (#11417)
    
    * Add extension hooks for encoding and decoding UDAFs and UDWFs
    
    * Add tests for encoding and decoding UDAF
---
 .../examples/composed_extension_codec.rs           |  80 +++----
 .../physical-expr-common/src/aggregate/mod.rs      |   5 +
 datafusion/proto/proto/datafusion.proto            |  35 +--
 datafusion/proto/src/generated/pbjson.rs           | 102 +++++++++
 datafusion/proto/src/generated/prost.rs            |  10 +
 datafusion/proto/src/logical_plan/file_formats.rs  |  80 -------
 datafusion/proto/src/logical_plan/from_proto.rs    |  42 ++--
 datafusion/proto/src/logical_plan/mod.rs           |  22 +-
 datafusion/proto/src/logical_plan/to_proto.rs      |  84 ++++---
 datafusion/proto/src/physical_plan/from_proto.rs   |   6 +-
 datafusion/proto/src/physical_plan/mod.rs          |  23 +-
 datafusion/proto/src/physical_plan/to_proto.rs     | 122 +++++-----
 datafusion/proto/tests/cases/mod.rs                |  99 ++++++++
 .../proto/tests/cases/roundtrip_logical_plan.rs    | 171 ++++++--------
 .../proto/tests/cases/roundtrip_physical_plan.rs   | 251 +++++++++++++--------
 15 files changed, 686 insertions(+), 446 deletions(-)

diff --git a/datafusion-examples/examples/composed_extension_codec.rs 
b/datafusion-examples/examples/composed_extension_codec.rs
index 43c6daba21..5c34eccf26 100644
--- a/datafusion-examples/examples/composed_extension_codec.rs
+++ b/datafusion-examples/examples/composed_extension_codec.rs
@@ -30,18 +30,19 @@
 //!           DeltaScan
 //! ```
 
+use std::any::Any;
+use std::fmt::Debug;
+use std::ops::Deref;
+use std::sync::Arc;
+
 use datafusion::common::Result;
 use datafusion::physical_plan::{DisplayAs, ExecutionPlan};
 use datafusion::prelude::SessionContext;
-use datafusion_common::internal_err;
+use datafusion_common::{internal_err, DataFusionError};
 use datafusion_expr::registry::FunctionRegistry;
-use datafusion_expr::ScalarUDF;
+use datafusion_expr::{AggregateUDF, ScalarUDF};
 use datafusion_proto::physical_plan::{AsExecutionPlan, PhysicalExtensionCodec};
 use datafusion_proto::protobuf;
-use std::any::Any;
-use std::fmt::Debug;
-use std::ops::Deref;
-use std::sync::Arc;
 
 #[tokio::main]
 async fn main() {
@@ -239,6 +240,25 @@ struct ComposedPhysicalExtensionCodec {
     codecs: Vec<Arc<dyn PhysicalExtensionCodec>>,
 }
 
+impl ComposedPhysicalExtensionCodec {
+    fn try_any<T>(
+        &self,
+        mut f: impl FnMut(&dyn PhysicalExtensionCodec) -> Result<T>,
+    ) -> Result<T> {
+        let mut last_err = None;
+        for codec in &self.codecs {
+            match f(codec.as_ref()) {
+                Ok(node) => return Ok(node),
+                Err(err) => last_err = Some(err),
+            }
+        }
+
+        Err(last_err.unwrap_or_else(|| {
+            DataFusionError::NotImplemented("Empty list of composed 
codecs".to_owned())
+        }))
+    }
+}
+
 impl PhysicalExtensionCodec for ComposedPhysicalExtensionCodec {
     fn try_decode(
         &self,
@@ -246,46 +266,26 @@ impl PhysicalExtensionCodec for 
ComposedPhysicalExtensionCodec {
         inputs: &[Arc<dyn ExecutionPlan>],
         registry: &dyn FunctionRegistry,
     ) -> Result<Arc<dyn ExecutionPlan>> {
-        let mut last_err = None;
-        for codec in &self.codecs {
-            match codec.try_decode(buf, inputs, registry) {
-                Ok(plan) => return Ok(plan),
-                Err(e) => last_err = Some(e),
-            }
-        }
-        Err(last_err.unwrap())
+        self.try_any(|codec| codec.try_decode(buf, inputs, registry))
     }
 
     fn try_encode(&self, node: Arc<dyn ExecutionPlan>, buf: &mut Vec<u8>) -> 
Result<()> {
-        let mut last_err = None;
-        for codec in &self.codecs {
-            match codec.try_encode(node.clone(), buf) {
-                Ok(_) => return Ok(()),
-                Err(e) => last_err = Some(e),
-            }
-        }
-        Err(last_err.unwrap())
+        self.try_any(|codec| codec.try_encode(node.clone(), buf))
     }
 
-    fn try_decode_udf(&self, name: &str, _buf: &[u8]) -> 
Result<Arc<ScalarUDF>> {
-        let mut last_err = None;
-        for codec in &self.codecs {
-            match codec.try_decode_udf(name, _buf) {
-                Ok(plan) => return Ok(plan),
-                Err(e) => last_err = Some(e),
-            }
-        }
-        Err(last_err.unwrap())
+    fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result<Arc<ScalarUDF>> 
{
+        self.try_any(|codec| codec.try_decode_udf(name, buf))
     }
 
-    fn try_encode_udf(&self, _node: &ScalarUDF, _buf: &mut Vec<u8>) -> 
Result<()> {
-        let mut last_err = None;
-        for codec in &self.codecs {
-            match codec.try_encode_udf(_node, _buf) {
-                Ok(_) => return Ok(()),
-                Err(e) => last_err = Some(e),
-            }
-        }
-        Err(last_err.unwrap())
+    fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec<u8>) -> 
Result<()> {
+        self.try_any(|codec| codec.try_encode_udf(node, buf))
+    }
+
+    fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> 
Result<Arc<AggregateUDF>> {
+        self.try_any(|codec| codec.try_decode_udaf(name, buf))
+    }
+
+    fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec<u8>) -> 
Result<()> {
+        self.try_any(|codec| codec.try_encode_udaf(node, buf))
     }
 }
diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs 
b/datafusion/physical-expr-common/src/aggregate/mod.rs
index db4581a622..0e245fd0a6 100644
--- a/datafusion/physical-expr-common/src/aggregate/mod.rs
+++ b/datafusion/physical-expr-common/src/aggregate/mod.rs
@@ -283,6 +283,11 @@ impl AggregateFunctionExpr {
     pub fn is_distinct(&self) -> bool {
         self.is_distinct
     }
+
+    /// Return if the aggregation ignores nulls
+    pub fn ignore_nulls(&self) -> bool {
+        self.ignore_nulls
+    }
 }
 
 impl AggregateExpr for AggregateFunctionExpr {
diff --git a/datafusion/proto/proto/datafusion.proto 
b/datafusion/proto/proto/datafusion.proto
index 9ef884531e..dc551778c5 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -164,7 +164,7 @@ message CreateExternalTableNode {
   map<string, string> options = 8;
   datafusion_common.Constraints constraints = 12;
   map<string, LogicalExprNode> column_defaults = 13;
- }
+}
 
 message PrepareNode {
   string name = 1;
@@ -249,24 +249,24 @@ message DistinctOnNode {
 }
 
 message CopyToNode {
-    LogicalPlanNode input = 1;
-    string output_url = 2;
-    bytes file_type = 3;
-    repeated string partition_by = 7;
+  LogicalPlanNode input = 1;
+  string output_url = 2;
+  bytes file_type = 3;
+  repeated string partition_by = 7;
 }
 
 message UnnestNode {
-    LogicalPlanNode input = 1;
-    repeated datafusion_common.Column exec_columns = 2;
-    repeated uint64 list_type_columns = 3;
-    repeated uint64 struct_type_columns = 4;
-    repeated uint64 dependency_indices = 5;
-    datafusion_common.DfSchema schema = 6;
-    UnnestOptions options = 7;
+  LogicalPlanNode input = 1;
+  repeated datafusion_common.Column exec_columns = 2;
+  repeated uint64 list_type_columns = 3;
+  repeated uint64 struct_type_columns = 4;
+  repeated uint64 dependency_indices = 5;
+  datafusion_common.DfSchema schema = 6;
+  UnnestOptions options = 7;
 }
 
 message UnnestOptions {
-    bool preserve_nulls = 1;
+  bool preserve_nulls = 1;
 }
 
 message UnionNode {
@@ -488,8 +488,8 @@ enum AggregateFunction {
   // BIT_AND = 19;
   // BIT_OR = 20;
   // BIT_XOR = 21;
-//  BOOL_AND = 22;
-//  BOOL_OR = 23;
+  //  BOOL_AND = 22;
+  //  BOOL_OR = 23;
   // REGR_SLOPE = 26;
   // REGR_INTERCEPT = 27;
   // REGR_COUNT = 28;
@@ -517,6 +517,7 @@ message AggregateUDFExprNode {
   bool distinct = 5;
   LogicalExprNode filter = 3;
   repeated LogicalExprNode order_by = 4;
+  optional bytes fun_definition = 6;
 }
 
 message ScalarUDFExprNode {
@@ -551,6 +552,7 @@ message WindowExprNode {
   repeated LogicalExprNode order_by = 6;
   // repeated LogicalExprNode filter = 7;
   WindowFrame window_frame = 8;
+  optional bytes fun_definition = 10;
 }
 
 message BetweenNode {
@@ -856,6 +858,8 @@ message PhysicalAggregateExprNode {
   repeated PhysicalExprNode expr = 2;
   repeated PhysicalSortExprNode ordering_req = 5;
   bool distinct = 3;
+  bool ignore_nulls = 6;
+  optional bytes fun_definition = 7;
 }
 
 message PhysicalWindowExprNode {
@@ -869,6 +873,7 @@ message PhysicalWindowExprNode {
   repeated PhysicalSortExprNode order_by = 6;
   WindowFrame window_frame = 7;
   string name = 8;
+  optional bytes fun_definition = 9;
 }
 
 message PhysicalIsNull {
diff --git a/datafusion/proto/src/generated/pbjson.rs 
b/datafusion/proto/src/generated/pbjson.rs
index fa989480fa..8f77c24bd9 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -829,6 +829,9 @@ impl serde::Serialize for AggregateUdfExprNode {
         if !self.order_by.is_empty() {
             len += 1;
         }
+        if self.fun_definition.is_some() {
+            len += 1;
+        }
         let mut struct_ser = 
serializer.serialize_struct("datafusion.AggregateUDFExprNode", len)?;
         if !self.fun_name.is_empty() {
             struct_ser.serialize_field("funName", &self.fun_name)?;
@@ -845,6 +848,10 @@ impl serde::Serialize for AggregateUdfExprNode {
         if !self.order_by.is_empty() {
             struct_ser.serialize_field("orderBy", &self.order_by)?;
         }
+        if let Some(v) = self.fun_definition.as_ref() {
+            #[allow(clippy::needless_borrow)]
+            struct_ser.serialize_field("funDefinition", 
pbjson::private::base64::encode(&v).as_str())?;
+        }
         struct_ser.end()
     }
 }
@@ -862,6 +869,8 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode {
             "filter",
             "order_by",
             "orderBy",
+            "fun_definition",
+            "funDefinition",
         ];
 
         #[allow(clippy::enum_variant_names)]
@@ -871,6 +880,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode {
             Distinct,
             Filter,
             OrderBy,
+            FunDefinition,
         }
         impl<'de> serde::Deserialize<'de> for GeneratedField {
             fn deserialize<D>(deserializer: D) -> 
std::result::Result<GeneratedField, D::Error>
@@ -897,6 +907,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode {
                             "distinct" => Ok(GeneratedField::Distinct),
                             "filter" => Ok(GeneratedField::Filter),
                             "orderBy" | "order_by" => 
Ok(GeneratedField::OrderBy),
+                            "funDefinition" | "fun_definition" => 
Ok(GeneratedField::FunDefinition),
                             _ => Err(serde::de::Error::unknown_field(value, 
FIELDS)),
                         }
                     }
@@ -921,6 +932,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode {
                 let mut distinct__ = None;
                 let mut filter__ = None;
                 let mut order_by__ = None;
+                let mut fun_definition__ = None;
                 while let Some(k) = map_.next_key()? {
                     match k {
                         GeneratedField::FunName => {
@@ -953,6 +965,14 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode 
{
                             }
                             order_by__ = Some(map_.next_value()?);
                         }
+                        GeneratedField::FunDefinition => {
+                            if fun_definition__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("funDefinition"));
+                            }
+                            fun_definition__ = 
+                                
map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x|
 x.0)
+                            ;
+                        }
                     }
                 }
                 Ok(AggregateUdfExprNode {
@@ -961,6 +981,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode {
                     distinct: distinct__.unwrap_or_default(),
                     filter: filter__,
                     order_by: order_by__.unwrap_or_default(),
+                    fun_definition: fun_definition__,
                 })
             }
         }
@@ -12631,6 +12652,12 @@ impl serde::Serialize for PhysicalAggregateExprNode {
         if self.distinct {
             len += 1;
         }
+        if self.ignore_nulls {
+            len += 1;
+        }
+        if self.fun_definition.is_some() {
+            len += 1;
+        }
         if self.aggregate_function.is_some() {
             len += 1;
         }
@@ -12644,6 +12671,13 @@ impl serde::Serialize for PhysicalAggregateExprNode {
         if self.distinct {
             struct_ser.serialize_field("distinct", &self.distinct)?;
         }
+        if self.ignore_nulls {
+            struct_ser.serialize_field("ignoreNulls", &self.ignore_nulls)?;
+        }
+        if let Some(v) = self.fun_definition.as_ref() {
+            #[allow(clippy::needless_borrow)]
+            struct_ser.serialize_field("funDefinition", 
pbjson::private::base64::encode(&v).as_str())?;
+        }
         if let Some(v) = self.aggregate_function.as_ref() {
             match v {
                 
physical_aggregate_expr_node::AggregateFunction::AggrFunction(v) => {
@@ -12670,6 +12704,10 @@ impl<'de> serde::Deserialize<'de> for 
PhysicalAggregateExprNode {
             "ordering_req",
             "orderingReq",
             "distinct",
+            "ignore_nulls",
+            "ignoreNulls",
+            "fun_definition",
+            "funDefinition",
             "aggr_function",
             "aggrFunction",
             "user_defined_aggr_function",
@@ -12681,6 +12719,8 @@ impl<'de> serde::Deserialize<'de> for 
PhysicalAggregateExprNode {
             Expr,
             OrderingReq,
             Distinct,
+            IgnoreNulls,
+            FunDefinition,
             AggrFunction,
             UserDefinedAggrFunction,
         }
@@ -12707,6 +12747,8 @@ impl<'de> serde::Deserialize<'de> for 
PhysicalAggregateExprNode {
                             "expr" => Ok(GeneratedField::Expr),
                             "orderingReq" | "ordering_req" => 
Ok(GeneratedField::OrderingReq),
                             "distinct" => Ok(GeneratedField::Distinct),
+                            "ignoreNulls" | "ignore_nulls" => 
Ok(GeneratedField::IgnoreNulls),
+                            "funDefinition" | "fun_definition" => 
Ok(GeneratedField::FunDefinition),
                             "aggrFunction" | "aggr_function" => 
Ok(GeneratedField::AggrFunction),
                             "userDefinedAggrFunction" | 
"user_defined_aggr_function" => Ok(GeneratedField::UserDefinedAggrFunction),
                             _ => Err(serde::de::Error::unknown_field(value, 
FIELDS)),
@@ -12731,6 +12773,8 @@ impl<'de> serde::Deserialize<'de> for 
PhysicalAggregateExprNode {
                 let mut expr__ = None;
                 let mut ordering_req__ = None;
                 let mut distinct__ = None;
+                let mut ignore_nulls__ = None;
+                let mut fun_definition__ = None;
                 let mut aggregate_function__ = None;
                 while let Some(k) = map_.next_key()? {
                     match k {
@@ -12752,6 +12796,20 @@ impl<'de> serde::Deserialize<'de> for 
PhysicalAggregateExprNode {
                             }
                             distinct__ = Some(map_.next_value()?);
                         }
+                        GeneratedField::IgnoreNulls => {
+                            if ignore_nulls__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("ignoreNulls"));
+                            }
+                            ignore_nulls__ = Some(map_.next_value()?);
+                        }
+                        GeneratedField::FunDefinition => {
+                            if fun_definition__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("funDefinition"));
+                            }
+                            fun_definition__ = 
+                                
map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x|
 x.0)
+                            ;
+                        }
                         GeneratedField::AggrFunction => {
                             if aggregate_function__.is_some() {
                                 return 
Err(serde::de::Error::duplicate_field("aggrFunction"));
@@ -12770,6 +12828,8 @@ impl<'de> serde::Deserialize<'de> for 
PhysicalAggregateExprNode {
                     expr: expr__.unwrap_or_default(),
                     ordering_req: ordering_req__.unwrap_or_default(),
                     distinct: distinct__.unwrap_or_default(),
+                    ignore_nulls: ignore_nulls__.unwrap_or_default(),
+                    fun_definition: fun_definition__,
                     aggregate_function: aggregate_function__,
                 })
             }
@@ -15832,6 +15892,9 @@ impl serde::Serialize for PhysicalWindowExprNode {
         if !self.name.is_empty() {
             len += 1;
         }
+        if self.fun_definition.is_some() {
+            len += 1;
+        }
         if self.window_function.is_some() {
             len += 1;
         }
@@ -15851,6 +15914,10 @@ impl serde::Serialize for PhysicalWindowExprNode {
         if !self.name.is_empty() {
             struct_ser.serialize_field("name", &self.name)?;
         }
+        if let Some(v) = self.fun_definition.as_ref() {
+            #[allow(clippy::needless_borrow)]
+            struct_ser.serialize_field("funDefinition", 
pbjson::private::base64::encode(&v).as_str())?;
+        }
         if let Some(v) = self.window_function.as_ref() {
             match v {
                 physical_window_expr_node::WindowFunction::AggrFunction(v) => {
@@ -15886,6 +15953,8 @@ impl<'de> serde::Deserialize<'de> for 
PhysicalWindowExprNode {
             "window_frame",
             "windowFrame",
             "name",
+            "fun_definition",
+            "funDefinition",
             "aggr_function",
             "aggrFunction",
             "built_in_function",
@@ -15901,6 +15970,7 @@ impl<'de> serde::Deserialize<'de> for 
PhysicalWindowExprNode {
             OrderBy,
             WindowFrame,
             Name,
+            FunDefinition,
             AggrFunction,
             BuiltInFunction,
             UserDefinedAggrFunction,
@@ -15930,6 +16000,7 @@ impl<'de> serde::Deserialize<'de> for 
PhysicalWindowExprNode {
                             "orderBy" | "order_by" => 
Ok(GeneratedField::OrderBy),
                             "windowFrame" | "window_frame" => 
Ok(GeneratedField::WindowFrame),
                             "name" => Ok(GeneratedField::Name),
+                            "funDefinition" | "fun_definition" => 
Ok(GeneratedField::FunDefinition),
                             "aggrFunction" | "aggr_function" => 
Ok(GeneratedField::AggrFunction),
                             "builtInFunction" | "built_in_function" => 
Ok(GeneratedField::BuiltInFunction),
                             "userDefinedAggrFunction" | 
"user_defined_aggr_function" => Ok(GeneratedField::UserDefinedAggrFunction),
@@ -15957,6 +16028,7 @@ impl<'de> serde::Deserialize<'de> for 
PhysicalWindowExprNode {
                 let mut order_by__ = None;
                 let mut window_frame__ = None;
                 let mut name__ = None;
+                let mut fun_definition__ = None;
                 let mut window_function__ = None;
                 while let Some(k) = map_.next_key()? {
                     match k {
@@ -15990,6 +16062,14 @@ impl<'de> serde::Deserialize<'de> for 
PhysicalWindowExprNode {
                             }
                             name__ = Some(map_.next_value()?);
                         }
+                        GeneratedField::FunDefinition => {
+                            if fun_definition__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("funDefinition"));
+                            }
+                            fun_definition__ = 
+                                
map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x|
 x.0)
+                            ;
+                        }
                         GeneratedField::AggrFunction => {
                             if window_function__.is_some() {
                                 return 
Err(serde::de::Error::duplicate_field("aggrFunction"));
@@ -16016,6 +16096,7 @@ impl<'de> serde::Deserialize<'de> for 
PhysicalWindowExprNode {
                     order_by: order_by__.unwrap_or_default(),
                     window_frame: window_frame__,
                     name: name__.unwrap_or_default(),
+                    fun_definition: fun_definition__,
                     window_function: window_function__,
                 })
             }
@@ -20349,6 +20430,9 @@ impl serde::Serialize for WindowExprNode {
         if self.window_frame.is_some() {
             len += 1;
         }
+        if self.fun_definition.is_some() {
+            len += 1;
+        }
         if self.window_function.is_some() {
             len += 1;
         }
@@ -20365,6 +20449,10 @@ impl serde::Serialize for WindowExprNode {
         if let Some(v) = self.window_frame.as_ref() {
             struct_ser.serialize_field("windowFrame", v)?;
         }
+        if let Some(v) = self.fun_definition.as_ref() {
+            #[allow(clippy::needless_borrow)]
+            struct_ser.serialize_field("funDefinition", 
pbjson::private::base64::encode(&v).as_str())?;
+        }
         if let Some(v) = self.window_function.as_ref() {
             match v {
                 window_expr_node::WindowFunction::AggrFunction(v) => {
@@ -20402,6 +20490,8 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode {
             "orderBy",
             "window_frame",
             "windowFrame",
+            "fun_definition",
+            "funDefinition",
             "aggr_function",
             "aggrFunction",
             "built_in_function",
@@ -20416,6 +20506,7 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode {
             PartitionBy,
             OrderBy,
             WindowFrame,
+            FunDefinition,
             AggrFunction,
             BuiltInFunction,
             Udaf,
@@ -20445,6 +20536,7 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode {
                             "partitionBy" | "partition_by" => 
Ok(GeneratedField::PartitionBy),
                             "orderBy" | "order_by" => 
Ok(GeneratedField::OrderBy),
                             "windowFrame" | "window_frame" => 
Ok(GeneratedField::WindowFrame),
+                            "funDefinition" | "fun_definition" => 
Ok(GeneratedField::FunDefinition),
                             "aggrFunction" | "aggr_function" => 
Ok(GeneratedField::AggrFunction),
                             "builtInFunction" | "built_in_function" => 
Ok(GeneratedField::BuiltInFunction),
                             "udaf" => Ok(GeneratedField::Udaf),
@@ -20472,6 +20564,7 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode {
                 let mut partition_by__ = None;
                 let mut order_by__ = None;
                 let mut window_frame__ = None;
+                let mut fun_definition__ = None;
                 let mut window_function__ = None;
                 while let Some(k) = map_.next_key()? {
                     match k {
@@ -20499,6 +20592,14 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode {
                             }
                             window_frame__ = map_.next_value()?;
                         }
+                        GeneratedField::FunDefinition => {
+                            if fun_definition__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("funDefinition"));
+                            }
+                            fun_definition__ = 
+                                
map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x|
 x.0)
+                            ;
+                        }
                         GeneratedField::AggrFunction => {
                             if window_function__.is_some() {
                                 return 
Err(serde::de::Error::duplicate_field("aggrFunction"));
@@ -20530,6 +20631,7 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode {
                     partition_by: partition_by__.unwrap_or_default(),
                     order_by: order_by__.unwrap_or_default(),
                     window_frame: window_frame__,
+                    fun_definition: fun_definition__,
                     window_function: window_function__,
                 })
             }
diff --git a/datafusion/proto/src/generated/prost.rs 
b/datafusion/proto/src/generated/prost.rs
index 8407e545fe..605c56fa94 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -756,6 +756,8 @@ pub struct AggregateUdfExprNode {
     pub filter: 
::core::option::Option<::prost::alloc::boxed::Box<LogicalExprNode>>,
     #[prost(message, repeated, tag = "4")]
     pub order_by: ::prost::alloc::vec::Vec<LogicalExprNode>,
+    #[prost(bytes = "vec", optional, tag = "6")]
+    pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec<u8>>,
 }
 #[allow(clippy::derive_partial_eq_without_eq)]
 #[derive(Clone, PartialEq, ::prost::Message)]
@@ -779,6 +781,8 @@ pub struct WindowExprNode {
     /// repeated LogicalExprNode filter = 7;
     #[prost(message, optional, tag = "8")]
     pub window_frame: ::core::option::Option<WindowFrame>,
+    #[prost(bytes = "vec", optional, tag = "10")]
+    pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec<u8>>,
     #[prost(oneof = "window_expr_node::WindowFunction", tags = "1, 2, 3, 9")]
     pub window_function: 
::core::option::Option<window_expr_node::WindowFunction>,
 }
@@ -1291,6 +1295,10 @@ pub struct PhysicalAggregateExprNode {
     pub ordering_req: ::prost::alloc::vec::Vec<PhysicalSortExprNode>,
     #[prost(bool, tag = "3")]
     pub distinct: bool,
+    #[prost(bool, tag = "6")]
+    pub ignore_nulls: bool,
+    #[prost(bytes = "vec", optional, tag = "7")]
+    pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec<u8>>,
     #[prost(oneof = "physical_aggregate_expr_node::AggregateFunction", tags = 
"1, 4")]
     pub aggregate_function: ::core::option::Option<
         physical_aggregate_expr_node::AggregateFunction,
@@ -1320,6 +1328,8 @@ pub struct PhysicalWindowExprNode {
     pub window_frame: ::core::option::Option<WindowFrame>,
     #[prost(string, tag = "8")]
     pub name: ::prost::alloc::string::String,
+    #[prost(bytes = "vec", optional, tag = "9")]
+    pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec<u8>>,
     #[prost(oneof = "physical_window_expr_node::WindowFunction", tags = "1, 2, 
3")]
     pub window_function: ::core::option::Option<
         physical_window_expr_node::WindowFunction,
diff --git a/datafusion/proto/src/logical_plan/file_formats.rs 
b/datafusion/proto/src/logical_plan/file_formats.rs
index 106d563948..09e36a650b 100644
--- a/datafusion/proto/src/logical_plan/file_formats.rs
+++ b/datafusion/proto/src/logical_plan/file_formats.rs
@@ -86,22 +86,6 @@ impl LogicalExtensionCodec for CsvLogicalExtensionCodec {
     ) -> datafusion_common::Result<()> {
         Ok(())
     }
-
-    fn try_decode_udf(
-        &self,
-        name: &str,
-        __buf: &[u8],
-    ) -> datafusion_common::Result<Arc<datafusion_expr::ScalarUDF>> {
-        not_impl_err!("LogicalExtensionCodec is not provided for scalar 
function {name}")
-    }
-
-    fn try_encode_udf(
-        &self,
-        __node: &datafusion_expr::ScalarUDF,
-        __buf: &mut Vec<u8>,
-    ) -> datafusion_common::Result<()> {
-        Ok(())
-    }
 }
 
 #[derive(Debug)]
@@ -162,22 +146,6 @@ impl LogicalExtensionCodec for JsonLogicalExtensionCodec {
     ) -> datafusion_common::Result<()> {
         Ok(())
     }
-
-    fn try_decode_udf(
-        &self,
-        name: &str,
-        __buf: &[u8],
-    ) -> datafusion_common::Result<Arc<datafusion_expr::ScalarUDF>> {
-        not_impl_err!("LogicalExtensionCodec is not provided for scalar 
function {name}")
-    }
-
-    fn try_encode_udf(
-        &self,
-        __node: &datafusion_expr::ScalarUDF,
-        __buf: &mut Vec<u8>,
-    ) -> datafusion_common::Result<()> {
-        Ok(())
-    }
 }
 
 #[derive(Debug)]
@@ -238,22 +206,6 @@ impl LogicalExtensionCodec for 
ParquetLogicalExtensionCodec {
     ) -> datafusion_common::Result<()> {
         Ok(())
     }
-
-    fn try_decode_udf(
-        &self,
-        name: &str,
-        __buf: &[u8],
-    ) -> datafusion_common::Result<Arc<datafusion_expr::ScalarUDF>> {
-        not_impl_err!("LogicalExtensionCodec is not provided for scalar 
function {name}")
-    }
-
-    fn try_encode_udf(
-        &self,
-        __node: &datafusion_expr::ScalarUDF,
-        __buf: &mut Vec<u8>,
-    ) -> datafusion_common::Result<()> {
-        Ok(())
-    }
 }
 
 #[derive(Debug)]
@@ -314,22 +266,6 @@ impl LogicalExtensionCodec for ArrowLogicalExtensionCodec {
     ) -> datafusion_common::Result<()> {
         Ok(())
     }
-
-    fn try_decode_udf(
-        &self,
-        name: &str,
-        __buf: &[u8],
-    ) -> datafusion_common::Result<Arc<datafusion_expr::ScalarUDF>> {
-        not_impl_err!("LogicalExtensionCodec is not provided for scalar 
function {name}")
-    }
-
-    fn try_encode_udf(
-        &self,
-        __node: &datafusion_expr::ScalarUDF,
-        __buf: &mut Vec<u8>,
-    ) -> datafusion_common::Result<()> {
-        Ok(())
-    }
 }
 
 #[derive(Debug)]
@@ -390,20 +326,4 @@ impl LogicalExtensionCodec for AvroLogicalExtensionCodec {
     ) -> datafusion_common::Result<()> {
         Ok(())
     }
-
-    fn try_decode_udf(
-        &self,
-        name: &str,
-        __buf: &[u8],
-    ) -> datafusion_common::Result<Arc<datafusion_expr::ScalarUDF>> {
-        not_impl_err!("LogicalExtensionCodec is not provided for scalar 
function {name}")
-    }
-
-    fn try_encode_udf(
-        &self,
-        __node: &datafusion_expr::ScalarUDF,
-        __buf: &mut Vec<u8>,
-    ) -> datafusion_common::Result<()> {
-        Ok(())
-    }
 }
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs 
b/datafusion/proto/src/logical_plan/from_proto.rs
index 095c6a5097..b6b556a8ed 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -308,14 +308,17 @@ pub fn parse_expr(
                     let aggr_function = parse_i32_to_aggregate_function(i)?;
 
                     Ok(Expr::WindowFunction(WindowFunction::new(
-                        
datafusion_expr::expr::WindowFunctionDefinition::AggregateFunction(
-                            aggr_function,
-                        ),
-                        vec![parse_required_expr(expr.expr.as_deref(), 
registry, "expr", codec)?],
+                        
expr::WindowFunctionDefinition::AggregateFunction(aggr_function),
+                        vec![parse_required_expr(
+                            expr.expr.as_deref(),
+                            registry,
+                            "expr",
+                            codec,
+                        )?],
                         partition_by,
                         order_by,
                         window_frame,
-                        None
+                        None,
                     )))
                 }
                 window_expr_node::WindowFunction::BuiltInFunction(i) => {
@@ -329,26 +332,28 @@ pub fn parse_expr(
                             .unwrap_or_else(Vec::new);
 
                     Ok(Expr::WindowFunction(WindowFunction::new(
-                        
datafusion_expr::expr::WindowFunctionDefinition::BuiltInWindowFunction(
+                        expr::WindowFunctionDefinition::BuiltInWindowFunction(
                             built_in_function,
                         ),
                         args,
                         partition_by,
                         order_by,
                         window_frame,
-                        null_treatment
+                        null_treatment,
                     )))
                 }
                 window_expr_node::WindowFunction::Udaf(udaf_name) => {
-                    let udaf_function = registry.udaf(udaf_name)?;
+                    let udaf_function = match &expr.fun_definition {
+                        Some(buf) => codec.try_decode_udaf(udaf_name, buf)?,
+                        None => registry.udaf(udaf_name)?,
+                    };
+
                     let args =
                         parse_optional_expr(expr.expr.as_deref(), registry, 
codec)?
                             .map(|e| vec![e])
                             .unwrap_or_else(Vec::new);
                     Ok(Expr::WindowFunction(WindowFunction::new(
-                        
datafusion_expr::expr::WindowFunctionDefinition::AggregateUDF(
-                            udaf_function,
-                        ),
+                        
expr::WindowFunctionDefinition::AggregateUDF(udaf_function),
                         args,
                         partition_by,
                         order_by,
@@ -357,15 +362,17 @@ pub fn parse_expr(
                     )))
                 }
                 window_expr_node::WindowFunction::Udwf(udwf_name) => {
-                    let udwf_function = registry.udwf(udwf_name)?;
+                    let udwf_function = match &expr.fun_definition {
+                        Some(buf) => codec.try_decode_udwf(udwf_name, buf)?,
+                        None => registry.udwf(udwf_name)?,
+                    };
+
                     let args =
                         parse_optional_expr(expr.expr.as_deref(), registry, 
codec)?
                             .map(|e| vec![e])
                             .unwrap_or_else(Vec::new);
                     Ok(Expr::WindowFunction(WindowFunction::new(
-                        
datafusion_expr::expr::WindowFunctionDefinition::WindowUDF(
-                            udwf_function,
-                        ),
+                        
expr::WindowFunctionDefinition::WindowUDF(udwf_function),
                         args,
                         partition_by,
                         order_by,
@@ -613,7 +620,10 @@ pub fn parse_expr(
             )))
         }
         ExprType::AggregateUdfExpr(pb) => {
-            let agg_fn = registry.udaf(pb.fun_name.as_str())?;
+            let agg_fn = match &pb.fun_definition {
+                Some(buf) => codec.try_decode_udaf(&pb.fun_name, buf)?,
+                None => registry.udaf(&pb.fun_name)?,
+            };
 
             Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf(
                 agg_fn,
diff --git a/datafusion/proto/src/logical_plan/mod.rs 
b/datafusion/proto/src/logical_plan/mod.rs
index 664cd7e115..2a963fb13c 100644
--- a/datafusion/proto/src/logical_plan/mod.rs
+++ b/datafusion/proto/src/logical_plan/mod.rs
@@ -51,7 +51,6 @@ use datafusion_common::{
     context, internal_datafusion_err, internal_err, not_impl_err, 
DataFusionError,
     Result, TableReference,
 };
-use datafusion_expr::Unnest;
 use datafusion_expr::{
     dml,
     logical_plan::{
@@ -60,8 +59,9 @@ use datafusion_expr::{
         EmptyRelation, Extension, Join, JoinConstraint, Limit, Prepare, 
Projection,
         Repartition, Sort, SubqueryAlias, TableScan, Values, Window,
     },
-    DistinctOn, DropView, Expr, LogicalPlan, LogicalPlanBuilder, ScalarUDF,
+    DistinctOn, DropView, Expr, LogicalPlan, LogicalPlanBuilder, ScalarUDF, 
WindowUDF,
 };
+use datafusion_expr::{AggregateUDF, Unnest};
 
 use prost::bytes::BufMut;
 use prost::Message;
@@ -144,6 +144,24 @@ pub trait LogicalExtensionCodec: Debug + Send + Sync {
     fn try_encode_udf(&self, _node: &ScalarUDF, _buf: &mut Vec<u8>) -> 
Result<()> {
         Ok(())
     }
+
+    fn try_decode_udaf(&self, name: &str, _buf: &[u8]) -> 
Result<Arc<AggregateUDF>> {
+        not_impl_err!(
+            "LogicalExtensionCodec is not provided for aggregate function 
{name}"
+        )
+    }
+
+    fn try_encode_udaf(&self, _node: &AggregateUDF, _buf: &mut Vec<u8>) -> 
Result<()> {
+        Ok(())
+    }
+
+    fn try_decode_udwf(&self, name: &str, _buf: &[u8]) -> 
Result<Arc<WindowUDF>> {
+        not_impl_err!("LogicalExtensionCodec is not provided for window 
function {name}")
+    }
+
+    fn try_encode_udwf(&self, _node: &WindowUDF, _buf: &mut Vec<u8>) -> 
Result<()> {
+        Ok(())
+    }
 }
 
 #[derive(Debug, Clone)]
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs 
b/datafusion/proto/src/logical_plan/to_proto.rs
index d8f8ea002b..9607b918eb 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -319,25 +319,37 @@ pub fn serialize_expr(
             // TODO: support null treatment in proto
             null_treatment: _,
         }) => {
-            let window_function = match fun {
-                WindowFunctionDefinition::AggregateFunction(fun) => {
+            let (window_function, fun_definition) = match fun {
+                WindowFunctionDefinition::AggregateFunction(fun) => (
                     protobuf::window_expr_node::WindowFunction::AggrFunction(
                         protobuf::AggregateFunction::from(fun).into(),
-                    )
-                }
-                WindowFunctionDefinition::BuiltInWindowFunction(fun) => {
+                    ),
+                    None,
+                ),
+                WindowFunctionDefinition::BuiltInWindowFunction(fun) => (
                     
protobuf::window_expr_node::WindowFunction::BuiltInFunction(
                         protobuf::BuiltInWindowFunction::from(fun).into(),
-                    )
-                }
+                    ),
+                    None,
+                ),
                 WindowFunctionDefinition::AggregateUDF(aggr_udf) => {
-                    protobuf::window_expr_node::WindowFunction::Udaf(
-                        aggr_udf.name().to_string(),
+                    let mut buf = Vec::new();
+                    let _ = codec.try_encode_udaf(aggr_udf, &mut buf);
+                    (
+                        protobuf::window_expr_node::WindowFunction::Udaf(
+                            aggr_udf.name().to_string(),
+                        ),
+                        (!buf.is_empty()).then_some(buf),
                     )
                 }
                 WindowFunctionDefinition::WindowUDF(window_udf) => {
-                    protobuf::window_expr_node::WindowFunction::Udwf(
-                        window_udf.name().to_string(),
+                    let mut buf = Vec::new();
+                    let _ = codec.try_encode_udwf(window_udf, &mut buf);
+                    (
+                        protobuf::window_expr_node::WindowFunction::Udwf(
+                            window_udf.name().to_string(),
+                        ),
+                        (!buf.is_empty()).then_some(buf),
                     )
                 }
             };
@@ -358,6 +370,7 @@ pub fn serialize_expr(
                 partition_by,
                 order_by,
                 window_frame,
+                fun_definition,
             });
             protobuf::LogicalExprNode {
                 expr_type: Some(ExprType::WindowExpr(window_expr)),
@@ -395,23 +408,30 @@ pub fn serialize_expr(
                     expr_type: 
Some(ExprType::AggregateExpr(Box::new(aggregate_expr))),
                 }
             }
-            AggregateFunctionDefinition::UDF(fun) => protobuf::LogicalExprNode 
{
-                expr_type: Some(ExprType::AggregateUdfExpr(Box::new(
-                    protobuf::AggregateUdfExprNode {
-                        fun_name: fun.name().to_string(),
-                        args: serialize_exprs(args, codec)?,
-                        distinct: *distinct,
-                        filter: match filter {
-                            Some(e) => 
Some(Box::new(serialize_expr(e.as_ref(), codec)?)),
-                            None => None,
-                        },
-                        order_by: match order_by {
-                            Some(e) => serialize_exprs(e, codec)?,
-                            None => vec![],
+            AggregateFunctionDefinition::UDF(fun) => {
+                let mut buf = Vec::new();
+                let _ = codec.try_encode_udaf(fun, &mut buf);
+                protobuf::LogicalExprNode {
+                    expr_type: Some(ExprType::AggregateUdfExpr(Box::new(
+                        protobuf::AggregateUdfExprNode {
+                            fun_name: fun.name().to_string(),
+                            args: serialize_exprs(args, codec)?,
+                            distinct: *distinct,
+                            filter: match filter {
+                                Some(e) => {
+                                    Some(Box::new(serialize_expr(e.as_ref(), 
codec)?))
+                                }
+                                None => None,
+                            },
+                            order_by: match order_by {
+                                Some(e) => serialize_exprs(e, codec)?,
+                                None => vec![],
+                            },
+                            fun_definition: (!buf.is_empty()).then_some(buf),
                         },
-                    },
-                ))),
-            },
+                    ))),
+                }
+            }
         },
 
         Expr::ScalarVariable(_, _) => {
@@ -420,17 +440,13 @@ pub fn serialize_expr(
             ))
         }
         Expr::ScalarFunction(ScalarFunction { func, args }) => {
-            let args = serialize_exprs(args, codec)?;
             let mut buf = Vec::new();
-            let _ = codec.try_encode_udf(func.as_ref(), &mut buf);
-
-            let fun_definition = if buf.is_empty() { None } else { Some(buf) };
-
+            let _ = codec.try_encode_udf(func, &mut buf);
             protobuf::LogicalExprNode {
                 expr_type: 
Some(ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode {
                     fun_name: func.name().to_string(),
-                    fun_definition,
-                    args,
+                    fun_definition: (!buf.is_empty()).then_some(buf),
+                    args: serialize_exprs(args, codec)?,
                 })),
             }
         }
diff --git a/datafusion/proto/src/physical_plan/from_proto.rs 
b/datafusion/proto/src/physical_plan/from_proto.rs
index b7311c694d..5ecca51478 100644
--- a/datafusion/proto/src/physical_plan/from_proto.rs
+++ b/datafusion/proto/src/physical_plan/from_proto.rs
@@ -164,8 +164,10 @@ pub fn parse_physical_window_expr(
                 WindowFunctionDefinition::BuiltInWindowFunction(f.into())
             }
             
protobuf::physical_window_expr_node::WindowFunction::UserDefinedAggrFunction(udaf_name)
 => {
-                let agg_udf = registry.udaf(udaf_name)?;
-                WindowFunctionDefinition::AggregateUDF(agg_udf)
+                WindowFunctionDefinition::AggregateUDF(match 
&proto.fun_definition {
+                    Some(buf) => codec.try_decode_udaf(udaf_name, buf)?,
+                    None => registry.udaf(udaf_name)?
+                })
             }
         }
     } else {
diff --git a/datafusion/proto/src/physical_plan/mod.rs 
b/datafusion/proto/src/physical_plan/mod.rs
index 948a39bfe0..1220f42ded 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -61,7 +61,7 @@ use datafusion::physical_plan::{
     udaf, AggregateExpr, ExecutionPlan, InputOrderMode, PhysicalExpr, 
WindowExpr,
 };
 use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result};
-use datafusion_expr::ScalarUDF;
+use datafusion_expr::{AggregateUDF, ScalarUDF};
 
 use crate::common::{byte_to_string, str_to_byte};
 use crate::convert_required;
@@ -491,19 +491,22 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode {
                                                 &ordering_req,
                                                 &physical_schema,
                                                 name.to_string(),
-                                                false,
+                                                agg_node.ignore_nulls,
                                             )
                                         }
                                         
AggregateFunction::UserDefinedAggrFunction(udaf_name) => {
-                                            let agg_udf = 
registry.udaf(udaf_name)?;
+                                            let agg_udf = match 
&agg_node.fun_definition {
+                                                Some(buf) => 
extension_codec.try_decode_udaf(udaf_name, buf)?,
+                                                None => 
registry.udaf(udaf_name)?
+                                            };
+
                                             // TODO: 'logical_exprs' is not 
supported for UDAF yet.
                                             // approx_percentile_cont and 
approx_percentile_cont_weight are not supported for UDAF from protobuf yet.
                                             let logical_exprs = &[];
                                             // TODO: `order by` is not 
supported for UDAF yet
                                             let sort_exprs = &[];
                                             let ordering_req = &[];
-                                            let ignore_nulls = false;
-                                            
udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, logical_exprs, 
sort_exprs, ordering_req, &physical_schema, name, ignore_nulls, false)
+                                            
udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, logical_exprs, 
sort_exprs, ordering_req, &physical_schema, name, agg_node.ignore_nulls, 
agg_node.distinct)
                                         }
                                     }
                                 }).transpose()?.ok_or_else(|| {
@@ -2034,6 +2037,16 @@ pub trait PhysicalExtensionCodec: Debug + Send + Sync {
     ) -> Result<()> {
         not_impl_err!("PhysicalExtensionCodec is not provided")
     }
+
+    fn try_decode_udaf(&self, name: &str, _buf: &[u8]) -> 
Result<Arc<AggregateUDF>> {
+        not_impl_err!(
+            "PhysicalExtensionCodec is not provided for aggregate function 
{name}"
+        )
+    }
+
+    fn try_encode_udaf(&self, _node: &AggregateUDF, _buf: &mut Vec<u8>) -> 
Result<()> {
+        Ok(())
+    }
 }
 
 #[derive(Debug)]
diff --git a/datafusion/proto/src/physical_plan/to_proto.rs 
b/datafusion/proto/src/physical_plan/to_proto.rs
index d8d0291e1c..7ea2902cf3 100644
--- a/datafusion/proto/src/physical_plan/to_proto.rs
+++ b/datafusion/proto/src/physical_plan/to_proto.rs
@@ -40,6 +40,7 @@ use datafusion::{
     physical_plan::expressions::LikeExpr,
 };
 use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result};
+use datafusion_expr::WindowFrame;
 
 use crate::protobuf::{
     self, physical_aggregate_expr_node, physical_window_expr_node, 
PhysicalSortExprNode,
@@ -58,13 +59,17 @@ pub fn serialize_physical_aggr_expr(
 
     if let Some(a) = 
aggr_expr.as_any().downcast_ref::<AggregateFunctionExpr>() {
         let name = a.fun().name().to_string();
+        let mut buf = Vec::new();
+        codec.try_encode_udaf(a.fun(), &mut buf)?;
         return Ok(protobuf::PhysicalExprNode {
             expr_type: 
Some(protobuf::physical_expr_node::ExprType::AggregateExpr(
                 protobuf::PhysicalAggregateExprNode {
                     aggregate_function: 
Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(name)),
                     expr: expressions,
                     ordering_req,
-                    distinct: false,
+                    distinct: a.is_distinct(),
+                    ignore_nulls: a.ignore_nulls(),
+                    fun_definition: (!buf.is_empty()).then_some(buf)
                 },
             )),
         });
@@ -86,11 +91,55 @@ pub fn serialize_physical_aggr_expr(
                 expr: expressions,
                 ordering_req,
                 distinct,
+                ignore_nulls: false,
+                fun_definition: None,
             },
         )),
     })
 }
 
+fn serialize_physical_window_aggr_expr(
+    aggr_expr: &dyn AggregateExpr,
+    window_frame: &WindowFrame,
+    codec: &dyn PhysicalExtensionCodec,
+) -> Result<(physical_window_expr_node::WindowFunction, Option<Vec<u8>>)> {
+    if let Some(a) = 
aggr_expr.as_any().downcast_ref::<AggregateFunctionExpr>() {
+        if a.is_distinct() || a.ignore_nulls() {
+            // TODO
+            return not_impl_err!(
+                "Distinct aggregate functions not supported in window 
expressions"
+            );
+        }
+
+        let mut buf = Vec::new();
+        codec.try_encode_udaf(a.fun(), &mut buf)?;
+        Ok((
+            physical_window_expr_node::WindowFunction::UserDefinedAggrFunction(
+                a.fun().name().to_string(),
+            ),
+            (!buf.is_empty()).then_some(buf),
+        ))
+    } else {
+        let AggrFn { inner, distinct } = aggr_expr_to_aggr_fn(aggr_expr)?;
+        if distinct {
+            return not_impl_err!(
+                "Distinct aggregate functions not supported in window 
expressions"
+            );
+        }
+
+        if !window_frame.start_bound.is_unbounded() {
+            return Err(DataFusionError::Internal(format!(
+                "Unbounded start bound in WindowFrame = {window_frame}"
+            )));
+        }
+
+        Ok((
+            physical_window_expr_node::WindowFunction::AggrFunction(inner as 
i32),
+            None,
+        ))
+    }
+}
+
 pub fn serialize_physical_window_expr(
     window_expr: Arc<dyn WindowExpr>,
     codec: &dyn PhysicalExtensionCodec,
@@ -99,7 +148,7 @@ pub fn serialize_physical_window_expr(
     let mut args = window_expr.expressions().to_vec();
     let window_frame = window_expr.get_window_frame();
 
-    let window_function = if let Some(built_in_window_expr) =
+    let (window_function, fun_definition) = if let Some(built_in_window_expr) =
         expr.downcast_ref::<BuiltInWindowExpr>()
     {
         let expr = built_in_window_expr.get_built_in_func_expr();
@@ -160,58 +209,26 @@ pub fn serialize_physical_window_expr(
             return not_impl_err!("BuiltIn function not supported: {expr:?}");
         };
 
-        physical_window_expr_node::WindowFunction::BuiltInFunction(builtin_fn 
as i32)
+        (
+            
physical_window_expr_node::WindowFunction::BuiltInFunction(builtin_fn as i32),
+            None,
+        )
     } else if let Some(plain_aggr_window_expr) =
         expr.downcast_ref::<PlainAggregateWindowExpr>()
     {
-        let aggr_expr = plain_aggr_window_expr.get_aggregate_expr();
-        if let Some(a) = 
aggr_expr.as_any().downcast_ref::<AggregateFunctionExpr>() {
-            physical_window_expr_node::WindowFunction::UserDefinedAggrFunction(
-                a.fun().name().to_string(),
-            )
-        } else {
-            let AggrFn { inner, distinct } = aggr_expr_to_aggr_fn(
-                plain_aggr_window_expr.get_aggregate_expr().as_ref(),
-            )?;
-
-            if distinct {
-                return not_impl_err!(
-                    "Distinct aggregate functions not supported in window 
expressions"
-                );
-            }
-
-            if !window_frame.start_bound.is_unbounded() {
-                return Err(DataFusionError::Internal(format!("Invalid 
PlainAggregateWindowExpr = {window_expr:?} with WindowFrame = 
{window_frame:?}")));
-            }
-
-            physical_window_expr_node::WindowFunction::AggrFunction(inner as 
i32)
-        }
+        serialize_physical_window_aggr_expr(
+            plain_aggr_window_expr.get_aggregate_expr().as_ref(),
+            window_frame,
+            codec,
+        )?
     } else if let Some(sliding_aggr_window_expr) =
         expr.downcast_ref::<SlidingAggregateWindowExpr>()
     {
-        let aggr_expr = sliding_aggr_window_expr.get_aggregate_expr();
-        if let Some(a) = 
aggr_expr.as_any().downcast_ref::<AggregateFunctionExpr>() {
-            physical_window_expr_node::WindowFunction::UserDefinedAggrFunction(
-                a.fun().name().to_string(),
-            )
-        } else {
-            let AggrFn { inner, distinct } = aggr_expr_to_aggr_fn(
-                sliding_aggr_window_expr.get_aggregate_expr().as_ref(),
-            )?;
-
-            if distinct {
-                // TODO
-                return not_impl_err!(
-                    "Distinct aggregate functions not supported in window 
expressions"
-                );
-            }
-
-            if window_frame.start_bound.is_unbounded() {
-                return Err(DataFusionError::Internal(format!("Invalid 
SlidingAggregateWindowExpr = {window_expr:?} with WindowFrame = 
{window_frame:?}")));
-            }
-
-            physical_window_expr_node::WindowFunction::AggrFunction(inner as 
i32)
-        }
+        serialize_physical_window_aggr_expr(
+            sliding_aggr_window_expr.get_aggregate_expr().as_ref(),
+            window_frame,
+            codec,
+        )?
     } else {
         return not_impl_err!("WindowExpr not supported: {window_expr:?}");
     };
@@ -232,6 +249,7 @@ pub fn serialize_physical_window_expr(
         window_frame: Some(window_frame),
         window_function: Some(window_function),
         name: window_expr.name().to_string(),
+        fun_definition,
     })
 }
 
@@ -461,18 +479,14 @@ pub fn serialize_physical_expr(
             ))),
         })
     } else if let Some(expr) = expr.downcast_ref::<ScalarFunctionExpr>() {
-        let args = serialize_physical_exprs(expr.args().to_vec(), codec)?;
-
         let mut buf = Vec::new();
         codec.try_encode_udf(expr.fun(), &mut buf)?;
-
-        let fun_definition = if buf.is_empty() { None } else { Some(buf) };
         Ok(protobuf::PhysicalExprNode {
             expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarUdf(
                 protobuf::PhysicalScalarUdfNode {
                     name: expr.name().to_string(),
-                    args,
-                    fun_definition,
+                    args: serialize_physical_exprs(expr.args().to_vec(), 
codec)?,
+                    fun_definition: (!buf.is_empty()).then_some(buf),
                     return_type: Some(expr.return_type().try_into()?),
                 },
             )),
diff --git a/datafusion/proto/tests/cases/mod.rs 
b/datafusion/proto/tests/cases/mod.rs
index b17289205f..1f837b7f42 100644
--- a/datafusion/proto/tests/cases/mod.rs
+++ b/datafusion/proto/tests/cases/mod.rs
@@ -15,6 +15,105 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use std::any::Any;
+
+use arrow::datatypes::DataType;
+
+use datafusion_common::plan_err;
+use datafusion_expr::function::AccumulatorArgs;
+use datafusion_expr::{
+    Accumulator, AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, Signature, 
Volatility,
+};
+
 mod roundtrip_logical_plan;
 mod roundtrip_physical_plan;
 mod serialize;
+
+#[derive(Debug, PartialEq, Eq, Hash)]
+struct MyRegexUdf {
+    signature: Signature,
+    // regex as original string
+    pattern: String,
+}
+
+impl MyRegexUdf {
+    fn new(pattern: String) -> Self {
+        let signature = Signature::exact(vec![DataType::Utf8], 
Volatility::Immutable);
+        Self { signature, pattern }
+    }
+}
+
+/// Implement the ScalarUDFImpl trait for MyRegexUdf
+impl ScalarUDFImpl for MyRegexUdf {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+    fn name(&self) -> &str {
+        "regex_udf"
+    }
+    fn signature(&self) -> &Signature {
+        &self.signature
+    }
+    fn return_type(&self, args: &[DataType]) -> 
datafusion_common::Result<DataType> {
+        if matches!(args, [DataType::Utf8]) {
+            Ok(DataType::Int64)
+        } else {
+            plan_err!("regex_udf only accepts Utf8 arguments")
+        }
+    }
+    fn invoke(
+        &self,
+        _args: &[ColumnarValue],
+    ) -> datafusion_common::Result<ColumnarValue> {
+        unimplemented!()
+    }
+}
+
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct MyRegexUdfNode {
+    #[prost(string, tag = "1")]
+    pub pattern: String,
+}
+
+#[derive(Debug, PartialEq, Eq, Hash)]
+struct MyAggregateUDF {
+    signature: Signature,
+    result: String,
+}
+
+impl MyAggregateUDF {
+    fn new(result: String) -> Self {
+        let signature = Signature::exact(vec![DataType::Int64], 
Volatility::Immutable);
+        Self { signature, result }
+    }
+}
+
+impl AggregateUDFImpl for MyAggregateUDF {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+    fn name(&self) -> &str {
+        "aggregate_udf"
+    }
+    fn signature(&self) -> &Signature {
+        &self.signature
+    }
+    fn return_type(
+        &self,
+        _arg_types: &[DataType],
+    ) -> datafusion_common::Result<DataType> {
+        Ok(DataType::Utf8)
+    }
+    fn accumulator(
+        &self,
+        _acc_args: AccumulatorArgs,
+    ) -> datafusion_common::Result<Box<dyn Accumulator>> {
+        unimplemented!()
+    }
+}
+
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct MyAggregateUdfNode {
+    #[prost(string, tag = "1")]
+    pub result: String,
+}
diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs 
b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
index d0209d811b..0117502f40 100644
--- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
@@ -28,15 +28,12 @@ use arrow::datatypes::{
     DataType, Field, Fields, Int32Type, IntervalDayTimeType, 
IntervalMonthDayNanoType,
     IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode,
 };
+use prost::Message;
+
 use datafusion::datasource::file_format::arrow::ArrowFormatFactory;
 use datafusion::datasource::file_format::csv::CsvFormatFactory;
 use datafusion::datasource::file_format::format_as_file_type;
 use datafusion::datasource::file_format::parquet::ParquetFormatFactory;
-use datafusion_proto::logical_plan::file_formats::{
-    ArrowLogicalExtensionCodec, CsvLogicalExtensionCodec, 
ParquetLogicalExtensionCodec,
-};
-use prost::Message;
-
 use datafusion::datasource::provider::TableProviderFactory;
 use datafusion::datasource::TableProvider;
 use datafusion::execution::session_state::SessionStateBuilder;
@@ -62,9 +59,9 @@ use datafusion_expr::expr::{
 };
 use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore};
 use datafusion_expr::{
-    Accumulator, AggregateExt, AggregateFunction, ColumnarValue, ExprSchemable,
-    LogicalPlan, Operator, PartitionEvaluator, ScalarUDF, ScalarUDFImpl, 
Signature,
-    TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits,
+    Accumulator, AggregateExt, AggregateFunction, AggregateUDF, ColumnarValue,
+    ExprSchemable, Literal, LogicalPlan, Operator, PartitionEvaluator, 
ScalarUDF,
+    Signature, TryCast, Volatility, WindowFrame, WindowFrameBound, 
WindowFrameUnits,
     WindowFunctionDefinition, WindowUDF, WindowUDFImpl,
 };
 use datafusion_functions_aggregate::average::avg_udaf;
@@ -76,12 +73,17 @@ use datafusion_proto::bytes::{
     logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec,
     logical_plan_to_bytes, logical_plan_to_bytes_with_extension_codec,
 };
+use datafusion_proto::logical_plan::file_formats::{
+    ArrowLogicalExtensionCodec, CsvLogicalExtensionCodec, 
ParquetLogicalExtensionCodec,
+};
 use datafusion_proto::logical_plan::to_proto::serialize_expr;
 use datafusion_proto::logical_plan::{
     from_proto, DefaultLogicalExtensionCodec, LogicalExtensionCodec,
 };
 use datafusion_proto::protobuf;
 
+use crate::cases::{MyAggregateUDF, MyAggregateUdfNode, MyRegexUdf, 
MyRegexUdfNode};
+
 #[cfg(feature = "json")]
 fn roundtrip_json_test(proto: &protobuf::LogicalExprNode) {
     let string = serde_json::to_string(proto).unwrap();
@@ -744,7 +746,7 @@ pub mod proto {
         pub k: u64,
 
         #[prost(message, optional, tag = "2")]
-        pub expr: 
::core::option::Option<datafusion_proto::protobuf::LogicalExprNode>,
+        pub expr: Option<datafusion_proto::protobuf::LogicalExprNode>,
     }
 
     #[derive(Clone, PartialEq, Eq, ::prost::Message)]
@@ -752,12 +754,6 @@ pub mod proto {
         #[prost(uint64, tag = "1")]
         pub k: u64,
     }
-
-    #[derive(Clone, PartialEq, ::prost::Message)]
-    pub struct MyRegexUdfNode {
-        #[prost(string, tag = "1")]
-        pub pattern: String,
-    }
 }
 
 #[derive(PartialEq, Eq, Hash)]
@@ -890,51 +886,9 @@ impl LogicalExtensionCodec for TopKExtensionCodec {
 }
 
 #[derive(Debug)]
-struct MyRegexUdf {
-    signature: Signature,
-    // regex as original string
-    pattern: String,
-}
-
-impl MyRegexUdf {
-    fn new(pattern: String) -> Self {
-        Self {
-            signature: Signature::uniform(
-                1,
-                vec![DataType::Int32],
-                Volatility::Immutable,
-            ),
-            pattern,
-        }
-    }
-}
-
-/// Implement the ScalarUDFImpl trait for MyRegexUdf
-impl ScalarUDFImpl for MyRegexUdf {
-    fn as_any(&self) -> &dyn Any {
-        self
-    }
-    fn name(&self) -> &str {
-        "regex_udf"
-    }
-    fn signature(&self) -> &Signature {
-        &self.signature
-    }
-    fn return_type(&self, args: &[DataType]) -> Result<DataType> {
-        if !matches!(args.first(), Some(&DataType::Utf8)) {
-            return plan_err!("regex_udf only accepts Utf8 arguments");
-        }
-        Ok(DataType::Int32)
-    }
-    fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
-        unimplemented!()
-    }
-}
-
-#[derive(Debug)]
-pub struct ScalarUDFExtensionCodec {}
+pub struct UDFExtensionCodec;
 
-impl LogicalExtensionCodec for ScalarUDFExtensionCodec {
+impl LogicalExtensionCodec for UDFExtensionCodec {
     fn try_decode(
         &self,
         _buf: &[u8],
@@ -969,13 +923,11 @@ impl LogicalExtensionCodec for ScalarUDFExtensionCodec {
 
     fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result<Arc<ScalarUDF>> 
{
         if name == "regex_udf" {
-            let proto = proto::MyRegexUdfNode::decode(buf).map_err(|err| {
-                DataFusionError::Internal(format!("failed to decode regex_udf: 
{}", err))
+            let proto = MyRegexUdfNode::decode(buf).map_err(|err| {
+                DataFusionError::Internal(format!("failed to decode regex_udf: 
{err}"))
             })?;
 
-            Ok(Arc::new(ScalarUDF::new_from_impl(MyRegexUdf::new(
-                proto.pattern,
-            ))))
+            Ok(Arc::new(ScalarUDF::from(MyRegexUdf::new(proto.pattern))))
         } else {
             not_impl_err!("unrecognized scalar UDF implementation, cannot 
decode")
         }
@@ -984,11 +936,39 @@ impl LogicalExtensionCodec for ScalarUDFExtensionCodec {
     fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec<u8>) -> 
Result<()> {
         let binding = node.inner();
         let udf = binding.as_any().downcast_ref::<MyRegexUdf>().unwrap();
-        let proto = proto::MyRegexUdfNode {
+        let proto = MyRegexUdfNode {
             pattern: udf.pattern.clone(),
         };
-        proto.encode(buf).map_err(|e| {
-            DataFusionError::Internal(format!("failed to encode udf: {e:?}"))
+        proto.encode(buf).map_err(|err| {
+            DataFusionError::Internal(format!("failed to encode udf: {err}"))
+        })?;
+        Ok(())
+    }
+
+    fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> 
Result<Arc<AggregateUDF>> {
+        if name == "aggregate_udf" {
+            let proto = MyAggregateUdfNode::decode(buf).map_err(|err| {
+                DataFusionError::Internal(format!(
+                    "failed to decode aggregate_udf: {err}"
+                ))
+            })?;
+
+            Ok(Arc::new(AggregateUDF::from(MyAggregateUDF::new(
+                proto.result,
+            ))))
+        } else {
+            not_impl_err!("unrecognized aggregate UDF implementation, cannot 
decode")
+        }
+    }
+
+    fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec<u8>) -> 
Result<()> {
+        let binding = node.inner();
+        let udf = binding.as_any().downcast_ref::<MyAggregateUDF>().unwrap();
+        let proto = MyAggregateUdfNode {
+            result: udf.result.clone(),
+        };
+        proto.encode(buf).map_err(|err| {
+            DataFusionError::Internal(format!("failed to encode udf: {err}"))
         })?;
         Ok(())
     }
@@ -1563,8 +1543,7 @@ fn roundtrip_null_scalar_values() {
 
     for test_case in test_types.into_iter() {
         let proto_scalar: protobuf::ScalarValue = 
(&test_case).try_into().unwrap();
-        let returned_scalar: datafusion::scalar::ScalarValue =
-            (&proto_scalar).try_into().unwrap();
+        let returned_scalar: ScalarValue = (&proto_scalar).try_into().unwrap();
         assert_eq!(format!("{:?}", &test_case), 
format!("{returned_scalar:?}"));
     }
 }
@@ -1893,22 +1872,19 @@ fn roundtrip_aggregate_udf() {
     struct Dummy {}
 
     impl Accumulator for Dummy {
-        fn state(&mut self) -> datafusion::error::Result<Vec<ScalarValue>> {
+        fn state(&mut self) -> Result<Vec<ScalarValue>> {
             Ok(vec![])
         }
 
-        fn update_batch(
-            &mut self,
-            _values: &[ArrayRef],
-        ) -> datafusion::error::Result<()> {
+        fn update_batch(&mut self, _values: &[ArrayRef]) -> Result<()> {
             Ok(())
         }
 
-        fn merge_batch(&mut self, _states: &[ArrayRef]) -> 
datafusion::error::Result<()> {
+        fn merge_batch(&mut self, _states: &[ArrayRef]) -> Result<()> {
             Ok(())
         }
 
-        fn evaluate(&mut self) -> datafusion::error::Result<ScalarValue> {
+        fn evaluate(&mut self) -> Result<ScalarValue> {
             Ok(ScalarValue::Float64(None))
         }
 
@@ -1976,25 +1952,27 @@ fn roundtrip_scalar_udf() {
 
 #[test]
 fn roundtrip_scalar_udf_extension_codec() {
-    let pattern = ".*";
-    let udf = ScalarUDF::from(MyRegexUdf::new(pattern.to_string()));
-    let test_expr =
-        Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(udf.clone()), 
vec![]));
-
+    let udf = ScalarUDF::from(MyRegexUdf::new(".*".to_owned()));
+    let test_expr = udf.call(vec!["foo".lit()]);
     let ctx = SessionContext::new();
-    ctx.register_udf(udf);
-
-    let extension_codec = ScalarUDFExtensionCodec {};
-    let proto: protobuf::LogicalExprNode =
-        match serialize_expr(&test_expr, &extension_codec) {
-            Ok(p) => p,
-            Err(e) => panic!("Error serializing expression: {:?}", e),
-        };
-    let round_trip: Expr =
-        from_proto::parse_expr(&proto, &ctx, &extension_codec).unwrap();
+    let proto = serialize_expr(&test_expr, 
&UDFExtensionCodec).expect("serialize expr");
+    let round_trip =
+        from_proto::parse_expr(&proto, &ctx, &UDFExtensionCodec).expect("parse 
expr");
 
     assert_eq!(format!("{:?}", &test_expr), format!("{round_trip:?}"));
+    roundtrip_json_test(&proto);
+}
+
+#[test]
+fn roundtrip_aggregate_udf_extension_codec() {
+    let udf = AggregateUDF::from(MyAggregateUDF::new("DataFusion".to_owned()));
+    let test_expr = udf.call(vec![42.lit()]);
+    let ctx = SessionContext::new();
+    let proto = serialize_expr(&test_expr, 
&UDFExtensionCodec).expect("serialize expr");
+    let round_trip =
+        from_proto::parse_expr(&proto, &ctx, &UDFExtensionCodec).expect("parse 
expr");
 
+    assert_eq!(format!("{:?}", &test_expr), format!("{round_trip:?}"));
     roundtrip_json_test(&proto);
 }
 
@@ -2120,22 +2098,19 @@ fn roundtrip_window() {
     struct DummyAggr {}
 
     impl Accumulator for DummyAggr {
-        fn state(&mut self) -> datafusion::error::Result<Vec<ScalarValue>> {
+        fn state(&mut self) -> Result<Vec<ScalarValue>> {
             Ok(vec![])
         }
 
-        fn update_batch(
-            &mut self,
-            _values: &[ArrayRef],
-        ) -> datafusion::error::Result<()> {
+        fn update_batch(&mut self, _values: &[ArrayRef]) -> Result<()> {
             Ok(())
         }
 
-        fn merge_batch(&mut self, _states: &[ArrayRef]) -> 
datafusion::error::Result<()> {
+        fn merge_batch(&mut self, _states: &[ArrayRef]) -> Result<()> {
             Ok(())
         }
 
-        fn evaluate(&mut self) -> datafusion::error::Result<ScalarValue> {
+        fn evaluate(&mut self) -> Result<ScalarValue> {
             Ok(ScalarValue::Float64(None))
         }
 
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs 
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index 2fcc65008f..fba6dfe425 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -15,7 +15,6 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use arrow::array::RecordBatch;
 use std::any::Any;
 use std::fmt::Display;
 use std::hash::Hasher;
@@ -23,8 +22,8 @@ use std::ops::Deref;
 use std::sync::Arc;
 use std::vec;
 
+use arrow::array::RecordBatch;
 use arrow::csv::WriterBuilder;
-use datafusion::functions_aggregate::sum::sum_udaf;
 use prost::Message;
 
 use datafusion::arrow::array::ArrayRef;
@@ -40,9 +39,10 @@ use datafusion::datasource::physical_plan::{
     FileSinkConfig, ParquetExec,
 };
 use datafusion::execution::FunctionRegistry;
+use datafusion::functions_aggregate::sum::sum_udaf;
 use datafusion::logical_expr::{create_udf, JoinType, Operator, Volatility};
 use datafusion::physical_expr::aggregate::utils::down_cast_any_ref;
-use datafusion::physical_expr::expressions::Max;
+use datafusion::physical_expr::expressions::{Literal, Max};
 use datafusion::physical_expr::window::SlidingAggregateWindowExpr;
 use datafusion::physical_expr::{PhysicalSortRequirement, ScalarFunctionExpr};
 use datafusion::physical_plan::aggregates::{
@@ -70,7 +70,7 @@ use datafusion::physical_plan::windows::{
     BuiltInWindowExpr, PlainAggregateWindowExpr, WindowAggExec,
 };
 use datafusion::physical_plan::{
-    udaf, AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, Statistics,
+    AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, Statistics,
 };
 use datafusion::prelude::SessionContext;
 use datafusion::scalar::ScalarValue;
@@ -79,10 +79,10 @@ use 
datafusion_common::file_options::csv_writer::CsvWriterOptions;
 use datafusion_common::file_options::json_writer::JsonWriterOptions;
 use datafusion_common::parsers::CompressionTypeVariant;
 use datafusion_common::stats::Precision;
-use datafusion_common::{internal_err, not_impl_err, plan_err, DataFusionError, 
Result};
+use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result};
 use datafusion_expr::{
     Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, 
ScalarUDF,
-    ScalarUDFImpl, Signature, SimpleAggregateUDF, WindowFrame, 
WindowFrameBound,
+    Signature, SimpleAggregateUDF, WindowFrame, WindowFrameBound,
 };
 use datafusion_functions_aggregate::average::avg_udaf;
 use datafusion_functions_aggregate::nth_value::nth_value_udaf;
@@ -92,6 +92,8 @@ use datafusion_proto::physical_plan::{
 };
 use datafusion_proto::protobuf;
 
+use crate::cases::{MyAggregateUDF, MyAggregateUdfNode, MyRegexUdf, 
MyRegexUdfNode};
+
 /// Perform a serde roundtrip and assert that the string representation of the 
before and after plans
 /// are identical. Note that this often isn't sufficient to guarantee that no 
information is
 /// lost during serde because the string representation of a plan often only 
shows a subset of state.
@@ -312,7 +314,7 @@ fn roundtrip_window() -> Result<()> {
     );
 
     let args = vec![cast(col("a", &schema)?, &schema, DataType::Float64)?];
-    let sum_expr = udaf::create_aggregate_expr(
+    let sum_expr = create_aggregate_expr(
         &sum_udaf(),
         &args,
         &[],
@@ -367,7 +369,7 @@ fn rountrip_aggregate() -> Result<()> {
             false,
         )?],
         // NTH_VALUE
-        vec![udaf::create_aggregate_expr(
+        vec![create_aggregate_expr(
             &nth_value_udaf(),
             &[col("b", &schema)?, lit(1u64)],
             &[],
@@ -379,7 +381,7 @@ fn rountrip_aggregate() -> Result<()> {
             false,
         )?],
         // STRING_AGG
-        vec![udaf::create_aggregate_expr(
+        vec![create_aggregate_expr(
             &AggregateUDF::new_from_impl(StringAgg::new()),
             &[
                 cast(col("b", &schema)?, &schema, DataType::Utf8)?,
@@ -490,7 +492,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> {
     let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =
         vec![(col("a", &schema)?, "unused".to_string())];
 
-    let aggregates: Vec<Arc<dyn AggregateExpr>> = 
vec![udaf::create_aggregate_expr(
+    let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![create_aggregate_expr(
         &udaf,
         &[col("b", &schema)?],
         &[],
@@ -845,123 +847,161 @@ fn roundtrip_scalar_udf() -> Result<()> {
     roundtrip_test_with_context(Arc::new(project), &ctx)
 }
 
-#[test]
-fn roundtrip_scalar_udf_extension_codec() -> Result<()> {
-    #[derive(Debug)]
-    struct MyRegexUdf {
-        signature: Signature,
-        // regex as original string
-        pattern: String,
+#[derive(Debug)]
+struct UDFExtensionCodec;
+
+impl PhysicalExtensionCodec for UDFExtensionCodec {
+    fn try_decode(
+        &self,
+        _buf: &[u8],
+        _inputs: &[Arc<dyn ExecutionPlan>],
+        _registry: &dyn FunctionRegistry,
+    ) -> Result<Arc<dyn ExecutionPlan>> {
+        not_impl_err!("No extension codec provided")
     }
 
-    impl MyRegexUdf {
-        fn new(pattern: String) -> Self {
-            Self {
-                signature: Signature::exact(vec![DataType::Utf8], 
Volatility::Immutable),
-                pattern,
-            }
-        }
+    fn try_encode(
+        &self,
+        _node: Arc<dyn ExecutionPlan>,
+        _buf: &mut Vec<u8>,
+    ) -> Result<()> {
+        not_impl_err!("No extension codec provided")
     }
 
-    /// Implement the ScalarUDFImpl trait for MyRegexUdf
-    impl ScalarUDFImpl for MyRegexUdf {
-        fn as_any(&self) -> &dyn Any {
-            self
-        }
+    fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result<Arc<ScalarUDF>> 
{
+        if name == "regex_udf" {
+            let proto = MyRegexUdfNode::decode(buf).map_err(|err| {
+                DataFusionError::Internal(format!("failed to decode regex_udf: 
{err}"))
+            })?;
 
-        fn name(&self) -> &str {
-            "regex_udf"
+            Ok(Arc::new(ScalarUDF::from(MyRegexUdf::new(proto.pattern))))
+        } else {
+            not_impl_err!("unrecognized scalar UDF implementation, cannot 
decode")
         }
+    }
 
-        fn signature(&self) -> &Signature {
-            &self.signature
+    fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec<u8>) -> 
Result<()> {
+        let binding = node.inner();
+        if let Some(udf) = binding.as_any().downcast_ref::<MyRegexUdf>() {
+            let proto = MyRegexUdfNode {
+                pattern: udf.pattern.clone(),
+            };
+            proto.encode(buf).map_err(|err| {
+                DataFusionError::Internal(format!("failed to encode udf: 
{err}"))
+            })?;
         }
+        Ok(())
+    }
 
-        fn return_type(&self, args: &[DataType]) -> Result<DataType> {
-            if !matches!(args.first(), Some(&DataType::Utf8)) {
-                return plan_err!("regex_udf only accepts Utf8 arguments");
-            }
-            Ok(DataType::Int64)
+    fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> 
Result<Arc<AggregateUDF>> {
+        if name == "aggregate_udf" {
+            let proto = MyAggregateUdfNode::decode(buf).map_err(|err| {
+                DataFusionError::Internal(format!(
+                    "failed to decode aggregate_udf: {err}"
+                ))
+            })?;
+
+            Ok(Arc::new(AggregateUDF::from(MyAggregateUDF::new(
+                proto.result,
+            ))))
+        } else {
+            not_impl_err!("unrecognized scalar UDF implementation, cannot 
decode")
         }
+    }
 
-        fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
-            unimplemented!()
+    fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec<u8>) -> 
Result<()> {
+        let binding = node.inner();
+        if let Some(udf) = binding.as_any().downcast_ref::<MyAggregateUDF>() {
+            let proto = MyAggregateUdfNode {
+                result: udf.result.clone(),
+            };
+            proto.encode(buf).map_err(|err| {
+                DataFusionError::Internal(format!("failed to encode udf: 
{err:?}"))
+            })?;
         }
+        Ok(())
     }
+}
 
-    #[derive(Clone, PartialEq, ::prost::Message)]
-    pub struct MyRegexUdfNode {
-        #[prost(string, tag = "1")]
-        pub pattern: String,
-    }
+#[test]
+fn roundtrip_scalar_udf_extension_codec() -> Result<()> {
+    let field_text = Field::new("text", DataType::Utf8, true);
+    let field_published = Field::new("published", DataType::Boolean, false);
+    let field_author = Field::new("author", DataType::Utf8, false);
+    let schema = Arc::new(Schema::new(vec![field_text, field_published, 
field_author]));
+    let input = Arc::new(EmptyExec::new(schema.clone()));
 
-    #[derive(Debug)]
-    pub struct ScalarUDFExtensionCodec {}
+    let udf_expr = Arc::new(ScalarFunctionExpr::new(
+        "regex_udf",
+        Arc::new(ScalarUDF::from(MyRegexUdf::new(".*".to_string()))),
+        vec![col("text", &schema)?],
+        DataType::Int64,
+    ));
 
-    impl PhysicalExtensionCodec for ScalarUDFExtensionCodec {
-        fn try_decode(
-            &self,
-            _buf: &[u8],
-            _inputs: &[Arc<dyn ExecutionPlan>],
-            _registry: &dyn FunctionRegistry,
-        ) -> Result<Arc<dyn ExecutionPlan>> {
-            not_impl_err!("No extension codec provided")
-        }
+    let filter = Arc::new(FilterExec::try_new(
+        Arc::new(BinaryExpr::new(
+            col("published", &schema)?,
+            Operator::And,
+            Arc::new(BinaryExpr::new(udf_expr.clone(), Operator::Gt, lit(0))),
+        )),
+        input,
+    )?);
 
-        fn try_encode(
-            &self,
-            _node: Arc<dyn ExecutionPlan>,
-            _buf: &mut Vec<u8>,
-        ) -> Result<()> {
-            not_impl_err!("No extension codec provided")
-        }
+    let window = Arc::new(WindowAggExec::try_new(
+        vec![Arc::new(PlainAggregateWindowExpr::new(
+            Arc::new(Max::new(udf_expr.clone(), "max", DataType::Int64)),
+            &[col("author", &schema)?],
+            &[],
+            Arc::new(WindowFrame::new(None)),
+        ))],
+        filter,
+        vec![col("author", &schema)?],
+    )?);
 
-        fn try_decode_udf(&self, name: &str, buf: &[u8]) -> 
Result<Arc<ScalarUDF>> {
-            if name == "regex_udf" {
-                let proto = MyRegexUdfNode::decode(buf).map_err(|err| {
-                    DataFusionError::Internal(format!(
-                        "failed to decode regex_udf: {}",
-                        err
-                    ))
-                })?;
-
-                Ok(Arc::new(ScalarUDF::new_from_impl(MyRegexUdf::new(
-                    proto.pattern,
-                ))))
-            } else {
-                not_impl_err!("unrecognized scalar UDF implementation, cannot 
decode")
-            }
-        }
+    let aggregate = Arc::new(AggregateExec::try_new(
+        AggregateMode::Final,
+        PhysicalGroupBy::new(vec![], vec![], vec![]),
+        vec![Arc::new(Max::new(udf_expr, "max", DataType::Int64))],
+        vec![None],
+        window,
+        schema.clone(),
+    )?);
 
-        fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec<u8>) -> 
Result<()> {
-            let binding = node.inner();
-            if let Some(udf) = binding.as_any().downcast_ref::<MyRegexUdf>() {
-                let proto = MyRegexUdfNode {
-                    pattern: udf.pattern.clone(),
-                };
-                proto.encode(buf).map_err(|e| {
-                    DataFusionError::Internal(format!("failed to encode udf: 
{e:?}"))
-                })?;
-            }
-            Ok(())
-        }
-    }
+    let ctx = SessionContext::new();
+    roundtrip_test_and_return(aggregate, &ctx, &UDFExtensionCodec)?;
+    Ok(())
+}
 
+#[test]
+fn roundtrip_aggregate_udf_extension_codec() -> Result<()> {
     let field_text = Field::new("text", DataType::Utf8, true);
     let field_published = Field::new("published", DataType::Boolean, false);
     let field_author = Field::new("author", DataType::Utf8, false);
     let schema = Arc::new(Schema::new(vec![field_text, field_published, 
field_author]));
     let input = Arc::new(EmptyExec::new(schema.clone()));
 
-    let pattern = ".*";
-    let udf = ScalarUDF::from(MyRegexUdf::new(pattern.to_string()));
     let udf_expr = Arc::new(ScalarFunctionExpr::new(
-        udf.name(),
-        Arc::new(udf.clone()),
+        "regex_udf",
+        Arc::new(ScalarUDF::from(MyRegexUdf::new(".*".to_string()))),
         vec![col("text", &schema)?],
         DataType::Int64,
     ));
 
+    let udaf = AggregateUDF::from(MyAggregateUDF::new("result".to_string()));
+    let aggr_args: [Arc<dyn PhysicalExpr>; 1] =
+        [Arc::new(Literal::new(ScalarValue::from(42)))];
+    let aggr_expr = create_aggregate_expr(
+        &udaf,
+        &aggr_args,
+        &[],
+        &[],
+        &[],
+        &schema,
+        "aggregate_udf",
+        false,
+        false,
+    )?;
+
     let filter = Arc::new(FilterExec::try_new(
         Arc::new(BinaryExpr::new(
             col("published", &schema)?,
@@ -973,7 +1013,7 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> {
 
     let window = Arc::new(WindowAggExec::try_new(
         vec![Arc::new(PlainAggregateWindowExpr::new(
-            Arc::new(Max::new(udf_expr.clone(), "max", DataType::Int64)),
+            aggr_expr,
             &[col("author", &schema)?],
             &[],
             Arc::new(WindowFrame::new(None)),
@@ -982,18 +1022,29 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> {
         vec![col("author", &schema)?],
     )?);
 
+    let aggr_expr = create_aggregate_expr(
+        &udaf,
+        &aggr_args,
+        &[],
+        &[],
+        &[],
+        &schema,
+        "aggregate_udf",
+        true,
+        true,
+    )?;
+
     let aggregate = Arc::new(AggregateExec::try_new(
         AggregateMode::Final,
         PhysicalGroupBy::new(vec![], vec![], vec![]),
-        vec![Arc::new(Max::new(udf_expr, "max", DataType::Int64))],
+        vec![aggr_expr],
         vec![None],
         window,
         schema.clone(),
     )?);
 
     let ctx = SessionContext::new();
-    let codec = ScalarUDFExtensionCodec {};
-    roundtrip_test_and_return(aggregate, &ctx, &codec)?;
+    roundtrip_test_and_return(aggregate, &ctx, &UDFExtensionCodec)?;
     Ok(())
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to