alamb commented on code in PR #7192:
URL: https://github.com/apache/arrow-datafusion/pull/7192#discussion_r1301660481


##########
Cargo.toml:
##########
@@ -39,6 +39,7 @@ arrow-flight = { version = "45.0.0", features = 
["flight-sql-experimental"] }
 arrow-schema = { version = "45.0.0", default-features = false }
 parquet = { version = "45.0.0", features = ["arrow", "async", "object_store"] }
 sqlparser = { version = "0.36.1", features = ["visitor"] }
+zerocopy = "0.6.1"

Review Comment:
   FWIW zero copy is already used by Datafusion transitively -- (see 
https://github.com/apache/arrow-datafusion/issues/7221) so this is not a new 
dependency



##########
datafusion/core/src/physical_plan/aggregates/mod.rs:
##########
@@ -717,14 +725,38 @@ impl AggregateExec {
         partition: usize,
         context: Arc<TaskContext>,
     ) -> Result<StreamType> {
+        // no group by at all
         if self.group_by.expr.is_empty() {
-            Ok(StreamType::AggregateStream(AggregateStream::new(
+            return Ok(StreamType::AggregateStream(AggregateStream::new(
                 self, context, partition,
-            )?))
+            )?));
+        }
+
+        // grouping by an expression that has a sort/limit upstream
+        if let Some(limit) = self.limit {
+            return Ok(StreamType::GroupedPriorityQueue(
+                GroupedTopKAggregateStream::new(self, context, partition, 
limit)?,
+            ));
+        }
+
+        // grouping by something else and we need to just materialize all 
results
+        Ok(StreamType::GroupedHash(GroupedHashAggregateStream::new(
+            self, context, partition,
+        )?))
+    }
+
+    /// Finds the DataType and SortDirection for this Aggregate, if there is 
one
+    pub fn get_minmax_desc(&self) -> Option<(Field, bool)> {
+        let agg_expr = match self.aggr_expr.as_slice() {
+            [expr] => expr,
+            _ => return None,
+        };
+        if let Some(max) = agg_expr.as_any().downcast_ref::<Max>() {
+            Some((max.field().ok()?, true))
+        } else if let Some(min) = agg_expr.as_any().downcast_ref::<Min>() {
+            Some((min.field().ok()?, true))

Review Comment:
   This code always seems to return `true` for the sort direction, is that 
intended?



##########
datafusion/core/src/physical_plan/aggregates/row_hash.rs:
##########
@@ -145,7 +145,7 @@ pub(crate) struct GroupedHashAggregateStream {
     /// accumulator. If present, only those rows for which the filter
     /// evaluate to true should be included in the aggregate results.
     ///
-    /// For example, for an aggregate like `SUM(x FILTER x > 100)`,
+    /// For example, for an aggregate like `SUM(x) FILTER (WHERE x >= 100)`,

Review Comment:
   👍 



##########
datafusion/sqllogictest/test_files/aggregate.slt:
##########
@@ -2291,7 +2291,131 @@ false
 true
 NULL
 
+# TopK aggregation
+statement ok
+CREATE TABLE traces(trace_id varchar, timestamp bigint) AS VALUES
+(NULL, 0),
+('a', NULL),
+('a', 1),
+('b', 0),
+('c', 1),
+('c', 2),
+('b', 3);
+
+statement ok
+set datafusion.optimizer.enable_topk_aggregation = false;
 
+query TT
+explain select trace_id, MAX(timestamp) from traces group by trace_id order by 
MAX(timestamp) desc limit 4;
+----
+logical_plan
+Limit: skip=0, fetch=4
+--Sort: MAX(traces.timestamp) DESC NULLS FIRST, fetch=4
+----Aggregate: groupBy=[[traces.trace_id]], aggr=[[MAX(traces.timestamp)]]
+------TableScan: traces projection=[trace_id, timestamp]
+physical_plan
+GlobalLimitExec: skip=0, fetch=4
+--SortPreservingMergeExec: [MAX(traces.timestamp)@1 DESC], fetch=4
+----SortExec: fetch=4, expr=[MAX(traces.timestamp)@1 DESC]
+------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], 
aggr=[MAX(traces.timestamp)]
+--------CoalesceBatchesExec: target_batch_size=8192
+----------RepartitionExec: partitioning=Hash([trace_id@0], 4), 
input_partitions=4
+------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], 
aggr=[MAX(traces.timestamp)]
+--------------RepartitionExec: partitioning=RoundRobinBatch(4), 
input_partitions=1
+----------------MemoryExec: partitions=1, partition_sizes=[1]
+
+
+query TI
+select trace_id, MAX(timestamp) from traces group by trace_id order by 
MAX(timestamp) desc limit 4;
+----
+b 3
+c 2
+a 1
+NULL 0
+
+query TI
+select trace_id, MIN(timestamp) from traces group by trace_id order by 
MIN(timestamp) asc limit 4;

Review Comment:
   I also recommend a test like
   
   ```suggestion
   select trace_id, MIN(timestamp) from traces group by trace_id order by 
MIN(timestamp) desc limit 4;
   ```
   
   Which I realize isn't a particularly useful query, but it should still work 
and it wasn't quite clear to me that this case is covered correctly



##########
datafusion/core/src/physical_optimizer/topk_aggregation.rs:
##########
@@ -0,0 +1,191 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! An optimizer rule that detects aggregate operations that could use a 
limited bucket count
+
+use crate::physical_optimizer::PhysicalOptimizerRule;
+use crate::physical_plan::aggregates::AggregateExec;
+use crate::physical_plan::coalesce_batches::CoalesceBatchesExec;
+use crate::physical_plan::filter::FilterExec;
+use crate::physical_plan::repartition::RepartitionExec;
+use crate::physical_plan::sorts::sort::SortExec;
+use crate::physical_plan::ExecutionPlan;
+use arrow_schema::DataType;
+use datafusion_common::config::ConfigOptions;
+use datafusion_common::tree_node::{Transformed, TreeNode};
+use datafusion_common::Result;
+use datafusion_physical_expr::expressions::Column;
+use datafusion_physical_expr::PhysicalSortExpr;
+use std::sync::Arc;
+
+/// An optimizer rule that passes a `limit` hint to aggregations if the whole 
result is not needed
+pub struct TopKAggregation {}
+
+impl TopKAggregation {
+    /// Create a new `LimitAggregation`
+    pub fn new() -> Self {
+        Self {}
+    }
+
+    fn transform_agg(
+        aggr: &AggregateExec,
+        order: &PhysicalSortExpr,
+        limit: usize,
+    ) -> Option<Arc<dyn ExecutionPlan>> {
+        // ensure the sort direction matches aggregate function
+        let (field, desc) = aggr.get_minmax_desc()?;
+        if desc != order.options.descending {
+            return None;
+        }
+        let group_key = match aggr.group_expr().expr() {
+            [expr] => expr, // only one group key
+            _ => return None,
+        };
+        match group_key.0.data_type(&aggr.input_schema).ok() {
+            Some(DataType::Utf8) => {} // only String keys for now
+            _ => return None,
+        }
+        if aggr
+            .filter_expr
+            .iter()
+            .fold(false, |acc, cur| acc | cur.is_some())
+        {
+            return None;
+        }
+
+        // ensure the sort is on the same field as the aggregate output
+        let col = order.expr.as_any().downcast_ref::<Column>()?;
+        if col.name() != field.name() {
+            return None;
+        }
+
+        // We found what we want: clone, copy the limit down, and return 
modified node
+        let mut new_aggr = AggregateExec::try_new(
+            aggr.mode,
+            aggr.group_by.clone(),
+            aggr.aggr_expr.clone(),
+            aggr.filter_expr.clone(),
+            aggr.order_by_expr.clone(),
+            aggr.input.clone(),
+            aggr.input_schema.clone(),
+        )
+        .expect("Unable to copy Aggregate!");
+        new_aggr.limit = Some(limit);
+        Some(Arc::new(new_aggr))
+    }
+
+    fn transform_sort(plan: Arc<dyn ExecutionPlan>) -> Option<Arc<dyn 
ExecutionPlan>> {
+        let sort = plan.as_any().downcast_ref::<SortExec>()?;
+
+        // TODO: support sorting on multiple fields
+        let children = sort.children();
+        let child = match children.as_slice() {
+            [child] => child.clone(),
+            _ => return None,
+        };
+        let order = sort.output_ordering()?;
+        let order = match order {
+            [order] => order,
+            _ => return None,
+        };
+        let limit = sort.fetch()?;
+
+        let is_cardinality_preserving = |plan: Arc<dyn ExecutionPlan>| {
+            plan.as_any()
+                .downcast_ref::<CoalesceBatchesExec>()
+                .is_some()
+                || plan.as_any().downcast_ref::<RepartitionExec>().is_some()
+                || plan.as_any().downcast_ref::<FilterExec>().is_some()
+            // TODO: whitelist joins that don't increase row count?
+        };
+
+        let mut cardinality_preserved = true;
+        let mut closure = |plan: Arc<dyn ExecutionPlan>| {
+            if !cardinality_preserved {
+                return Ok(Transformed::No(plan));
+            }
+            if let Some(aggr) = plan.as_any().downcast_ref::<AggregateExec>() {
+                // either we run into an Aggregate and transform it
+                match Self::transform_agg(aggr, order, limit) {
+                    None => cardinality_preserved = false,
+                    Some(plan) => return Ok(Transformed::Yes(plan)),
+                }
+            } else {
+                // or we continue down whitelisted nodes of other types
+                if !is_cardinality_preserving(plan.clone()) {
+                    cardinality_preserved = false;
+                }
+            }
+            Ok(Transformed::No(plan))
+        };
+        let child = transform_down_mut(child, &mut closure).ok()?;
+        let sort = SortExec::new(sort.expr().to_vec(), child)
+            .with_fetch(sort.fetch())
+            .with_preserve_partitioning(sort.preserve_partitioning());
+        Some(Arc::new(sort))
+    }
+}
+
+fn transform_down_mut<F>(
+    me: Arc<dyn ExecutionPlan>,
+    op: &mut F,
+) -> Result<Arc<dyn ExecutionPlan>>
+where
+    F: FnMut(Arc<dyn ExecutionPlan>) -> Result<Transformed<Arc<dyn 
ExecutionPlan>>>,
+{
+    let after_op = op(me)?.into();
+    after_op.map_children(|node| transform_down_mut(node, op))
+}
+
+impl Default for TopKAggregation {
+    fn default() -> Self {
+        Self::new()
+    }
+}
+
+impl PhysicalOptimizerRule for TopKAggregation {
+    fn optimize(
+        &self,
+        plan: Arc<dyn ExecutionPlan>,
+        config: &ConfigOptions,
+    ) -> Result<Arc<dyn ExecutionPlan>> {
+        let plan = if config.optimizer.enable_topk_aggregation {
+            plan.transform_down(&|plan| {
+                Ok(
+                    if let Some(plan) = 
TopKAggregation::transform_sort(plan.clone()) {
+                        Transformed::Yes(plan)
+                    } else {
+                        Transformed::No(plan)
+                    },
+                )
+            })?
+        } else {
+            plan
+        };
+        Ok(plan)
+    }
+
+    fn name(&self) -> &str {
+        "LimitAggregation"
+    }
+
+    fn schema_check(&self) -> bool {
+        true
+    }
+}
+
+// TODO: tests

Review Comment:
   👍 



##########
datafusion/core/src/physical_plan/aggregates/priority_map.rs:
##########
@@ -0,0 +1,829 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! A memory-conscious aggregation implementation that limits group buckets to 
a fixed number
+
+use crate::physical_plan::aggregates::{
+    aggregate_expressions, evaluate_group_by, evaluate_many, AggregateExec,
+    PhysicalGroupBy,
+};
+use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream};
+use ahash::RandomState;
+use arrow::util::pretty::print_batches;
+use arrow_array::cast::AsArray;
+use arrow_array::downcast_primitive;
+use arrow_array::{
+    Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray, 
RecordBatch,
+    StringArray,
+};
+use arrow_schema::{DataType, SchemaRef};
+use datafusion_common::DataFusionError;
+use datafusion_common::Result;
+use datafusion_execution::TaskContext;
+use datafusion_physical_expr::PhysicalExpr;
+use futures::stream::{Stream, StreamExt};
+use hashbrown::raw::RawTable;
+use log::{trace, Level};
+use std::cmp::Ordering;
+use std::collections::BTreeSet;
+use std::hash::Hash;
+use std::pin::Pin;
+use std::sync::Arc;
+use std::task::{Context, Poll};
+
+pub struct GroupedTopKAggregateStream {
+    partition: usize,
+    row_count: usize,
+    started: bool,
+    schema: SchemaRef,
+    input: SendableRecordBatchStream,
+    aggregate_arguments: Vec<Vec<Arc<dyn PhysicalExpr>>>,
+    group_by: PhysicalGroupBy,
+    aggregator: Box<dyn LimitedAggregator>,
+}
+
+impl GroupedTopKAggregateStream {
+    pub fn new(
+        agg: &AggregateExec,
+        context: Arc<TaskContext>,
+        partition: usize,
+        limit: usize,
+    ) -> Result<Self> {
+        let agg_schema = Arc::clone(&agg.schema);
+        let group_by = agg.group_by.clone();
+
+        let input = agg.input.execute(partition, Arc::clone(&context))?;
+
+        let aggregate_arguments =
+            aggregate_expressions(&agg.aggr_expr, &agg.mode, 
group_by.expr.len())?;
+
+        let (val_field, descending) = agg
+            .get_minmax_desc()
+            .ok_or_else(|| DataFusionError::Execution("Min/max 
required".to_string()))?;
+
+        let vt = val_field.data_type().clone();
+        let ag = new_group_values(limit, descending, vt)?;
+
+        Ok(GroupedTopKAggregateStream {
+            partition,
+            started: false,
+            row_count: 0,
+            schema: agg_schema,
+            input,
+            aggregate_arguments,
+            group_by,
+            aggregator: ag,
+        })
+    }
+}
+
+pub fn new_group_values(
+    limit: usize,
+    desc: bool,
+    vt: DataType,
+) -> Result<Box<dyn LimitedAggregator>> {
+    macro_rules! downcast_helper {
+        ($vt:ty, $d:ident) => {
+            return Ok(Box::new(PrimitiveAggregator::<$vt>::new(
+                limit,
+                limit * 10,
+                desc,
+            )))
+        };
+    }
+
+    downcast_primitive! {
+        vt => (downcast_helper, vt),
+        _ => {}
+    }
+
+    Err(DataFusionError::Execution(format!(
+        "Can't group type: {vt:?}"
+    )))
+}
+
+impl RecordBatchStream for GroupedTopKAggregateStream {
+    fn schema(&self) -> SchemaRef {
+        self.schema.clone()
+    }
+}
+
+pub trait LimitedAggregator: Send {
+    fn intern(&mut self, ids: ArrayRef, vals: ArrayRef) -> Result<()>;
+    fn emit(&mut self) -> Result<Vec<ArrayRef>>;
+    fn is_empty(&self) -> bool;
+}
+
+pub trait ValueType: ArrowNativeTypeOp + Clone {}
+
+impl<T> ValueType for T where T: ArrowNativeTypeOp + Clone {}
+
+pub trait KeyType: Clone + Eq + Hash {}
+
+impl<T> KeyType for T where T: Clone + Eq + Hash {}
+
+struct PrimitiveAggregator<VAL: ArrowPrimitiveType>
+where
+    <VAL as ArrowPrimitiveType>::Native: Clone,
+{
+    priority_map: PriorityMap<Option<String>, VAL::Native>,
+}
+
+impl<VAL: ArrowPrimitiveType> PrimitiveAggregator<VAL>
+where
+    <VAL as ArrowPrimitiveType>::Native: Clone,
+{
+    pub fn new(limit: usize, capacity: usize, descending: bool) -> Self {
+        Self {
+            priority_map: PriorityMap::new(limit, capacity, descending),
+        }
+    }
+}
+
+unsafe impl<VAL: ArrowPrimitiveType> Send for PrimitiveAggregator<VAL> where
+    <VAL as ArrowPrimitiveType>::Native: Clone
+{
+}
+
+impl<VAL: ArrowPrimitiveType> LimitedAggregator for PrimitiveAggregator<VAL>
+where
+    <VAL as ArrowPrimitiveType>::Native: Clone,
+{
+    fn intern(&mut self, ids: ArrayRef, vals: ArrayRef) -> Result<()> {
+        let ids = ids.as_any().downcast_ref::<StringArray>().ok_or_else(|| {
+            DataFusionError::Execution("Expected StringArray".to_string())
+        })?;
+        let vals = vals.as_primitive::<VAL>();
+        let null_count = vals.null_count();
+        for row_idx in 0..ids.len() {
+            if null_count > 0 && vals.is_null(row_idx) {
+                continue;
+            }
+            let val = vals.value(row_idx);
+            let id = if ids.is_null(row_idx) {
+                None
+            } else {
+                // Check goes here, because it is generalizable between 
str/String and Row/OwnedRow
+                let id = ids.value(row_idx);
+                if self.priority_map.is_full() {
+                    if self.priority_map.desc {
+                        if let Some(worst) = self.priority_map.min_val() {
+                            if val < *worst {
+                                continue;
+                            }
+                        }
+                    } else if let Some(worst) = self.priority_map.max_val() {
+                        if val > *worst {
+                            continue;
+                        }
+                    }
+                }
+                Some(id.to_string())
+            };
+
+            self.priority_map.insert(id, val)?;
+        }
+        Ok(())
+    }
+
+    fn emit(&mut self) -> Result<Vec<ArrayRef>> {
+        let (keys, vals): (Vec<_>, Vec<_>) =
+            self.priority_map.drain().into_iter().unzip();
+        let keys = Arc::new(StringArray::from(keys));
+        let vals = Arc::new(PrimitiveArray::<VAL>::from_iter_values(vals));
+        Ok(vec![keys, vals])
+    }
+
+    fn is_empty(&self) -> bool {
+        self.priority_map.is_empty()
+    }
+}
+
+impl Stream for GroupedTopKAggregateStream {
+    type Item = Result<RecordBatch>;
+
+    fn poll_next(
+        mut self: Pin<&mut Self>,
+        cx: &mut Context<'_>,
+    ) -> Poll<Option<Self::Item>> {
+        while let Poll::Ready(res) = self.input.poll_next_unpin(cx) {
+            match res {
+                // got a batch, convert to rows and append to our TreeMap
+                Some(Ok(batch)) => {
+                    self.started = true;
+                    trace!(
+                        "partition {} has {} rows and got batch with {} rows",
+                        self.partition,
+                        self.row_count,
+                        batch.num_rows()
+                    );
+                    if log::log_enabled!(Level::Trace) && batch.num_rows() < 
20 {
+                        print_batches(&[batch.clone()])?;
+                    }
+                    self.row_count += batch.num_rows();
+                    let batches = &[batch];
+                    let group_by_values =
+                        evaluate_group_by(&self.group_by, 
batches.first().unwrap())?;
+                    let group_by_values =
+                        group_by_values.into_iter().last().expect("values");
+                    let group_by_values =
+                        group_by_values.into_iter().last().expect("values");
+                    let input_values = evaluate_many(
+                        &self.aggregate_arguments,
+                        batches.first().unwrap(),
+                    )?;
+                    let input_values = match input_values.as_slice() {
+                        [] => {
+                            Err(DataFusionError::Execution("vals 
required".to_string()))?

Review Comment:
   ```suggestion
                               Err(DataFusionError::Internal("vals 
required".to_string()))?
   ```
   
   Probably the same for the other checks below



##########
datafusion/sqllogictest/test_files/aggregate.slt:
##########
@@ -2291,7 +2291,131 @@ false
 true
 NULL
 
+# TopK aggregation
+statement ok
+CREATE TABLE traces(trace_id varchar, timestamp bigint) AS VALUES
+(NULL, 0),
+('a', NULL),
+('a', 1),
+('b', 0),
+('c', 1),
+('c', 2),
+('b', 3);
+
+statement ok
+set datafusion.optimizer.enable_topk_aggregation = false;
 
+query TT
+explain select trace_id, MAX(timestamp) from traces group by trace_id order by 
MAX(timestamp) desc limit 4;
+----
+logical_plan
+Limit: skip=0, fetch=4
+--Sort: MAX(traces.timestamp) DESC NULLS FIRST, fetch=4
+----Aggregate: groupBy=[[traces.trace_id]], aggr=[[MAX(traces.timestamp)]]
+------TableScan: traces projection=[trace_id, timestamp]
+physical_plan
+GlobalLimitExec: skip=0, fetch=4
+--SortPreservingMergeExec: [MAX(traces.timestamp)@1 DESC], fetch=4
+----SortExec: fetch=4, expr=[MAX(traces.timestamp)@1 DESC]
+------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], 
aggr=[MAX(traces.timestamp)]
+--------CoalesceBatchesExec: target_batch_size=8192
+----------RepartitionExec: partitioning=Hash([trace_id@0], 4), 
input_partitions=4
+------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], 
aggr=[MAX(traces.timestamp)]
+--------------RepartitionExec: partitioning=RoundRobinBatch(4), 
input_partitions=1
+----------------MemoryExec: partitions=1, partition_sizes=[1]
+
+
+query TI
+select trace_id, MAX(timestamp) from traces group by trace_id order by 
MAX(timestamp) desc limit 4;

Review Comment:
   I recommend using a limit that doesn't save all of the values -- there are 
only 4 distinct values in the input so there will be only 4 in the output -- 
maybe you could either add another value to the input or reduce the limit to 
`3` so that it was clearly limiting



##########
datafusion/core/src/physical_plan/aggregates/row_hash.rs:
##########
@@ -145,7 +145,7 @@ pub(crate) struct GroupedHashAggregateStream {
     /// accumulator. If present, only those rows for which the filter
     /// evaluate to true should be included in the aggregate results.
     ///
-    /// For example, for an aggregate like `SUM(x FILTER x > 100)`,
+    /// For example, for an aggregate like `SUM(x) FILTER (WHERE x >= 100)`,

Review Comment:
   👍 



##########
datafusion/core/tests/sql/select.rs:
##########
@@ -572,6 +574,79 @@ async fn parallel_query_with_filter() -> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn parallel_query_with_limit() -> Result<()> {
+    let tmp_dir = TempDir::new()?;
+    let partition_count = 4;
+    let ctx = partitioned_csv::create_ctx(&tmp_dir, partition_count).await?;
+
+    let dataframe = ctx
+        .sql("SELECT c3, max(c2) as max FROM test group by c3 order by max 
desc limit 2")
+        .await?;
+
+    let actual_logical_plan = format!("{:?}", dataframe.logical_plan());
+    let expected_logical_plan = r#"
+Limit: skip=0, fetch=2
+  Sort: max DESC NULLS FIRST
+    Projection: test.c3, MAX(test.c2) AS max
+      Aggregate: groupBy=[[test.c3]], aggr=[[MAX(test.c2)]]
+        TableScan: test
+    "#
+    .trim();
+    assert_eq!(expected_logical_plan, actual_logical_plan);
+
+    let physical_plan = dataframe.create_physical_plan().await?;
+
+    // TODO: find the GroupedHashAggregateStream node and see if we can assert 
bucket count
+    finder(physical_plan.clone());
+
+    let actual_phys_plan = 
displayable(physical_plan.as_ref()).indent(true).to_string();
+    let mut expected_physical_plan = r#"
+GlobalLimitExec: skip=0, fetch=2
+  SortPreservingMergeExec: [max@1 DESC], fetch=2
+    SortExec: fetch=2, expr=[max@1 DESC]
+      ProjectionExec: expr=[c3@0 as c3, MAX(test.c2)@1 as max]
+        AggregateExec: mode=FinalPartitioned, gby=[c3@0 as c3], 
aggr=[MAX(test.c2)]

Review Comment:
   I am surprised this AggregateExec doesn't have a `lim=` / `fetch` field



##########
datafusion/core/tests/sql/select.rs:
##########
@@ -572,6 +574,79 @@ async fn parallel_query_with_filter() -> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn parallel_query_with_limit() -> Result<()> {
+    let tmp_dir = TempDir::new()?;
+    let partition_count = 4;
+    let ctx = partitioned_csv::create_ctx(&tmp_dir, partition_count).await?;
+
+    let dataframe = ctx
+        .sql("SELECT c3, max(c2) as max FROM test group by c3 order by max 
desc limit 2")
+        .await?;
+
+    let actual_logical_plan = format!("{:?}", dataframe.logical_plan());
+    let expected_logical_plan = r#"
+Limit: skip=0, fetch=2
+  Sort: max DESC NULLS FIRST
+    Projection: test.c3, MAX(test.c2) AS max
+      Aggregate: groupBy=[[test.c3]], aggr=[[MAX(test.c2)]]
+        TableScan: test
+    "#
+    .trim();
+    assert_eq!(expected_logical_plan, actual_logical_plan);
+
+    let physical_plan = dataframe.create_physical_plan().await?;
+
+    // TODO: find the GroupedHashAggregateStream node and see if we can assert 
bucket count

Review Comment:
   Is this still a TODO?



##########
datafusion/core/src/physical_plan/aggregates/priority_map.rs:
##########
@@ -0,0 +1,829 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! A memory-conscious aggregation implementation that limits group buckets to 
a fixed number

Review Comment:
   
   ```suggestion
   //! A memory-conscious aggregation implementation that limits group buckets 
to a fixed number
   //! [`GroupedTopKAggregateStream`]
   ```



##########
datafusion/core/src/physical_plan/aggregates/mod.rs:
##########
@@ -265,6 +270,8 @@ pub struct AggregateExec {
     pub(crate) filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
     /// (ORDER BY clause) expression for each aggregate expression
     pub(crate) order_by_expr: Vec<Option<LexOrdering>>,
+    /// Set if the output of this aggregation is truncated by a upstream 
sort/limit clause
+    pub(crate) limit: Option<usize>,

Review Comment:
   I wonder if we can add comments here or name it something different (like 
`group_num_limit` ?) to explain what is going on. Initially this might look 
like a `fetch` on sort or limit, but I think it is quite different
   
   Specifically,  I think it means "output only the top  groups based on the 
value of the aggregates" (I still don't fully understand how MIN/MAX and ORDER 
BY ASC/DESC are handled together, I would have expected the optimizer to have 
to specify that here as well, but perhaps it only marks `limit` for the right 
combinations)



##########
datafusion/core/src/physical_optimizer/topk_aggregation.rs:
##########
@@ -0,0 +1,191 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! An optimizer rule that detects aggregate operations that could use a 
limited bucket count
+
+use crate::physical_optimizer::PhysicalOptimizerRule;
+use crate::physical_plan::aggregates::AggregateExec;
+use crate::physical_plan::coalesce_batches::CoalesceBatchesExec;
+use crate::physical_plan::filter::FilterExec;
+use crate::physical_plan::repartition::RepartitionExec;
+use crate::physical_plan::sorts::sort::SortExec;
+use crate::physical_plan::ExecutionPlan;
+use arrow_schema::DataType;
+use datafusion_common::config::ConfigOptions;
+use datafusion_common::tree_node::{Transformed, TreeNode};
+use datafusion_common::Result;
+use datafusion_physical_expr::expressions::Column;
+use datafusion_physical_expr::PhysicalSortExpr;
+use std::sync::Arc;
+
+/// An optimizer rule that passes a `limit` hint to aggregations if the whole 
result is not needed
+pub struct TopKAggregation {}
+
+impl TopKAggregation {
+    /// Create a new `LimitAggregation`
+    pub fn new() -> Self {
+        Self {}
+    }
+
+    fn transform_agg(
+        aggr: &AggregateExec,
+        order: &PhysicalSortExpr,
+        limit: usize,
+    ) -> Option<Arc<dyn ExecutionPlan>> {
+        // ensure the sort direction matches aggregate function
+        let (field, desc) = aggr.get_minmax_desc()?;
+        if desc != order.options.descending {
+            return None;
+        }
+        let group_key = match aggr.group_expr().expr() {
+            [expr] => expr, // only one group key
+            _ => return None,
+        };
+        match group_key.0.data_type(&aggr.input_schema).ok() {
+            Some(DataType::Utf8) => {} // only String keys for now
+            _ => return None,
+        }
+        if aggr
+            .filter_expr
+            .iter()
+            .fold(false, |acc, cur| acc | cur.is_some())
+        {
+            return None;
+        }
+
+        // ensure the sort is on the same field as the aggregate output
+        let col = order.expr.as_any().downcast_ref::<Column>()?;
+        if col.name() != field.name() {
+            return None;
+        }
+
+        // We found what we want: clone, copy the limit down, and return 
modified node
+        let mut new_aggr = AggregateExec::try_new(
+            aggr.mode,
+            aggr.group_by.clone(),
+            aggr.aggr_expr.clone(),
+            aggr.filter_expr.clone(),
+            aggr.order_by_expr.clone(),
+            aggr.input.clone(),
+            aggr.input_schema.clone(),
+        )
+        .expect("Unable to copy Aggregate!");
+        new_aggr.limit = Some(limit);
+        Some(Arc::new(new_aggr))
+    }
+
+    fn transform_sort(plan: Arc<dyn ExecutionPlan>) -> Option<Arc<dyn 
ExecutionPlan>> {
+        let sort = plan.as_any().downcast_ref::<SortExec>()?;
+
+        // TODO: support sorting on multiple fields
+        let children = sort.children();
+        let child = match children.as_slice() {
+            [child] => child.clone(),
+            _ => return None,
+        };
+        let order = sort.output_ordering()?;
+        let order = match order {
+            [order] => order,
+            _ => return None,
+        };
+        let limit = sort.fetch()?;
+
+        let is_cardinality_preserving = |plan: Arc<dyn ExecutionPlan>| {
+            plan.as_any()
+                .downcast_ref::<CoalesceBatchesExec>()
+                .is_some()
+                || plan.as_any().downcast_ref::<RepartitionExec>().is_some()
+                || plan.as_any().downcast_ref::<FilterExec>().is_some()
+            // TODO: whitelist joins that don't increase row count?
+        };

Review Comment:
   I think it would be better to add a new method to the `ExecutionPlan` trait 
for this property rather than check for specific types. That way extensions and 
future plans can work correctly with this optimizer
   
   I am also somewhat confused that `FilterExec` would be labeled as 
"cardinality preserving" as it can certainly reduce the cardinality of its 
input 🤔  Is the idea to check for "non increasing" cardinality maybe?



##########
datafusion/core/src/physical_plan/aggregates/priority_map.rs:
##########
@@ -0,0 +1,829 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! A memory-conscious aggregation implementation that limits group buckets to 
a fixed number
+
+use crate::physical_plan::aggregates::{
+    aggregate_expressions, evaluate_group_by, evaluate_many, AggregateExec,
+    PhysicalGroupBy,
+};
+use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream};
+use ahash::RandomState;
+use arrow::util::pretty::print_batches;
+use arrow_array::cast::AsArray;
+use arrow_array::downcast_primitive;
+use arrow_array::{
+    Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray, 
RecordBatch,
+    StringArray,
+};
+use arrow_schema::{DataType, SchemaRef};
+use datafusion_common::DataFusionError;
+use datafusion_common::Result;
+use datafusion_execution::TaskContext;
+use datafusion_physical_expr::PhysicalExpr;
+use futures::stream::{Stream, StreamExt};
+use hashbrown::raw::RawTable;
+use log::{trace, Level};
+use std::cmp::Ordering;
+use std::collections::BTreeSet;
+use std::hash::Hash;
+use std::pin::Pin;
+use std::sync::Arc;
+use std::task::{Context, Poll};
+
+pub struct GroupedTopKAggregateStream {
+    partition: usize,
+    row_count: usize,
+    started: bool,
+    schema: SchemaRef,
+    input: SendableRecordBatchStream,
+    aggregate_arguments: Vec<Vec<Arc<dyn PhysicalExpr>>>,
+    group_by: PhysicalGroupBy,
+    aggregator: Box<dyn LimitedAggregator>,
+}
+
+impl GroupedTopKAggregateStream {
+    pub fn new(
+        agg: &AggregateExec,
+        context: Arc<TaskContext>,
+        partition: usize,
+        limit: usize,
+    ) -> Result<Self> {
+        let agg_schema = Arc::clone(&agg.schema);
+        let group_by = agg.group_by.clone();
+
+        let input = agg.input.execute(partition, Arc::clone(&context))?;
+
+        let aggregate_arguments =
+            aggregate_expressions(&agg.aggr_expr, &agg.mode, 
group_by.expr.len())?;
+
+        let (val_field, descending) = agg
+            .get_minmax_desc()
+            .ok_or_else(|| DataFusionError::Execution("Min/max 
required".to_string()))?;

Review Comment:
   If this code get run somehow it would be due to bug in one of the optimizer 
passes
   
   ```suggestion
               .ok_or_else(|| DataFusionError::Internal("Min/max 
required".to_string()))?;
   ```



##########
datafusion/core/src/physical_optimizer/topk_aggregation.rs:
##########
@@ -0,0 +1,191 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! An optimizer rule that detects aggregate operations that could use a 
limited bucket count
+
+use crate::physical_optimizer::PhysicalOptimizerRule;
+use crate::physical_plan::aggregates::AggregateExec;
+use crate::physical_plan::coalesce_batches::CoalesceBatchesExec;
+use crate::physical_plan::filter::FilterExec;
+use crate::physical_plan::repartition::RepartitionExec;
+use crate::physical_plan::sorts::sort::SortExec;
+use crate::physical_plan::ExecutionPlan;
+use arrow_schema::DataType;
+use datafusion_common::config::ConfigOptions;
+use datafusion_common::tree_node::{Transformed, TreeNode};
+use datafusion_common::Result;
+use datafusion_physical_expr::expressions::Column;
+use datafusion_physical_expr::PhysicalSortExpr;
+use std::sync::Arc;
+
+/// An optimizer rule that passes a `limit` hint to aggregations if the whole 
result is not needed
+pub struct TopKAggregation {}
+
+impl TopKAggregation {
+    /// Create a new `LimitAggregation`
+    pub fn new() -> Self {
+        Self {}
+    }
+
+    fn transform_agg(
+        aggr: &AggregateExec,
+        order: &PhysicalSortExpr,
+        limit: usize,
+    ) -> Option<Arc<dyn ExecutionPlan>> {
+        // ensure the sort direction matches aggregate function
+        let (field, desc) = aggr.get_minmax_desc()?;
+        if desc != order.options.descending {
+            return None;
+        }
+        let group_key = match aggr.group_expr().expr() {
+            [expr] => expr, // only one group key
+            _ => return None,
+        };
+        match group_key.0.data_type(&aggr.input_schema).ok() {
+            Some(DataType::Utf8) => {} // only String keys for now
+            _ => return None,
+        }
+        if aggr
+            .filter_expr
+            .iter()
+            .fold(false, |acc, cur| acc | cur.is_some())
+        {
+            return None;
+        }
+
+        // ensure the sort is on the same field as the aggregate output
+        let col = order.expr.as_any().downcast_ref::<Column>()?;
+        if col.name() != field.name() {
+            return None;
+        }
+
+        // We found what we want: clone, copy the limit down, and return 
modified node
+        let mut new_aggr = AggregateExec::try_new(
+            aggr.mode,
+            aggr.group_by.clone(),
+            aggr.aggr_expr.clone(),
+            aggr.filter_expr.clone(),
+            aggr.order_by_expr.clone(),
+            aggr.input.clone(),
+            aggr.input_schema.clone(),
+        )
+        .expect("Unable to copy Aggregate!");
+        new_aggr.limit = Some(limit);
+        Some(Arc::new(new_aggr))
+    }
+
+    fn transform_sort(plan: Arc<dyn ExecutionPlan>) -> Option<Arc<dyn 
ExecutionPlan>> {
+        let sort = plan.as_any().downcast_ref::<SortExec>()?;
+
+        // TODO: support sorting on multiple fields
+        let children = sort.children();
+        let child = match children.as_slice() {
+            [child] => child.clone(),
+            _ => return None,
+        };
+        let order = sort.output_ordering()?;
+        let order = match order {
+            [order] => order,
+            _ => return None,
+        };
+        let limit = sort.fetch()?;
+
+        let is_cardinality_preserving = |plan: Arc<dyn ExecutionPlan>| {
+            plan.as_any()
+                .downcast_ref::<CoalesceBatchesExec>()
+                .is_some()
+                || plan.as_any().downcast_ref::<RepartitionExec>().is_some()
+                || plan.as_any().downcast_ref::<FilterExec>().is_some()
+            // TODO: whitelist joins that don't increase row count?
+        };

Review Comment:
   I think it would be better to add a new method to the `ExecutionPlan` trait 
for this property rather than check for specific types. That way extensions and 
future plans can work correctly with this optimizer
   
   I am also somewhat confused that `FilterExec` would be labeled as 
"cardinality preserving" as it can certainly reduce the cardinality of its 
input 🤔  Is the idea to check for "non increasing" cardinality maybe?



##########
datafusion/core/src/physical_plan/aggregates/priority_map.rs:
##########
@@ -0,0 +1,829 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! A memory-conscious aggregation implementation that limits group buckets to 
a fixed number
+
+use crate::physical_plan::aggregates::{
+    aggregate_expressions, evaluate_group_by, evaluate_many, AggregateExec,
+    PhysicalGroupBy,
+};
+use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream};
+use ahash::RandomState;
+use arrow::util::pretty::print_batches;
+use arrow_array::cast::AsArray;
+use arrow_array::downcast_primitive;
+use arrow_array::{
+    Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray, 
RecordBatch,
+    StringArray,
+};
+use arrow_schema::{DataType, SchemaRef};
+use datafusion_common::DataFusionError;
+use datafusion_common::Result;
+use datafusion_execution::TaskContext;
+use datafusion_physical_expr::PhysicalExpr;
+use futures::stream::{Stream, StreamExt};
+use hashbrown::raw::RawTable;
+use log::{trace, Level};
+use std::cmp::Ordering;
+use std::collections::BTreeSet;
+use std::hash::Hash;
+use std::pin::Pin;
+use std::sync::Arc;
+use std::task::{Context, Poll};
+
+pub struct GroupedTopKAggregateStream {
+    partition: usize,
+    row_count: usize,
+    started: bool,
+    schema: SchemaRef,
+    input: SendableRecordBatchStream,
+    aggregate_arguments: Vec<Vec<Arc<dyn PhysicalExpr>>>,
+    group_by: PhysicalGroupBy,
+    aggregator: Box<dyn LimitedAggregator>,
+}
+
+impl GroupedTopKAggregateStream {
+    pub fn new(
+        agg: &AggregateExec,
+        context: Arc<TaskContext>,
+        partition: usize,
+        limit: usize,
+    ) -> Result<Self> {
+        let agg_schema = Arc::clone(&agg.schema);
+        let group_by = agg.group_by.clone();
+
+        let input = agg.input.execute(partition, Arc::clone(&context))?;
+
+        let aggregate_arguments =
+            aggregate_expressions(&agg.aggr_expr, &agg.mode, 
group_by.expr.len())?;
+
+        let (val_field, descending) = agg
+            .get_minmax_desc()
+            .ok_or_else(|| DataFusionError::Execution("Min/max 
required".to_string()))?;
+
+        let vt = val_field.data_type().clone();
+        let ag = new_group_values(limit, descending, vt)?;
+
+        Ok(GroupedTopKAggregateStream {
+            partition,
+            started: false,
+            row_count: 0,
+            schema: agg_schema,
+            input,
+            aggregate_arguments,
+            group_by,
+            aggregator: ag,
+        })
+    }
+}
+
+pub fn new_group_values(
+    limit: usize,
+    desc: bool,
+    vt: DataType,
+) -> Result<Box<dyn LimitedAggregator>> {
+    macro_rules! downcast_helper {
+        ($vt:ty, $d:ident) => {
+            return Ok(Box::new(PrimitiveAggregator::<$vt>::new(
+                limit,
+                limit * 10,
+                desc,
+            )))
+        };
+    }
+
+    downcast_primitive! {
+        vt => (downcast_helper, vt),
+        _ => {}
+    }
+
+    Err(DataFusionError::Execution(format!(
+        "Can't group type: {vt:?}"
+    )))
+}
+
+impl RecordBatchStream for GroupedTopKAggregateStream {
+    fn schema(&self) -> SchemaRef {
+        self.schema.clone()
+    }
+}
+
+pub trait LimitedAggregator: Send {
+    fn intern(&mut self, ids: ArrayRef, vals: ArrayRef) -> Result<()>;
+    fn emit(&mut self) -> Result<Vec<ArrayRef>>;
+    fn is_empty(&self) -> bool;
+}
+
+pub trait ValueType: ArrowNativeTypeOp + Clone {}
+
+impl<T> ValueType for T where T: ArrowNativeTypeOp + Clone {}
+
+pub trait KeyType: Clone + Eq + Hash {}
+
+impl<T> KeyType for T where T: Clone + Eq + Hash {}
+
+struct PrimitiveAggregator<VAL: ArrowPrimitiveType>
+where
+    <VAL as ArrowPrimitiveType>::Native: Clone,
+{
+    priority_map: PriorityMap<Option<String>, VAL::Native>,
+}
+
+impl<VAL: ArrowPrimitiveType> PrimitiveAggregator<VAL>
+where
+    <VAL as ArrowPrimitiveType>::Native: Clone,
+{
+    pub fn new(limit: usize, capacity: usize, descending: bool) -> Self {
+        Self {
+            priority_map: PriorityMap::new(limit, capacity, descending),
+        }
+    }
+}
+
+unsafe impl<VAL: ArrowPrimitiveType> Send for PrimitiveAggregator<VAL> where
+    <VAL as ArrowPrimitiveType>::Native: Clone
+{
+}
+
+impl<VAL: ArrowPrimitiveType> LimitedAggregator for PrimitiveAggregator<VAL>
+where
+    <VAL as ArrowPrimitiveType>::Native: Clone,
+{
+    fn intern(&mut self, ids: ArrayRef, vals: ArrayRef) -> Result<()> {
+        let ids = ids.as_any().downcast_ref::<StringArray>().ok_or_else(|| {
+            DataFusionError::Execution("Expected StringArray".to_string())
+        })?;
+        let vals = vals.as_primitive::<VAL>();
+        let null_count = vals.null_count();
+        for row_idx in 0..ids.len() {
+            if null_count > 0 && vals.is_null(row_idx) {
+                continue;
+            }
+            let val = vals.value(row_idx);
+            let id = if ids.is_null(row_idx) {
+                None
+            } else {
+                // Check goes here, because it is generalizable between 
str/String and Row/OwnedRow
+                let id = ids.value(row_idx);
+                if self.priority_map.is_full() {
+                    if self.priority_map.desc {
+                        if let Some(worst) = self.priority_map.min_val() {
+                            if val < *worst {
+                                continue;
+                            }
+                        }
+                    } else if let Some(worst) = self.priority_map.max_val() {
+                        if val > *worst {
+                            continue;
+                        }
+                    }
+                }
+                Some(id.to_string())
+            };
+
+            self.priority_map.insert(id, val)?;
+        }
+        Ok(())
+    }
+
+    fn emit(&mut self) -> Result<Vec<ArrayRef>> {
+        let (keys, vals): (Vec<_>, Vec<_>) =
+            self.priority_map.drain().into_iter().unzip();
+        let keys = Arc::new(StringArray::from(keys));
+        let vals = Arc::new(PrimitiveArray::<VAL>::from_iter_values(vals));
+        Ok(vec![keys, vals])
+    }
+
+    fn is_empty(&self) -> bool {
+        self.priority_map.is_empty()
+    }
+}
+
+impl Stream for GroupedTopKAggregateStream {
+    type Item = Result<RecordBatch>;
+
+    fn poll_next(
+        mut self: Pin<&mut Self>,
+        cx: &mut Context<'_>,
+    ) -> Poll<Option<Self::Item>> {
+        while let Poll::Ready(res) = self.input.poll_next_unpin(cx) {
+            match res {
+                // got a batch, convert to rows and append to our TreeMap
+                Some(Ok(batch)) => {
+                    self.started = true;
+                    trace!(
+                        "partition {} has {} rows and got batch with {} rows",
+                        self.partition,
+                        self.row_count,
+                        batch.num_rows()
+                    );
+                    if log::log_enabled!(Level::Trace) && batch.num_rows() < 
20 {
+                        print_batches(&[batch.clone()])?;
+                    }
+                    self.row_count += batch.num_rows();
+                    let batches = &[batch];
+                    let group_by_values =
+                        evaluate_group_by(&self.group_by, 
batches.first().unwrap())?;
+                    let group_by_values =

Review Comment:
   maybe it would be clearer if you asserted that these slices were length 1 
and then do group_by_values[0][0] or soemthing



##########
datafusion/core/src/physical_optimizer/topk_aggregation.rs:
##########
@@ -0,0 +1,191 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! An optimizer rule that detects aggregate operations that could use a 
limited bucket count
+
+use crate::physical_optimizer::PhysicalOptimizerRule;
+use crate::physical_plan::aggregates::AggregateExec;
+use crate::physical_plan::coalesce_batches::CoalesceBatchesExec;
+use crate::physical_plan::filter::FilterExec;
+use crate::physical_plan::repartition::RepartitionExec;
+use crate::physical_plan::sorts::sort::SortExec;
+use crate::physical_plan::ExecutionPlan;
+use arrow_schema::DataType;
+use datafusion_common::config::ConfigOptions;
+use datafusion_common::tree_node::{Transformed, TreeNode};
+use datafusion_common::Result;
+use datafusion_physical_expr::expressions::Column;
+use datafusion_physical_expr::PhysicalSortExpr;
+use std::sync::Arc;
+
+/// An optimizer rule that passes a `limit` hint to aggregations if the whole 
result is not needed
+pub struct TopKAggregation {}
+
+impl TopKAggregation {
+    /// Create a new `LimitAggregation`
+    pub fn new() -> Self {
+        Self {}
+    }
+
+    fn transform_agg(
+        aggr: &AggregateExec,
+        order: &PhysicalSortExpr,
+        limit: usize,
+    ) -> Option<Arc<dyn ExecutionPlan>> {
+        // ensure the sort direction matches aggregate function
+        let (field, desc) = aggr.get_minmax_desc()?;
+        if desc != order.options.descending {
+            return None;
+        }
+        let group_key = match aggr.group_expr().expr() {
+            [expr] => expr, // only one group key
+            _ => return None,
+        };
+        match group_key.0.data_type(&aggr.input_schema).ok() {
+            Some(DataType::Utf8) => {} // only String keys for now
+            _ => return None,
+        }
+        if aggr
+            .filter_expr
+            .iter()
+            .fold(false, |acc, cur| acc | cur.is_some())
+        {
+            return None;
+        }
+
+        // ensure the sort is on the same field as the aggregate output
+        let col = order.expr.as_any().downcast_ref::<Column>()?;
+        if col.name() != field.name() {
+            return None;
+        }
+
+        // We found what we want: clone, copy the limit down, and return 
modified node
+        let mut new_aggr = AggregateExec::try_new(
+            aggr.mode,
+            aggr.group_by.clone(),
+            aggr.aggr_expr.clone(),
+            aggr.filter_expr.clone(),
+            aggr.order_by_expr.clone(),
+            aggr.input.clone(),
+            aggr.input_schema.clone(),
+        )
+        .expect("Unable to copy Aggregate!");
+        new_aggr.limit = Some(limit);
+        Some(Arc::new(new_aggr))
+    }
+
+    fn transform_sort(plan: Arc<dyn ExecutionPlan>) -> Option<Arc<dyn 
ExecutionPlan>> {
+        let sort = plan.as_any().downcast_ref::<SortExec>()?;
+
+        // TODO: support sorting on multiple fields
+        let children = sort.children();
+        let child = match children.as_slice() {
+            [child] => child.clone(),
+            _ => return None,
+        };
+        let order = sort.output_ordering()?;
+        let order = match order {
+            [order] => order,
+            _ => return None,
+        };
+        let limit = sort.fetch()?;
+
+        let is_cardinality_preserving = |plan: Arc<dyn ExecutionPlan>| {
+            plan.as_any()
+                .downcast_ref::<CoalesceBatchesExec>()
+                .is_some()
+                || plan.as_any().downcast_ref::<RepartitionExec>().is_some()
+                || plan.as_any().downcast_ref::<FilterExec>().is_some()
+            // TODO: whitelist joins that don't increase row count?
+        };
+
+        let mut cardinality_preserved = true;
+        let mut closure = |plan: Arc<dyn ExecutionPlan>| {
+            if !cardinality_preserved {
+                return Ok(Transformed::No(plan));
+            }
+            if let Some(aggr) = plan.as_any().downcast_ref::<AggregateExec>() {
+                // either we run into an Aggregate and transform it
+                match Self::transform_agg(aggr, order, limit) {
+                    None => cardinality_preserved = false,
+                    Some(plan) => return Ok(Transformed::Yes(plan)),
+                }
+            } else {
+                // or we continue down whitelisted nodes of other types
+                if !is_cardinality_preserving(plan.clone()) {
+                    cardinality_preserved = false;
+                }
+            }
+            Ok(Transformed::No(plan))
+        };
+        let child = transform_down_mut(child, &mut closure).ok()?;
+        let sort = SortExec::new(sort.expr().to_vec(), child)
+            .with_fetch(sort.fetch())
+            .with_preserve_partitioning(sort.preserve_partitioning());
+        Some(Arc::new(sort))
+    }
+}
+
+fn transform_down_mut<F>(
+    me: Arc<dyn ExecutionPlan>,
+    op: &mut F,
+) -> Result<Arc<dyn ExecutionPlan>>
+where
+    F: FnMut(Arc<dyn ExecutionPlan>) -> Result<Transformed<Arc<dyn 
ExecutionPlan>>>,
+{
+    let after_op = op(me)?.into();
+    after_op.map_children(|node| transform_down_mut(node, op))
+}
+
+impl Default for TopKAggregation {
+    fn default() -> Self {
+        Self::new()
+    }
+}
+
+impl PhysicalOptimizerRule for TopKAggregation {
+    fn optimize(
+        &self,
+        plan: Arc<dyn ExecutionPlan>,
+        config: &ConfigOptions,
+    ) -> Result<Arc<dyn ExecutionPlan>> {
+        let plan = if config.optimizer.enable_topk_aggregation {
+            plan.transform_down(&|plan| {
+                Ok(
+                    if let Some(plan) = 
TopKAggregation::transform_sort(plan.clone()) {
+                        Transformed::Yes(plan)
+                    } else {
+                        Transformed::No(plan)
+                    },
+                )
+            })?
+        } else {
+            plan
+        };
+        Ok(plan)
+    }
+
+    fn name(&self) -> &str {
+        "LimitAggregation"
+    }
+
+    fn schema_check(&self) -> bool {
+        true
+    }
+}
+
+// TODO: tests

Review Comment:
   👍 



##########
datafusion/core/src/physical_plan/aggregates/priority_map.rs:
##########
@@ -0,0 +1,829 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! A memory-conscious aggregation implementation that limits group buckets to 
a fixed number
+
+use crate::physical_plan::aggregates::{
+    aggregate_expressions, evaluate_group_by, evaluate_many, AggregateExec,
+    PhysicalGroupBy,
+};
+use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream};
+use ahash::RandomState;
+use arrow::util::pretty::print_batches;
+use arrow_array::cast::AsArray;
+use arrow_array::downcast_primitive;
+use arrow_array::{
+    Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray, 
RecordBatch,
+    StringArray,
+};
+use arrow_schema::{DataType, SchemaRef};
+use datafusion_common::DataFusionError;
+use datafusion_common::Result;
+use datafusion_execution::TaskContext;
+use datafusion_physical_expr::PhysicalExpr;
+use futures::stream::{Stream, StreamExt};
+use hashbrown::raw::RawTable;
+use log::{trace, Level};
+use std::cmp::Ordering;
+use std::collections::BTreeSet;
+use std::hash::Hash;
+use std::pin::Pin;
+use std::sync::Arc;
+use std::task::{Context, Poll};
+
+pub struct GroupedTopKAggregateStream {

Review Comment:
   I think it would help to make the use case clear as this is an optimization 
for a fairly special case
   
   For example
   
   ```suggestion
   /// This is a special case for queries of the following form:
   ///
   /// SELECT group_id, MAX(time) FROM t GROUP BY group_id ORDER BY MAX(time) 
ASC LIMIT k
   /// SELECT group_id, MIN(time) FROM t GROUP BY group_id ORDER BY MIN(time) 
DESC LIMIT k
   ///
   /// It maintains only the current top group_ids in memory and is therefore 
much more memory efficient
   /// than the general purpose group operator
   pub struct GroupedTopKAggregateStream {
   ```
   



##########
datafusion/core/src/physical_plan/aggregates/priority_map.rs:
##########
@@ -0,0 +1,820 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! A memory-conscious aggregation implementation that limits group buckets to 
a fixed number
+
+use crate::physical_plan::aggregates::{
+    aggregate_expressions, evaluate_group_by, evaluate_many, AggregateExec,
+    PhysicalGroupBy,
+};
+use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream};
+use ahash::RandomState;
+use arrow::util::pretty::print_batches;
+use arrow_array::cast::AsArray;
+use arrow_array::downcast_primitive;
+use arrow_array::{
+    Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray, 
RecordBatch,
+    StringArray,
+};
+use arrow_schema::{DataType, SchemaRef, SortOptions};
+use datafusion_common::DataFusionError;
+use datafusion_common::Result;
+use datafusion_execution::TaskContext;
+use datafusion_physical_expr::PhysicalExpr;
+use futures::stream::{Stream, StreamExt};
+use hashbrown::raw::RawTable;
+use log::{trace, Level};
+use std::cmp::Ordering;
+use std::collections::BTreeSet;
+use std::hash::{BuildHasher, Hash, Hasher};
+use std::pin::Pin;
+use std::sync::Arc;
+use std::task::{Context, Poll};
+
+pub struct GroupedTopKAggregateStream {
+    partition: usize,
+    row_count: usize,
+    started: bool,
+    schema: SchemaRef,
+    input: SendableRecordBatchStream,
+    aggregate_arguments: Vec<Vec<Arc<dyn PhysicalExpr>>>,
+    group_by: PhysicalGroupBy,
+    aggregator: Box<dyn LimitedAggregator>,
+}
+
+impl GroupedTopKAggregateStream {
+    pub fn new(
+        agg: &AggregateExec,
+        context: Arc<TaskContext>,
+        partition: usize,
+        limit: usize,
+    ) -> Result<Self> {
+        let agg_schema = Arc::clone(&agg.schema);
+        let group_by = agg.group_by.clone();
+
+        let input = agg.input.execute(partition, Arc::clone(&context))?;
+
+        let aggregate_arguments =
+            aggregate_expressions(&agg.aggr_expr, &agg.mode, 
group_by.expr.len())?;
+
+        let (val_field, _) = agg
+            .get_minmax_desc()
+            .ok_or_else(|| DataFusionError::Execution("Min/max 
required".to_string()))?;
+
+        let vt = val_field.data_type().clone();
+        let ag = new_group_values(limit, agg.sort, vt)?;
+
+        Ok(GroupedTopKAggregateStream {
+            partition,
+            started: false,
+            row_count: 0,
+            schema: agg_schema,
+            input,
+            aggregate_arguments,
+            group_by,
+            aggregator: ag,
+        })
+    }
+}
+
+pub fn new_group_values(
+    limit: usize,
+    sort: SortOptions,
+    vt: DataType,
+) -> Result<Box<dyn LimitedAggregator>> {
+    macro_rules! downcast_helper {
+        ($vt:ty, $d:ident) => {
+            return Ok(Box::new(PrimitiveAggregator::<$vt>::new(
+                limit,
+                limit * 10,
+                sort,
+            )))
+        };
+    }
+
+    downcast_primitive! {
+        vt => (downcast_helper, vt),
+        _ => {}
+    }
+
+    Err(DataFusionError::Execution(format!(
+        "Can't group type: {vt:?}"
+    )))
+}
+
+impl RecordBatchStream for GroupedTopKAggregateStream {
+    fn schema(&self) -> SchemaRef {
+        self.schema.clone()
+    }
+}
+
+pub trait LimitedAggregator: Send {
+    fn intern(&mut self, ids: ArrayRef, vals: ArrayRef) -> Result<()>;
+    fn emit(&mut self) -> Result<Vec<ArrayRef>>;
+    fn is_empty(&self) -> bool;
+}
+
+pub trait ValueType: ArrowNativeTypeOp + Clone {}
+
+impl<T> ValueType for T where T: ArrowNativeTypeOp + Clone {}
+
+pub trait KeyType: Clone + Eq + Hash {}
+
+impl<T> KeyType for T where T: Clone + Eq + Hash {}
+
+struct PrimitiveAggregator<VAL: ArrowPrimitiveType>
+where
+    <VAL as ArrowPrimitiveType>::Native: Clone,
+{
+    priority_map: PriorityMap<Option<String>, VAL::Native>,

Review Comment:
   I think using the arrow Row format is another possibility (that would also 
allow multiple group keys to be supported, though not sure if that is an 
important usecase)



##########
datafusion/core/src/physical_plan/aggregates/priority_map.rs:
##########
@@ -0,0 +1,829 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! A memory-conscious aggregation implementation that limits group buckets to 
a fixed number
+
+use crate::physical_plan::aggregates::{
+    aggregate_expressions, evaluate_group_by, evaluate_many, AggregateExec,
+    PhysicalGroupBy,
+};
+use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream};
+use ahash::RandomState;
+use arrow::util::pretty::print_batches;
+use arrow_array::cast::AsArray;
+use arrow_array::downcast_primitive;
+use arrow_array::{
+    Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray, 
RecordBatch,
+    StringArray,
+};
+use arrow_schema::{DataType, SchemaRef};
+use datafusion_common::DataFusionError;
+use datafusion_common::Result;
+use datafusion_execution::TaskContext;
+use datafusion_physical_expr::PhysicalExpr;
+use futures::stream::{Stream, StreamExt};
+use hashbrown::raw::RawTable;
+use log::{trace, Level};
+use std::cmp::Ordering;
+use std::collections::BTreeSet;
+use std::hash::Hash;
+use std::pin::Pin;
+use std::sync::Arc;
+use std::task::{Context, Poll};
+
+pub struct GroupedTopKAggregateStream {
+    partition: usize,
+    row_count: usize,
+    started: bool,
+    schema: SchemaRef,
+    input: SendableRecordBatchStream,
+    aggregate_arguments: Vec<Vec<Arc<dyn PhysicalExpr>>>,
+    group_by: PhysicalGroupBy,
+    aggregator: Box<dyn LimitedAggregator>,
+}
+
+impl GroupedTopKAggregateStream {
+    pub fn new(
+        agg: &AggregateExec,
+        context: Arc<TaskContext>,
+        partition: usize,
+        limit: usize,
+    ) -> Result<Self> {
+        let agg_schema = Arc::clone(&agg.schema);
+        let group_by = agg.group_by.clone();
+
+        let input = agg.input.execute(partition, Arc::clone(&context))?;
+
+        let aggregate_arguments =
+            aggregate_expressions(&agg.aggr_expr, &agg.mode, 
group_by.expr.len())?;
+
+        let (val_field, descending) = agg
+            .get_minmax_desc()
+            .ok_or_else(|| DataFusionError::Execution("Min/max 
required".to_string()))?;
+
+        let vt = val_field.data_type().clone();
+        let ag = new_group_values(limit, descending, vt)?;
+
+        Ok(GroupedTopKAggregateStream {
+            partition,
+            started: false,
+            row_count: 0,
+            schema: agg_schema,
+            input,
+            aggregate_arguments,
+            group_by,
+            aggregator: ag,
+        })
+    }
+}
+
+pub fn new_group_values(
+    limit: usize,
+    desc: bool,
+    vt: DataType,
+) -> Result<Box<dyn LimitedAggregator>> {
+    macro_rules! downcast_helper {
+        ($vt:ty, $d:ident) => {
+            return Ok(Box::new(PrimitiveAggregator::<$vt>::new(
+                limit,
+                limit * 10,
+                desc,
+            )))
+        };
+    }
+
+    downcast_primitive! {
+        vt => (downcast_helper, vt),
+        _ => {}
+    }
+
+    Err(DataFusionError::Execution(format!(
+        "Can't group type: {vt:?}"
+    )))
+}
+
+impl RecordBatchStream for GroupedTopKAggregateStream {
+    fn schema(&self) -> SchemaRef {
+        self.schema.clone()
+    }
+}
+
+pub trait LimitedAggregator: Send {
+    fn intern(&mut self, ids: ArrayRef, vals: ArrayRef) -> Result<()>;
+    fn emit(&mut self) -> Result<Vec<ArrayRef>>;
+    fn is_empty(&self) -> bool;
+}
+
+pub trait ValueType: ArrowNativeTypeOp + Clone {}
+
+impl<T> ValueType for T where T: ArrowNativeTypeOp + Clone {}
+
+pub trait KeyType: Clone + Eq + Hash {}
+
+impl<T> KeyType for T where T: Clone + Eq + Hash {}
+
+struct PrimitiveAggregator<VAL: ArrowPrimitiveType>
+where
+    <VAL as ArrowPrimitiveType>::Native: Clone,
+{
+    priority_map: PriorityMap<Option<String>, VAL::Native>,
+}
+
+impl<VAL: ArrowPrimitiveType> PrimitiveAggregator<VAL>
+where
+    <VAL as ArrowPrimitiveType>::Native: Clone,
+{
+    pub fn new(limit: usize, capacity: usize, descending: bool) -> Self {
+        Self {
+            priority_map: PriorityMap::new(limit, capacity, descending),
+        }
+    }
+}
+
+unsafe impl<VAL: ArrowPrimitiveType> Send for PrimitiveAggregator<VAL> where
+    <VAL as ArrowPrimitiveType>::Native: Clone
+{
+}
+
+impl<VAL: ArrowPrimitiveType> LimitedAggregator for PrimitiveAggregator<VAL>
+where
+    <VAL as ArrowPrimitiveType>::Native: Clone,
+{
+    fn intern(&mut self, ids: ArrayRef, vals: ArrayRef) -> Result<()> {
+        let ids = ids.as_any().downcast_ref::<StringArray>().ok_or_else(|| {
+            DataFusionError::Execution("Expected StringArray".to_string())
+        })?;
+        let vals = vals.as_primitive::<VAL>();
+        let null_count = vals.null_count();
+        for row_idx in 0..ids.len() {
+            if null_count > 0 && vals.is_null(row_idx) {
+                continue;
+            }
+            let val = vals.value(row_idx);
+            let id = if ids.is_null(row_idx) {
+                None
+            } else {
+                // Check goes here, because it is generalizable between 
str/String and Row/OwnedRow
+                let id = ids.value(row_idx);
+                if self.priority_map.is_full() {
+                    if self.priority_map.desc {
+                        if let Some(worst) = self.priority_map.min_val() {
+                            if val < *worst {
+                                continue;
+                            }
+                        }
+                    } else if let Some(worst) = self.priority_map.max_val() {
+                        if val > *worst {
+                            continue;
+                        }
+                    }
+                }
+                Some(id.to_string())
+            };
+
+            self.priority_map.insert(id, val)?;
+        }
+        Ok(())
+    }
+
+    fn emit(&mut self) -> Result<Vec<ArrayRef>> {
+        let (keys, vals): (Vec<_>, Vec<_>) =
+            self.priority_map.drain().into_iter().unzip();
+        let keys = Arc::new(StringArray::from(keys));
+        let vals = Arc::new(PrimitiveArray::<VAL>::from_iter_values(vals));
+        Ok(vec![keys, vals])
+    }
+
+    fn is_empty(&self) -> bool {
+        self.priority_map.is_empty()
+    }
+}
+
+impl Stream for GroupedTopKAggregateStream {
+    type Item = Result<RecordBatch>;
+
+    fn poll_next(
+        mut self: Pin<&mut Self>,
+        cx: &mut Context<'_>,
+    ) -> Poll<Option<Self::Item>> {
+        while let Poll::Ready(res) = self.input.poll_next_unpin(cx) {
+            match res {
+                // got a batch, convert to rows and append to our TreeMap
+                Some(Ok(batch)) => {
+                    self.started = true;
+                    trace!(
+                        "partition {} has {} rows and got batch with {} rows",
+                        self.partition,
+                        self.row_count,
+                        batch.num_rows()
+                    );
+                    if log::log_enabled!(Level::Trace) && batch.num_rows() < 
20 {
+                        print_batches(&[batch.clone()])?;
+                    }
+                    self.row_count += batch.num_rows();
+                    let batches = &[batch];
+                    let group_by_values =
+                        evaluate_group_by(&self.group_by, 
batches.first().unwrap())?;
+                    let group_by_values =
+                        group_by_values.into_iter().last().expect("values");
+                    let group_by_values =
+                        group_by_values.into_iter().last().expect("values");
+                    let input_values = evaluate_many(
+                        &self.aggregate_arguments,
+                        batches.first().unwrap(),
+                    )?;
+                    let input_values = match input_values.as_slice() {
+                        [] => {
+                            Err(DataFusionError::Execution("vals 
required".to_string()))?
+                        }
+                        [vals] => vals,
+                        _ => {
+                            Err(DataFusionError::Execution("1 val 
required".to_string()))?
+                        }
+                    };
+                    let input_values = match input_values.as_slice() {
+                        [] => {
+                            Err(DataFusionError::Execution("vals 
required".to_string()))?
+                        }
+                        [vals] => vals,
+                        _ => {
+                            Err(DataFusionError::Execution("1 val 
required".to_string()))?
+                        }
+                    }
+                    .clone();
+
+                    // iterate over each column of group_by values
+                    (*self.aggregator).intern(group_by_values, input_values)?;
+                }
+                // inner is done, emit all rows and switch to producing output
+                None => {
+                    if self.aggregator.is_empty() {
+                        trace!("partition {} emit None", self.partition);
+                        return Poll::Ready(None);
+                    }
+                    let cols = self.aggregator.emit()?;
+                    let batch = RecordBatch::try_new(self.schema.clone(), 
cols)?;
+                    trace!(
+                        "partition {} emit batch with {} rows",
+                        self.partition,
+                        batch.num_rows()
+                    );
+                    if log::log_enabled!(Level::Trace) {
+                        print_batches(&[batch.clone()])?;
+                    }
+                    return Poll::Ready(Some(Ok(batch)));
+                }
+                // inner had error, return to caller
+                Some(Err(e)) => {
+                    return Poll::Ready(Some(Err(e)));
+                }
+            }
+        }
+        Poll::Pending
+    }
+}
+
+/// A dual data structure consisting of a bi-directionally linked Map & Heap
+///
+/// The implementation is optimized for performance because `insert()` will be 
called on billions of
+/// rows. Because traversing between the map & heap will happen frequently, it 
is important to
+/// be highly optimized.
+///
+/// In order to quickly traverse from heap to map, we use the unsafe raw 
indexes that `RawTable`
+/// exposes to us to avoid needing to find buckets based on their hash.
+///
+/// Presently this implementation does not traverse quickly from map to heap, 
instead performing
+/// a B-ary search of the `BTreeSet`. In the future this could be eliminated 
by implementing a
+/// custom binary heap with pointers from the map into the heap.
+pub struct PriorityMap<ID: KeyType, VAL: ValueType> {
+    limit: usize,
+    desc: bool,
+    rnd: RandomState,
+    id_to_val: RawTable<MapItem<ID, VAL>>,
+    val_to_idx: BTreeSet<HeapItem<VAL>>,
+}
+
+pub struct MapItem<ID: KeyType, VAL: ValueType> {
+    hash: u64,
+    pub id: ID,
+    pub val: VAL,
+}
+
+impl<ID: KeyType, VAL: ValueType> MapItem<ID, VAL> {
+    pub fn new(hash: u64, id: ID, val: VAL) -> Self {
+        Self { hash, id, val }
+    }
+}
+
+struct HeapItem<VAL: ValueType> {
+    val: VAL,
+    buk_idx: usize,
+}
+
+impl<VAL: ValueType> HeapItem<VAL> {
+    pub fn new(val: VAL, buk_idx: usize) -> Self {
+        Self { val, buk_idx }
+    }
+}
+
+impl<VAL: ValueType> Eq for HeapItem<VAL> {}
+
+impl<VAL: ValueType> PartialEq<Self> for HeapItem<VAL> {
+    fn eq(&self, other: &Self) -> bool {
+        self.cmp(other) == Ordering::Equal
+    }
+}
+
+impl<VAL: ValueType> PartialOrd<Self> for HeapItem<VAL> {
+    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+        Some(self.cmp(other))
+    }
+}
+
+impl<VAL: ValueType> Ord for HeapItem<VAL> {
+    fn cmp(&self, other: &Self) -> Ordering {
+        let res = self.val.compare(other.val);
+        if res != Ordering::Equal {
+            return res;
+        }
+        self.buk_idx.cmp(&other.buk_idx)
+    }
+}
+
+impl<ID: KeyType, VAL: ValueType> PriorityMap<ID, VAL>
+where
+    VAL: PartialEq<VAL>,
+{
+    pub fn new(limit: usize, capacity: usize, desc: bool) -> Self {
+        Self {
+            limit,
+            desc,
+            rnd: Default::default(),
+            id_to_val: RawTable::with_capacity(capacity),
+            val_to_idx: Default::default(),
+        }
+    }
+
+    pub fn insert(&mut self, new_id: ID, new_val: VAL) -> Result<()> {
+        let is_full = self.is_full();
+        let desc = self.desc;
+        assert!(self.len() <= self.limit, "Overflow");
+        let (id_to_val, val_to_idx) = (&mut self.id_to_val, &mut 
self.val_to_idx);
+
+        // if we're full, and the new val is worse than all our values, just 
bail
+        if is_full {
+            let worst_entry = if desc {
+                val_to_idx.first()
+            } else {
+                val_to_idx.last()
+            }
+            .expect("Missing value!");
+            if (!desc && new_val > worst_entry.val) || (desc && new_val < 
worst_entry.val)
+            {
+                return Ok(());
+            }
+        }
+
+        // handle new groups we haven't seen yet
+        let new_hash = self.rnd.hash_one(&new_id);

Review Comment:
   I didn't review this implementation carefully yet but the basic idea makes 
sense to me. I can review it more carefully if @Dandandan  and @thinkharderdev  
haven't done so. 
   
   I still believe (without evidence yet) that I can use the same code from 
https://github.com/apache/arrow-datafusion/pull/7250 and get very similar 
performance, but I can do that as a follow on PR in parallel with this one



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to