This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new f5d88d1790 Support serialization/deserialization for custom physical 
exprs in proto (#11387)
f5d88d1790 is described below

commit f5d88d1790eea85910ae5590a353ae17318f8401
Author: 张林伟 <[email protected]>
AuthorDate: Sun Jul 14 05:44:32 2024 +0800

    Support serialization/deserialization for custom physical exprs in proto 
(#11387)
    
    * Add PhysicalExtensionExprNode
    
    * regen proto
    
    * Add ser/de extension expr logic
    
    * Add test and fix clippy lint
---
 datafusion/proto/proto/datafusion.proto            |   7 +
 datafusion/proto/src/generated/pbjson.rs           | 124 +++++++++++++++++
 datafusion/proto/src/generated/prost.rs            |  12 +-
 datafusion/proto/src/physical_plan/from_proto.rs   |   8 ++
 datafusion/proto/src/physical_plan/mod.rs          |  16 +++
 datafusion/proto/src/physical_plan/to_proto.rs     |  19 ++-
 .../proto/tests/cases/roundtrip_physical_plan.rs   | 147 ++++++++++++++++++++-
 7 files changed, 330 insertions(+), 3 deletions(-)

diff --git a/datafusion/proto/proto/datafusion.proto 
b/datafusion/proto/proto/datafusion.proto
index 345765b08b..9ef884531e 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -836,6 +836,8 @@ message PhysicalExprNode {
     // was PhysicalDateTimeIntervalExprNode date_time_interval_expr = 17;
 
     PhysicalLikeExprNode like_expr = 18;
+
+    PhysicalExtensionExprNode extension = 19;
   }
 }
 
@@ -942,6 +944,11 @@ message PhysicalNegativeNode {
   PhysicalExprNode expr = 1;
 }
 
+message PhysicalExtensionExprNode {
+  bytes expr = 1;
+  repeated PhysicalExprNode inputs = 2;
+}
+
 message FilterExecNode {
   PhysicalPlanNode input = 1;
   PhysicalExprNode expr = 2;
diff --git a/datafusion/proto/src/generated/pbjson.rs 
b/datafusion/proto/src/generated/pbjson.rs
index 905f0d9849..fa989480fa 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -13543,6 +13543,9 @@ impl serde::Serialize for PhysicalExprNode {
                 physical_expr_node::ExprType::LikeExpr(v) => {
                     struct_ser.serialize_field("likeExpr", v)?;
                 }
+                physical_expr_node::ExprType::Extension(v) => {
+                    struct_ser.serialize_field("extension", v)?;
+                }
             }
         }
         struct_ser.end()
@@ -13582,6 +13585,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode {
             "scalarUdf",
             "like_expr",
             "likeExpr",
+            "extension",
         ];
 
         #[allow(clippy::enum_variant_names)]
@@ -13602,6 +13606,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode {
             WindowExpr,
             ScalarUdf,
             LikeExpr,
+            Extension,
         }
         impl<'de> serde::Deserialize<'de> for GeneratedField {
             fn deserialize<D>(deserializer: D) -> 
std::result::Result<GeneratedField, D::Error>
@@ -13639,6 +13644,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode {
                             "windowExpr" | "window_expr" => 
Ok(GeneratedField::WindowExpr),
                             "scalarUdf" | "scalar_udf" => 
Ok(GeneratedField::ScalarUdf),
                             "likeExpr" | "like_expr" => 
Ok(GeneratedField::LikeExpr),
+                            "extension" => Ok(GeneratedField::Extension),
                             _ => Err(serde::de::Error::unknown_field(value, 
FIELDS)),
                         }
                     }
@@ -13771,6 +13777,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode 
{
                                 return 
Err(serde::de::Error::duplicate_field("likeExpr"));
                             }
                             expr_type__ = 
map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::LikeExpr)
+;
+                        }
+                        GeneratedField::Extension => {
+                            if expr_type__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("extension"));
+                            }
+                            expr_type__ = 
map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Extension)
 ;
                         }
                     }
@@ -13783,6 +13796,117 @@ impl<'de> serde::Deserialize<'de> for 
PhysicalExprNode {
         deserializer.deserialize_struct("datafusion.PhysicalExprNode", FIELDS, 
GeneratedVisitor)
     }
 }
+impl serde::Serialize for PhysicalExtensionExprNode {
+    #[allow(deprecated)]
+    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, 
S::Error>
+    where
+        S: serde::Serializer,
+    {
+        use serde::ser::SerializeStruct;
+        let mut len = 0;
+        if !self.expr.is_empty() {
+            len += 1;
+        }
+        if !self.inputs.is_empty() {
+            len += 1;
+        }
+        let mut struct_ser = 
serializer.serialize_struct("datafusion.PhysicalExtensionExprNode", len)?;
+        if !self.expr.is_empty() {
+            #[allow(clippy::needless_borrow)]
+            struct_ser.serialize_field("expr", 
pbjson::private::base64::encode(&self.expr).as_str())?;
+        }
+        if !self.inputs.is_empty() {
+            struct_ser.serialize_field("inputs", &self.inputs)?;
+        }
+        struct_ser.end()
+    }
+}
+impl<'de> serde::Deserialize<'de> for PhysicalExtensionExprNode {
+    #[allow(deprecated)]
+    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+    where
+        D: serde::Deserializer<'de>,
+    {
+        const FIELDS: &[&str] = &[
+            "expr",
+            "inputs",
+        ];
+
+        #[allow(clippy::enum_variant_names)]
+        enum GeneratedField {
+            Expr,
+            Inputs,
+        }
+        impl<'de> serde::Deserialize<'de> for GeneratedField {
+            fn deserialize<D>(deserializer: D) -> 
std::result::Result<GeneratedField, D::Error>
+            where
+                D: serde::Deserializer<'de>,
+            {
+                struct GeneratedVisitor;
+
+                impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+                    type Value = GeneratedField;
+
+                    fn expecting(&self, formatter: &mut 
std::fmt::Formatter<'_>) -> std::fmt::Result {
+                        write!(formatter, "expected one of: {:?}", &FIELDS)
+                    }
+
+                    #[allow(unused_variables)]
+                    fn visit_str<E>(self, value: &str) -> 
std::result::Result<GeneratedField, E>
+                    where
+                        E: serde::de::Error,
+                    {
+                        match value {
+                            "expr" => Ok(GeneratedField::Expr),
+                            "inputs" => Ok(GeneratedField::Inputs),
+                            _ => Err(serde::de::Error::unknown_field(value, 
FIELDS)),
+                        }
+                    }
+                }
+                deserializer.deserialize_identifier(GeneratedVisitor)
+            }
+        }
+        struct GeneratedVisitor;
+        impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+            type Value = PhysicalExtensionExprNode;
+
+            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> 
std::fmt::Result {
+                formatter.write_str("struct 
datafusion.PhysicalExtensionExprNode")
+            }
+
+            fn visit_map<V>(self, mut map_: V) -> 
std::result::Result<PhysicalExtensionExprNode, V::Error>
+                where
+                    V: serde::de::MapAccess<'de>,
+            {
+                let mut expr__ = None;
+                let mut inputs__ = None;
+                while let Some(k) = map_.next_key()? {
+                    match k {
+                        GeneratedField::Expr => {
+                            if expr__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("expr"));
+                            }
+                            expr__ = 
+                                
Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0)
+                            ;
+                        }
+                        GeneratedField::Inputs => {
+                            if inputs__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("inputs"));
+                            }
+                            inputs__ = Some(map_.next_value()?);
+                        }
+                    }
+                }
+                Ok(PhysicalExtensionExprNode {
+                    expr: expr__.unwrap_or_default(),
+                    inputs: inputs__.unwrap_or_default(),
+                })
+            }
+        }
+        
deserializer.deserialize_struct("datafusion.PhysicalExtensionExprNode", FIELDS, 
GeneratedVisitor)
+    }
+}
 impl serde::Serialize for PhysicalExtensionNode {
     #[allow(deprecated)]
     fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, 
S::Error>
diff --git a/datafusion/proto/src/generated/prost.rs 
b/datafusion/proto/src/generated/prost.rs
index b16d26ee6e..8407e545fe 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1218,7 +1218,7 @@ pub struct PhysicalExtensionNode {
 pub struct PhysicalExprNode {
     #[prost(
         oneof = "physical_expr_node::ExprType",
-        tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 18"
+        tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 18, 19"
     )]
     pub expr_type: ::core::option::Option<physical_expr_node::ExprType>,
 }
@@ -1266,6 +1266,8 @@ pub mod physical_expr_node {
         ScalarUdf(super::PhysicalScalarUdfNode),
         #[prost(message, tag = "18")]
         LikeExpr(::prost::alloc::boxed::Box<super::PhysicalLikeExprNode>),
+        #[prost(message, tag = "19")]
+        Extension(super::PhysicalExtensionExprNode),
     }
 }
 #[allow(clippy::derive_partial_eq_without_eq)]
@@ -1456,6 +1458,14 @@ pub struct PhysicalNegativeNode {
 }
 #[allow(clippy::derive_partial_eq_without_eq)]
 #[derive(Clone, PartialEq, ::prost::Message)]
+pub struct PhysicalExtensionExprNode {
+    #[prost(bytes = "vec", tag = "1")]
+    pub expr: ::prost::alloc::vec::Vec<u8>,
+    #[prost(message, repeated, tag = "2")]
+    pub inputs: ::prost::alloc::vec::Vec<PhysicalExprNode>,
+}
+#[allow(clippy::derive_partial_eq_without_eq)]
+#[derive(Clone, PartialEq, ::prost::Message)]
 pub struct FilterExecNode {
     #[prost(message, optional, boxed, tag = "1")]
     pub input: 
::core::option::Option<::prost::alloc::boxed::Box<PhysicalPlanNode>>,
diff --git a/datafusion/proto/src/physical_plan/from_proto.rs 
b/datafusion/proto/src/physical_plan/from_proto.rs
index e94bb3b8ef..52fbd5cbdc 100644
--- a/datafusion/proto/src/physical_plan/from_proto.rs
+++ b/datafusion/proto/src/physical_plan/from_proto.rs
@@ -394,6 +394,14 @@ pub fn parse_physical_expr(
                 codec,
             )?,
         )),
+        ExprType::Extension(extension) => {
+            let inputs: Vec<Arc<dyn PhysicalExpr>> = extension
+                .inputs
+                .iter()
+                .map(|e| parse_physical_expr(e, registry, input_schema, codec))
+                .collect::<Result<_>>()?;
+            (codec.try_decode_expr(extension.expr.as_slice(), &inputs)?) as _
+        }
     };
 
     Ok(pexpr)
diff --git a/datafusion/proto/src/physical_plan/mod.rs 
b/datafusion/proto/src/physical_plan/mod.rs
index 56e7027047..e5429945e9 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -2018,6 +2018,22 @@ pub trait PhysicalExtensionCodec: Debug + Send + Sync {
     fn try_encode_udf(&self, _node: &ScalarUDF, _buf: &mut Vec<u8>) -> 
Result<()> {
         Ok(())
     }
+
+    fn try_decode_expr(
+        &self,
+        _buf: &[u8],
+        _inputs: &[Arc<dyn PhysicalExpr>],
+    ) -> Result<Arc<dyn PhysicalExpr>> {
+        not_impl_err!("PhysicalExtensionCodec is not provided")
+    }
+
+    fn try_encode_expr(
+        &self,
+        _node: Arc<dyn PhysicalExpr>,
+        _buf: &mut Vec<u8>,
+    ) -> Result<()> {
+        not_impl_err!("PhysicalExtensionCodec is not provided")
+    }
 }
 
 #[derive(Debug)]
diff --git a/datafusion/proto/src/physical_plan/to_proto.rs 
b/datafusion/proto/src/physical_plan/to_proto.rs
index 5e982ad2af..9c95acc1dc 100644
--- a/datafusion/proto/src/physical_plan/to_proto.rs
+++ b/datafusion/proto/src/physical_plan/to_proto.rs
@@ -495,7 +495,24 @@ pub fn serialize_physical_expr(
             ))),
         })
     } else {
-        internal_err!("physical_plan::to_proto() unsupported expression 
{value:?}")
+        let mut buf: Vec<u8> = vec![];
+        match codec.try_encode_expr(Arc::clone(&value), &mut buf) {
+            Ok(_) => {
+                let inputs: Vec<protobuf::PhysicalExprNode> = value
+                    .children()
+                    .into_iter()
+                    .map(|e| serialize_physical_expr(Arc::clone(e), codec))
+                    .collect::<Result<_>>()?;
+                Ok(protobuf::PhysicalExprNode {
+                    expr_type: 
Some(protobuf::physical_expr_node::ExprType::Extension(
+                        protobuf::PhysicalExtensionExprNode { expr: buf, 
inputs },
+                    )),
+                })
+            }
+            Err(e) => internal_err!(
+                "Unsupported physical expr and extension codec failed with 
[{e}]. Expr: {value:?}"
+            ),
+        }
     }
 }
 
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs 
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index d8d85ace1a..2fcc65008f 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -15,7 +15,10 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use arrow::array::RecordBatch;
 use std::any::Any;
+use std::fmt::Display;
+use std::hash::Hasher;
 use std::ops::Deref;
 use std::sync::Arc;
 use std::vec;
@@ -38,6 +41,7 @@ use datafusion::datasource::physical_plan::{
 };
 use datafusion::execution::FunctionRegistry;
 use datafusion::logical_expr::{create_udf, JoinType, Operator, Volatility};
+use datafusion::physical_expr::aggregate::utils::down_cast_any_ref;
 use datafusion::physical_expr::expressions::Max;
 use datafusion::physical_expr::window::SlidingAggregateWindowExpr;
 use datafusion::physical_expr::{PhysicalSortRequirement, ScalarFunctionExpr};
@@ -75,7 +79,7 @@ use 
datafusion_common::file_options::csv_writer::CsvWriterOptions;
 use datafusion_common::file_options::json_writer::JsonWriterOptions;
 use datafusion_common::parsers::CompressionTypeVariant;
 use datafusion_common::stats::Precision;
-use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result};
+use datafusion_common::{internal_err, not_impl_err, plan_err, DataFusionError, 
Result};
 use datafusion_expr::{
     Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, 
ScalarUDF,
     ScalarUDFImpl, Signature, SimpleAggregateUDF, WindowFrame, 
WindowFrameBound,
@@ -658,6 +662,147 @@ async fn 
roundtrip_parquet_exec_with_table_partition_cols() -> Result<()> {
     roundtrip_test(ParquetExec::builder(scan_config).build_arc())
 }
 
+#[test]
+fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> {
+    let scan_config = FileScanConfig {
+        object_store_url: ObjectStoreUrl::local_filesystem(),
+        file_schema: Arc::new(Schema::new(vec![Field::new(
+            "col",
+            DataType::Utf8,
+            false,
+        )])),
+        file_groups: vec![vec![PartitionedFile::new(
+            "/path/to/file.parquet".to_string(),
+            1024,
+        )]],
+        statistics: Statistics {
+            num_rows: Precision::Inexact(100),
+            total_byte_size: Precision::Inexact(1024),
+            column_statistics: 
Statistics::unknown_column(&Arc::new(Schema::new(vec![
+                Field::new("col", DataType::Utf8, false),
+            ]))),
+        },
+        projection: None,
+        limit: None,
+        table_partition_cols: vec![],
+        output_ordering: vec![],
+    };
+
+    #[derive(Debug, Hash, Clone)]
+    struct CustomPredicateExpr {
+        inner: Arc<dyn PhysicalExpr>,
+    }
+    impl Display for CustomPredicateExpr {
+        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+            write!(f, "CustomPredicateExpr")
+        }
+    }
+    impl PartialEq<dyn Any> for CustomPredicateExpr {
+        fn eq(&self, other: &dyn Any) -> bool {
+            down_cast_any_ref(other)
+                .downcast_ref::<Self>()
+                .map(|x| self.inner.eq(&x.inner))
+                .unwrap_or(false)
+        }
+    }
+    impl PhysicalExpr for CustomPredicateExpr {
+        fn as_any(&self) -> &dyn Any {
+            self
+        }
+
+        fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
+            unreachable!()
+        }
+
+        fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
+            unreachable!()
+        }
+
+        fn evaluate(&self, _batch: &RecordBatch) -> Result<ColumnarValue> {
+            unreachable!()
+        }
+
+        fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
+            vec![&self.inner]
+        }
+
+        fn with_new_children(
+            self: Arc<Self>,
+            _children: Vec<Arc<dyn PhysicalExpr>>,
+        ) -> Result<Arc<dyn PhysicalExpr>> {
+            todo!()
+        }
+
+        fn dyn_hash(&self, _state: &mut dyn Hasher) {
+            unreachable!()
+        }
+    }
+
+    #[derive(Debug)]
+    struct CustomPhysicalExtensionCodec;
+    impl PhysicalExtensionCodec for CustomPhysicalExtensionCodec {
+        fn try_decode(
+            &self,
+            _buf: &[u8],
+            _inputs: &[Arc<dyn ExecutionPlan>],
+            _registry: &dyn FunctionRegistry,
+        ) -> Result<Arc<dyn ExecutionPlan>> {
+            unreachable!()
+        }
+
+        fn try_encode(
+            &self,
+            _node: Arc<dyn ExecutionPlan>,
+            _buf: &mut Vec<u8>,
+        ) -> Result<()> {
+            unreachable!()
+        }
+
+        fn try_decode_expr(
+            &self,
+            buf: &[u8],
+            inputs: &[Arc<dyn PhysicalExpr>],
+        ) -> Result<Arc<dyn PhysicalExpr>> {
+            if buf == "CustomPredicateExpr".as_bytes() {
+                Ok(Arc::new(CustomPredicateExpr {
+                    inner: inputs[0].clone(),
+                }))
+            } else {
+                internal_err!("Not supported")
+            }
+        }
+
+        fn try_encode_expr(
+            &self,
+            node: Arc<dyn PhysicalExpr>,
+            buf: &mut Vec<u8>,
+        ) -> Result<()> {
+            if node
+                .as_ref()
+                .as_any()
+                .downcast_ref::<CustomPredicateExpr>()
+                .is_some()
+            {
+                buf.extend_from_slice("CustomPredicateExpr".as_bytes());
+                Ok(())
+            } else {
+                internal_err!("Not supported")
+            }
+        }
+    }
+
+    let custom_predicate_expr = Arc::new(CustomPredicateExpr {
+        inner: Arc::new(Column::new("col", 1)),
+    });
+    let exec_plan = ParquetExec::builder(scan_config)
+        .with_predicate(custom_predicate_expr)
+        .build_arc();
+
+    let ctx = SessionContext::new();
+    roundtrip_test_and_return(exec_plan, &ctx, &CustomPhysicalExtensionCodec 
{})?;
+    Ok(())
+}
+
 #[test]
 fn roundtrip_scalar_udf() -> Result<()> {
     let field_a = Field::new("a", DataType::Int64, false);


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to