This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new f5d88d1790 Support serialization/deserialization for custom physical
exprs in proto (#11387)
f5d88d1790 is described below
commit f5d88d1790eea85910ae5590a353ae17318f8401
Author: 张林伟 <[email protected]>
AuthorDate: Sun Jul 14 05:44:32 2024 +0800
Support serialization/deserialization for custom physical exprs in proto
(#11387)
* Add PhysicalExtensionExprNode
* regen proto
* Add ser/de extension expr logic
* Add test and fix clippy lint
---
datafusion/proto/proto/datafusion.proto | 7 +
datafusion/proto/src/generated/pbjson.rs | 124 +++++++++++++++++
datafusion/proto/src/generated/prost.rs | 12 +-
datafusion/proto/src/physical_plan/from_proto.rs | 8 ++
datafusion/proto/src/physical_plan/mod.rs | 16 +++
datafusion/proto/src/physical_plan/to_proto.rs | 19 ++-
.../proto/tests/cases/roundtrip_physical_plan.rs | 147 ++++++++++++++++++++-
7 files changed, 330 insertions(+), 3 deletions(-)
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index 345765b08b..9ef884531e 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -836,6 +836,8 @@ message PhysicalExprNode {
// was PhysicalDateTimeIntervalExprNode date_time_interval_expr = 17;
PhysicalLikeExprNode like_expr = 18;
+
+ PhysicalExtensionExprNode extension = 19;
}
}
@@ -942,6 +944,11 @@ message PhysicalNegativeNode {
PhysicalExprNode expr = 1;
}
+message PhysicalExtensionExprNode {
+ bytes expr = 1;
+ repeated PhysicalExprNode inputs = 2;
+}
+
message FilterExecNode {
PhysicalPlanNode input = 1;
PhysicalExprNode expr = 2;
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index 905f0d9849..fa989480fa 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -13543,6 +13543,9 @@ impl serde::Serialize for PhysicalExprNode {
physical_expr_node::ExprType::LikeExpr(v) => {
struct_ser.serialize_field("likeExpr", v)?;
}
+ physical_expr_node::ExprType::Extension(v) => {
+ struct_ser.serialize_field("extension", v)?;
+ }
}
}
struct_ser.end()
@@ -13582,6 +13585,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode {
"scalarUdf",
"like_expr",
"likeExpr",
+ "extension",
];
#[allow(clippy::enum_variant_names)]
@@ -13602,6 +13606,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode {
WindowExpr,
ScalarUdf,
LikeExpr,
+ Extension,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
@@ -13639,6 +13644,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode {
"windowExpr" | "window_expr" =>
Ok(GeneratedField::WindowExpr),
"scalarUdf" | "scalar_udf" =>
Ok(GeneratedField::ScalarUdf),
"likeExpr" | "like_expr" =>
Ok(GeneratedField::LikeExpr),
+ "extension" => Ok(GeneratedField::Extension),
_ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
}
}
@@ -13771,6 +13777,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode
{
return
Err(serde::de::Error::duplicate_field("likeExpr"));
}
expr_type__ =
map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::LikeExpr)
+;
+ }
+ GeneratedField::Extension => {
+ if expr_type__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("extension"));
+ }
+ expr_type__ =
map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Extension)
;
}
}
@@ -13783,6 +13796,117 @@ impl<'de> serde::Deserialize<'de> for
PhysicalExprNode {
deserializer.deserialize_struct("datafusion.PhysicalExprNode", FIELDS,
GeneratedVisitor)
}
}
+impl serde::Serialize for PhysicalExtensionExprNode {
+ #[allow(deprecated)]
+ fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok,
S::Error>
+ where
+ S: serde::Serializer,
+ {
+ use serde::ser::SerializeStruct;
+ let mut len = 0;
+ if !self.expr.is_empty() {
+ len += 1;
+ }
+ if !self.inputs.is_empty() {
+ len += 1;
+ }
+ let mut struct_ser =
serializer.serialize_struct("datafusion.PhysicalExtensionExprNode", len)?;
+ if !self.expr.is_empty() {
+ #[allow(clippy::needless_borrow)]
+ struct_ser.serialize_field("expr",
pbjson::private::base64::encode(&self.expr).as_str())?;
+ }
+ if !self.inputs.is_empty() {
+ struct_ser.serialize_field("inputs", &self.inputs)?;
+ }
+ struct_ser.end()
+ }
+}
+impl<'de> serde::Deserialize<'de> for PhysicalExtensionExprNode {
+ #[allow(deprecated)]
+ fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ const FIELDS: &[&str] = &[
+ "expr",
+ "inputs",
+ ];
+
+ #[allow(clippy::enum_variant_names)]
+ enum GeneratedField {
+ Expr,
+ Inputs,
+ }
+ impl<'de> serde::Deserialize<'de> for GeneratedField {
+ fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ struct GeneratedVisitor;
+
+ impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+ type Value = GeneratedField;
+
+ fn expecting(&self, formatter: &mut
std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(formatter, "expected one of: {:?}", &FIELDS)
+ }
+
+ #[allow(unused_variables)]
+ fn visit_str<E>(self, value: &str) ->
std::result::Result<GeneratedField, E>
+ where
+ E: serde::de::Error,
+ {
+ match value {
+ "expr" => Ok(GeneratedField::Expr),
+ "inputs" => Ok(GeneratedField::Inputs),
+ _ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
+ }
+ }
+ }
+ deserializer.deserialize_identifier(GeneratedVisitor)
+ }
+ }
+ struct GeneratedVisitor;
+ impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+ type Value = PhysicalExtensionExprNode;
+
+ fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) ->
std::fmt::Result {
+ formatter.write_str("struct
datafusion.PhysicalExtensionExprNode")
+ }
+
+ fn visit_map<V>(self, mut map_: V) ->
std::result::Result<PhysicalExtensionExprNode, V::Error>
+ where
+ V: serde::de::MapAccess<'de>,
+ {
+ let mut expr__ = None;
+ let mut inputs__ = None;
+ while let Some(k) = map_.next_key()? {
+ match k {
+ GeneratedField::Expr => {
+ if expr__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("expr"));
+ }
+ expr__ =
+
Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0)
+ ;
+ }
+ GeneratedField::Inputs => {
+ if inputs__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("inputs"));
+ }
+ inputs__ = Some(map_.next_value()?);
+ }
+ }
+ }
+ Ok(PhysicalExtensionExprNode {
+ expr: expr__.unwrap_or_default(),
+ inputs: inputs__.unwrap_or_default(),
+ })
+ }
+ }
+
deserializer.deserialize_struct("datafusion.PhysicalExtensionExprNode", FIELDS,
GeneratedVisitor)
+ }
+}
impl serde::Serialize for PhysicalExtensionNode {
#[allow(deprecated)]
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok,
S::Error>
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index b16d26ee6e..8407e545fe 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1218,7 +1218,7 @@ pub struct PhysicalExtensionNode {
pub struct PhysicalExprNode {
#[prost(
oneof = "physical_expr_node::ExprType",
- tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 18"
+ tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 18, 19"
)]
pub expr_type: ::core::option::Option<physical_expr_node::ExprType>,
}
@@ -1266,6 +1266,8 @@ pub mod physical_expr_node {
ScalarUdf(super::PhysicalScalarUdfNode),
#[prost(message, tag = "18")]
LikeExpr(::prost::alloc::boxed::Box<super::PhysicalLikeExprNode>),
+ #[prost(message, tag = "19")]
+ Extension(super::PhysicalExtensionExprNode),
}
}
#[allow(clippy::derive_partial_eq_without_eq)]
@@ -1456,6 +1458,14 @@ pub struct PhysicalNegativeNode {
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct PhysicalExtensionExprNode {
+ #[prost(bytes = "vec", tag = "1")]
+ pub expr: ::prost::alloc::vec::Vec<u8>,
+ #[prost(message, repeated, tag = "2")]
+ pub inputs: ::prost::alloc::vec::Vec<PhysicalExprNode>,
+}
+#[allow(clippy::derive_partial_eq_without_eq)]
+#[derive(Clone, PartialEq, ::prost::Message)]
pub struct FilterExecNode {
#[prost(message, optional, boxed, tag = "1")]
pub input:
::core::option::Option<::prost::alloc::boxed::Box<PhysicalPlanNode>>,
diff --git a/datafusion/proto/src/physical_plan/from_proto.rs
b/datafusion/proto/src/physical_plan/from_proto.rs
index e94bb3b8ef..52fbd5cbdc 100644
--- a/datafusion/proto/src/physical_plan/from_proto.rs
+++ b/datafusion/proto/src/physical_plan/from_proto.rs
@@ -394,6 +394,14 @@ pub fn parse_physical_expr(
codec,
)?,
)),
+ ExprType::Extension(extension) => {
+ let inputs: Vec<Arc<dyn PhysicalExpr>> = extension
+ .inputs
+ .iter()
+ .map(|e| parse_physical_expr(e, registry, input_schema, codec))
+ .collect::<Result<_>>()?;
+ (codec.try_decode_expr(extension.expr.as_slice(), &inputs)?) as _
+ }
};
Ok(pexpr)
diff --git a/datafusion/proto/src/physical_plan/mod.rs
b/datafusion/proto/src/physical_plan/mod.rs
index 56e7027047..e5429945e9 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -2018,6 +2018,22 @@ pub trait PhysicalExtensionCodec: Debug + Send + Sync {
fn try_encode_udf(&self, _node: &ScalarUDF, _buf: &mut Vec<u8>) ->
Result<()> {
Ok(())
}
+
+ fn try_decode_expr(
+ &self,
+ _buf: &[u8],
+ _inputs: &[Arc<dyn PhysicalExpr>],
+ ) -> Result<Arc<dyn PhysicalExpr>> {
+ not_impl_err!("PhysicalExtensionCodec is not provided")
+ }
+
+ fn try_encode_expr(
+ &self,
+ _node: Arc<dyn PhysicalExpr>,
+ _buf: &mut Vec<u8>,
+ ) -> Result<()> {
+ not_impl_err!("PhysicalExtensionCodec is not provided")
+ }
}
#[derive(Debug)]
diff --git a/datafusion/proto/src/physical_plan/to_proto.rs
b/datafusion/proto/src/physical_plan/to_proto.rs
index 5e982ad2af..9c95acc1dc 100644
--- a/datafusion/proto/src/physical_plan/to_proto.rs
+++ b/datafusion/proto/src/physical_plan/to_proto.rs
@@ -495,7 +495,24 @@ pub fn serialize_physical_expr(
))),
})
} else {
- internal_err!("physical_plan::to_proto() unsupported expression
{value:?}")
+ let mut buf: Vec<u8> = vec![];
+ match codec.try_encode_expr(Arc::clone(&value), &mut buf) {
+ Ok(_) => {
+ let inputs: Vec<protobuf::PhysicalExprNode> = value
+ .children()
+ .into_iter()
+ .map(|e| serialize_physical_expr(Arc::clone(e), codec))
+ .collect::<Result<_>>()?;
+ Ok(protobuf::PhysicalExprNode {
+ expr_type:
Some(protobuf::physical_expr_node::ExprType::Extension(
+ protobuf::PhysicalExtensionExprNode { expr: buf,
inputs },
+ )),
+ })
+ }
+ Err(e) => internal_err!(
+ "Unsupported physical expr and extension codec failed with
[{e}]. Expr: {value:?}"
+ ),
+ }
}
}
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index d8d85ace1a..2fcc65008f 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -15,7 +15,10 @@
// specific language governing permissions and limitations
// under the License.
+use arrow::array::RecordBatch;
use std::any::Any;
+use std::fmt::Display;
+use std::hash::Hasher;
use std::ops::Deref;
use std::sync::Arc;
use std::vec;
@@ -38,6 +41,7 @@ use datafusion::datasource::physical_plan::{
};
use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::{create_udf, JoinType, Operator, Volatility};
+use datafusion::physical_expr::aggregate::utils::down_cast_any_ref;
use datafusion::physical_expr::expressions::Max;
use datafusion::physical_expr::window::SlidingAggregateWindowExpr;
use datafusion::physical_expr::{PhysicalSortRequirement, ScalarFunctionExpr};
@@ -75,7 +79,7 @@ use
datafusion_common::file_options::csv_writer::CsvWriterOptions;
use datafusion_common::file_options::json_writer::JsonWriterOptions;
use datafusion_common::parsers::CompressionTypeVariant;
use datafusion_common::stats::Precision;
-use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result};
+use datafusion_common::{internal_err, not_impl_err, plan_err, DataFusionError,
Result};
use datafusion_expr::{
Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue,
ScalarUDF,
ScalarUDFImpl, Signature, SimpleAggregateUDF, WindowFrame,
WindowFrameBound,
@@ -658,6 +662,147 @@ async fn
roundtrip_parquet_exec_with_table_partition_cols() -> Result<()> {
roundtrip_test(ParquetExec::builder(scan_config).build_arc())
}
+#[test]
+fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> {
+ let scan_config = FileScanConfig {
+ object_store_url: ObjectStoreUrl::local_filesystem(),
+ file_schema: Arc::new(Schema::new(vec![Field::new(
+ "col",
+ DataType::Utf8,
+ false,
+ )])),
+ file_groups: vec![vec![PartitionedFile::new(
+ "/path/to/file.parquet".to_string(),
+ 1024,
+ )]],
+ statistics: Statistics {
+ num_rows: Precision::Inexact(100),
+ total_byte_size: Precision::Inexact(1024),
+ column_statistics:
Statistics::unknown_column(&Arc::new(Schema::new(vec![
+ Field::new("col", DataType::Utf8, false),
+ ]))),
+ },
+ projection: None,
+ limit: None,
+ table_partition_cols: vec![],
+ output_ordering: vec![],
+ };
+
+ #[derive(Debug, Hash, Clone)]
+ struct CustomPredicateExpr {
+ inner: Arc<dyn PhysicalExpr>,
+ }
+ impl Display for CustomPredicateExpr {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "CustomPredicateExpr")
+ }
+ }
+ impl PartialEq<dyn Any> for CustomPredicateExpr {
+ fn eq(&self, other: &dyn Any) -> bool {
+ down_cast_any_ref(other)
+ .downcast_ref::<Self>()
+ .map(|x| self.inner.eq(&x.inner))
+ .unwrap_or(false)
+ }
+ }
+ impl PhysicalExpr for CustomPredicateExpr {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
+ unreachable!()
+ }
+
+ fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
+ unreachable!()
+ }
+
+ fn evaluate(&self, _batch: &RecordBatch) -> Result<ColumnarValue> {
+ unreachable!()
+ }
+
+ fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
+ vec![&self.inner]
+ }
+
+ fn with_new_children(
+ self: Arc<Self>,
+ _children: Vec<Arc<dyn PhysicalExpr>>,
+ ) -> Result<Arc<dyn PhysicalExpr>> {
+ todo!()
+ }
+
+ fn dyn_hash(&self, _state: &mut dyn Hasher) {
+ unreachable!()
+ }
+ }
+
+ #[derive(Debug)]
+ struct CustomPhysicalExtensionCodec;
+ impl PhysicalExtensionCodec for CustomPhysicalExtensionCodec {
+ fn try_decode(
+ &self,
+ _buf: &[u8],
+ _inputs: &[Arc<dyn ExecutionPlan>],
+ _registry: &dyn FunctionRegistry,
+ ) -> Result<Arc<dyn ExecutionPlan>> {
+ unreachable!()
+ }
+
+ fn try_encode(
+ &self,
+ _node: Arc<dyn ExecutionPlan>,
+ _buf: &mut Vec<u8>,
+ ) -> Result<()> {
+ unreachable!()
+ }
+
+ fn try_decode_expr(
+ &self,
+ buf: &[u8],
+ inputs: &[Arc<dyn PhysicalExpr>],
+ ) -> Result<Arc<dyn PhysicalExpr>> {
+ if buf == "CustomPredicateExpr".as_bytes() {
+ Ok(Arc::new(CustomPredicateExpr {
+ inner: inputs[0].clone(),
+ }))
+ } else {
+ internal_err!("Not supported")
+ }
+ }
+
+ fn try_encode_expr(
+ &self,
+ node: Arc<dyn PhysicalExpr>,
+ buf: &mut Vec<u8>,
+ ) -> Result<()> {
+ if node
+ .as_ref()
+ .as_any()
+ .downcast_ref::<CustomPredicateExpr>()
+ .is_some()
+ {
+ buf.extend_from_slice("CustomPredicateExpr".as_bytes());
+ Ok(())
+ } else {
+ internal_err!("Not supported")
+ }
+ }
+ }
+
+ let custom_predicate_expr = Arc::new(CustomPredicateExpr {
+ inner: Arc::new(Column::new("col", 1)),
+ });
+ let exec_plan = ParquetExec::builder(scan_config)
+ .with_predicate(custom_predicate_expr)
+ .build_arc();
+
+ let ctx = SessionContext::new();
+ roundtrip_test_and_return(exec_plan, &ctx, &CustomPhysicalExtensionCodec
{})?;
+ Ok(())
+}
+
#[test]
fn roundtrip_scalar_udf() -> Result<()> {
let field_a = Field::new("a", DataType::Int64, false);
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]