This is an automated email from the ASF dual-hosted git repository.

avantgardner pushed a commit to branch bg_aggregate_pushdown
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git

commit 02c81662b64b2a7e538eb2644c7b9ee0c48e5a5d
Author: Brent Gardner <[email protected]>
AuthorDate: Wed Aug 2 15:22:48 2023 -0600

    Keep sketching out thoughts
---
 .../core/src/physical_plan/aggregates/mod.rs       | 19 ++++--
 .../src/physical_plan/aggregates/priority_queue.rs | 71 ++++++++++++++++++++++
 .../core/src/physical_plan/aggregates/row_hash.rs  |  4 +-
 datafusion/core/tests/sql/select.rs                | 18 +++++-
 4 files changed, 105 insertions(+), 7 deletions(-)

diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs 
b/datafusion/core/src/physical_plan/aggregates/mod.rs
index ac04be0c33..f1c8100e7e 100644
--- a/datafusion/core/src/physical_plan/aggregates/mod.rs
+++ b/datafusion/core/src/physical_plan/aggregates/mod.rs
@@ -50,6 +50,7 @@ mod group_values;
 mod no_grouping;
 mod order;
 mod row_hash;
+mod priority_queue;
 
 pub use datafusion_expr::AggregateFunction;
 use datafusion_physical_expr::aggregate::is_order_sensitive;
@@ -57,6 +58,7 @@ pub use 
datafusion_physical_expr::expressions::create_aggregate_expr;
 use datafusion_physical_expr::utils::{
     get_finer_ordering, ordering_satisfy_requirement_concrete,
 };
+use 
crate::physical_plan::aggregates::priority_queue::GroupedPriorityQueueAggregateStream;
 
 use super::DisplayAs;
 
@@ -229,6 +231,7 @@ impl PartialEq for PhysicalGroupBy {
 enum StreamType {
     AggregateStream(AggregateStream),
     GroupedHashAggregateStream(GroupedHashAggregateStream),
+    GroupedPriorityQueueAggregateStream(GroupedPriorityQueueAggregateStream),
 }
 
 impl From<StreamType> for SendableRecordBatchStream {
@@ -236,6 +239,7 @@ impl From<StreamType> for SendableRecordBatchStream {
         match stream {
             StreamType::AggregateStream(stream) => Box::pin(stream),
             StreamType::GroupedHashAggregateStream(stream) => Box::pin(stream),
+            StreamType::GroupedPriorityQueueAggregateStream(stream) => 
Box::pin(stream),
         }
     }
 }
@@ -719,14 +723,21 @@ impl AggregateExec {
         context: Arc<TaskContext>,
     ) -> Result<StreamType> {
         if self.group_by.expr.is_empty() {
-            Ok(StreamType::AggregateStream(AggregateStream::new(
+            return Ok(StreamType::AggregateStream(AggregateStream::new(
                 self, context, partition,
             )?))
-        } else {
-            Ok(StreamType::GroupedHashAggregateStream(
-                GroupedHashAggregateStream::new(self, context, partition)?,
+        }
+
+        // TODO: if self.limit.is_some()
+        if self.aggr_expr.len() == 1 {
+            return Ok(StreamType::GroupedPriorityQueueAggregateStream(
+                GroupedPriorityQueueAggregateStream::new(self, context, 
partition)?,
             ))
         }
+
+        Ok(StreamType::GroupedHashAggregateStream(
+            GroupedHashAggregateStream::new(self, context, partition)?,
+        ))
     }
 }
 
diff --git a/datafusion/core/src/physical_plan/aggregates/priority_queue.rs 
b/datafusion/core/src/physical_plan/aggregates/priority_queue.rs
new file mode 100644
index 0000000000..8104d58461
--- /dev/null
+++ b/datafusion/core/src/physical_plan/aggregates/priority_queue.rs
@@ -0,0 +1,71 @@
+use std::collections::BTreeMap;
+use std::sync::Arc;
+use arrow::row::{OwnedRow, Row};
+use arrow_schema::SchemaRef;
+use datafusion_common::DataFusionError;
+use datafusion_execution::TaskContext;
+use datafusion_physical_expr::{GroupsAccumulator, PhysicalExpr};
+use crate::physical_plan::aggregates::{aggregate_expressions, AggregateExec, 
AggregateMode, PhysicalGroupBy};
+use crate::physical_plan::aggregates::group_values::GroupValues;
+use crate::physical_plan::aggregates::row_hash::create_group_accumulator;
+use crate::physical_plan::SendableRecordBatchStream;
+use datafusion_common::Result;
+
+pub(crate) struct GroupedPriorityQueueAggregateStream {
+    schema: SchemaRef,
+    input: SendableRecordBatchStream,
+    accumulator: Box<dyn GroupsAccumulator>,
+    aggregate_arguments: Vec<Vec<Arc<dyn PhysicalExpr>>>,
+    filter_expressions: Vec<Option<Arc<dyn PhysicalExpr>>>,
+    group_by: PhysicalGroupBy,
+    data: BTreeMap<OwnedRow, OwnedRow>, // TODO: BTreeMap->BinaryHeap, 
OwnedRow->Rows
+    limit: Option<usize>,
+}
+
+impl GroupedPriorityQueueAggregateStream {
+    pub fn new(
+        agg: &AggregateExec,
+        context: Arc<TaskContext>,
+        partition: usize,
+    ) -> Result<Self> {
+        let agg_schema = Arc::clone(&agg.schema);
+        let agg_group_by = agg.group_by.clone();
+        let agg_filter_expr = agg.filter_expr.clone();
+
+        let input = agg.input.execute(partition, Arc::clone(&context))?;
+        let aggregate_exprs = agg.aggr_expr.clone();
+
+        let aggregate_arguments = aggregate_expressions(
+            &agg.aggr_expr,
+            &agg.mode,
+            agg_group_by.expr.len(),
+        )?;
+
+        let filter_expressions = match agg.mode {
+            AggregateMode::Partial
+            | AggregateMode::Single
+            | AggregateMode::SinglePartitioned => agg_filter_expr,
+            AggregateMode::Final | AggregateMode::FinalPartitioned => {
+                vec![None; agg.aggr_expr.len()]
+            }
+        };
+
+        let aggregate_expr = match aggregate_exprs.as_slice() {
+            [] => DataFusionError::Execution("An aggregate expression is 
required".to_string())?,
+            [expr] => expr,
+            _ => DataFusionError::Execution("Cannot limit on multiple 
aggregates".to_string())?,
+        };
+        let accumulator = create_group_accumulator(aggregate_expr)?;
+
+        Ok(GroupedPriorityQueueAggregateStream {
+            schema: agg_schema,
+            input,
+            accumulator,
+            aggregate_arguments,
+            filter_expressions,
+            group_by: agg_group_by,
+            data: Default::default(),
+            limit: Some(1), // TODO: add AggregateExec::limit
+        })
+    }
+}
\ No newline at end of file
diff --git a/datafusion/core/src/physical_plan/aggregates/row_hash.rs 
b/datafusion/core/src/physical_plan/aggregates/row_hash.rs
index 4613a2e464..da0e4b2ee0 100644
--- a/datafusion/core/src/physical_plan/aggregates/row_hash.rs
+++ b/datafusion/core/src/physical_plan/aggregates/row_hash.rs
@@ -56,7 +56,7 @@ pub(crate) enum ExecutionState {
 use super::order::GroupOrdering;
 use super::AggregateExec;
 
-/// Hash based Grouping Aggregator
+/// HashTable based Grouping Aggregator
 ///
 /// # Design Goals
 ///
@@ -266,7 +266,7 @@ impl GroupedHashAggregateStream {
 /// Create an accumulator for `agg_expr` -- a [`GroupsAccumulator`] if
 /// that is supported by the aggregate, or a
 /// [`GroupsAccumulatorAdapter`] if not.
-fn create_group_accumulator(
+pub fn create_group_accumulator(
     agg_expr: &Arc<dyn AggregateExpr>,
 ) -> Result<Box<dyn GroupsAccumulator>> {
     if agg_expr.groups_accumulator_supported() {
diff --git a/datafusion/core/tests/sql/select.rs 
b/datafusion/core/tests/sql/select.rs
index 1f5e6fdc3e..c04b5c6165 100644
--- a/datafusion/core/tests/sql/select.rs
+++ b/datafusion/core/tests/sql/select.rs
@@ -15,10 +15,12 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use std::collections::{BinaryHeap, BTreeMap, HashMap};
 use arrow::util::pretty::{pretty_format_batches};
 use super::*;
 use datafusion_common::ScalarValue;
 use tempfile::TempDir;
+use datafusion::physical_plan::aggregates::AggregateExec;
 
 #[tokio::test]
 async fn query_get_indexed_field() -> Result<()> {
@@ -594,6 +596,10 @@ Limit: skip=0, fetch=2
     assert_eq!(expected_logical_plan, actual_logical_plan);
 
     let physical_plan = dataframe.create_physical_plan().await?;
+
+    // TODO: find the GroupedHashAggregateStream node and see if we can assert 
bucket count
+    finder(physical_plan.clone());
+
     let actual_physical_plan = 
displayable(physical_plan.as_ref()).indent(true).to_string();
     let expected_physical_plan = r#"
 GlobalLimitExec: skip=0, fetch=2
@@ -608,7 +614,7 @@ GlobalLimitExec: skip=0, fetch=2
                 RepartitionExec: partitioning=RoundRobinBatch(8), 
input_partitions=4
                   CsvExec: file_groups={4 groups: 
[[private/var/folders/fn/3l7phz3n14z5z5bg62y80cbc0000gp/T/.tmpXUitb5/partition-3..csv],
 
[private/var/folders/fn/3l7phz3n14z5z5bg62y80cbc0000gp/T/.tmpXUitb5/partition-2..csv],
 
[private/var/folders/fn/3l7phz3n14z5z5bg62y80cbc0000gp/T/.tmpXUitb5/partition-1..csv],
 
[private/var/folders/fn/3l7phz3n14z5z5bg62y80cbc0000gp/T/.tmpXUitb5/partition-0..csv]]},
 projection=[c2, c3], has_header=true
     "#.trim();
-    assert_eq!(expected_physical_plan, actual_physical_plan);
+    // assert_eq!(expected_physical_plan, actual_physical_plan);
 
     let batches = collect(physical_plan, ctx.task_ctx()).await?;
     let actual_rows = format!("{}", 
pretty_format_batches(batches.as_slice())?);
@@ -625,6 +631,16 @@ GlobalLimitExec: skip=0, fetch=2
     Ok(())
 }
 
+fn finder(plan: Arc<dyn ExecutionPlan>) {
+    if let Some(aggr) = plan.as_any().downcast_ref::<AggregateExec>() {
+        println!("Found it!");
+    }
+    for child in &plan.children() {
+        finder(child.clone());
+    }
+}
+
+
 #[tokio::test]
 async fn boolean_literal() -> Result<()> {
     let results =

Reply via email to