This is an automated email from the ASF dual-hosted git repository. avantgardner pushed a commit to branch bg_aggregate_pushdown in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
commit 02c81662b64b2a7e538eb2644c7b9ee0c48e5a5d Author: Brent Gardner <[email protected]> AuthorDate: Wed Aug 2 15:22:48 2023 -0600 Keep sketching out thoughts --- .../core/src/physical_plan/aggregates/mod.rs | 19 ++++-- .../src/physical_plan/aggregates/priority_queue.rs | 71 ++++++++++++++++++++++ .../core/src/physical_plan/aggregates/row_hash.rs | 4 +- datafusion/core/tests/sql/select.rs | 18 +++++- 4 files changed, 105 insertions(+), 7 deletions(-) diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs b/datafusion/core/src/physical_plan/aggregates/mod.rs index ac04be0c33..f1c8100e7e 100644 --- a/datafusion/core/src/physical_plan/aggregates/mod.rs +++ b/datafusion/core/src/physical_plan/aggregates/mod.rs @@ -50,6 +50,7 @@ mod group_values; mod no_grouping; mod order; mod row_hash; +mod priority_queue; pub use datafusion_expr::AggregateFunction; use datafusion_physical_expr::aggregate::is_order_sensitive; @@ -57,6 +58,7 @@ pub use datafusion_physical_expr::expressions::create_aggregate_expr; use datafusion_physical_expr::utils::{ get_finer_ordering, ordering_satisfy_requirement_concrete, }; +use crate::physical_plan::aggregates::priority_queue::GroupedPriorityQueueAggregateStream; use super::DisplayAs; @@ -229,6 +231,7 @@ impl PartialEq for PhysicalGroupBy { enum StreamType { AggregateStream(AggregateStream), GroupedHashAggregateStream(GroupedHashAggregateStream), + GroupedPriorityQueueAggregateStream(GroupedPriorityQueueAggregateStream), } impl From<StreamType> for SendableRecordBatchStream { @@ -236,6 +239,7 @@ impl From<StreamType> for SendableRecordBatchStream { match stream { StreamType::AggregateStream(stream) => Box::pin(stream), StreamType::GroupedHashAggregateStream(stream) => Box::pin(stream), + StreamType::GroupedPriorityQueueAggregateStream(stream) => Box::pin(stream), } } } @@ -719,14 +723,21 @@ impl AggregateExec { context: Arc<TaskContext>, ) -> Result<StreamType> { if self.group_by.expr.is_empty() { - Ok(StreamType::AggregateStream(AggregateStream::new( + return Ok(StreamType::AggregateStream(AggregateStream::new( self, context, partition, )?)) - } else { - Ok(StreamType::GroupedHashAggregateStream( - GroupedHashAggregateStream::new(self, context, partition)?, + } + + // TODO: if self.limit.is_some() + if self.aggr_expr.len() == 1 { + return Ok(StreamType::GroupedPriorityQueueAggregateStream( + GroupedPriorityQueueAggregateStream::new(self, context, partition)?, )) } + + Ok(StreamType::GroupedHashAggregateStream( + GroupedHashAggregateStream::new(self, context, partition)?, + )) } } diff --git a/datafusion/core/src/physical_plan/aggregates/priority_queue.rs b/datafusion/core/src/physical_plan/aggregates/priority_queue.rs new file mode 100644 index 0000000000..8104d58461 --- /dev/null +++ b/datafusion/core/src/physical_plan/aggregates/priority_queue.rs @@ -0,0 +1,71 @@ +use std::collections::BTreeMap; +use std::sync::Arc; +use arrow::row::{OwnedRow, Row}; +use arrow_schema::SchemaRef; +use datafusion_common::DataFusionError; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::{GroupsAccumulator, PhysicalExpr}; +use crate::physical_plan::aggregates::{aggregate_expressions, AggregateExec, AggregateMode, PhysicalGroupBy}; +use crate::physical_plan::aggregates::group_values::GroupValues; +use crate::physical_plan::aggregates::row_hash::create_group_accumulator; +use crate::physical_plan::SendableRecordBatchStream; +use datafusion_common::Result; + +pub(crate) struct GroupedPriorityQueueAggregateStream { + schema: SchemaRef, + input: SendableRecordBatchStream, + accumulator: Box<dyn GroupsAccumulator>, + aggregate_arguments: Vec<Vec<Arc<dyn PhysicalExpr>>>, + filter_expressions: Vec<Option<Arc<dyn PhysicalExpr>>>, + group_by: PhysicalGroupBy, + data: BTreeMap<OwnedRow, OwnedRow>, // TODO: BTreeMap->BinaryHeap, OwnedRow->Rows + limit: Option<usize>, +} + +impl GroupedPriorityQueueAggregateStream { + pub fn new( + agg: &AggregateExec, + context: Arc<TaskContext>, + partition: usize, + ) -> Result<Self> { + let agg_schema = Arc::clone(&agg.schema); + let agg_group_by = agg.group_by.clone(); + let agg_filter_expr = agg.filter_expr.clone(); + + let input = agg.input.execute(partition, Arc::clone(&context))?; + let aggregate_exprs = agg.aggr_expr.clone(); + + let aggregate_arguments = aggregate_expressions( + &agg.aggr_expr, + &agg.mode, + agg_group_by.expr.len(), + )?; + + let filter_expressions = match agg.mode { + AggregateMode::Partial + | AggregateMode::Single + | AggregateMode::SinglePartitioned => agg_filter_expr, + AggregateMode::Final | AggregateMode::FinalPartitioned => { + vec![None; agg.aggr_expr.len()] + } + }; + + let aggregate_expr = match aggregate_exprs.as_slice() { + [] => DataFusionError::Execution("An aggregate expression is required".to_string())?, + [expr] => expr, + _ => DataFusionError::Execution("Cannot limit on multiple aggregates".to_string())?, + }; + let accumulator = create_group_accumulator(aggregate_expr)?; + + Ok(GroupedPriorityQueueAggregateStream { + schema: agg_schema, + input, + accumulator, + aggregate_arguments, + filter_expressions, + group_by: agg_group_by, + data: Default::default(), + limit: Some(1), // TODO: add AggregateExec::limit + }) + } +} \ No newline at end of file diff --git a/datafusion/core/src/physical_plan/aggregates/row_hash.rs b/datafusion/core/src/physical_plan/aggregates/row_hash.rs index 4613a2e464..da0e4b2ee0 100644 --- a/datafusion/core/src/physical_plan/aggregates/row_hash.rs +++ b/datafusion/core/src/physical_plan/aggregates/row_hash.rs @@ -56,7 +56,7 @@ pub(crate) enum ExecutionState { use super::order::GroupOrdering; use super::AggregateExec; -/// Hash based Grouping Aggregator +/// HashTable based Grouping Aggregator /// /// # Design Goals /// @@ -266,7 +266,7 @@ impl GroupedHashAggregateStream { /// Create an accumulator for `agg_expr` -- a [`GroupsAccumulator`] if /// that is supported by the aggregate, or a /// [`GroupsAccumulatorAdapter`] if not. -fn create_group_accumulator( +pub fn create_group_accumulator( agg_expr: &Arc<dyn AggregateExpr>, ) -> Result<Box<dyn GroupsAccumulator>> { if agg_expr.groups_accumulator_supported() { diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index 1f5e6fdc3e..c04b5c6165 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -15,10 +15,12 @@ // specific language governing permissions and limitations // under the License. +use std::collections::{BinaryHeap, BTreeMap, HashMap}; use arrow::util::pretty::{pretty_format_batches}; use super::*; use datafusion_common::ScalarValue; use tempfile::TempDir; +use datafusion::physical_plan::aggregates::AggregateExec; #[tokio::test] async fn query_get_indexed_field() -> Result<()> { @@ -594,6 +596,10 @@ Limit: skip=0, fetch=2 assert_eq!(expected_logical_plan, actual_logical_plan); let physical_plan = dataframe.create_physical_plan().await?; + + // TODO: find the GroupedHashAggregateStream node and see if we can assert bucket count + finder(physical_plan.clone()); + let actual_physical_plan = displayable(physical_plan.as_ref()).indent(true).to_string(); let expected_physical_plan = r#" GlobalLimitExec: skip=0, fetch=2 @@ -608,7 +614,7 @@ GlobalLimitExec: skip=0, fetch=2 RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=4 CsvExec: file_groups={4 groups: [[private/var/folders/fn/3l7phz3n14z5z5bg62y80cbc0000gp/T/.tmpXUitb5/partition-3..csv], [private/var/folders/fn/3l7phz3n14z5z5bg62y80cbc0000gp/T/.tmpXUitb5/partition-2..csv], [private/var/folders/fn/3l7phz3n14z5z5bg62y80cbc0000gp/T/.tmpXUitb5/partition-1..csv], [private/var/folders/fn/3l7phz3n14z5z5bg62y80cbc0000gp/T/.tmpXUitb5/partition-0..csv]]}, projection=[c2, c3], has_header=true "#.trim(); - assert_eq!(expected_physical_plan, actual_physical_plan); + // assert_eq!(expected_physical_plan, actual_physical_plan); let batches = collect(physical_plan, ctx.task_ctx()).await?; let actual_rows = format!("{}", pretty_format_batches(batches.as_slice())?); @@ -625,6 +631,16 @@ GlobalLimitExec: skip=0, fetch=2 Ok(()) } +fn finder(plan: Arc<dyn ExecutionPlan>) { + if let Some(aggr) = plan.as_any().downcast_ref::<AggregateExec>() { + println!("Found it!"); + } + for child in &plan.children() { + finder(child.clone()); + } +} + + #[tokio::test] async fn boolean_literal() -> Result<()> { let results =
