alamb commented on code in PR #7192:
URL: https://github.com/apache/arrow-datafusion/pull/7192#discussion_r1301660481
##########
Cargo.toml:
##########
@@ -39,6 +39,7 @@ arrow-flight = { version = "45.0.0", features =
["flight-sql-experimental"] }
arrow-schema = { version = "45.0.0", default-features = false }
parquet = { version = "45.0.0", features = ["arrow", "async", "object_store"] }
sqlparser = { version = "0.36.1", features = ["visitor"] }
+zerocopy = "0.6.1"
Review Comment:
FWIW zero copy is already used by Datafusion transitively -- (see
https://github.com/apache/arrow-datafusion/issues/7221) so this is not a new
dependency
##########
datafusion/core/src/physical_plan/aggregates/mod.rs:
##########
@@ -717,14 +725,38 @@ impl AggregateExec {
partition: usize,
context: Arc<TaskContext>,
) -> Result<StreamType> {
+ // no group by at all
if self.group_by.expr.is_empty() {
- Ok(StreamType::AggregateStream(AggregateStream::new(
+ return Ok(StreamType::AggregateStream(AggregateStream::new(
self, context, partition,
- )?))
+ )?));
+ }
+
+ // grouping by an expression that has a sort/limit upstream
+ if let Some(limit) = self.limit {
+ return Ok(StreamType::GroupedPriorityQueue(
+ GroupedTopKAggregateStream::new(self, context, partition,
limit)?,
+ ));
+ }
+
+ // grouping by something else and we need to just materialize all
results
+ Ok(StreamType::GroupedHash(GroupedHashAggregateStream::new(
+ self, context, partition,
+ )?))
+ }
+
+ /// Finds the DataType and SortDirection for this Aggregate, if there is
one
+ pub fn get_minmax_desc(&self) -> Option<(Field, bool)> {
+ let agg_expr = match self.aggr_expr.as_slice() {
+ [expr] => expr,
+ _ => return None,
+ };
+ if let Some(max) = agg_expr.as_any().downcast_ref::<Max>() {
+ Some((max.field().ok()?, true))
+ } else if let Some(min) = agg_expr.as_any().downcast_ref::<Min>() {
+ Some((min.field().ok()?, true))
Review Comment:
This code always seems to return `true` for the sort direction, is that
intended?
##########
datafusion/core/src/physical_plan/aggregates/row_hash.rs:
##########
@@ -145,7 +145,7 @@ pub(crate) struct GroupedHashAggregateStream {
/// accumulator. If present, only those rows for which the filter
/// evaluate to true should be included in the aggregate results.
///
- /// For example, for an aggregate like `SUM(x FILTER x > 100)`,
+ /// For example, for an aggregate like `SUM(x) FILTER (WHERE x >= 100)`,
Review Comment:
👍
##########
datafusion/sqllogictest/test_files/aggregate.slt:
##########
@@ -2291,7 +2291,131 @@ false
true
NULL
+# TopK aggregation
+statement ok
+CREATE TABLE traces(trace_id varchar, timestamp bigint) AS VALUES
+(NULL, 0),
+('a', NULL),
+('a', 1),
+('b', 0),
+('c', 1),
+('c', 2),
+('b', 3);
+
+statement ok
+set datafusion.optimizer.enable_topk_aggregation = false;
+query TT
+explain select trace_id, MAX(timestamp) from traces group by trace_id order by
MAX(timestamp) desc limit 4;
+----
+logical_plan
+Limit: skip=0, fetch=4
+--Sort: MAX(traces.timestamp) DESC NULLS FIRST, fetch=4
+----Aggregate: groupBy=[[traces.trace_id]], aggr=[[MAX(traces.timestamp)]]
+------TableScan: traces projection=[trace_id, timestamp]
+physical_plan
+GlobalLimitExec: skip=0, fetch=4
+--SortPreservingMergeExec: [MAX(traces.timestamp)@1 DESC], fetch=4
+----SortExec: fetch=4, expr=[MAX(traces.timestamp)@1 DESC]
+------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id],
aggr=[MAX(traces.timestamp)]
+--------CoalesceBatchesExec: target_batch_size=8192
+----------RepartitionExec: partitioning=Hash([trace_id@0], 4),
input_partitions=4
+------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id],
aggr=[MAX(traces.timestamp)]
+--------------RepartitionExec: partitioning=RoundRobinBatch(4),
input_partitions=1
+----------------MemoryExec: partitions=1, partition_sizes=[1]
+
+
+query TI
+select trace_id, MAX(timestamp) from traces group by trace_id order by
MAX(timestamp) desc limit 4;
+----
+b 3
+c 2
+a 1
+NULL 0
+
+query TI
+select trace_id, MIN(timestamp) from traces group by trace_id order by
MIN(timestamp) asc limit 4;
Review Comment:
I also recommend a test like
```suggestion
select trace_id, MIN(timestamp) from traces group by trace_id order by
MIN(timestamp) desc limit 4;
```
Which I realize isn't a particularly useful query, but it should still work
and it wasn't quite clear to me that this case is covered correctly
##########
datafusion/core/src/physical_optimizer/topk_aggregation.rs:
##########
@@ -0,0 +1,191 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! An optimizer rule that detects aggregate operations that could use a
limited bucket count
+
+use crate::physical_optimizer::PhysicalOptimizerRule;
+use crate::physical_plan::aggregates::AggregateExec;
+use crate::physical_plan::coalesce_batches::CoalesceBatchesExec;
+use crate::physical_plan::filter::FilterExec;
+use crate::physical_plan::repartition::RepartitionExec;
+use crate::physical_plan::sorts::sort::SortExec;
+use crate::physical_plan::ExecutionPlan;
+use arrow_schema::DataType;
+use datafusion_common::config::ConfigOptions;
+use datafusion_common::tree_node::{Transformed, TreeNode};
+use datafusion_common::Result;
+use datafusion_physical_expr::expressions::Column;
+use datafusion_physical_expr::PhysicalSortExpr;
+use std::sync::Arc;
+
+/// An optimizer rule that passes a `limit` hint to aggregations if the whole
result is not needed
+pub struct TopKAggregation {}
+
+impl TopKAggregation {
+ /// Create a new `LimitAggregation`
+ pub fn new() -> Self {
+ Self {}
+ }
+
+ fn transform_agg(
+ aggr: &AggregateExec,
+ order: &PhysicalSortExpr,
+ limit: usize,
+ ) -> Option<Arc<dyn ExecutionPlan>> {
+ // ensure the sort direction matches aggregate function
+ let (field, desc) = aggr.get_minmax_desc()?;
+ if desc != order.options.descending {
+ return None;
+ }
+ let group_key = match aggr.group_expr().expr() {
+ [expr] => expr, // only one group key
+ _ => return None,
+ };
+ match group_key.0.data_type(&aggr.input_schema).ok() {
+ Some(DataType::Utf8) => {} // only String keys for now
+ _ => return None,
+ }
+ if aggr
+ .filter_expr
+ .iter()
+ .fold(false, |acc, cur| acc | cur.is_some())
+ {
+ return None;
+ }
+
+ // ensure the sort is on the same field as the aggregate output
+ let col = order.expr.as_any().downcast_ref::<Column>()?;
+ if col.name() != field.name() {
+ return None;
+ }
+
+ // We found what we want: clone, copy the limit down, and return
modified node
+ let mut new_aggr = AggregateExec::try_new(
+ aggr.mode,
+ aggr.group_by.clone(),
+ aggr.aggr_expr.clone(),
+ aggr.filter_expr.clone(),
+ aggr.order_by_expr.clone(),
+ aggr.input.clone(),
+ aggr.input_schema.clone(),
+ )
+ .expect("Unable to copy Aggregate!");
+ new_aggr.limit = Some(limit);
+ Some(Arc::new(new_aggr))
+ }
+
+ fn transform_sort(plan: Arc<dyn ExecutionPlan>) -> Option<Arc<dyn
ExecutionPlan>> {
+ let sort = plan.as_any().downcast_ref::<SortExec>()?;
+
+ // TODO: support sorting on multiple fields
+ let children = sort.children();
+ let child = match children.as_slice() {
+ [child] => child.clone(),
+ _ => return None,
+ };
+ let order = sort.output_ordering()?;
+ let order = match order {
+ [order] => order,
+ _ => return None,
+ };
+ let limit = sort.fetch()?;
+
+ let is_cardinality_preserving = |plan: Arc<dyn ExecutionPlan>| {
+ plan.as_any()
+ .downcast_ref::<CoalesceBatchesExec>()
+ .is_some()
+ || plan.as_any().downcast_ref::<RepartitionExec>().is_some()
+ || plan.as_any().downcast_ref::<FilterExec>().is_some()
+ // TODO: whitelist joins that don't increase row count?
+ };
+
+ let mut cardinality_preserved = true;
+ let mut closure = |plan: Arc<dyn ExecutionPlan>| {
+ if !cardinality_preserved {
+ return Ok(Transformed::No(plan));
+ }
+ if let Some(aggr) = plan.as_any().downcast_ref::<AggregateExec>() {
+ // either we run into an Aggregate and transform it
+ match Self::transform_agg(aggr, order, limit) {
+ None => cardinality_preserved = false,
+ Some(plan) => return Ok(Transformed::Yes(plan)),
+ }
+ } else {
+ // or we continue down whitelisted nodes of other types
+ if !is_cardinality_preserving(plan.clone()) {
+ cardinality_preserved = false;
+ }
+ }
+ Ok(Transformed::No(plan))
+ };
+ let child = transform_down_mut(child, &mut closure).ok()?;
+ let sort = SortExec::new(sort.expr().to_vec(), child)
+ .with_fetch(sort.fetch())
+ .with_preserve_partitioning(sort.preserve_partitioning());
+ Some(Arc::new(sort))
+ }
+}
+
+fn transform_down_mut<F>(
+ me: Arc<dyn ExecutionPlan>,
+ op: &mut F,
+) -> Result<Arc<dyn ExecutionPlan>>
+where
+ F: FnMut(Arc<dyn ExecutionPlan>) -> Result<Transformed<Arc<dyn
ExecutionPlan>>>,
+{
+ let after_op = op(me)?.into();
+ after_op.map_children(|node| transform_down_mut(node, op))
+}
+
+impl Default for TopKAggregation {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+impl PhysicalOptimizerRule for TopKAggregation {
+ fn optimize(
+ &self,
+ plan: Arc<dyn ExecutionPlan>,
+ config: &ConfigOptions,
+ ) -> Result<Arc<dyn ExecutionPlan>> {
+ let plan = if config.optimizer.enable_topk_aggregation {
+ plan.transform_down(&|plan| {
+ Ok(
+ if let Some(plan) =
TopKAggregation::transform_sort(plan.clone()) {
+ Transformed::Yes(plan)
+ } else {
+ Transformed::No(plan)
+ },
+ )
+ })?
+ } else {
+ plan
+ };
+ Ok(plan)
+ }
+
+ fn name(&self) -> &str {
+ "LimitAggregation"
+ }
+
+ fn schema_check(&self) -> bool {
+ true
+ }
+}
+
+// TODO: tests
Review Comment:
👍
##########
datafusion/core/src/physical_plan/aggregates/priority_map.rs:
##########
@@ -0,0 +1,829 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! A memory-conscious aggregation implementation that limits group buckets to
a fixed number
+
+use crate::physical_plan::aggregates::{
+ aggregate_expressions, evaluate_group_by, evaluate_many, AggregateExec,
+ PhysicalGroupBy,
+};
+use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream};
+use ahash::RandomState;
+use arrow::util::pretty::print_batches;
+use arrow_array::cast::AsArray;
+use arrow_array::downcast_primitive;
+use arrow_array::{
+ Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray,
RecordBatch,
+ StringArray,
+};
+use arrow_schema::{DataType, SchemaRef};
+use datafusion_common::DataFusionError;
+use datafusion_common::Result;
+use datafusion_execution::TaskContext;
+use datafusion_physical_expr::PhysicalExpr;
+use futures::stream::{Stream, StreamExt};
+use hashbrown::raw::RawTable;
+use log::{trace, Level};
+use std::cmp::Ordering;
+use std::collections::BTreeSet;
+use std::hash::Hash;
+use std::pin::Pin;
+use std::sync::Arc;
+use std::task::{Context, Poll};
+
+pub struct GroupedTopKAggregateStream {
+ partition: usize,
+ row_count: usize,
+ started: bool,
+ schema: SchemaRef,
+ input: SendableRecordBatchStream,
+ aggregate_arguments: Vec<Vec<Arc<dyn PhysicalExpr>>>,
+ group_by: PhysicalGroupBy,
+ aggregator: Box<dyn LimitedAggregator>,
+}
+
+impl GroupedTopKAggregateStream {
+ pub fn new(
+ agg: &AggregateExec,
+ context: Arc<TaskContext>,
+ partition: usize,
+ limit: usize,
+ ) -> Result<Self> {
+ let agg_schema = Arc::clone(&agg.schema);
+ let group_by = agg.group_by.clone();
+
+ let input = agg.input.execute(partition, Arc::clone(&context))?;
+
+ let aggregate_arguments =
+ aggregate_expressions(&agg.aggr_expr, &agg.mode,
group_by.expr.len())?;
+
+ let (val_field, descending) = agg
+ .get_minmax_desc()
+ .ok_or_else(|| DataFusionError::Execution("Min/max
required".to_string()))?;
+
+ let vt = val_field.data_type().clone();
+ let ag = new_group_values(limit, descending, vt)?;
+
+ Ok(GroupedTopKAggregateStream {
+ partition,
+ started: false,
+ row_count: 0,
+ schema: agg_schema,
+ input,
+ aggregate_arguments,
+ group_by,
+ aggregator: ag,
+ })
+ }
+}
+
+pub fn new_group_values(
+ limit: usize,
+ desc: bool,
+ vt: DataType,
+) -> Result<Box<dyn LimitedAggregator>> {
+ macro_rules! downcast_helper {
+ ($vt:ty, $d:ident) => {
+ return Ok(Box::new(PrimitiveAggregator::<$vt>::new(
+ limit,
+ limit * 10,
+ desc,
+ )))
+ };
+ }
+
+ downcast_primitive! {
+ vt => (downcast_helper, vt),
+ _ => {}
+ }
+
+ Err(DataFusionError::Execution(format!(
+ "Can't group type: {vt:?}"
+ )))
+}
+
+impl RecordBatchStream for GroupedTopKAggregateStream {
+ fn schema(&self) -> SchemaRef {
+ self.schema.clone()
+ }
+}
+
+pub trait LimitedAggregator: Send {
+ fn intern(&mut self, ids: ArrayRef, vals: ArrayRef) -> Result<()>;
+ fn emit(&mut self) -> Result<Vec<ArrayRef>>;
+ fn is_empty(&self) -> bool;
+}
+
+pub trait ValueType: ArrowNativeTypeOp + Clone {}
+
+impl<T> ValueType for T where T: ArrowNativeTypeOp + Clone {}
+
+pub trait KeyType: Clone + Eq + Hash {}
+
+impl<T> KeyType for T where T: Clone + Eq + Hash {}
+
+struct PrimitiveAggregator<VAL: ArrowPrimitiveType>
+where
+ <VAL as ArrowPrimitiveType>::Native: Clone,
+{
+ priority_map: PriorityMap<Option<String>, VAL::Native>,
+}
+
+impl<VAL: ArrowPrimitiveType> PrimitiveAggregator<VAL>
+where
+ <VAL as ArrowPrimitiveType>::Native: Clone,
+{
+ pub fn new(limit: usize, capacity: usize, descending: bool) -> Self {
+ Self {
+ priority_map: PriorityMap::new(limit, capacity, descending),
+ }
+ }
+}
+
+unsafe impl<VAL: ArrowPrimitiveType> Send for PrimitiveAggregator<VAL> where
+ <VAL as ArrowPrimitiveType>::Native: Clone
+{
+}
+
+impl<VAL: ArrowPrimitiveType> LimitedAggregator for PrimitiveAggregator<VAL>
+where
+ <VAL as ArrowPrimitiveType>::Native: Clone,
+{
+ fn intern(&mut self, ids: ArrayRef, vals: ArrayRef) -> Result<()> {
+ let ids = ids.as_any().downcast_ref::<StringArray>().ok_or_else(|| {
+ DataFusionError::Execution("Expected StringArray".to_string())
+ })?;
+ let vals = vals.as_primitive::<VAL>();
+ let null_count = vals.null_count();
+ for row_idx in 0..ids.len() {
+ if null_count > 0 && vals.is_null(row_idx) {
+ continue;
+ }
+ let val = vals.value(row_idx);
+ let id = if ids.is_null(row_idx) {
+ None
+ } else {
+ // Check goes here, because it is generalizable between
str/String and Row/OwnedRow
+ let id = ids.value(row_idx);
+ if self.priority_map.is_full() {
+ if self.priority_map.desc {
+ if let Some(worst) = self.priority_map.min_val() {
+ if val < *worst {
+ continue;
+ }
+ }
+ } else if let Some(worst) = self.priority_map.max_val() {
+ if val > *worst {
+ continue;
+ }
+ }
+ }
+ Some(id.to_string())
+ };
+
+ self.priority_map.insert(id, val)?;
+ }
+ Ok(())
+ }
+
+ fn emit(&mut self) -> Result<Vec<ArrayRef>> {
+ let (keys, vals): (Vec<_>, Vec<_>) =
+ self.priority_map.drain().into_iter().unzip();
+ let keys = Arc::new(StringArray::from(keys));
+ let vals = Arc::new(PrimitiveArray::<VAL>::from_iter_values(vals));
+ Ok(vec![keys, vals])
+ }
+
+ fn is_empty(&self) -> bool {
+ self.priority_map.is_empty()
+ }
+}
+
+impl Stream for GroupedTopKAggregateStream {
+ type Item = Result<RecordBatch>;
+
+ fn poll_next(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<Option<Self::Item>> {
+ while let Poll::Ready(res) = self.input.poll_next_unpin(cx) {
+ match res {
+ // got a batch, convert to rows and append to our TreeMap
+ Some(Ok(batch)) => {
+ self.started = true;
+ trace!(
+ "partition {} has {} rows and got batch with {} rows",
+ self.partition,
+ self.row_count,
+ batch.num_rows()
+ );
+ if log::log_enabled!(Level::Trace) && batch.num_rows() <
20 {
+ print_batches(&[batch.clone()])?;
+ }
+ self.row_count += batch.num_rows();
+ let batches = &[batch];
+ let group_by_values =
+ evaluate_group_by(&self.group_by,
batches.first().unwrap())?;
+ let group_by_values =
+ group_by_values.into_iter().last().expect("values");
+ let group_by_values =
+ group_by_values.into_iter().last().expect("values");
+ let input_values = evaluate_many(
+ &self.aggregate_arguments,
+ batches.first().unwrap(),
+ )?;
+ let input_values = match input_values.as_slice() {
+ [] => {
+ Err(DataFusionError::Execution("vals
required".to_string()))?
Review Comment:
```suggestion
Err(DataFusionError::Internal("vals
required".to_string()))?
```
Probably the same for the other checks below
##########
datafusion/sqllogictest/test_files/aggregate.slt:
##########
@@ -2291,7 +2291,131 @@ false
true
NULL
+# TopK aggregation
+statement ok
+CREATE TABLE traces(trace_id varchar, timestamp bigint) AS VALUES
+(NULL, 0),
+('a', NULL),
+('a', 1),
+('b', 0),
+('c', 1),
+('c', 2),
+('b', 3);
+
+statement ok
+set datafusion.optimizer.enable_topk_aggregation = false;
+query TT
+explain select trace_id, MAX(timestamp) from traces group by trace_id order by
MAX(timestamp) desc limit 4;
+----
+logical_plan
+Limit: skip=0, fetch=4
+--Sort: MAX(traces.timestamp) DESC NULLS FIRST, fetch=4
+----Aggregate: groupBy=[[traces.trace_id]], aggr=[[MAX(traces.timestamp)]]
+------TableScan: traces projection=[trace_id, timestamp]
+physical_plan
+GlobalLimitExec: skip=0, fetch=4
+--SortPreservingMergeExec: [MAX(traces.timestamp)@1 DESC], fetch=4
+----SortExec: fetch=4, expr=[MAX(traces.timestamp)@1 DESC]
+------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id],
aggr=[MAX(traces.timestamp)]
+--------CoalesceBatchesExec: target_batch_size=8192
+----------RepartitionExec: partitioning=Hash([trace_id@0], 4),
input_partitions=4
+------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id],
aggr=[MAX(traces.timestamp)]
+--------------RepartitionExec: partitioning=RoundRobinBatch(4),
input_partitions=1
+----------------MemoryExec: partitions=1, partition_sizes=[1]
+
+
+query TI
+select trace_id, MAX(timestamp) from traces group by trace_id order by
MAX(timestamp) desc limit 4;
Review Comment:
I recommend using a limit that doesn't save all of the values -- there are
only 4 distinct values in the input so there will be only 4 in the output --
maybe you could either add another value to the input or reduce the limit to
`3` so that it was clearly limiting
##########
datafusion/core/src/physical_plan/aggregates/row_hash.rs:
##########
@@ -145,7 +145,7 @@ pub(crate) struct GroupedHashAggregateStream {
/// accumulator. If present, only those rows for which the filter
/// evaluate to true should be included in the aggregate results.
///
- /// For example, for an aggregate like `SUM(x FILTER x > 100)`,
+ /// For example, for an aggregate like `SUM(x) FILTER (WHERE x >= 100)`,
Review Comment:
👍
##########
datafusion/core/tests/sql/select.rs:
##########
@@ -572,6 +574,79 @@ async fn parallel_query_with_filter() -> Result<()> {
Ok(())
}
+#[tokio::test]
+async fn parallel_query_with_limit() -> Result<()> {
+ let tmp_dir = TempDir::new()?;
+ let partition_count = 4;
+ let ctx = partitioned_csv::create_ctx(&tmp_dir, partition_count).await?;
+
+ let dataframe = ctx
+ .sql("SELECT c3, max(c2) as max FROM test group by c3 order by max
desc limit 2")
+ .await?;
+
+ let actual_logical_plan = format!("{:?}", dataframe.logical_plan());
+ let expected_logical_plan = r#"
+Limit: skip=0, fetch=2
+ Sort: max DESC NULLS FIRST
+ Projection: test.c3, MAX(test.c2) AS max
+ Aggregate: groupBy=[[test.c3]], aggr=[[MAX(test.c2)]]
+ TableScan: test
+ "#
+ .trim();
+ assert_eq!(expected_logical_plan, actual_logical_plan);
+
+ let physical_plan = dataframe.create_physical_plan().await?;
+
+ // TODO: find the GroupedHashAggregateStream node and see if we can assert
bucket count
+ finder(physical_plan.clone());
+
+ let actual_phys_plan =
displayable(physical_plan.as_ref()).indent(true).to_string();
+ let mut expected_physical_plan = r#"
+GlobalLimitExec: skip=0, fetch=2
+ SortPreservingMergeExec: [max@1 DESC], fetch=2
+ SortExec: fetch=2, expr=[max@1 DESC]
+ ProjectionExec: expr=[c3@0 as c3, MAX(test.c2)@1 as max]
+ AggregateExec: mode=FinalPartitioned, gby=[c3@0 as c3],
aggr=[MAX(test.c2)]
Review Comment:
I am surprised this AggregateExec doesn't have a `lim=` / `fetch` field
##########
datafusion/core/tests/sql/select.rs:
##########
@@ -572,6 +574,79 @@ async fn parallel_query_with_filter() -> Result<()> {
Ok(())
}
+#[tokio::test]
+async fn parallel_query_with_limit() -> Result<()> {
+ let tmp_dir = TempDir::new()?;
+ let partition_count = 4;
+ let ctx = partitioned_csv::create_ctx(&tmp_dir, partition_count).await?;
+
+ let dataframe = ctx
+ .sql("SELECT c3, max(c2) as max FROM test group by c3 order by max
desc limit 2")
+ .await?;
+
+ let actual_logical_plan = format!("{:?}", dataframe.logical_plan());
+ let expected_logical_plan = r#"
+Limit: skip=0, fetch=2
+ Sort: max DESC NULLS FIRST
+ Projection: test.c3, MAX(test.c2) AS max
+ Aggregate: groupBy=[[test.c3]], aggr=[[MAX(test.c2)]]
+ TableScan: test
+ "#
+ .trim();
+ assert_eq!(expected_logical_plan, actual_logical_plan);
+
+ let physical_plan = dataframe.create_physical_plan().await?;
+
+ // TODO: find the GroupedHashAggregateStream node and see if we can assert
bucket count
Review Comment:
Is this still a TODO?
##########
datafusion/core/src/physical_plan/aggregates/priority_map.rs:
##########
@@ -0,0 +1,829 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! A memory-conscious aggregation implementation that limits group buckets to
a fixed number
Review Comment:
```suggestion
//! A memory-conscious aggregation implementation that limits group buckets
to a fixed number
//! [`GroupedTopKAggregateStream`]
```
##########
datafusion/core/src/physical_plan/aggregates/mod.rs:
##########
@@ -265,6 +270,8 @@ pub struct AggregateExec {
pub(crate) filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
/// (ORDER BY clause) expression for each aggregate expression
pub(crate) order_by_expr: Vec<Option<LexOrdering>>,
+ /// Set if the output of this aggregation is truncated by a upstream
sort/limit clause
+ pub(crate) limit: Option<usize>,
Review Comment:
I wonder if we can add comments here or name it something different (like
`group_num_limit` ?) to explain what is going on. Initially this might look
like a `fetch` on sort or limit, but I think it is quite different
Specifically, I think it means "output only the top groups based on the
value of the aggregates" (I still don't fully understand how MIN/MAX and ORDER
BY ASC/DESC are handled together, I would have expected the optimizer to have
to specify that here as well, but perhaps it only marks `limit` for the right
combinations)
##########
datafusion/core/src/physical_optimizer/topk_aggregation.rs:
##########
@@ -0,0 +1,191 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! An optimizer rule that detects aggregate operations that could use a
limited bucket count
+
+use crate::physical_optimizer::PhysicalOptimizerRule;
+use crate::physical_plan::aggregates::AggregateExec;
+use crate::physical_plan::coalesce_batches::CoalesceBatchesExec;
+use crate::physical_plan::filter::FilterExec;
+use crate::physical_plan::repartition::RepartitionExec;
+use crate::physical_plan::sorts::sort::SortExec;
+use crate::physical_plan::ExecutionPlan;
+use arrow_schema::DataType;
+use datafusion_common::config::ConfigOptions;
+use datafusion_common::tree_node::{Transformed, TreeNode};
+use datafusion_common::Result;
+use datafusion_physical_expr::expressions::Column;
+use datafusion_physical_expr::PhysicalSortExpr;
+use std::sync::Arc;
+
+/// An optimizer rule that passes a `limit` hint to aggregations if the whole
result is not needed
+pub struct TopKAggregation {}
+
+impl TopKAggregation {
+ /// Create a new `LimitAggregation`
+ pub fn new() -> Self {
+ Self {}
+ }
+
+ fn transform_agg(
+ aggr: &AggregateExec,
+ order: &PhysicalSortExpr,
+ limit: usize,
+ ) -> Option<Arc<dyn ExecutionPlan>> {
+ // ensure the sort direction matches aggregate function
+ let (field, desc) = aggr.get_minmax_desc()?;
+ if desc != order.options.descending {
+ return None;
+ }
+ let group_key = match aggr.group_expr().expr() {
+ [expr] => expr, // only one group key
+ _ => return None,
+ };
+ match group_key.0.data_type(&aggr.input_schema).ok() {
+ Some(DataType::Utf8) => {} // only String keys for now
+ _ => return None,
+ }
+ if aggr
+ .filter_expr
+ .iter()
+ .fold(false, |acc, cur| acc | cur.is_some())
+ {
+ return None;
+ }
+
+ // ensure the sort is on the same field as the aggregate output
+ let col = order.expr.as_any().downcast_ref::<Column>()?;
+ if col.name() != field.name() {
+ return None;
+ }
+
+ // We found what we want: clone, copy the limit down, and return
modified node
+ let mut new_aggr = AggregateExec::try_new(
+ aggr.mode,
+ aggr.group_by.clone(),
+ aggr.aggr_expr.clone(),
+ aggr.filter_expr.clone(),
+ aggr.order_by_expr.clone(),
+ aggr.input.clone(),
+ aggr.input_schema.clone(),
+ )
+ .expect("Unable to copy Aggregate!");
+ new_aggr.limit = Some(limit);
+ Some(Arc::new(new_aggr))
+ }
+
+ fn transform_sort(plan: Arc<dyn ExecutionPlan>) -> Option<Arc<dyn
ExecutionPlan>> {
+ let sort = plan.as_any().downcast_ref::<SortExec>()?;
+
+ // TODO: support sorting on multiple fields
+ let children = sort.children();
+ let child = match children.as_slice() {
+ [child] => child.clone(),
+ _ => return None,
+ };
+ let order = sort.output_ordering()?;
+ let order = match order {
+ [order] => order,
+ _ => return None,
+ };
+ let limit = sort.fetch()?;
+
+ let is_cardinality_preserving = |plan: Arc<dyn ExecutionPlan>| {
+ plan.as_any()
+ .downcast_ref::<CoalesceBatchesExec>()
+ .is_some()
+ || plan.as_any().downcast_ref::<RepartitionExec>().is_some()
+ || plan.as_any().downcast_ref::<FilterExec>().is_some()
+ // TODO: whitelist joins that don't increase row count?
+ };
Review Comment:
I think it would be better to add a new method to the `ExecutionPlan` trait
for this property rather than check for specific types. That way extensions and
future plans can work correctly with this optimizer
I am also somewhat confused that `FilterExec` would be labeled as
"cardinality preserving" as it can certainly reduce the cardinality of its
input 🤔 Is the idea to check for "non increasing" cardinality maybe?
##########
datafusion/core/src/physical_plan/aggregates/priority_map.rs:
##########
@@ -0,0 +1,829 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! A memory-conscious aggregation implementation that limits group buckets to
a fixed number
+
+use crate::physical_plan::aggregates::{
+ aggregate_expressions, evaluate_group_by, evaluate_many, AggregateExec,
+ PhysicalGroupBy,
+};
+use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream};
+use ahash::RandomState;
+use arrow::util::pretty::print_batches;
+use arrow_array::cast::AsArray;
+use arrow_array::downcast_primitive;
+use arrow_array::{
+ Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray,
RecordBatch,
+ StringArray,
+};
+use arrow_schema::{DataType, SchemaRef};
+use datafusion_common::DataFusionError;
+use datafusion_common::Result;
+use datafusion_execution::TaskContext;
+use datafusion_physical_expr::PhysicalExpr;
+use futures::stream::{Stream, StreamExt};
+use hashbrown::raw::RawTable;
+use log::{trace, Level};
+use std::cmp::Ordering;
+use std::collections::BTreeSet;
+use std::hash::Hash;
+use std::pin::Pin;
+use std::sync::Arc;
+use std::task::{Context, Poll};
+
+pub struct GroupedTopKAggregateStream {
+ partition: usize,
+ row_count: usize,
+ started: bool,
+ schema: SchemaRef,
+ input: SendableRecordBatchStream,
+ aggregate_arguments: Vec<Vec<Arc<dyn PhysicalExpr>>>,
+ group_by: PhysicalGroupBy,
+ aggregator: Box<dyn LimitedAggregator>,
+}
+
+impl GroupedTopKAggregateStream {
+ pub fn new(
+ agg: &AggregateExec,
+ context: Arc<TaskContext>,
+ partition: usize,
+ limit: usize,
+ ) -> Result<Self> {
+ let agg_schema = Arc::clone(&agg.schema);
+ let group_by = agg.group_by.clone();
+
+ let input = agg.input.execute(partition, Arc::clone(&context))?;
+
+ let aggregate_arguments =
+ aggregate_expressions(&agg.aggr_expr, &agg.mode,
group_by.expr.len())?;
+
+ let (val_field, descending) = agg
+ .get_minmax_desc()
+ .ok_or_else(|| DataFusionError::Execution("Min/max
required".to_string()))?;
Review Comment:
If this code get run somehow it would be due to bug in one of the optimizer
passes
```suggestion
.ok_or_else(|| DataFusionError::Internal("Min/max
required".to_string()))?;
```
##########
datafusion/core/src/physical_optimizer/topk_aggregation.rs:
##########
@@ -0,0 +1,191 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! An optimizer rule that detects aggregate operations that could use a
limited bucket count
+
+use crate::physical_optimizer::PhysicalOptimizerRule;
+use crate::physical_plan::aggregates::AggregateExec;
+use crate::physical_plan::coalesce_batches::CoalesceBatchesExec;
+use crate::physical_plan::filter::FilterExec;
+use crate::physical_plan::repartition::RepartitionExec;
+use crate::physical_plan::sorts::sort::SortExec;
+use crate::physical_plan::ExecutionPlan;
+use arrow_schema::DataType;
+use datafusion_common::config::ConfigOptions;
+use datafusion_common::tree_node::{Transformed, TreeNode};
+use datafusion_common::Result;
+use datafusion_physical_expr::expressions::Column;
+use datafusion_physical_expr::PhysicalSortExpr;
+use std::sync::Arc;
+
+/// An optimizer rule that passes a `limit` hint to aggregations if the whole
result is not needed
+pub struct TopKAggregation {}
+
+impl TopKAggregation {
+ /// Create a new `LimitAggregation`
+ pub fn new() -> Self {
+ Self {}
+ }
+
+ fn transform_agg(
+ aggr: &AggregateExec,
+ order: &PhysicalSortExpr,
+ limit: usize,
+ ) -> Option<Arc<dyn ExecutionPlan>> {
+ // ensure the sort direction matches aggregate function
+ let (field, desc) = aggr.get_minmax_desc()?;
+ if desc != order.options.descending {
+ return None;
+ }
+ let group_key = match aggr.group_expr().expr() {
+ [expr] => expr, // only one group key
+ _ => return None,
+ };
+ match group_key.0.data_type(&aggr.input_schema).ok() {
+ Some(DataType::Utf8) => {} // only String keys for now
+ _ => return None,
+ }
+ if aggr
+ .filter_expr
+ .iter()
+ .fold(false, |acc, cur| acc | cur.is_some())
+ {
+ return None;
+ }
+
+ // ensure the sort is on the same field as the aggregate output
+ let col = order.expr.as_any().downcast_ref::<Column>()?;
+ if col.name() != field.name() {
+ return None;
+ }
+
+ // We found what we want: clone, copy the limit down, and return
modified node
+ let mut new_aggr = AggregateExec::try_new(
+ aggr.mode,
+ aggr.group_by.clone(),
+ aggr.aggr_expr.clone(),
+ aggr.filter_expr.clone(),
+ aggr.order_by_expr.clone(),
+ aggr.input.clone(),
+ aggr.input_schema.clone(),
+ )
+ .expect("Unable to copy Aggregate!");
+ new_aggr.limit = Some(limit);
+ Some(Arc::new(new_aggr))
+ }
+
+ fn transform_sort(plan: Arc<dyn ExecutionPlan>) -> Option<Arc<dyn
ExecutionPlan>> {
+ let sort = plan.as_any().downcast_ref::<SortExec>()?;
+
+ // TODO: support sorting on multiple fields
+ let children = sort.children();
+ let child = match children.as_slice() {
+ [child] => child.clone(),
+ _ => return None,
+ };
+ let order = sort.output_ordering()?;
+ let order = match order {
+ [order] => order,
+ _ => return None,
+ };
+ let limit = sort.fetch()?;
+
+ let is_cardinality_preserving = |plan: Arc<dyn ExecutionPlan>| {
+ plan.as_any()
+ .downcast_ref::<CoalesceBatchesExec>()
+ .is_some()
+ || plan.as_any().downcast_ref::<RepartitionExec>().is_some()
+ || plan.as_any().downcast_ref::<FilterExec>().is_some()
+ // TODO: whitelist joins that don't increase row count?
+ };
Review Comment:
I think it would be better to add a new method to the `ExecutionPlan` trait
for this property rather than check for specific types. That way extensions and
future plans can work correctly with this optimizer
I am also somewhat confused that `FilterExec` would be labeled as
"cardinality preserving" as it can certainly reduce the cardinality of its
input 🤔 Is the idea to check for "non increasing" cardinality maybe?
##########
datafusion/core/src/physical_plan/aggregates/priority_map.rs:
##########
@@ -0,0 +1,829 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! A memory-conscious aggregation implementation that limits group buckets to
a fixed number
+
+use crate::physical_plan::aggregates::{
+ aggregate_expressions, evaluate_group_by, evaluate_many, AggregateExec,
+ PhysicalGroupBy,
+};
+use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream};
+use ahash::RandomState;
+use arrow::util::pretty::print_batches;
+use arrow_array::cast::AsArray;
+use arrow_array::downcast_primitive;
+use arrow_array::{
+ Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray,
RecordBatch,
+ StringArray,
+};
+use arrow_schema::{DataType, SchemaRef};
+use datafusion_common::DataFusionError;
+use datafusion_common::Result;
+use datafusion_execution::TaskContext;
+use datafusion_physical_expr::PhysicalExpr;
+use futures::stream::{Stream, StreamExt};
+use hashbrown::raw::RawTable;
+use log::{trace, Level};
+use std::cmp::Ordering;
+use std::collections::BTreeSet;
+use std::hash::Hash;
+use std::pin::Pin;
+use std::sync::Arc;
+use std::task::{Context, Poll};
+
+pub struct GroupedTopKAggregateStream {
+ partition: usize,
+ row_count: usize,
+ started: bool,
+ schema: SchemaRef,
+ input: SendableRecordBatchStream,
+ aggregate_arguments: Vec<Vec<Arc<dyn PhysicalExpr>>>,
+ group_by: PhysicalGroupBy,
+ aggregator: Box<dyn LimitedAggregator>,
+}
+
+impl GroupedTopKAggregateStream {
+ pub fn new(
+ agg: &AggregateExec,
+ context: Arc<TaskContext>,
+ partition: usize,
+ limit: usize,
+ ) -> Result<Self> {
+ let agg_schema = Arc::clone(&agg.schema);
+ let group_by = agg.group_by.clone();
+
+ let input = agg.input.execute(partition, Arc::clone(&context))?;
+
+ let aggregate_arguments =
+ aggregate_expressions(&agg.aggr_expr, &agg.mode,
group_by.expr.len())?;
+
+ let (val_field, descending) = agg
+ .get_minmax_desc()
+ .ok_or_else(|| DataFusionError::Execution("Min/max
required".to_string()))?;
+
+ let vt = val_field.data_type().clone();
+ let ag = new_group_values(limit, descending, vt)?;
+
+ Ok(GroupedTopKAggregateStream {
+ partition,
+ started: false,
+ row_count: 0,
+ schema: agg_schema,
+ input,
+ aggregate_arguments,
+ group_by,
+ aggregator: ag,
+ })
+ }
+}
+
+pub fn new_group_values(
+ limit: usize,
+ desc: bool,
+ vt: DataType,
+) -> Result<Box<dyn LimitedAggregator>> {
+ macro_rules! downcast_helper {
+ ($vt:ty, $d:ident) => {
+ return Ok(Box::new(PrimitiveAggregator::<$vt>::new(
+ limit,
+ limit * 10,
+ desc,
+ )))
+ };
+ }
+
+ downcast_primitive! {
+ vt => (downcast_helper, vt),
+ _ => {}
+ }
+
+ Err(DataFusionError::Execution(format!(
+ "Can't group type: {vt:?}"
+ )))
+}
+
+impl RecordBatchStream for GroupedTopKAggregateStream {
+ fn schema(&self) -> SchemaRef {
+ self.schema.clone()
+ }
+}
+
+pub trait LimitedAggregator: Send {
+ fn intern(&mut self, ids: ArrayRef, vals: ArrayRef) -> Result<()>;
+ fn emit(&mut self) -> Result<Vec<ArrayRef>>;
+ fn is_empty(&self) -> bool;
+}
+
+pub trait ValueType: ArrowNativeTypeOp + Clone {}
+
+impl<T> ValueType for T where T: ArrowNativeTypeOp + Clone {}
+
+pub trait KeyType: Clone + Eq + Hash {}
+
+impl<T> KeyType for T where T: Clone + Eq + Hash {}
+
+struct PrimitiveAggregator<VAL: ArrowPrimitiveType>
+where
+ <VAL as ArrowPrimitiveType>::Native: Clone,
+{
+ priority_map: PriorityMap<Option<String>, VAL::Native>,
+}
+
+impl<VAL: ArrowPrimitiveType> PrimitiveAggregator<VAL>
+where
+ <VAL as ArrowPrimitiveType>::Native: Clone,
+{
+ pub fn new(limit: usize, capacity: usize, descending: bool) -> Self {
+ Self {
+ priority_map: PriorityMap::new(limit, capacity, descending),
+ }
+ }
+}
+
+unsafe impl<VAL: ArrowPrimitiveType> Send for PrimitiveAggregator<VAL> where
+ <VAL as ArrowPrimitiveType>::Native: Clone
+{
+}
+
+impl<VAL: ArrowPrimitiveType> LimitedAggregator for PrimitiveAggregator<VAL>
+where
+ <VAL as ArrowPrimitiveType>::Native: Clone,
+{
+ fn intern(&mut self, ids: ArrayRef, vals: ArrayRef) -> Result<()> {
+ let ids = ids.as_any().downcast_ref::<StringArray>().ok_or_else(|| {
+ DataFusionError::Execution("Expected StringArray".to_string())
+ })?;
+ let vals = vals.as_primitive::<VAL>();
+ let null_count = vals.null_count();
+ for row_idx in 0..ids.len() {
+ if null_count > 0 && vals.is_null(row_idx) {
+ continue;
+ }
+ let val = vals.value(row_idx);
+ let id = if ids.is_null(row_idx) {
+ None
+ } else {
+ // Check goes here, because it is generalizable between
str/String and Row/OwnedRow
+ let id = ids.value(row_idx);
+ if self.priority_map.is_full() {
+ if self.priority_map.desc {
+ if let Some(worst) = self.priority_map.min_val() {
+ if val < *worst {
+ continue;
+ }
+ }
+ } else if let Some(worst) = self.priority_map.max_val() {
+ if val > *worst {
+ continue;
+ }
+ }
+ }
+ Some(id.to_string())
+ };
+
+ self.priority_map.insert(id, val)?;
+ }
+ Ok(())
+ }
+
+ fn emit(&mut self) -> Result<Vec<ArrayRef>> {
+ let (keys, vals): (Vec<_>, Vec<_>) =
+ self.priority_map.drain().into_iter().unzip();
+ let keys = Arc::new(StringArray::from(keys));
+ let vals = Arc::new(PrimitiveArray::<VAL>::from_iter_values(vals));
+ Ok(vec![keys, vals])
+ }
+
+ fn is_empty(&self) -> bool {
+ self.priority_map.is_empty()
+ }
+}
+
+impl Stream for GroupedTopKAggregateStream {
+ type Item = Result<RecordBatch>;
+
+ fn poll_next(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<Option<Self::Item>> {
+ while let Poll::Ready(res) = self.input.poll_next_unpin(cx) {
+ match res {
+ // got a batch, convert to rows and append to our TreeMap
+ Some(Ok(batch)) => {
+ self.started = true;
+ trace!(
+ "partition {} has {} rows and got batch with {} rows",
+ self.partition,
+ self.row_count,
+ batch.num_rows()
+ );
+ if log::log_enabled!(Level::Trace) && batch.num_rows() <
20 {
+ print_batches(&[batch.clone()])?;
+ }
+ self.row_count += batch.num_rows();
+ let batches = &[batch];
+ let group_by_values =
+ evaluate_group_by(&self.group_by,
batches.first().unwrap())?;
+ let group_by_values =
Review Comment:
maybe it would be clearer if you asserted that these slices were length 1
and then do group_by_values[0][0] or soemthing
##########
datafusion/core/src/physical_optimizer/topk_aggregation.rs:
##########
@@ -0,0 +1,191 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! An optimizer rule that detects aggregate operations that could use a
limited bucket count
+
+use crate::physical_optimizer::PhysicalOptimizerRule;
+use crate::physical_plan::aggregates::AggregateExec;
+use crate::physical_plan::coalesce_batches::CoalesceBatchesExec;
+use crate::physical_plan::filter::FilterExec;
+use crate::physical_plan::repartition::RepartitionExec;
+use crate::physical_plan::sorts::sort::SortExec;
+use crate::physical_plan::ExecutionPlan;
+use arrow_schema::DataType;
+use datafusion_common::config::ConfigOptions;
+use datafusion_common::tree_node::{Transformed, TreeNode};
+use datafusion_common::Result;
+use datafusion_physical_expr::expressions::Column;
+use datafusion_physical_expr::PhysicalSortExpr;
+use std::sync::Arc;
+
+/// An optimizer rule that passes a `limit` hint to aggregations if the whole
result is not needed
+pub struct TopKAggregation {}
+
+impl TopKAggregation {
+ /// Create a new `LimitAggregation`
+ pub fn new() -> Self {
+ Self {}
+ }
+
+ fn transform_agg(
+ aggr: &AggregateExec,
+ order: &PhysicalSortExpr,
+ limit: usize,
+ ) -> Option<Arc<dyn ExecutionPlan>> {
+ // ensure the sort direction matches aggregate function
+ let (field, desc) = aggr.get_minmax_desc()?;
+ if desc != order.options.descending {
+ return None;
+ }
+ let group_key = match aggr.group_expr().expr() {
+ [expr] => expr, // only one group key
+ _ => return None,
+ };
+ match group_key.0.data_type(&aggr.input_schema).ok() {
+ Some(DataType::Utf8) => {} // only String keys for now
+ _ => return None,
+ }
+ if aggr
+ .filter_expr
+ .iter()
+ .fold(false, |acc, cur| acc | cur.is_some())
+ {
+ return None;
+ }
+
+ // ensure the sort is on the same field as the aggregate output
+ let col = order.expr.as_any().downcast_ref::<Column>()?;
+ if col.name() != field.name() {
+ return None;
+ }
+
+ // We found what we want: clone, copy the limit down, and return
modified node
+ let mut new_aggr = AggregateExec::try_new(
+ aggr.mode,
+ aggr.group_by.clone(),
+ aggr.aggr_expr.clone(),
+ aggr.filter_expr.clone(),
+ aggr.order_by_expr.clone(),
+ aggr.input.clone(),
+ aggr.input_schema.clone(),
+ )
+ .expect("Unable to copy Aggregate!");
+ new_aggr.limit = Some(limit);
+ Some(Arc::new(new_aggr))
+ }
+
+ fn transform_sort(plan: Arc<dyn ExecutionPlan>) -> Option<Arc<dyn
ExecutionPlan>> {
+ let sort = plan.as_any().downcast_ref::<SortExec>()?;
+
+ // TODO: support sorting on multiple fields
+ let children = sort.children();
+ let child = match children.as_slice() {
+ [child] => child.clone(),
+ _ => return None,
+ };
+ let order = sort.output_ordering()?;
+ let order = match order {
+ [order] => order,
+ _ => return None,
+ };
+ let limit = sort.fetch()?;
+
+ let is_cardinality_preserving = |plan: Arc<dyn ExecutionPlan>| {
+ plan.as_any()
+ .downcast_ref::<CoalesceBatchesExec>()
+ .is_some()
+ || plan.as_any().downcast_ref::<RepartitionExec>().is_some()
+ || plan.as_any().downcast_ref::<FilterExec>().is_some()
+ // TODO: whitelist joins that don't increase row count?
+ };
+
+ let mut cardinality_preserved = true;
+ let mut closure = |plan: Arc<dyn ExecutionPlan>| {
+ if !cardinality_preserved {
+ return Ok(Transformed::No(plan));
+ }
+ if let Some(aggr) = plan.as_any().downcast_ref::<AggregateExec>() {
+ // either we run into an Aggregate and transform it
+ match Self::transform_agg(aggr, order, limit) {
+ None => cardinality_preserved = false,
+ Some(plan) => return Ok(Transformed::Yes(plan)),
+ }
+ } else {
+ // or we continue down whitelisted nodes of other types
+ if !is_cardinality_preserving(plan.clone()) {
+ cardinality_preserved = false;
+ }
+ }
+ Ok(Transformed::No(plan))
+ };
+ let child = transform_down_mut(child, &mut closure).ok()?;
+ let sort = SortExec::new(sort.expr().to_vec(), child)
+ .with_fetch(sort.fetch())
+ .with_preserve_partitioning(sort.preserve_partitioning());
+ Some(Arc::new(sort))
+ }
+}
+
+fn transform_down_mut<F>(
+ me: Arc<dyn ExecutionPlan>,
+ op: &mut F,
+) -> Result<Arc<dyn ExecutionPlan>>
+where
+ F: FnMut(Arc<dyn ExecutionPlan>) -> Result<Transformed<Arc<dyn
ExecutionPlan>>>,
+{
+ let after_op = op(me)?.into();
+ after_op.map_children(|node| transform_down_mut(node, op))
+}
+
+impl Default for TopKAggregation {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+impl PhysicalOptimizerRule for TopKAggregation {
+ fn optimize(
+ &self,
+ plan: Arc<dyn ExecutionPlan>,
+ config: &ConfigOptions,
+ ) -> Result<Arc<dyn ExecutionPlan>> {
+ let plan = if config.optimizer.enable_topk_aggregation {
+ plan.transform_down(&|plan| {
+ Ok(
+ if let Some(plan) =
TopKAggregation::transform_sort(plan.clone()) {
+ Transformed::Yes(plan)
+ } else {
+ Transformed::No(plan)
+ },
+ )
+ })?
+ } else {
+ plan
+ };
+ Ok(plan)
+ }
+
+ fn name(&self) -> &str {
+ "LimitAggregation"
+ }
+
+ fn schema_check(&self) -> bool {
+ true
+ }
+}
+
+// TODO: tests
Review Comment:
👍
##########
datafusion/core/src/physical_plan/aggregates/priority_map.rs:
##########
@@ -0,0 +1,829 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! A memory-conscious aggregation implementation that limits group buckets to
a fixed number
+
+use crate::physical_plan::aggregates::{
+ aggregate_expressions, evaluate_group_by, evaluate_many, AggregateExec,
+ PhysicalGroupBy,
+};
+use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream};
+use ahash::RandomState;
+use arrow::util::pretty::print_batches;
+use arrow_array::cast::AsArray;
+use arrow_array::downcast_primitive;
+use arrow_array::{
+ Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray,
RecordBatch,
+ StringArray,
+};
+use arrow_schema::{DataType, SchemaRef};
+use datafusion_common::DataFusionError;
+use datafusion_common::Result;
+use datafusion_execution::TaskContext;
+use datafusion_physical_expr::PhysicalExpr;
+use futures::stream::{Stream, StreamExt};
+use hashbrown::raw::RawTable;
+use log::{trace, Level};
+use std::cmp::Ordering;
+use std::collections::BTreeSet;
+use std::hash::Hash;
+use std::pin::Pin;
+use std::sync::Arc;
+use std::task::{Context, Poll};
+
+pub struct GroupedTopKAggregateStream {
Review Comment:
I think it would help to make the use case clear as this is an optimization
for a fairly special case
For example
```suggestion
/// This is a special case for queries of the following form:
///
/// SELECT group_id, MAX(time) FROM t GROUP BY group_id ORDER BY MAX(time)
ASC LIMIT k
/// SELECT group_id, MIN(time) FROM t GROUP BY group_id ORDER BY MIN(time)
DESC LIMIT k
///
/// It maintains only the current top group_ids in memory and is therefore
much more memory efficient
/// than the general purpose group operator
pub struct GroupedTopKAggregateStream {
```
##########
datafusion/core/src/physical_plan/aggregates/priority_map.rs:
##########
@@ -0,0 +1,820 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! A memory-conscious aggregation implementation that limits group buckets to
a fixed number
+
+use crate::physical_plan::aggregates::{
+ aggregate_expressions, evaluate_group_by, evaluate_many, AggregateExec,
+ PhysicalGroupBy,
+};
+use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream};
+use ahash::RandomState;
+use arrow::util::pretty::print_batches;
+use arrow_array::cast::AsArray;
+use arrow_array::downcast_primitive;
+use arrow_array::{
+ Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray,
RecordBatch,
+ StringArray,
+};
+use arrow_schema::{DataType, SchemaRef, SortOptions};
+use datafusion_common::DataFusionError;
+use datafusion_common::Result;
+use datafusion_execution::TaskContext;
+use datafusion_physical_expr::PhysicalExpr;
+use futures::stream::{Stream, StreamExt};
+use hashbrown::raw::RawTable;
+use log::{trace, Level};
+use std::cmp::Ordering;
+use std::collections::BTreeSet;
+use std::hash::{BuildHasher, Hash, Hasher};
+use std::pin::Pin;
+use std::sync::Arc;
+use std::task::{Context, Poll};
+
+pub struct GroupedTopKAggregateStream {
+ partition: usize,
+ row_count: usize,
+ started: bool,
+ schema: SchemaRef,
+ input: SendableRecordBatchStream,
+ aggregate_arguments: Vec<Vec<Arc<dyn PhysicalExpr>>>,
+ group_by: PhysicalGroupBy,
+ aggregator: Box<dyn LimitedAggregator>,
+}
+
+impl GroupedTopKAggregateStream {
+ pub fn new(
+ agg: &AggregateExec,
+ context: Arc<TaskContext>,
+ partition: usize,
+ limit: usize,
+ ) -> Result<Self> {
+ let agg_schema = Arc::clone(&agg.schema);
+ let group_by = agg.group_by.clone();
+
+ let input = agg.input.execute(partition, Arc::clone(&context))?;
+
+ let aggregate_arguments =
+ aggregate_expressions(&agg.aggr_expr, &agg.mode,
group_by.expr.len())?;
+
+ let (val_field, _) = agg
+ .get_minmax_desc()
+ .ok_or_else(|| DataFusionError::Execution("Min/max
required".to_string()))?;
+
+ let vt = val_field.data_type().clone();
+ let ag = new_group_values(limit, agg.sort, vt)?;
+
+ Ok(GroupedTopKAggregateStream {
+ partition,
+ started: false,
+ row_count: 0,
+ schema: agg_schema,
+ input,
+ aggregate_arguments,
+ group_by,
+ aggregator: ag,
+ })
+ }
+}
+
+pub fn new_group_values(
+ limit: usize,
+ sort: SortOptions,
+ vt: DataType,
+) -> Result<Box<dyn LimitedAggregator>> {
+ macro_rules! downcast_helper {
+ ($vt:ty, $d:ident) => {
+ return Ok(Box::new(PrimitiveAggregator::<$vt>::new(
+ limit,
+ limit * 10,
+ sort,
+ )))
+ };
+ }
+
+ downcast_primitive! {
+ vt => (downcast_helper, vt),
+ _ => {}
+ }
+
+ Err(DataFusionError::Execution(format!(
+ "Can't group type: {vt:?}"
+ )))
+}
+
+impl RecordBatchStream for GroupedTopKAggregateStream {
+ fn schema(&self) -> SchemaRef {
+ self.schema.clone()
+ }
+}
+
+pub trait LimitedAggregator: Send {
+ fn intern(&mut self, ids: ArrayRef, vals: ArrayRef) -> Result<()>;
+ fn emit(&mut self) -> Result<Vec<ArrayRef>>;
+ fn is_empty(&self) -> bool;
+}
+
+pub trait ValueType: ArrowNativeTypeOp + Clone {}
+
+impl<T> ValueType for T where T: ArrowNativeTypeOp + Clone {}
+
+pub trait KeyType: Clone + Eq + Hash {}
+
+impl<T> KeyType for T where T: Clone + Eq + Hash {}
+
+struct PrimitiveAggregator<VAL: ArrowPrimitiveType>
+where
+ <VAL as ArrowPrimitiveType>::Native: Clone,
+{
+ priority_map: PriorityMap<Option<String>, VAL::Native>,
Review Comment:
I think using the arrow Row format is another possibility (that would also
allow multiple group keys to be supported, though not sure if that is an
important usecase)
##########
datafusion/core/src/physical_plan/aggregates/priority_map.rs:
##########
@@ -0,0 +1,829 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! A memory-conscious aggregation implementation that limits group buckets to
a fixed number
+
+use crate::physical_plan::aggregates::{
+ aggregate_expressions, evaluate_group_by, evaluate_many, AggregateExec,
+ PhysicalGroupBy,
+};
+use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream};
+use ahash::RandomState;
+use arrow::util::pretty::print_batches;
+use arrow_array::cast::AsArray;
+use arrow_array::downcast_primitive;
+use arrow_array::{
+ Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray,
RecordBatch,
+ StringArray,
+};
+use arrow_schema::{DataType, SchemaRef};
+use datafusion_common::DataFusionError;
+use datafusion_common::Result;
+use datafusion_execution::TaskContext;
+use datafusion_physical_expr::PhysicalExpr;
+use futures::stream::{Stream, StreamExt};
+use hashbrown::raw::RawTable;
+use log::{trace, Level};
+use std::cmp::Ordering;
+use std::collections::BTreeSet;
+use std::hash::Hash;
+use std::pin::Pin;
+use std::sync::Arc;
+use std::task::{Context, Poll};
+
+pub struct GroupedTopKAggregateStream {
+ partition: usize,
+ row_count: usize,
+ started: bool,
+ schema: SchemaRef,
+ input: SendableRecordBatchStream,
+ aggregate_arguments: Vec<Vec<Arc<dyn PhysicalExpr>>>,
+ group_by: PhysicalGroupBy,
+ aggregator: Box<dyn LimitedAggregator>,
+}
+
+impl GroupedTopKAggregateStream {
+ pub fn new(
+ agg: &AggregateExec,
+ context: Arc<TaskContext>,
+ partition: usize,
+ limit: usize,
+ ) -> Result<Self> {
+ let agg_schema = Arc::clone(&agg.schema);
+ let group_by = agg.group_by.clone();
+
+ let input = agg.input.execute(partition, Arc::clone(&context))?;
+
+ let aggregate_arguments =
+ aggregate_expressions(&agg.aggr_expr, &agg.mode,
group_by.expr.len())?;
+
+ let (val_field, descending) = agg
+ .get_minmax_desc()
+ .ok_or_else(|| DataFusionError::Execution("Min/max
required".to_string()))?;
+
+ let vt = val_field.data_type().clone();
+ let ag = new_group_values(limit, descending, vt)?;
+
+ Ok(GroupedTopKAggregateStream {
+ partition,
+ started: false,
+ row_count: 0,
+ schema: agg_schema,
+ input,
+ aggregate_arguments,
+ group_by,
+ aggregator: ag,
+ })
+ }
+}
+
+pub fn new_group_values(
+ limit: usize,
+ desc: bool,
+ vt: DataType,
+) -> Result<Box<dyn LimitedAggregator>> {
+ macro_rules! downcast_helper {
+ ($vt:ty, $d:ident) => {
+ return Ok(Box::new(PrimitiveAggregator::<$vt>::new(
+ limit,
+ limit * 10,
+ desc,
+ )))
+ };
+ }
+
+ downcast_primitive! {
+ vt => (downcast_helper, vt),
+ _ => {}
+ }
+
+ Err(DataFusionError::Execution(format!(
+ "Can't group type: {vt:?}"
+ )))
+}
+
+impl RecordBatchStream for GroupedTopKAggregateStream {
+ fn schema(&self) -> SchemaRef {
+ self.schema.clone()
+ }
+}
+
+pub trait LimitedAggregator: Send {
+ fn intern(&mut self, ids: ArrayRef, vals: ArrayRef) -> Result<()>;
+ fn emit(&mut self) -> Result<Vec<ArrayRef>>;
+ fn is_empty(&self) -> bool;
+}
+
+pub trait ValueType: ArrowNativeTypeOp + Clone {}
+
+impl<T> ValueType for T where T: ArrowNativeTypeOp + Clone {}
+
+pub trait KeyType: Clone + Eq + Hash {}
+
+impl<T> KeyType for T where T: Clone + Eq + Hash {}
+
+struct PrimitiveAggregator<VAL: ArrowPrimitiveType>
+where
+ <VAL as ArrowPrimitiveType>::Native: Clone,
+{
+ priority_map: PriorityMap<Option<String>, VAL::Native>,
+}
+
+impl<VAL: ArrowPrimitiveType> PrimitiveAggregator<VAL>
+where
+ <VAL as ArrowPrimitiveType>::Native: Clone,
+{
+ pub fn new(limit: usize, capacity: usize, descending: bool) -> Self {
+ Self {
+ priority_map: PriorityMap::new(limit, capacity, descending),
+ }
+ }
+}
+
+unsafe impl<VAL: ArrowPrimitiveType> Send for PrimitiveAggregator<VAL> where
+ <VAL as ArrowPrimitiveType>::Native: Clone
+{
+}
+
+impl<VAL: ArrowPrimitiveType> LimitedAggregator for PrimitiveAggregator<VAL>
+where
+ <VAL as ArrowPrimitiveType>::Native: Clone,
+{
+ fn intern(&mut self, ids: ArrayRef, vals: ArrayRef) -> Result<()> {
+ let ids = ids.as_any().downcast_ref::<StringArray>().ok_or_else(|| {
+ DataFusionError::Execution("Expected StringArray".to_string())
+ })?;
+ let vals = vals.as_primitive::<VAL>();
+ let null_count = vals.null_count();
+ for row_idx in 0..ids.len() {
+ if null_count > 0 && vals.is_null(row_idx) {
+ continue;
+ }
+ let val = vals.value(row_idx);
+ let id = if ids.is_null(row_idx) {
+ None
+ } else {
+ // Check goes here, because it is generalizable between
str/String and Row/OwnedRow
+ let id = ids.value(row_idx);
+ if self.priority_map.is_full() {
+ if self.priority_map.desc {
+ if let Some(worst) = self.priority_map.min_val() {
+ if val < *worst {
+ continue;
+ }
+ }
+ } else if let Some(worst) = self.priority_map.max_val() {
+ if val > *worst {
+ continue;
+ }
+ }
+ }
+ Some(id.to_string())
+ };
+
+ self.priority_map.insert(id, val)?;
+ }
+ Ok(())
+ }
+
+ fn emit(&mut self) -> Result<Vec<ArrayRef>> {
+ let (keys, vals): (Vec<_>, Vec<_>) =
+ self.priority_map.drain().into_iter().unzip();
+ let keys = Arc::new(StringArray::from(keys));
+ let vals = Arc::new(PrimitiveArray::<VAL>::from_iter_values(vals));
+ Ok(vec![keys, vals])
+ }
+
+ fn is_empty(&self) -> bool {
+ self.priority_map.is_empty()
+ }
+}
+
+impl Stream for GroupedTopKAggregateStream {
+ type Item = Result<RecordBatch>;
+
+ fn poll_next(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<Option<Self::Item>> {
+ while let Poll::Ready(res) = self.input.poll_next_unpin(cx) {
+ match res {
+ // got a batch, convert to rows and append to our TreeMap
+ Some(Ok(batch)) => {
+ self.started = true;
+ trace!(
+ "partition {} has {} rows and got batch with {} rows",
+ self.partition,
+ self.row_count,
+ batch.num_rows()
+ );
+ if log::log_enabled!(Level::Trace) && batch.num_rows() <
20 {
+ print_batches(&[batch.clone()])?;
+ }
+ self.row_count += batch.num_rows();
+ let batches = &[batch];
+ let group_by_values =
+ evaluate_group_by(&self.group_by,
batches.first().unwrap())?;
+ let group_by_values =
+ group_by_values.into_iter().last().expect("values");
+ let group_by_values =
+ group_by_values.into_iter().last().expect("values");
+ let input_values = evaluate_many(
+ &self.aggregate_arguments,
+ batches.first().unwrap(),
+ )?;
+ let input_values = match input_values.as_slice() {
+ [] => {
+ Err(DataFusionError::Execution("vals
required".to_string()))?
+ }
+ [vals] => vals,
+ _ => {
+ Err(DataFusionError::Execution("1 val
required".to_string()))?
+ }
+ };
+ let input_values = match input_values.as_slice() {
+ [] => {
+ Err(DataFusionError::Execution("vals
required".to_string()))?
+ }
+ [vals] => vals,
+ _ => {
+ Err(DataFusionError::Execution("1 val
required".to_string()))?
+ }
+ }
+ .clone();
+
+ // iterate over each column of group_by values
+ (*self.aggregator).intern(group_by_values, input_values)?;
+ }
+ // inner is done, emit all rows and switch to producing output
+ None => {
+ if self.aggregator.is_empty() {
+ trace!("partition {} emit None", self.partition);
+ return Poll::Ready(None);
+ }
+ let cols = self.aggregator.emit()?;
+ let batch = RecordBatch::try_new(self.schema.clone(),
cols)?;
+ trace!(
+ "partition {} emit batch with {} rows",
+ self.partition,
+ batch.num_rows()
+ );
+ if log::log_enabled!(Level::Trace) {
+ print_batches(&[batch.clone()])?;
+ }
+ return Poll::Ready(Some(Ok(batch)));
+ }
+ // inner had error, return to caller
+ Some(Err(e)) => {
+ return Poll::Ready(Some(Err(e)));
+ }
+ }
+ }
+ Poll::Pending
+ }
+}
+
+/// A dual data structure consisting of a bi-directionally linked Map & Heap
+///
+/// The implementation is optimized for performance because `insert()` will be
called on billions of
+/// rows. Because traversing between the map & heap will happen frequently, it
is important to
+/// be highly optimized.
+///
+/// In order to quickly traverse from heap to map, we use the unsafe raw
indexes that `RawTable`
+/// exposes to us to avoid needing to find buckets based on their hash.
+///
+/// Presently this implementation does not traverse quickly from map to heap,
instead performing
+/// a B-ary search of the `BTreeSet`. In the future this could be eliminated
by implementing a
+/// custom binary heap with pointers from the map into the heap.
+pub struct PriorityMap<ID: KeyType, VAL: ValueType> {
+ limit: usize,
+ desc: bool,
+ rnd: RandomState,
+ id_to_val: RawTable<MapItem<ID, VAL>>,
+ val_to_idx: BTreeSet<HeapItem<VAL>>,
+}
+
+pub struct MapItem<ID: KeyType, VAL: ValueType> {
+ hash: u64,
+ pub id: ID,
+ pub val: VAL,
+}
+
+impl<ID: KeyType, VAL: ValueType> MapItem<ID, VAL> {
+ pub fn new(hash: u64, id: ID, val: VAL) -> Self {
+ Self { hash, id, val }
+ }
+}
+
+struct HeapItem<VAL: ValueType> {
+ val: VAL,
+ buk_idx: usize,
+}
+
+impl<VAL: ValueType> HeapItem<VAL> {
+ pub fn new(val: VAL, buk_idx: usize) -> Self {
+ Self { val, buk_idx }
+ }
+}
+
+impl<VAL: ValueType> Eq for HeapItem<VAL> {}
+
+impl<VAL: ValueType> PartialEq<Self> for HeapItem<VAL> {
+ fn eq(&self, other: &Self) -> bool {
+ self.cmp(other) == Ordering::Equal
+ }
+}
+
+impl<VAL: ValueType> PartialOrd<Self> for HeapItem<VAL> {
+ fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+ Some(self.cmp(other))
+ }
+}
+
+impl<VAL: ValueType> Ord for HeapItem<VAL> {
+ fn cmp(&self, other: &Self) -> Ordering {
+ let res = self.val.compare(other.val);
+ if res != Ordering::Equal {
+ return res;
+ }
+ self.buk_idx.cmp(&other.buk_idx)
+ }
+}
+
+impl<ID: KeyType, VAL: ValueType> PriorityMap<ID, VAL>
+where
+ VAL: PartialEq<VAL>,
+{
+ pub fn new(limit: usize, capacity: usize, desc: bool) -> Self {
+ Self {
+ limit,
+ desc,
+ rnd: Default::default(),
+ id_to_val: RawTable::with_capacity(capacity),
+ val_to_idx: Default::default(),
+ }
+ }
+
+ pub fn insert(&mut self, new_id: ID, new_val: VAL) -> Result<()> {
+ let is_full = self.is_full();
+ let desc = self.desc;
+ assert!(self.len() <= self.limit, "Overflow");
+ let (id_to_val, val_to_idx) = (&mut self.id_to_val, &mut
self.val_to_idx);
+
+ // if we're full, and the new val is worse than all our values, just
bail
+ if is_full {
+ let worst_entry = if desc {
+ val_to_idx.first()
+ } else {
+ val_to_idx.last()
+ }
+ .expect("Missing value!");
+ if (!desc && new_val > worst_entry.val) || (desc && new_val <
worst_entry.val)
+ {
+ return Ok(());
+ }
+ }
+
+ // handle new groups we haven't seen yet
+ let new_hash = self.rnd.hash_one(&new_id);
Review Comment:
I didn't review this implementation carefully yet but the basic idea makes
sense to me. I can review it more carefully if @Dandandan and @thinkharderdev
haven't done so.
I still believe (without evidence yet) that I can use the same code from
https://github.com/apache/arrow-datafusion/pull/7250 and get very similar
performance, but I can do that as a follow on PR in parallel with this one
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]