This is an automated email from the ASF dual-hosted git repository.
jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 13569340bc ExprBuilder for Physical Aggregate Expr (#11617)
13569340bc is described below
commit 13569340bce99e4a317ec4d71e5c46d69dfa733d
Author: Jay Zhan <[email protected]>
AuthorDate: Wed Jul 24 22:30:05 2024 +0800
ExprBuilder for Physical Aggregate Expr (#11617)
* aggregate expr builder
Signed-off-by: jayzhan211 <[email protected]>
* replace parts of test
Signed-off-by: jayzhan211 <[email protected]>
* continue
Signed-off-by: jayzhan211 <[email protected]>
* cleanup all
Signed-off-by: jayzhan211 <[email protected]>
* clipp
Signed-off-by: jayzhan211 <[email protected]>
* add sort
Signed-off-by: jayzhan211 <[email protected]>
* rm field
Signed-off-by: jayzhan211 <[email protected]>
* address comment
Signed-off-by: jayzhan211 <[email protected]>
* fix import path
Signed-off-by: jayzhan211 <[email protected]>
---------
Signed-off-by: jayzhan211 <[email protected]>
---
datafusion/core/src/lib.rs | 5 +
.../src/physical_optimizer/aggregate_statistics.rs | 20 +-
.../combine_partial_final_agg.rs | 41 +--
datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs | 23 +-
.../physical-expr-common/src/aggregate/mod.rs | 286 +++++++++++++++------
datafusion/physical-plan/src/aggregates/mod.rs | 134 ++++------
datafusion/physical-plan/src/windows/mod.rs | 39 ++-
datafusion/proto/src/physical_plan/mod.rs | 11 +-
.../proto/tests/cases/roundtrip_physical_plan.rs | 177 +++++--------
9 files changed, 369 insertions(+), 367 deletions(-)
diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs
index 9ab6ed527d..d9ab9e1c07 100644
--- a/datafusion/core/src/lib.rs
+++ b/datafusion/core/src/lib.rs
@@ -545,6 +545,11 @@ pub mod optimizer {
pub use datafusion_optimizer::*;
}
+/// re-export of [`datafusion_physical_expr`] crate
+pub mod physical_expr_common {
+ pub use datafusion_physical_expr_common::*;
+}
+
/// re-export of [`datafusion_physical_expr`] crate
pub mod physical_expr {
pub use datafusion_physical_expr::*;
diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
index e7580d3e33..5f08e4512b 100644
--- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
+++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
@@ -326,7 +326,7 @@ pub(crate) mod tests {
use datafusion_functions_aggregate::count::count_udaf;
use datafusion_physical_expr::expressions::cast;
use datafusion_physical_expr::PhysicalExpr;
- use datafusion_physical_expr_common::aggregate::create_aggregate_expr;
+ use datafusion_physical_expr_common::aggregate::AggregateExprBuilder;
use datafusion_physical_plan::aggregates::AggregateMode;
/// Mock data using a MemoryExec which has an exact count statistic
@@ -419,19 +419,11 @@ pub(crate) mod tests {
// Return appropriate expr depending if COUNT is for col or table (*)
pub(crate) fn count_expr(&self, schema: &Schema) -> Arc<dyn
AggregateExpr> {
- create_aggregate_expr(
- &count_udaf(),
- &[self.column()],
- &[],
- &[],
- &[],
- schema,
- self.column_name(),
- false,
- false,
- false,
- )
- .unwrap()
+ AggregateExprBuilder::new(count_udaf(), vec![self.column()])
+ .schema(Arc::new(schema.clone()))
+ .name(self.column_name())
+ .build()
+ .unwrap()
}
/// what argument would this aggregate need in the plan?
diff --git
a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs
b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs
index ddb7d36fb5..6f3274820c 100644
--- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs
+++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs
@@ -177,7 +177,7 @@ mod tests {
use datafusion_functions_aggregate::count::count_udaf;
use datafusion_functions_aggregate::sum::sum_udaf;
use datafusion_physical_expr::expressions::col;
- use datafusion_physical_plan::udaf::create_aggregate_expr;
+ use datafusion_physical_expr_common::aggregate::AggregateExprBuilder;
/// Runs the CombinePartialFinalAggregate optimizer and asserts the plan
against the expected
macro_rules! assert_optimized {
@@ -278,19 +278,11 @@ mod tests {
name: &str,
schema: &Schema,
) -> Arc<dyn AggregateExpr> {
- create_aggregate_expr(
- &count_udaf(),
- &[expr],
- &[],
- &[],
- &[],
- schema,
- name,
- false,
- false,
- false,
- )
- .unwrap()
+ AggregateExprBuilder::new(count_udaf(), vec![expr])
+ .schema(Arc::new(schema.clone()))
+ .name(name)
+ .build()
+ .unwrap()
}
#[test]
@@ -368,19 +360,14 @@ mod tests {
#[test]
fn aggregations_with_group_combined() -> Result<()> {
let schema = schema();
-
- let aggr_expr = vec![create_aggregate_expr(
- &sum_udaf(),
- &[col("b", &schema)?],
- &[],
- &[],
- &[],
- &schema,
- "Sum(b)",
- false,
- false,
- false,
- )?];
+ let aggr_expr =
+ vec![
+ AggregateExprBuilder::new(sum_udaf(), vec![col("b", &schema)?])
+ .schema(Arc::clone(&schema))
+ .name("Sum(b)")
+ .build()
+ .unwrap(),
+ ];
let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =
vec![(col("c", &schema)?, "c".to_string())];
diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
index 736560da97..6f286c9aeb 100644
--- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
@@ -35,7 +35,7 @@ use datafusion_common::tree_node::{TreeNode,
TreeNodeRecursion, TreeNodeVisitor}
use datafusion_functions_aggregate::sum::sum_udaf;
use datafusion_physical_expr::expressions::col;
use datafusion_physical_expr::PhysicalSortExpr;
-use datafusion_physical_plan::udaf::create_aggregate_expr;
+use datafusion_physical_expr_common::aggregate::AggregateExprBuilder;
use datafusion_physical_plan::InputOrderMode;
use test_utils::{add_empty_batches, StringBatchGenerator};
@@ -103,19 +103,14 @@ async fn run_aggregate_test(input1: Vec<RecordBatch>,
group_by_columns: Vec<&str
.with_sort_information(vec![sort_keys]),
);
- let aggregate_expr = vec![create_aggregate_expr(
- &sum_udaf(),
- &[col("d", &schema).unwrap()],
- &[],
- &[],
- &[],
- &schema,
- "sum1",
- false,
- false,
- false,
- )
- .unwrap()];
+ let aggregate_expr =
+ vec![
+ AggregateExprBuilder::new(sum_udaf(), vec![col("d",
&schema).unwrap()])
+ .schema(Arc::clone(&schema))
+ .name("sum1")
+ .build()
+ .unwrap(),
+ ];
let expr = group_by_columns
.iter()
.map(|elem| (col(elem, &schema).unwrap(), elem.to_string()))
diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs
b/datafusion/physical-expr-common/src/aggregate/mod.rs
index 8c5f9f9e5a..b58a5a6faf 100644
--- a/datafusion/physical-expr-common/src/aggregate/mod.rs
+++ b/datafusion/physical-expr-common/src/aggregate/mod.rs
@@ -22,8 +22,8 @@ pub mod stats;
pub mod tdigest;
pub mod utils;
-use arrow::datatypes::{DataType, Field, Schema};
-use datafusion_common::{not_impl_err, DFSchema, Result};
+use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
+use datafusion_common::{internal_err, not_impl_err, DFSchema, Result};
use datafusion_expr::function::StateFieldsArgs;
use datafusion_expr::type_coercion::aggregates::check_arg_count;
use datafusion_expr::ReversedUDAF;
@@ -33,7 +33,7 @@ use datafusion_expr::{
use std::fmt::Debug;
use std::{any::Any, sync::Arc};
-use self::utils::{down_cast_any_ref, ordering_fields};
+use self::utils::down_cast_any_ref;
use crate::physical_expr::PhysicalExpr;
use crate::sort_expr::{LexOrdering, PhysicalSortExpr};
use crate::utils::reverse_order_bys;
@@ -55,6 +55,8 @@ use datafusion_expr::utils::AggregateOrderSensitivity;
/// `is_reversed` is used to indicate whether the aggregation is running in
reverse order,
/// it could be used to hint Accumulator to accumulate in the reversed order,
/// you can just set to false if you are not reversing expression
+///
+/// You can also create expression by [`AggregateExprBuilder`]
#[allow(clippy::too_many_arguments)]
pub fn create_aggregate_expr(
fun: &AggregateUDF,
@@ -66,45 +68,23 @@ pub fn create_aggregate_expr(
name: impl Into<String>,
ignore_nulls: bool,
is_distinct: bool,
- is_reversed: bool,
) -> Result<Arc<dyn AggregateExpr>> {
- debug_assert_eq!(sort_exprs.len(), ordering_req.len());
-
- let input_exprs_types = input_phy_exprs
- .iter()
- .map(|arg| arg.data_type(schema))
- .collect::<Result<Vec<_>>>()?;
-
- check_arg_count(
- fun.name(),
- &input_exprs_types,
- &fun.signature().type_signature,
- )?;
-
- let ordering_types = ordering_req
- .iter()
- .map(|e| e.expr.data_type(schema))
- .collect::<Result<Vec<_>>>()?;
-
- let ordering_fields = ordering_fields(ordering_req, &ordering_types);
- let name = name.into();
-
- Ok(Arc::new(AggregateFunctionExpr {
- fun: fun.clone(),
- args: input_phy_exprs.to_vec(),
- logical_args: input_exprs.to_vec(),
- data_type: fun.return_type(&input_exprs_types)?,
- name,
- schema: schema.clone(),
- dfschema: DFSchema::empty(),
- sort_exprs: sort_exprs.to_vec(),
- ordering_req: ordering_req.to_vec(),
- ignore_nulls,
- ordering_fields,
- is_distinct,
- input_type: input_exprs_types[0].clone(),
- is_reversed,
- }))
+ let mut builder =
+ AggregateExprBuilder::new(Arc::new(fun.clone()),
input_phy_exprs.to_vec());
+ builder = builder.sort_exprs(sort_exprs.to_vec());
+ builder = builder.order_by(ordering_req.to_vec());
+ builder = builder.logical_exprs(input_exprs.to_vec());
+ builder = builder.schema(Arc::new(schema.clone()));
+ builder = builder.name(name);
+
+ if ignore_nulls {
+ builder = builder.ignore_nulls();
+ }
+ if is_distinct {
+ builder = builder.distinct();
+ }
+
+ builder.build()
}
#[allow(clippy::too_many_arguments)]
@@ -121,44 +101,196 @@ pub fn create_aggregate_expr_with_dfschema(
is_distinct: bool,
is_reversed: bool,
) -> Result<Arc<dyn AggregateExpr>> {
- debug_assert_eq!(sort_exprs.len(), ordering_req.len());
-
+ let mut builder =
+ AggregateExprBuilder::new(Arc::new(fun.clone()),
input_phy_exprs.to_vec());
+ builder = builder.sort_exprs(sort_exprs.to_vec());
+ builder = builder.order_by(ordering_req.to_vec());
+ builder = builder.logical_exprs(input_exprs.to_vec());
+ builder = builder.dfschema(dfschema.clone());
let schema: Schema = dfschema.into();
+ builder = builder.schema(Arc::new(schema));
+ builder = builder.name(name);
+
+ if ignore_nulls {
+ builder = builder.ignore_nulls();
+ }
+ if is_distinct {
+ builder = builder.distinct();
+ }
+ if is_reversed {
+ builder = builder.reversed();
+ }
+
+ builder.build()
+}
+
+/// Builder for physical [`AggregateExpr`]
+///
+/// `AggregateExpr` contains the information necessary to call
+/// an aggregate expression.
+#[derive(Debug, Clone)]
+pub struct AggregateExprBuilder {
+ fun: Arc<AggregateUDF>,
+ /// Physical expressions of the aggregate function
+ args: Vec<Arc<dyn PhysicalExpr>>,
+ /// Logical expressions of the aggregate function, it will be deprecated
in <https://github.com/apache/datafusion/issues/11359>
+ logical_args: Vec<Expr>,
+ name: String,
+ /// Arrow Schema for the aggregate function
+ schema: SchemaRef,
+ /// Datafusion Schema for the aggregate function
+ dfschema: DFSchema,
+ /// The logical order by expressions, it will be deprecated in
<https://github.com/apache/datafusion/issues/11359>
+ sort_exprs: Vec<Expr>,
+ /// The physical order by expressions
+ ordering_req: LexOrdering,
+ /// Whether to ignore null values
+ ignore_nulls: bool,
+ /// Whether is distinct aggregate function
+ is_distinct: bool,
+ /// Whether the expression is reversed
+ is_reversed: bool,
+}
+
+impl AggregateExprBuilder {
+ pub fn new(fun: Arc<AggregateUDF>, args: Vec<Arc<dyn PhysicalExpr>>) ->
Self {
+ Self {
+ fun,
+ args,
+ logical_args: vec![],
+ name: String::new(),
+ schema: Arc::new(Schema::empty()),
+ dfschema: DFSchema::empty(),
+ sort_exprs: vec![],
+ ordering_req: vec![],
+ ignore_nulls: false,
+ is_distinct: false,
+ is_reversed: false,
+ }
+ }
+
+ pub fn build(self) -> Result<Arc<dyn AggregateExpr>> {
+ let Self {
+ fun,
+ args,
+ logical_args,
+ name,
+ schema,
+ dfschema,
+ sort_exprs,
+ ordering_req,
+ ignore_nulls,
+ is_distinct,
+ is_reversed,
+ } = self;
+ if args.is_empty() {
+ return internal_err!("args should not be empty");
+ }
+
+ let mut ordering_fields = vec![];
+
+ debug_assert_eq!(sort_exprs.len(), ordering_req.len());
+ if !ordering_req.is_empty() {
+ let ordering_types = ordering_req
+ .iter()
+ .map(|e| e.expr.data_type(&schema))
+ .collect::<Result<Vec<_>>>()?;
+
+ ordering_fields = utils::ordering_fields(&ordering_req,
&ordering_types);
+ }
+
+ let input_exprs_types = args
+ .iter()
+ .map(|arg| arg.data_type(&schema))
+ .collect::<Result<Vec<_>>>()?;
+
+ check_arg_count(
+ fun.name(),
+ &input_exprs_types,
+ &fun.signature().type_signature,
+ )?;
- let input_exprs_types = input_phy_exprs
- .iter()
- .map(|arg| arg.data_type(&schema))
- .collect::<Result<Vec<_>>>()?;
-
- check_arg_count(
- fun.name(),
- &input_exprs_types,
- &fun.signature().type_signature,
- )?;
-
- let ordering_types = ordering_req
- .iter()
- .map(|e| e.expr.data_type(&schema))
- .collect::<Result<Vec<_>>>()?;
-
- let ordering_fields = ordering_fields(ordering_req, &ordering_types);
-
- Ok(Arc::new(AggregateFunctionExpr {
- fun: fun.clone(),
- args: input_phy_exprs.to_vec(),
- logical_args: input_exprs.to_vec(),
- data_type: fun.return_type(&input_exprs_types)?,
- name: name.into(),
- schema: schema.clone(),
- dfschema: dfschema.clone(),
- sort_exprs: sort_exprs.to_vec(),
- ordering_req: ordering_req.to_vec(),
- ignore_nulls,
- ordering_fields,
- is_distinct,
- input_type: input_exprs_types[0].clone(),
- is_reversed,
- }))
+ let data_type = fun.return_type(&input_exprs_types)?;
+
+ Ok(Arc::new(AggregateFunctionExpr {
+ fun: Arc::unwrap_or_clone(fun),
+ args,
+ logical_args,
+ data_type,
+ name,
+ schema: Arc::unwrap_or_clone(schema),
+ dfschema,
+ sort_exprs,
+ ordering_req,
+ ignore_nulls,
+ ordering_fields,
+ is_distinct,
+ input_type: input_exprs_types[0].clone(),
+ is_reversed,
+ }))
+ }
+
+ pub fn name(mut self, name: impl Into<String>) -> Self {
+ self.name = name.into();
+ self
+ }
+
+ pub fn schema(mut self, schema: SchemaRef) -> Self {
+ self.schema = schema;
+ self
+ }
+
+ pub fn dfschema(mut self, dfschema: DFSchema) -> Self {
+ self.dfschema = dfschema;
+ self
+ }
+
+ pub fn order_by(mut self, order_by: LexOrdering) -> Self {
+ self.ordering_req = order_by;
+ self
+ }
+
+ pub fn reversed(mut self) -> Self {
+ self.is_reversed = true;
+ self
+ }
+
+ pub fn with_reversed(mut self, is_reversed: bool) -> Self {
+ self.is_reversed = is_reversed;
+ self
+ }
+
+ pub fn distinct(mut self) -> Self {
+ self.is_distinct = true;
+ self
+ }
+
+ pub fn with_distinct(mut self, is_distinct: bool) -> Self {
+ self.is_distinct = is_distinct;
+ self
+ }
+
+ pub fn ignore_nulls(mut self) -> Self {
+ self.ignore_nulls = true;
+ self
+ }
+
+ pub fn with_ignore_nulls(mut self, ignore_nulls: bool) -> Self {
+ self.ignore_nulls = ignore_nulls;
+ self
+ }
+
+ /// This method will be deprecated in
<https://github.com/apache/datafusion/issues/11359>
+ pub fn sort_exprs(mut self, sort_exprs: Vec<Expr>) -> Self {
+ self.sort_exprs = sort_exprs;
+ self
+ }
+
+ /// This method will be deprecated in
<https://github.com/apache/datafusion/issues/11359>
+ pub fn logical_exprs(mut self, logical_args: Vec<Expr>) -> Self {
+ self.logical_args = logical_args;
+ self
+ }
}
/// An aggregate expression that:
diff --git a/datafusion/physical-plan/src/aggregates/mod.rs
b/datafusion/physical-plan/src/aggregates/mod.rs
index e7cd5cb272..d1152038eb 100644
--- a/datafusion/physical-plan/src/aggregates/mod.rs
+++ b/datafusion/physical-plan/src/aggregates/mod.rs
@@ -1211,7 +1211,7 @@ mod tests {
use crate::common::collect;
use datafusion_physical_expr_common::aggregate::{
- create_aggregate_expr, create_aggregate_expr_with_dfschema,
+ create_aggregate_expr_with_dfschema, AggregateExprBuilder,
};
use datafusion_physical_expr_common::expressions::Literal;
use futures::{FutureExt, Stream};
@@ -1351,18 +1351,11 @@ mod tests {
],
};
- let aggregates = vec![create_aggregate_expr(
- &count_udaf(),
- &[lit(1i8)],
- &[datafusion_expr::lit(1i8)],
- &[],
- &[],
- &input_schema,
- "COUNT(1)",
- false,
- false,
- false,
- )?];
+ let aggregates = vec![AggregateExprBuilder::new(count_udaf(),
vec![lit(1i8)])
+ .schema(Arc::clone(&input_schema))
+ .name("COUNT(1)")
+ .logical_exprs(vec![datafusion_expr::lit(1i8)])
+ .build()?];
let task_ctx = if spill {
new_spill_ctx(4, 1000)
@@ -1501,18 +1494,13 @@ mod tests {
groups: vec![vec![false]],
};
- let aggregates: Vec<Arc<dyn AggregateExpr>> =
vec![create_aggregate_expr(
- &avg_udaf(),
- &[col("b", &input_schema)?],
- &[datafusion_expr::col("b")],
- &[],
- &[],
- &input_schema,
- "AVG(b)",
- false,
- false,
- false,
- )?];
+ let aggregates: Vec<Arc<dyn AggregateExpr>> =
+ vec![
+ AggregateExprBuilder::new(avg_udaf(), vec![col("b",
&input_schema)?])
+ .schema(Arc::clone(&input_schema))
+ .name("AVG(b)")
+ .build()?,
+ ];
let task_ctx = if spill {
// set to an appropriate value to trigger spill
@@ -1803,21 +1791,11 @@ mod tests {
}
// Median(a)
- fn test_median_agg_expr(schema: &Schema) -> Result<Arc<dyn AggregateExpr>>
{
- let args = vec![col("a", schema)?];
- let fun = median_udaf();
- datafusion_physical_expr_common::aggregate::create_aggregate_expr(
- &fun,
- &args,
- &[],
- &[],
- &[],
- schema,
- "MEDIAN(a)",
- false,
- false,
- false,
- )
+ fn test_median_agg_expr(schema: SchemaRef) -> Result<Arc<dyn
AggregateExpr>> {
+ AggregateExprBuilder::new(median_udaf(), vec![col("a", &schema)?])
+ .schema(schema)
+ .name("MEDIAN(a)")
+ .build()
}
#[tokio::test]
@@ -1840,21 +1818,16 @@ mod tests {
// something that allocates within the aggregator
let aggregates_v0: Vec<Arc<dyn AggregateExpr>> =
- vec![test_median_agg_expr(&input_schema)?];
+ vec![test_median_agg_expr(Arc::clone(&input_schema))?];
// use fast-path in `row_hash.rs`.
- let aggregates_v2: Vec<Arc<dyn AggregateExpr>> =
vec![create_aggregate_expr(
- &avg_udaf(),
- &[col("b", &input_schema)?],
- &[datafusion_expr::col("b")],
- &[],
- &[],
- &input_schema,
- "AVG(b)",
- false,
- false,
- false,
- )?];
+ let aggregates_v2: Vec<Arc<dyn AggregateExpr>> =
+ vec![
+ AggregateExprBuilder::new(avg_udaf(), vec![col("b",
&input_schema)?])
+ .schema(Arc::clone(&input_schema))
+ .name("AVG(b)")
+ .build()?,
+ ];
for (version, groups, aggregates) in [
(0, groups_none, aggregates_v0),
@@ -1908,18 +1881,13 @@ mod tests {
let groups = PhysicalGroupBy::default();
- let aggregates: Vec<Arc<dyn AggregateExpr>> =
vec![create_aggregate_expr(
- &avg_udaf(),
- &[col("a", &schema)?],
- &[datafusion_expr::col("a")],
- &[],
- &[],
- &schema,
- "AVG(a)",
- false,
- false,
- false,
- )?];
+ let aggregates: Vec<Arc<dyn AggregateExpr>> =
+ vec![
+ AggregateExprBuilder::new(avg_udaf(), vec![col("a", &schema)?])
+ .schema(Arc::clone(&schema))
+ .name("AVG(a)")
+ .build()?,
+ ];
let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema),
1));
let refs = blocking_exec.refs();
@@ -1953,18 +1921,13 @@ mod tests {
let groups =
PhysicalGroupBy::new_single(vec![(col("a", &schema)?,
"a".to_string())]);
- let aggregates: Vec<Arc<dyn AggregateExpr>> =
vec![create_aggregate_expr(
- &avg_udaf(),
- &[col("b", &schema)?],
- &[datafusion_expr::col("b")],
- &[],
- &[],
- &schema,
- "AVG(b)",
- false,
- false,
- false,
- )?];
+ let aggregates: Vec<Arc<dyn AggregateExpr>> =
+ vec![
+ AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
+ .schema(Arc::clone(&schema))
+ .name("AVG(b)")
+ .build()?,
+ ];
let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema),
1));
let refs = blocking_exec.refs();
@@ -2388,18 +2351,11 @@ mod tests {
],
);
- let aggregates: Vec<Arc<dyn AggregateExpr>> =
vec![create_aggregate_expr(
- count_udaf().as_ref(),
- &[lit(1)],
- &[datafusion_expr::lit(1)],
- &[],
- &[],
- schema.as_ref(),
- "1",
- false,
- false,
- false,
- )?];
+ let aggregates: Vec<Arc<dyn AggregateExpr>> =
+ vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1)])
+ .schema(Arc::clone(&schema))
+ .name("1")
+ .build()?];
let input_batches = (0..4)
.map(|_| {
diff --git a/datafusion/physical-plan/src/windows/mod.rs
b/datafusion/physical-plan/src/windows/mod.rs
index 959796489c..ffe558e215 100644
--- a/datafusion/physical-plan/src/windows/mod.rs
+++ b/datafusion/physical-plan/src/windows/mod.rs
@@ -26,16 +26,16 @@ use crate::{
cume_dist, dense_rank, lag, lead, percent_rank, rank, Literal,
NthValue, Ntile,
PhysicalSortExpr, RowNumber,
},
- udaf, ExecutionPlan, ExecutionPlanProperties, InputOrderMode, PhysicalExpr,
+ ExecutionPlan, ExecutionPlanProperties, InputOrderMode, PhysicalExpr,
};
use arrow::datatypes::Schema;
use arrow_schema::{DataType, Field, SchemaRef};
-use datafusion_common::{exec_err, Column, DataFusionError, Result,
ScalarValue};
-use datafusion_expr::Expr;
+use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue};
+use datafusion_expr::{col, Expr, SortExpr};
use datafusion_expr::{
- BuiltInWindowFunction, PartitionEvaluator, SortExpr, WindowFrame,
- WindowFunctionDefinition, WindowUDF,
+ BuiltInWindowFunction, PartitionEvaluator, WindowFrame,
WindowFunctionDefinition,
+ WindowUDF,
};
use datafusion_physical_expr::equivalence::collapse_lex_req;
use datafusion_physical_expr::{
@@ -44,6 +44,7 @@ use datafusion_physical_expr::{
AggregateExpr, ConstExpr, EquivalenceProperties, LexOrdering,
PhysicalSortRequirement,
};
+use datafusion_physical_expr_common::aggregate::AggregateExprBuilder;
use itertools::Itertools;
mod bounded_window_agg_exec;
@@ -95,7 +96,7 @@ pub fn create_window_expr(
fun: &WindowFunctionDefinition,
name: String,
args: &[Arc<dyn PhysicalExpr>],
- logical_args: &[Expr],
+ _logical_args: &[Expr],
partition_by: &[Arc<dyn PhysicalExpr>],
order_by: &[PhysicalSortExpr],
window_frame: Arc<WindowFrame>,
@@ -129,7 +130,6 @@ pub fn create_window_expr(
))
}
WindowFunctionDefinition::AggregateUDF(fun) => {
- // TODO: Ordering not supported for Window UDFs yet
// Convert `Vec<PhysicalSortExpr>` into `Vec<Expr::Sort>`
let sort_exprs = order_by
.iter()
@@ -137,28 +137,20 @@ pub fn create_window_expr(
let field_name = expr.to_string();
let field_name =
field_name.split('@').next().unwrap_or(&field_name);
Expr::Sort(SortExpr {
- expr: Box::new(Expr::Column(Column::new(
- None::<String>,
- field_name,
- ))),
+ expr: Box::new(col(field_name)),
asc: !options.descending,
nulls_first: options.nulls_first,
})
})
.collect::<Vec<_>>();
- let aggregate = udaf::create_aggregate_expr(
- fun.as_ref(),
- args,
- logical_args,
- &sort_exprs,
- order_by,
- input_schema,
- name,
- ignore_nulls,
- false,
- false,
- )?;
+ let aggregate = AggregateExprBuilder::new(Arc::clone(fun),
args.to_vec())
+ .schema(Arc::new(input_schema.clone()))
+ .name(name)
+ .order_by(order_by.to_vec())
+ .sort_exprs(sort_exprs)
+ .with_ignore_nulls(ignore_nulls)
+ .build()?;
window_expr_from_aggregate_expr(
partition_by,
order_by,
@@ -166,6 +158,7 @@ pub fn create_window_expr(
aggregate,
)
}
+ // TODO: Ordering not supported for Window UDFs yet
WindowFunctionDefinition::WindowUDF(fun) =>
Arc::new(BuiltInWindowExpr::new(
create_udwf_window_expr(fun, args, input_schema, name)?,
partition_by,
diff --git a/datafusion/proto/src/physical_plan/mod.rs
b/datafusion/proto/src/physical_plan/mod.rs
index 8c9e5bbd0e..5c4d41f0ec 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -18,6 +18,7 @@
use std::fmt::Debug;
use std::sync::Arc;
+use datafusion::physical_expr_common::aggregate::AggregateExprBuilder;
use prost::bytes::BufMut;
use prost::Message;
@@ -58,7 +59,7 @@ use
datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMerge
use datafusion::physical_plan::union::{InterleaveExec, UnionExec};
use datafusion::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec};
use datafusion::physical_plan::{
- udaf, AggregateExpr, ExecutionPlan, InputOrderMode, PhysicalExpr,
WindowExpr,
+ AggregateExpr, ExecutionPlan, InputOrderMode, PhysicalExpr, WindowExpr,
};
use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result};
use datafusion_expr::{AggregateUDF, ScalarUDF};
@@ -501,13 +502,9 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode {
None =>
registry.udaf(udaf_name)?
};
- // TODO: 'logical_exprs' is not
supported for UDAF yet.
- // approx_percentile_cont and
approx_percentile_cont_weight are not supported for UDAF from protobuf yet.
- let logical_exprs = &[];
+ // TODO: approx_percentile_cont
and approx_percentile_cont_weight are not supported for UDAF from protobuf yet.
// TODO: `order by` is not
supported for UDAF yet
- let sort_exprs = &[];
- let ordering_req = &[];
-
udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, logical_exprs,
sort_exprs, ordering_req, &physical_schema, name, agg_node.ignore_nulls,
agg_node.distinct, false)
+ AggregateExprBuilder::new(agg_udf,
input_phy_expr).schema(Arc::clone(&physical_schema)).name(name).with_ignore_nulls(agg_node.ignore_nulls).with_distinct(agg_node.distinct).build()
}
}
}).transpose()?.ok_or_else(|| {
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index 31ed0837d2..3ddc122e3d 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -24,6 +24,7 @@ use std::vec;
use arrow::array::RecordBatch;
use arrow::csv::WriterBuilder;
+use datafusion::physical_expr_common::aggregate::AggregateExprBuilder;
use prost::Message;
use datafusion::arrow::array::ArrayRef;
@@ -64,7 +65,6 @@ use
datafusion::physical_plan::placeholder_row::PlaceholderRowExec;
use datafusion::physical_plan::projection::ProjectionExec;
use datafusion::physical_plan::repartition::RepartitionExec;
use datafusion::physical_plan::sorts::sort::SortExec;
-use datafusion::physical_plan::udaf::create_aggregate_expr;
use datafusion::physical_plan::union::{InterleaveExec, UnionExec};
use datafusion::physical_plan::windows::{
BuiltInWindowExpr, PlainAggregateWindowExpr, WindowAggExec,
@@ -86,7 +86,7 @@ use datafusion_expr::{
};
use datafusion_functions_aggregate::average::avg_udaf;
use datafusion_functions_aggregate::nth_value::nth_value_udaf;
-use datafusion_functions_aggregate::string_agg::StringAgg;
+use datafusion_functions_aggregate::string_agg::string_agg_udaf;
use datafusion_proto::physical_plan::{
AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec,
};
@@ -291,18 +291,13 @@ fn roundtrip_window() -> Result<()> {
));
let plain_aggr_window_expr = Arc::new(PlainAggregateWindowExpr::new(
- create_aggregate_expr(
- &avg_udaf(),
- &[cast(col("b", &schema)?, &schema, DataType::Float64)?],
- &[],
- &[],
- &[],
- &schema,
- "avg(b)",
- false,
- false,
- false,
- )?,
+ AggregateExprBuilder::new(
+ avg_udaf(),
+ vec![cast(col("b", &schema)?, &schema, DataType::Float64)?],
+ )
+ .schema(Arc::clone(&schema))
+ .name("avg(b)")
+ .build()?,
&[],
&[],
Arc::new(WindowFrame::new(None)),
@@ -315,18 +310,10 @@ fn roundtrip_window() -> Result<()> {
);
let args = vec![cast(col("a", &schema)?, &schema, DataType::Float64)?];
- let sum_expr = create_aggregate_expr(
- &sum_udaf(),
- &args,
- &[],
- &[],
- &[],
- &schema,
- "SUM(a) RANGE BETWEEN CURRENT ROW AND UNBOUNDED PRECEEDING",
- false,
- false,
- false,
- )?;
+ let sum_expr = AggregateExprBuilder::new(sum_udaf(), args)
+ .schema(Arc::clone(&schema))
+ .name("SUM(a) RANGE BETWEEN CURRENT ROW AND UNBOUNDED PRECEEDING")
+ .build()?;
let sliding_aggr_window_expr = Arc::new(SlidingAggregateWindowExpr::new(
sum_expr,
@@ -357,49 +344,28 @@ fn rountrip_aggregate() -> Result<()> {
let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =
vec![(col("a", &schema)?, "unused".to_string())];
+ let avg_expr = AggregateExprBuilder::new(avg_udaf(), vec![col("b",
&schema)?])
+ .schema(Arc::clone(&schema))
+ .name("AVG(b)")
+ .build()?;
+ let nth_expr =
+ AggregateExprBuilder::new(nth_value_udaf(), vec![col("b", &schema)?,
lit(1u64)])
+ .schema(Arc::clone(&schema))
+ .name("NTH_VALUE(b, 1)")
+ .build()?;
+ let str_agg_expr =
+ AggregateExprBuilder::new(string_agg_udaf(), vec![col("b", &schema)?,
lit(1u64)])
+ .schema(Arc::clone(&schema))
+ .name("NTH_VALUE(b, 1)")
+ .build()?;
+
let test_cases: Vec<Vec<Arc<dyn AggregateExpr>>> = vec![
// AVG
- vec![create_aggregate_expr(
- &avg_udaf(),
- &[col("b", &schema)?],
- &[],
- &[],
- &[],
- &schema,
- "AVG(b)",
- false,
- false,
- false,
- )?],
+ vec![avg_expr],
// NTH_VALUE
- vec![create_aggregate_expr(
- &nth_value_udaf(),
- &[col("b", &schema)?, lit(1u64)],
- &[],
- &[],
- &[],
- &schema,
- "NTH_VALUE(b, 1)",
- false,
- false,
- false,
- )?],
+ vec![nth_expr],
// STRING_AGG
- vec![create_aggregate_expr(
- &AggregateUDF::new_from_impl(StringAgg::new()),
- &[
- cast(col("b", &schema)?, &schema, DataType::Utf8)?,
- lit(ScalarValue::Utf8(Some(",".to_string()))),
- ],
- &[],
- &[],
- &[],
- &schema,
- "STRING_AGG(name, ',')",
- false,
- false,
- false,
- )?],
+ vec![str_agg_expr],
];
for aggregates in test_cases {
@@ -426,18 +392,13 @@ fn rountrip_aggregate_with_limit() -> Result<()> {
let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =
vec![(col("a", &schema)?, "unused".to_string())];
- let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![create_aggregate_expr(
- &avg_udaf(),
- &[col("b", &schema)?],
- &[],
- &[],
- &[],
- &schema,
- "AVG(b)",
- false,
- false,
- false,
- )?];
+ let aggregates: Vec<Arc<dyn AggregateExpr>> =
+ vec![
+ AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
+ .schema(Arc::clone(&schema))
+ .name("AVG(b)")
+ .build()?,
+ ];
let agg = AggregateExec::try_new(
AggregateMode::Final,
@@ -498,18 +459,13 @@ fn roundtrip_aggregate_udaf() -> Result<()> {
let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =
vec![(col("a", &schema)?, "unused".to_string())];
- let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![create_aggregate_expr(
- &udaf,
- &[col("b", &schema)?],
- &[],
- &[],
- &[],
- &schema,
- "example_agg",
- false,
- false,
- false,
- )?];
+ let aggregates: Vec<Arc<dyn AggregateExpr>> =
+ vec![
+ AggregateExprBuilder::new(Arc::new(udaf), vec![col("b", &schema)?])
+ .schema(Arc::clone(&schema))
+ .name("example_agg")
+ .build()?,
+ ];
roundtrip_test_with_context(
Arc::new(AggregateExec::try_new(
@@ -994,21 +950,16 @@ fn roundtrip_aggregate_udf_extension_codec() ->
Result<()> {
DataType::Int64,
));
- let udaf = AggregateUDF::from(MyAggregateUDF::new("result".to_string()));
- let aggr_args: [Arc<dyn PhysicalExpr>; 1] =
- [Arc::new(Literal::new(ScalarValue::from(42)))];
- let aggr_expr = create_aggregate_expr(
- &udaf,
- &aggr_args,
- &[],
- &[],
- &[],
- &schema,
- "aggregate_udf",
- false,
- false,
- false,
- )?;
+ let udaf = Arc::new(AggregateUDF::from(MyAggregateUDF::new(
+ "result".to_string(),
+ )));
+ let aggr_args: Vec<Arc<dyn PhysicalExpr>> =
+ vec![Arc::new(Literal::new(ScalarValue::from(42)))];
+
+ let aggr_expr = AggregateExprBuilder::new(Arc::clone(&udaf),
aggr_args.clone())
+ .schema(Arc::clone(&schema))
+ .name("aggregate_udf")
+ .build()?;
let filter = Arc::new(FilterExec::try_new(
Arc::new(BinaryExpr::new(
@@ -1030,18 +981,12 @@ fn roundtrip_aggregate_udf_extension_codec() ->
Result<()> {
vec![col("author", &schema)?],
)?);
- let aggr_expr = create_aggregate_expr(
- &udaf,
- &aggr_args,
- &[],
- &[],
- &[],
- &schema,
- "aggregate_udf",
- true,
- true,
- false,
- )?;
+ let aggr_expr = AggregateExprBuilder::new(udaf, aggr_args.clone())
+ .schema(Arc::clone(&schema))
+ .name("aggregate_udf")
+ .distinct()
+ .ignore_nulls()
+ .build()?;
let aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Final,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]