This is an automated email from the ASF dual-hosted git repository.

avantgardner pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new c412c74a7e Fix bug in TopK aggregates (#12766)
c412c74a7e is described below

commit c412c74a7e6874ee09b1749aec6fd2a0bc72faed
Author: Brent Gardner <[email protected]>
AuthorDate: Tue Oct 8 10:41:34 2024 -0600

    Fix bug in TopK aggregates (#12766)
    
    Fix bug in TopK aggregates (#12766)
---
 .../physical-optimizer/src/topk_aggregation.rs     | 46 ++++++++++++----------
 datafusion/physical-plan/src/aggregates/mod.rs     |  5 +++
 datafusion/physical-plan/src/coalesce_batches.rs   |  5 +++
 .../physical-plan/src/coalesce_partitions.rs       |  5 +++
 datafusion/physical-plan/src/execution_plan.rs     | 19 +++++++++
 datafusion/physical-plan/src/filter.rs             |  5 +++
 datafusion/physical-plan/src/limit.rs              |  5 +++
 datafusion/physical-plan/src/projection.rs         |  5 +++
 datafusion/physical-plan/src/repartition/mod.rs    |  5 +++
 datafusion/physical-plan/src/sorts/sort.rs         |  9 +++++
 .../sqllogictest/test_files/aggregates_topk.slt    | 11 ++++++
 11 files changed, 100 insertions(+), 20 deletions(-)

diff --git a/datafusion/physical-optimizer/src/topk_aggregation.rs 
b/datafusion/physical-optimizer/src/topk_aggregation.rs
index 804dd165d3..5dec99535c 100644
--- a/datafusion/physical-optimizer/src/topk_aggregation.rs
+++ b/datafusion/physical-optimizer/src/topk_aggregation.rs
@@ -20,9 +20,6 @@
 use std::sync::Arc;
 
 use datafusion_physical_plan::aggregates::AggregateExec;
-use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec;
-use datafusion_physical_plan::filter::FilterExec;
-use datafusion_physical_plan::repartition::RepartitionExec;
 use datafusion_physical_plan::sorts::sort::SortExec;
 use datafusion_physical_plan::ExecutionPlan;
 
@@ -31,9 +28,10 @@ use datafusion_common::config::ConfigOptions;
 use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
 use datafusion_common::Result;
 use datafusion_physical_expr::expressions::Column;
-use datafusion_physical_expr::PhysicalSortExpr;
 
 use crate::PhysicalOptimizerRule;
+use datafusion_physical_plan::execution_plan::CardinalityEffect;
+use datafusion_physical_plan::projection::ProjectionExec;
 use itertools::Itertools;
 
 /// An optimizer rule that passes a `limit` hint to aggregations if the whole 
result is not needed
@@ -48,12 +46,13 @@ impl TopKAggregation {
 
     fn transform_agg(
         aggr: &AggregateExec,
-        order: &PhysicalSortExpr,
+        order_by: &str,
+        order_desc: bool,
         limit: usize,
     ) -> Option<Arc<dyn ExecutionPlan>> {
         // ensure the sort direction matches aggregate function
         let (field, desc) = aggr.get_minmax_desc()?;
-        if desc != order.options.descending {
+        if desc != order_desc {
             return None;
         }
         let group_key = aggr.group_expr().expr().iter().exactly_one().ok()?;
@@ -66,8 +65,7 @@ impl TopKAggregation {
         }
 
         // ensure the sort is on the same field as the aggregate output
-        let col = order.expr.as_any().downcast_ref::<Column>()?;
-        if col.name() != field.name() {
+        if order_by != field.name() {
             return None;
         }
 
@@ -92,16 +90,11 @@ impl TopKAggregation {
         let child = children.into_iter().exactly_one().ok()?;
         let order = sort.properties().output_ordering()?;
         let order = order.iter().exactly_one().ok()?;
+        let order_desc = order.options.descending;
+        let order = order.expr.as_any().downcast_ref::<Column>()?;
+        let mut cur_col_name = order.name().to_string();
         let limit = sort.fetch()?;
 
-        let is_cardinality_preserving = |plan: Arc<dyn ExecutionPlan>| {
-            plan.as_any()
-                .downcast_ref::<CoalesceBatchesExec>()
-                .is_some()
-                || plan.as_any().downcast_ref::<RepartitionExec>().is_some()
-                || plan.as_any().downcast_ref::<FilterExec>().is_some()
-        };
-
         let mut cardinality_preserved = true;
         let closure = |plan: Arc<dyn ExecutionPlan>| {
             if !cardinality_preserved {
@@ -109,14 +102,27 @@ impl TopKAggregation {
             }
             if let Some(aggr) = plan.as_any().downcast_ref::<AggregateExec>() {
                 // either we run into an Aggregate and transform it
-                match Self::transform_agg(aggr, order, limit) {
+                match Self::transform_agg(aggr, &cur_col_name, order_desc, 
limit) {
                     None => cardinality_preserved = false,
                     Some(plan) => return Ok(Transformed::yes(plan)),
                 }
+            } else if let Some(proj) = 
plan.as_any().downcast_ref::<ProjectionExec>() {
+                // track renames due to successive projections
+                for (src_expr, proj_name) in proj.expr() {
+                    let Some(src_col) = 
src_expr.as_any().downcast_ref::<Column>() else {
+                        continue;
+                    };
+                    if *proj_name == cur_col_name {
+                        cur_col_name = src_col.name().to_string();
+                    }
+                }
             } else {
-                // or we continue down whitelisted nodes of other types
-                if !is_cardinality_preserving(Arc::clone(&plan)) {
-                    cardinality_preserved = false;
+                // or we continue down through types that don't reduce 
cardinality
+                match plan.cardinality_effect() {
+                    CardinalityEffect::Equal | CardinalityEffect::GreaterEqual 
=> {}
+                    CardinalityEffect::Unknown | CardinalityEffect::LowerEqual 
=> {
+                        cardinality_preserved = false;
+                    }
                 }
             }
             Ok(Transformed::no(plan))
diff --git a/datafusion/physical-plan/src/aggregates/mod.rs 
b/datafusion/physical-plan/src/aggregates/mod.rs
index f9dd973c81..d6f16fb0fd 100644
--- a/datafusion/physical-plan/src/aggregates/mod.rs
+++ b/datafusion/physical-plan/src/aggregates/mod.rs
@@ -48,6 +48,7 @@ use datafusion_physical_expr::{
     PhysicalExpr, PhysicalSortRequirement,
 };
 
+use crate::execution_plan::CardinalityEffect;
 use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
 use itertools::Itertools;
 
@@ -866,6 +867,10 @@ impl ExecutionPlan for AggregateExec {
             }
         }
     }
+
+    fn cardinality_effect(&self) -> CardinalityEffect {
+        CardinalityEffect::LowerEqual
+    }
 }
 
 fn create_schema(
diff --git a/datafusion/physical-plan/src/coalesce_batches.rs 
b/datafusion/physical-plan/src/coalesce_batches.rs
index 7caf5b8ab6..e1a2f32d8a 100644
--- a/datafusion/physical-plan/src/coalesce_batches.rs
+++ b/datafusion/physical-plan/src/coalesce_batches.rs
@@ -34,6 +34,7 @@ use datafusion_common::Result;
 use datafusion_execution::TaskContext;
 
 use crate::coalesce::{BatchCoalescer, CoalescerState};
+use crate::execution_plan::CardinalityEffect;
 use futures::ready;
 use futures::stream::{Stream, StreamExt};
 
@@ -199,6 +200,10 @@ impl ExecutionPlan for CoalesceBatchesExec {
     fn fetch(&self) -> Option<usize> {
         self.fetch
     }
+
+    fn cardinality_effect(&self) -> CardinalityEffect {
+        CardinalityEffect::Equal
+    }
 }
 
 /// Stream for [`CoalesceBatchesExec`]. See [`CoalesceBatchesExec`] for more 
details.
diff --git a/datafusion/physical-plan/src/coalesce_partitions.rs 
b/datafusion/physical-plan/src/coalesce_partitions.rs
index 486ae41901..2ab6e3de1a 100644
--- a/datafusion/physical-plan/src/coalesce_partitions.rs
+++ b/datafusion/physical-plan/src/coalesce_partitions.rs
@@ -30,6 +30,7 @@ use super::{
 
 use crate::{DisplayFormatType, ExecutionPlan, Partitioning};
 
+use crate::execution_plan::CardinalityEffect;
 use datafusion_common::{internal_err, Result};
 use datafusion_execution::TaskContext;
 
@@ -178,6 +179,10 @@ impl ExecutionPlan for CoalescePartitionsExec {
     fn supports_limit_pushdown(&self) -> bool {
         true
     }
+
+    fn cardinality_effect(&self) -> CardinalityEffect {
+        CardinalityEffect::Equal
+    }
 }
 
 #[cfg(test)]
diff --git a/datafusion/physical-plan/src/execution_plan.rs 
b/datafusion/physical-plan/src/execution_plan.rs
index b14021f4a9..a89e265ad2 100644
--- a/datafusion/physical-plan/src/execution_plan.rs
+++ b/datafusion/physical-plan/src/execution_plan.rs
@@ -416,6 +416,11 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync {
     fn fetch(&self) -> Option<usize> {
         None
     }
+
+    /// Gets the effect on cardinality, if known
+    fn cardinality_effect(&self) -> CardinalityEffect {
+        CardinalityEffect::Unknown
+    }
 }
 
 /// Extension trait provides an easy API to fetch various properties of
@@ -898,6 +903,20 @@ pub fn get_plan_string(plan: &Arc<dyn ExecutionPlan>) -> 
Vec<String> {
     actual.iter().map(|elem| elem.to_string()).collect()
 }
 
+/// Indicates the effect an execution plan operator will have on the 
cardinality
+/// of its input stream
+pub enum CardinalityEffect {
+    /// Unknown effect. This is the default
+    Unknown,
+    /// The operator is guaranteed to produce exactly one row for
+    /// each input row
+    Equal,
+    /// The operator may produce fewer output rows than it receives input rows
+    LowerEqual,
+    /// The operator may produce more output rows than it receives input rows
+    GreaterEqual,
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
diff --git a/datafusion/physical-plan/src/filter.rs 
b/datafusion/physical-plan/src/filter.rs
index 417d2098b0..c39a91e251 100644
--- a/datafusion/physical-plan/src/filter.rs
+++ b/datafusion/physical-plan/src/filter.rs
@@ -48,6 +48,7 @@ use datafusion_physical_expr::{
     analyze, split_conjunction, AnalysisContext, ConstExpr, ExprBoundaries, 
PhysicalExpr,
 };
 
+use crate::execution_plan::CardinalityEffect;
 use futures::stream::{Stream, StreamExt};
 use log::trace;
 
@@ -372,6 +373,10 @@ impl ExecutionPlan for FilterExec {
     fn statistics(&self) -> Result<Statistics> {
         Self::statistics_helper(&self.input, self.predicate(), 
self.default_selectivity)
     }
+
+    fn cardinality_effect(&self) -> CardinalityEffect {
+        CardinalityEffect::LowerEqual
+    }
 }
 
 /// This function ensures that all bounds in the `ExprBoundaries` vector are
diff --git a/datafusion/physical-plan/src/limit.rs 
b/datafusion/physical-plan/src/limit.rs
index 360e942226..a42e2da605 100644
--- a/datafusion/physical-plan/src/limit.rs
+++ b/datafusion/physical-plan/src/limit.rs
@@ -34,6 +34,7 @@ use arrow::record_batch::RecordBatch;
 use datafusion_common::{internal_err, Result};
 use datafusion_execution::TaskContext;
 
+use crate::execution_plan::CardinalityEffect;
 use futures::stream::{Stream, StreamExt};
 use log::trace;
 
@@ -336,6 +337,10 @@ impl ExecutionPlan for LocalLimitExec {
     fn supports_limit_pushdown(&self) -> bool {
         true
     }
+
+    fn cardinality_effect(&self) -> CardinalityEffect {
+        CardinalityEffect::LowerEqual
+    }
 }
 
 /// A Limit stream skips `skip` rows, and then fetch up to `fetch` rows.
diff --git a/datafusion/physical-plan/src/projection.rs 
b/datafusion/physical-plan/src/projection.rs
index 4c889d1fc8..49bf059642 100644
--- a/datafusion/physical-plan/src/projection.rs
+++ b/datafusion/physical-plan/src/projection.rs
@@ -42,6 +42,7 @@ use datafusion_execution::TaskContext;
 use datafusion_physical_expr::equivalence::ProjectionMapping;
 use datafusion_physical_expr::expressions::Literal;
 
+use crate::execution_plan::CardinalityEffect;
 use futures::stream::{Stream, StreamExt};
 use log::trace;
 
@@ -233,6 +234,10 @@ impl ExecutionPlan for ProjectionExec {
     fn supports_limit_pushdown(&self) -> bool {
         true
     }
+
+    fn cardinality_effect(&self) -> CardinalityEffect {
+        CardinalityEffect::Equal
+    }
 }
 
 /// If e is a direct column reference, returns the field level
diff --git a/datafusion/physical-plan/src/repartition/mod.rs 
b/datafusion/physical-plan/src/repartition/mod.rs
index f0f198319e..d9368cf86d 100644
--- a/datafusion/physical-plan/src/repartition/mod.rs
+++ b/datafusion/physical-plan/src/repartition/mod.rs
@@ -48,6 +48,7 @@ use datafusion_execution::memory_pool::MemoryConsumer;
 use datafusion_execution::TaskContext;
 use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr, 
PhysicalSortExpr};
 
+use crate::execution_plan::CardinalityEffect;
 use futures::stream::Stream;
 use futures::{FutureExt, StreamExt, TryStreamExt};
 use hashbrown::HashMap;
@@ -669,6 +670,10 @@ impl ExecutionPlan for RepartitionExec {
     fn statistics(&self) -> Result<Statistics> {
         self.input.statistics()
     }
+
+    fn cardinality_effect(&self) -> CardinalityEffect {
+        CardinalityEffect::Equal
+    }
 }
 
 impl RepartitionExec {
diff --git a/datafusion/physical-plan/src/sorts/sort.rs 
b/datafusion/physical-plan/src/sorts/sort.rs
index 50f6f4a930..5d86c2183b 100644
--- a/datafusion/physical-plan/src/sorts/sort.rs
+++ b/datafusion/physical-plan/src/sorts/sort.rs
@@ -55,6 +55,7 @@ use datafusion_execution::TaskContext;
 use datafusion_physical_expr::LexOrdering;
 use datafusion_physical_expr_common::sort_expr::PhysicalSortRequirement;
 
+use crate::execution_plan::CardinalityEffect;
 use futures::{StreamExt, TryStreamExt};
 use log::{debug, trace};
 
@@ -972,6 +973,14 @@ impl ExecutionPlan for SortExec {
     fn fetch(&self) -> Option<usize> {
         self.fetch
     }
+
+    fn cardinality_effect(&self) -> CardinalityEffect {
+        if self.fetch.is_none() {
+            CardinalityEffect::Equal
+        } else {
+            CardinalityEffect::LowerEqual
+        }
+    }
 }
 
 #[cfg(test)]
diff --git a/datafusion/sqllogictest/test_files/aggregates_topk.slt 
b/datafusion/sqllogictest/test_files/aggregates_topk.slt
index 2209edc5d1..a67fec695f 100644
--- a/datafusion/sqllogictest/test_files/aggregates_topk.slt
+++ b/datafusion/sqllogictest/test_files/aggregates_topk.slt
@@ -53,6 +53,11 @@ physical_plan
 07)------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], 
aggr=[max(traces.timestamp)]
 08)--------------MemoryExec: partitions=1, partition_sizes=[1]
 
+query TI
+select * from (select trace_id, MAX(timestamp) max_ts from traces t group by 
trace_id) where trace_id != 'b' order by max_ts desc limit 3;
+----
+c 4
+a 1
 
 query TI
 select trace_id, MAX(timestamp) from traces group by trace_id order by 
MAX(timestamp) desc limit 4;
@@ -89,6 +94,12 @@ c 1 2
 statement ok
 set datafusion.optimizer.enable_topk_aggregation = true;
 
+query TI
+select * from (select trace_id, MAX(timestamp) max_ts from traces t group by 
trace_id) where max_ts != 3 order by max_ts desc limit 2;
+----
+c 4
+a 1
+
 query TT
 explain select trace_id, MAX(timestamp) from traces group by trace_id order by 
MAX(timestamp) desc limit 4;
 ----


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to