ozankabak commented on code in PR #7192: URL: https://github.com/apache/arrow-datafusion/pull/7192#discussion_r1306157419
########## datafusion/core/src/physical_optimizer/topk_aggregation.rs: ########## @@ -0,0 +1,187 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! An optimizer rule that detects aggregate operations that could use a limited bucket count + +use crate::physical_optimizer::PhysicalOptimizerRule; +use crate::physical_plan::aggregates::AggregateExec; +use crate::physical_plan::coalesce_batches::CoalesceBatchesExec; +use crate::physical_plan::filter::FilterExec; +use crate::physical_plan::repartition::RepartitionExec; +use crate::physical_plan::sorts::sort::SortExec; +use crate::physical_plan::ExecutionPlan; +use arrow_schema::DataType; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::Result; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::PhysicalSortExpr; +use std::sync::Arc; + +/// An optimizer rule that passes a `limit` hint to aggregations if the whole result is not needed +pub struct TopKAggregation {} + +impl TopKAggregation { + /// Create a new `LimitAggregation` + pub fn new() -> Self { + Self {} + } + + fn transform_agg( + aggr: &AggregateExec, + order: &PhysicalSortExpr, + limit: usize, + ) -> Option<Arc<dyn ExecutionPlan>> { + // ensure the sort direction matches aggregate function + let (field, desc) = aggr.get_minmax_desc()?; + if desc != order.options.descending { + return None; + } + let group_key = match aggr.group_expr().expr() { + [expr] => expr, // only one group key + _ => return None, + }; + let kt = group_key.0.data_type(&aggr.input().schema()).ok()?; + if !kt.is_primitive() && kt != DataType::Utf8 { + return None; // TODO: other types? + } + if aggr.filter_expr.iter().any(|e| e.is_some()) { + return None; + } + + // ensure the sort is on the same field as the aggregate output + let col = order.expr.as_any().downcast_ref::<Column>()?; + if col.name() != field.name() { + return None; + } + + // We found what we want: clone, copy the limit down, and return modified node + let mut new_aggr = AggregateExec::try_new( + aggr.mode, + aggr.group_by.clone(), + aggr.aggr_expr.clone(), + aggr.filter_expr.clone(), + aggr.order_by_expr.clone(), + aggr.input.clone(), + aggr.input_schema.clone(), + ) + .expect("Unable to copy Aggregate!"); + new_aggr.limit = Some(limit); + Some(Arc::new(new_aggr)) + } + + fn transform_sort(plan: Arc<dyn ExecutionPlan>) -> Option<Arc<dyn ExecutionPlan>> { + let sort = plan.as_any().downcast_ref::<SortExec>()?; + + // TODO: support sorting on multiple fields + let children = sort.children(); + let child = match children.as_slice() { + [child] => child.clone(), + _ => return None, + }; + let order = sort.output_ordering()?; + let order = match order { + [order] => order, + _ => return None, + }; + let limit = sort.fetch()?; + + let is_cardinality_preserving = |plan: Arc<dyn ExecutionPlan>| { + plan.as_any() + .downcast_ref::<CoalesceBatchesExec>() + .is_some() + || plan.as_any().downcast_ref::<RepartitionExec>().is_some() + || plan.as_any().downcast_ref::<FilterExec>().is_some() + // TODO: whitelist joins that don't increase row count? + }; + + let mut cardinality_preserved = true; + let mut closure = |plan: Arc<dyn ExecutionPlan>| { + if !cardinality_preserved { + return Ok(Transformed::No(plan)); + } + if let Some(aggr) = plan.as_any().downcast_ref::<AggregateExec>() { + // either we run into an Aggregate and transform it + match Self::transform_agg(aggr, order, limit) { + None => cardinality_preserved = false, + Some(plan) => return Ok(Transformed::Yes(plan)), + } + } else { + // or we continue down whitelisted nodes of other types + if !is_cardinality_preserving(plan.clone()) { + cardinality_preserved = false; + } + } + Ok(Transformed::No(plan)) + }; + let child = transform_down_mut(child, &mut closure).ok()?; + let sort = SortExec::new(sort.expr().to_vec(), child) Review Comment: > We can do it in this PR @ log2(limit) by draining the heap I think this makes sense -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
