This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 70a215bc8a support limit in agg exec for ser/deser (#10692)
70a215bc8a is described below

commit 70a215bc8ab357987e81d3c607ccc664ee634790
Author: Kun Liu <[email protected]>
AuthorDate: Tue May 28 21:29:23 2024 +0800

    support limit in agg exec for ser/deser (#10692)
---
 datafusion/proto/proto/datafusion.proto            |   6 ++
 datafusion/proto/src/generated/pbjson.rs           | 111 +++++++++++++++++++++
 datafusion/proto/src/generated/prost.rs            |   9 ++
 datafusion/proto/src/physical_plan/mod.rs          |  18 +++-
 .../proto/tests/cases/roundtrip_physical_plan.rs   |  27 +++++
 5 files changed, 169 insertions(+), 2 deletions(-)

diff --git a/datafusion/proto/proto/datafusion.proto 
b/datafusion/proto/proto/datafusion.proto
index 6b4e2aae29..448d9f0582 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -1125,6 +1125,11 @@ message MaybePhysicalSortExprs {
   repeated PhysicalSortExprNode sort_expr = 1;
 }
 
+message AggLimit {
+  // wrap into a message to make it optional
+  uint64 limit = 1;
+}
+
 message AggregateExecNode {
   repeated PhysicalExprNode group_expr = 1;
   repeated PhysicalExprNode aggr_expr = 2;
@@ -1137,6 +1142,7 @@ message AggregateExecNode {
   repeated PhysicalExprNode null_expr = 8;
   repeated bool groups = 9;
   repeated MaybeFilter filter_expr = 10;
+  AggLimit limit = 11;
 }
 
 message GlobalLimitExecNode {
diff --git a/datafusion/proto/src/generated/pbjson.rs 
b/datafusion/proto/src/generated/pbjson.rs
index bbee3311b7..76a367b402 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -1,3 +1,97 @@
+impl serde::Serialize for AggLimit {
+    #[allow(deprecated)]
+    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, 
S::Error>
+    where
+        S: serde::Serializer,
+    {
+        use serde::ser::SerializeStruct;
+        let mut len = 0;
+        if self.limit != 0 {
+            len += 1;
+        }
+        let mut struct_ser = 
serializer.serialize_struct("datafusion.AggLimit", len)?;
+        if self.limit != 0 {
+            #[allow(clippy::needless_borrow)]
+            struct_ser.serialize_field("limit", 
ToString::to_string(&self.limit).as_str())?;
+        }
+        struct_ser.end()
+    }
+}
+impl<'de> serde::Deserialize<'de> for AggLimit {
+    #[allow(deprecated)]
+    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+    where
+        D: serde::Deserializer<'de>,
+    {
+        const FIELDS: &[&str] = &[
+            "limit",
+        ];
+
+        #[allow(clippy::enum_variant_names)]
+        enum GeneratedField {
+            Limit,
+        }
+        impl<'de> serde::Deserialize<'de> for GeneratedField {
+            fn deserialize<D>(deserializer: D) -> 
std::result::Result<GeneratedField, D::Error>
+            where
+                D: serde::Deserializer<'de>,
+            {
+                struct GeneratedVisitor;
+
+                impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+                    type Value = GeneratedField;
+
+                    fn expecting(&self, formatter: &mut 
std::fmt::Formatter<'_>) -> std::fmt::Result {
+                        write!(formatter, "expected one of: {:?}", &FIELDS)
+                    }
+
+                    #[allow(unused_variables)]
+                    fn visit_str<E>(self, value: &str) -> 
std::result::Result<GeneratedField, E>
+                    where
+                        E: serde::de::Error,
+                    {
+                        match value {
+                            "limit" => Ok(GeneratedField::Limit),
+                            _ => Err(serde::de::Error::unknown_field(value, 
FIELDS)),
+                        }
+                    }
+                }
+                deserializer.deserialize_identifier(GeneratedVisitor)
+            }
+        }
+        struct GeneratedVisitor;
+        impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+            type Value = AggLimit;
+
+            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> 
std::fmt::Result {
+                formatter.write_str("struct datafusion.AggLimit")
+            }
+
+            fn visit_map<V>(self, mut map_: V) -> 
std::result::Result<AggLimit, V::Error>
+                where
+                    V: serde::de::MapAccess<'de>,
+            {
+                let mut limit__ = None;
+                while let Some(k) = map_.next_key()? {
+                    match k {
+                        GeneratedField::Limit => {
+                            if limit__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("limit"));
+                            }
+                            limit__ = 
+                                
Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0)
+                            ;
+                        }
+                    }
+                }
+                Ok(AggLimit {
+                    limit: limit__.unwrap_or_default(),
+                })
+            }
+        }
+        deserializer.deserialize_struct("datafusion.AggLimit", FIELDS, 
GeneratedVisitor)
+    }
+}
 impl serde::Serialize for AggregateExecNode {
     #[allow(deprecated)]
     fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, 
S::Error>
@@ -36,6 +130,9 @@ impl serde::Serialize for AggregateExecNode {
         if !self.filter_expr.is_empty() {
             len += 1;
         }
+        if self.limit.is_some() {
+            len += 1;
+        }
         let mut struct_ser = 
serializer.serialize_struct("datafusion.AggregateExecNode", len)?;
         if !self.group_expr.is_empty() {
             struct_ser.serialize_field("groupExpr", &self.group_expr)?;
@@ -69,6 +166,9 @@ impl serde::Serialize for AggregateExecNode {
         if !self.filter_expr.is_empty() {
             struct_ser.serialize_field("filterExpr", &self.filter_expr)?;
         }
+        if let Some(v) = self.limit.as_ref() {
+            struct_ser.serialize_field("limit", v)?;
+        }
         struct_ser.end()
     }
 }
@@ -96,6 +196,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode {
             "groups",
             "filter_expr",
             "filterExpr",
+            "limit",
         ];
 
         #[allow(clippy::enum_variant_names)]
@@ -110,6 +211,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode {
             NullExpr,
             Groups,
             FilterExpr,
+            Limit,
         }
         impl<'de> serde::Deserialize<'de> for GeneratedField {
             fn deserialize<D>(deserializer: D) -> 
std::result::Result<GeneratedField, D::Error>
@@ -141,6 +243,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode {
                             "nullExpr" | "null_expr" => 
Ok(GeneratedField::NullExpr),
                             "groups" => Ok(GeneratedField::Groups),
                             "filterExpr" | "filter_expr" => 
Ok(GeneratedField::FilterExpr),
+                            "limit" => Ok(GeneratedField::Limit),
                             _ => Err(serde::de::Error::unknown_field(value, 
FIELDS)),
                         }
                     }
@@ -170,6 +273,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode {
                 let mut null_expr__ = None;
                 let mut groups__ = None;
                 let mut filter_expr__ = None;
+                let mut limit__ = None;
                 while let Some(k) = map_.next_key()? {
                     match k {
                         GeneratedField::GroupExpr => {
@@ -232,6 +336,12 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode {
                             }
                             filter_expr__ = Some(map_.next_value()?);
                         }
+                        GeneratedField::Limit => {
+                            if limit__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("limit"));
+                            }
+                            limit__ = map_.next_value()?;
+                        }
                     }
                 }
                 Ok(AggregateExecNode {
@@ -245,6 +355,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode {
                     null_expr: null_expr__.unwrap_or_default(),
                     groups: groups__.unwrap_or_default(),
                     filter_expr: filter_expr__.unwrap_or_default(),
+                    limit: limit__,
                 })
             }
         }
diff --git a/datafusion/proto/src/generated/prost.rs 
b/datafusion/proto/src/generated/prost.rs
index 0354ead9e7..5e0f6613f3 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1735,6 +1735,13 @@ pub struct MaybePhysicalSortExprs {
 }
 #[allow(clippy::derive_partial_eq_without_eq)]
 #[derive(Clone, PartialEq, ::prost::Message)]
+pub struct AggLimit {
+    /// wrap into a message to make it optional
+    #[prost(uint64, tag = "1")]
+    pub limit: u64,
+}
+#[allow(clippy::derive_partial_eq_without_eq)]
+#[derive(Clone, PartialEq, ::prost::Message)]
 pub struct AggregateExecNode {
     #[prost(message, repeated, tag = "1")]
     pub group_expr: ::prost::alloc::vec::Vec<PhysicalExprNode>,
@@ -1757,6 +1764,8 @@ pub struct AggregateExecNode {
     pub groups: ::prost::alloc::vec::Vec<bool>,
     #[prost(message, repeated, tag = "10")]
     pub filter_expr: ::prost::alloc::vec::Vec<MaybeFilter>,
+    #[prost(message, optional, tag = "11")]
+    pub limit: ::core::option::Option<AggLimit>,
 }
 #[allow(clippy::derive_partial_eq_without_eq)]
 #[derive(Clone, PartialEq, ::prost::Message)]
diff --git a/datafusion/proto/src/physical_plan/mod.rs 
b/datafusion/proto/src/physical_plan/mod.rs
index a85bfdc89d..91ed3b7f5e 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -539,14 +539,23 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode {
                     })
                     .collect::<Result<Vec<_>, _>>()?;
 
-                Ok(Arc::new(AggregateExec::try_new(
+                let limit = hash_agg
+                    .limit
+                    .as_ref()
+                    .map(|lit_value| lit_value.limit as usize);
+
+                let agg = AggregateExec::try_new(
                     agg_mode,
                     PhysicalGroupBy::new(group_expr, null_expr, groups),
                     physical_aggr_expr,
                     physical_filter_expr,
                     input,
                     physical_schema,
-                )?))
+                )?;
+
+                let agg = agg.with_limit(limit);
+
+                Ok(Arc::new(agg))
             }
             PhysicalPlanType::HashJoin(hashjoin) => {
                 let left: Arc<dyn ExecutionPlan> = into_physical_plan(
@@ -1504,6 +1513,10 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode {
                 .map(|expr| serialize_physical_expr(expr.0.to_owned(), 
extension_codec))
                 .collect::<Result<Vec<_>>>()?;
 
+            let limit = exec.limit().map(|value| protobuf::AggLimit {
+                limit: value as u64,
+            });
+
             return Ok(protobuf::PhysicalPlanNode {
                 physical_plan_type: Some(PhysicalPlanType::Aggregate(Box::new(
                     protobuf::AggregateExecNode {
@@ -1517,6 +1530,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode {
                         input_schema: Some(input_schema.as_ref().try_into()?),
                         null_expr,
                         groups,
+                        limit,
                     },
                 ))),
             });
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs 
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index 79abecf556..55b346a482 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -370,6 +370,33 @@ fn rountrip_aggregate() -> Result<()> {
     Ok(())
 }
 
+#[test]
+fn rountrip_aggregate_with_limit() -> Result<()> {
+    let field_a = Field::new("a", DataType::Int64, false);
+    let field_b = Field::new("b", DataType::Int64, false);
+    let schema = Arc::new(Schema::new(vec![field_a, field_b]));
+
+    let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =
+        vec![(col("a", &schema)?, "unused".to_string())];
+
+    let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Avg::new(
+        cast(col("b", &schema)?, &schema, DataType::Float64)?,
+        "AVG(b)".to_string(),
+        DataType::Float64,
+    ))];
+
+    let agg = AggregateExec::try_new(
+        AggregateMode::Final,
+        PhysicalGroupBy::new_single(groups.clone()),
+        aggregates.clone(),
+        vec![None],
+        Arc::new(EmptyExec::new(schema.clone())),
+        schema,
+    )?;
+    let agg = agg.with_limit(Some(12));
+    roundtrip_test(Arc::new(agg))
+}
+
 #[test]
 fn roundtrip_aggregate_udaf() -> Result<()> {
     let field_a = Field::new("a", DataType::Int64, false);


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to