This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 98647e842a Stop copying LogicalPlan and Exprs in `PushDownLimit`
(#10508)
98647e842a is described below
commit 98647e842a85b768ea0cb0f8ccf1016636001abb
Author: Andrew Lamb <[email protected]>
AuthorDate: Fri May 17 04:14:32 2024 -0400
Stop copying LogicalPlan and Exprs in `PushDownLimit` (#10508)
* Stop copying LogicalPlan and Exprs in `PushDownLimit`
* Refine make_limit
---
datafusion/optimizer/src/push_down_limit.rs | 275 +++++++++++++++-------------
1 file changed, 149 insertions(+), 126 deletions(-)
diff --git a/datafusion/optimizer/src/push_down_limit.rs
b/datafusion/optimizer/src/push_down_limit.rs
index 9190881335..b97dff74d9 100644
--- a/datafusion/optimizer/src/push_down_limit.rs
+++ b/datafusion/optimizer/src/push_down_limit.rs
@@ -23,11 +23,10 @@ use std::sync::Arc;
use crate::optimizer::ApplyOrder;
use crate::{OptimizerConfig, OptimizerRule};
-use datafusion_common::Result;
-use datafusion_expr::logical_plan::{
- Join, JoinType, Limit, LogicalPlan, Sort, TableScan, Union,
-};
-use datafusion_expr::CrossJoin;
+use datafusion_common::tree_node::Transformed;
+use datafusion_common::{internal_err, Result};
+use datafusion_expr::logical_plan::tree_node::unwrap_arc;
+use datafusion_expr::logical_plan::{Join, JoinType, Limit, LogicalPlan};
/// Optimization rule that tries to push down `LIMIT`.
///
@@ -46,131 +45,120 @@ impl PushDownLimit {
impl OptimizerRule for PushDownLimit {
fn try_optimize(
&self,
- plan: &LogicalPlan,
+ _plan: &LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
- use std::cmp::min;
+ internal_err!("Should have called PushDownLimit::rewrite")
+ }
+
+ fn supports_rewrite(&self) -> bool {
+ true
+ }
- let LogicalPlan::Limit(limit) = plan else {
- return Ok(None);
+ fn rewrite(
+ &self,
+ plan: LogicalPlan,
+ _config: &dyn OptimizerConfig,
+ ) -> Result<Transformed<LogicalPlan>> {
+ let LogicalPlan::Limit(mut limit) = plan else {
+ return Ok(Transformed::no(plan));
};
- if let LogicalPlan::Limit(child) = &*limit.input {
- // Merge the Parent Limit and the Child Limit.
+ let Limit { skip, fetch, input } = limit;
+ let input = input;
+
+ // Merge the Parent Limit and the Child Limit.
+ if let LogicalPlan::Limit(child) = input.as_ref() {
let (skip, fetch) =
combine_limit(limit.skip, limit.fetch, child.skip,
child.fetch);
let plan = LogicalPlan::Limit(Limit {
skip,
fetch,
- input: Arc::new((*child.input).clone()),
+ input: Arc::clone(&child.input),
});
- return self
- .try_optimize(&plan, _config)
- .map(|opt_plan| opt_plan.or_else(|| Some(plan)));
+
+ // recursively reapply the rule on the new plan
+ return self.rewrite(plan, _config);
}
- let Some(fetch) = limit.fetch else {
- return Ok(None);
+ // no fetch to push, so return the original plan
+ let Some(fetch) = fetch else {
+ return Ok(Transformed::no(LogicalPlan::Limit(Limit {
+ skip,
+ fetch,
+ input,
+ })));
};
- let skip = limit.skip;
- match limit.input.as_ref() {
- LogicalPlan::TableScan(scan) => {
- let limit = if fetch != 0 { fetch + skip } else { 0 };
- let new_fetch = scan.fetch.map(|x| min(x,
limit)).or(Some(limit));
+ match unwrap_arc(input) {
+ LogicalPlan::TableScan(mut scan) => {
+ let rows_needed = if fetch != 0 { fetch + skip } else { 0 };
+ let new_fetch = scan
+ .fetch
+ .map(|x| min(x, rows_needed))
+ .or(Some(rows_needed));
if new_fetch == scan.fetch {
- Ok(None)
+ original_limit(skip, fetch, LogicalPlan::TableScan(scan))
} else {
- let new_input = LogicalPlan::TableScan(TableScan {
- table_name: scan.table_name.clone(),
- source: scan.source.clone(),
- projection: scan.projection.clone(),
- filters: scan.filters.clone(),
- fetch: scan.fetch.map(|x| min(x,
limit)).or(Some(limit)),
- projected_schema: scan.projected_schema.clone(),
- });
- plan.with_new_exprs(plan.expressions(), vec![new_input])
- .map(Some)
+ // push limit into the table scan itself
+ scan.fetch = scan
+ .fetch
+ .map(|x| min(x, rows_needed))
+ .or(Some(rows_needed));
+ transformed_limit(skip, fetch,
LogicalPlan::TableScan(scan))
}
}
- LogicalPlan::Union(union) => {
- let new_inputs = union
+ LogicalPlan::Union(mut union) => {
+ // push limits to each input of the union
+ union.inputs = union
.inputs
- .iter()
- .map(|x| {
- Ok(Arc::new(LogicalPlan::Limit(Limit {
- skip: 0,
- fetch: Some(fetch + skip),
- input: x.clone(),
- })))
- })
- .collect::<Result<_>>()?;
- let union = LogicalPlan::Union(Union {
- inputs: new_inputs,
- schema: union.schema.clone(),
- });
- plan.with_new_exprs(plan.expressions(), vec![union])
- .map(Some)
+ .into_iter()
+ .map(|input| make_arc_limit(0, fetch + skip, input))
+ .collect();
+ transformed_limit(skip, fetch, LogicalPlan::Union(union))
}
- LogicalPlan::CrossJoin(cross_join) => {
- let new_left = LogicalPlan::Limit(Limit {
- skip: 0,
- fetch: Some(fetch + skip),
- input: cross_join.left.clone(),
- });
- let new_right = LogicalPlan::Limit(Limit {
- skip: 0,
- fetch: Some(fetch + skip),
- input: cross_join.right.clone(),
- });
- let new_cross_join = LogicalPlan::CrossJoin(CrossJoin {
- left: Arc::new(new_left),
- right: Arc::new(new_right),
- schema: plan.schema().clone(),
- });
- plan.with_new_exprs(plan.expressions(), vec![new_cross_join])
- .map(Some)
+ LogicalPlan::CrossJoin(mut cross_join) => {
+ // push limit to both inputs
+ cross_join.left = make_arc_limit(0, fetch + skip,
cross_join.left);
+ cross_join.right = make_arc_limit(0, fetch + skip,
cross_join.right);
+ transformed_limit(skip, fetch,
LogicalPlan::CrossJoin(cross_join))
}
- LogicalPlan::Join(join) => {
- if let Some(new_join) = push_down_join(join, fetch + skip) {
- let inputs = vec![LogicalPlan::Join(new_join)];
- plan.with_new_exprs(plan.expressions(), inputs).map(Some)
- } else {
- Ok(None)
- }
- }
+ LogicalPlan::Join(join) => Ok(push_down_join(join, fetch + skip)
+ .update_data(|join| {
+ make_limit(skip, fetch, Arc::new(LogicalPlan::Join(join)))
+ })),
- LogicalPlan::Sort(sort) => {
+ LogicalPlan::Sort(mut sort) => {
let new_fetch = {
let sort_fetch = skip + fetch;
Some(sort.fetch.map(|f|
f.min(sort_fetch)).unwrap_or(sort_fetch))
};
if new_fetch == sort.fetch {
- Ok(None)
+ original_limit(skip, fetch, LogicalPlan::Sort(sort))
} else {
- let new_sort = LogicalPlan::Sort(Sort {
- expr: sort.expr.clone(),
- input: sort.input.clone(),
- fetch: new_fetch,
- });
- plan.with_new_exprs(plan.expressions(), vec![new_sort])
- .map(Some)
+ sort.fetch = new_fetch;
+ limit.input = Arc::new(LogicalPlan::Sort(sort));
+ Ok(Transformed::yes(LogicalPlan::Limit(limit)))
}
}
- child_plan @ (LogicalPlan::Projection(_) |
LogicalPlan::SubqueryAlias(_)) => {
+ LogicalPlan::Projection(mut proj) => {
// commute
- let new_limit = plan.with_new_exprs(
- plan.expressions(),
- vec![child_plan.inputs()[0].clone()],
- )?;
- child_plan
- .with_new_exprs(child_plan.expressions(), vec![new_limit])
- .map(Some)
+ limit.input = Arc::clone(&proj.input);
+ let new_limit = LogicalPlan::Limit(limit);
+ proj.input = Arc::new(new_limit);
+ Ok(Transformed::yes(LogicalPlan::Projection(proj)))
}
- _ => Ok(None),
+ LogicalPlan::SubqueryAlias(mut subquery_alias) => {
+ // commute
+ limit.input = Arc::clone(&subquery_alias.input);
+ let new_limit = LogicalPlan::Limit(limit);
+ subquery_alias.input = Arc::new(new_limit);
+
Ok(Transformed::yes(LogicalPlan::SubqueryAlias(subquery_alias)))
+ }
+ input => original_limit(skip, fetch, input),
}
}
@@ -183,6 +171,61 @@ impl OptimizerRule for PushDownLimit {
}
}
+/// Wrap the input plan with a limit node
+///
+/// Original:
+/// ```text
+/// input
+/// ```
+///
+/// Return
+/// ```text
+/// Limit: skip=skip, fetch=fetch
+/// input
+/// ```
+fn make_limit(skip: usize, fetch: usize, input: Arc<LogicalPlan>) ->
LogicalPlan {
+ LogicalPlan::Limit(Limit {
+ skip,
+ fetch: Some(fetch),
+ input,
+ })
+}
+
+/// Wrap the input plan with a limit node
+fn make_arc_limit(
+ skip: usize,
+ fetch: usize,
+ input: Arc<LogicalPlan>,
+) -> Arc<LogicalPlan> {
+ Arc::new(make_limit(skip, fetch, input))
+}
+
+/// Returns the original limit (non transformed)
+fn original_limit(
+ skip: usize,
+ fetch: usize,
+ input: LogicalPlan,
+) -> Result<Transformed<LogicalPlan>> {
+ Ok(Transformed::no(LogicalPlan::Limit(Limit {
+ skip,
+ fetch: Some(fetch),
+ input: Arc::new(input),
+ })))
+}
+
+/// Returns the a transformed limit
+fn transformed_limit(
+ skip: usize,
+ fetch: usize,
+ input: LogicalPlan,
+) -> Result<Transformed<LogicalPlan>> {
+ Ok(Transformed::yes(LogicalPlan::Limit(Limit {
+ skip,
+ fetch: Some(fetch),
+ input: Arc::new(input),
+ })))
+}
+
/// Combines two limits into a single
///
/// Returns the combined limit `(skip, fetch)`
@@ -255,14 +298,15 @@ fn combine_limit(
(combined_skip, combined_fetch)
}
-fn push_down_join(join: &Join, limit: usize) -> Option<Join> {
+/// Adds a limit to the inputs of a join, if possible
+fn push_down_join(mut join: Join, limit: usize) -> Transformed<Join> {
use JoinType::*;
fn is_no_join_condition(join: &Join) -> bool {
join.on.is_empty() && join.filter.is_none()
}
- let (left_limit, right_limit) = if is_no_join_condition(join) {
+ let (left_limit, right_limit) = if is_no_join_condition(&join) {
match join.join_type {
Left | Right | Full => (Some(limit), Some(limit)),
LeftAnti | LeftSemi => (Some(limit), None),
@@ -277,37 +321,16 @@ fn push_down_join(join: &Join, limit: usize) ->
Option<Join> {
}
};
- match (left_limit, right_limit) {
- (None, None) => None,
- _ => {
- let left = match left_limit {
- Some(limit) => Arc::new(LogicalPlan::Limit(Limit {
- skip: 0,
- fetch: Some(limit),
- input: join.left.clone(),
- })),
- None => join.left.clone(),
- };
- let right = match right_limit {
- Some(limit) => Arc::new(LogicalPlan::Limit(Limit {
- skip: 0,
- fetch: Some(limit),
- input: join.right.clone(),
- })),
- None => join.right.clone(),
- };
- Some(Join {
- left,
- right,
- on: join.on.clone(),
- filter: join.filter.clone(),
- join_type: join.join_type,
- join_constraint: join.join_constraint,
- schema: join.schema.clone(),
- null_equals_null: join.null_equals_null,
- })
- }
+ if left_limit.is_none() && right_limit.is_none() {
+ return Transformed::no(join);
+ }
+ if let Some(limit) = left_limit {
+ join.left = make_arc_limit(0, limit, join.left);
+ }
+ if let Some(limit) = right_limit {
+ join.right = make_arc_limit(0, limit, join.right);
}
+ Transformed::yes(join)
}
#[cfg(test)]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]