This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 70a215bc8a support limit in agg exec for ser/deser (#10692)
70a215bc8a is described below
commit 70a215bc8ab357987e81d3c607ccc664ee634790
Author: Kun Liu <[email protected]>
AuthorDate: Tue May 28 21:29:23 2024 +0800
support limit in agg exec for ser/deser (#10692)
---
datafusion/proto/proto/datafusion.proto | 6 ++
datafusion/proto/src/generated/pbjson.rs | 111 +++++++++++++++++++++
datafusion/proto/src/generated/prost.rs | 9 ++
datafusion/proto/src/physical_plan/mod.rs | 18 +++-
.../proto/tests/cases/roundtrip_physical_plan.rs | 27 +++++
5 files changed, 169 insertions(+), 2 deletions(-)
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index 6b4e2aae29..448d9f0582 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -1125,6 +1125,11 @@ message MaybePhysicalSortExprs {
repeated PhysicalSortExprNode sort_expr = 1;
}
+message AggLimit {
+ // wrap into a message to make it optional
+ uint64 limit = 1;
+}
+
message AggregateExecNode {
repeated PhysicalExprNode group_expr = 1;
repeated PhysicalExprNode aggr_expr = 2;
@@ -1137,6 +1142,7 @@ message AggregateExecNode {
repeated PhysicalExprNode null_expr = 8;
repeated bool groups = 9;
repeated MaybeFilter filter_expr = 10;
+ AggLimit limit = 11;
}
message GlobalLimitExecNode {
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index bbee3311b7..76a367b402 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -1,3 +1,97 @@
+impl serde::Serialize for AggLimit {
+ #[allow(deprecated)]
+ fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok,
S::Error>
+ where
+ S: serde::Serializer,
+ {
+ use serde::ser::SerializeStruct;
+ let mut len = 0;
+ if self.limit != 0 {
+ len += 1;
+ }
+ let mut struct_ser =
serializer.serialize_struct("datafusion.AggLimit", len)?;
+ if self.limit != 0 {
+ #[allow(clippy::needless_borrow)]
+ struct_ser.serialize_field("limit",
ToString::to_string(&self.limit).as_str())?;
+ }
+ struct_ser.end()
+ }
+}
+impl<'de> serde::Deserialize<'de> for AggLimit {
+ #[allow(deprecated)]
+ fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ const FIELDS: &[&str] = &[
+ "limit",
+ ];
+
+ #[allow(clippy::enum_variant_names)]
+ enum GeneratedField {
+ Limit,
+ }
+ impl<'de> serde::Deserialize<'de> for GeneratedField {
+ fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ struct GeneratedVisitor;
+
+ impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+ type Value = GeneratedField;
+
+ fn expecting(&self, formatter: &mut
std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(formatter, "expected one of: {:?}", &FIELDS)
+ }
+
+ #[allow(unused_variables)]
+ fn visit_str<E>(self, value: &str) ->
std::result::Result<GeneratedField, E>
+ where
+ E: serde::de::Error,
+ {
+ match value {
+ "limit" => Ok(GeneratedField::Limit),
+ _ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
+ }
+ }
+ }
+ deserializer.deserialize_identifier(GeneratedVisitor)
+ }
+ }
+ struct GeneratedVisitor;
+ impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+ type Value = AggLimit;
+
+ fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) ->
std::fmt::Result {
+ formatter.write_str("struct datafusion.AggLimit")
+ }
+
+ fn visit_map<V>(self, mut map_: V) ->
std::result::Result<AggLimit, V::Error>
+ where
+ V: serde::de::MapAccess<'de>,
+ {
+ let mut limit__ = None;
+ while let Some(k) = map_.next_key()? {
+ match k {
+ GeneratedField::Limit => {
+ if limit__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("limit"));
+ }
+ limit__ =
+
Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0)
+ ;
+ }
+ }
+ }
+ Ok(AggLimit {
+ limit: limit__.unwrap_or_default(),
+ })
+ }
+ }
+ deserializer.deserialize_struct("datafusion.AggLimit", FIELDS,
GeneratedVisitor)
+ }
+}
impl serde::Serialize for AggregateExecNode {
#[allow(deprecated)]
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok,
S::Error>
@@ -36,6 +130,9 @@ impl serde::Serialize for AggregateExecNode {
if !self.filter_expr.is_empty() {
len += 1;
}
+ if self.limit.is_some() {
+ len += 1;
+ }
let mut struct_ser =
serializer.serialize_struct("datafusion.AggregateExecNode", len)?;
if !self.group_expr.is_empty() {
struct_ser.serialize_field("groupExpr", &self.group_expr)?;
@@ -69,6 +166,9 @@ impl serde::Serialize for AggregateExecNode {
if !self.filter_expr.is_empty() {
struct_ser.serialize_field("filterExpr", &self.filter_expr)?;
}
+ if let Some(v) = self.limit.as_ref() {
+ struct_ser.serialize_field("limit", v)?;
+ }
struct_ser.end()
}
}
@@ -96,6 +196,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode {
"groups",
"filter_expr",
"filterExpr",
+ "limit",
];
#[allow(clippy::enum_variant_names)]
@@ -110,6 +211,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode {
NullExpr,
Groups,
FilterExpr,
+ Limit,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
@@ -141,6 +243,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode {
"nullExpr" | "null_expr" =>
Ok(GeneratedField::NullExpr),
"groups" => Ok(GeneratedField::Groups),
"filterExpr" | "filter_expr" =>
Ok(GeneratedField::FilterExpr),
+ "limit" => Ok(GeneratedField::Limit),
_ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
}
}
@@ -170,6 +273,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode {
let mut null_expr__ = None;
let mut groups__ = None;
let mut filter_expr__ = None;
+ let mut limit__ = None;
while let Some(k) = map_.next_key()? {
match k {
GeneratedField::GroupExpr => {
@@ -232,6 +336,12 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode {
}
filter_expr__ = Some(map_.next_value()?);
}
+ GeneratedField::Limit => {
+ if limit__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("limit"));
+ }
+ limit__ = map_.next_value()?;
+ }
}
}
Ok(AggregateExecNode {
@@ -245,6 +355,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode {
null_expr: null_expr__.unwrap_or_default(),
groups: groups__.unwrap_or_default(),
filter_expr: filter_expr__.unwrap_or_default(),
+ limit: limit__,
})
}
}
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index 0354ead9e7..5e0f6613f3 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1735,6 +1735,13 @@ pub struct MaybePhysicalSortExprs {
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct AggLimit {
+ /// wrap into a message to make it optional
+ #[prost(uint64, tag = "1")]
+ pub limit: u64,
+}
+#[allow(clippy::derive_partial_eq_without_eq)]
+#[derive(Clone, PartialEq, ::prost::Message)]
pub struct AggregateExecNode {
#[prost(message, repeated, tag = "1")]
pub group_expr: ::prost::alloc::vec::Vec<PhysicalExprNode>,
@@ -1757,6 +1764,8 @@ pub struct AggregateExecNode {
pub groups: ::prost::alloc::vec::Vec<bool>,
#[prost(message, repeated, tag = "10")]
pub filter_expr: ::prost::alloc::vec::Vec<MaybeFilter>,
+ #[prost(message, optional, tag = "11")]
+ pub limit: ::core::option::Option<AggLimit>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
diff --git a/datafusion/proto/src/physical_plan/mod.rs
b/datafusion/proto/src/physical_plan/mod.rs
index a85bfdc89d..91ed3b7f5e 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -539,14 +539,23 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode {
})
.collect::<Result<Vec<_>, _>>()?;
- Ok(Arc::new(AggregateExec::try_new(
+ let limit = hash_agg
+ .limit
+ .as_ref()
+ .map(|lit_value| lit_value.limit as usize);
+
+ let agg = AggregateExec::try_new(
agg_mode,
PhysicalGroupBy::new(group_expr, null_expr, groups),
physical_aggr_expr,
physical_filter_expr,
input,
physical_schema,
- )?))
+ )?;
+
+ let agg = agg.with_limit(limit);
+
+ Ok(Arc::new(agg))
}
PhysicalPlanType::HashJoin(hashjoin) => {
let left: Arc<dyn ExecutionPlan> = into_physical_plan(
@@ -1504,6 +1513,10 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode {
.map(|expr| serialize_physical_expr(expr.0.to_owned(),
extension_codec))
.collect::<Result<Vec<_>>>()?;
+ let limit = exec.limit().map(|value| protobuf::AggLimit {
+ limit: value as u64,
+ });
+
return Ok(protobuf::PhysicalPlanNode {
physical_plan_type: Some(PhysicalPlanType::Aggregate(Box::new(
protobuf::AggregateExecNode {
@@ -1517,6 +1530,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode {
input_schema: Some(input_schema.as_ref().try_into()?),
null_expr,
groups,
+ limit,
},
))),
});
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index 79abecf556..55b346a482 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -370,6 +370,33 @@ fn rountrip_aggregate() -> Result<()> {
Ok(())
}
+#[test]
+fn rountrip_aggregate_with_limit() -> Result<()> {
+ let field_a = Field::new("a", DataType::Int64, false);
+ let field_b = Field::new("b", DataType::Int64, false);
+ let schema = Arc::new(Schema::new(vec![field_a, field_b]));
+
+ let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =
+ vec![(col("a", &schema)?, "unused".to_string())];
+
+ let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Avg::new(
+ cast(col("b", &schema)?, &schema, DataType::Float64)?,
+ "AVG(b)".to_string(),
+ DataType::Float64,
+ ))];
+
+ let agg = AggregateExec::try_new(
+ AggregateMode::Final,
+ PhysicalGroupBy::new_single(groups.clone()),
+ aggregates.clone(),
+ vec![None],
+ Arc::new(EmptyExec::new(schema.clone())),
+ schema,
+ )?;
+ let agg = agg.with_limit(Some(12));
+ roundtrip_test(Arc::new(agg))
+}
+
#[test]
fn roundtrip_aggregate_udaf() -> Result<()> {
let field_a = Field::new("a", DataType::Int64, false);
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]